Spaces:
Running
on
Zero
Running
on
Zero
github-actions[bot]
commited on
Commit
·
039e024
0
Parent(s):
Sync to HuggingFace Spaces
Browse files- .gitattributes +35 -0
- .github/workflows/sync.yml +26 -0
- .gitignore +6 -0
- README.md +31 -0
- app.py +175 -0
- headers.yaml +9 -0
- requirements.txt +6 -0
- youtube.py +42 -0
- zero.py +45 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/sync.yml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to Hugging Face Spaces
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
sync:
|
| 10 |
+
name: Sync
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
|
| 13 |
+
steps:
|
| 14 |
+
- name: Checkout Repository
|
| 15 |
+
uses: actions/checkout@v4
|
| 16 |
+
with:
|
| 17 |
+
lfs: true
|
| 18 |
+
|
| 19 |
+
- name: Sync to Hugging Face Spaces
|
| 20 |
+
uses: JacobLinCool/huggingface-sync@v1
|
| 21 |
+
with:
|
| 22 |
+
github: ${{ secrets.GITHUB_TOKEN }}
|
| 23 |
+
user: jacoblincool # Hugging Face username or organization name
|
| 24 |
+
space: vocal-separation # Hugging Face space name
|
| 25 |
+
token: ${{ secrets.HF_TOKEN }} # Hugging Face token
|
| 26 |
+
configuration: headers.yaml
|
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
|
| 3 |
+
*.wav
|
| 4 |
+
*.mp3
|
| 5 |
+
|
| 6 |
+
__pycache__/
|
README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Vocal Separation SOTA
|
| 3 |
+
emoji: 🎤
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.37.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Vocal Separation SOTA
|
| 14 |
+
|
| 15 |
+
[](https://huggingface.co/spaces/JacobLinCool/vocal-separation)
|
| 16 |
+
|
| 17 |
+
This is a demo for SOTA vocal separation models. Upload an audio file and the model will separate the vocals from the background music.
|
| 18 |
+
|
| 19 |
+
Based on the result of [MDX23](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023/problems/music-demixing-track-mdx-23/leaderboards), the current SOTA model is [BS-RoFormer](https://arxiv.org/abs/2309.02612).
|
| 20 |
+
|
| 21 |
+
For comparison, you can also try the Mel-RoFormer model (a variant of BS-RoFormer) and the popular HTDemucs FT model.
|
| 22 |
+
|
| 23 |
+
## Models
|
| 24 |
+
|
| 25 |
+
- BS-RoFormer
|
| 26 |
+
- Mel-RoFormer
|
| 27 |
+
- HTDemucs FT
|
| 28 |
+
|
| 29 |
+
> The models are trained by the [UVR project](https://github.com/Anjok07/ultimatevocalremovergui).
|
| 30 |
+
|
| 31 |
+
> The code of this app is available on [GitHub](https://github.com/JacobLinCool/vocal-separation), any contributions should go there. Hugging Face Space is force pushed by GitHub Actions.
|
app.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import tempfile
|
| 5 |
+
import numpy as np
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
import librosa
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from audio_separator.separator import Separator
|
| 10 |
+
from zero import dynGPU
|
| 11 |
+
from youtube import youtube
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
separators = {
|
| 15 |
+
"BS-RoFormer": Separator(output_dir=tempfile.gettempdir(), output_format="mp3"),
|
| 16 |
+
"Mel-RoFormer": Separator(output_dir=tempfile.gettempdir(), output_format="mp3"),
|
| 17 |
+
"HTDemucs-FT": Separator(output_dir=tempfile.gettempdir(), output_format="mp3"),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load():
|
| 22 |
+
separators["BS-RoFormer"].load_model("model_bs_roformer_ep_317_sdr_12.9755.ckpt")
|
| 23 |
+
separators["Mel-RoFormer"].load_model(
|
| 24 |
+
"model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"
|
| 25 |
+
)
|
| 26 |
+
separators["HTDemucs-FT"].load_model("htdemucs_ft.yaml")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# sometimes the network might be down, so we retry a few times
|
| 30 |
+
for _ in range(3):
|
| 31 |
+
try:
|
| 32 |
+
load()
|
| 33 |
+
break
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(e)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def merge(outs):
|
| 39 |
+
print(f"Merging {outs}")
|
| 40 |
+
bgm = np.sum(np.array([sf.read(out)[0] for out in outs]), axis=0)
|
| 41 |
+
print(f"Merged shape: {bgm.shape}")
|
| 42 |
+
tmp_file = os.path.join(tempfile.gettempdir(), f"{outs[0].split('/')[-1]}_merged")
|
| 43 |
+
sf.write(tmp_file + ".mp3", bgm, 44100)
|
| 44 |
+
return tmp_file + ".mp3"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def measure_duration(audio: str, model: str) -> int:
|
| 48 |
+
y, sr = librosa.load(audio, sr=44100)
|
| 49 |
+
return int(librosa.get_duration(y=y, sr=sr) / 3.0)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dynGPU(duration=measure_duration)
|
| 53 |
+
def separate(audio: str, model: str) -> Tuple[str, str]:
|
| 54 |
+
separator = separators[model]
|
| 55 |
+
outs = separator.separate(audio)
|
| 56 |
+
outs = [os.path.join(tempfile.gettempdir(), out) for out in outs]
|
| 57 |
+
# roformers
|
| 58 |
+
if len(outs) == 2:
|
| 59 |
+
return outs[1], outs[0]
|
| 60 |
+
# demucs
|
| 61 |
+
if len(outs) == 4:
|
| 62 |
+
bgm = merge(outs[:3])
|
| 63 |
+
return outs[3], bgm
|
| 64 |
+
raise gr.Error("Unknown output format")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def from_youtube(url: str, model: str) -> Tuple[str, str, str]:
|
| 68 |
+
audio = youtube(url)
|
| 69 |
+
return audio, *separate(audio, model)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def plot_spectrogram(audio: str):
|
| 73 |
+
y, sr = librosa.load(audio, sr=44100)
|
| 74 |
+
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
|
| 75 |
+
S_dB = librosa.power_to_db(S, ref=np.max)
|
| 76 |
+
fig = plt.figure(figsize=(15, 5))
|
| 77 |
+
librosa.display.specshow(S_dB, sr=sr, x_axis="time", y_axis="mel")
|
| 78 |
+
plt.colorbar(format="%+2.0f dB")
|
| 79 |
+
plt.title("Mel-frequency spectrogram")
|
| 80 |
+
fig.tight_layout()
|
| 81 |
+
return fig
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
with gr.Blocks() as app:
|
| 85 |
+
with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
|
| 86 |
+
README = f.read()
|
| 87 |
+
# remove yaml front matter
|
| 88 |
+
blocks = README.split("---")
|
| 89 |
+
if len(blocks) > 1:
|
| 90 |
+
README = "---".join(blocks[2:])
|
| 91 |
+
|
| 92 |
+
gr.Markdown(README)
|
| 93 |
+
|
| 94 |
+
with gr.Row():
|
| 95 |
+
with gr.Column():
|
| 96 |
+
gr.Markdown("## Upload an audio file")
|
| 97 |
+
audio = gr.Audio(label="Upload an audio file", type="filepath")
|
| 98 |
+
with gr.Column():
|
| 99 |
+
gr.Markdown(
|
| 100 |
+
"## or use a YouTube URL\n\nTry something on [The First Take](https://www.youtube.com/@The_FirstTake)?"
|
| 101 |
+
)
|
| 102 |
+
yt = gr.Textbox(
|
| 103 |
+
label="YouTube URL", placeholder="https://www.youtube.com/watch?v=..."
|
| 104 |
+
)
|
| 105 |
+
yt_btn = gr.Button("Use this YouTube URL")
|
| 106 |
+
|
| 107 |
+
with gr.Row():
|
| 108 |
+
model = gr.Radio(
|
| 109 |
+
label="Select a model",
|
| 110 |
+
choices=[s for s in separators.keys()],
|
| 111 |
+
value="BS-RoFormer",
|
| 112 |
+
)
|
| 113 |
+
btn = gr.Button("Separate", variant="primary")
|
| 114 |
+
|
| 115 |
+
with gr.Row():
|
| 116 |
+
with gr.Column():
|
| 117 |
+
vocals = gr.Audio(
|
| 118 |
+
label="Vocals", format="mp3", type="filepath", interactive=False
|
| 119 |
+
)
|
| 120 |
+
with gr.Column():
|
| 121 |
+
bgm = gr.Audio(
|
| 122 |
+
label="Background", format="mp3", type="filepath", interactive=False
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
with gr.Row():
|
| 126 |
+
with gr.Column():
|
| 127 |
+
vocal_spec = gr.Plot(label="Vocal spectrogram")
|
| 128 |
+
with gr.Column():
|
| 129 |
+
bgm_spec = gr.Plot(label="Background spectrogram")
|
| 130 |
+
|
| 131 |
+
gr.Examples(
|
| 132 |
+
examples=[
|
| 133 |
+
# I don't have any good examples, please contribute some!
|
| 134 |
+
# Suno's generated musix seems to have too many artifacts
|
| 135 |
+
],
|
| 136 |
+
inputs=[audio],
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
gr.Markdown(
|
| 140 |
+
"""
|
| 141 |
+
- BS-RoFormer: https://arxiv.org/abs/2309.02612
|
| 142 |
+
- Mel-RoFormer: https://arxiv.org/abs/2310.01809
|
| 143 |
+
"""
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
btn.click(
|
| 147 |
+
fn=separate,
|
| 148 |
+
inputs=[audio, model],
|
| 149 |
+
outputs=[vocals, bgm],
|
| 150 |
+
api_name="separate",
|
| 151 |
+
).success(
|
| 152 |
+
fn=plot_spectrogram,
|
| 153 |
+
inputs=[vocals],
|
| 154 |
+
outputs=[vocal_spec],
|
| 155 |
+
).success(
|
| 156 |
+
fn=plot_spectrogram,
|
| 157 |
+
inputs=[bgm],
|
| 158 |
+
outputs=[bgm_spec],
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
yt_btn.click(
|
| 162 |
+
fn=from_youtube,
|
| 163 |
+
inputs=[yt, model],
|
| 164 |
+
outputs=[audio, vocals, bgm],
|
| 165 |
+
).success(
|
| 166 |
+
fn=plot_spectrogram,
|
| 167 |
+
inputs=[vocals],
|
| 168 |
+
outputs=[vocal_spec],
|
| 169 |
+
).success(
|
| 170 |
+
fn=plot_spectrogram,
|
| 171 |
+
inputs=[bgm],
|
| 172 |
+
outputs=[bgm_spec],
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
app.launch(show_error=True)
|
headers.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
title: Vocal Separation SOTA
|
| 2 |
+
emoji: 🎤
|
| 3 |
+
colorFrom: red
|
| 4 |
+
colorTo: gray
|
| 5 |
+
sdk: gradio
|
| 6 |
+
sdk_version: 4.37.2
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
license: mit
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
audio-separator[gpu]; sys_platform != 'darwin'
|
| 3 |
+
audio-separator[cpu]; sys_platform == 'darwin'
|
| 4 |
+
yt_dlp
|
| 5 |
+
librosa
|
| 6 |
+
spaces
|
youtube.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from gradio_client import Client
|
| 4 |
+
import yt_dlp
|
| 5 |
+
import tempfile
|
| 6 |
+
import hashlib
|
| 7 |
+
import shutil
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def youtube(url: str) -> str:
|
| 11 |
+
if not url:
|
| 12 |
+
raise gr.Error("Please input a YouTube URL")
|
| 13 |
+
|
| 14 |
+
hash = hashlib.md5(url.encode()).hexdigest()
|
| 15 |
+
tmp_file = os.path.join(tempfile.gettempdir(), f"{hash}")
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
ydl_opts = {
|
| 19 |
+
"format": "bestaudio/best",
|
| 20 |
+
"outtmpl": tmp_file,
|
| 21 |
+
"postprocessors": [
|
| 22 |
+
{
|
| 23 |
+
"key": "FFmpegExtractAudio",
|
| 24 |
+
"preferredcodec": "mp3",
|
| 25 |
+
"preferredquality": "192",
|
| 26 |
+
}
|
| 27 |
+
],
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 31 |
+
ydl.download([url])
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(e)
|
| 34 |
+
try:
|
| 35 |
+
ytdl = Client("JacobLinCool/yt-dlp")
|
| 36 |
+
file = ytdl.predict(api_name="/download", url=url)
|
| 37 |
+
shutil.move(file, tmp_file + ".mp3")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(e)
|
| 40 |
+
raise gr.Error(f"Failed to download YouTube audio from {url}")
|
| 41 |
+
|
| 42 |
+
return tmp_file + ".mp3"
|
zero.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
from functools import partial
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import spaces
|
| 5 |
+
import spaces.config
|
| 6 |
+
from spaces.zero.decorator import P, R
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _dynGPU(
|
| 10 |
+
fn: Callable[P, R] | None, duration: Callable[P, int], min=30, max=300, step=10
|
| 11 |
+
) -> Callable[P, R]:
|
| 12 |
+
if not spaces.config.Config.zero_gpu:
|
| 13 |
+
return fn
|
| 14 |
+
|
| 15 |
+
funcs = [
|
| 16 |
+
(t, spaces.GPU(duration=t)(lambda *args, **kwargs: fn(*args, **kwargs)))
|
| 17 |
+
for t in range(min, max + 1, step)
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
def wrapper(*args, **kwargs):
|
| 21 |
+
requirement = duration(*args, **kwargs)
|
| 22 |
+
|
| 23 |
+
# find the function that satisfies the duration requirement
|
| 24 |
+
for t, func in funcs:
|
| 25 |
+
if t >= requirement:
|
| 26 |
+
gr.Info(f"Acquiring ZeroGPU for {t} seconds")
|
| 27 |
+
return func(*args, **kwargs)
|
| 28 |
+
|
| 29 |
+
# if no function is found, return the last one
|
| 30 |
+
gr.Info(f"Acquiring ZeroGPU for {funcs[-1][0]} seconds")
|
| 31 |
+
return funcs[-1][1](*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
return wrapper
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def dynGPU(
|
| 37 |
+
fn: Callable[P, R] | None = None,
|
| 38 |
+
duration: Callable[P, int] = lambda: 60,
|
| 39 |
+
min=30,
|
| 40 |
+
max=300,
|
| 41 |
+
step=10,
|
| 42 |
+
) -> Callable[P, R]:
|
| 43 |
+
if fn is None:
|
| 44 |
+
return partial(_dynGPU, duration=duration, min=min, max=max, step=step)
|
| 45 |
+
return _dynGPU(fn, duration, min, max, step)
|