NariLabs commited on
Commit
1315cad
·
verified ·
1 Parent(s): 8d6458b

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ banner.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ _kyutai
3
+ __pycache__
4
+ *.npz
5
+ *.safetensors
6
+ *.model
7
+ *.DS_Store
8
+ *.parquet
9
+ *.wav
10
+ *.mp3
11
+ weights/
12
+ *.egg-info/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Nari Labs
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,104 @@
1
- ---
2
- title: Dia2 2B
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Banner](banner.gif)
2
+
3
+ <div align="center">
4
+ <a href="https://huggingface.co/nari-labs/Dia2-2B"><img src="https://img.shields.io/badge/HF%20Repo-Dia2--2B-orange?style=for-the-badge"></a>
5
+ <a href="https://discord.gg/bJq6vjRRKv"><img src="https://img.shields.io/badge/Discord-Join%20Chat-7289DA?logo=discord&style=for-the-badge"></a>
6
+ <a href="https://github.com/nari-labs/dia2/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg?style=for-the-badge"></a>
7
+ </div>
8
+
9
+
10
+ **Dia2** is a **streaming dialogue TTS model** created by Nari Labs.
11
+
12
+ The model does not need the entire text to produce the audio, and can start generating as the first few words are given as input. You can condition the output on audio, enabling natural conversations in realtime.
13
+
14
+ We provide model checkpoints (1B, 2B) and inference code to accelerate research. The model only supports up to 2 minutes of generation in English.
15
+
16
+ ⚠️ Quality and voices vary per generation, as the model is not fine-tuned on a specific voice. Use with prefix or fine-tune in order to obtain stable output.
17
+
18
+ ## Upcoming
19
+
20
+ - Bonsai (JAX) implementation
21
+ - Dia2 TTS Server: Real streaming support
22
+ - Sori: Dia2-powered speech-to-speech engine written in Rust
23
+
24
+ ## Quickstart
25
+
26
+ > **Requirement** — install [uv](https://docs.astral.sh/uv/) and use CUDA 12.8+
27
+ > drivers. All commands below run through `uv run …` as a rule.
28
+
29
+ 1. **Install dependencies (one-time):**
30
+ ```bash
31
+ uv sync
32
+ ```
33
+ 2. **Prepare a script:** edit `input.txt` using `[S1]` / `[S2]` speaker tags.
34
+ 3. **Generate audio:**
35
+ ```bash
36
+ uv run -m dia2.cli \
37
+ --hf nari-labs/Dia2-2B \
38
+ --input input.txt \
39
+ --cfg 6.0 --temperature 0.8 \
40
+ --cuda-graph --verbose \
41
+ output.wav
42
+ ```
43
+ The first run downloads weights/tokenizer/Mimi. The CLI auto-selects CUDA when available (otherwise CPU) and defaults to bfloat16 precision—override with `--device` / `--dtype` if needed.
44
+ 4. **Conditional Generation (recommended for stable use):**
45
+ ```bash
46
+ uv run -m dia2.cli \
47
+ --hf nari-labs/Dia2-2B \
48
+ --input input.txt \
49
+ --prefix-speaker-1 example_prefix1.wav \
50
+ --prefix-speaker-2 example_prefix2.wav \
51
+ --cuda-graph --verbose \
52
+ output_conditioned.wav
53
+ ```
54
+ Condition the generation on previous conversational context in order to generate natural output for your speech-to-speech system. For example, place the voice of your assistant as prefix speaker 1, place user's audio input as prefix speaker 2, and generate the response to user's input.
55
 
56
+ Whisper is used to transcribe each prefix file, which takes additional time. We include example prefix files as `example_prefix1.wav` and `example_prefix2.wav` (both files are output created by the model).
57
+ 6. **Gradio for Easy Usage**
58
+ ```bash
59
+ uv run gradio_app.py
60
+ ```
61
+
62
+ ### Programmatic Usage
63
+ ```python
64
+ from dia2 import Dia2, GenerationConfig, SamplingConfig
65
+
66
+ dia = Dia2.from_repo("nari-labs/Dia2-2B", device="cuda", dtype="bfloat16")
67
+ config = GenerationConfig(
68
+ cfg_scale=2.0,
69
+ audio=SamplingConfig(temperature=0.8, top_k=50),
70
+ use_cuda_graph=True,
71
+ )
72
+ result = dia.generate("[S1] Hello Dia2!", config=config, output_wav="hello.wav", verbose=True)
73
+ ```
74
+ Generation runs until the runtime config's `max_context_steps` (1500, 2 minutes)
75
+ or until EOS is detected. `GenerationResult` includes audio tokens, waveform tensor,
76
+ and word timestamps relative to Mimi’s ~12.5 Hz frame rate.
77
+
78
+ ## Hugging Face
79
+
80
+ | Variant | Repo |
81
+ | --- | --- |
82
+ | Dia2-1B | [`nari-labs/Dia2-1B`](https://huggingface.co/nari-labs/Dia2-1B)
83
+ | Dia2-2B | [`nari-labs/Dia2-2B`](https://huggingface.co/nari-labs/Dia2-2B)
84
+
85
+ ## License & Attribution
86
+
87
+ Licensed under [Apache 2.0](LICENSE). All third-party assets (Kyutai Mimi codec, etc.) retain their original licenses.
88
+
89
+ ## Disclaimer
90
+
91
+ This project offers a high-fidelity speech generation model intended for research and educational use. The following uses are **strictly forbidden**:
92
+
93
+ - **Identity Misuse**: Do not produce audio resembling real individuals without permission.
94
+ - **Deceptive Content**: Do not use this model to generate misleading content (e.g. fake news)
95
+ - **Illegal or Malicious Use**: Do not use this model for activities that are illegal or intended to cause harm.
96
+
97
+ By using this model, you agree to uphold relevant legal standards and ethical responsibilities. We **are not responsible** for any misuse and firmly oppose any unethical usage of this technology.
98
+
99
+ ## Acknowledgements
100
+ - We thank the [TPU Research Cloud](https://sites.research.google/trc/about/) program for providing compute for training.
101
+ - Our work was heavily inspired by [KyutaiTTS](https://kyutai.org/next/tts) and [Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice)
102
+
103
+ ---
104
+ Questions? Join our [Discord](https://discord.gg/bJq6vjRRKv) or open an issue.
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import io
5
+ import os
6
+ from pathlib import Path
7
+ from typing import List, Tuple
8
+
9
+ import gradio as gr
10
+ import torch
11
+
12
+ from dia2 import Dia2, GenerationConfig, SamplingConfig
13
+
14
+ DEFAULT_REPO = os.environ.get("DIA2_DEFAULT_REPO", "nari-labs/Dia2-2B")
15
+ MAX_TURNS = 10
16
+ INITIAL_TURNS = 2
17
+
18
+ _dia: Dia2 | None = None
19
+
20
+
21
+ def _get_dia() -> Dia2:
22
+ global _dia
23
+ if _dia is None:
24
+ _dia = Dia2.from_repo(DEFAULT_REPO, device="cuda", dtype="bfloat16")
25
+ return _dia
26
+
27
+
28
+ def _concat_script(turn_count: int, turn_values: List[str]) -> str:
29
+ lines: List[str] = []
30
+ for idx in range(min(turn_count, len(turn_values))):
31
+ text = (turn_values[idx] or "").strip()
32
+ if not text:
33
+ continue
34
+ speaker = "[S1]" if idx % 2 == 0 else "[S2]"
35
+ lines.append(f"{speaker} {text}")
36
+ return "\n".join(lines)
37
+
38
+
39
+ EXAMPLES: dict[str, List[str]] = {
40
+ "Intro": [
41
+ "Hello Dia2 fans! Today we're unveiling the new open TTS model.",
42
+ "Sounds exciting. Can you show a sample right now?",
43
+ "Absolutely. (laughs) Just press generate.",
44
+ ],
45
+ "Customer Support": [
46
+ "Thanks for calling. How can I help you today?",
47
+ "My parcel never arrived and it's been two weeks.",
48
+ "I'm sorry about that. Let me check your tracking number.",
49
+ "Appreciate it. I really need that package soon.",
50
+ ],
51
+ }
52
+
53
+
54
+ def _apply_turn_visibility(count: int) -> List[gr.Update]:
55
+ return [gr.update(visible=i < count) for i in range(MAX_TURNS)]
56
+
57
+
58
+ def _add_turn(count: int):
59
+ count = min(count + 1, MAX_TURNS)
60
+ return (count, *_apply_turn_visibility(count))
61
+
62
+
63
+ def _remove_turn(count: int):
64
+ count = max(1, count - 1)
65
+ return (count, *_apply_turn_visibility(count))
66
+
67
+
68
+ def _load_example(name: str, count: int):
69
+ data = EXAMPLES.get(name)
70
+ if not data:
71
+ return (count, *_apply_turn_visibility(count))
72
+ new_count = min(len(data), MAX_TURNS)
73
+ updates: List[gr.Update] = []
74
+ for idx in range(MAX_TURNS):
75
+ if idx < new_count:
76
+ updates.append(gr.update(value=data[idx], visible=True))
77
+ else:
78
+ updates.append(gr.update(value="", visible=idx < INITIAL_TURNS))
79
+ return (new_count, *updates)
80
+
81
+
82
+ def _prepare_prefix(file_path: str | None) -> str | None:
83
+ if not file_path:
84
+ return None
85
+ path = Path(file_path)
86
+ if not path.exists():
87
+ return None
88
+ return str(path)
89
+
90
+
91
+ def generate_audio(
92
+ turn_count: int,
93
+ *inputs,
94
+ ):
95
+ turn_values = list(inputs[:MAX_TURNS])
96
+ voice_s1 = inputs[MAX_TURNS]
97
+ voice_s2 = inputs[MAX_TURNS + 1]
98
+ cfg_scale = float(inputs[MAX_TURNS + 2])
99
+ text_temperature = float(inputs[MAX_TURNS + 3])
100
+ audio_temperature = float(inputs[MAX_TURNS + 4])
101
+ text_top_k = int(inputs[MAX_TURNS + 5])
102
+ audio_top_k = int(inputs[MAX_TURNS + 6])
103
+ include_prefix = bool(inputs[MAX_TURNS + 7])
104
+
105
+ script = _concat_script(turn_count, turn_values)
106
+ if not script.strip():
107
+ raise gr.Error("Please enter at least one non-empty speaker turn.")
108
+
109
+ dia = _get_dia()
110
+ config = GenerationConfig(
111
+ cfg_scale=cfg_scale,
112
+ text=SamplingConfig(temperature=text_temperature, top_k=text_top_k),
113
+ audio=SamplingConfig(temperature=audio_temperature, top_k=audio_top_k),
114
+ use_cuda_graph=True,
115
+ )
116
+ kwargs = {
117
+ "prefix_speaker_1": _prepare_prefix(voice_s1),
118
+ "prefix_speaker_2": _prepare_prefix(voice_s2),
119
+ "include_prefix": include_prefix,
120
+ }
121
+ buffer = io.StringIO()
122
+ with contextlib.redirect_stdout(buffer):
123
+ result = dia.generate(
124
+ script,
125
+ config=config,
126
+ output_wav=None,
127
+ verbose=True,
128
+ **kwargs,
129
+ )
130
+ waveform = result.waveform.detach().cpu().numpy()
131
+ sample_rate = result.sample_rate
132
+ timestamps = result.timestamps
133
+ log_text = buffer.getvalue().strip()
134
+ table = [[w, round(t, 3)] for w, t in timestamps]
135
+ return (sample_rate, waveform), table, log_text or "Generation finished."
136
+
137
+
138
+ def build_interface() -> gr.Blocks:
139
+ with gr.Blocks(
140
+ title="Dia2 TTS", css=".compact-turn textarea {min-height: 60px}"
141
+ ) as demo:
142
+ gr.Markdown(
143
+ """## Dia2 — Open TTS Model
144
+ Compose dialogue, attach optional voice prompts, and generate audio (CUDA graphs enabled by default)."""
145
+ )
146
+ turn_state = gr.State(INITIAL_TURNS)
147
+ with gr.Row(equal_height=True):
148
+ example_dropdown = gr.Dropdown(
149
+ choices=["(select example)"] + list(EXAMPLES.keys()),
150
+ label="Examples",
151
+ value="(select example)",
152
+ )
153
+ with gr.Row(equal_height=True):
154
+ with gr.Column(scale=1):
155
+ with gr.Group():
156
+ gr.Markdown("### Script")
157
+ controls = []
158
+ for idx in range(MAX_TURNS):
159
+ speaker = "[S1]" if idx % 2 == 0 else "[S2]"
160
+ box = gr.Textbox(
161
+ label=f"{speaker} turn {idx + 1}",
162
+ lines=2,
163
+ elem_classes=["compact-turn"],
164
+ placeholder=f"Enter dialogue for {speaker}…",
165
+ visible=idx < INITIAL_TURNS,
166
+ )
167
+ controls.append(box)
168
+ with gr.Row():
169
+ add_btn = gr.Button("Add Turn")
170
+ remove_btn = gr.Button("Remove Turn")
171
+ with gr.Group():
172
+ gr.Markdown("### Voice Prompts")
173
+ with gr.Row():
174
+ voice_s1 = gr.File(
175
+ label="[S1] voice (wav/mp3)", type="filepath"
176
+ )
177
+ voice_s2 = gr.File(
178
+ label="[S2] voice (wav/mp3)", type="filepath"
179
+ )
180
+ with gr.Group():
181
+ gr.Markdown("### Sampling")
182
+ cfg_scale = gr.Slider(
183
+ 1.0, 8.0, value=6.0, step=0.1, label="CFG Scale"
184
+ )
185
+ with gr.Group():
186
+ gr.Markdown("#### Text Sampling")
187
+ text_temperature = gr.Slider(
188
+ 0.1, 1.5, value=0.6, step=0.05, label="Text Temperature"
189
+ )
190
+ text_top_k = gr.Slider(
191
+ 1, 200, value=50, step=1, label="Text Top-K"
192
+ )
193
+ with gr.Group():
194
+ gr.Markdown("#### Audio Sampling")
195
+ audio_temperature = gr.Slider(
196
+ 0.1, 1.5, value=0.8, step=0.05, label="Audio Temperature"
197
+ )
198
+ audio_top_k = gr.Slider(
199
+ 1, 200, value=50, step=1, label="Audio Top-K"
200
+ )
201
+ include_prefix = gr.Checkbox(
202
+ label="Keep prefix audio in output", value=False
203
+ )
204
+ generate_btn = gr.Button("Generate", variant="primary")
205
+ with gr.Column(scale=1):
206
+ gr.Markdown("### Output")
207
+ audio_out = gr.Audio(label="Waveform", interactive=False)
208
+ timestamps = gr.Dataframe(
209
+ headers=["word", "seconds"], label="Timestamps"
210
+ )
211
+ log_box = gr.Textbox(label="Logs", lines=8)
212
+
213
+ add_btn.click(
214
+ lambda c: _add_turn(c),
215
+ inputs=turn_state,
216
+ outputs=[turn_state, *controls],
217
+ )
218
+ remove_btn.click(
219
+ lambda c: _remove_turn(c),
220
+ inputs=turn_state,
221
+ outputs=[turn_state, *controls],
222
+ )
223
+ example_dropdown.change(
224
+ lambda name, c: _load_example(name, c),
225
+ inputs=[example_dropdown, turn_state],
226
+ outputs=[turn_state, *controls],
227
+ )
228
+
229
+ generate_btn.click(
230
+ generate_audio,
231
+ inputs=[
232
+ turn_state,
233
+ *controls,
234
+ voice_s1,
235
+ voice_s2,
236
+ cfg_scale,
237
+ text_temperature,
238
+ audio_temperature,
239
+ text_top_k,
240
+ audio_top_k,
241
+ include_prefix,
242
+ ],
243
+ outputs=[audio_out, timestamps, log_box],
244
+ )
245
+ return demo
246
+
247
+
248
+ if __name__ == "__main__":
249
+ app = build_interface()
250
+ app.queue(default_concurrency_limit=1)
251
+ app.launch(share=True)
banner.gif ADDED

Git LFS Details

  • SHA256: 43198c9f6216a8884031b8875aea5efb8822f1ea188ee4e059512001cc837d11
  • Pointer size: 132 Bytes
  • Size of remote file: 2.01 MB
dia2/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import DiaConfig, load_config
2
+ from .core.model import Dia2Model
3
+ from .engine import Dia2
4
+ from .generation import (
5
+ GenerationConfig,
6
+ GenerationResult,
7
+ PrefixConfig,
8
+ SamplingConfig,
9
+ )
10
+
11
+ __all__ = [
12
+ "DiaConfig",
13
+ "Dia2Model",
14
+ "load_config",
15
+ "GenerationConfig",
16
+ "GenerationResult",
17
+ "PrefixConfig",
18
+ "SamplingConfig",
19
+ "Dia2",
20
+ ]
dia2/assets.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ ASSET_MANIFEST = os.environ.get("DIA2_ASSET_MANIFEST", "dia2_assets.json")
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class AssetBundle:
16
+ config_path: str
17
+ weights_path: str
18
+ tokenizer_id: Optional[str]
19
+ mimi_id: Optional[str]
20
+ repo_id: Optional[str]
21
+
22
+
23
+ def resolve_assets(
24
+ *,
25
+ repo: Optional[str],
26
+ config_path: Optional[str | Path],
27
+ weights_path: Optional[str | Path],
28
+ manifest_name: Optional[str] = None,
29
+ ) -> AssetBundle:
30
+ repo_id = repo
31
+ manifest_name = manifest_name or ASSET_MANIFEST
32
+ if repo_id and (config_path or weights_path):
33
+ raise ValueError("Provide either repo or config+weights, not both")
34
+ if config_path is None or weights_path is None:
35
+ if repo_id is None:
36
+ raise ValueError("Must specify repo or config+weights")
37
+ manifest = load_manifest(repo_id, manifest_name)
38
+ config_name = manifest.get("config", "config.json")
39
+ weights_name = manifest.get("weights", "model.safetensors")
40
+ config_local = hf_hub_download(repo_id, config_name)
41
+ weights_local = hf_hub_download(repo_id, weights_name)
42
+ return AssetBundle(
43
+ config_path=config_local,
44
+ weights_path=weights_local,
45
+ tokenizer_id=manifest.get("tokenizer") or repo_id,
46
+ mimi_id=manifest.get("mimi"),
47
+ repo_id=repo_id,
48
+ )
49
+ return AssetBundle(str(config_path), str(weights_path), None, None, repo_id)
50
+
51
+
52
+ def load_manifest(repo_id: str, manifest_name: str) -> dict:
53
+ if not manifest_name:
54
+ return {}
55
+ try:
56
+ path = hf_hub_download(repo_id, manifest_name)
57
+ except Exception:
58
+ return {}
59
+ try:
60
+ return json.loads(Path(path).read_text())
61
+ except json.JSONDecodeError:
62
+ return {}
63
+
64
+
65
+ __all__ = ["AssetBundle", "ASSET_MANIFEST", "resolve_assets", "load_manifest"]
dia2/audio/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .codec import MimiCodec, DEFAULT_MIMI_MODEL_ID, MimiConfig
2
+ from .grid import delay_frames, undelay_frames, mask_audio_logits, fill_audio_channels, write_wav
3
+
4
+ __all__ = [
5
+ "MimiCodec",
6
+ "DEFAULT_MIMI_MODEL_ID",
7
+ "MimiConfig",
8
+ "delay_frames",
9
+ "undelay_frames",
10
+ "mask_audio_logits",
11
+ "fill_audio_channels",
12
+ "write_wav",
13
+ ]
dia2/audio/codec.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import MimiModel
9
+
10
+
11
+ DEFAULT_MIMI_MODEL_ID = "kyutai/mimi"
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class MimiConfig:
16
+ model_id: str = DEFAULT_MIMI_MODEL_ID
17
+ dtype: Optional[torch.dtype] = None
18
+
19
+
20
+ class MimiCodec(nn.Module):
21
+ """Thin wrapper around transformers' MimiModel for decoding audio tokens."""
22
+
23
+ def __init__(self, model: MimiModel, device: torch.device) -> None:
24
+ super().__init__()
25
+ self.model = model
26
+ self.device = device
27
+ cfg = getattr(model, "config", None)
28
+ self.sample_rate = getattr(cfg, "sampling_rate", 24000)
29
+ self.frame_rate = getattr(cfg, "frame_rate", 12.5)
30
+ self.samples_per_frame = int(round(self.sample_rate / self.frame_rate)) if self.frame_rate else 0
31
+
32
+ @classmethod
33
+ def from_pretrained(
34
+ cls,
35
+ model_id: str = DEFAULT_MIMI_MODEL_ID,
36
+ *,
37
+ device: torch.device,
38
+ dtype: Optional[torch.dtype] = None,
39
+ ) -> "MimiCodec":
40
+ model = MimiModel.from_pretrained(
41
+ model_id,
42
+ torch_dtype=dtype,
43
+ low_cpu_mem_usage=True,
44
+ )
45
+ model = model.to(device)
46
+ model.eval()
47
+ return cls(model, device)
48
+
49
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
50
+ codes = codes.to(self.device)
51
+ with torch.inference_mode():
52
+ audio, _ = self.model.decode(codes, return_dict=False)
53
+ return torch.clamp(audio, -1.0, 1.0)
54
+
55
+ def encode(self, audio: torch.Tensor, *, return_dict: bool = False):
56
+ audio = audio.to(self.device)
57
+ with torch.inference_mode():
58
+ return self.model.encode(audio, return_dict=return_dict)
dia2/audio/grid.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Sequence
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ def delay_frames(aligned: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
11
+ channels, total = aligned.shape
12
+ max_delay = max(delays) if delays else 0
13
+ out = aligned.new_full((channels, total + max_delay), pad_id)
14
+ for idx, delay in enumerate(delays):
15
+ out[idx, delay : delay + total] = aligned[idx]
16
+ return out
17
+
18
+
19
+ def undelay_frames(delayed: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
20
+ channels, total = delayed.shape
21
+ max_delay = max(delays) if delays else 0
22
+ target = max(0, total - max_delay)
23
+ out = delayed.new_full((channels, target), pad_id)
24
+ for idx, delay in enumerate(delays):
25
+ out[idx] = delayed[idx, delay : delay + target]
26
+ return out
27
+
28
+
29
+ def mask_audio_logits(logits: torch.Tensor, pad_idx: int, bos_idx: int) -> torch.Tensor:
30
+ if logits.shape[-1] == 0:
31
+ return logits
32
+ max_idx = logits.shape[-1] - 1
33
+ targets = [idx for idx in (pad_idx, bos_idx) if 0 <= idx <= max_idx]
34
+ if not targets:
35
+ return logits
36
+ masked = logits.clone()
37
+ neg_inf = torch.finfo(masked.dtype).min
38
+ for idx in targets:
39
+ masked[..., idx] = neg_inf
40
+ return masked
41
+
42
+
43
+ def fill_audio_channels(
44
+ delays: Sequence[int],
45
+ constants,
46
+ step: int,
47
+ step_tokens: torch.Tensor,
48
+ audio_buf: torch.Tensor,
49
+ ) -> None:
50
+ for cb, delay in enumerate(delays):
51
+ idx = step - delay
52
+ in_bounds = idx >= 0 and step < audio_buf.shape[-1]
53
+ if in_bounds:
54
+ step_tokens[:, 2 + cb, 0] = audio_buf[:, cb, step]
55
+ else:
56
+ step_tokens[:, 2 + cb, 0] = constants.audio_bos
57
+
58
+
59
+ def write_wav(path: str | Path, audio: np.ndarray, sample_rate: int) -> None:
60
+ path = Path(path)
61
+ path.parent.mkdir(parents=True, exist_ok=True)
62
+ audio = np.clip(audio, -1.0, 1.0)
63
+ pcm16 = (audio * 32767.0).astype(np.int16)
64
+ import wave
65
+
66
+ with wave.open(str(path), "wb") as handle:
67
+ handle.setnchannels(1)
68
+ handle.setsampwidth(2)
69
+ handle.setframerate(sample_rate)
70
+ handle.writeframes(pcm16.tobytes())
71
+
72
+
73
+ __all__ = [
74
+ "delay_frames",
75
+ "undelay_frames",
76
+ "mask_audio_logits",
77
+ "fill_audio_channels",
78
+ "write_wav",
79
+ ]
dia2/cli.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ import torch
6
+
7
+ from .engine import Dia2
8
+ from .generation import (
9
+ build_generation_config,
10
+ load_script_text,
11
+ validate_generation_params,
12
+ )
13
+
14
+
15
+ def main() -> None:
16
+ parser = argparse.ArgumentParser(description="Generate audio with Dia2")
17
+ parser.add_argument("--config", help="Path to config.json (overrides repo lookup)")
18
+ parser.add_argument(
19
+ "--weights", help="Path to model.safetensors (overrides repo lookup)"
20
+ )
21
+ parser.add_argument(
22
+ "--hf",
23
+ required=False,
24
+ help="Hugging Face repo id to download config/weights from (e.g. nari-labs/Dia2-2B)",
25
+ )
26
+ parser.add_argument(
27
+ "--input", default="input.txt", help="Script text file (default: input.txt)"
28
+ )
29
+ parser.add_argument("output", help="Output WAV path")
30
+ parser.add_argument(
31
+ "--device",
32
+ default=None,
33
+ help="Computation device (defaults to cuda if available, else cpu)",
34
+ )
35
+ parser.add_argument(
36
+ "--dtype",
37
+ choices=["auto", "float32", "bfloat16"],
38
+ default="bfloat16",
39
+ help="Computation dtype (default: bfloat16)",
40
+ )
41
+ parser.add_argument("--topk", type=int, default=50)
42
+ parser.add_argument("--temperature", type=float, default=0.8)
43
+ parser.add_argument("--cfg", type=float, default=1.0)
44
+ parser.add_argument("--tokenizer", help="Tokenizer repo or local path override")
45
+ parser.add_argument(
46
+ "--mimi", help="Mimi repo id override (defaults to config/assets)"
47
+ )
48
+ parser.add_argument("--prefix-speaker-1", help="Prefix audio file for speaker 1")
49
+ parser.add_argument("--prefix-speaker-2", help="Prefix audio file for speaker 2")
50
+ parser.add_argument(
51
+ "--include-prefix",
52
+ action="store_true",
53
+ help="Keep prefix audio in the final waveform (default: trimmed)",
54
+ )
55
+ parser.add_argument(
56
+ "--verbose", action="store_true", help="Print generation progress logs"
57
+ )
58
+ parser.add_argument(
59
+ "--cuda-graph",
60
+ action="store_true",
61
+ help="Run generation with CUDA graph capture",
62
+ )
63
+ args = parser.parse_args()
64
+
65
+ device = args.device
66
+ if device is None or device == "auto":
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ dtype = args.dtype or "bfloat16"
69
+
70
+ repo = args.hf
71
+ if repo:
72
+ dia = Dia2(
73
+ repo=repo,
74
+ device=device,
75
+ dtype=dtype,
76
+ tokenizer_id=args.tokenizer,
77
+ mimi_id=args.mimi,
78
+ )
79
+ elif args.config and args.weights:
80
+ dia = Dia2.from_local(
81
+ config_path=args.config,
82
+ weights_path=args.weights,
83
+ device=device,
84
+ dtype=dtype,
85
+ tokenizer_id=args.tokenizer,
86
+ mimi_id=args.mimi,
87
+ )
88
+ else:
89
+ raise ValueError("Provide --hf/--variant or both --config and --weights")
90
+
91
+ script = load_script_text(args.input)
92
+ temperature, top_k, cfg_scale = validate_generation_params(
93
+ temperature=args.temperature,
94
+ top_k=args.topk,
95
+ cfg_scale=args.cfg,
96
+ )
97
+ config = build_generation_config(
98
+ temperature=temperature,
99
+ top_k=top_k,
100
+ cfg_scale=cfg_scale,
101
+ )
102
+ overrides = {}
103
+ if args.cuda_graph:
104
+ overrides["use_cuda_graph"] = True
105
+ if args.prefix_speaker_1:
106
+ overrides["prefix_speaker_1"] = args.prefix_speaker_1
107
+ if args.prefix_speaker_2:
108
+ overrides["prefix_speaker_2"] = args.prefix_speaker_2
109
+ if args.include_prefix:
110
+ overrides["include_prefix"] = True
111
+
112
+ dia.generate(
113
+ script,
114
+ config=config,
115
+ output_wav=args.output,
116
+ verbose=args.verbose,
117
+ **overrides,
118
+ )
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()
dia2/config.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class DataConfig:
11
+ channels: int
12
+ text_vocab_size: int
13
+ audio_vocab_size: int
14
+ action_vocab_size: int
15
+ text_pad_token_id: int
16
+ text_new_word_token_id: int
17
+ text_zero_token_id: int
18
+ audio_pad_token_id: int
19
+ audio_bos_token_id: int
20
+ action_pad_token_id: int
21
+ action_new_word_token_id: int
22
+ delay_pattern: List[int]
23
+ first_word_min_start: int
24
+ max_pad: int
25
+ second_stream_ahead: int
26
+ tokenizer_path: Optional[str] = None
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class DecoderConfig:
31
+ n_layer: int
32
+ n_embd: int
33
+ n_hidden: int
34
+ gqa_query_heads: int
35
+ kv_heads: int
36
+ gqa_head_dim: int
37
+ dropout: float
38
+ low_rank_dim: int | None = None
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class DepformerConfig:
43
+ n_layer: int
44
+ n_embd: int
45
+ n_hidden: int
46
+ gqa_query_heads: int
47
+ kv_heads: int
48
+ gqa_head_dim: int
49
+ apply_rope: bool
50
+ text_embedding: bool
51
+ mlp_activations: List[str]
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class LinearHeadConfig:
56
+ mlp_activations: List[str]
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class ModelConfig:
61
+ decoder: DecoderConfig
62
+ depformer: DepformerConfig
63
+ linear: LinearHeadConfig
64
+ dropout: float
65
+ rope_min_timescale: int
66
+ rope_max_timescale: int
67
+ normalization_layer_epsilon: float
68
+
69
+
70
+ @dataclass(frozen=True)
71
+ class RuntimeConfig:
72
+ weights_schedule: List[int]
73
+ max_context_steps: int
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class AssetsConfig:
78
+ tokenizer: Optional[str]
79
+ mimi: Optional[str]
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class DiaConfig:
84
+ data: DataConfig
85
+ model: ModelConfig
86
+ runtime: RuntimeConfig
87
+ assets: AssetsConfig
88
+
89
+
90
+ def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig:
91
+ block = block or {}
92
+ weights_schedule = block.get("weights_schedule")
93
+ if weights_schedule is None:
94
+ audio_channels = max(0, data_cfg.channels - 2)
95
+ weights_schedule = list(range(max(audio_channels - 1, 0)))
96
+ max_context = block.get("max_context_steps", 1500)
97
+ return RuntimeConfig(
98
+ weights_schedule=list(weights_schedule),
99
+ max_context_steps=int(max_context),
100
+ )
101
+
102
+
103
+ def load_config(path: str | Path) -> DiaConfig:
104
+ cfg = json.loads(Path(path).read_text())
105
+ data = cfg["data"]
106
+ model = cfg["model"]
107
+ runtime_cfg_raw = cfg.get("runtime")
108
+ if runtime_cfg_raw is None:
109
+ raise ValueError(f"Config '{path}' is missing a runtime block")
110
+
111
+ decoder_cfg = DecoderConfig(
112
+ n_layer=model["decoder"]["n_layer"],
113
+ n_embd=model["decoder"]["n_embd"],
114
+ n_hidden=model["decoder"]["n_hidden"],
115
+ gqa_query_heads=model["decoder"]["gqa_query_heads"],
116
+ kv_heads=model["decoder"]["kv_heads"],
117
+ gqa_head_dim=model["decoder"]["gqa_head_dim"],
118
+ dropout=model.get("dropout", 0.0),
119
+ low_rank_dim=model["decoder"].get("low_rank_dim"),
120
+ )
121
+
122
+ depformer_cfg = DepformerConfig(
123
+ n_layer=model["depformer"]["n_layer"],
124
+ n_embd=model["depformer"]["n_embd"],
125
+ n_hidden=model["depformer"]["n_hidden"],
126
+ gqa_query_heads=model["depformer"]["gqa_query_heads"],
127
+ kv_heads=model["depformer"]["kv_heads"],
128
+ gqa_head_dim=model["depformer"]["gqa_head_dim"],
129
+ apply_rope=model["depformer"].get("apply_rope", True),
130
+ text_embedding=model["depformer"].get("text_embedding", True),
131
+ mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]),
132
+ )
133
+
134
+ data_cfg = DataConfig(
135
+ channels=data["channels"],
136
+ text_vocab_size=data["text_vocab_size"],
137
+ audio_vocab_size=data["audio_vocab_size"],
138
+ action_vocab_size=data["action_vocab_size"],
139
+ text_pad_token_id=data["text_pad_token_id"],
140
+ text_new_word_token_id=data["text_new_word_token_id"],
141
+ text_zero_token_id=data.get("text_zero_token_id", 7),
142
+ audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1),
143
+ audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2),
144
+ action_pad_token_id=data["action_pad_token_id"],
145
+ action_new_word_token_id=data["action_new_word_token_id"],
146
+ delay_pattern=list(data.get("delay_pattern", [])),
147
+ first_word_min_start=data.get("first_word_min_start", 0),
148
+ max_pad=data.get("max_pad", 0),
149
+ second_stream_ahead=data.get("second_stream_ahead", 0),
150
+ tokenizer_path=data.get("tokenizer_path"),
151
+ )
152
+
153
+ runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg)
154
+
155
+ linear_cfg = LinearHeadConfig(
156
+ mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]),
157
+ )
158
+
159
+ model_cfg = ModelConfig(
160
+ decoder=decoder_cfg,
161
+ depformer=depformer_cfg,
162
+ linear=linear_cfg,
163
+ dropout=model.get("dropout", 0.0),
164
+ rope_min_timescale=model.get("rope_min_timescale", 1),
165
+ rope_max_timescale=model.get("rope_max_timescale", 10000),
166
+ normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5),
167
+ )
168
+
169
+ assets_raw = cfg.get("assets") or {}
170
+ assets_cfg = AssetsConfig(
171
+ tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path,
172
+ mimi=assets_raw.get("mimi"),
173
+ )
174
+
175
+ return DiaConfig(
176
+ data=data_cfg,
177
+ model=model_cfg,
178
+ runtime=runtime_cfg,
179
+ assets=assets_cfg,
180
+ )
dia2/core/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import Dia2Model, DecodeState
2
+ from .transformer import TransformerDecoder
3
+ from .depformer import Depformer
4
+
5
+ __all__ = [
6
+ "Dia2Model",
7
+ "DecodeState",
8
+ "TransformerDecoder",
9
+ "Depformer",
10
+ ]
dia2/core/cache.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass
10
+ class CacheSlot:
11
+ keys: torch.Tensor
12
+ values: torch.Tensor
13
+
14
+ def __post_init__(self) -> None:
15
+ self.max_steps = self.keys.shape[2]
16
+ self.head_dim = self.keys.shape[3]
17
+ self.flat_heads = self.keys.shape[0] * self.keys.shape[1]
18
+ device = self.keys.device
19
+ self.length = torch.zeros((), dtype=torch.long, device=device)
20
+ self.positions = torch.arange(self.max_steps, dtype=torch.long, device=device)
21
+
22
+ @classmethod
23
+ def allocate(
24
+ cls,
25
+ *,
26
+ batch_size: int,
27
+ heads: int,
28
+ max_steps: int,
29
+ head_dim: int,
30
+ device: torch.device,
31
+ dtype: torch.dtype,
32
+ ) -> "CacheSlot":
33
+ keys = torch.zeros(batch_size, heads, max_steps, head_dim, device=device, dtype=dtype)
34
+ values = torch.zeros_like(keys)
35
+ return cls(keys, values)
36
+
37
+ def reset(self) -> None:
38
+ self.length.zero_()
39
+
40
+ def write_and_view(
41
+ self,
42
+ key_chunk: torch.Tensor,
43
+ value_chunk: torch.Tensor,
44
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
+ step = key_chunk.shape[2]
46
+ start = self.length
47
+ indices = self.positions[:step] + start
48
+ expanded = indices.unsqueeze(0).expand(self.flat_heads, -1)
49
+
50
+ flat_keys = self.keys.view(self.flat_heads, self.max_steps, self.head_dim)
51
+ flat_values = self.values.view(self.flat_heads, self.max_steps, self.head_dim)
52
+ flat_key_chunk = key_chunk.reshape(self.flat_heads, step, self.head_dim)
53
+ flat_value_chunk = value_chunk.reshape(self.flat_heads, step, self.head_dim)
54
+ scatter_index = expanded.unsqueeze(-1).expand_as(flat_key_chunk)
55
+ flat_keys.scatter_(1, scatter_index, flat_key_chunk)
56
+ flat_values.scatter_(1, scatter_index, flat_value_chunk)
57
+
58
+ self.length.add_(step)
59
+ bool_mask = (self.positions >= self.length).view(1, 1, 1, self.max_steps)
60
+ mask_dtype = self.keys.dtype
61
+ mask_value = torch.finfo(mask_dtype).min
62
+ attn_mask = torch.zeros_like(bool_mask, dtype=mask_dtype)
63
+ attn_mask = attn_mask.masked_fill(bool_mask, mask_value)
64
+ return self.keys, self.values, attn_mask
65
+
66
+
67
+ class KVCache:
68
+ def __init__(self, slots: List[CacheSlot]) -> None:
69
+ self.slots = slots
70
+
71
+ @classmethod
72
+ def allocate(
73
+ cls,
74
+ *,
75
+ num_layers: int,
76
+ batch_size: int,
77
+ heads: int,
78
+ max_steps: int,
79
+ head_dim: int,
80
+ device: torch.device,
81
+ dtype: torch.dtype,
82
+ ) -> "KVCache":
83
+ slots = [
84
+ CacheSlot.allocate(
85
+ batch_size=batch_size,
86
+ heads=heads,
87
+ max_steps=max_steps,
88
+ head_dim=head_dim,
89
+ device=device,
90
+ dtype=dtype,
91
+ )
92
+ for _ in range(num_layers)
93
+ ]
94
+ return cls(slots)
95
+
96
+ def get_slot(self, index: int) -> CacheSlot:
97
+ return self.slots[index]
98
+
99
+ def reset(self) -> None:
100
+ for slot in self.slots:
101
+ slot.reset()
102
+
103
+ clear = reset
104
+
105
+
106
+ __all__ = ["CacheSlot", "KVCache"]
dia2/core/depformer.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from ..config import DiaConfig
10
+ from .cache import KVCache
11
+ from .layers import MultiStreamEmbedding, Mlp, RotaryEmbedding
12
+ from .precision import Precision
13
+
14
+
15
+ class ScheduleAttention(nn.Module):
16
+ """Depformer attention that mirrors dia_v2 ScheduleAttention."""
17
+
18
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype) -> None:
19
+ super().__init__()
20
+ dep_cfg = config.model.depformer
21
+ runtime = config.runtime
22
+ self.schedule = runtime.weights_schedule
23
+ self.num_query_heads = dep_cfg.gqa_query_heads
24
+ self.num_kv_heads = dep_cfg.kv_heads
25
+ self.head_dim = dep_cfg.gqa_head_dim
26
+ self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
27
+ self.apply_rope = dep_cfg.apply_rope
28
+ self.used_ids = sorted(set(self.schedule))
29
+ self.compute_dtype = compute_dtype
30
+
31
+ self.in_proj = nn.ModuleDict(
32
+ {
33
+ str(i): nn.Linear(
34
+ dep_cfg.n_embd,
35
+ 3 * self.num_query_heads * self.head_dim,
36
+ bias=False,
37
+ )
38
+ for i in self.used_ids
39
+ }
40
+ )
41
+ self.out_proj = nn.ModuleDict(
42
+ {
43
+ str(i): nn.Linear(
44
+ self.num_query_heads * self.head_dim,
45
+ dep_cfg.n_embd,
46
+ bias=False,
47
+ )
48
+ for i in self.used_ids
49
+ }
50
+ )
51
+ eps = config.model.normalization_layer_epsilon
52
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
53
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
54
+
55
+ if self.apply_rope:
56
+ self.rotary = RotaryEmbedding(
57
+ self.head_dim,
58
+ config.model.rope_min_timescale,
59
+ config.model.rope_max_timescale,
60
+ )
61
+ stage_count = max(len(self.schedule), 1)
62
+ self.register_buffer(
63
+ "stage_positions",
64
+ torch.arange(stage_count, dtype=torch.long).view(stage_count, 1),
65
+ persistent=False,
66
+ )
67
+ else:
68
+ self.rotary = None
69
+ self.register_buffer(
70
+ "stage_positions",
71
+ torch.zeros(0, 1, dtype=torch.long),
72
+ persistent=False,
73
+ )
74
+
75
+ def forward_incremental(
76
+ self,
77
+ x_t: torch.Tensor,
78
+ stage_index: int,
79
+ cache_slot,
80
+ ) -> Tuple[torch.Tensor, object]:
81
+ bsz, seq, _ = x_t.shape
82
+ if seq != 1:
83
+ raise ValueError("ScheduleAttention expects seq len 1 during decoding")
84
+ orig_dtype = x_t.dtype
85
+ module_index = self.schedule[stage_index]
86
+ proj = self.in_proj[str(module_index)](x_t.to(torch.float32))
87
+ proj = proj.view(bsz, seq, 3, self.num_query_heads, self.head_dim).to(self.compute_dtype)
88
+
89
+ q_proj = self.q_norm(proj[:, :, 0])
90
+ k_proj = self.k_norm(proj[:, :, 1])
91
+ v_proj = proj[:, :, 2]
92
+
93
+ if self.apply_rope:
94
+ pos_ids = self.stage_positions[stage_index : stage_index + 1]
95
+ if pos_ids.device != x_t.device:
96
+ pos_ids = pos_ids.to(x_t.device)
97
+ q_proj = self.rotary(q_proj, pos_ids)
98
+ k_proj = self.rotary(k_proj, pos_ids)
99
+
100
+ q = q_proj.transpose(1, 2)
101
+ k = k_proj.transpose(1, 2)
102
+ v = v_proj.transpose(1, 2)
103
+
104
+ if cache_slot is not None:
105
+ k, v, attn_mask = cache_slot.write_and_view(k, v)
106
+ else:
107
+ attn_mask = None
108
+
109
+ attn = F.scaled_dot_product_attention(
110
+ q,
111
+ k,
112
+ v,
113
+ scale=1.0,
114
+ attn_mask=attn_mask,
115
+ enable_gqa=self.num_gqa_groups > 1,
116
+ )
117
+ attn = attn.transpose(1, 2).contiguous()
118
+ flat = attn.reshape(bsz, seq, self.num_query_heads * self.head_dim)
119
+ out = self.out_proj[str(module_index)](flat.to(torch.float32))
120
+ return out.to(orig_dtype), cache_slot
121
+
122
+
123
+ class DepformerLayer(nn.Module):
124
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
125
+ super().__init__()
126
+ dep_cfg = config.model.depformer
127
+ eps = config.model.normalization_layer_epsilon
128
+ self.pre_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
129
+ self.post_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
130
+ self.self_attention = ScheduleAttention(config, compute_dtype)
131
+ self.mlp = Mlp(
132
+ dep_cfg.n_embd,
133
+ dep_cfg.n_hidden,
134
+ compute_dtype,
135
+ tuple(config.model.depformer.mlp_activations),
136
+ )
137
+
138
+ def decode_step(
139
+ self,
140
+ x_t: torch.Tensor,
141
+ stage_index: int,
142
+ cache_slot,
143
+ ) -> Tuple[torch.Tensor, object]:
144
+ residual = x_t
145
+ x_norm = self.pre_norm(x_t)
146
+ sa_out, _ = self.self_attention.forward_incremental(x_norm, stage_index, cache_slot)
147
+ x = residual + sa_out
148
+ residual2 = x
149
+ x_norm2 = self.post_norm(x)
150
+ mlp_out = self.mlp(x_norm2)
151
+ return residual2 + mlp_out, cache_slot
152
+
153
+
154
+ class Depformer(nn.Module):
155
+ def __init__(self, config: DiaConfig, precision: Precision):
156
+ super().__init__()
157
+ self.config = config
158
+ self.precision = precision
159
+ dep_cfg = config.model.depformer
160
+ data_cfg = config.data
161
+ runtime = config.runtime
162
+
163
+ self.num_audio_channels = max(0, data_cfg.channels - 2)
164
+ self.num_depth = max(self.num_audio_channels - 1, 0)
165
+ self.weights_schedule = runtime.weights_schedule
166
+
167
+ self.audio_embeds = nn.ModuleList(
168
+ [nn.Embedding(data_cfg.audio_vocab_size, dep_cfg.n_embd) for _ in range(self.num_depth)]
169
+ )
170
+ if dep_cfg.text_embedding:
171
+ self.text_embed = MultiStreamEmbedding(
172
+ data_cfg.text_vocab_size,
173
+ dep_cfg.n_embd,
174
+ pad_id=data_cfg.text_pad_token_id,
175
+ output_dtype=precision.compute,
176
+ )
177
+ else:
178
+ self.text_embed = None
179
+
180
+ used_ids = sorted(set(self.weights_schedule))
181
+ self.depformer_in = nn.ModuleDict(
182
+ {
183
+ str(i): nn.Linear(
184
+ config.model.decoder.n_embd,
185
+ dep_cfg.n_embd,
186
+ bias=False,
187
+ )
188
+ for i in used_ids
189
+ }
190
+ )
191
+
192
+ self.layers = nn.ModuleList([DepformerLayer(config, precision.compute) for _ in range(dep_cfg.n_layer)])
193
+ self.norm = nn.RMSNorm(dep_cfg.n_embd, eps=config.model.normalization_layer_epsilon)
194
+ self.logits_dtype = precision.logits
195
+ self.logits = nn.ModuleList(
196
+ [
197
+ nn.Linear(dep_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
198
+ for _ in range(self.num_depth)
199
+ ]
200
+ )
201
+ self.audio_vocab_limit = min(data_cfg.audio_pad_token_id, data_cfg.audio_bos_token_id)
202
+
203
+ def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
204
+ heads = self.layers[0].self_attention.num_kv_heads
205
+ head_dim = self.layers[0].self_attention.head_dim
206
+ return KVCache.allocate(
207
+ num_layers=len(self.layers),
208
+ batch_size=batch_size,
209
+ heads=heads,
210
+ max_steps=max_steps,
211
+ head_dim=head_dim,
212
+ device=device,
213
+ dtype=self.precision.compute,
214
+ )
215
+
216
+ def forward_step(
217
+ self,
218
+ prev_audio: torch.Tensor,
219
+ transformer_out: torch.Tensor,
220
+ stage_index: int,
221
+ cache: KVCache,
222
+ main_text: Optional[torch.Tensor],
223
+ second_text: Optional[torch.Tensor],
224
+ ) -> Tuple[torch.Tensor, KVCache]:
225
+ self._validate_inputs(stage_index, cache)
226
+ return self._forward_stage(stage_index, prev_audio, transformer_out, cache, main_text, second_text)
227
+
228
+ def _forward_stage(
229
+ self,
230
+ stage_index: int,
231
+ prev_audio: torch.Tensor,
232
+ transformer_out: torch.Tensor,
233
+ cache: KVCache,
234
+ main_text: Optional[torch.Tensor],
235
+ second_text: Optional[torch.Tensor],
236
+ ) -> Tuple[torch.Tensor, KVCache]:
237
+ prev_audio = prev_audio.long()
238
+ weight_idx = self.weights_schedule[stage_index]
239
+ token_emb = self.audio_embeds[stage_index](prev_audio[:, None]).to(self.precision.compute)
240
+ if stage_index == 0 and self.text_embed is not None:
241
+ if main_text is None or second_text is None:
242
+ raise ValueError("stage 0 requires text tokens")
243
+ token_emb = token_emb + self.text_embed(main_text[:, None], second_text[:, None])
244
+
245
+ dep_in = self.depformer_in[str(weight_idx)](transformer_out.to(torch.float32))
246
+ dep_in = dep_in.to(self.precision.compute)
247
+ dep_in = dep_in + token_emb.to(dep_in.dtype)
248
+ x = dep_in
249
+ for idx, layer in enumerate(self.layers):
250
+ slot = cache.get_slot(idx)
251
+ x, _ = layer.decode_step(x, stage_index, slot)
252
+
253
+ hidden = self.norm(x)
254
+ logits = self.logits[stage_index](hidden.to(torch.float32))
255
+ logits = logits.to(self.logits_dtype)
256
+ logits = logits.unsqueeze(1)
257
+ logits = logits[..., : self.audio_vocab_limit]
258
+ return logits, cache
259
+
260
+ def _validate_inputs(self, stage_index: int, cache: KVCache | None) -> None:
261
+ if stage_index < 0 or stage_index >= self.num_depth:
262
+ raise ValueError(f"stage_index {stage_index} out of range (depth={self.num_depth})")
263
+ if cache is None:
264
+ raise ValueError("depformer cache must be initialized")
dia2/core/layers.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Union, List
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class RotaryEmbedding(nn.Module):
13
+ def __init__(self, head_dim: int, min_timescale: int, max_timescale: int):
14
+ super().__init__()
15
+ if head_dim % 2 != 0:
16
+ raise ValueError("RoPE dimension must be even")
17
+ half_dim = head_dim // 2
18
+ fraction = (2.0 * torch.arange(0, half_dim)) / head_dim
19
+ timescale = min_timescale * (max_timescale / min_timescale) ** fraction
20
+ inv_freq = 1.0 / timescale
21
+ self.register_buffer("inv_freq", inv_freq.to(torch.float32), persistent=False)
22
+
23
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
24
+ pos = position_ids.to(self.inv_freq.dtype)
25
+ freqs = torch.einsum("...i,j->...ij", pos, self.inv_freq)
26
+ emb = torch.cat((freqs, freqs), dim=-1)
27
+ while emb.dim() < x.dim():
28
+ emb = emb.unsqueeze(-2)
29
+ cos = emb.cos().to(x.dtype)
30
+ sin = emb.sin().to(x.dtype)
31
+ x1, x2 = torch.chunk(x, 2, dim=-1)
32
+ rotated = torch.cat((-x2, x1), dim=-1)
33
+ return (x * cos) + (rotated * sin)
34
+
35
+
36
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
37
+ x1 = x[..., ::2]
38
+ x2 = x[..., 1::2]
39
+ return torch.stack((-x2, x1), dim=-1).reshape_as(x)
40
+
41
+
42
+ def _get_activation(name: str) -> nn.Module:
43
+ name = name.lower()
44
+ if name in ("silu", "swish", "swiglu"):
45
+ return nn.SiLU()
46
+ if name in ("gelu", "geglu"):
47
+ return nn.GELU()
48
+ if name == "relu":
49
+ return nn.ReLU()
50
+ if name == "linear":
51
+ return nn.Identity()
52
+ raise ValueError(f"Unsupported activation {name}")
53
+
54
+
55
+ @dataclass
56
+ class AttentionShape:
57
+ dim: int
58
+ heads: int
59
+ kv_heads: int
60
+ head_dim: int
61
+ rope_min: int
62
+ rope_max: int
63
+ apply_rope: bool
64
+
65
+
66
+ class Attention(nn.Module):
67
+ """Byte-for-byte port of dia_v2 Attention.forward_incremental."""
68
+
69
+ def __init__(self, config: DiaConfig, dim: int, compute_dtype: torch.dtype) -> None:
70
+ super().__init__()
71
+ dec = config.model.decoder
72
+ self.num_query_heads = dec.gqa_query_heads
73
+ self.num_kv_heads = dec.kv_heads
74
+ self.head_dim = dec.gqa_head_dim
75
+ self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
76
+ self.compute_dtype = compute_dtype
77
+ self.q_proj = nn.Linear(dim, self.num_query_heads * self.head_dim, bias=False)
78
+ self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
79
+ self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
80
+ self.o_proj = nn.Linear(self.num_query_heads * self.head_dim, dim, bias=False)
81
+ eps = config.model.normalization_layer_epsilon
82
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
83
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
84
+ self.rotary = RotaryEmbedding(
85
+ self.head_dim,
86
+ config.model.rope_min_timescale,
87
+ config.model.rope_max_timescale,
88
+ )
89
+
90
+ def forward_incremental(
91
+ self,
92
+ x: torch.Tensor,
93
+ pos: Optional[torch.Tensor],
94
+ cache_slot,
95
+ ) -> Tuple[torch.Tensor, object]:
96
+ B, T, _ = x.shape
97
+ if T != 1:
98
+ raise ValueError("Attention expects sequence length 1 during decoding")
99
+ orig_dtype = x.dtype
100
+ q_proj = self._project_heads(self.q_proj, x, self.num_query_heads)
101
+ k_proj = self._project_heads(self.k_proj, x, self.num_kv_heads)
102
+ v_proj = self._project_heads(self.v_proj, x, self.num_kv_heads)
103
+ q_proj = self.q_norm(q_proj)
104
+ k_proj = self.k_norm(k_proj)
105
+ if pos is not None:
106
+ q_proj = self.rotary(q_proj, pos)
107
+ k_proj = self.rotary(k_proj, pos)
108
+ q = q_proj.transpose(1, 2)
109
+ k = k_proj.transpose(1, 2)
110
+ v = v_proj.transpose(1, 2)
111
+ if cache_slot is not None:
112
+ k_cache, v_cache, attn_mask = cache_slot.write_and_view(k, v)
113
+ else:
114
+ k_cache, v_cache = k, v
115
+ attn_mask = None
116
+ attn = F.scaled_dot_product_attention(
117
+ q,
118
+ k_cache,
119
+ v_cache,
120
+ scale=1.0,
121
+ attn_mask=attn_mask,
122
+ enable_gqa=self.num_gqa_groups > 1,
123
+ )
124
+ attn = attn.transpose(1, 2).contiguous()
125
+ flat = attn.reshape(B, T, self.num_query_heads * self.head_dim)
126
+ out = self.o_proj(flat.to(torch.float32))
127
+ return out.to(orig_dtype), cache_slot
128
+
129
+ def _project_heads(self, layer: nn.Linear, x: torch.Tensor, heads: int) -> torch.Tensor:
130
+ proj = layer(x.to(torch.float32))
131
+ B, T, _ = proj.shape
132
+ proj = proj.view(B, T, heads, self.head_dim)
133
+ return proj.to(self.compute_dtype)
134
+
135
+ def forward(
136
+ self,
137
+ x: torch.Tensor,
138
+ positions: Optional[torch.Tensor],
139
+ cache=None,
140
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
141
+ return self.forward_incremental(x, positions, cache)
142
+
143
+
144
+
145
+ class MultiStreamEmbedding(nn.Module):
146
+ """Port of dia_v2 MultiStreamEmbed."""
147
+
148
+ def __init__(
149
+ self,
150
+ vocab_size: int,
151
+ dim: int,
152
+ pad_id: int,
153
+ *,
154
+ output_dtype: torch.dtype,
155
+ low_rank_dim: Optional[int] = None,
156
+ ) -> None:
157
+ super().__init__()
158
+ self.pad_id = pad_id
159
+ self.dtype = output_dtype
160
+ base_dim = low_rank_dim if low_rank_dim is not None else dim
161
+ self.embedding = nn.Embedding(vocab_size, base_dim)
162
+ self.main_proj = nn.Linear(base_dim, dim, bias=False)
163
+ self.second_proj = nn.Linear(base_dim, dim, bias=False)
164
+
165
+ def forward(self, main_inputs: torch.Tensor, second_inputs: torch.Tensor) -> torch.Tensor:
166
+ main_inputs = main_inputs.long()
167
+ second_inputs = second_inputs.long()
168
+ if self.pad_id is not None:
169
+ second_is_pad = second_inputs == self.pad_id
170
+ else:
171
+ second_is_pad = torch.zeros_like(second_inputs, dtype=torch.bool)
172
+ use_second = ~second_is_pad
173
+ emb_main = self.embedding(main_inputs)
174
+ emb_second = self.embedding(second_inputs)
175
+ out_main = self.main_proj(emb_main.to(torch.float32))
176
+ out_second = self.second_proj(emb_second.to(torch.float32))
177
+ zeros = torch.zeros_like(out_second)
178
+ y = out_main + torch.where(use_second.unsqueeze(-1), out_second, zeros)
179
+ target_dtype = self.dtype if self.dtype is not None else y.dtype
180
+ return y.to(target_dtype)
181
+
182
+
183
+ class Mlp(nn.Module):
184
+ """Port of dia_v2 MlpBlock (two-activation gated MLP)."""
185
+
186
+ def __init__(
187
+ self,
188
+ dim: int,
189
+ hidden: int,
190
+ compute_dtype: torch.dtype,
191
+ activations: Sequence[str],
192
+ ) -> None:
193
+ super().__init__()
194
+ if len(activations) != 2:
195
+ raise ValueError("Mlp expects two activation functions.")
196
+ self.dtype = compute_dtype
197
+ self.hidden = hidden
198
+ self.branch_count = len(activations)
199
+ self.wi = nn.Linear(dim, self.branch_count * hidden, bias=False)
200
+ self.wo = nn.Linear(hidden, dim, bias=False)
201
+ self.activation_fns = [_get_activation(activations[0]), _get_activation(activations[1])]
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ proj = self.wi(x.to(torch.float32))
205
+ proj = proj.view(*x.shape[:-1], self.branch_count, self.hidden).to(self.dtype)
206
+ gate, up = proj.unbind(dim=-2)
207
+ hidden = self.activation_fns[0](gate) * self.activation_fns[1](up)
208
+ out = self.wo(hidden.to(torch.float32))
209
+ return out.to(self.dtype)
dia2/core/model.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from ..config import DiaConfig
9
+ from .cache import KVCache
10
+ from .depformer import Depformer
11
+ from .precision import Precision
12
+ from .transformer import TransformerDecoder
13
+
14
+
15
+ @dataclass
16
+ class DecodeState:
17
+ transformer: KVCache
18
+ depformer: KVCache
19
+
20
+
21
+ class Dia2Model(nn.Module):
22
+ def __init__(self, config: DiaConfig, precision: Precision):
23
+ super().__init__()
24
+ self.config = config
25
+ self.precision = precision
26
+ self.transformer = TransformerDecoder(config, precision)
27
+ self.depformer = Depformer(config, precision)
28
+ self._cast_norms_to_compute()
29
+
30
+ def init_state(self, batch_size: int, device: torch.device, max_steps: int) -> DecodeState:
31
+ transformer_cache = self.transformer.init_cache(batch_size, device, max_steps)
32
+ depformer_cache = self.depformer.init_cache(batch_size, device, self.depformer.num_depth)
33
+ return DecodeState(transformer_cache, depformer_cache)
34
+
35
+ def step_text(
36
+ self,
37
+ tokens: torch.Tensor,
38
+ positions: torch.Tensor,
39
+ state: DecodeState,
40
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
41
+ hidden, action, cb0, cache = self.transformer.forward_step(tokens, positions, state.transformer)
42
+ state.transformer = cache
43
+ return hidden, action, cb0
44
+
45
+ def step_audio_stage(
46
+ self,
47
+ stage_index: int,
48
+ prev_audio: torch.Tensor,
49
+ transformer_hidden: torch.Tensor,
50
+ state: DecodeState,
51
+ main_text: Optional[torch.Tensor],
52
+ second_text: Optional[torch.Tensor],
53
+ ) -> torch.Tensor:
54
+ cache = state.depformer
55
+ logits, new_cache = self.depformer.forward_step(
56
+ prev_audio,
57
+ transformer_hidden,
58
+ stage_index,
59
+ cache,
60
+ main_text,
61
+ second_text,
62
+ )
63
+ state.depformer = new_cache
64
+ return logits
65
+
66
+ def _cast_norms_to_compute(self) -> None:
67
+ """Cast RMSNorm weights/biases to the compute dtype to avoid bf16 warnings."""
68
+ def _convert(module: nn.Module) -> None:
69
+ if isinstance(module, nn.RMSNorm):
70
+ module.to(self.precision.compute)
71
+
72
+ self.apply(_convert)
dia2/core/precision.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class Precision:
10
+ compute: torch.dtype
11
+ logits: torch.dtype
12
+
13
+
14
+ def resolve_precision(kind: str | None, device: torch.device) -> Precision:
15
+ normalized = (kind or "auto").lower()
16
+ if normalized == "auto":
17
+ normalized = "bfloat16" if device.type == "cuda" else "float32"
18
+ if normalized == "bfloat16":
19
+ compute = torch.bfloat16 if device.type == "cuda" else torch.float32
20
+ return Precision(compute=compute, logits=torch.float32)
21
+ if normalized == "float32":
22
+ return Precision(compute=torch.float32, logits=torch.float32)
23
+ raise ValueError(f"Unsupported dtype '{kind}'")
dia2/core/transformer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from ..config import DiaConfig
10
+ from .cache import KVCache
11
+ from .precision import Precision
12
+ from .layers import (
13
+ AttentionShape,
14
+ MultiStreamEmbedding,
15
+ Mlp,
16
+ Attention,
17
+ )
18
+
19
+
20
+ class TransformerDecoder(nn.Module):
21
+ """Inference-time port of dia_v2.model.Transformer."""
22
+
23
+ def __init__(self, config: DiaConfig, precision: Precision):
24
+ super().__init__()
25
+ self.config = config
26
+ self.precision = precision
27
+ data_cfg = config.data
28
+ dec_cfg = config.model.decoder
29
+
30
+ self.audio_embeds = nn.ModuleList(
31
+ [
32
+ nn.Embedding(
33
+ data_cfg.audio_vocab_size,
34
+ dec_cfg.n_embd,
35
+ )
36
+ for _ in range(max(0, data_cfg.channels - 2))
37
+ ]
38
+ )
39
+ self.text_embed = MultiStreamEmbedding(
40
+ data_cfg.text_vocab_size,
41
+ dec_cfg.n_embd,
42
+ pad_id=data_cfg.text_pad_token_id,
43
+ output_dtype=self.precision.compute,
44
+ low_rank_dim=dec_cfg.low_rank_dim,
45
+ )
46
+ self.layers = nn.ModuleList([DecoderLayer(config, precision) for _ in range(dec_cfg.n_layer)])
47
+ self.norm = nn.RMSNorm(dec_cfg.n_embd, eps=config.model.normalization_layer_epsilon, dtype=torch.float32)
48
+
49
+ self.action_head = nn.Linear(dec_cfg.n_embd, data_cfg.action_vocab_size, bias=False)
50
+ self.cb0_head = nn.Linear(dec_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
51
+
52
+ def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
53
+ heads = self.layers[0].attn.num_kv_heads
54
+ head_dim = self.layers[0].attn.head_dim
55
+ return KVCache.allocate(
56
+ num_layers=len(self.layers),
57
+ batch_size=batch_size,
58
+ heads=heads,
59
+ max_steps=max_steps,
60
+ head_dim=head_dim,
61
+ device=device,
62
+ dtype=self.precision.compute,
63
+ )
64
+
65
+ def forward_step(
66
+ self,
67
+ tokens: torch.Tensor,
68
+ positions: torch.Tensor,
69
+ cache: KVCache,
70
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, KVCache]:
71
+ if cache is None:
72
+ raise ValueError("Transformer cache must be initialized")
73
+
74
+ B, C, T1 = tokens.shape
75
+ if T1 != 1:
76
+ raise ValueError("forward_step expects sequence length 1")
77
+ num_audio_channels = max(0, C - 2)
78
+
79
+ hidden_t = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
80
+ for idx in range(num_audio_channels):
81
+ audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
82
+ hidden_t.add_(audio_emb)
83
+ hidden_t = hidden_t.to(self.precision.compute)
84
+
85
+ x = hidden_t
86
+ for idx, layer in enumerate(self.layers):
87
+ slot = cache.get_slot(idx)
88
+ x, _ = layer.decode_step(x, positions, slot)
89
+
90
+ hidden_norm = self.norm(x)
91
+ action_logits = self.action_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
92
+ cb0_logits = self.cb0_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
93
+ return hidden_norm, action_logits, cb0_logits, cache
94
+
95
+ def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
96
+ B, C, T1 = tokens.shape
97
+ if T1 != 1:
98
+ raise ValueError("_embed expects sequence length 1")
99
+ num_audio_channels = max(0, C - 2)
100
+ text_hidden = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
101
+ audio_terms: list[torch.Tensor] = []
102
+ for idx in range(num_audio_channels):
103
+ audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
104
+ audio_terms.append(audio_emb)
105
+ hidden = text_hidden
106
+ for term in audio_terms:
107
+ hidden = hidden + term
108
+ final = hidden.to(self.precision.compute)
109
+ return final
110
+
111
+
112
+ class DecoderLayer(nn.Module):
113
+ def __init__(self, config: DiaConfig, precision: Precision):
114
+ super().__init__()
115
+ dec = config.model.decoder
116
+ eps = config.model.normalization_layer_epsilon
117
+ self.pre_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
118
+ self.attn = Attention(config, dec.n_embd, precision.compute)
119
+ self.post_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
120
+ self.mlp = Mlp(
121
+ dec.n_embd,
122
+ dec.n_hidden,
123
+ precision.compute,
124
+ tuple(config.model.linear.mlp_activations),
125
+ )
126
+
127
+ def decode_step(
128
+ self,
129
+ x: torch.Tensor,
130
+ pos: torch.Tensor,
131
+ cache_slot,
132
+ ) -> Tuple[torch.Tensor, object]:
133
+ residual = x
134
+ x_norm = self.pre_norm(x)
135
+ attn_out, _ = self.attn(x_norm, pos, cache_slot)
136
+ x = residual + attn_out
137
+ residual2 = x
138
+ x_norm2 = self.post_norm(x)
139
+ mlp_out = self.mlp(x_norm2)
140
+ return residual2 + mlp_out, cache_slot
dia2/engine.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Sequence
5
+
6
+ from .assets import resolve_assets
7
+ from .runtime.context import RuntimeContext, build_runtime
8
+ from .runtime.generator import (
9
+ build_initial_state,
10
+ decode_audio,
11
+ run_generation_loop,
12
+ warmup_with_prefix,
13
+ )
14
+ from .runtime.script_parser import parse_script
15
+ from .audio.grid import undelay_frames, write_wav
16
+ from .runtime.voice_clone import build_prefix_plan
17
+ from .generation import (
18
+ GenerationConfig,
19
+ GenerationResult,
20
+ merge_generation_config,
21
+ normalize_script,
22
+ )
23
+ from .runtime.logger import RuntimeLogger
24
+
25
+ class Dia2:
26
+ def __init__(
27
+ self,
28
+ *,
29
+ repo: Optional[str] = None,
30
+ config_path: Optional[str | Path] = None,
31
+ weights_path: Optional[str | Path] = None,
32
+ tokenizer_id: Optional[str | Path] = None,
33
+ mimi_id: Optional[str] = None,
34
+ device: str = "cuda",
35
+ dtype: str = "auto",
36
+ default_config: Optional[GenerationConfig] = None,
37
+ ) -> None:
38
+ bundle = resolve_assets(
39
+ repo=repo,
40
+ config_path=config_path,
41
+ weights_path=weights_path,
42
+ )
43
+ self._config_path = bundle.config_path
44
+ self._weights_path = bundle.weights_path
45
+ self._tokenizer_id = (str(tokenizer_id) if tokenizer_id else None) or bundle.tokenizer_id
46
+ self._repo_id = bundle.repo_id
47
+ self._mimi_id = mimi_id or bundle.mimi_id
48
+ self.device = device
49
+ self._dtype_pref = dtype or "auto"
50
+ self.default_config = default_config or GenerationConfig()
51
+ self._runtime: Optional[RuntimeContext] = None
52
+
53
+ @classmethod
54
+ def from_repo(
55
+ cls,
56
+ repo: str,
57
+ *,
58
+ device: str = "cuda",
59
+ dtype: str = "auto",
60
+ tokenizer_id: Optional[str] = None,
61
+ mimi_id: Optional[str] = None,
62
+ ) -> "Dia2":
63
+ return cls(repo=repo, device=device, dtype=dtype, tokenizer_id=tokenizer_id, mimi_id=mimi_id)
64
+
65
+ @classmethod
66
+ def from_local(
67
+ cls,
68
+ config_path: str | Path,
69
+ weights_path: str | Path,
70
+ *,
71
+ device: str = "cuda",
72
+ dtype: str = "auto",
73
+ tokenizer_id: Optional[str | Path] = None,
74
+ mimi_id: Optional[str] = None,
75
+ ) -> "Dia2":
76
+ return cls(
77
+ config_path=config_path,
78
+ weights_path=weights_path,
79
+ tokenizer_id=tokenizer_id,
80
+ device=device,
81
+ dtype=dtype,
82
+ mimi_id=mimi_id,
83
+ )
84
+
85
+ def set_device(self, device: str, *, dtype: Optional[str] = None) -> None:
86
+ desired_dtype = dtype or self._dtype_pref
87
+ if self.device == device and desired_dtype == self._dtype_pref:
88
+ return
89
+ self.device = device
90
+ self._dtype_pref = desired_dtype
91
+ self._runtime = None
92
+
93
+ def close(self) -> None:
94
+ self._runtime = None
95
+
96
+ def _ensure_runtime(self) -> RuntimeContext:
97
+ if self._runtime is None:
98
+ self._runtime = self._build_runtime()
99
+ return self._runtime
100
+
101
+ def generate(
102
+ self,
103
+ script: str | Sequence[str],
104
+ *,
105
+ config: Optional[GenerationConfig] = None,
106
+ output_wav: Optional[str | Path] = None,
107
+ prefix_speaker_1: Optional[str] = None,
108
+ prefix_speaker_2: Optional[str] = None,
109
+ include_prefix: Optional[bool] = None,
110
+ verbose: bool = False,
111
+ **overrides,
112
+ ):
113
+ runtime = self._ensure_runtime()
114
+ logger = RuntimeLogger(verbose)
115
+ merged_overrides = dict(overrides)
116
+ if prefix_speaker_1 is not None:
117
+ merged_overrides["prefix_speaker_1"] = prefix_speaker_1
118
+ if prefix_speaker_2 is not None:
119
+ merged_overrides["prefix_speaker_2"] = prefix_speaker_2
120
+ if include_prefix is not None:
121
+ merged_overrides["include_prefix"] = include_prefix
122
+ merged = merge_generation_config(base=config or self.default_config, overrides=merged_overrides)
123
+ max_context = runtime.config.runtime.max_context_steps
124
+ text = normalize_script(script)
125
+ prefix_plan = build_prefix_plan(runtime, merged.prefix)
126
+ entries = []
127
+ if prefix_plan is not None:
128
+ entries.extend(prefix_plan.entries)
129
+ entries.extend(parse_script([text], runtime.tokenizer, runtime.constants, runtime.frame_rate))
130
+ runtime.machine.initial_padding = merged.initial_padding
131
+ logger.event(
132
+ f"starting generation: max_context={max_context} cfg_scale={merged.cfg_scale:.2f} "
133
+ f"device={self.device} dtype={self._dtype_pref}"
134
+ )
135
+ state = runtime.machine.new_state(entries)
136
+ cfg_active = merged.cfg_scale != 1.0
137
+ if cfg_active:
138
+ logger.event(f"classifier-free guidance enabled (scale={merged.cfg_scale:.2f})")
139
+ else:
140
+ logger.event("classifier-free guidance disabled (scale=1.0)")
141
+ gen_state = build_initial_state(
142
+ runtime,
143
+ prefix=prefix_plan,
144
+ )
145
+ include_prefix_audio = bool(prefix_plan and merged.prefix and merged.prefix.include_audio)
146
+ start_step = 0
147
+ if prefix_plan is not None:
148
+ logger.event(f"warming up with prefix ({prefix_plan.aligned_frames} frames)")
149
+ start_step = warmup_with_prefix(runtime, prefix_plan, state, gen_state)
150
+ if include_prefix_audio:
151
+ logger.event("prefix audio will be kept in output")
152
+ else:
153
+ logger.event("prefix audio trimmed from output")
154
+ first_word_frame, audio_buf = run_generation_loop(
155
+ runtime,
156
+ state=state,
157
+ generation=gen_state,
158
+ config=merged,
159
+ start_step=start_step,
160
+ logger=logger,
161
+ )
162
+ aligned = undelay_frames(audio_buf[0], runtime.audio_delays, runtime.constants.audio_pad).unsqueeze(0)
163
+ crop = 0 if include_prefix_audio else max(first_word_frame, 0)
164
+ if crop > 0 and crop < aligned.shape[-1]:
165
+ aligned = aligned[:, :, crop:]
166
+ elif crop >= aligned.shape[-1]:
167
+ crop = 0
168
+ logger.event(f"decoding {aligned.shape[-1]} Mimi frames")
169
+ waveform = decode_audio(runtime, aligned)
170
+ if output_wav is not None:
171
+ write_wav(str(output_wav), waveform.detach().cpu().numpy(), runtime.mimi.sample_rate)
172
+ duration = waveform.shape[-1] / max(runtime.mimi.sample_rate, 1)
173
+ logger.event(f"saved {output_wav} ({duration:.2f}s)")
174
+ frame_rate = max(runtime.frame_rate, 1.0)
175
+ prefix_entry_count = len(prefix_plan.entries) if prefix_plan is not None else 0
176
+ transcript_entries = state.transcript
177
+ if prefix_plan is not None and not include_prefix_audio:
178
+ if len(transcript_entries) > prefix_entry_count:
179
+ transcript_entries = transcript_entries[prefix_entry_count:]
180
+ else:
181
+ transcript_entries = []
182
+ timestamps = []
183
+ for word, step in transcript_entries:
184
+ adj = step - crop
185
+ if adj < 0:
186
+ continue
187
+ timestamps.append((word, adj / frame_rate))
188
+ logger.event(f"generation finished in {logger.elapsed():.2f}s")
189
+ return GenerationResult(aligned, waveform, runtime.mimi.sample_rate, timestamps)
190
+
191
+ def save_wav(self, script: str | Sequence[str], path: str | Path, **kwargs):
192
+ return self.generate(script, output_wav=path, **kwargs)
193
+
194
+ @property
195
+ def sample_rate(self) -> int:
196
+ return self._ensure_runtime().mimi.sample_rate
197
+
198
+ @property
199
+ def tokenizer_id(self) -> Optional[str]:
200
+ if self._tokenizer_id:
201
+ return self._tokenizer_id
202
+ if self._runtime is not None:
203
+ return getattr(self._runtime.tokenizer, "name_or_path", None)
204
+ return self._repo_id
205
+
206
+ @property
207
+ def dtype(self) -> str:
208
+ return self._dtype_pref
209
+
210
+ @property
211
+ def max_context_steps(self) -> int:
212
+ return self._ensure_runtime().config.runtime.max_context_steps
213
+
214
+ @property
215
+ def repo(self) -> Optional[str]:
216
+ return self._repo_id
217
+
218
+ def _build_runtime(self) -> RuntimeContext:
219
+ runtime, tokenizer_ref, mimi_ref = build_runtime(
220
+ config_path=self._config_path,
221
+ weights_path=self._weights_path,
222
+ tokenizer_id=self._tokenizer_id,
223
+ repo_id=self._repo_id,
224
+ mimi_id=self._mimi_id,
225
+ device=self.device,
226
+ dtype_pref=self._dtype_pref,
227
+ )
228
+ self._tokenizer_id = tokenizer_ref
229
+ self._mimi_id = mimi_ref
230
+ return runtime
dia2/generation.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import List, Mapping, Optional, Sequence, Tuple
7
+
8
+ import torch
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class SamplingConfig:
13
+ temperature: float = 0.8
14
+ top_k: int = 50
15
+
16
+
17
+ def _default_text_sampling() -> SamplingConfig:
18
+ return SamplingConfig(temperature=0.6, top_k=50)
19
+
20
+
21
+ def _default_audio_sampling() -> SamplingConfig:
22
+ return SamplingConfig(temperature=0.8, top_k=50)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class PrefixConfig:
27
+ speaker_1: Optional[str] = None
28
+ speaker_2: Optional[str] = None
29
+ include_audio: bool = False
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class GenerationConfig:
34
+ text: SamplingConfig = field(default_factory=_default_text_sampling)
35
+ audio: SamplingConfig = field(default_factory=_default_audio_sampling)
36
+ cfg_scale: float = 2.0
37
+ cfg_filter_k: int = 50
38
+ initial_padding: int = 2
39
+ prefix: Optional["PrefixConfig"] = None
40
+ use_cuda_graph: bool = False
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class GenerationResult:
45
+ audio_tokens: torch.Tensor
46
+ waveform: torch.Tensor
47
+ sample_rate: int
48
+ timestamps: List[Tuple[str, float]]
49
+
50
+
51
+ def normalize_script(script: str | Sequence[str]) -> str:
52
+ if isinstance(script, str):
53
+ return script.strip()
54
+ return "\n".join(line.strip() for line in script)
55
+
56
+
57
+ def load_script_text(path: str | Path) -> str:
58
+ if path == "-":
59
+ return sys.stdin.read().strip()
60
+ path_obj = Path(path)
61
+ if path_obj.exists():
62
+ return path_obj.read_text().strip()
63
+ return str(path).strip()
64
+
65
+
66
+ def validate_generation_params(
67
+ *,
68
+ temperature: float,
69
+ top_k: int,
70
+ cfg_scale: float,
71
+ ) -> tuple[float, int, float]:
72
+ if temperature <= 0:
73
+ raise ValueError("temperature must be positive")
74
+ if top_k <= 0:
75
+ raise ValueError("top_k must be positive")
76
+ if cfg_scale <= 0:
77
+ raise ValueError("cfg_scale must be positive")
78
+ return temperature, top_k, cfg_scale
79
+
80
+
81
+ def build_generation_config(
82
+ *,
83
+ temperature: float,
84
+ top_k: int,
85
+ cfg_scale: float,
86
+ ) -> GenerationConfig:
87
+ sampling = SamplingConfig(temperature=temperature, top_k=top_k)
88
+ return GenerationConfig(
89
+ text=sampling,
90
+ audio=sampling,
91
+ cfg_scale=cfg_scale,
92
+ )
93
+
94
+
95
+ def merge_generation_config(
96
+ *,
97
+ base: GenerationConfig,
98
+ overrides: Mapping[str, object],
99
+ ) -> GenerationConfig:
100
+ clean_overrides = {k: v for k, v in overrides.items() if v is not None}
101
+ text_temp = clean_overrides.pop("temp_text", None)
102
+ text_topk = clean_overrides.pop("topk_text", None)
103
+ audio_temp = clean_overrides.pop("temp_audio", None)
104
+ audio_topk = clean_overrides.pop("topk_audio", None)
105
+ prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None)
106
+ prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None)
107
+ include_prefix = clean_overrides.pop("include_prefix", None)
108
+
109
+ text_sampling = base.text
110
+ if text_temp is not None or text_topk is not None:
111
+ text_sampling = SamplingConfig(
112
+ temperature=text_temp if text_temp is not None else text_sampling.temperature,
113
+ top_k=text_topk if text_topk is not None else text_sampling.top_k,
114
+ )
115
+
116
+ audio_sampling = base.audio
117
+ if audio_temp is not None or audio_topk is not None:
118
+ audio_sampling = SamplingConfig(
119
+ temperature=audio_temp if audio_temp is not None else audio_sampling.temperature,
120
+ top_k=audio_topk if audio_topk is not None else audio_sampling.top_k,
121
+ )
122
+
123
+ prefix_cfg = base.prefix
124
+ if (
125
+ prefix_speaker_1 is not None
126
+ or prefix_speaker_2 is not None
127
+ or include_prefix is not None
128
+ or prefix_cfg is not None
129
+ ):
130
+ prefix_cfg = prefix_cfg or PrefixConfig()
131
+ prefix_cfg = PrefixConfig(
132
+ speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1,
133
+ speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2,
134
+ include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio,
135
+ )
136
+
137
+ return GenerationConfig(
138
+ text=text_sampling,
139
+ audio=audio_sampling,
140
+ cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale),
141
+ cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k),
142
+ initial_padding=clean_overrides.pop("initial_padding", base.initial_padding),
143
+ prefix=prefix_cfg,
144
+ use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph),
145
+ )
146
+
147
+
148
+ __all__ = [
149
+ "SamplingConfig",
150
+ "GenerationConfig",
151
+ "GenerationResult",
152
+ "PrefixConfig",
153
+ "normalize_script",
154
+ "load_script_text",
155
+ "validate_generation_params",
156
+ "build_generation_config",
157
+ "merge_generation_config",
158
+ ]
dia2/runtime/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .state_machine import Entry, StateMachine, TokenIds
2
+
3
+ __all__ = [
4
+ "Entry",
5
+ "StateMachine",
6
+ "TokenIds",
7
+ ]
dia2/runtime/audio_io.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import sphn
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from ..audio import MimiCodec
12
+
13
+ PathLike = Union[str, Path]
14
+
15
+
16
+ def load_mono_audio(path: PathLike, target_sr: int) -> np.ndarray:
17
+ """Read an audio file, convert to mono float32, and resample to target_sr."""
18
+ path = str(path)
19
+ try:
20
+ audio, sr = sphn.read_wav(path)
21
+ except Exception:
22
+ import soundfile as sf # Local fallback
23
+
24
+ audio, sr = sf.read(path, dtype="float32", always_2d=False)
25
+ audio = np.asarray(audio, dtype=np.float32)
26
+ if audio.ndim == 2:
27
+ audio = audio.mean(axis=1)
28
+ if sr != target_sr:
29
+ if hasattr(sphn, "resample_audio"):
30
+ audio = sphn.resample_audio(audio, sr, target_sr).astype(np.float32)
31
+ else:
32
+ audio = _resample_linear(audio, sr, target_sr)
33
+ return audio
34
+
35
+
36
+ def audio_to_tensor(audio: np.ndarray, device: torch.device) -> torch.Tensor:
37
+ """Convert mono PCM samples into shape [1, 1, T] tensor."""
38
+ tensor = torch.from_numpy(audio).to(device)
39
+ if tensor.dim() == 1:
40
+ tensor = tensor.unsqueeze(0)
41
+ if tensor.dim() == 2:
42
+ tensor = tensor.unsqueeze(0)
43
+ return tensor
44
+
45
+
46
+ def encode_audio_tokens(mimi: MimiCodec, audio: np.ndarray) -> torch.Tensor:
47
+ """Encode PCM audio into Mimi codebook tokens [C, T]."""
48
+ waveform = audio_to_tensor(audio, mimi.device)
49
+ with torch.inference_mode():
50
+ codes, *_ = mimi.encode(waveform, return_dict=False)
51
+ if isinstance(codes, (tuple, list)):
52
+ codes = codes[0]
53
+ # Mimi.encode returns [B, num_codebooks, T]; select batch 0.
54
+ codes = codes[0].to(torch.long)
55
+ return codes
56
+
57
+
58
+ def _resample_linear(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
59
+ if src_sr == dst_sr:
60
+ return audio.astype(np.float32)
61
+ length = audio.shape[0]
62
+ new_length = max(1, int(round(length * dst_sr / src_sr)))
63
+ tensor = torch.from_numpy(audio.astype(np.float32)).unsqueeze(0).unsqueeze(0)
64
+ with torch.no_grad():
65
+ resampled = F.interpolate(tensor, size=new_length, mode="linear", align_corners=False)
66
+ return resampled.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32)
67
+
68
+
69
+ __all__ = ["load_mono_audio", "audio_to_tensor", "encode_audio_tokens"]
dia2/runtime/context.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ import warnings
7
+
8
+ import torch
9
+ from safetensors.torch import load_file
10
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
11
+
12
+ from ..config import DiaConfig, load_config
13
+ from ..core.model import Dia2Model
14
+ from ..core.precision import Precision, resolve_precision
15
+ from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
16
+ from .state_machine import StateMachine, TokenIds
17
+
18
+
19
+ @dataclass
20
+ class RuntimeContext:
21
+ config: DiaConfig
22
+ model: Dia2Model
23
+ precision: Precision
24
+ tokenizer: PreTrainedTokenizerBase
25
+ mimi: MimiCodec
26
+ device: torch.device
27
+ machine: StateMachine
28
+ transformer_step: callable
29
+ depformer_step: callable
30
+ constants: TokenIds
31
+ audio_delays: list[int]
32
+ audio_delay_tensor: torch.Tensor
33
+ frame_rate: float
34
+
35
+
36
+ def build_runtime(
37
+ *,
38
+ config_path: str | Path,
39
+ weights_path: str | Path,
40
+ tokenizer_id: Optional[str],
41
+ repo_id: Optional[str],
42
+ mimi_id: Optional[str],
43
+ device: str,
44
+ dtype_pref: str,
45
+ ) -> tuple[RuntimeContext, str, str]:
46
+ device_obj = torch.device(device)
47
+ if device_obj.type == "cuda":
48
+ cuda_matmul = torch.backends.cuda.matmul
49
+ cudnn_conv = torch.backends.cudnn.conv
50
+ if hasattr(cuda_matmul, "fp32_precision"):
51
+ cuda_matmul.fp32_precision = "tf32"
52
+ with warnings.catch_warnings():
53
+ warnings.filterwarnings(
54
+ "ignore",
55
+ message="Please use the new API settings",
56
+ )
57
+ torch.backends.cuda.matmul.allow_tf32 = True
58
+ else: # pragma: no cover - compatibility with older PyTorch
59
+ torch.backends.cuda.matmul.allow_tf32 = True
60
+ if hasattr(cudnn_conv, "fp32_precision"):
61
+ cudnn_conv.fp32_precision = "tf32"
62
+ with warnings.catch_warnings():
63
+ warnings.filterwarnings(
64
+ "ignore",
65
+ message="Please use the new API settings",
66
+ )
67
+ torch.backends.cudnn.allow_tf32 = True
68
+ else: # pragma: no cover
69
+ torch.backends.cudnn.allow_tf32 = True
70
+ precision = resolve_precision(dtype_pref, device_obj)
71
+ config = load_config(config_path)
72
+ model = Dia2Model(config, precision)
73
+ state = load_file(str(weights_path))
74
+ model.load_state_dict(state)
75
+ model = model.to(device_obj)
76
+
77
+ tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
78
+ if tokenizer_ref is None:
79
+ raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
80
+ tokenizer = AutoTokenizer.from_pretrained(
81
+ tokenizer_ref,
82
+ use_fast=False,
83
+ trust_remote_code=True,
84
+ )
85
+
86
+ mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
87
+ mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)
88
+
89
+ data_cfg = config.data
90
+ constants = TokenIds(
91
+ card=data_cfg.text_vocab_size,
92
+ new_word=data_cfg.text_new_word_token_id,
93
+ pad=data_cfg.text_pad_token_id,
94
+ bos=getattr(tokenizer, "bos_token_id", 1) or 1,
95
+ zero=data_cfg.text_zero_token_id,
96
+ spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
97
+ spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
98
+ audio_pad=data_cfg.audio_pad_token_id,
99
+ audio_bos=data_cfg.audio_bos_token_id,
100
+ )
101
+ machine = StateMachine(
102
+ token_ids=constants,
103
+ second_stream_ahead=data_cfg.second_stream_ahead,
104
+ max_padding=6,
105
+ initial_padding=0,
106
+ )
107
+ audio_delays = list(data_cfg.delay_pattern)
108
+ audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
109
+ frame_rate = getattr(mimi, "frame_rate", 75.0)
110
+
111
+ runtime = RuntimeContext(
112
+ config=config,
113
+ precision=precision,
114
+ model=model,
115
+ tokenizer=tokenizer,
116
+ mimi=mimi,
117
+ device=device_obj,
118
+ machine=machine,
119
+ constants=constants,
120
+ audio_delays=audio_delays,
121
+ audio_delay_tensor=audio_delay_tensor,
122
+ frame_rate=frame_rate,
123
+ transformer_step=model.transformer.forward_step,
124
+ depformer_step=model.depformer.forward_step,
125
+ )
126
+ return runtime, tokenizer_ref, mimi_ref
127
+
128
+
129
+ __all__ = [
130
+ "RuntimeContext",
131
+ "build_runtime",
132
+ ]
dia2/runtime/generator.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from ..core.cache import KVCache
9
+ from ..core.model import DecodeState
10
+ from ..generation import GenerationConfig
11
+ from ..audio.grid import delay_frames, mask_audio_logits, undelay_frames
12
+ from .context import RuntimeContext
13
+ from .state_machine import State, TokenIds
14
+ from .guidance import apply_classifier_guidance, sample_audio_logits
15
+ from .sampler import sample_token
16
+ from .voice_clone import PrefixPlan
17
+ from .logger import RuntimeLogger
18
+
19
+ _GRAPH_CUBLAS_READY = False
20
+
21
+
22
+ def _ensure_graph_cublas_ready(device: torch.device) -> None:
23
+ global _GRAPH_CUBLAS_READY
24
+ if _GRAPH_CUBLAS_READY or device.type != "cuda":
25
+ return
26
+ tmp = torch.empty((1, 1), device=device, dtype=torch.float32)
27
+ torch.matmul(tmp, tmp)
28
+ torch.cuda.synchronize()
29
+ _GRAPH_CUBLAS_READY = True
30
+ @dataclass
31
+ class GenerationState:
32
+ decode: DecodeState
33
+ step_tokens: torch.Tensor
34
+ audio_buf: torch.Tensor
35
+
36
+ def trim_audio(self, limit: int, pad_token: int, ungenerated: int) -> torch.Tensor:
37
+ trimmed = self.audio_buf[:, :, :limit]
38
+ pad = torch.full_like(trimmed, pad_token)
39
+ trimmed = torch.where(trimmed == ungenerated, pad, trimmed)
40
+ self.audio_buf = trimmed
41
+ return trimmed
42
+
43
+ @property
44
+ def transformer_cache(self) -> KVCache:
45
+ return self.decode.transformer
46
+
47
+ @transformer_cache.setter
48
+ def transformer_cache(self, cache: KVCache) -> None:
49
+ self.decode.transformer = cache
50
+
51
+ @property
52
+ def depformer_cache(self) -> KVCache:
53
+ return self.decode.depformer
54
+
55
+ @depformer_cache.setter
56
+ def depformer_cache(self, cache: KVCache) -> None:
57
+ self.decode.depformer = cache
58
+
59
+ def reset_dep_cache(self) -> None:
60
+ self.decode.depformer.reset()
61
+
62
+
63
+ @dataclass
64
+ class NetworkBuffers:
65
+ text: torch.Tensor
66
+ cb0: torch.Tensor
67
+ dep: list[torch.Tensor]
68
+
69
+
70
+ def _allocate_network_buffers(runtime: RuntimeContext, branches: int) -> NetworkBuffers:
71
+ device = runtime.device
72
+ logits_dtype = runtime.precision.logits
73
+ data_cfg = runtime.config.data
74
+ text_logits = torch.empty((branches, 1, data_cfg.action_vocab_size), dtype=logits_dtype, device=device)
75
+ cb0_logits = torch.empty((branches, 1, data_cfg.audio_vocab_size), dtype=logits_dtype, device=device)
76
+ dep_vocab = runtime.model.depformer.audio_vocab_limit or data_cfg.audio_vocab_size
77
+ dep_logits = [
78
+ torch.empty((branches, 1, 1, dep_vocab), dtype=logits_dtype, device=device)
79
+ for _ in range(runtime.model.depformer.num_depth)
80
+ ]
81
+ return NetworkBuffers(text=text_logits, cb0=cb0_logits, dep=dep_logits)
82
+
83
+
84
+ def build_initial_state(
85
+ runtime: RuntimeContext,
86
+ *,
87
+ prefix: PrefixPlan | None = None,
88
+ ) -> GenerationState:
89
+ dep_q = runtime.model.depformer.num_audio_channels
90
+ channels = 2 + dep_q
91
+ branches = 2
92
+ token_ids = runtime.constants
93
+ step_tokens = torch.full(
94
+ (branches, channels, 1),
95
+ token_ids.pad,
96
+ dtype=torch.long,
97
+ device=runtime.device,
98
+ )
99
+ step_tokens[0, 0, 0] = token_ids.bos
100
+ step_tokens[0, 1, 0] = token_ids.pad
101
+ step_tokens[1, 0, 0] = token_ids.zero
102
+ step_tokens[1, 1, 0] = token_ids.pad
103
+ prefix_len = 0
104
+ if prefix is not None:
105
+ delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad)
106
+ prefix_len = delayed.shape[1]
107
+ limit = runtime.config.runtime.max_context_steps
108
+ total_steps = max(limit + prefix_len + 1, limit)
109
+ decode_state = runtime.model.init_state(branches, runtime.device, total_steps)
110
+ audio_buf = torch.full(
111
+ (branches, dep_q, total_steps),
112
+ token_ids.ungenerated,
113
+ dtype=torch.long,
114
+ device=runtime.device,
115
+ )
116
+ if prefix is not None:
117
+ delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad).to(runtime.device)
118
+ audio_buf[0, :, : delayed.shape[1]] = delayed
119
+ if branches > 1:
120
+ audio_buf[1:, :, : delayed.shape[1]] = delayed
121
+ return GenerationState(decode_state, step_tokens, audio_buf)
122
+
123
+
124
+ def _fill_audio_channels(
125
+ step_tokens: torch.Tensor,
126
+ audio_buf: torch.Tensor,
127
+ delays: torch.Tensor,
128
+ step: int,
129
+ bos_token: int,
130
+ ) -> None:
131
+ channels = delays.numel()
132
+ if channels == 0:
133
+ return
134
+ target = step_tokens[:, 2 : 2 + channels, 0]
135
+ if step < audio_buf.shape[-1]:
136
+ target.copy_(audio_buf[:, :channels, step])
137
+ else:
138
+ target.fill_(bos_token)
139
+ mask = delays > step
140
+ if mask.any().item():
141
+ target[:, mask] = bos_token
142
+
143
+
144
+ def _execute_transformer_step(
145
+ step_tokens: torch.Tensor,
146
+ positions_view: torch.Tensor,
147
+ generation: GenerationState,
148
+ transformer_step,
149
+ buffers: NetworkBuffers,
150
+ ) -> torch.Tensor:
151
+ hidden_t, text_logits_t, cb0_logits_t, present = transformer_step(
152
+ step_tokens,
153
+ positions_view,
154
+ generation.transformer_cache,
155
+ )
156
+ buffers.text.copy_(text_logits_t)
157
+ buffers.cb0.copy_(cb0_logits_t)
158
+ generation.transformer_cache = present
159
+ return hidden_t
160
+
161
+
162
+ def _execute_depformer_stage(
163
+ stage_index: int,
164
+ prev_audio: torch.Tensor,
165
+ hidden_t: torch.Tensor,
166
+ generation: GenerationState,
167
+ depformer_step,
168
+ main_tokens: Optional[torch.Tensor],
169
+ second_tokens: Optional[torch.Tensor],
170
+ buffers: NetworkBuffers,
171
+ ) -> None:
172
+ logits_stage, dep_present = depformer_step(
173
+ prev_audio=prev_audio,
174
+ transformer_out=hidden_t,
175
+ stage_index=stage_index,
176
+ cache=generation.depformer_cache,
177
+ main_text=main_tokens if stage_index == 0 else None,
178
+ second_text=second_tokens if stage_index == 0 else None,
179
+ )
180
+ target = buffers.dep[stage_index]
181
+ if logits_stage.shape != target.shape:
182
+ raise RuntimeError(
183
+ f"depformer logits shape mismatch: {logits_stage.shape} vs {target.shape}"
184
+ )
185
+ target.copy_(logits_stage)
186
+ generation.depformer_cache = dep_present
187
+
188
+
189
+
190
+
191
+ def run_generation_loop(
192
+ runtime: RuntimeContext,
193
+ *,
194
+ state: State,
195
+ generation: GenerationState,
196
+ config: GenerationConfig,
197
+ start_step: int = 0,
198
+ logger: RuntimeLogger | None = None,
199
+ ) -> tuple[Optional[int], torch.Tensor]:
200
+ step_tokens = generation.step_tokens
201
+ audio_buf = generation.audio_buf
202
+ branches = step_tokens.shape[0]
203
+ max_context = runtime.config.runtime.max_context_steps
204
+ if max_context <= 0:
205
+ raise ValueError("Runtime configuration must specify a positive max_context_steps")
206
+ positions = torch.empty(1, 1, dtype=torch.long, device=runtime.device)
207
+ main_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
208
+ aux_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
209
+ cfg_active = config.cfg_scale != 1.0
210
+ token_ids = runtime.constants
211
+ delay_tensor = runtime.audio_delay_tensor
212
+ max_delay = int(delay_tensor.max().item()) if delay_tensor.numel() else 0
213
+ flush_tail = max_delay + getattr(runtime.machine, "max_padding", 0)
214
+ first_word_frame: Optional[int] = None
215
+ eos_cutoff: Optional[int] = None
216
+ last_step = start_step - 1
217
+ use_graph = bool(config.use_cuda_graph and runtime.device.type == "cuda")
218
+ transformer_step = runtime.transformer_step
219
+ depformer_step = runtime.depformer_step
220
+ buffers = _allocate_network_buffers(runtime, branches)
221
+ positions_view = positions.expand(branches, -1)
222
+ transformer_capture = None
223
+ dep_captures: list[dict] | None = None
224
+ if use_graph:
225
+ _ensure_graph_cublas_ready(runtime.device)
226
+ processed_steps = 0
227
+ report_interval = 12
228
+ with torch.inference_mode():
229
+ for offset in range(max_context):
230
+ t = start_step + offset
231
+ if eos_cutoff is not None and t >= eos_cutoff:
232
+ break
233
+ if t + 1 >= audio_buf.shape[-1]:
234
+ break
235
+ generation.reset_dep_cache()
236
+ positions.fill_(t)
237
+ _fill_audio_channels(step_tokens, audio_buf, delay_tensor, t, token_ids.audio_bos)
238
+ if branches > 1:
239
+ step_tokens[1:, 0, 0] = token_ids.zero
240
+ step_tokens[1:, 1, 0] = token_ids.pad
241
+ if use_graph:
242
+ if transformer_capture is None:
243
+ torch.cuda.synchronize()
244
+ graph = torch.cuda.CUDAGraph()
245
+ with torch.cuda.graph(graph):
246
+ hidden_ref = _execute_transformer_step(
247
+ step_tokens,
248
+ positions_view,
249
+ generation,
250
+ transformer_step,
251
+ buffers,
252
+ )
253
+ transformer_capture = (graph, hidden_ref)
254
+ if runtime.model.depformer.num_depth > 0:
255
+ dep_captures = []
256
+ for idx in range(runtime.model.depformer.num_depth):
257
+ capture = {
258
+ "graph": torch.cuda.CUDAGraph(),
259
+ "captured": False,
260
+ "prev_audio": torch.empty((branches,), dtype=torch.long, device=runtime.device),
261
+ "main_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
262
+ "second_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
263
+ }
264
+ dep_captures.append(capture)
265
+ else:
266
+ transformer_capture[0].replay()
267
+ hidden_t = transformer_capture[1]
268
+ else:
269
+ hidden_t = _execute_transformer_step(
270
+ step_tokens,
271
+ positions_view,
272
+ generation,
273
+ transformer_step,
274
+ buffers,
275
+ )
276
+
277
+ guided_text = apply_classifier_guidance(buffers.text, cfg_active, config.cfg_scale, config.cfg_filter_k)
278
+ if guided_text.shape[0] > 1:
279
+ guided_text = guided_text[:1]
280
+ text_token = sample_token(
281
+ guided_text,
282
+ temp=config.text.temperature,
283
+ top_k=config.text.top_k,
284
+ ).item()
285
+
286
+ main_token, aux_token, _ = runtime.machine.process(t, state, text_token)
287
+ second_token = aux_token if aux_token != -1 else token_ids.pad
288
+ if first_word_frame is None and main_token == token_ids.new_word:
289
+ first_word_frame = t - config.initial_padding
290
+ step_tokens[:, 0, 0] = main_token
291
+ step_tokens[:, 1, 0] = second_token
292
+
293
+ guided_cb0 = apply_classifier_guidance(buffers.cb0, cfg_active, config.cfg_scale, config.cfg_filter_k)
294
+ if guided_cb0.shape[0] > 1:
295
+ guided_cb0 = guided_cb0[:1]
296
+ masked_cb0 = mask_audio_logits(guided_cb0, token_ids.audio_pad, token_ids.audio_bos)
297
+ codebook_token = sample_audio_logits(masked_cb0, config.audio.temperature, config.audio.top_k)
298
+ audio_buf[:, 0, t + 1] = codebook_token
299
+
300
+ prev_audio = codebook_token.expand(branches)
301
+ main_tokens.fill_(main_token)
302
+ aux_tokens.fill_(second_token)
303
+ for stage in range(runtime.model.depformer.num_depth):
304
+ if use_graph and dep_captures is not None:
305
+ capture = dep_captures[stage]
306
+ capture["prev_audio"].copy_(prev_audio)
307
+ if capture["main_tokens"] is not None and stage == 0:
308
+ capture["main_tokens"].copy_(main_tokens)
309
+ capture["second_tokens"].copy_(aux_tokens)
310
+ if not capture["captured"]:
311
+ torch.cuda.synchronize()
312
+ with torch.cuda.graph(capture["graph"]):
313
+ _execute_depformer_stage(
314
+ stage_index=stage,
315
+ prev_audio=capture["prev_audio"],
316
+ hidden_t=hidden_t,
317
+ generation=generation,
318
+ depformer_step=depformer_step,
319
+ main_tokens=capture["main_tokens"],
320
+ second_tokens=capture["second_tokens"],
321
+ buffers=buffers,
322
+ )
323
+ capture["captured"] = True
324
+ else:
325
+ capture["graph"].replay()
326
+ else:
327
+ _execute_depformer_stage(
328
+ stage_index=stage,
329
+ prev_audio=prev_audio,
330
+ hidden_t=hidden_t,
331
+ generation=generation,
332
+ depformer_step=depformer_step,
333
+ main_tokens=main_tokens,
334
+ second_tokens=aux_tokens,
335
+ buffers=buffers,
336
+ )
337
+ dep_logits = apply_classifier_guidance(buffers.dep[stage], cfg_active, config.cfg_scale, config.cfg_filter_k)
338
+ if dep_logits.shape[0] > 1:
339
+ dep_logits = dep_logits[:1]
340
+ stage_token = sample_audio_logits(
341
+ dep_logits,
342
+ config.audio.temperature,
343
+ config.audio.top_k,
344
+ )
345
+ audio_buf[:, stage + 1, t + 1] = stage_token
346
+ prev_audio = stage_token.expand(branches)
347
+ last_step = t
348
+ if eos_cutoff is None and state.end_step is not None:
349
+ eos_cutoff = state.end_step + flush_tail
350
+ processed_steps = offset + 1
351
+ if logger and processed_steps % report_interval == 0:
352
+ logger.progress(processed_steps, max_context)
353
+
354
+ if logger and processed_steps and processed_steps % report_interval != 0:
355
+ logger.progress(processed_steps, max_context)
356
+
357
+ if first_word_frame is None:
358
+ first_word_frame = start_step
359
+ if last_step < start_step:
360
+ limit = min(start_step + 1, audio_buf.shape[-1])
361
+ else:
362
+ limit = min(last_step + 2, audio_buf.shape[-1])
363
+ trimmed = generation.trim_audio(limit, token_ids.audio_pad, token_ids.ungenerated)
364
+ return first_word_frame, trimmed
365
+
366
+
367
+ def decode_audio(runtime: RuntimeContext, tokens: torch.Tensor) -> torch.Tensor:
368
+ if tokens.shape[-1] == 0:
369
+ return torch.zeros(0, device=runtime.device)
370
+ with torch.inference_mode():
371
+ pcm = runtime.mimi.decode(tokens.to(runtime.device))
372
+ return pcm[0, 0]
373
+
374
+ def warmup_with_prefix(
375
+ runtime: RuntimeContext,
376
+ plan: PrefixPlan,
377
+ state: State,
378
+ generation: GenerationState,
379
+ ) -> int:
380
+ step_tokens = generation.step_tokens
381
+ model_state = generation.decode
382
+ branches = step_tokens.shape[0]
383
+ device = runtime.device
384
+ tokens = plan.aligned_tokens.to(device)
385
+ new_word_steps = set(plan.new_word_steps)
386
+ positions = torch.empty(1, 1, dtype=torch.long, device=device)
387
+
388
+ with torch.inference_mode():
389
+ for t in range(plan.aligned_frames):
390
+ positions.fill_(t)
391
+ channels = tokens.shape[0]
392
+ for cb in range(channels):
393
+ delay = runtime.audio_delays[cb] if cb < len(runtime.audio_delays) else 0
394
+ idx = t - delay
395
+ value = tokens[cb, idx] if idx >= 0 else runtime.constants.audio_bos
396
+ step_tokens[:, 2 + cb, 0] = value
397
+ hidden, text_logits, cb0_logits, present = runtime.model.transformer.forward_step(
398
+ step_tokens,
399
+ positions.expand(branches, -1),
400
+ model_state.transformer,
401
+ )
402
+ model_state.transformer = present
403
+
404
+ forced = runtime.constants.new_word if t in new_word_steps else runtime.constants.pad
405
+ main_token, aux_token, _ = runtime.machine.process(t, state, forced, is_forced=True)
406
+ second_token = runtime.constants.pad if aux_token == -1 else aux_token
407
+ step_tokens[0, 0, 0] = main_token
408
+ step_tokens[0, 1, 0] = second_token
409
+ if branches > 1:
410
+ step_tokens[1:, 0, 0] = runtime.constants.zero
411
+ step_tokens[1:, 1, 0] = runtime.constants.pad
412
+
413
+ return max(plan.aligned_frames - 1, 0)
414
+ __all__ = [
415
+ "build_initial_state",
416
+ "run_generation_loop",
417
+ "decode_audio",
418
+ "warmup_with_prefix",
419
+ "GenerationState",
420
+ ]
dia2/runtime/guidance.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ from .sampler import sample_token
6
+
7
+
8
+ def apply_classifier_guidance(
9
+ logits: torch.Tensor,
10
+ cfg_active: bool,
11
+ scale: float,
12
+ top_k: int,
13
+ ) -> torch.Tensor:
14
+ if not cfg_active:
15
+ return logits
16
+ conditional = logits[0:1]
17
+ unconditional = logits[1:2]
18
+ cond32 = conditional.to(torch.float32)
19
+ uncond32 = unconditional.to(torch.float32)
20
+ guided = torch.lerp(uncond32, cond32, scale)
21
+ if top_k > 0 and guided.shape[-1] > 0:
22
+ k = min(top_k, guided.shape[-1])
23
+ threshold = torch.topk(guided, k=k, dim=-1, sorted=False).values[..., -1:]
24
+ mask = guided >= threshold
25
+ neg_inf = torch.full_like(cond32, float("-inf"))
26
+ cond32 = torch.where(mask, cond32, neg_inf)
27
+ return cond32.to(conditional.dtype)
28
+
29
+
30
+ def sample_audio_logits(logits: torch.Tensor, temp: float, top_k: int) -> torch.Tensor:
31
+ """Sample a single audio token (shape [1]) from logits."""
32
+ return (
33
+ sample_token(
34
+ logits,
35
+ temp=temp,
36
+ top_k=top_k,
37
+ ).view(1)
38
+ )
dia2/runtime/logger.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ class RuntimeLogger:
5
+ def __init__(self, enabled: bool) -> None:
6
+ self.enabled = enabled
7
+ self.start_time = time.perf_counter()
8
+ self.last_time = self.start_time
9
+ self.last_step = 0
10
+
11
+ def event(self, message: str) -> None:
12
+ if self.enabled:
13
+ print(f"[dia2] {message}")
14
+
15
+ def progress(self, step: int, total: Optional[int] = None) -> None:
16
+ if not self.enabled:
17
+ return
18
+ now = time.perf_counter()
19
+ delta_t = max(now - self.last_time, 1e-6)
20
+ delta_steps = max(step - self.last_step, 1)
21
+ speed = delta_steps / delta_t
22
+ if total is None:
23
+ self.event(f"step {step} :: {speed:.1f} toks/s")
24
+ else:
25
+ self.event(f"step {step}/{total} :: {speed:.1f} toks/s")
26
+ self.last_time = now
27
+ self.last_step = step
28
+
29
+ def elapsed(self) -> float:
30
+ return time.perf_counter() - self.start_time
31
+
32
+
33
+ __all__ = ["RuntimeLogger"]
dia2/runtime/sampler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+
6
+ def sample_token(
7
+ logits: torch.Tensor,
8
+ *,
9
+ temp: float,
10
+ top_k: int = 0,
11
+ ) -> torch.Tensor:
12
+ logits32 = logits.to(torch.float32)
13
+ if temp <= 0.0:
14
+ return torch.argmax(logits32, dim=-1, keepdim=True)
15
+ probs = torch.softmax(logits32 / max(temp, 1e-6), dim=-1)
16
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
17
+ probs = torch.clamp_min(probs, 0.0)
18
+ flat = probs.reshape(-1, probs.shape[-1])
19
+ norm = flat.sum(dim=-1, keepdim=True)
20
+ zero_mask = norm <= 0
21
+ norm = norm.clamp_min(1e-12)
22
+ flat = flat / norm
23
+ if zero_mask.any():
24
+ filler = torch.zeros_like(flat)
25
+ filler[..., 0] = 1.0
26
+ mask = zero_mask.expand_as(flat)
27
+ flat = torch.where(mask, filler, flat)
28
+ vocab = flat.shape[-1]
29
+ if top_k > 0 and top_k < vocab:
30
+ topv, indices = torch.topk(flat, top_k, dim=-1)
31
+ topv = topv / topv.sum(dim=-1, keepdim=True).clamp_min(1e-12)
32
+ draws = torch.multinomial(topv, num_samples=1)
33
+ picks = torch.gather(indices, dim=-1, index=draws)
34
+ else:
35
+ picks = torch.multinomial(flat, num_samples=1)
36
+ picks = picks.reshape(*probs.shape[:-1], 1)
37
+ return picks
dia2/runtime/script_parser.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import List, Optional, Sequence
5
+
6
+ from .state_machine import Entry
7
+
8
+
9
+ def parse_script(
10
+ script: Sequence[str],
11
+ tokenizer,
12
+ constants,
13
+ frame_rate: float,
14
+ ) -> List[Entry]:
15
+ entries: List[Entry] = []
16
+ speaker_tokens = [constants.spk1, constants.spk2]
17
+ padding_between = 1
18
+ event_re = re.compile(r"(?:<break\s+time=\"([0-9]+(?:.[0-9]*)?)s\"\s*/?>)|(?:\s+)")
19
+ last_speaker_idx = [None]
20
+
21
+ def add_entry(idx: int, word: str, *, pending: Optional[int], first_content: List[bool]):
22
+ tokens: List[int]
23
+ if pending is not None:
24
+ prefix = "[S1]" if pending == constants.spk1 else "[S2]"
25
+ tokens = tokenizer.encode(f"{prefix} {word}", add_special_tokens=False)
26
+ else:
27
+ tokens = tokenizer.encode(word, add_special_tokens=False)
28
+ if first_content[0]:
29
+ if speaker_tokens:
30
+ speaker_idx = idx % len(speaker_tokens)
31
+ speaker_token = speaker_tokens[speaker_idx]
32
+ if speaker_token is not None and last_speaker_idx[0] != speaker_idx:
33
+ if not tokens or tokens[0] != speaker_token:
34
+ tokens.insert(0, speaker_token)
35
+ last_speaker_idx[0] = speaker_idx
36
+ first_content[0] = False
37
+ padding = max(0, padding_between + len(tokens) - 1)
38
+ entries.append(Entry(tokens=tokens, text=word, padding=padding))
39
+
40
+ for idx, line in enumerate(script):
41
+ normalized = line.replace("’", "'").replace(":", " ")
42
+ remaining = normalized
43
+ first_content = [True]
44
+ pending_speaker: Optional[int] = None
45
+ while remaining:
46
+ match = event_re.search(remaining)
47
+ if match is None:
48
+ segment = remaining
49
+ remaining = ""
50
+ else:
51
+ segment = remaining[: match.start()]
52
+ remaining = remaining[match.end() :]
53
+ if segment:
54
+ for raw_word in segment.split():
55
+ if raw_word in ("[S1]", "[S2]"):
56
+ pending_speaker = (
57
+ constants.spk1 if raw_word == "[S1]" else constants.spk2
58
+ )
59
+ continue
60
+ add_entry(idx, raw_word, pending=pending_speaker, first_content=first_content)
61
+ pending_speaker = None
62
+ if match and match.group(1):
63
+ seconds = float(match.group(1))
64
+ padding = int(round(seconds * frame_rate))
65
+ if padding > 0:
66
+ entries.append(Entry(tokens=[], text="", padding=padding))
67
+ if remaining:
68
+ continue
69
+ return entries
dia2/runtime/state_machine.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import deque
4
+ from dataclasses import dataclass, field
5
+ from typing import Deque, Iterable, List, Sequence, Tuple
6
+
7
+
8
+ @dataclass
9
+ class TokenIds:
10
+ card: int
11
+ new_word: int
12
+ pad: int
13
+ bos: int
14
+ zero: int
15
+ spk1: int
16
+ spk2: int
17
+ audio_pad: int
18
+ audio_bos: int
19
+ ungenerated: int = -2
20
+
21
+
22
+ @dataclass
23
+ class Entry:
24
+ tokens: List[int]
25
+ text: str
26
+ padding: int = 0
27
+
28
+
29
+ @dataclass
30
+ class State:
31
+ entries: Deque[Entry]
32
+ padding_budget: int
33
+ forced_padding: int
34
+ pending_tokens: Deque[int] = field(default_factory=deque)
35
+ lookahead_tokens: Deque[int] = field(default_factory=deque)
36
+ end_step: int | None = None
37
+ consumption_times: List[int] = field(default_factory=list)
38
+ transcript: List[Tuple[str, int]] = field(default_factory=list)
39
+
40
+ def peek_tokens(self, count: int) -> List[int]:
41
+ """Return tokens from upcoming entries (used for second-stream lookahead)."""
42
+ assert count > 0
43
+ for entry in self.entries:
44
+ if entry.tokens:
45
+ count -= 1
46
+ if count == 0:
47
+ return entry.tokens
48
+ return []
49
+
50
+
51
+ class StateMachine:
52
+ def __init__(
53
+ self,
54
+ token_ids: TokenIds,
55
+ *,
56
+ second_stream_ahead: int = 0,
57
+ max_padding: int = 6,
58
+ initial_padding: int = 0,
59
+ ) -> None:
60
+ self.token_ids = token_ids
61
+ self.second_stream_ahead = second_stream_ahead
62
+ self.max_padding = max_padding
63
+ self.initial_padding = initial_padding
64
+
65
+ def new_state(self, entries: Iterable[Entry]) -> State:
66
+ return State(
67
+ entries=deque(entries),
68
+ padding_budget=self.initial_padding,
69
+ forced_padding=self.initial_padding,
70
+ )
71
+
72
+ def process(
73
+ self,
74
+ step: int,
75
+ state: State,
76
+ token: int,
77
+ is_forced: bool = False,
78
+ ) -> Tuple[int, int, bool]:
79
+ token = self._sanitize_token(token)
80
+ token = self._enforce_token_constraints(state, token, is_forced)
81
+ token, consumed_new_word = self._handle_new_word(step, state, token)
82
+ output_token = self._select_output_token(state, token)
83
+ final_main, final_second = self._maybe_multiplex_second_stream(
84
+ state, output_token
85
+ )
86
+ return final_main, final_second, consumed_new_word
87
+
88
+ def _sanitize_token(self, token: int) -> int:
89
+ if token == 1:
90
+ token = self.token_ids.new_word
91
+ elif token == 0:
92
+ token = self.token_ids.pad
93
+ if token not in (self.token_ids.new_word, self.token_ids.pad):
94
+ return self.token_ids.pad
95
+ return token
96
+
97
+ def _enforce_token_constraints(
98
+ self, state: State, token: int, is_forced: bool
99
+ ) -> int:
100
+ if state.pending_tokens:
101
+ return self.token_ids.pad
102
+ if is_forced:
103
+ return token
104
+ if state.forced_padding > 0:
105
+ if token != self.token_ids.pad:
106
+ token = self.token_ids.pad
107
+ return token
108
+ if state.padding_budget <= 0 and token != self.token_ids.new_word:
109
+ return self.token_ids.new_word
110
+ return token
111
+
112
+ def _handle_new_word(
113
+ self, step: int, state: State, token: int
114
+ ) -> Tuple[int, bool]:
115
+ if token != self.token_ids.new_word:
116
+ return token, False
117
+ if state.entries:
118
+ entry = state.entries.popleft()
119
+ state.consumption_times.append(step)
120
+ if entry.tokens:
121
+ state.transcript.append((entry.text, step))
122
+ state.pending_tokens.extend(entry.tokens)
123
+ if self.second_stream_ahead:
124
+ state.lookahead_tokens.extend(
125
+ state.peek_tokens(self.second_stream_ahead)
126
+ )
127
+ state.padding_budget = self.max_padding
128
+ else:
129
+ token = self.token_ids.pad
130
+ state.forced_padding = entry.padding
131
+ return token, True
132
+ token = self.token_ids.pad
133
+ if self.second_stream_ahead and state.end_step is None:
134
+ token = self.token_ids.new_word
135
+ if state.end_step is None:
136
+ state.end_step = step
137
+ return token, False
138
+
139
+ def _select_output_token(self, state: State, token: int) -> int:
140
+ if token == self.token_ids.pad:
141
+ if state.padding_budget > 0:
142
+ state.padding_budget -= 1
143
+ if state.forced_padding > 0:
144
+ state.forced_padding -= 1
145
+ if state.pending_tokens:
146
+ return state.pending_tokens.popleft()
147
+ return self.token_ids.pad
148
+ if token == self.token_ids.new_word:
149
+ return self.token_ids.new_word
150
+ if token == self.token_ids.zero:
151
+ return token
152
+ raise RuntimeError(f"Invalid token {token}")
153
+
154
+ def _maybe_multiplex_second_stream(
155
+ self, state: State, output: int
156
+ ) -> Tuple[int, int]:
157
+ if not self.second_stream_ahead:
158
+ return output, output
159
+ second = -1
160
+ if output == self.token_ids.new_word:
161
+ second = self.token_ids.new_word
162
+ if state.pending_tokens:
163
+ output = state.pending_tokens.popleft()
164
+ else:
165
+ output = self.token_ids.pad
166
+ elif state.lookahead_tokens:
167
+ second = state.lookahead_tokens.popleft()
168
+ else:
169
+ second = self.token_ids.pad
170
+ return output, second
dia2/runtime/voice_clone.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable, List, Optional, Sequence, TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from ..generation import PrefixConfig
10
+ from .audio_io import encode_audio_tokens, load_mono_audio
11
+ from .state_machine import Entry
12
+
13
+ if TYPE_CHECKING: # pragma: no cover
14
+ from .context import RuntimeContext
15
+
16
+
17
+ @dataclass
18
+ class WhisperWord:
19
+ text: str
20
+ start: float
21
+ end: float
22
+
23
+
24
+ @dataclass
25
+ class PrefixPlan:
26
+ entries: List[Entry]
27
+ new_word_steps: List[int]
28
+ aligned_tokens: torch.Tensor
29
+ aligned_frames: int
30
+
31
+
32
+ def build_prefix_plan(
33
+ runtime: "RuntimeContext",
34
+ prefix: Optional[PrefixConfig],
35
+ *,
36
+ transcribe_fn: Optional[Callable[[str, torch.device], List[WhisperWord]]] = None,
37
+ load_audio_fn: Optional[Callable[[str, int], np.ndarray]] = None,
38
+ encode_fn: Optional[Callable[[np.ndarray], torch.Tensor]] = None,
39
+ ) -> Optional[PrefixPlan]:
40
+ if prefix is None:
41
+ return None
42
+ if not prefix.speaker_1:
43
+ if prefix.speaker_2:
44
+ raise ValueError("speaker_2 requires speaker_1 to be provided")
45
+ return None
46
+
47
+ transcribe = transcribe_fn or (lambda path, device: transcribe_words(path, device))
48
+ load_audio = load_audio_fn or (lambda path, sr: load_mono_audio(path, sr))
49
+ encode_audio = encode_fn or (lambda audio: encode_audio_tokens(runtime.mimi, audio))
50
+
51
+ entries1, steps1, tokens1 = _process_prefix_audio(
52
+ runtime=runtime,
53
+ audio_path=prefix.speaker_1,
54
+ speaker_token=runtime.constants.spk1,
55
+ transcribe=transcribe,
56
+ load_audio=load_audio,
57
+ encode_audio=encode_audio,
58
+ )
59
+ offset = 3 # Match legacy BOS/PAD offset
60
+ entries = list(entries1)
61
+ new_word_steps = [step + offset for step in steps1]
62
+ audio_tokens = tokens1.to(runtime.device)
63
+
64
+ if prefix.speaker_2:
65
+ entries2, steps2, tokens2 = _process_prefix_audio(
66
+ runtime=runtime,
67
+ audio_path=prefix.speaker_2,
68
+ speaker_token=runtime.constants.spk2,
69
+ transcribe=transcribe,
70
+ load_audio=load_audio,
71
+ encode_audio=encode_audio,
72
+ )
73
+ spk1_frames = audio_tokens.shape[-1]
74
+ new_word_steps.extend(step + spk1_frames for step in steps2)
75
+ entries.extend(entries2)
76
+ audio_tokens = torch.cat([audio_tokens, tokens2.to(runtime.device)], dim=1)
77
+
78
+ return PrefixPlan(
79
+ entries=entries,
80
+ new_word_steps=new_word_steps,
81
+ aligned_tokens=audio_tokens,
82
+ aligned_frames=audio_tokens.shape[-1],
83
+ )
84
+
85
+
86
+ def _process_prefix_audio(
87
+ runtime: "RuntimeContext",
88
+ audio_path: str,
89
+ speaker_token: int,
90
+ *,
91
+ transcribe: Callable[[str, torch.device], List[WhisperWord]],
92
+ load_audio: Callable[[str, int], np.ndarray],
93
+ encode_audio: Callable[[np.ndarray], torch.Tensor],
94
+ ) -> tuple[List[Entry], List[int], torch.Tensor]:
95
+ words = transcribe(audio_path, runtime.device)
96
+ entries, steps = words_to_entries(
97
+ words=words,
98
+ tokenizer=runtime.tokenizer,
99
+ speaker_token=speaker_token,
100
+ frame_rate=runtime.frame_rate,
101
+ )
102
+ audio = load_audio(audio_path, runtime.mimi.sample_rate)
103
+ tokens = encode_audio(audio)
104
+ return entries, steps, tokens
105
+
106
+
107
+ def transcribe_words(
108
+ audio_path: str,
109
+ device: torch.device,
110
+ language: Optional[str] = None,
111
+ ) -> List[WhisperWord]:
112
+ import whisper_timestamped as wts # Imported lazily
113
+
114
+ model = wts.load_model("openai/whisper-large-v3", device=str(device))
115
+ result = wts.transcribe(model, audio_path, language=language)
116
+
117
+ words: List[WhisperWord] = []
118
+ for segment in result.get("segments", []):
119
+ for word in segment.get("words", []):
120
+ text = (word.get("text") or word.get("word") or "").strip()
121
+ if not text:
122
+ continue
123
+ words.append(
124
+ WhisperWord(
125
+ text=text,
126
+ start=float(word.get("start", 0.0)),
127
+ end=float(word.get("end", 0.0)),
128
+ )
129
+ )
130
+ return words
131
+
132
+
133
+ def words_to_entries(
134
+ *,
135
+ words: Sequence[WhisperWord],
136
+ tokenizer,
137
+ speaker_token: int,
138
+ frame_rate: float,
139
+ ) -> tuple[List[Entry], List[int]]:
140
+ entries: List[Entry] = []
141
+ new_word_steps: List[int] = []
142
+ if not words:
143
+ return entries, new_word_steps
144
+
145
+ convert = getattr(tokenizer, "convert_tokens_to_ids", None)
146
+ speaker_prefix: Optional[str] = None
147
+ if callable(convert):
148
+ s1_id = convert("[S1]")
149
+ s2_id = convert("[S2]")
150
+ if speaker_token == s1_id:
151
+ speaker_prefix = "[S1]"
152
+ elif speaker_token == s2_id:
153
+ speaker_prefix = "[S2]"
154
+ pending_prefix: Optional[str] = speaker_prefix
155
+ current_pos = 0
156
+
157
+ for idx, word in enumerate(words):
158
+ tokens = _encode_word(word.text, tokenizer, pending_prefix)
159
+ pending_prefix = None
160
+ start_frame = max(current_pos + 1, int(round(word.start * frame_rate)))
161
+ end_frame = start_frame + len(tokens)
162
+ new_word_steps.append(start_frame - 1)
163
+
164
+ if idx < len(words) - 1:
165
+ next_start = int(round(words[idx + 1].start * frame_rate))
166
+ next_word_start = max(end_frame + 1, next_start)
167
+ else:
168
+ end_time = int(round(words[-1].end * frame_rate))
169
+ next_word_start = max(end_frame + 1, end_time)
170
+
171
+ padding = max(0, next_word_start - start_frame - 1)
172
+ entries.append(Entry(tokens=tokens, text=word.text, padding=padding))
173
+ current_pos = end_frame
174
+
175
+ return entries, new_word_steps
176
+
177
+
178
+ def _encode_word(text: str, tokenizer, prefix: Optional[str]) -> List[int]:
179
+ if prefix:
180
+ return tokenizer.encode(f"{prefix} {text}", add_special_tokens=False)
181
+ return tokenizer.encode(text, add_special_tokens=False)
182
+
183
+
184
+ __all__ = [
185
+ "PrefixPlan",
186
+ "WhisperWord",
187
+ "build_prefix_plan",
188
+ "transcribe_words",
189
+ "words_to_entries",
190
+ ]
input.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ [S1] Um, like, I don't know, I've never actually, like, been on a real vacation, you know? [S2] Oh, seriously? That's wild. I've, uh, only been on, like, one trip myself, and it was kinda stressful. [S1] Yeah, I always see people going places on, like, Instagram, but then I'm just, um, at home thinking, "Maybe next year." [S2] Honestly, same. I, like, plan stuff in my head but then forget or just, you know, bail at the last minute. [S1] So, we should, like, totally go somewhere together one day. [S2] For real, that would be awesome.
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=70.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "dia2"
7
+ version = "0.1.0"
8
+ description = "Dia2 CUDA-only text-to-speech runtime"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ authors = [{ name = "Dia Contributors" }]
12
+ license = { file = "LICENSE" }
13
+ dependencies = [
14
+ "torch>=2.8.0",
15
+ "numpy>=2.1.0,<3.0",
16
+ "transformers>=4.55.3",
17
+ "safetensors==0.5.3",
18
+ "huggingface-hub>=0.24.7",
19
+ "sphn>=0.2.0",
20
+ "soundfile>=0.12.1",
21
+ "whisper-timestamped>=1.14.2",
22
+ "gradio>=4.44.1",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ dev = [
27
+ "ruff>=0.6.9",
28
+ "pyright>=1.1.385",
29
+ ]
30
+
31
+ [tool.uv]
32
+ package = true
33
+
34
+ [tool.uv.sources]
35
+ torch = [
36
+ { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
37
+ ]
38
+
39
+ [[tool.uv.index]]
40
+ name = "pytorch-cu128"
41
+ url = "https://download.pytorch.org/whl/cu128"
42
+ explicit = true
43
+
44
+ [tool.setuptools]
45
+ packages = ["dia2"]
uv.lock ADDED
The diff for this file is too large to render. See raw diff