Spaces:
Paused
Paused
File size: 6,542 Bytes
af748a3 38aaf0e af748a3 38aaf0e af748a3 38aaf0e af748a3 38aaf0e af748a3 38aaf0e af748a3 369b73f af748a3 38aaf0e 7977760 38aaf0e 7977760 f93db06 7977760 ee6d5df af748a3 7977760 af748a3 ee6d5df af748a3 7977760 af748a3 ee6d5df 38aaf0e 206e2a1 369b73f af748a3 ee6d5df af748a3 ee6d5df af748a3 acdc971 af748a3 0e78953 af748a3 38aaf0e 7977760 38aaf0e 7977760 f93db06 38aaf0e af748a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import concurrent.futures
import functools
import typing as tp
import gradio as gr
import numpy as np
from magenta_rt import system, audio as audio_lib
class AudioFade:
"""Handles the cross-fade between audio chunks.
Args:
chunk_size: Number of audio samples per predicted frame (current
SpectroStream models produce 25Hz frames corresponding to 1920 audio
samples at 48kHz)
num_chunks: Number of audio chunks to fade between.
stereo: Whether the predicted audio is stereo or mono.
"""
def __init__(self, chunk_size: int, num_chunks: int, stereo: bool):
fade_size = chunk_size * num_chunks
self.fade_size = fade_size
self.num_chunks = num_chunks
self.previous_chunk = np.zeros(fade_size)
self.ramp = np.sin(np.linspace(0, np.pi / 2, fade_size)) ** 2
if stereo:
self.previous_chunk = self.previous_chunk[:, np.newaxis]
self.ramp = self.ramp[:, np.newaxis]
def reset(self):
self.previous_chunk = np.zeros_like(self.previous_chunk)
def __call__(self, chunk: np.ndarray) -> np.ndarray:
chunk[: self.fade_size] *= self.ramp
chunk[: self.fade_size] += self.previous_chunk
self.previous_chunk = chunk[-self.fade_size:] * np.flip(self.ramp)
return chunk[: -self.fade_size]
class MagentaRTStreamer:
"""Audio streamer class for our open weights Magenta RT model.
This class holds a pretrained Magenta RT model, a cross-fade state, a
generation state and an asynchronous executor to handle the embedding of text
prompt without interrupting the audio thread.
Args:
system: A MagentaRTBase instance.
"""
def __init__(self, system: system.MagentaRTBase):
super().__init__()
self.system = system
self.fade = AudioFade(chunk_size=1920, num_chunks=1, stereo=True)
self.state = None
self.executor = concurrent.futures.ThreadPoolExecutor()
@property
def warmup(self):
return True
@functools.cache
def embed_style(self, style: str):
return self.executor.submit(self.system.embed_style, style)
@functools.cache
def embed_audio(self, audio: tuple[float]):
audio = audio_lib.Waveform(np.asarray(audio), 16000)
return self.executor.submit(self.system.embed_style, audio)
def get_style_embedding(self, prompts: dict, force_wait: bool = False):
weighted_embedding = np.zeros((768,), dtype=np.float32)
total_weight = 0.0
for text_or_audio, weight in prompts.items():
if not weight:
continue
if isinstance(text_or_audio, np.ndarray):
embedding = self.embed_audio(tuple(text_or_audio))
else:
if not text_or_audio:
continue
embedding = self.embed_style(text_or_audio)
if force_wait:
embedding.result()
if embedding.done():
weighted_embedding += embedding.result() * weight
total_weight += weight
if total_weight > 0:
weighted_embedding /= total_weight
return weighted_embedding
def on_stream_start(self, prompts: dict):
self.get_style_embedding(prompts, force_wait=False)
self.get_style_embedding(prompts, force_wait=True)
def reset(self):
self.state = None
self.fade.reset()
self.embed_style.cache_clear()
def generate(self, prompts: dict) -> tuple[int, np.ndarray]:
chunk, self.state = self.system.generate_chunk(
state=self.state,
style=self.get_style_embedding(prompts),
seed=None,
# **ui_params,
)
return chunk.sample_rate, self.fade(chunk.samples)
def stop(self):
self.executor.shutdown(wait=True)
prompts = {
"dark synthesizer": 1,
"flamenco guitar": 0.7,
"funky bass": 0.0,
"ambient pads": 0.0,
"drum machine": 0.0,
"vocal chops": 0.0,
}
def update_prompts(prompts_state: dict):
global prompts
prompts.clear()
prompts.update(prompts_state)
running = False
MRT = system.MagentaRT(tag="large", device="gpu", lazy=False)
def update_state(*args) -> dict:
new_config = {}
for i in range(0, len(args), 2):
if i + 1 >= len(args):
break
prompt, weight = args[i], args[i + 1]
if not prompt or not weight:
continue
new_config[prompt] = weight
return new_config
def play():
global running
streamer = MagentaRTStreamer(MRT)
streamer.on_stream_start(prompts)
running = True
while running:
print("Generating audio chunk...")
sample_rate, samples = streamer.generate(prompts)
print("Generated audio chunk...")
yield sample_rate, samples
def stop():
global running
if not running:
gr.Info("No audio is currently playing.")
running = False
gr.Info("Audio playback stopped.")
with gr.Blocks() as block:
gr.Markdown("# Magenta RT Audio Player")
with gr.Group():
with gr.Row():
audio_out = gr.Audio(
label="Magenta RT",
streaming=True,
autoplay=True,
loop=False,
)
with gr.Group():
gr.Markdown(
"This app plays audio generated by the Magenta RT model. "
"You can start and stop the audio playback using the buttons below."
)
prompts_state = gr.State({})
all_components = []
for prompt, weight in prompts.items():
with gr.Row():
prompt = gr.Textbox(value=prompt, label="Prompt")
slider = gr.Slider(value=weight, label="Weight", minimum=0.0, maximum=1.0, step=0.01)
all_components.extend([prompt, slider])
for component in all_components:
if isinstance(component, gr.Textbox):
component.submit(update_state, inputs=all_components, outputs=prompts_state)
else:
component.change(update_state, inputs=all_components, outputs=prompts_state)
prompts_state.change(update_prompts, inputs=prompts_state)
with gr.Group():
with gr.Row():
play_button = gr.Button("Play", variant="primary")
stop_button = gr.Button("Stop", variant="secondary")
play_button.click(play, outputs=audio_out)
stop_button.click(stop)
block.launch()
|