Spaces:
Paused
Paused
| import concurrent.futures | |
| import functools | |
| import typing as tp | |
| import gradio as gr | |
| import numpy as np | |
| from magenta_rt import system, audio as audio_lib | |
| class AudioFade: | |
| """Handles the cross-fade between audio chunks. | |
| Args: | |
| chunk_size: Number of audio samples per predicted frame (current | |
| SpectroStream models produce 25Hz frames corresponding to 1920 audio | |
| samples at 48kHz) | |
| num_chunks: Number of audio chunks to fade between. | |
| stereo: Whether the predicted audio is stereo or mono. | |
| """ | |
| def __init__(self, chunk_size: int, num_chunks: int, stereo: bool): | |
| fade_size = chunk_size * num_chunks | |
| self.fade_size = fade_size | |
| self.num_chunks = num_chunks | |
| self.previous_chunk = np.zeros(fade_size) | |
| self.ramp = np.sin(np.linspace(0, np.pi / 2, fade_size)) ** 2 | |
| if stereo: | |
| self.previous_chunk = self.previous_chunk[:, np.newaxis] | |
| self.ramp = self.ramp[:, np.newaxis] | |
| def reset(self): | |
| self.previous_chunk = np.zeros_like(self.previous_chunk) | |
| def __call__(self, chunk: np.ndarray) -> np.ndarray: | |
| chunk[: self.fade_size] *= self.ramp | |
| chunk[: self.fade_size] += self.previous_chunk | |
| self.previous_chunk = chunk[-self.fade_size:] * np.flip(self.ramp) | |
| return chunk[: -self.fade_size] | |
| class MagentaRTStreamer: | |
| """Audio streamer class for our open weights Magenta RT model. | |
| This class holds a pretrained Magenta RT model, a cross-fade state, a | |
| generation state and an asynchronous executor to handle the embedding of text | |
| prompt without interrupting the audio thread. | |
| Args: | |
| system: A MagentaRTBase instance. | |
| """ | |
| def __init__(self, system: system.MagentaRTBase): | |
| super().__init__() | |
| self.system = system | |
| self.fade = AudioFade(chunk_size=1920, num_chunks=1, stereo=True) | |
| self.state = None | |
| self.executor = concurrent.futures.ThreadPoolExecutor() | |
| def warmup(self): | |
| return True | |
| def embed_style(self, style: str): | |
| return self.executor.submit(self.system.embed_style, style) | |
| def embed_audio(self, audio: tuple[float]): | |
| audio = audio_lib.Waveform(np.asarray(audio), 16000) | |
| return self.executor.submit(self.system.embed_style, audio) | |
| def get_style_embedding(self, prompts: dict, force_wait: bool = False): | |
| weighted_embedding = np.zeros((768,), dtype=np.float32) | |
| total_weight = 0.0 | |
| for text_or_audio, weight in prompts.items(): | |
| if not weight: | |
| continue | |
| if isinstance(text_or_audio, np.ndarray): | |
| embedding = self.embed_audio(tuple(text_or_audio)) | |
| else: | |
| if not text_or_audio: | |
| continue | |
| embedding = self.embed_style(text_or_audio) | |
| if force_wait: | |
| embedding.result() | |
| if embedding.done(): | |
| weighted_embedding += embedding.result() * weight | |
| total_weight += weight | |
| if total_weight > 0: | |
| weighted_embedding /= total_weight | |
| return weighted_embedding | |
| def on_stream_start(self, prompts: dict): | |
| self.get_style_embedding(prompts, force_wait=False) | |
| self.get_style_embedding(prompts, force_wait=True) | |
| def reset(self): | |
| self.state = None | |
| self.fade.reset() | |
| self.embed_style.cache_clear() | |
| def generate(self, prompts: dict) -> tuple[int, np.ndarray]: | |
| chunk, self.state = self.system.generate_chunk( | |
| state=self.state, | |
| style=self.get_style_embedding(prompts), | |
| seed=None, | |
| # **ui_params, | |
| ) | |
| return chunk.sample_rate, self.fade(chunk.samples) | |
| def stop(self): | |
| self.executor.shutdown(wait=True) | |
| prompts = { | |
| "dark synthesizer": 1, | |
| "flamenco guitar": 0.7, | |
| "funky bass": 0.0, | |
| "ambient pads": 0.0, | |
| "drum machine": 0.0, | |
| "vocal chops": 0.0, | |
| } | |
| def update_prompts(prompts_state: dict): | |
| global prompts | |
| prompts.clear() | |
| prompts.update(prompts_state) | |
| running = False | |
| MRT = system.MagentaRT(tag="large", device="gpu", lazy=False) | |
| def update_state(*args) -> dict: | |
| new_config = {} | |
| for i in range(0, len(args), 2): | |
| if i + 1 >= len(args): | |
| break | |
| prompt, weight = args[i], args[i + 1] | |
| if not prompt or not weight: | |
| continue | |
| new_config[prompt] = weight | |
| return new_config | |
| def play(): | |
| global running | |
| streamer = MagentaRTStreamer(MRT) | |
| streamer.on_stream_start(prompts) | |
| running = True | |
| while running: | |
| print("Generating audio chunk...") | |
| sample_rate, samples = streamer.generate(prompts) | |
| print("Generated audio chunk...") | |
| yield sample_rate, samples | |
| def stop(): | |
| global running | |
| if not running: | |
| gr.Info("No audio is currently playing.") | |
| running = False | |
| gr.Info("Audio playback stopped.") | |
| with gr.Blocks() as block: | |
| gr.Markdown("# Magenta RT Audio Player") | |
| with gr.Group(): | |
| with gr.Row(): | |
| audio_out = gr.Audio( | |
| label="Magenta RT", | |
| streaming=True, | |
| autoplay=True, | |
| loop=False, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown( | |
| "This app plays audio generated by the Magenta RT model. " | |
| "You can start and stop the audio playback using the buttons below." | |
| ) | |
| prompts_state = gr.State({}) | |
| all_components = [] | |
| for prompt, weight in prompts.items(): | |
| with gr.Row(): | |
| prompt = gr.Textbox(value=prompt, label="Prompt") | |
| slider = gr.Slider(value=weight, label="Weight", minimum=0.0, maximum=1.0, step=0.01) | |
| all_components.extend([prompt, slider]) | |
| for component in all_components: | |
| if isinstance(component, gr.Textbox): | |
| component.submit(update_state, inputs=all_components, outputs=prompts_state) | |
| else: | |
| component.change(update_state, inputs=all_components, outputs=prompts_state) | |
| prompts_state.change(update_prompts, inputs=prompts_state) | |
| with gr.Group(): | |
| with gr.Row(): | |
| play_button = gr.Button("Play", variant="primary") | |
| stop_button = gr.Button("Stop", variant="secondary") | |
| play_button.click(play, outputs=audio_out) | |
| stop_button.click(stop) | |
| block.launch() | |