File size: 3,977 Bytes
3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 043f7ed 3b68505 461ad4e 043f7ed 3b68505 043f7ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
from pathlib import Path
import torch
from PIL import Image
import numpy as np
import cv2
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download
from tqdm import tqdm
HF_USERNAME = "Subh75"
HF_ORGNAME = "LeafNet75"
MODEL_NAME = "Leaf-Annotate-v2"
HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}"
# Set to your original image and output folder respectively
INPUT_IMAGE_DIR = "newimgs/images"
OUTPUT_MASK_DIR = "newimgs/masks"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 256
CONFIDENCE_THRESHOLD = 0.5
def load_model_from_hub(repo_id: str):
"""Loads the interactive segmentation model from the Hub."""
print(f"Loading model '{repo_id}' from Hugging Face Hub...")
model = smp.Unet(
encoder_name="mobilenet_v2",
encoder_weights=None,
in_channels=4, # RGB + Scribble
classes=1,
)
model_weights_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("Model loaded successfully.")
return model
def predict_scribble(model, pil_image, scribble_mask):
"""Runs inference using a scribble and returns a binary mask."""
img_resized = np.array(
pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR)
)
scribble_resized = cv2.resize(
scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST
)
img_tensor = (
torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
)
scribble_tensor = (
torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
)
input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
probs = torch.sigmoid(output)
binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
final_mask = cv2.resize(
binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST
)
return (final_mask * 255).astype(np.uint8)
def main():
"""Main function to run batch inference on a folder of images."""
if not os.path.isdir(INPUT_IMAGE_DIR):
print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'")
return
os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)
model = load_model_from_hub(HF_MODEL_REPO_ID)
image_files = [
f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg"))
]
print(f"\nFound {len(image_files)} images to process.")
for filename in tqdm(image_files, desc="Generating Masks"):
image_path = os.path.join(INPUT_IMAGE_DIR, filename)
try:
original_image = Image.open(image_path).convert("RGB")
h, w = original_image.height, original_image.width
# Create a dummy scribble (center line)
scribble = np.zeros((h, w), dtype=np.uint8)
center_x, center_y = w // 2, h // 2
length = int(min(w, h) * 0.2)
start_point = (center_x - length // 2, center_y)
end_point = (center_x + length // 2, center_y)
cv2.line(scribble, start_point, end_point, 255, thickness=25)
# Predict mask
predicted_mask = predict_scribble(model, original_image, scribble)
mask_image = Image.fromarray(predicted_mask)
# Keep same base name, save as .png in OUTPUT_MASK_DIR
base_name = Path(filename).stem
output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}.png")
mask_image.save(output_path)
except Exception as e:
print(f"\n Could not process {filename}. Error: {e}")
print(f"\n Done! Masks saved in '{OUTPUT_MASK_DIR}' with same names as input images.")
if __name__ == "__main__":
main()
|