Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image, UnidentifiedImageError | |
| from transformers import AutoProcessor, Blip2ForConditionalGeneration | |
| import torch | |
| import io | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load the model and processor | |
| try: | |
| model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded") | |
| model.load_adapter('blip-cpu-model') | |
| processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load the model or processor: {str(e)}") | |
| async def generate_caption(file: UploadFile = File(...)): | |
| try: | |
| image = Image.open(io.BytesIO(await file.read())) | |
| except UnidentifiedImageError: | |
| # Raise a 400 error if the file is not a valid image | |
| raise HTTPException(status_code=400, detail="Uploaded file is not a valid image.") | |
| except Exception as e: | |
| # Catch any other unexpected errors related to image processing | |
| raise HTTPException(status_code=500, detail=f"An unexpected error occurred while processing the image: {str(e)}") | |
| try: | |
| inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) | |
| with torch.no_grad(): | |
| caption_ids = model.generate(**inputs, max_length=128) | |
| caption = processor.decode(caption_ids[0], skip_special_tokens=True) | |
| return {"caption": caption} | |
| except Exception as e: | |
| # Catch any errors during the caption generation process | |
| raise HTTPException(status_code=500, detail=f"An error occurred while generating the caption: {str(e)}") |