Leaf-Annotate-v2 / Batch_Inference.py
Subh775's picture
Update Batch_Inference.py
461ad4e verified
import os
from pathlib import Path
import torch
from PIL import Image
import numpy as np
import cv2
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download
from tqdm import tqdm
HF_USERNAME = "Subh75"
HF_ORGNAME = "LeafNet75"
MODEL_NAME = "Leaf-Annotate-v2"
HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}"
# Set to your original image and output folder respectively
INPUT_IMAGE_DIR = "newimgs/images"
OUTPUT_MASK_DIR = "newimgs/masks"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 256
CONFIDENCE_THRESHOLD = 0.5
def load_model_from_hub(repo_id: str):
"""Loads the interactive segmentation model from the Hub."""
print(f"Loading model '{repo_id}' from Hugging Face Hub...")
model = smp.Unet(
encoder_name="mobilenet_v2",
encoder_weights=None,
in_channels=4, # RGB + Scribble
classes=1,
)
model_weights_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("Model loaded successfully.")
return model
def predict_scribble(model, pil_image, scribble_mask):
"""Runs inference using a scribble and returns a binary mask."""
img_resized = np.array(
pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR)
)
scribble_resized = cv2.resize(
scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST
)
img_tensor = (
torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
)
scribble_tensor = (
torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
)
input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
probs = torch.sigmoid(output)
binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
final_mask = cv2.resize(
binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST
)
return (final_mask * 255).astype(np.uint8)
def main():
"""Main function to run batch inference on a folder of images."""
if not os.path.isdir(INPUT_IMAGE_DIR):
print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'")
return
os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)
model = load_model_from_hub(HF_MODEL_REPO_ID)
image_files = [
f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg"))
]
print(f"\nFound {len(image_files)} images to process.")
for filename in tqdm(image_files, desc="Generating Masks"):
image_path = os.path.join(INPUT_IMAGE_DIR, filename)
try:
original_image = Image.open(image_path).convert("RGB")
h, w = original_image.height, original_image.width
# Create a dummy scribble (center line)
scribble = np.zeros((h, w), dtype=np.uint8)
center_x, center_y = w // 2, h // 2
length = int(min(w, h) * 0.2)
start_point = (center_x - length // 2, center_y)
end_point = (center_x + length // 2, center_y)
cv2.line(scribble, start_point, end_point, 255, thickness=25)
# Predict mask
predicted_mask = predict_scribble(model, original_image, scribble)
mask_image = Image.fromarray(predicted_mask)
# Keep same base name, save as .png in OUTPUT_MASK_DIR
base_name = Path(filename).stem
output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}.png")
mask_image.save(output_path)
except Exception as e:
print(f"\n Could not process {filename}. Error: {e}")
print(f"\n Done! Masks saved in '{OUTPUT_MASK_DIR}' with same names as input images.")
if __name__ == "__main__":
main()