File size: 6,542 Bytes
af748a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38aaf0e
af748a3
 
 
38aaf0e
af748a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38aaf0e
 
 
af748a3
 
 
 
 
 
38aaf0e
af748a3
 
38aaf0e
af748a3
 
 
369b73f
 
af748a3
 
 
 
 
38aaf0e
 
 
7977760
 
 
 
38aaf0e
 
7977760
 
 
 
 
f93db06
7977760
 
ee6d5df
af748a3
 
 
7977760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af748a3
ee6d5df
af748a3
 
7977760
af748a3
ee6d5df
 
 
 
38aaf0e
206e2a1
 
369b73f
af748a3
 
 
ee6d5df
af748a3
ee6d5df
af748a3
 
acdc971
af748a3
 
 
 
 
 
 
0e78953
 
 
 
 
 
af748a3
38aaf0e
 
 
 
 
7977760
 
 
 
38aaf0e
 
7977760
 
 
 
 
 
 
 
 
 
 
f93db06
38aaf0e
 
af748a3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import concurrent.futures
import functools
import typing as tp

import gradio as gr
import numpy as np

from magenta_rt import system, audio as audio_lib


class AudioFade:
    """Handles the cross-fade between audio chunks.

    Args:
      chunk_size: Number of audio samples per predicted frame (current
        SpectroStream models produce 25Hz frames corresponding to 1920 audio
        samples at 48kHz)
      num_chunks: Number of audio chunks to fade between.
      stereo: Whether the predicted audio is stereo or mono.
    """

    def __init__(self, chunk_size: int, num_chunks: int, stereo: bool):
        fade_size = chunk_size * num_chunks
        self.fade_size = fade_size
        self.num_chunks = num_chunks

        self.previous_chunk = np.zeros(fade_size)
        self.ramp = np.sin(np.linspace(0, np.pi / 2, fade_size)) ** 2

        if stereo:
            self.previous_chunk = self.previous_chunk[:, np.newaxis]
            self.ramp = self.ramp[:, np.newaxis]

    def reset(self):
        self.previous_chunk = np.zeros_like(self.previous_chunk)

    def __call__(self, chunk: np.ndarray) -> np.ndarray:
        chunk[: self.fade_size] *= self.ramp
        chunk[: self.fade_size] += self.previous_chunk
        self.previous_chunk = chunk[-self.fade_size:] * np.flip(self.ramp)
        return chunk[: -self.fade_size]


class MagentaRTStreamer:
    """Audio streamer class for our open weights Magenta RT model.

    This class holds a pretrained Magenta RT model, a cross-fade state, a
    generation state and an asynchronous executor to handle the embedding of text
    prompt without interrupting the audio thread.

    Args:
      system: A MagentaRTBase instance.
    """

    def __init__(self, system: system.MagentaRTBase):
        super().__init__()
        self.system = system
        self.fade = AudioFade(chunk_size=1920, num_chunks=1, stereo=True)
        self.state = None
        self.executor = concurrent.futures.ThreadPoolExecutor()

    @property
    def warmup(self):
        return True

    @functools.cache
    def embed_style(self, style: str):
        return self.executor.submit(self.system.embed_style, style)

    @functools.cache
    def embed_audio(self, audio: tuple[float]):
        audio = audio_lib.Waveform(np.asarray(audio), 16000)
        return self.executor.submit(self.system.embed_style, audio)

    def get_style_embedding(self, prompts: dict, force_wait: bool = False):

        weighted_embedding = np.zeros((768,), dtype=np.float32)
        total_weight = 0.0
        for text_or_audio, weight in prompts.items():
            if not weight:
                continue

            if isinstance(text_or_audio, np.ndarray):
                embedding = self.embed_audio(tuple(text_or_audio))
            else:
                if not text_or_audio:
                    continue
                embedding = self.embed_style(text_or_audio)

            if force_wait:
                embedding.result()
            if embedding.done():
                weighted_embedding += embedding.result() * weight
                total_weight += weight

        if total_weight > 0:
            weighted_embedding /= total_weight

        return weighted_embedding

    def on_stream_start(self, prompts: dict):
        self.get_style_embedding(prompts, force_wait=False)
        self.get_style_embedding(prompts, force_wait=True)

    def reset(self):
        self.state = None
        self.fade.reset()
        self.embed_style.cache_clear()

    def generate(self, prompts: dict) -> tuple[int, np.ndarray]:
        chunk, self.state = self.system.generate_chunk(
            state=self.state,
            style=self.get_style_embedding(prompts),
            seed=None,
            # **ui_params,
        )

        return chunk.sample_rate, self.fade(chunk.samples)

    def stop(self):
        self.executor.shutdown(wait=True)


prompts = {
    "dark synthesizer": 1,
    "flamenco guitar": 0.7,
    "funky bass": 0.0,
    "ambient pads": 0.0,
    "drum machine": 0.0,
    "vocal chops": 0.0,
}


def update_prompts(prompts_state: dict):
    global prompts

    prompts.clear()
    prompts.update(prompts_state)


running = False
MRT = system.MagentaRT(tag="large", device="gpu", lazy=False)


def update_state(*args) -> dict:
    new_config = {}

    for i in range(0, len(args), 2):
        if i + 1 >= len(args):
            break

        prompt, weight = args[i], args[i + 1]

        if not prompt or not weight:
            continue

        new_config[prompt] = weight

    return new_config


def play():
    global running

    streamer = MagentaRTStreamer(MRT)
    streamer.on_stream_start(prompts)

    running = True
    while running:
        print("Generating audio chunk...")

        sample_rate, samples = streamer.generate(prompts)
        print("Generated audio chunk...")

        yield sample_rate, samples


def stop():
    global running

    if not running:
        gr.Info("No audio is currently playing.")

    running = False
    gr.Info("Audio playback stopped.")


with gr.Blocks() as block:
    gr.Markdown("# Magenta RT Audio Player")
    with gr.Group():
        with gr.Row():
            audio_out = gr.Audio(
                label="Magenta RT",
                streaming=True,
                autoplay=True,
                loop=False,
            )

    with gr.Group():
        gr.Markdown(
            "This app plays audio generated by the Magenta RT model. "
            "You can start and stop the audio playback using the buttons below."
        )

        prompts_state = gr.State({})

        all_components = []
        for prompt, weight in prompts.items():
            with gr.Row():
                prompt = gr.Textbox(value=prompt, label="Prompt")
                slider = gr.Slider(value=weight, label="Weight", minimum=0.0, maximum=1.0, step=0.01)

                all_components.extend([prompt, slider])

        for component in all_components:
            if isinstance(component, gr.Textbox):
                component.submit(update_state, inputs=all_components, outputs=prompts_state)
            else:
                component.change(update_state, inputs=all_components, outputs=prompts_state)

    prompts_state.change(update_prompts, inputs=prompts_state)

    with gr.Group():
        with gr.Row():
            play_button = gr.Button("Play", variant="primary")
            stop_button = gr.Button("Stop", variant="secondary")

    play_button.click(play, outputs=audio_out)
    stop_button.click(stop)

block.launch()