|
|
import os |
|
|
from pathlib import Path |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import segmentation_models_pytorch as smp |
|
|
from huggingface_hub import hf_hub_download |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
HF_USERNAME = "Subh75" |
|
|
HF_ORGNAME = "LeafNet75" |
|
|
MODEL_NAME = "Leaf-Annotate-v2" |
|
|
HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}" |
|
|
|
|
|
|
|
|
INPUT_IMAGE_DIR = "newimgs/images" |
|
|
OUTPUT_MASK_DIR = "newimgs/masks" |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
IMG_SIZE = 256 |
|
|
CONFIDENCE_THRESHOLD = 0.5 |
|
|
|
|
|
|
|
|
def load_model_from_hub(repo_id: str): |
|
|
"""Loads the interactive segmentation model from the Hub.""" |
|
|
print(f"Loading model '{repo_id}' from Hugging Face Hub...") |
|
|
|
|
|
model = smp.Unet( |
|
|
encoder_name="mobilenet_v2", |
|
|
encoder_weights=None, |
|
|
in_channels=4, |
|
|
classes=1, |
|
|
) |
|
|
|
|
|
model_weights_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth") |
|
|
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE)) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
print("Model loaded successfully.") |
|
|
return model |
|
|
|
|
|
|
|
|
def predict_scribble(model, pil_image, scribble_mask): |
|
|
"""Runs inference using a scribble and returns a binary mask.""" |
|
|
img_resized = np.array( |
|
|
pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR) |
|
|
) |
|
|
scribble_resized = cv2.resize( |
|
|
scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST |
|
|
) |
|
|
|
|
|
img_tensor = ( |
|
|
torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0 |
|
|
) |
|
|
scribble_tensor = ( |
|
|
torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0 |
|
|
) |
|
|
|
|
|
input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor) |
|
|
|
|
|
probs = torch.sigmoid(output) |
|
|
binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy() |
|
|
|
|
|
final_mask = cv2.resize( |
|
|
binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST |
|
|
) |
|
|
return (final_mask * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run batch inference on a folder of images.""" |
|
|
if not os.path.isdir(INPUT_IMAGE_DIR): |
|
|
print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'") |
|
|
return |
|
|
|
|
|
os.makedirs(OUTPUT_MASK_DIR, exist_ok=True) |
|
|
|
|
|
model = load_model_from_hub(HF_MODEL_REPO_ID) |
|
|
|
|
|
image_files = [ |
|
|
f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg")) |
|
|
] |
|
|
|
|
|
print(f"\nFound {len(image_files)} images to process.") |
|
|
|
|
|
for filename in tqdm(image_files, desc="Generating Masks"): |
|
|
image_path = os.path.join(INPUT_IMAGE_DIR, filename) |
|
|
|
|
|
try: |
|
|
original_image = Image.open(image_path).convert("RGB") |
|
|
h, w = original_image.height, original_image.width |
|
|
|
|
|
|
|
|
scribble = np.zeros((h, w), dtype=np.uint8) |
|
|
center_x, center_y = w // 2, h // 2 |
|
|
length = int(min(w, h) * 0.2) |
|
|
|
|
|
start_point = (center_x - length // 2, center_y) |
|
|
end_point = (center_x + length // 2, center_y) |
|
|
cv2.line(scribble, start_point, end_point, 255, thickness=25) |
|
|
|
|
|
|
|
|
predicted_mask = predict_scribble(model, original_image, scribble) |
|
|
|
|
|
mask_image = Image.fromarray(predicted_mask) |
|
|
|
|
|
|
|
|
base_name = Path(filename).stem |
|
|
output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}.png") |
|
|
|
|
|
mask_image.save(output_path) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n Could not process {filename}. Error: {e}") |
|
|
|
|
|
print(f"\n Done! Masks saved in '{OUTPUT_MASK_DIR}' with same names as input images.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|