File size: 8,278 Bytes
82f3ba5
f022eee
eac9355
3470339
3da3af7
f022eee
3da3af7
 
 
529c8dc
3da3af7
3470339
 
 
 
 
3da3af7
 
3470339
3a03753
4761c77
3da3af7
 
 
 
3470339
 
 
82f3ba5
3470339
f022eee
3470339
 
 
 
f022eee
3470339
 
 
e06f2f9
3470339
f022eee
3470339
 
 
 
 
 
f022eee
3470339
 
 
 
e06f2f9
3470339
 
f022eee
1444940
3470339
3da3af7
3470339
 
e06f2f9
3470339
 
eac9355
3470339
 
3da3af7
3470339
eac9355
3da3af7
 
e06f2f9
3470339
e06f2f9
3470339
 
 
 
 
e06f2f9
 
3470339
e06f2f9
 
 
 
 
 
3da3af7
 
e06f2f9
3da3af7
ca34a82
238793e
 
 
3da3af7
 
3470339
3da3af7
 
 
 
 
 
 
 
 
 
 
 
 
 
238793e
3da3af7
 
 
e06f2f9
238793e
3da3af7
 
 
 
 
 
 
 
 
 
 
 
 
e06f2f9
 
3da3af7
e06f2f9
3da3af7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3470339
f022eee
3470339
 
 
 
f022eee
3470339
f022eee
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# engineers/deformes3D.py
#
# Copyright (C) 2025 Carlos Rodrigues dos Santos
#
# Version: 1.5.1
#
# This version maintains the core FLUX-based keyframe generation and adds the
# LTX-based "enrichment" as a secondary, experimental step for each keyframe,
# allowing for direct comparison without altering the primary workflow.

from PIL import Image, ImageOps
import os
import time
import logging
import gradio as gr
import yaml
import torch
import numpy as np

from managers.flux_kontext_manager import flux_kontext_singleton
from engineers.deformes2D_thinker import deformes2d_thinker_singleton
from aduc_types import LatentConditioningItem
from managers.ltx_manager import ltx_manager_singleton
from managers.vae_manager import vae_manager_singleton
from managers.latent_enhancer_manager import latent_enhancer_specialist_singleton

logger = logging.getLogger(__name__)

