Spaces:
Runtime error
Runtime error
| import gc | |
| import os | |
| import torch | |
| from diffusers import AutoencoderKL, UNet2DConditionModel | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( | |
| retrieve_latents, | |
| ) | |
| from polygraphy import cuda | |
| from ...pipeline import StreamDiffusion | |
| from .builder import EngineBuilder, create_onnx_path | |
| from .engine import AutoencoderKLEngine, UNet2DConditionModelEngine | |
| from .models import VAE, BaseModel, UNet, VAEEncoder | |
| class TorchVAEEncoder(torch.nn.Module): | |
| def __init__(self, vae: AutoencoderKL): | |
| super().__init__() | |
| self.vae = vae | |
| def forward(self, x: torch.Tensor): | |
| return retrieve_latents(self.vae.encode(x)) | |
| def compile_vae_encoder( | |
| vae: TorchVAEEncoder, | |
| model_data: BaseModel, | |
| onnx_path: str, | |
| onnx_opt_path: str, | |
| engine_path: str, | |
| opt_batch_size: int = 1, | |
| engine_build_options: dict = {}, | |
| ): | |
| builder = EngineBuilder(model_data, vae, device=torch.device("cuda")) | |
| builder.build( | |
| onnx_path, | |
| onnx_opt_path, | |
| engine_path, | |
| opt_batch_size=opt_batch_size, | |
| **engine_build_options, | |
| ) | |
| def compile_vae_decoder( | |
| vae: AutoencoderKL, | |
| model_data: BaseModel, | |
| onnx_path: str, | |
| onnx_opt_path: str, | |
| engine_path: str, | |
| opt_batch_size: int = 1, | |
| engine_build_options: dict = {}, | |
| ): | |
| vae = vae.to(torch.device("cuda")) | |
| builder = EngineBuilder(model_data, vae, device=torch.device("cuda")) | |
| builder.build( | |
| onnx_path, | |
| onnx_opt_path, | |
| engine_path, | |
| opt_batch_size=opt_batch_size, | |
| **engine_build_options, | |
| ) | |
| def compile_unet( | |
| unet: UNet2DConditionModel, | |
| model_data: BaseModel, | |
| onnx_path: str, | |
| onnx_opt_path: str, | |
| engine_path: str, | |
| opt_batch_size: int = 1, | |
| engine_build_options: dict = {}, | |
| ): | |
| unet = unet.to(torch.device("cuda"), dtype=torch.float16) | |
| builder = EngineBuilder(model_data, unet, device=torch.device("cuda")) | |
| builder.build( | |
| onnx_path, | |
| onnx_opt_path, | |
| engine_path, | |
| opt_batch_size=opt_batch_size, | |
| **engine_build_options, | |
| ) | |
| def accelerate_with_tensorrt( | |
| stream: StreamDiffusion, | |
| engine_dir: str, | |
| max_batch_size: int = 2, | |
| min_batch_size: int = 1, | |
| use_cuda_graph: bool = False, | |
| engine_build_options: dict = {}, | |
| ): | |
| if "opt_batch_size" not in engine_build_options or engine_build_options["opt_batch_size"] is None: | |
| engine_build_options["opt_batch_size"] = max_batch_size | |
| text_encoder = stream.text_encoder | |
| unet = stream.unet | |
| vae = stream.vae | |
| del stream.unet, stream.vae, stream.pipe.unet, stream.pipe.vae | |
| vae_config = vae.config | |
| vae_dtype = vae.dtype | |
| unet.to(torch.device("cpu")) | |
| vae.to(torch.device("cpu")) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| onnx_dir = os.path.join(engine_dir, "onnx") | |
| os.makedirs(onnx_dir, exist_ok=True) | |
| unet_engine_path = f"{engine_dir}/unet.engine" | |
| vae_encoder_engine_path = f"{engine_dir}/vae_encoder.engine" | |
| vae_decoder_engine_path = f"{engine_dir}/vae_decoder.engine" | |
| unet_model = UNet( | |
| fp16=True, | |
| device=stream.device, | |
| max_batch_size=max_batch_size, | |
| min_batch_size=min_batch_size, | |
| embedding_dim=text_encoder.config.hidden_size, | |
| unet_dim=unet.config.in_channels, | |
| ) | |
| vae_decoder_model = VAE( | |
| device=stream.device, | |
| max_batch_size=max_batch_size, | |
| min_batch_size=min_batch_size, | |
| ) | |
| vae_encoder_model = VAEEncoder( | |
| device=stream.device, | |
| max_batch_size=max_batch_size, | |
| min_batch_size=min_batch_size, | |
| ) | |
| if not os.path.exists(unet_engine_path): | |
| compile_unet( | |
| unet, | |
| unet_model, | |
| create_onnx_path("unet", onnx_dir, opt=False), | |
| create_onnx_path("unet", onnx_dir, opt=True), | |
| unet_engine_path, | |
| **engine_build_options, | |
| ) | |
| else: | |
| del unet | |
| if not os.path.exists(vae_decoder_engine_path): | |
| vae.forward = vae.decode | |
| compile_vae_decoder( | |
| vae, | |
| vae_decoder_model, | |
| create_onnx_path("vae_decoder", onnx_dir, opt=False), | |
| create_onnx_path("vae_decoder", onnx_dir, opt=True), | |
| vae_decoder_engine_path, | |
| **engine_build_options, | |
| ) | |
| if not os.path.exists(vae_encoder_engine_path): | |
| vae_encoder = TorchVAEEncoder(vae).to(torch.device("cuda")) | |
| compile_vae_encoder( | |
| vae_encoder, | |
| vae_encoder_model, | |
| create_onnx_path("vae_encoder", onnx_dir, opt=False), | |
| create_onnx_path("vae_encoder", onnx_dir, opt=True), | |
| vae_encoder_engine_path, | |
| **engine_build_options, | |
| ) | |
| del vae | |
| cuda_steram = cuda.Stream() | |
| stream.unet = UNet2DConditionModelEngine(unet_engine_path, cuda_steram, use_cuda_graph=use_cuda_graph) | |
| stream.vae = AutoencoderKLEngine( | |
| vae_encoder_engine_path, | |
| vae_decoder_engine_path, | |
| cuda_steram, | |
| stream.pipe.vae_scale_factor, | |
| use_cuda_graph=use_cuda_graph, | |
| ) | |
| setattr(stream.vae, "config", vae_config) | |
| setattr(stream.vae, "dtype", vae_dtype) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return stream | |