SkyReels-V1-Hunyuan-I2V-HFIE / skyreelsinfer /skyreels_video_infer.py
jbilcke-hf's picture
Upload 5 files
e238883 verified
import logging
import os
import threading
import time
from datetime import timedelta
from typing import Any
from typing import Dict
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers import HunyuanVideoTransformer3DModel
from PIL import Image
from torchao.quantization import float8_weight_only
from torchao.quantization import quantize_
from transformers import LlamaModel
from . import TaskType
from .offload import Offload
from .offload import OffloadConfig
from .pipelines import SkyreelsVideoPipeline
logger = logging.getLogger("SkyreelsVideoInfer")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class SkyReelsVideoSingleGpuInfer:
def _load_model(
self,
model_id: str,
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
quant_model: bool = True,
gpu_device: str = "cuda:0",
) -> SkyreelsVideoPipeline:
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
text_encoder = LlamaModel.from_pretrained(
base_model_id,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
).to("cpu")
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id,
# subfolder="transformer",
torch_dtype=torch.bfloat16,
device="cpu",
).to("cpu")
if quant_model:
quantize_(text_encoder, float8_weight_only(), device=gpu_device)
text_encoder.to("cpu")
torch.cuda.empty_cache()
quantize_(transformer, float8_weight_only(), device=gpu_device)
transformer.to("cpu")
torch.cuda.empty_cache()
pipe = SkyreelsVideoPipeline.from_pretrained(
base_model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16,
).to("cpu")
pipe.vae.enable_tiling()
torch.cuda.empty_cache()
return pipe
def __init__(
self,
task_type: TaskType,
model_id: str,
quant_model: bool = True,
local_rank: int = 0,
world_size: int = 1,
is_offload: bool = True,
offload_config: OffloadConfig = OffloadConfig(),
enable_cfg_parallel: bool = True,
):
self.task_type = task_type
self.gpu_rank = local_rank
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:23456",
timeout=timedelta(seconds=600),
world_size=world_size,
rank=local_rank,
)
os.environ["LOCAL_RANK"] = str(local_rank)
logger.info(f"rank:{local_rank} Distributed backend: {dist.get_backend()}")
torch.cuda.set_device(dist.get_rank())
torch.backends.cuda.enable_cudnn_sdp(False)
gpu_device = f"cuda:{dist.get_rank()}"
self.pipe: SkyreelsVideoPipeline = self._load_model(
model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
)
from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
max_batch_dim_size = 2 if enable_cfg_parallel and world_size > 1 else 1
max_ulysses_dim_size = int(world_size / max_batch_dim_size)
logger.info(f"max_batch_dim_size: {max_batch_dim_size}, max_ulysses_dim_size:{max_ulysses_dim_size}")
mesh = init_context_parallel_mesh(
self.pipe.device.type,
max_ring_dim_size=1,
max_batch_dim_size=max_batch_dim_size,
)
parallelize_pipe(self.pipe, mesh=mesh)
parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
if is_offload:
Offload.offload(
pipeline=self.pipe,
config=offload_config,
)
else:
self.pipe.to(gpu_device)
if offload_config.compiler_transformer:
torch._dynamo.config.suppress_errors = True
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_{world_size}"
self.pipe.transformer = torch.compile(
self.pipe.transformer,
mode="max-autotune-no-cudagraphs",
dynamic=True,
)
self.warm_up()
def warm_up(self):
init_kwargs = {
"prompt": "A woman is dancing in a room",
"height": 544,
"width": 960,
"guidance_scale": 6,
"num_inference_steps": 1,
"negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
"num_frames": 97,
"generator": torch.Generator("cuda").manual_seed(42),
"embedded_guidance_scale": 1.0,
}
if self.task_type == TaskType.I2V:
init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
self.pipe(**init_kwargs)
def damon_inference(self, request_queue: mp.Queue, response_queue: mp.Queue):
response_queue.put(f"rank:{self.gpu_rank} ready")
logger.info(f"rank:{self.gpu_rank} finish init pipe")
while True:
logger.info(f"rank:{self.gpu_rank} waiting for request")
kwargs = request_queue.get()
logger.info(f"rank:{self.gpu_rank} kwargs: {kwargs}")
if "seed" in kwargs:
kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
del kwargs["seed"]
start_time = time.time()
assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
out = self.pipe(**kwargs).frames[0]
logger.info(f"rank:{dist.get_rank()} inference time: {time.time() - start_time}")
if dist.get_rank() == 0:
response_queue.put(out)
def single_gpu_run(
rank,
task_type: TaskType,
model_id: str,
request_queue: mp.Queue,
response_queue: mp.Queue,
quant_model: bool = True,
world_size: int = 1,
is_offload: bool = True,
offload_config: OffloadConfig = OffloadConfig(),
enable_cfg_parallel: bool = True,
):
pipe = SkyReelsVideoSingleGpuInfer(
task_type=task_type,
model_id=model_id,
quant_model=quant_model,
local_rank=rank,
world_size=world_size,
is_offload=is_offload,
offload_config=offload_config,
enable_cfg_parallel=enable_cfg_parallel,
)
pipe.damon_inference(request_queue, response_queue)
class SkyReelsVideoInfer:
def __init__(
self,
task_type: TaskType,
model_id: str,
quant_model: bool = True,
world_size: int = 1,
is_offload: bool = True,
offload_config: OffloadConfig = OffloadConfig(),
enable_cfg_parallel: bool = True,
):
self.world_size = world_size
smp = mp.get_context("spawn")
self.REQ_QUEUES: mp.Queue = smp.Queue()
self.RESP_QUEUE: mp.Queue = smp.Queue()
assert self.world_size > 0, "gpu_num must be greater than 0"
spawn_thread = threading.Thread(
target=self.lauch_single_gpu_infer,
args=(task_type, model_id, quant_model, world_size, is_offload, offload_config, enable_cfg_parallel),
daemon=True,
)
spawn_thread.start()
logger.info(f"Started multi-GPU thread with GPU_NUM: {world_size}")
print(f"Started multi-GPU thread with GPU_NUM: {world_size}")
# Block and wait for the prediction process to start
for _ in range(world_size):
msg = self.RESP_QUEUE.get()
logger.info(f"launch_multi_gpu get init msg: {msg}")
print(f"launch_multi_gpu get init msg: {msg}")
def lauch_single_gpu_infer(
self,
task_type: TaskType,
model_id: str,
quant_model: bool = True,
world_size: int = 1,
is_offload: bool = True,
offload_config: OffloadConfig = OffloadConfig(),
enable_cfg_parallel: bool = True,
):
mp.spawn(
single_gpu_run,
nprocs=world_size,
join=True,
daemon=True,
args=(
task_type,
model_id,
self.REQ_QUEUES,
self.RESP_QUEUE,
quant_model,
world_size,
is_offload,
offload_config,
enable_cfg_parallel,
),
)
logger.info(f"finish lanch multi gpu infer, world_size:{world_size}")
def inference(self, kwargs: Dict[str, Any]):
# put request to singlegpuinfer
for _ in range(self.world_size):
self.REQ_QUEUES.put(kwargs)
return self.RESP_QUEUE.get()