import spaces from dataclasses import dataclass import json import logging import os import random import re import sys import warnings from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler import gradio as gr import torch from transformers import AutoModel, AutoTokenizer sys.path.append(os.path.dirname(os.path.abspath(__file__))) from diffusers import ZImagePipeline from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel from pe import prompt_template # ==================== Environment Variables ================================== MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") HF_TOKEN = os.environ.get("HF_TOKEN") # ============================================================================= os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") logging.getLogger("transformers").setLevel(logging.ERROR) RES_CHOICES = { "1024": [ "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )", "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )", "832x1248 ( 2:3 )", "1280x720 ( 16:9 )", "720x1280 ( 9:16 )", "1344x576 ( 21:9 )", "576x1344 ( 9:21 )", ], "1280": [ "1280x1280 ( 1:1 )", "1440x1120 ( 9:7 )", "1120x1440 ( 7:9 )", "1472x1104 ( 4:3 )", "1104x1472 ( 3:4 )", "1536x1024 ( 3:2 )", "1024x1536 ( 2:3 )", "1600x896 ( 16:9 )", "896x1600 ( 9:16 )", # not 900 coz divided by 16 needed "1680x720 ( 21:9 )", "720x1680 ( 9:21 )", ], } RESOLUTION_SET = [] for resolutions in RES_CHOICES.values(): RESOLUTION_SET.extend(resolutions) EXAMPLE_PROMPTS = [ ["한 남성과 그의 푸들이 어울리는 의상을 입고 실내 조명 아래 관객들이 있는 배경에서 개 쇼에 참가하고 있는 모습."], [ "분위기 있는 어두운 톤의 인물 사진, 우아한 중국 여성이 어두운 방에 있다. 강한 빛이 셔터를 통과해 그녀의 얼굗에 번개 모양의 선명한 빛과 그림자를 투사하며 한쪽 눈만을 정확히 비춘다. 높은 대비, 명암 경계가 선명하며, 신비로운 느낌, 라이카 카메라 색조." ], [ "밝게 조명된 엘리베이터 안에서 긴 검은 머리를 한 젊은 동아시아 여성이 거울을 향해 셀카를 찍는 중간 거리 스마트폰 셀카 사진. 그녀는 흰색 꽃무늬가 있는 검은색 오프숄더 크롭탑과 어두운 청바지를 입고 있다. 머리를 약간 기울이고 입술을 뾰족하게 내밀어 키스하는 듯한 포즈로 매우 귀엽고 장난스러운 모습이다. 오른손에 짙은 회색 스마트폰을 들고 얼굴 일부를 가리고 있으며, 후면 카메라 렌즈가 거울을 향하고 있다." ], [ "빨간 한푸를 입은 젊은 중국 여성, 정교한 자수. 완벽한 메이크업, 붉은 꽃무늬 이마 장식. 정교한 높은 쪽진 머리, 금빛 봉황 머리 장식, 붉은 꽃, 구슬. 여인과 나무, 새가 그려진 둥근 접이식 부채를 들고 있다. 네온 번개 모양 램프 (⚡️), 밝은 노란색 빛, 펼친 왼쪽 손바닥 위에. 부드럽게 조명된 야외 밤 배경, 실루엣의 다층 탑(서안 대안탑), 흐릿한 컬러 먼 불빛들." ], [ '''고요하고 장엄한 중국 풍경을 묘사한 세로 형식의 디지털 일러스트레이션으로, 전통적인 산수화 스타일을 현대적이고 깔끔한 미학으로 재해석했다. 장면은 중앙 계곡을 둘러싼 다양한 파란색과 청록색 음영의 우뚝 솟은 가파른 절벽이 지배한다. 멀리 산들이 층층이 연한 파란색과 흰색 안개 속으로 사라지며 강한 대기 원근감과 깊이를 만들어낸다. 고요한 청록색 강이 구성의 중앙을 가로질러 흐르며, 작은 전통 중국 배, 아마도 삼판이 물 위를 항해하고 있다. 배는 밝은 노란색 천막과 붉은 선체를 가지고 있으며 뒤에 부드러운 물결을 남긴다. 여러 명의 희미한 인물들을 태우고 있다. 녹색 나무와 일부 맨가지 나무를 포함한 드문드문한 식생이 바위 선반과 봉우리에 붙어 있다. 전체 조명은 부드럽고 확산되어 전체 장면에 평온한 빛을 드리운다. 이미지 중앙에 텍스트가 겹쳐져 있다. 텍스트 블록 상단에는 양식화된 문자가 포함된 작고 빨간색의 원형 도장 같은 로고가 있다. 그 아래 작은 검은색 산세리프 글꼴로 'Zao-Xiang * East Beauty & West Fashion * Z-Image'라는 단어가 있다. 그 바로 아래 더 큰 우아한 검은색 세리프 글꼴로 'SHOW & SHARE CREATIVITY WITH THE WORLD'라는 단어가 있다. 그 중에는 "SHOW & SHARE", "CREATIVITY", "WITH THE WORLD"가 있다.''' ], [ """가상의 영화 《회상의 맛》(The Taste of Memory)의 영화 포스터. 장면은 소박한 19세기 스타일 주방에 설정되어 있다. 화면 중앙에 적갈색 머리와 작은 콧수염을 가진 중년 남성(배우 아서 펜할리건 연기)이 나무 테이블 뒤에 서 있으며, 흰색 셔츠, 검은색 조끼, 베이지색 앞치마를 입고 있고 한 여성을 바라보며 손에 큰 덩어리의 생고기를 들고 있으며 아래에는 나무 도마가 있다. 그의 오른쪽에는 높은 쪽진 머리를 한 검은 머리 여성(배우 엘리너 밴스 연기)이 테이블에 기대어 그에게 부드럽게 미소짓고 있다. 그녀는 연한 색 셔츠와 상단은 흰색, 하단은 파란색인 긴 치마를 입고 있다. 테이블 위에는 다진 파와 양배추 채가 있는 도마 외에도 흰색 도자기 접시, 신선한 허브가 있고, 왼쪽 나무 상자 위에는 짙은 색 포도 한 송이가 놓여 있다. 배경은 거칠게 회백색으로 미장된 벽이며 풍경화 한 점이 걸려 있다. 가장 오른쪽 작업대 위에는 복고풍 오일 램프가 놓여 있다. 포스터에는 많은 텍스트 정보가 있다. 왼쪽 상단에는 흰색 산세리프 글꼴로 "ARTISAN FILMS PRESENTS"가 있고 그 아래에 "ELEANOR VANCE"와 "ACADEMY AWARD® WINNER"가 있다. 오른쪽 상단에는 "ARTHUR PENHALIGON"과 "GOLDEN GLOBE® AWARD WINNER"가 쓰여 있다. 상단 중앙에는 선댄스 영화제 월계관 로고가 있고 아래에 "SUNDANCE FILM FESTIVAL GRAND JURY PRIZE 2024"가 쓰여 있다. 주요 제목 "THE TASTE OF MEMORY"는 흰색의 큰 세리프 글꼴로 하단에 눈에 띄게 표시되어 있다. 제목 아래에는 "A FILM BY Tongyi Interaction Lab"이 명시되어 있다. 하단 영역에는 흰색 작은 글씨로 "SCREENPLAY BY ANNA REID", "CULINARY DIRECTION BY JAMES CARTER" 및 Artisan Films, Riverstone Pictures, Heritage Media 등 수많은 제작사 로고를 포함한 전체 출연진 및 제작진 명단이 나열되어 있다. 전체적인 스타일은 사실주의로 따뜻하고 부드러운 조명 방식을 채택하여 친밀한 분위기를 조성한다. 색조는 갈색, 베이지, 부드러운 녹색 등 대지색 톤이 주를 이룬다. 두 배우의 몸은 모두 허리에서 잘려 있다.""" ], [ """정사각형 구도의 클로즈업 사진으로, 거대하고 선명한 녹색 식물 잎이 주제이며 텍스트가 겹쳐져 포스터나 잡지 표지 같은 외관을 갖추고 있다. 주요 피사체는 왼쪽 하단에서 오른쪽 상단으로 대각선으로 구부러져 프레임을 가로지르는 두껍고 왁스 같은 질감의 잎이다. 표면이 매우 반사적이어서 밝은 직사광원을 포착하여 두드러진 하이라이트를 형성하고 밝은 면 아래 평행한 미세 잎맥이 드러난다. 배경은 다른 짙은 녹색 잎들로 구성되어 있으며 약간 초점이 흐려져 얕은 피사계 심도 효과를 만들어 전경의 주요 잎을 강조한다. 전체적인 스타일은 사실적인 사진으로 밝은 잎과 어두운 그림자 배경 사이에 높은 대비를 형성한다. 이미지에는 여러 렌더링된 텍스트가 있다. 왼쪽 상단에는 흰색 세리프 글꼴로 "PIXEL-PEEPERS GUILD Presents"라는 텍스트가 있다. 오른쪽 상단에도 흰색 세리프 글꼴로 "[Instant Noodle] 泡面调料包"라는 텍스트가 있다. 왼쪽에는 수직으로 배열된 제목 "Render Distance: Max"가 흰색 세리프 글꼴로 되어 있다. 왼쪽 하단에는 다섯 개의 거대한 흰색 송체 한자 "显卡在...燃烧"가 있다. 오른쪽 하단에는 작은 흰색 세리프 글꼴로 "Leica Glow™ Unobtanium X-1"이 있고, 그 바로 위에는 흰색 송체로 쓰인 이름 "蔡几"가 있다. 식별된 핵심 개체에는 브랜드 픽셀 피퍼스 길드, 제품 라인 인스턴트 누들 조미료 패키지, 카메라 모델 Unobtanium™ X-1 및 사진가 이름 Zao-Xiang이 포함된다.""" ], ] def get_resolution(resolution): match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) if match: return int(match.group(1)), int(match.group(2)) return 1024, 1024 def load_models(model_path, enable_compile=False, attention_backend="native"): print(f"Loading models from {model_path}...") use_auth_token = HF_TOKEN if HF_TOKEN else True if not os.path.exists(model_path): vae = AutoencoderKL.from_pretrained( f"{model_path}", subfolder="vae", torch_dtype=torch.bfloat16, device_map="cuda", use_auth_token=use_auth_token, ) text_encoder = AutoModel.from_pretrained( f"{model_path}", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda", use_auth_token=use_auth_token, ).eval() tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token) else: vae = AutoencoderKL.from_pretrained( os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda" ) text_encoder = AutoModel.from_pretrained( os.path.join(model_path, "text_encoder"), torch_dtype=torch.bfloat16, device_map="cuda", ).eval() tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) tokenizer.padding_side = "left" if enable_compile: print("Enabling torch.compile optimizations...") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.max_autotune_gemm = True torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" torch._inductor.config.triton.cudagraphs = False pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None) if enable_compile: pipe.vae.disable_tiling() if not os.path.exists(model_path): transformer = ZImageTransformer2DModel.from_pretrained( f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token ).to("cuda", torch.bfloat16) else: transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to( "cuda", torch.bfloat16 ) pipe.transformer = transformer pipe.transformer.set_attention_backend(attention_backend) if enable_compile: print("Compiling transformer...") pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) pipe.to("cuda", torch.bfloat16) return pipe def generate_image( pipe, prompt, resolution="1024x1024", seed=42, guidance_scale=5.0, num_inference_steps=50, shift=3.0, max_sequence_length=512, progress=gr.Progress(track_tqdm=True), ): width, height = get_resolution(resolution) generator = torch.Generator("cuda").manual_seed(seed) scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) pipe.scheduler = scheduler image = pipe( prompt=prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, max_sequence_length=max_sequence_length, ).images[0] return image def warmup_model(pipe, resolutions): print("Starting warmup phase...") dummy_prompt = "warmup" for res_str in resolutions: print(f"Warming up for resolution: {res_str}") try: for i in range(3): generate_image( pipe, prompt=dummy_prompt, resolution=res_str, num_inference_steps=9, guidance_scale=0.0, seed=42 + i, ) except Exception as e: print(f"Warmup failed for {res_str}: {e}") print("Warmup completed.") # ==================== Prompt Expander ==================== @dataclass class PromptOutput: status: bool prompt: str seed: int system_prompt: str message: str class PromptExpander: def __init__(self, backend="api", **kwargs): self.backend = backend def decide_system_prompt(self, template_name=None): return prompt_template class APIPromptExpander(PromptExpander): def __init__(self, api_config=None, **kwargs): super().__init__(backend="api", **kwargs) self.api_config = api_config or {} self.client = self._init_api_client() def _init_api_client(self): try: from openai import OpenAI api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1") if not api_key: print("Warning: DASHSCOPE_API_KEY not found.") return None return OpenAI(api_key=api_key, base_url=base_url) except ImportError: print("Please install openai: pip install openai") return None except Exception as e: print(f"Failed to initialize API client: {e}") return None def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs): return self.extend(prompt, system_prompt, seed, **kwargs) def extend(self, prompt, system_prompt=None, seed=-1, **kwargs): if self.client is None: return PromptOutput(False, "", seed, system_prompt, "API client not initialized") if system_prompt is None: system_prompt = self.decide_system_prompt() if "{prompt}" in system_prompt: system_prompt = system_prompt.format(prompt=prompt) prompt = " " try: model = self.api_config.get("model", "qwen3-max-preview") response = self.client.chat.completions.create( model=model, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], temperature=0.7, top_p=0.8, ) content = response.choices[0].message.content json_start = content.find("```json") if json_start != -1: json_end = content.find("```", json_start + 7) try: json_str = content[json_start + 7 : json_end].strip() data = json.loads(json_str) expanded_prompt = data.get("revised_prompt", content) except: expanded_prompt = content else: expanded_prompt = content return PromptOutput( status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content ) except Exception as e: return PromptOutput(False, "", seed, system_prompt, str(e)) def create_prompt_expander(backend="api", **kwargs): if backend == "api": return APIPromptExpander(**kwargs) raise ValueError("Only 'api' backend is supported.") pipe = None prompt_expander = None def init_app(): global pipe, prompt_expander try: pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}") if ENABLE_WARMUP: all_resolutions = [] for cat in RES_CHOICES.values(): all_resolutions.extend(cat) warmup_model(pipe, all_resolutions) except Exception as e: print(f"Error loading model: {e}") pipe = None try: prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"}) print("Prompt expander initialized.") except Exception as e: print(f"Error initializing prompt expander: {e}") prompt_expander = None def prompt_enhance(prompt, enable_enhance): if not enable_enhance or not prompt_expander: return prompt, "프롬프트 향상이 비활성화되었거나 사용할 수 없습니다." if not prompt.strip(): return "", "프롬프트를 입력해주세요." try: result = prompt_expander(prompt) if result.status: return result.prompt, result.message else: return prompt, f"향상 실패: {result.message}" except Exception as e: return prompt, f"오류: {str(e)}" @spaces.GPU def generate( prompt, resolution="1024x1024 ( 1:1 )", seed=42, steps=9, shift=3.0, enhance=False, random_seed=True, gallery_images=None, progress=gr.Progress(track_tqdm=True) ): """ Generate an image using the Z-Image model based on the provided prompt and settings. This function is triggered when the user clicks the "Generate" button. It processes the input prompt (optionally enhancing it), configures generation parameters, and produces an image using the Z-Image diffusion transformer pipeline. Args: prompt (str): Text prompt describing the desired image content resolution (str): Output resolution in format "WIDTHxHEIGHT ( RATIO )" (e.g., "1024x1024 ( 1:1 )") valid options, 1024 category: - "1024x1024 ( 1:1 )" - "1152x896 ( 9:7 )" - "896x1152 ( 7:9 )" - "1152x864 ( 4:3 )" - "864x1152 ( 3:4 )" - "1248x832 ( 3:2 )" - "832x1248 ( 2:3 )" - "1280x720 ( 16:9 )" - "720x1280 ( 9:16 )" - "1344x576 ( 21:9 )" - "576x1344 ( 9:21 )" 1280 category: - "1280x1280 ( 1:1 )" - "1440x1120 ( 9:7 )" - "1120x1440 ( 7:9 )" - "1472x1104 ( 4:3 )" - "1104x1472 ( 3:4 )" - "1536x1024 ( 3:2 )" - "1024x1536 ( 2:3 )" - "1600x896 ( 16:9 )" - "896x1600 ( 9:16 )" - "1680x720 ( 21:9 )" - "720x1680 ( 9:21 )" seed (int): Seed for reproducible generation steps (int): Number of inference steps for the diffusion process shift (float): Time shift parameter for the flow matching scheduler enhance (bool): This was Whether to enhance the prompt (DISABLED! Do not use) random_seed (bool): Whether to generate a new random seed, if True will ignore the seed input gallery_images (list): List of previously generated images to append to (only needed for the Gradio UI) progress (gr.Progress): Gradio progress tracker for displaying generation progress (only needed for the Gradio UI) Returns: tuple: (gallery_images, seed_str, seed_int) - gallery_images: Updated list of generated images including the new image - seed_str: String representation of the seed used for generation - seed_int: Integer representation of the seed used for generation """ if pipe is None: raise gr.Error("모델이 로드되지 않았습니다.") final_prompt = prompt if enhance: final_prompt, _ = prompt_enhance(prompt, True) print(f"Enhanced prompt: {final_prompt}") if random_seed: new_seed = random.randint(1, 1000000) else: new_seed = seed if seed != -1 else random.randint(1, 1000000) try: resolution_str = resolution.split(" ")[0] except: resolution_str = "1024x1024" image = generate_image( pipe=pipe, prompt=final_prompt, resolution=resolution_str, seed=new_seed, guidance_scale=0.0, num_inference_steps=int(steps + 1), shift=shift, ) if gallery_images is None: gallery_images = [] gallery_images.append(image) return gallery_images, str(new_seed), int(new_seed) init_app() # ==================== AoTI (Ahead of Time Inductor compilation) ==================== pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") with gr.Blocks(title="Z-Image 데모") as demo: gr.Markdown( """
# Z-Image 이미지 생성 데모 [![GitHub](https://img.shields.io/badge/GitHub-Z--Image-181717?logo=github&logoColor=white)](https://github.com/Tongyi-MAI/Z-Image) *단일 스트림 디퓨전 트랜스포머를 사용한 효율적인 이미지 생성 기반 모델*
""" ) with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox(label="프롬프트", lines=3, placeholder="프롬프트를 입력하세요...") # PE components (Temporarily disabled) # with gr.Row(): # enable_enhance = gr.Checkbox(label="프롬프트 향상 (DashScope)", value=False) # enhance_btn = gr.Button("향상만 실행") with gr.Row(): choices = [int(k) for k in RES_CHOICES.keys()] res_cat = gr.Dropdown(value=1024, choices=choices, label="해상도 카테고리") initial_res_choices = RES_CHOICES["1024"] resolution = gr.Dropdown(value=initial_res_choices[0], choices=RESOLUTION_SET, label="너비 x 높이 (비율)") with gr.Row(): seed = gr.Number(label="시드", value=42, precision=0) random_seed = gr.Checkbox(label="랜덤 시드", value=True) with gr.Row(): steps = gr.Slider(label="스텝 수", minimum=1, maximum=100, value=8, step=1, interactive=False) shift = gr.Slider(label="시간 이동", minimum=1.0, maximum=10.0, value=3.0, step=0.1) generate_btn = gr.Button("생성", variant="primary") # Example prompts gr.Markdown("### 📝 예제 프롬프트") gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None) with gr.Column(scale=1): output_gallery = gr.Gallery( label="생성된 이미지", columns=2, rows=2, height=600, object_fit="contain", format="png", interactive=False ) used_seed = gr.Textbox(label="사용된 시드", interactive=False) def update_res_choices(_res_cat): if str(_res_cat) in RES_CHOICES: res_choices = RES_CHOICES[str(_res_cat)] else: res_choices = RES_CHOICES["1024"] return gr.update(value=res_choices[0], choices=res_choices) res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private") # PE enhancement button (Temporarily disabled) # enhance_btn.click( # prompt_enhance, # inputs=[prompt_input, enable_enhance], # outputs=[prompt_input, final_prompt_output] # ) # Dummy enable_enhance variable set to False enable_enhance = gr.State(value=False) generate_btn.click( generate, inputs=[prompt_input, resolution, seed, steps, shift, enable_enhance, random_seed, output_gallery], outputs=[output_gallery, used_seed, seed], api_visibility="public", ) css=''' .fillable{max-width: 1230px !important} ''' if __name__ == "__main__": demo.launch(css=css, mcp_server=True)