class Deformes3DEngine:
    """
    ADUC Specialist for static image (keyframe) generation.
    """
    def __init__(self, workspace_dir):
        self.workspace_dir = workspace_dir
        self.image_generation_helper = flux_kontext_singleton
        logger.info("3D Engine (Image Specialist) ready to receive orders from the Maestro.")

    def _generate_single_keyframe(self, prompt: str, reference_images: list[Image.Image], output_filename: str, width: int, height: int, callback: callable = None) -> str:
        """
        Low-level function that generates a single image using the LTX helper.
        """
        logger.info(f"Generating keyframe '{output_filename}' with prompt: '{prompt}'")
        generated_image = self.image_generation_helper.generate_image(
            reference_images=reference_images, prompt=prompt, width=width,
            height=height, seed=int(time.time()), callback=callback
        )
        final_path = os.path.join(self.workspace_dir, output_filename)
        generated_image.save(final_path)
        logger.info(f"Keyframe successfully saved to: {final_path}")
        return final_path

    def generate_keyframes_from_storyboard(self, storyboard: list, initial_ref_path: str, global_prompt: str, keyframe_resolution: int, general_ref_paths: list, progress_callback_factory: callable = None):
        """
        Orchestrates the generation of all keyframes. 
        """
        current_base_image_path = initial_ref_path
        previous_prompt = "N/A (initial reference image)"
        final_keyframes_gallery = [] #[current_base_image_path]
        width, height = keyframe_resolution, keyframe_resolution
        target_resolution_tuple = (width, height)
        
        num_keyframes_to_generate = len(storyboard) - 1
        logger.info(f"IMAGE SPECIALIST: Received order to generate {num_keyframes_to_generate} keyframes (LTX versions).")

        for i in range(num_keyframes_to_generate):
            scene_index = i + 1
            current_scene = storyboard[i]
            future_scene = storyboard[i+1]
            progress_callback_flux = progress_callback_factory(scene_index, num_keyframes_to_generate) if progress_callback_factory else None
            
            logger.info(f"--> Generating Keyframe {scene_index}/{num_keyframes_to_generate}...")

            # --- STEP A: Generate with FLUX (Primary Method) ---
            logger.info(f"    - Step A: Generating with keyframe...")
            
            img_prompt = deformes2d_thinker_singleton.get_anticipatory_keyframe_prompt(
                global_prompt=global_prompt, scene_history=previous_prompt,
                current_scene_desc=current_scene, future_scene_desc=future_scene,
                last_image_path=current_base_image_path, fixed_ref_paths=general_ref_paths
            )
            
            #flux_ref_paths = list(set([current_base_image_path] + general_ref_paths))
            #flux_ref_images = [Image.open(p) for p in flux_ref_paths]
            
            #flux_keyframe_path = self._generate_single_keyframe(
            #    prompt=img_prompt, reference_images=flux_ref_images,
            #    output_filename=f"keyframe_{scene_index}_flux.png", width=width, height=height,
            #    callback=progress_callback_flux
            #)
            #final_keyframes_gallery.append(flux_keyframe_path)
            
            # --- STEP B: LTX Enrichment Experiment ---
            #logger.info(f"    - Step B: Generating enrichment with LTX...")

            ltx_context_paths = []
            context_paths = []
            context_paths = [current_base_image_path] + [p for p in general_ref_paths if p != current_base_image_path][:3]
            
            ltx_context_paths = list(reversed(context_paths))
            logger.info(f"    - LTX Context Order (Reversed): {[os.path.basename(p) for p in ltx_context_paths]}")

            ltx_conditioning_items = []
            
            weight = 0.6
            for idx, path in enumerate(ltx_context_paths):
                img_pil = Image.open(path).convert("RGB")
                img_processed = self._preprocess_image_for_latent_conversion(img_pil, target_resolution_tuple)
                pixel_tensor = self._pil_to_pixel_tensor(img_processed)
                latent_tensor = vae_manager_singleton.encode(pixel_tensor)
                
                ltx_conditioning_items.append(LatentConditioningItem(latent_tensor, 0, weight))
                
                if idx >= 0:
                    weight -= 0.1
            
            ltx_base_params = {"guidance_scale": 1.0, "stg_scale": 0.001, "num_inference_steps": 25}
            generated_latents, _ = ltx_manager_singleton.generate_latent_fragment(
                height=height, width=width,
                conditioning_items_data=ltx_conditioning_items,
                motion_prompt=img_prompt,
                video_total_frames=48,
                video_fps=24,
                **ltx_base_params
            )

            final_latent = generated_latents[:, :, -1:, :, :]
            upscaled_latent = latent_enhancer_specialist_singleton.upscale(final_latent)
            enriched_pixel_tensor = vae_manager_singleton.decode(upscaled_latent)

            ltx_keyframe_path = os.path.join(self.workspace_dir, f"keyframe_{scene_index}_ltx.png")
            self.save_image_from_tensor(enriched_pixel_tensor, ltx_keyframe_path)
            final_keyframes_gallery.append(ltx_keyframe_path)
            
            # Use the FLUX keyframe as the base for the next iteration to maintain the primary narrative path
            current_base_image_path = ltx_keyframe_path #flux_keyframe_path 
            previous_prompt = img_prompt

        logger.info(f"IMAGE SPECIALIST: Generation of all keyframe versions (LTX) complete.")
        return final_keyframes_gallery

    # --- HELPER FUNCTIONS ---

    def _preprocess_image_for_latent_conversion(self, image: Image.Image, target_resolution: tuple) -> Image.Image:
        """Resizes and fits an image to the target resolution for VAE encoding."""
        if image.size != target_resolution:
            return ImageOps.fit(image, target_resolution, Image.Resampling.LANCZOS)
        return image
        
    def _pil_to_pixel_tensor(self, pil_image: Image.Image) -> torch.Tensor:
        """Helper to convert PIL to the 5D pixel tensor the VAE expects."""
        image_np = np.array(pil_image).astype(np.float32) / 255.0
        tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0).unsqueeze(2)
        return (tensor * 2.0) - 1.0

    def save_image_from_tensor(self, pixel_tensor: torch.Tensor, path: str):
        """Helper to save a 1-frame pixel tensor as an image."""
        tensor_chw = pixel_tensor.squeeze(0).squeeze(1)
        tensor_hwc = tensor_chw.permute(1, 2, 0)
        tensor_hwc = (tensor_hwc.clamp(-1, 1) + 1) / 2.0
        image_np = (tensor_hwc.cpu().float().numpy() * 255).astype(np.uint8)
        Image.fromarray(image_np).save(path)

# --- Singleton Instantiation ---
try:
    with open("config.yaml", 'r') as f:
        config = yaml.safe_load(f)
    WORKSPACE_DIR = config['application']['workspace_dir']
    deformes3d_engine_singleton = Deformes3DEngine(workspace_dir=WORKSPACE_DIR)
except Exception as e:
    logger.error(f"Could not initialize Deformes3DEngine: {e}", exc_info=True)
    deformes3d_engine_singleton = None