frascuchon's picture
frascuchon HF Staff
fix: update prompts
f93db06
import concurrent.futures
import functools
import typing as tp
import gradio as gr
import numpy as np
from magenta_rt import system, audio as audio_lib
class AudioFade:
"""Handles the cross-fade between audio chunks.
Args:
chunk_size: Number of audio samples per predicted frame (current
SpectroStream models produce 25Hz frames corresponding to 1920 audio
samples at 48kHz)
num_chunks: Number of audio chunks to fade between.
stereo: Whether the predicted audio is stereo or mono.
"""
def __init__(self, chunk_size: int, num_chunks: int, stereo: bool):
fade_size = chunk_size * num_chunks
self.fade_size = fade_size
self.num_chunks = num_chunks
self.previous_chunk = np.zeros(fade_size)
self.ramp = np.sin(np.linspace(0, np.pi / 2, fade_size)) ** 2
if stereo:
self.previous_chunk = self.previous_chunk[:, np.newaxis]
self.ramp = self.ramp[:, np.newaxis]
def reset(self):
self.previous_chunk = np.zeros_like(self.previous_chunk)
def __call__(self, chunk: np.ndarray) -> np.ndarray:
chunk[: self.fade_size] *= self.ramp
chunk[: self.fade_size] += self.previous_chunk
self.previous_chunk = chunk[-self.fade_size:] * np.flip(self.ramp)
return chunk[: -self.fade_size]
class MagentaRTStreamer:
"""Audio streamer class for our open weights Magenta RT model.
This class holds a pretrained Magenta RT model, a cross-fade state, a
generation state and an asynchronous executor to handle the embedding of text
prompt without interrupting the audio thread.
Args:
system: A MagentaRTBase instance.
"""
def __init__(self, system: system.MagentaRTBase):
super().__init__()
self.system = system
self.fade = AudioFade(chunk_size=1920, num_chunks=1, stereo=True)
self.state = None
self.executor = concurrent.futures.ThreadPoolExecutor()
@property
def warmup(self):
return True
@functools.cache
def embed_style(self, style: str):
return self.executor.submit(self.system.embed_style, style)
@functools.cache
def embed_audio(self, audio: tuple[float]):
audio = audio_lib.Waveform(np.asarray(audio), 16000)
return self.executor.submit(self.system.embed_style, audio)
def get_style_embedding(self, prompts: dict, force_wait: bool = False):
weighted_embedding = np.zeros((768,), dtype=np.float32)
total_weight = 0.0
for text_or_audio, weight in prompts.items():
if not weight:
continue
if isinstance(text_or_audio, np.ndarray):
embedding = self.embed_audio(tuple(text_or_audio))
else:
if not text_or_audio:
continue
embedding = self.embed_style(text_or_audio)
if force_wait:
embedding.result()
if embedding.done():
weighted_embedding += embedding.result() * weight
total_weight += weight
if total_weight > 0:
weighted_embedding /= total_weight
return weighted_embedding
def on_stream_start(self, prompts: dict):
self.get_style_embedding(prompts, force_wait=False)
self.get_style_embedding(prompts, force_wait=True)
def reset(self):
self.state = None
self.fade.reset()
self.embed_style.cache_clear()
def generate(self, prompts: dict) -> tuple[int, np.ndarray]:
chunk, self.state = self.system.generate_chunk(
state=self.state,
style=self.get_style_embedding(prompts),
seed=None,
# **ui_params,
)
return chunk.sample_rate, self.fade(chunk.samples)
def stop(self):
self.executor.shutdown(wait=True)
prompts = {
"dark synthesizer": 1,
"flamenco guitar": 0.7,
"funky bass": 0.0,
"ambient pads": 0.0,
"drum machine": 0.0,
"vocal chops": 0.0,
}
def update_prompts(prompts_state: dict):
global prompts
prompts.clear()
prompts.update(prompts_state)
running = False
MRT = system.MagentaRT(tag="large", device="gpu", lazy=False)
def update_state(*args) -> dict:
new_config = {}
for i in range(0, len(args), 2):
if i + 1 >= len(args):
break
prompt, weight = args[i], args[i + 1]
if not prompt or not weight:
continue
new_config[prompt] = weight
return new_config
def play():
global running
streamer = MagentaRTStreamer(MRT)
streamer.on_stream_start(prompts)
running = True
while running:
print("Generating audio chunk...")
sample_rate, samples = streamer.generate(prompts)
print("Generated audio chunk...")
yield sample_rate, samples
def stop():
global running
if not running:
gr.Info("No audio is currently playing.")
running = False
gr.Info("Audio playback stopped.")
with gr.Blocks() as block:
gr.Markdown("# Magenta RT Audio Player")
with gr.Group():
with gr.Row():
audio_out = gr.Audio(
label="Magenta RT",
streaming=True,
autoplay=True,
loop=False,
)
with gr.Group():
gr.Markdown(
"This app plays audio generated by the Magenta RT model. "
"You can start and stop the audio playback using the buttons below."
)
prompts_state = gr.State({})
all_components = []
for prompt, weight in prompts.items():
with gr.Row():
prompt = gr.Textbox(value=prompt, label="Prompt")
slider = gr.Slider(value=weight, label="Weight", minimum=0.0, maximum=1.0, step=0.01)
all_components.extend([prompt, slider])
for component in all_components:
if isinstance(component, gr.Textbox):
component.submit(update_state, inputs=all_components, outputs=prompts_state)
else:
component.change(update_state, inputs=all_components, outputs=prompts_state)
prompts_state.change(update_prompts, inputs=prompts_state)
with gr.Group():
with gr.Row():
play_button = gr.Button("Play", variant="primary")
stop_button = gr.Button("Stop", variant="secondary")
play_button.click(play, outputs=audio_out)
stop_button.click(stop)
block.launch()