import os from pathlib import Path import torch from PIL import Image import numpy as np import cv2 import segmentation_models_pytorch as smp from huggingface_hub import hf_hub_download from tqdm import tqdm HF_USERNAME = "Subh75" HF_ORGNAME = "LeafNet75" MODEL_NAME = "Leaf-Annotate-v2" HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}" # Set to your original image and output folder respectively INPUT_IMAGE_DIR = "newimgs/images" OUTPUT_MASK_DIR = "newimgs/masks" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" IMG_SIZE = 256 CONFIDENCE_THRESHOLD = 0.5 def load_model_from_hub(repo_id: str): """Loads the interactive segmentation model from the Hub.""" print(f"Loading model '{repo_id}' from Hugging Face Hub...") model = smp.Unet( encoder_name="mobilenet_v2", encoder_weights=None, in_channels=4, # RGB + Scribble classes=1, ) model_weights_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth") model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE)) model.to(DEVICE) model.eval() print("Model loaded successfully.") return model def predict_scribble(model, pil_image, scribble_mask): """Runs inference using a scribble and returns a binary mask.""" img_resized = np.array( pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR) ) scribble_resized = cv2.resize( scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST ) img_tensor = ( torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0 ) scribble_tensor = ( torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0 ) input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = model(input_tensor) probs = torch.sigmoid(output) binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy() final_mask = cv2.resize( binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST ) return (final_mask * 255).astype(np.uint8) def main(): """Main function to run batch inference on a folder of images.""" if not os.path.isdir(INPUT_IMAGE_DIR): print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'") return os.makedirs(OUTPUT_MASK_DIR, exist_ok=True) model = load_model_from_hub(HF_MODEL_REPO_ID) image_files = [ f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg")) ] print(f"\nFound {len(image_files)} images to process.") for filename in tqdm(image_files, desc="Generating Masks"): image_path = os.path.join(INPUT_IMAGE_DIR, filename) try: original_image = Image.open(image_path).convert("RGB") h, w = original_image.height, original_image.width # Create a dummy scribble (center line) scribble = np.zeros((h, w), dtype=np.uint8) center_x, center_y = w // 2, h // 2 length = int(min(w, h) * 0.2) start_point = (center_x - length // 2, center_y) end_point = (center_x + length // 2, center_y) cv2.line(scribble, start_point, end_point, 255, thickness=25) # Predict mask predicted_mask = predict_scribble(model, original_image, scribble) mask_image = Image.fromarray(predicted_mask) # Keep same base name, save as .png in OUTPUT_MASK_DIR base_name = Path(filename).stem output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}.png") mask_image.save(output_path) except Exception as e: print(f"\n Could not process {filename}. Error: {e}") print(f"\n Done! Masks saved in '{OUTPUT_MASK_DIR}' with same names as input images.") if __name__ == "__main__": main()