dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
56f2217 verified
raw
history blame
23.4 kB
import hashlib
import os
from typing import List, Optional, Union
import torch
from diffusers import FluxModularPipeline, ModularPipelineBlocks
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.modular_pipelines import PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)
from diffusers.utils import (
USE_PEFT_BACKEND,
logger,
scale_lora_layers,
unscale_lora_layers,
)
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
class CachedFluxTextEncoderStep(ModularPipelineBlocks):
model_name = "flux"
def __init__(
self,
use_cache: bool = True,
cache_dir: Optional[str] = None,
load_from_disk: bool = True,
) -> None:
"""Initialize the cached Flux text encoder step.
Args:
use_cache: Whether to enable caching of prompt embeddings. Defaults to True.
cache_dir: Directory to store cache files. If None, uses ~/.cache/flux_prompt_cache.
load_from_disk: Whether to load existing cache from disk on initialization. Defaults to True.
"""
super().__init__()
self.cache = {} if use_cache else None
if use_cache:
self.cache_dir = cache_dir or os.path.join(
os.path.expanduser("~"), ".cache", "flux_prompt_cache"
)
os.makedirs(self.cache_dir, exist_ok=True)
else:
self.cache_dir = None
# Load existing cache if requested
if load_from_disk and use_cache:
self.load_cache_from_disk()
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation"
@property
def expected_components(self):
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("text_encoder_2", T5EncoderModel),
ComponentSpec("tokenizer_2", T5TokenizerFast),
]
@property
def expected_configs(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("joint_attention_kwargs"),
]
@property
def intermediate_outputs(self):
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"text_ids",
type_hint=torch.Tensor,
description="ids from the text sequence for RoPE",
),
]
@staticmethod
def check_inputs(block_state):
for prompt in [block_state.prompt, block_state.prompt_2]:
if prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}"
)
def save_cache_to_disk(self):
"""Save the current cache to disk as a safetensors file."""
if not self.cache or not self.cache_dir:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
# Prepare tensors dict for safetensors
tensors_to_save = {}
for key, tensor in self.cache.items():
# Ensure tensor is on CPU before saving
cpu_tensor = (
tensor.cpu() if tensor.device != torch.device("cpu") else tensor
)
tensors_to_save[key] = cpu_tensor
# Save tensors
save_file(tensors_to_save, cache_file)
logger.info(f"Saved {len(tensors_to_save)} cached embeddings to {cache_file}")
def load_cache_from_disk(self):
"""Load cache from disk using memory-mapped safetensors."""
if not self.cache_dir or self.cache is None:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if not os.path.exists(cache_file):
return
try:
# Open safetensors file in context manager
with safe_open(cache_file, framework="pt", device="cpu") as f:
loaded_count = 0
for key in f.keys():
self.cache[key] = f.get_tensor(key)
loaded_count += 1
logger.debug(
f"Loaded {loaded_count} cached embeddings from {cache_file} (memory-mapped)"
)
except Exception as e:
logger.warning(f"Failed to load cache from disk: {e}")
def clear_cache_from_disk(self):
"""Clear cached safetensors file from disk."""
if not self.cache_dir:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if os.path.exists(cache_file):
os.remove(cache_file)
logger.info(f"Cleared cache file: {cache_file}")
# Also clear the in-memory cache
if self.cache:
self.cache.clear()
def get_cache_size(self):
"""Get the current cache size in MB."""
if not self.cache_dir:
return 0
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if os.path.exists(cache_file):
return os.path.getsize(cache_file) / (1024 * 1024) # Convert to MB
return 0
@staticmethod
def _to_cache_key(prompt: str) -> str:
"""Generate a hash key for a single prompt string."""
return hashlib.sha256(prompt.encode()).hexdigest()
@staticmethod
def _get_cached_prompt_embeds(prompts, cache_instance, cache_suffix, device=None):
"""Split prompts into cached and new, returning indices for reconstruction.
Args:
prompts: List of prompt strings to check against cache.
cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
device: Optional device to move cached tensors to.
Returns:
tuple: (cached_embeds, prompts_to_encode, prompt_indices)
- cached_embeds: List of (idx, embedding) tuples for cached prompts
- prompts_to_encode: List of prompts that need encoding
- prompt_indices: List of original indices for prompts_to_encode
"""
cached_embeds = []
prompts_to_encode = []
prompt_indices = []
for idx, prompt in enumerate(prompts):
cache_key = CachedFluxTextEncoderStep._to_cache_key(prompt + cache_suffix)
if (
cache_instance
and cache_instance.cache
and cache_key in cache_instance.cache
):
cached_tensor = cache_instance.cache[cache_key]
# Move tensor to the correct device if specified
if device is not None and cached_tensor.device != device:
cached_tensor = cached_tensor.to(device)
cached_embeds.append((idx, cached_tensor))
else:
prompts_to_encode.append(prompt)
prompt_indices.append(idx)
return cached_embeds, prompts_to_encode, prompt_indices
@staticmethod
def _cache_prompt_embeds(
prompts, prompt_indices, prompt_embeds, cache_instance, cache_suffix
):
"""Store newly computed embeddings in cache and save to disk.
Args:
prompts: Original full list of prompts.
prompt_indices: Indices of newly encoded prompts in the original list.
prompt_embeds: Newly computed embeddings tensor.
cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
"""
if not cache_instance or cache_instance.cache is None:
return
for i, idx in enumerate(prompt_indices):
cache_key = CachedFluxTextEncoderStep._to_cache_key(
prompts[idx] + cache_suffix
)
# Store in memory cache on CPU to save GPU memory
tensor_slice = prompt_embeds[i : i + 1]
cache_instance.cache[cache_key] = tensor_slice
# Save updated cache to disk
cache_instance.save_cache_to_disk()
@staticmethod
def _merge_cached_prompt_embeds(
cached_embeds, prompt_indices, prompt_embeds, batch_size
):
"""Merge cached and newly computed embeddings back into original batch order.
Args:
cached_embeds: List of (idx, embedding) tuples from cache.
prompt_indices: Indices where new embeddings should be placed.
prompt_embeds: Newly computed embeddings tensor, or None if all cached.
batch_size: Total batch size for output tensor.
Returns:
torch.Tensor: Combined embeddings tensor in correct batch order.
"""
all_embeds = [None] * batch_size
# Place cached embeddings
for idx, embed in cached_embeds:
all_embeds[idx] = embed
# Place new embeddings
if prompt_embeds is not None:
for i, idx in enumerate(prompt_indices):
all_embeds[idx] = prompt_embeds[i : i + 1]
return torch.cat(all_embeds, dim=0)
@staticmethod
def _get_t5_prompt_embeds(
components,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: torch.device = None,
cache_instance=None,
):
"""Encode prompts using T5 text encoder with caching support.
Args:
components: Pipeline components containing T5 encoder and tokenizer.
prompt: Prompt(s) to encode.
num_images_per_prompt: Number of images per prompt for duplication.
max_sequence_length: Maximum sequence length for tokenization.
device: Device to place tensors on.
cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
Returns:
torch.Tensor: T5 prompt embeddings ready for diffusion model.
"""
dtype = components.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
cached_embeds, prompts_to_encode, prompt_indices = (
CachedFluxTextEncoderStep._get_cached_prompt_embeds(
prompt, cache_instance, "_t5", device
)
)
if not prompts_to_encode:
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, None, batch_size
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
return prompt_embeds
if isinstance(components, TextualInversionLoaderMixin):
prompts_to_encode = components.maybe_convert_prompt(
prompts_to_encode, components.tokenizer_2
)
text_inputs = components.tokenizer_2(
prompts_to_encode,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
# Check for truncation
untruncated_ids = components.tokenizer_2(
prompts_to_encode, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = components.tokenizer_2.batch_decode(
untruncated_ids[:, max_sequence_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = components.text_encoder_2(
text_input_ids.to(device), output_hidden_states=False
)[0]
CachedFluxTextEncoderStep._cache_prompt_embeds(
prompt, prompt_indices, prompt_embeds, cache_instance, "_t5"
)
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, prompt_embeds, batch_size
)
_, seq_len, _ = prompt_embeds.shape
# Duplicate for num_images_per_prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
@staticmethod
def _get_clip_prompt_embeds(
components,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
device: torch.device = None,
cache_instance=None,
):
"""Encode prompts using CLIP text encoder with caching support.
Args:
components: Pipeline components containing CLIP encoder and tokenizer.
prompt: Prompt(s) to encode.
num_images_per_prompt: Number of images per prompt for duplication.
device: Device to place tensors on.
cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
Returns:
torch.Tensor: CLIP pooled prompt embeddings ready for diffusion model.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# Split cached and new prompts
cached_embeds, prompts_to_encode, prompt_indices = (
CachedFluxTextEncoderStep._get_cached_prompt_embeds(
prompt, cache_instance, "_clip", device
)
)
# Early return if all prompts are cached
if not prompts_to_encode:
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, None, batch_size
)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
if prompts_to_encode:
if isinstance(components, TextualInversionLoaderMixin):
prompts_to_encode = components.maybe_convert_prompt(
prompts_to_encode, components.tokenizer
)
text_inputs = components.tokenizer(
prompts_to_encode,
padding="max_length",
max_length=components.tokenizer.model_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
tokenizer_max_length = components.tokenizer.model_max_length
untruncated_ids = components.tokenizer(
prompts_to_encode, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = components.tokenizer.batch_decode(
untruncated_ids[:, tokenizer_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = components.text_encoder(
text_input_ids.to(device), output_hidden_states=False
)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(
dtype=components.text_encoder.dtype, device=device
)
# Cache the new embeddings
CachedFluxTextEncoderStep._cache_prompt_embeds(
prompt, prompt_indices, prompt_embeds, cache_instance, "_clip"
)
# Combine cached and newly encoded embeddings in correct order
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds,
prompt_indices,
prompt_embeds if prompts_to_encode else None,
batch_size,
)
# Duplicate for num_images_per_prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
@staticmethod
def encode_prompt(
components,
prompt: Union[str, List[str]] = None,
prompt_2: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
cache_instance: Optional["CachedFluxTextEncoderStep"] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = CachedFluxTextEncoderStep._get_clip_prompt_embeds(
components,
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
cache_instance=cache_instance,
)
prompt_embeds = CachedFluxTextEncoderStep._get_t5_prompt_embeds(
components,
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
cache_instance=cache_instance,
)
if components.text_encoder is not None:
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
dtype = (
components.text_encoder.dtype
if components.text_encoder is not None
else torch.bfloat16
)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
@torch.no_grad()
def __call__(
self, components: FluxModularPipeline, state: PipelineState
) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.joint_attention_kwargs.get("scale", None)
if block_state.joint_attention_kwargs is not None
else None
)
(
block_state.prompt_embeds,
block_state.pooled_prompt_embeds,
block_state.text_ids,
) = self.encode_prompt(
components,
prompt=block_state.prompt,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=block_state.device,
num_images_per_prompt=1, # TODO: hardcoded for now.
max_sequence_length=512,
lora_scale=block_state.text_encoder_lora_scale,
cache_instance=self
if self.cache is not None
else None, # Pass self as cache_instance
)
# Add outputs
self.set_block_state(state, block_state)
return components, state