Spaces:
Running
Running
| import base64 | |
| import logging | |
| import math | |
| import uuid | |
| from io import BytesIO | |
| from typing import Optional | |
| import av | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from comfy.utils import common_upscale | |
| from comfy_api.latest import Input, InputImpl | |
| from comfy_api.util import VideoContainer, VideoCodec | |
| from ._helpers import mimetype_to_extension | |
| def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: | |
| """Converts image data from BytesIO to a torch.Tensor. | |
| Args: | |
| image_bytesio: BytesIO object containing the image data. | |
| mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). | |
| Returns: | |
| A torch.Tensor representing the image (1, H, W, C). | |
| Raises: | |
| PIL.UnidentifiedImageError: If the image data cannot be identified. | |
| ValueError: If the specified mode is invalid. | |
| """ | |
| image = Image.open(image_bytesio) | |
| image = image.convert(mode) | |
| image_array = np.array(image).astype(np.float32) / 255.0 | |
| return torch.from_numpy(image_array).unsqueeze(0) | |
| def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Converts a pair of image tensors to a batch tensor. | |
| If the images are not the same size, the smaller image is resized to | |
| match the larger image. | |
| """ | |
| if image1.shape[1:] != image2.shape[1:]: | |
| image2 = common_upscale( | |
| image2.movedim(-1, 1), | |
| image1.shape[2], | |
| image1.shape[1], | |
| "bilinear", | |
| "center", | |
| ).movedim(1, -1) | |
| return torch.cat((image1, image2), dim=0) | |
| def tensor_to_bytesio( | |
| image: torch.Tensor, | |
| name: Optional[str] = None, | |
| total_pixels: int = 2048 * 2048, | |
| mime_type: str = "image/png", | |
| ) -> BytesIO: | |
| """Converts a torch.Tensor image to a named BytesIO object. | |
| Args: | |
| image: Input torch.Tensor image. | |
| name: Optional filename for the BytesIO object. | |
| total_pixels: Maximum total pixels for potential downscaling. | |
| mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). | |
| Returns: | |
| Named BytesIO object containing the image data, with pointer set to the start of buffer. | |
| """ | |
| if not mime_type: | |
| mime_type = "image/png" | |
| pil_image = tensor_to_pil(image, total_pixels=total_pixels) | |
| img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) | |
| img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" | |
| return img_binary | |
| def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: | |
| """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" | |
| if len(image.shape) > 3: | |
| image = image[0] | |
| # TODO: remove alpha if not allowed and present | |
| input_tensor = image.cpu() | |
| input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() | |
| image_np = (input_tensor.numpy() * 255).astype(np.uint8) | |
| img = Image.fromarray(image_np) | |
| return img | |
| def tensor_to_base64_string( | |
| image_tensor: torch.Tensor, | |
| total_pixels: int = 2048 * 2048, | |
| mime_type: str = "image/png", | |
| ) -> str: | |
| """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. | |
| Args: | |
| image_tensor: Input torch.Tensor image. | |
| total_pixels: Maximum total pixels for potential downscaling. | |
| mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). | |
| Returns: | |
| Base64 encoded string of the image. | |
| """ | |
| pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels) | |
| img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type) | |
| img_bytes = img_byte_arr.getvalue() | |
| # Encode bytes to base64 string | |
| base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") | |
| return base64_encoded_string | |
| def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: | |
| """Converts a PIL Image to a BytesIO object.""" | |
| if not mime_type: | |
| mime_type = "image/png" | |
| img_byte_arr = BytesIO() | |
| # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') | |
| pil_format = mime_type.split("/")[-1].upper() | |
| if pil_format == "JPG": | |
| pil_format = "JPEG" | |
| img.save(img_byte_arr, format=pil_format) | |
| img_byte_arr.seek(0) | |
| return img_byte_arr | |
| def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: | |
| """Downscale input image tensor to roughly the specified total pixels.""" | |
| samples = image.movedim(-1, 1) | |
| total = int(total_pixels) | |
| scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) | |
| if scale_by >= 1: | |
| return image | |
| width = round(samples.shape[3] * scale_by) | |
| height = round(samples.shape[2] * scale_by) | |
| s = common_upscale(samples, width, height, "lanczos", "disabled") | |
| s = s.movedim(1, -1) | |
| return s | |
| def tensor_to_data_uri( | |
| image_tensor: torch.Tensor, | |
| total_pixels: int = 2048 * 2048, | |
| mime_type: str = "image/png", | |
| ) -> str: | |
| """Converts a tensor image to a Data URI string. | |
| Args: | |
| image_tensor: Input torch.Tensor image. | |
| total_pixels: Maximum total pixels for potential downscaling. | |
| mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). | |
| Returns: | |
| Data URI string (e.g., 'data:image/png;base64,...'). | |
| """ | |
| base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) | |
| return f"data:{mime_type};base64,{base64_string}" | |
| def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str: | |
| """Converts an audio input to a base64 string.""" | |
| sample_rate: int = audio["sample_rate"] | |
| waveform: torch.Tensor = audio["waveform"] | |
| audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) | |
| audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) | |
| audio_bytes = audio_bytes_io.getvalue() | |
| return base64.b64encode(audio_bytes).decode("utf-8") | |
| def video_to_base64_string( | |
| video: Input.Video, | |
| container_format: VideoContainer = None, | |
| codec: VideoCodec = None | |
| ) -> str: | |
| """ | |
| Converts a video input to a base64 string. | |
| Args: | |
| video: The video input to convert | |
| container_format: Optional container format to use (defaults to video.container if available) | |
| codec: Optional codec to use (defaults to video.codec if available) | |
| """ | |
| video_bytes_io = BytesIO() | |
| # Use provided format/codec if specified, otherwise use video's own if available | |
| format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) | |
| codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) | |
| video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) | |
| video_bytes_io.seek(0) | |
| return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") | |
| def audio_ndarray_to_bytesio( | |
| audio_data_np: np.ndarray, | |
| sample_rate: int, | |
| container_format: str = "mp4", | |
| codec_name: str = "aac", | |
| ) -> BytesIO: | |
| """ | |
| Encodes a numpy array of audio data into a BytesIO object. | |
| """ | |
| audio_bytes_io = BytesIO() | |
| with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: | |
| audio_stream = output_container.add_stream(codec_name, rate=sample_rate) | |
| frame = av.AudioFrame.from_ndarray( | |
| audio_data_np, | |
| format="fltp", | |
| layout="stereo" if audio_data_np.shape[0] > 1 else "mono", | |
| ) | |
| frame.sample_rate = sample_rate | |
| frame.pts = 0 | |
| for packet in audio_stream.encode(frame): | |
| output_container.mux(packet) | |
| # Flush stream | |
| for packet in audio_stream.encode(None): | |
| output_container.mux(packet) | |
| audio_bytes_io.seek(0) | |
| return audio_bytes_io | |
| def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: | |
| """ | |
| Prepares audio waveform for av library by converting to a contiguous numpy array. | |
| Args: | |
| waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. | |
| Returns: | |
| Contiguous numpy array of the audio waveform. If the audio was batched, | |
| the first item is taken. | |
| """ | |
| if waveform.ndim != 3 or waveform.shape[0] != 1: | |
| raise ValueError("Expected waveform tensor shape (1, channels, samples)") | |
| # If batch is > 1, take first item | |
| if waveform.shape[0] > 1: | |
| waveform = waveform[0] | |
| # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array | |
| audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() | |
| if audio_data_np.dtype != np.float32: | |
| audio_data_np = audio_data_np.astype(np.float32) | |
| return audio_data_np | |
| def audio_input_to_mp3(audio: Input.Audio) -> BytesIO: | |
| waveform = audio["waveform"].cpu() | |
| output_buffer = BytesIO() | |
| output_container = av.open(output_buffer, mode="w", format="mp3") | |
| out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) | |
| out_stream.bit_rate = 320000 | |
| frame = av.AudioFrame.from_ndarray( | |
| waveform.movedim(0, 1).reshape(1, -1).float().numpy(), | |
| format="flt", | |
| layout="mono" if waveform.shape[0] == 1 else "stereo", | |
| ) | |
| frame.sample_rate = audio["sample_rate"] | |
| frame.pts = 0 | |
| output_container.mux(out_stream.encode(frame)) | |
| output_container.mux(out_stream.encode(None)) | |
| output_container.close() | |
| output_buffer.seek(0) | |
| return output_buffer | |
| def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: | |
| """ | |
| Returns a new VideoInput object trimmed from the beginning to the specified duration, | |
| using av to avoid loading entire video into memory. | |
| Args: | |
| video: Input video to trim | |
| duration_sec: Duration in seconds to keep from the beginning | |
| Returns: | |
| VideoFromFile object that owns the output buffer | |
| """ | |
| output_buffer = BytesIO() | |
| input_container = None | |
| output_container = None | |
| try: | |
| # Get the stream source - this avoids loading entire video into memory | |
| # when the source is already a file path | |
| input_source = video.get_stream_source() | |
| # Open containers | |
| input_container = av.open(input_source, mode="r") | |
| output_container = av.open(output_buffer, mode="w", format="mp4") | |
| # Set up output streams for re-encoding | |
| video_stream = None | |
| audio_stream = None | |
| for stream in input_container.streams: | |
| logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) | |
| if isinstance(stream, av.VideoStream): | |
| # Create output video stream with same parameters | |
| video_stream = output_container.add_stream("h264", rate=stream.average_rate) | |
| video_stream.width = stream.width | |
| video_stream.height = stream.height | |
| video_stream.pix_fmt = "yuv420p" | |
| logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate) | |
| elif isinstance(stream, av.AudioStream): | |
| # Create output audio stream with same parameters | |
| audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) | |
| audio_stream.sample_rate = stream.sample_rate | |
| audio_stream.layout = stream.layout | |
| logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) | |
| # Calculate target frame count that's divisible by 16 | |
| fps = input_container.streams.video[0].average_rate | |
| estimated_frames = int(duration_sec * fps) | |
| target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 | |
| if target_frames == 0: | |
| raise ValueError("Video too short: need at least 16 frames for Moonvalley") | |
| frame_count = 0 | |
| audio_frame_count = 0 | |
| # Decode and re-encode video frames | |
| if video_stream: | |
| for frame in input_container.decode(video=0): | |
| if frame_count >= target_frames: | |
| break | |
| # Re-encode frame | |
| for packet in video_stream.encode(frame): | |
| output_container.mux(packet) | |
| frame_count += 1 | |
| # Flush encoder | |
| for packet in video_stream.encode(): | |
| output_container.mux(packet) | |
| logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) | |
| # Decode and re-encode audio frames | |
| if audio_stream: | |
| input_container.seek(0) # Reset to beginning for audio | |
| for frame in input_container.decode(audio=0): | |
| if frame.time >= duration_sec: | |
| break | |
| # Re-encode frame | |
| for packet in audio_stream.encode(frame): | |
| output_container.mux(packet) | |
| audio_frame_count += 1 | |
| # Flush encoder | |
| for packet in audio_stream.encode(): | |
| output_container.mux(packet) | |
| logging.info("Encoded %s audio frames", audio_frame_count) | |
| # Close containers | |
| output_container.close() | |
| input_container.close() | |
| # Return as VideoFromFile using the buffer | |
| output_buffer.seek(0) | |
| return InputImpl.VideoFromFile(output_buffer) | |
| except Exception as e: | |
| # Clean up on error | |
| if input_container is not None: | |
| input_container.close() | |
| if output_container is not None: | |
| output_container.close() | |
| raise RuntimeError(f"Failed to trim video: {str(e)}") from e | |
| def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: | |
| """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" | |
| if wav.dtype.is_floating_point: | |
| return wav | |
| elif wav.dtype == torch.int16: | |
| return wav.float() / (2**15) | |
| elif wav.dtype == torch.int32: | |
| return wav.float() / (2**31) | |
| raise ValueError(f"Unsupported wav dtype: {wav.dtype}") | |
| def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict: | |
| """ | |
| Decode any common audio container from bytes using PyAV and return | |
| a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. | |
| """ | |
| with av.open(BytesIO(audio_bytes)) as af: | |
| if not af.streams.audio: | |
| raise ValueError("No audio stream found in response.") | |
| stream = af.streams.audio[0] | |
| in_sr = int(stream.codec_context.sample_rate) | |
| out_sr = in_sr | |
| frames: list[torch.Tensor] = [] | |
| n_channels = stream.channels or 1 | |
| for frame in af.decode(streams=stream.index): | |
| arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] | |
| buf = torch.from_numpy(arr) | |
| if buf.ndim == 1: | |
| buf = buf.unsqueeze(0) # [T] -> [1, T] | |
| elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: | |
| buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] | |
| elif buf.shape[0] != n_channels: | |
| buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] | |
| frames.append(buf) | |
| if not frames: | |
| raise ValueError("Decoded zero audio frames.") | |
| wav = torch.cat(frames, dim=1) # [C, T] | |
| wav = _f32_pcm(wav) | |
| return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} | |