File size: 3,977 Bytes
3b68505
 
 
 
 
 
 
 
 
 
 
 
043f7ed
3b68505
 
 
043f7ed
 
 
3b68505
 
 
 
 
 
 
 
 
043f7ed
3b68505
 
 
043f7ed
3b68505
 
 
 
 
 
 
 
 
 
043f7ed
3b68505
 
043f7ed
 
 
 
 
 
3b68505
043f7ed
 
 
 
 
 
3b68505
 
043f7ed
3b68505
 
043f7ed
3b68505
 
043f7ed
 
 
 
3b68505
 
043f7ed
3b68505
043f7ed
3b68505
 
 
043f7ed
3b68505
043f7ed
3b68505
043f7ed
 
 
 
3b68505
 
043f7ed
3b68505
 
043f7ed
3b68505
 
 
 
043f7ed
3b68505
 
043f7ed
 
3b68505
 
 
043f7ed
 
3b68505
043f7ed
3b68505
043f7ed
 
3b68505
043f7ed
 
3b68505
 
 
461ad4e
043f7ed
 
3b68505
 
 
043f7ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from pathlib import Path
import torch
from PIL import Image
import numpy as np
import cv2
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download
from tqdm import tqdm


HF_USERNAME = "Subh75"
HF_ORGNAME = "LeafNet75"
MODEL_NAME = "Leaf-Annotate-v2"
HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}"

# Set to your original image and output folder respectively
INPUT_IMAGE_DIR = "newimgs/images"
OUTPUT_MASK_DIR = "newimgs/masks"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 256
CONFIDENCE_THRESHOLD = 0.5


def load_model_from_hub(repo_id: str):
    """Loads the interactive segmentation model from the Hub."""
    print(f"Loading model '{repo_id}' from Hugging Face Hub...")

    model = smp.Unet(
        encoder_name="mobilenet_v2",
        encoder_weights=None,
        in_channels=4,  # RGB + Scribble
        classes=1,
    )

    model_weights_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
    model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    print("Model loaded successfully.")
    return model


def predict_scribble(model, pil_image, scribble_mask):
    """Runs inference using a scribble and returns a binary mask."""
    img_resized = np.array(
        pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR)
    )
    scribble_resized = cv2.resize(
        scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST
    )

    img_tensor = (
        torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
    )
    scribble_tensor = (
        torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
    )

    input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(input_tensor)

    probs = torch.sigmoid(output)
    binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()

    final_mask = cv2.resize(
        binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST
    )
    return (final_mask * 255).astype(np.uint8)


def main():
    """Main function to run batch inference on a folder of images."""
    if not os.path.isdir(INPUT_IMAGE_DIR):
        print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'")
        return

    os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)

    model = load_model_from_hub(HF_MODEL_REPO_ID)

    image_files = [
        f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ]

    print(f"\nFound {len(image_files)} images to process.")

    for filename in tqdm(image_files, desc="Generating Masks"):
        image_path = os.path.join(INPUT_IMAGE_DIR, filename)

        try:
            original_image = Image.open(image_path).convert("RGB")
            h, w = original_image.height, original_image.width

            # Create a dummy scribble (center line)
            scribble = np.zeros((h, w), dtype=np.uint8)
            center_x, center_y = w // 2, h // 2
            length = int(min(w, h) * 0.2)

            start_point = (center_x - length // 2, center_y)
            end_point = (center_x + length // 2, center_y)
            cv2.line(scribble, start_point, end_point, 255, thickness=25)

            # Predict mask
            predicted_mask = predict_scribble(model, original_image, scribble)

            mask_image = Image.fromarray(predicted_mask)

            # Keep same base name, save as .png in OUTPUT_MASK_DIR
            base_name = Path(filename).stem
            output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}.png")

            mask_image.save(output_path)

        except Exception as e:
            print(f"\n Could not process {filename}. Error: {e}")

    print(f"\n Done! Masks saved in '{OUTPUT_MASK_DIR}' with same names as input images.")


if __name__ == "__main__":
    main()