frascuchon HF Staff commited on
Commit
af748a3
·
1 Parent(s): 6a8e8ce

create basic app

Browse files
Files changed (2) hide show
  1. app.py +177 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import functools
3
+ import typing as tp
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ from magenta_rt import system, audio as audio_lib
9
+
10
+
11
+ class AudioFade:
12
+ """Handles the cross-fade between audio chunks.
13
+
14
+ Args:
15
+ chunk_size: Number of audio samples per predicted frame (current
16
+ SpectroStream models produce 25Hz frames corresponding to 1920 audio
17
+ samples at 48kHz)
18
+ num_chunks: Number of audio chunks to fade between.
19
+ stereo: Whether the predicted audio is stereo or mono.
20
+ """
21
+
22
+ def __init__(self, chunk_size: int, num_chunks: int, stereo: bool):
23
+ fade_size = chunk_size * num_chunks
24
+ self.fade_size = fade_size
25
+ self.num_chunks = num_chunks
26
+
27
+ self.previous_chunk = np.zeros(fade_size)
28
+ self.ramp = np.sin(np.linspace(0, np.pi / 2, fade_size)) ** 2
29
+
30
+ if stereo:
31
+ self.previous_chunk = self.previous_chunk[:, np.newaxis]
32
+ self.ramp = self.ramp[:, np.newaxis]
33
+
34
+ def reset(self):
35
+ self.previous_chunk = np.zeros_like(self.previous_chunk)
36
+
37
+ def __call__(self, chunk: np.ndarray) -> np.ndarray:
38
+ chunk[: self.fade_size] *= self.ramp
39
+ chunk[: self.fade_size] += self.previous_chunk
40
+ self.previous_chunk = chunk[-self.fade_size:] * np.flip(self.ramp)
41
+ return chunk[: -self.fade_size]
42
+
43
+
44
+ class MagentaRTStreamer:
45
+ """Audio streamer class for our open weights Magenta RT model.
46
+
47
+ This class holds a pretrained Magenta RT model, a cross-fade state, a
48
+ generation state and an asynchronous executor to handle the embedding of text
49
+ prompt without interrupting the audio thread.
50
+
51
+ Args:
52
+ system: A MagentaRTBase instance.
53
+ """
54
+
55
+ def __init__(self, system: system.MagentaRTBase):
56
+ super().__init__()
57
+ self.system = system
58
+ self.fade = AudioFade(chunk_size=1920, num_chunks=1, stereo=True)
59
+ self.state = None
60
+ self.executor = concurrent.futures.ThreadPoolExecutor()
61
+
62
+ @property
63
+ def warmup(self):
64
+ return True
65
+
66
+ @functools.cache
67
+ def embed_style(self, style: str):
68
+ return self.executor.submit(self.system.embed_style, style)
69
+
70
+ @functools.cache
71
+ def embed_audio(self, audio: tuple[float]):
72
+ audio = audio_lib.Waveform(np.asarray(audio), 16000)
73
+ return self.executor.submit(self.system.embed_style, audio)
74
+
75
+ def get_style_embedding(self, force_wait: bool = False):
76
+ prompts = [
77
+ ("syntethizer", 1),
78
+ ("flamenco guitar", 0.7),
79
+ ] # Parameterize with your prompts
80
+
81
+ weighted_embedding = np.zeros((768,), dtype=np.float32)
82
+ total_weight = 0.0
83
+ for text_or_audio, weight in prompts:
84
+ if not weight:
85
+ continue
86
+
87
+ if isinstance(text_or_audio, np.ndarray):
88
+ embedding = self.embed_audio(tuple(text_or_audio))
89
+ else:
90
+ if not text_or_audio:
91
+ continue
92
+ embedding = self.embed_style(text_or_audio)
93
+
94
+ if force_wait:
95
+ embedding.result()
96
+ if embedding.done():
97
+ weighted_embedding += embedding.result() * weight
98
+ total_weight += weight
99
+
100
+ if total_weight > 0:
101
+ weighted_embedding /= total_weight
102
+
103
+ return weighted_embedding
104
+
105
+ def on_stream_start(self):
106
+ self.get_style_embedding(force_wait=False)
107
+ self.get_style_embedding(force_wait=True)
108
+
109
+ def reset(self):
110
+ self.state = None
111
+ self.fade.reset()
112
+ self.embed_style.cache_clear()
113
+
114
+ def generate(self):
115
+ chunk, self.state = self.system.generate_chunk(
116
+ state=self.state,
117
+ style=self.get_style_embedding(),
118
+ seed=None,
119
+ # **ui_params,
120
+ )
121
+ chunk = self.fade(chunk.samples)
122
+ return chunk
123
+
124
+ def stop(self):
125
+ self.executor.shutdown(wait=True)
126
+
127
+
128
+ is_stopped = False
129
+ MRT = system.MagentaRT(tag="large", device="gpu", lazy=False)
130
+ streamer: tp.Union[MagentaRTStreamer, None] = None
131
+
132
+
133
+ def play():
134
+ global streamer
135
+
136
+ if streamer is not None:
137
+ gr.Info("Audio is already playing.")
138
+ return
139
+
140
+ streamer = MagentaRTStreamer(MRT)
141
+ streamer.on_stream_start()
142
+
143
+ while not is_stopped:
144
+ waveform = streamer.generate()
145
+ yield waveform
146
+
147
+
148
+ def stop():
149
+ global is_stopped, streamer
150
+
151
+ if is_stopped is None:
152
+ gr.Info("No audio is currently playing.")
153
+
154
+ is_stopped = True
155
+
156
+ if streamer is not None:
157
+ streamer.stop()
158
+ del streamer
159
+
160
+ gr.Info("Audio playback stopped.")
161
+
162
+
163
+ with gr.Blocks() as block:
164
+ gr.Markdown("# Magenta RT Audio Player")
165
+ with gr.Group():
166
+ with gr.Row():
167
+ audio_out = gr.Audio(label="Magenta RT", streaming=True, autoplay=True, loop=False)
168
+ # text_out = gr.Textbox(label="Output Text", placeholder="Generated text will appear here", lines=2)
169
+
170
+ with gr.Row():
171
+ play_button = gr.Button("Play", variant="primary")
172
+ stop_button = gr.Button("Stop", variant="secondary")
173
+
174
+ play_button.click(play, outputs=audio_out)
175
+ stop_button.click(stop)
176
+
177
+ block.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ magenta-rt[gpu] @ git+https://github.com/magenta/magenta-realtime.git@main#egg=magenta_rt
2
+ tf-nightly==2.20.0.dev20250619
3
+ tensorflow-text-nightly==2.20.0.dev20250316
4
+ tf-hub-nightly