Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| # -------- MODEL DEFINITION -------- | |
| class ImprovedCNN(nn.Module): | |
| def __init__(self): | |
| super(ImprovedCNN, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(128 * 16 * 16, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(512, 1) | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| return x | |
| # -------- LOAD MODEL -------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| model = ImprovedCNN().to(device) | |
| model_path = "age_prediction_model3.pth" | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| print(f"β Model loaded from {model_path}") | |
| # -------- IMAGE PREPROCESSING -------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # -------- PREDICTION FUNCTION -------- | |
| def predict_age(image: Image.Image) -> float: | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| age = output.item() | |
| return round(age, 2) | |
| # -------- GRADIO UI -------- | |
| demo = gr.Interface( | |
| fn=predict_age, | |
| inputs=gr.Image(type="pil", image_mode="RGB", label="Upload Face Image"), | |
| outputs=gr.Number(label="Predicted Age"), | |
| title="Face Age Prediction", | |
| description="Upload a face image to predict age using a CNN model." | |
| ) | |
| # -------- LAUNCH -------- | |
| if __name__ == "__main__": | |
| demo.launch() | |