Spaces:
Build error
Build error
| import argparse | |
| import base64 | |
| import os | |
| from pathlib import Path | |
| from io import BytesIO | |
| import time | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS, cross_origin | |
| from consts import IMAGES_OUTPUT_DIR | |
| from utils import parse_arg_boolean, parse_arg_dalle_version | |
| from consts import ModelSize | |
| import gradio as gr | |
| def greet(name): | |
| return "Hello " + name + "!!" | |
| iface = gr.Interface(fn=greet, inputs="text", outputs="text") | |
| iface.launch() | |
| app = Flask(__name__) | |
| CORS(app) | |
| print("--> Starting DALL-E Server. This might take up to two minutes.") | |
| from dalle_model import DalleModel | |
| dalle_model = None | |
| parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights") | |
| parser.add_argument("--port", type=int, default=8000, help = "backend port") | |
| parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full") | |
| parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk") | |
| args = parser.parse_args() | |
| def generate_images_api(): | |
| json_data = request.get_json(force=True) | |
| text_prompt = json_data["text"] | |
| num_images = json_data["num_images"] | |
| generated_imgs = dalle_model.generate_images(text_prompt, num_images) | |
| generated_images = [] | |
| if args.save_to_disk: | |
| dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}") | |
| Path(dir_name).mkdir(parents=True, exist_ok=True) | |
| for idx, img in enumerate(generated_imgs): | |
| if args.save_to_disk: | |
| img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG") | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| generated_images.append(img_str) | |
| print(f"Created {num_images} images from text prompt [{text_prompt}]") | |
| return jsonify(generated_images) | |
| def health_check(): | |
| return jsonify(success=True) | |
| with app.app_context(): | |
| dalle_model = DalleModel(args.model_version) | |
| dalle_model.generate_images("warm-up", 1) | |
| print("--> DALL-E Server is up and running!") | |
| print(f"--> Model selected - DALL-E {args.model_version}") | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=args.port, debug=False) |