Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +12 -0
- LICENSE +201 -0
- README.md +103 -12
- app.py +251 -0
- banner.gif +3 -0
- dia2/__init__.py +20 -0
- dia2/assets.py +65 -0
- dia2/audio/__init__.py +13 -0
- dia2/audio/codec.py +58 -0
- dia2/audio/grid.py +79 -0
- dia2/cli.py +122 -0
- dia2/config.py +180 -0
- dia2/core/__init__.py +10 -0
- dia2/core/cache.py +106 -0
- dia2/core/depformer.py +264 -0
- dia2/core/layers.py +209 -0
- dia2/core/model.py +72 -0
- dia2/core/precision.py +23 -0
- dia2/core/transformer.py +140 -0
- dia2/engine.py +230 -0
- dia2/generation.py +158 -0
- dia2/runtime/__init__.py +7 -0
- dia2/runtime/audio_io.py +69 -0
- dia2/runtime/context.py +132 -0
- dia2/runtime/generator.py +420 -0
- dia2/runtime/guidance.py +38 -0
- dia2/runtime/logger.py +33 -0
- dia2/runtime/sampler.py +37 -0
- dia2/runtime/script_parser.py +69 -0
- dia2/runtime/state_machine.py +170 -0
- dia2/runtime/voice_clone.py +190 -0
- input.txt +1 -0
- pyproject.toml +45 -0
- uv.lock +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
banner.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
_kyutai
|
| 3 |
+
__pycache__
|
| 4 |
+
*.npz
|
| 5 |
+
*.safetensors
|
| 6 |
+
*.model
|
| 7 |
+
*.DS_Store
|
| 8 |
+
*.parquet
|
| 9 |
+
*.wav
|
| 10 |
+
*.mp3
|
| 11 |
+
weights/
|
| 12 |
+
*.egg-info/
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2025 Nari Labs
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,13 +1,104 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<a href="https://huggingface.co/nari-labs/Dia2-2B"><img src="https://img.shields.io/badge/HF%20Repo-Dia2--2B-orange?style=for-the-badge"></a>
|
| 5 |
+
<a href="https://discord.gg/bJq6vjRRKv"><img src="https://img.shields.io/badge/Discord-Join%20Chat-7289DA?logo=discord&style=for-the-badge"></a>
|
| 6 |
+
<a href="https://github.com/nari-labs/dia2/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg?style=for-the-badge"></a>
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
**Dia2** is a **streaming dialogue TTS model** created by Nari Labs.
|
| 11 |
+
|
| 12 |
+
The model does not need the entire text to produce the audio, and can start generating as the first few words are given as input. You can condition the output on audio, enabling natural conversations in realtime.
|
| 13 |
+
|
| 14 |
+
We provide model checkpoints (1B, 2B) and inference code to accelerate research. The model only supports up to 2 minutes of generation in English.
|
| 15 |
+
|
| 16 |
+
⚠️ Quality and voices vary per generation, as the model is not fine-tuned on a specific voice. Use with prefix or fine-tune in order to obtain stable output.
|
| 17 |
+
|
| 18 |
+
## Upcoming
|
| 19 |
+
|
| 20 |
+
- Bonsai (JAX) implementation
|
| 21 |
+
- Dia2 TTS Server: Real streaming support
|
| 22 |
+
- Sori: Dia2-powered speech-to-speech engine written in Rust
|
| 23 |
+
|
| 24 |
+
## Quickstart
|
| 25 |
+
|
| 26 |
+
> **Requirement** — install [uv](https://docs.astral.sh/uv/) and use CUDA 12.8+
|
| 27 |
+
> drivers. All commands below run through `uv run …` as a rule.
|
| 28 |
+
|
| 29 |
+
1. **Install dependencies (one-time):**
|
| 30 |
+
```bash
|
| 31 |
+
uv sync
|
| 32 |
+
```
|
| 33 |
+
2. **Prepare a script:** edit `input.txt` using `[S1]` / `[S2]` speaker tags.
|
| 34 |
+
3. **Generate audio:**
|
| 35 |
+
```bash
|
| 36 |
+
uv run -m dia2.cli \
|
| 37 |
+
--hf nari-labs/Dia2-2B \
|
| 38 |
+
--input input.txt \
|
| 39 |
+
--cfg 6.0 --temperature 0.8 \
|
| 40 |
+
--cuda-graph --verbose \
|
| 41 |
+
output.wav
|
| 42 |
+
```
|
| 43 |
+
The first run downloads weights/tokenizer/Mimi. The CLI auto-selects CUDA when available (otherwise CPU) and defaults to bfloat16 precision—override with `--device` / `--dtype` if needed.
|
| 44 |
+
4. **Conditional Generation (recommended for stable use):**
|
| 45 |
+
```bash
|
| 46 |
+
uv run -m dia2.cli \
|
| 47 |
+
--hf nari-labs/Dia2-2B \
|
| 48 |
+
--input input.txt \
|
| 49 |
+
--prefix-speaker-1 example_prefix1.wav \
|
| 50 |
+
--prefix-speaker-2 example_prefix2.wav \
|
| 51 |
+
--cuda-graph --verbose \
|
| 52 |
+
output_conditioned.wav
|
| 53 |
+
```
|
| 54 |
+
Condition the generation on previous conversational context in order to generate natural output for your speech-to-speech system. For example, place the voice of your assistant as prefix speaker 1, place user's audio input as prefix speaker 2, and generate the response to user's input.
|
| 55 |
|
| 56 |
+
Whisper is used to transcribe each prefix file, which takes additional time. We include example prefix files as `example_prefix1.wav` and `example_prefix2.wav` (both files are output created by the model).
|
| 57 |
+
6. **Gradio for Easy Usage**
|
| 58 |
+
```bash
|
| 59 |
+
uv run gradio_app.py
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### Programmatic Usage
|
| 63 |
+
```python
|
| 64 |
+
from dia2 import Dia2, GenerationConfig, SamplingConfig
|
| 65 |
+
|
| 66 |
+
dia = Dia2.from_repo("nari-labs/Dia2-2B", device="cuda", dtype="bfloat16")
|
| 67 |
+
config = GenerationConfig(
|
| 68 |
+
cfg_scale=2.0,
|
| 69 |
+
audio=SamplingConfig(temperature=0.8, top_k=50),
|
| 70 |
+
use_cuda_graph=True,
|
| 71 |
+
)
|
| 72 |
+
result = dia.generate("[S1] Hello Dia2!", config=config, output_wav="hello.wav", verbose=True)
|
| 73 |
+
```
|
| 74 |
+
Generation runs until the runtime config's `max_context_steps` (1500, 2 minutes)
|
| 75 |
+
or until EOS is detected. `GenerationResult` includes audio tokens, waveform tensor,
|
| 76 |
+
and word timestamps relative to Mimi’s ~12.5 Hz frame rate.
|
| 77 |
+
|
| 78 |
+
## Hugging Face
|
| 79 |
+
|
| 80 |
+
| Variant | Repo |
|
| 81 |
+
| --- | --- |
|
| 82 |
+
| Dia2-1B | [`nari-labs/Dia2-1B`](https://huggingface.co/nari-labs/Dia2-1B)
|
| 83 |
+
| Dia2-2B | [`nari-labs/Dia2-2B`](https://huggingface.co/nari-labs/Dia2-2B)
|
| 84 |
+
|
| 85 |
+
## License & Attribution
|
| 86 |
+
|
| 87 |
+
Licensed under [Apache 2.0](LICENSE). All third-party assets (Kyutai Mimi codec, etc.) retain their original licenses.
|
| 88 |
+
|
| 89 |
+
## Disclaimer
|
| 90 |
+
|
| 91 |
+
This project offers a high-fidelity speech generation model intended for research and educational use. The following uses are **strictly forbidden**:
|
| 92 |
+
|
| 93 |
+
- **Identity Misuse**: Do not produce audio resembling real individuals without permission.
|
| 94 |
+
- **Deceptive Content**: Do not use this model to generate misleading content (e.g. fake news)
|
| 95 |
+
- **Illegal or Malicious Use**: Do not use this model for activities that are illegal or intended to cause harm.
|
| 96 |
+
|
| 97 |
+
By using this model, you agree to uphold relevant legal standards and ethical responsibilities. We **are not responsible** for any misuse and firmly oppose any unethical usage of this technology.
|
| 98 |
+
|
| 99 |
+
## Acknowledgements
|
| 100 |
+
- We thank the [TPU Research Cloud](https://sites.research.google/trc/about/) program for providing compute for training.
|
| 101 |
+
- Our work was heavily inspired by [KyutaiTTS](https://kyutai.org/next/tts) and [Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice)
|
| 102 |
+
|
| 103 |
+
---
|
| 104 |
+
Questions? Join our [Discord](https://discord.gg/bJq6vjRRKv) or open an issue.
|
app.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from dia2 import Dia2, GenerationConfig, SamplingConfig
|
| 13 |
+
|
| 14 |
+
DEFAULT_REPO = os.environ.get("DIA2_DEFAULT_REPO", "nari-labs/Dia2-2B")
|
| 15 |
+
MAX_TURNS = 10
|
| 16 |
+
INITIAL_TURNS = 2
|
| 17 |
+
|
| 18 |
+
_dia: Dia2 | None = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_dia() -> Dia2:
|
| 22 |
+
global _dia
|
| 23 |
+
if _dia is None:
|
| 24 |
+
_dia = Dia2.from_repo(DEFAULT_REPO, device="cuda", dtype="bfloat16")
|
| 25 |
+
return _dia
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _concat_script(turn_count: int, turn_values: List[str]) -> str:
|
| 29 |
+
lines: List[str] = []
|
| 30 |
+
for idx in range(min(turn_count, len(turn_values))):
|
| 31 |
+
text = (turn_values[idx] or "").strip()
|
| 32 |
+
if not text:
|
| 33 |
+
continue
|
| 34 |
+
speaker = "[S1]" if idx % 2 == 0 else "[S2]"
|
| 35 |
+
lines.append(f"{speaker} {text}")
|
| 36 |
+
return "\n".join(lines)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
EXAMPLES: dict[str, List[str]] = {
|
| 40 |
+
"Intro": [
|
| 41 |
+
"Hello Dia2 fans! Today we're unveiling the new open TTS model.",
|
| 42 |
+
"Sounds exciting. Can you show a sample right now?",
|
| 43 |
+
"Absolutely. (laughs) Just press generate.",
|
| 44 |
+
],
|
| 45 |
+
"Customer Support": [
|
| 46 |
+
"Thanks for calling. How can I help you today?",
|
| 47 |
+
"My parcel never arrived and it's been two weeks.",
|
| 48 |
+
"I'm sorry about that. Let me check your tracking number.",
|
| 49 |
+
"Appreciate it. I really need that package soon.",
|
| 50 |
+
],
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _apply_turn_visibility(count: int) -> List[gr.Update]:
|
| 55 |
+
return [gr.update(visible=i < count) for i in range(MAX_TURNS)]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _add_turn(count: int):
|
| 59 |
+
count = min(count + 1, MAX_TURNS)
|
| 60 |
+
return (count, *_apply_turn_visibility(count))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _remove_turn(count: int):
|
| 64 |
+
count = max(1, count - 1)
|
| 65 |
+
return (count, *_apply_turn_visibility(count))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _load_example(name: str, count: int):
|
| 69 |
+
data = EXAMPLES.get(name)
|
| 70 |
+
if not data:
|
| 71 |
+
return (count, *_apply_turn_visibility(count))
|
| 72 |
+
new_count = min(len(data), MAX_TURNS)
|
| 73 |
+
updates: List[gr.Update] = []
|
| 74 |
+
for idx in range(MAX_TURNS):
|
| 75 |
+
if idx < new_count:
|
| 76 |
+
updates.append(gr.update(value=data[idx], visible=True))
|
| 77 |
+
else:
|
| 78 |
+
updates.append(gr.update(value="", visible=idx < INITIAL_TURNS))
|
| 79 |
+
return (new_count, *updates)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _prepare_prefix(file_path: str | None) -> str | None:
|
| 83 |
+
if not file_path:
|
| 84 |
+
return None
|
| 85 |
+
path = Path(file_path)
|
| 86 |
+
if not path.exists():
|
| 87 |
+
return None
|
| 88 |
+
return str(path)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def generate_audio(
|
| 92 |
+
turn_count: int,
|
| 93 |
+
*inputs,
|
| 94 |
+
):
|
| 95 |
+
turn_values = list(inputs[:MAX_TURNS])
|
| 96 |
+
voice_s1 = inputs[MAX_TURNS]
|
| 97 |
+
voice_s2 = inputs[MAX_TURNS + 1]
|
| 98 |
+
cfg_scale = float(inputs[MAX_TURNS + 2])
|
| 99 |
+
text_temperature = float(inputs[MAX_TURNS + 3])
|
| 100 |
+
audio_temperature = float(inputs[MAX_TURNS + 4])
|
| 101 |
+
text_top_k = int(inputs[MAX_TURNS + 5])
|
| 102 |
+
audio_top_k = int(inputs[MAX_TURNS + 6])
|
| 103 |
+
include_prefix = bool(inputs[MAX_TURNS + 7])
|
| 104 |
+
|
| 105 |
+
script = _concat_script(turn_count, turn_values)
|
| 106 |
+
if not script.strip():
|
| 107 |
+
raise gr.Error("Please enter at least one non-empty speaker turn.")
|
| 108 |
+
|
| 109 |
+
dia = _get_dia()
|
| 110 |
+
config = GenerationConfig(
|
| 111 |
+
cfg_scale=cfg_scale,
|
| 112 |
+
text=SamplingConfig(temperature=text_temperature, top_k=text_top_k),
|
| 113 |
+
audio=SamplingConfig(temperature=audio_temperature, top_k=audio_top_k),
|
| 114 |
+
use_cuda_graph=True,
|
| 115 |
+
)
|
| 116 |
+
kwargs = {
|
| 117 |
+
"prefix_speaker_1": _prepare_prefix(voice_s1),
|
| 118 |
+
"prefix_speaker_2": _prepare_prefix(voice_s2),
|
| 119 |
+
"include_prefix": include_prefix,
|
| 120 |
+
}
|
| 121 |
+
buffer = io.StringIO()
|
| 122 |
+
with contextlib.redirect_stdout(buffer):
|
| 123 |
+
result = dia.generate(
|
| 124 |
+
script,
|
| 125 |
+
config=config,
|
| 126 |
+
output_wav=None,
|
| 127 |
+
verbose=True,
|
| 128 |
+
**kwargs,
|
| 129 |
+
)
|
| 130 |
+
waveform = result.waveform.detach().cpu().numpy()
|
| 131 |
+
sample_rate = result.sample_rate
|
| 132 |
+
timestamps = result.timestamps
|
| 133 |
+
log_text = buffer.getvalue().strip()
|
| 134 |
+
table = [[w, round(t, 3)] for w, t in timestamps]
|
| 135 |
+
return (sample_rate, waveform), table, log_text or "Generation finished."
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def build_interface() -> gr.Blocks:
|
| 139 |
+
with gr.Blocks(
|
| 140 |
+
title="Dia2 TTS", css=".compact-turn textarea {min-height: 60px}"
|
| 141 |
+
) as demo:
|
| 142 |
+
gr.Markdown(
|
| 143 |
+
"""## Dia2 — Open TTS Model
|
| 144 |
+
Compose dialogue, attach optional voice prompts, and generate audio (CUDA graphs enabled by default)."""
|
| 145 |
+
)
|
| 146 |
+
turn_state = gr.State(INITIAL_TURNS)
|
| 147 |
+
with gr.Row(equal_height=True):
|
| 148 |
+
example_dropdown = gr.Dropdown(
|
| 149 |
+
choices=["(select example)"] + list(EXAMPLES.keys()),
|
| 150 |
+
label="Examples",
|
| 151 |
+
value="(select example)",
|
| 152 |
+
)
|
| 153 |
+
with gr.Row(equal_height=True):
|
| 154 |
+
with gr.Column(scale=1):
|
| 155 |
+
with gr.Group():
|
| 156 |
+
gr.Markdown("### Script")
|
| 157 |
+
controls = []
|
| 158 |
+
for idx in range(MAX_TURNS):
|
| 159 |
+
speaker = "[S1]" if idx % 2 == 0 else "[S2]"
|
| 160 |
+
box = gr.Textbox(
|
| 161 |
+
label=f"{speaker} turn {idx + 1}",
|
| 162 |
+
lines=2,
|
| 163 |
+
elem_classes=["compact-turn"],
|
| 164 |
+
placeholder=f"Enter dialogue for {speaker}…",
|
| 165 |
+
visible=idx < INITIAL_TURNS,
|
| 166 |
+
)
|
| 167 |
+
controls.append(box)
|
| 168 |
+
with gr.Row():
|
| 169 |
+
add_btn = gr.Button("Add Turn")
|
| 170 |
+
remove_btn = gr.Button("Remove Turn")
|
| 171 |
+
with gr.Group():
|
| 172 |
+
gr.Markdown("### Voice Prompts")
|
| 173 |
+
with gr.Row():
|
| 174 |
+
voice_s1 = gr.File(
|
| 175 |
+
label="[S1] voice (wav/mp3)", type="filepath"
|
| 176 |
+
)
|
| 177 |
+
voice_s2 = gr.File(
|
| 178 |
+
label="[S2] voice (wav/mp3)", type="filepath"
|
| 179 |
+
)
|
| 180 |
+
with gr.Group():
|
| 181 |
+
gr.Markdown("### Sampling")
|
| 182 |
+
cfg_scale = gr.Slider(
|
| 183 |
+
1.0, 8.0, value=6.0, step=0.1, label="CFG Scale"
|
| 184 |
+
)
|
| 185 |
+
with gr.Group():
|
| 186 |
+
gr.Markdown("#### Text Sampling")
|
| 187 |
+
text_temperature = gr.Slider(
|
| 188 |
+
0.1, 1.5, value=0.6, step=0.05, label="Text Temperature"
|
| 189 |
+
)
|
| 190 |
+
text_top_k = gr.Slider(
|
| 191 |
+
1, 200, value=50, step=1, label="Text Top-K"
|
| 192 |
+
)
|
| 193 |
+
with gr.Group():
|
| 194 |
+
gr.Markdown("#### Audio Sampling")
|
| 195 |
+
audio_temperature = gr.Slider(
|
| 196 |
+
0.1, 1.5, value=0.8, step=0.05, label="Audio Temperature"
|
| 197 |
+
)
|
| 198 |
+
audio_top_k = gr.Slider(
|
| 199 |
+
1, 200, value=50, step=1, label="Audio Top-K"
|
| 200 |
+
)
|
| 201 |
+
include_prefix = gr.Checkbox(
|
| 202 |
+
label="Keep prefix audio in output", value=False
|
| 203 |
+
)
|
| 204 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 205 |
+
with gr.Column(scale=1):
|
| 206 |
+
gr.Markdown("### Output")
|
| 207 |
+
audio_out = gr.Audio(label="Waveform", interactive=False)
|
| 208 |
+
timestamps = gr.Dataframe(
|
| 209 |
+
headers=["word", "seconds"], label="Timestamps"
|
| 210 |
+
)
|
| 211 |
+
log_box = gr.Textbox(label="Logs", lines=8)
|
| 212 |
+
|
| 213 |
+
add_btn.click(
|
| 214 |
+
lambda c: _add_turn(c),
|
| 215 |
+
inputs=turn_state,
|
| 216 |
+
outputs=[turn_state, *controls],
|
| 217 |
+
)
|
| 218 |
+
remove_btn.click(
|
| 219 |
+
lambda c: _remove_turn(c),
|
| 220 |
+
inputs=turn_state,
|
| 221 |
+
outputs=[turn_state, *controls],
|
| 222 |
+
)
|
| 223 |
+
example_dropdown.change(
|
| 224 |
+
lambda name, c: _load_example(name, c),
|
| 225 |
+
inputs=[example_dropdown, turn_state],
|
| 226 |
+
outputs=[turn_state, *controls],
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
generate_btn.click(
|
| 230 |
+
generate_audio,
|
| 231 |
+
inputs=[
|
| 232 |
+
turn_state,
|
| 233 |
+
*controls,
|
| 234 |
+
voice_s1,
|
| 235 |
+
voice_s2,
|
| 236 |
+
cfg_scale,
|
| 237 |
+
text_temperature,
|
| 238 |
+
audio_temperature,
|
| 239 |
+
text_top_k,
|
| 240 |
+
audio_top_k,
|
| 241 |
+
include_prefix,
|
| 242 |
+
],
|
| 243 |
+
outputs=[audio_out, timestamps, log_box],
|
| 244 |
+
)
|
| 245 |
+
return demo
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
app = build_interface()
|
| 250 |
+
app.queue(default_concurrency_limit=1)
|
| 251 |
+
app.launch(share=True)
|
banner.gif
ADDED
|
Git LFS Details
|
dia2/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import DiaConfig, load_config
|
| 2 |
+
from .core.model import Dia2Model
|
| 3 |
+
from .engine import Dia2
|
| 4 |
+
from .generation import (
|
| 5 |
+
GenerationConfig,
|
| 6 |
+
GenerationResult,
|
| 7 |
+
PrefixConfig,
|
| 8 |
+
SamplingConfig,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"DiaConfig",
|
| 13 |
+
"Dia2Model",
|
| 14 |
+
"load_config",
|
| 15 |
+
"GenerationConfig",
|
| 16 |
+
"GenerationResult",
|
| 17 |
+
"PrefixConfig",
|
| 18 |
+
"SamplingConfig",
|
| 19 |
+
"Dia2",
|
| 20 |
+
]
|
dia2/assets.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
+
ASSET_MANIFEST = os.environ.get("DIA2_ASSET_MANIFEST", "dia2_assets.json")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class AssetBundle:
|
| 16 |
+
config_path: str
|
| 17 |
+
weights_path: str
|
| 18 |
+
tokenizer_id: Optional[str]
|
| 19 |
+
mimi_id: Optional[str]
|
| 20 |
+
repo_id: Optional[str]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def resolve_assets(
|
| 24 |
+
*,
|
| 25 |
+
repo: Optional[str],
|
| 26 |
+
config_path: Optional[str | Path],
|
| 27 |
+
weights_path: Optional[str | Path],
|
| 28 |
+
manifest_name: Optional[str] = None,
|
| 29 |
+
) -> AssetBundle:
|
| 30 |
+
repo_id = repo
|
| 31 |
+
manifest_name = manifest_name or ASSET_MANIFEST
|
| 32 |
+
if repo_id and (config_path or weights_path):
|
| 33 |
+
raise ValueError("Provide either repo or config+weights, not both")
|
| 34 |
+
if config_path is None or weights_path is None:
|
| 35 |
+
if repo_id is None:
|
| 36 |
+
raise ValueError("Must specify repo or config+weights")
|
| 37 |
+
manifest = load_manifest(repo_id, manifest_name)
|
| 38 |
+
config_name = manifest.get("config", "config.json")
|
| 39 |
+
weights_name = manifest.get("weights", "model.safetensors")
|
| 40 |
+
config_local = hf_hub_download(repo_id, config_name)
|
| 41 |
+
weights_local = hf_hub_download(repo_id, weights_name)
|
| 42 |
+
return AssetBundle(
|
| 43 |
+
config_path=config_local,
|
| 44 |
+
weights_path=weights_local,
|
| 45 |
+
tokenizer_id=manifest.get("tokenizer") or repo_id,
|
| 46 |
+
mimi_id=manifest.get("mimi"),
|
| 47 |
+
repo_id=repo_id,
|
| 48 |
+
)
|
| 49 |
+
return AssetBundle(str(config_path), str(weights_path), None, None, repo_id)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_manifest(repo_id: str, manifest_name: str) -> dict:
|
| 53 |
+
if not manifest_name:
|
| 54 |
+
return {}
|
| 55 |
+
try:
|
| 56 |
+
path = hf_hub_download(repo_id, manifest_name)
|
| 57 |
+
except Exception:
|
| 58 |
+
return {}
|
| 59 |
+
try:
|
| 60 |
+
return json.loads(Path(path).read_text())
|
| 61 |
+
except json.JSONDecodeError:
|
| 62 |
+
return {}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
__all__ = ["AssetBundle", "ASSET_MANIFEST", "resolve_assets", "load_manifest"]
|
dia2/audio/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .codec import MimiCodec, DEFAULT_MIMI_MODEL_ID, MimiConfig
|
| 2 |
+
from .grid import delay_frames, undelay_frames, mask_audio_logits, fill_audio_channels, write_wav
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"MimiCodec",
|
| 6 |
+
"DEFAULT_MIMI_MODEL_ID",
|
| 7 |
+
"MimiConfig",
|
| 8 |
+
"delay_frames",
|
| 9 |
+
"undelay_frames",
|
| 10 |
+
"mask_audio_logits",
|
| 11 |
+
"fill_audio_channels",
|
| 12 |
+
"write_wav",
|
| 13 |
+
]
|
dia2/audio/codec.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from transformers import MimiModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
DEFAULT_MIMI_MODEL_ID = "kyutai/mimi"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class MimiConfig:
|
| 16 |
+
model_id: str = DEFAULT_MIMI_MODEL_ID
|
| 17 |
+
dtype: Optional[torch.dtype] = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MimiCodec(nn.Module):
|
| 21 |
+
"""Thin wrapper around transformers' MimiModel for decoding audio tokens."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model: MimiModel, device: torch.device) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.model = model
|
| 26 |
+
self.device = device
|
| 27 |
+
cfg = getattr(model, "config", None)
|
| 28 |
+
self.sample_rate = getattr(cfg, "sampling_rate", 24000)
|
| 29 |
+
self.frame_rate = getattr(cfg, "frame_rate", 12.5)
|
| 30 |
+
self.samples_per_frame = int(round(self.sample_rate / self.frame_rate)) if self.frame_rate else 0
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def from_pretrained(
|
| 34 |
+
cls,
|
| 35 |
+
model_id: str = DEFAULT_MIMI_MODEL_ID,
|
| 36 |
+
*,
|
| 37 |
+
device: torch.device,
|
| 38 |
+
dtype: Optional[torch.dtype] = None,
|
| 39 |
+
) -> "MimiCodec":
|
| 40 |
+
model = MimiModel.from_pretrained(
|
| 41 |
+
model_id,
|
| 42 |
+
torch_dtype=dtype,
|
| 43 |
+
low_cpu_mem_usage=True,
|
| 44 |
+
)
|
| 45 |
+
model = model.to(device)
|
| 46 |
+
model.eval()
|
| 47 |
+
return cls(model, device)
|
| 48 |
+
|
| 49 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
codes = codes.to(self.device)
|
| 51 |
+
with torch.inference_mode():
|
| 52 |
+
audio, _ = self.model.decode(codes, return_dict=False)
|
| 53 |
+
return torch.clamp(audio, -1.0, 1.0)
|
| 54 |
+
|
| 55 |
+
def encode(self, audio: torch.Tensor, *, return_dict: bool = False):
|
| 56 |
+
audio = audio.to(self.device)
|
| 57 |
+
with torch.inference_mode():
|
| 58 |
+
return self.model.encode(audio, return_dict=return_dict)
|
dia2/audio/grid.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Sequence
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def delay_frames(aligned: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
|
| 11 |
+
channels, total = aligned.shape
|
| 12 |
+
max_delay = max(delays) if delays else 0
|
| 13 |
+
out = aligned.new_full((channels, total + max_delay), pad_id)
|
| 14 |
+
for idx, delay in enumerate(delays):
|
| 15 |
+
out[idx, delay : delay + total] = aligned[idx]
|
| 16 |
+
return out
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def undelay_frames(delayed: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
|
| 20 |
+
channels, total = delayed.shape
|
| 21 |
+
max_delay = max(delays) if delays else 0
|
| 22 |
+
target = max(0, total - max_delay)
|
| 23 |
+
out = delayed.new_full((channels, target), pad_id)
|
| 24 |
+
for idx, delay in enumerate(delays):
|
| 25 |
+
out[idx] = delayed[idx, delay : delay + target]
|
| 26 |
+
return out
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def mask_audio_logits(logits: torch.Tensor, pad_idx: int, bos_idx: int) -> torch.Tensor:
|
| 30 |
+
if logits.shape[-1] == 0:
|
| 31 |
+
return logits
|
| 32 |
+
max_idx = logits.shape[-1] - 1
|
| 33 |
+
targets = [idx for idx in (pad_idx, bos_idx) if 0 <= idx <= max_idx]
|
| 34 |
+
if not targets:
|
| 35 |
+
return logits
|
| 36 |
+
masked = logits.clone()
|
| 37 |
+
neg_inf = torch.finfo(masked.dtype).min
|
| 38 |
+
for idx in targets:
|
| 39 |
+
masked[..., idx] = neg_inf
|
| 40 |
+
return masked
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def fill_audio_channels(
|
| 44 |
+
delays: Sequence[int],
|
| 45 |
+
constants,
|
| 46 |
+
step: int,
|
| 47 |
+
step_tokens: torch.Tensor,
|
| 48 |
+
audio_buf: torch.Tensor,
|
| 49 |
+
) -> None:
|
| 50 |
+
for cb, delay in enumerate(delays):
|
| 51 |
+
idx = step - delay
|
| 52 |
+
in_bounds = idx >= 0 and step < audio_buf.shape[-1]
|
| 53 |
+
if in_bounds:
|
| 54 |
+
step_tokens[:, 2 + cb, 0] = audio_buf[:, cb, step]
|
| 55 |
+
else:
|
| 56 |
+
step_tokens[:, 2 + cb, 0] = constants.audio_bos
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def write_wav(path: str | Path, audio: np.ndarray, sample_rate: int) -> None:
|
| 60 |
+
path = Path(path)
|
| 61 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
audio = np.clip(audio, -1.0, 1.0)
|
| 63 |
+
pcm16 = (audio * 32767.0).astype(np.int16)
|
| 64 |
+
import wave
|
| 65 |
+
|
| 66 |
+
with wave.open(str(path), "wb") as handle:
|
| 67 |
+
handle.setnchannels(1)
|
| 68 |
+
handle.setsampwidth(2)
|
| 69 |
+
handle.setframerate(sample_rate)
|
| 70 |
+
handle.writeframes(pcm16.tobytes())
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
__all__ = [
|
| 74 |
+
"delay_frames",
|
| 75 |
+
"undelay_frames",
|
| 76 |
+
"mask_audio_logits",
|
| 77 |
+
"fill_audio_channels",
|
| 78 |
+
"write_wav",
|
| 79 |
+
]
|
dia2/cli.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .engine import Dia2
|
| 8 |
+
from .generation import (
|
| 9 |
+
build_generation_config,
|
| 10 |
+
load_script_text,
|
| 11 |
+
validate_generation_params,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main() -> None:
|
| 16 |
+
parser = argparse.ArgumentParser(description="Generate audio with Dia2")
|
| 17 |
+
parser.add_argument("--config", help="Path to config.json (overrides repo lookup)")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--weights", help="Path to model.safetensors (overrides repo lookup)"
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--hf",
|
| 23 |
+
required=False,
|
| 24 |
+
help="Hugging Face repo id to download config/weights from (e.g. nari-labs/Dia2-2B)",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--input", default="input.txt", help="Script text file (default: input.txt)"
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument("output", help="Output WAV path")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--device",
|
| 32 |
+
default=None,
|
| 33 |
+
help="Computation device (defaults to cuda if available, else cpu)",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--dtype",
|
| 37 |
+
choices=["auto", "float32", "bfloat16"],
|
| 38 |
+
default="bfloat16",
|
| 39 |
+
help="Computation dtype (default: bfloat16)",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument("--topk", type=int, default=50)
|
| 42 |
+
parser.add_argument("--temperature", type=float, default=0.8)
|
| 43 |
+
parser.add_argument("--cfg", type=float, default=1.0)
|
| 44 |
+
parser.add_argument("--tokenizer", help="Tokenizer repo or local path override")
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--mimi", help="Mimi repo id override (defaults to config/assets)"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("--prefix-speaker-1", help="Prefix audio file for speaker 1")
|
| 49 |
+
parser.add_argument("--prefix-speaker-2", help="Prefix audio file for speaker 2")
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--include-prefix",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Keep prefix audio in the final waveform (default: trimmed)",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--verbose", action="store_true", help="Print generation progress logs"
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--cuda-graph",
|
| 60 |
+
action="store_true",
|
| 61 |
+
help="Run generation with CUDA graph capture",
|
| 62 |
+
)
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
|
| 65 |
+
device = args.device
|
| 66 |
+
if device is None or device == "auto":
|
| 67 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
dtype = args.dtype or "bfloat16"
|
| 69 |
+
|
| 70 |
+
repo = args.hf
|
| 71 |
+
if repo:
|
| 72 |
+
dia = Dia2(
|
| 73 |
+
repo=repo,
|
| 74 |
+
device=device,
|
| 75 |
+
dtype=dtype,
|
| 76 |
+
tokenizer_id=args.tokenizer,
|
| 77 |
+
mimi_id=args.mimi,
|
| 78 |
+
)
|
| 79 |
+
elif args.config and args.weights:
|
| 80 |
+
dia = Dia2.from_local(
|
| 81 |
+
config_path=args.config,
|
| 82 |
+
weights_path=args.weights,
|
| 83 |
+
device=device,
|
| 84 |
+
dtype=dtype,
|
| 85 |
+
tokenizer_id=args.tokenizer,
|
| 86 |
+
mimi_id=args.mimi,
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError("Provide --hf/--variant or both --config and --weights")
|
| 90 |
+
|
| 91 |
+
script = load_script_text(args.input)
|
| 92 |
+
temperature, top_k, cfg_scale = validate_generation_params(
|
| 93 |
+
temperature=args.temperature,
|
| 94 |
+
top_k=args.topk,
|
| 95 |
+
cfg_scale=args.cfg,
|
| 96 |
+
)
|
| 97 |
+
config = build_generation_config(
|
| 98 |
+
temperature=temperature,
|
| 99 |
+
top_k=top_k,
|
| 100 |
+
cfg_scale=cfg_scale,
|
| 101 |
+
)
|
| 102 |
+
overrides = {}
|
| 103 |
+
if args.cuda_graph:
|
| 104 |
+
overrides["use_cuda_graph"] = True
|
| 105 |
+
if args.prefix_speaker_1:
|
| 106 |
+
overrides["prefix_speaker_1"] = args.prefix_speaker_1
|
| 107 |
+
if args.prefix_speaker_2:
|
| 108 |
+
overrides["prefix_speaker_2"] = args.prefix_speaker_2
|
| 109 |
+
if args.include_prefix:
|
| 110 |
+
overrides["include_prefix"] = True
|
| 111 |
+
|
| 112 |
+
dia.generate(
|
| 113 |
+
script,
|
| 114 |
+
config=config,
|
| 115 |
+
output_wav=args.output,
|
| 116 |
+
verbose=args.verbose,
|
| 117 |
+
**overrides,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|
dia2/config.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class DataConfig:
|
| 11 |
+
channels: int
|
| 12 |
+
text_vocab_size: int
|
| 13 |
+
audio_vocab_size: int
|
| 14 |
+
action_vocab_size: int
|
| 15 |
+
text_pad_token_id: int
|
| 16 |
+
text_new_word_token_id: int
|
| 17 |
+
text_zero_token_id: int
|
| 18 |
+
audio_pad_token_id: int
|
| 19 |
+
audio_bos_token_id: int
|
| 20 |
+
action_pad_token_id: int
|
| 21 |
+
action_new_word_token_id: int
|
| 22 |
+
delay_pattern: List[int]
|
| 23 |
+
first_word_min_start: int
|
| 24 |
+
max_pad: int
|
| 25 |
+
second_stream_ahead: int
|
| 26 |
+
tokenizer_path: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class DecoderConfig:
|
| 31 |
+
n_layer: int
|
| 32 |
+
n_embd: int
|
| 33 |
+
n_hidden: int
|
| 34 |
+
gqa_query_heads: int
|
| 35 |
+
kv_heads: int
|
| 36 |
+
gqa_head_dim: int
|
| 37 |
+
dropout: float
|
| 38 |
+
low_rank_dim: int | None = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass(frozen=True)
|
| 42 |
+
class DepformerConfig:
|
| 43 |
+
n_layer: int
|
| 44 |
+
n_embd: int
|
| 45 |
+
n_hidden: int
|
| 46 |
+
gqa_query_heads: int
|
| 47 |
+
kv_heads: int
|
| 48 |
+
gqa_head_dim: int
|
| 49 |
+
apply_rope: bool
|
| 50 |
+
text_embedding: bool
|
| 51 |
+
mlp_activations: List[str]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass(frozen=True)
|
| 55 |
+
class LinearHeadConfig:
|
| 56 |
+
mlp_activations: List[str]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass(frozen=True)
|
| 60 |
+
class ModelConfig:
|
| 61 |
+
decoder: DecoderConfig
|
| 62 |
+
depformer: DepformerConfig
|
| 63 |
+
linear: LinearHeadConfig
|
| 64 |
+
dropout: float
|
| 65 |
+
rope_min_timescale: int
|
| 66 |
+
rope_max_timescale: int
|
| 67 |
+
normalization_layer_epsilon: float
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass(frozen=True)
|
| 71 |
+
class RuntimeConfig:
|
| 72 |
+
weights_schedule: List[int]
|
| 73 |
+
max_context_steps: int
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass(frozen=True)
|
| 77 |
+
class AssetsConfig:
|
| 78 |
+
tokenizer: Optional[str]
|
| 79 |
+
mimi: Optional[str]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass(frozen=True)
|
| 83 |
+
class DiaConfig:
|
| 84 |
+
data: DataConfig
|
| 85 |
+
model: ModelConfig
|
| 86 |
+
runtime: RuntimeConfig
|
| 87 |
+
assets: AssetsConfig
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig:
|
| 91 |
+
block = block or {}
|
| 92 |
+
weights_schedule = block.get("weights_schedule")
|
| 93 |
+
if weights_schedule is None:
|
| 94 |
+
audio_channels = max(0, data_cfg.channels - 2)
|
| 95 |
+
weights_schedule = list(range(max(audio_channels - 1, 0)))
|
| 96 |
+
max_context = block.get("max_context_steps", 1500)
|
| 97 |
+
return RuntimeConfig(
|
| 98 |
+
weights_schedule=list(weights_schedule),
|
| 99 |
+
max_context_steps=int(max_context),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_config(path: str | Path) -> DiaConfig:
|
| 104 |
+
cfg = json.loads(Path(path).read_text())
|
| 105 |
+
data = cfg["data"]
|
| 106 |
+
model = cfg["model"]
|
| 107 |
+
runtime_cfg_raw = cfg.get("runtime")
|
| 108 |
+
if runtime_cfg_raw is None:
|
| 109 |
+
raise ValueError(f"Config '{path}' is missing a runtime block")
|
| 110 |
+
|
| 111 |
+
decoder_cfg = DecoderConfig(
|
| 112 |
+
n_layer=model["decoder"]["n_layer"],
|
| 113 |
+
n_embd=model["decoder"]["n_embd"],
|
| 114 |
+
n_hidden=model["decoder"]["n_hidden"],
|
| 115 |
+
gqa_query_heads=model["decoder"]["gqa_query_heads"],
|
| 116 |
+
kv_heads=model["decoder"]["kv_heads"],
|
| 117 |
+
gqa_head_dim=model["decoder"]["gqa_head_dim"],
|
| 118 |
+
dropout=model.get("dropout", 0.0),
|
| 119 |
+
low_rank_dim=model["decoder"].get("low_rank_dim"),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
depformer_cfg = DepformerConfig(
|
| 123 |
+
n_layer=model["depformer"]["n_layer"],
|
| 124 |
+
n_embd=model["depformer"]["n_embd"],
|
| 125 |
+
n_hidden=model["depformer"]["n_hidden"],
|
| 126 |
+
gqa_query_heads=model["depformer"]["gqa_query_heads"],
|
| 127 |
+
kv_heads=model["depformer"]["kv_heads"],
|
| 128 |
+
gqa_head_dim=model["depformer"]["gqa_head_dim"],
|
| 129 |
+
apply_rope=model["depformer"].get("apply_rope", True),
|
| 130 |
+
text_embedding=model["depformer"].get("text_embedding", True),
|
| 131 |
+
mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
data_cfg = DataConfig(
|
| 135 |
+
channels=data["channels"],
|
| 136 |
+
text_vocab_size=data["text_vocab_size"],
|
| 137 |
+
audio_vocab_size=data["audio_vocab_size"],
|
| 138 |
+
action_vocab_size=data["action_vocab_size"],
|
| 139 |
+
text_pad_token_id=data["text_pad_token_id"],
|
| 140 |
+
text_new_word_token_id=data["text_new_word_token_id"],
|
| 141 |
+
text_zero_token_id=data.get("text_zero_token_id", 7),
|
| 142 |
+
audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1),
|
| 143 |
+
audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2),
|
| 144 |
+
action_pad_token_id=data["action_pad_token_id"],
|
| 145 |
+
action_new_word_token_id=data["action_new_word_token_id"],
|
| 146 |
+
delay_pattern=list(data.get("delay_pattern", [])),
|
| 147 |
+
first_word_min_start=data.get("first_word_min_start", 0),
|
| 148 |
+
max_pad=data.get("max_pad", 0),
|
| 149 |
+
second_stream_ahead=data.get("second_stream_ahead", 0),
|
| 150 |
+
tokenizer_path=data.get("tokenizer_path"),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg)
|
| 154 |
+
|
| 155 |
+
linear_cfg = LinearHeadConfig(
|
| 156 |
+
mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
model_cfg = ModelConfig(
|
| 160 |
+
decoder=decoder_cfg,
|
| 161 |
+
depformer=depformer_cfg,
|
| 162 |
+
linear=linear_cfg,
|
| 163 |
+
dropout=model.get("dropout", 0.0),
|
| 164 |
+
rope_min_timescale=model.get("rope_min_timescale", 1),
|
| 165 |
+
rope_max_timescale=model.get("rope_max_timescale", 10000),
|
| 166 |
+
normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
assets_raw = cfg.get("assets") or {}
|
| 170 |
+
assets_cfg = AssetsConfig(
|
| 171 |
+
tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path,
|
| 172 |
+
mimi=assets_raw.get("mimi"),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return DiaConfig(
|
| 176 |
+
data=data_cfg,
|
| 177 |
+
model=model_cfg,
|
| 178 |
+
runtime=runtime_cfg,
|
| 179 |
+
assets=assets_cfg,
|
| 180 |
+
)
|
dia2/core/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import Dia2Model, DecodeState
|
| 2 |
+
from .transformer import TransformerDecoder
|
| 3 |
+
from .depformer import Depformer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"Dia2Model",
|
| 7 |
+
"DecodeState",
|
| 8 |
+
"TransformerDecoder",
|
| 9 |
+
"Depformer",
|
| 10 |
+
]
|
dia2/core/cache.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class CacheSlot:
|
| 11 |
+
keys: torch.Tensor
|
| 12 |
+
values: torch.Tensor
|
| 13 |
+
|
| 14 |
+
def __post_init__(self) -> None:
|
| 15 |
+
self.max_steps = self.keys.shape[2]
|
| 16 |
+
self.head_dim = self.keys.shape[3]
|
| 17 |
+
self.flat_heads = self.keys.shape[0] * self.keys.shape[1]
|
| 18 |
+
device = self.keys.device
|
| 19 |
+
self.length = torch.zeros((), dtype=torch.long, device=device)
|
| 20 |
+
self.positions = torch.arange(self.max_steps, dtype=torch.long, device=device)
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def allocate(
|
| 24 |
+
cls,
|
| 25 |
+
*,
|
| 26 |
+
batch_size: int,
|
| 27 |
+
heads: int,
|
| 28 |
+
max_steps: int,
|
| 29 |
+
head_dim: int,
|
| 30 |
+
device: torch.device,
|
| 31 |
+
dtype: torch.dtype,
|
| 32 |
+
) -> "CacheSlot":
|
| 33 |
+
keys = torch.zeros(batch_size, heads, max_steps, head_dim, device=device, dtype=dtype)
|
| 34 |
+
values = torch.zeros_like(keys)
|
| 35 |
+
return cls(keys, values)
|
| 36 |
+
|
| 37 |
+
def reset(self) -> None:
|
| 38 |
+
self.length.zero_()
|
| 39 |
+
|
| 40 |
+
def write_and_view(
|
| 41 |
+
self,
|
| 42 |
+
key_chunk: torch.Tensor,
|
| 43 |
+
value_chunk: torch.Tensor,
|
| 44 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 45 |
+
step = key_chunk.shape[2]
|
| 46 |
+
start = self.length
|
| 47 |
+
indices = self.positions[:step] + start
|
| 48 |
+
expanded = indices.unsqueeze(0).expand(self.flat_heads, -1)
|
| 49 |
+
|
| 50 |
+
flat_keys = self.keys.view(self.flat_heads, self.max_steps, self.head_dim)
|
| 51 |
+
flat_values = self.values.view(self.flat_heads, self.max_steps, self.head_dim)
|
| 52 |
+
flat_key_chunk = key_chunk.reshape(self.flat_heads, step, self.head_dim)
|
| 53 |
+
flat_value_chunk = value_chunk.reshape(self.flat_heads, step, self.head_dim)
|
| 54 |
+
scatter_index = expanded.unsqueeze(-1).expand_as(flat_key_chunk)
|
| 55 |
+
flat_keys.scatter_(1, scatter_index, flat_key_chunk)
|
| 56 |
+
flat_values.scatter_(1, scatter_index, flat_value_chunk)
|
| 57 |
+
|
| 58 |
+
self.length.add_(step)
|
| 59 |
+
bool_mask = (self.positions >= self.length).view(1, 1, 1, self.max_steps)
|
| 60 |
+
mask_dtype = self.keys.dtype
|
| 61 |
+
mask_value = torch.finfo(mask_dtype).min
|
| 62 |
+
attn_mask = torch.zeros_like(bool_mask, dtype=mask_dtype)
|
| 63 |
+
attn_mask = attn_mask.masked_fill(bool_mask, mask_value)
|
| 64 |
+
return self.keys, self.values, attn_mask
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class KVCache:
|
| 68 |
+
def __init__(self, slots: List[CacheSlot]) -> None:
|
| 69 |
+
self.slots = slots
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def allocate(
|
| 73 |
+
cls,
|
| 74 |
+
*,
|
| 75 |
+
num_layers: int,
|
| 76 |
+
batch_size: int,
|
| 77 |
+
heads: int,
|
| 78 |
+
max_steps: int,
|
| 79 |
+
head_dim: int,
|
| 80 |
+
device: torch.device,
|
| 81 |
+
dtype: torch.dtype,
|
| 82 |
+
) -> "KVCache":
|
| 83 |
+
slots = [
|
| 84 |
+
CacheSlot.allocate(
|
| 85 |
+
batch_size=batch_size,
|
| 86 |
+
heads=heads,
|
| 87 |
+
max_steps=max_steps,
|
| 88 |
+
head_dim=head_dim,
|
| 89 |
+
device=device,
|
| 90 |
+
dtype=dtype,
|
| 91 |
+
)
|
| 92 |
+
for _ in range(num_layers)
|
| 93 |
+
]
|
| 94 |
+
return cls(slots)
|
| 95 |
+
|
| 96 |
+
def get_slot(self, index: int) -> CacheSlot:
|
| 97 |
+
return self.slots[index]
|
| 98 |
+
|
| 99 |
+
def reset(self) -> None:
|
| 100 |
+
for slot in self.slots:
|
| 101 |
+
slot.reset()
|
| 102 |
+
|
| 103 |
+
clear = reset
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
__all__ = ["CacheSlot", "KVCache"]
|
dia2/core/depformer.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from ..config import DiaConfig
|
| 10 |
+
from .cache import KVCache
|
| 11 |
+
from .layers import MultiStreamEmbedding, Mlp, RotaryEmbedding
|
| 12 |
+
from .precision import Precision
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ScheduleAttention(nn.Module):
|
| 16 |
+
"""Depformer attention that mirrors dia_v2 ScheduleAttention."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
dep_cfg = config.model.depformer
|
| 21 |
+
runtime = config.runtime
|
| 22 |
+
self.schedule = runtime.weights_schedule
|
| 23 |
+
self.num_query_heads = dep_cfg.gqa_query_heads
|
| 24 |
+
self.num_kv_heads = dep_cfg.kv_heads
|
| 25 |
+
self.head_dim = dep_cfg.gqa_head_dim
|
| 26 |
+
self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
|
| 27 |
+
self.apply_rope = dep_cfg.apply_rope
|
| 28 |
+
self.used_ids = sorted(set(self.schedule))
|
| 29 |
+
self.compute_dtype = compute_dtype
|
| 30 |
+
|
| 31 |
+
self.in_proj = nn.ModuleDict(
|
| 32 |
+
{
|
| 33 |
+
str(i): nn.Linear(
|
| 34 |
+
dep_cfg.n_embd,
|
| 35 |
+
3 * self.num_query_heads * self.head_dim,
|
| 36 |
+
bias=False,
|
| 37 |
+
)
|
| 38 |
+
for i in self.used_ids
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
self.out_proj = nn.ModuleDict(
|
| 42 |
+
{
|
| 43 |
+
str(i): nn.Linear(
|
| 44 |
+
self.num_query_heads * self.head_dim,
|
| 45 |
+
dep_cfg.n_embd,
|
| 46 |
+
bias=False,
|
| 47 |
+
)
|
| 48 |
+
for i in self.used_ids
|
| 49 |
+
}
|
| 50 |
+
)
|
| 51 |
+
eps = config.model.normalization_layer_epsilon
|
| 52 |
+
self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
|
| 53 |
+
self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
|
| 54 |
+
|
| 55 |
+
if self.apply_rope:
|
| 56 |
+
self.rotary = RotaryEmbedding(
|
| 57 |
+
self.head_dim,
|
| 58 |
+
config.model.rope_min_timescale,
|
| 59 |
+
config.model.rope_max_timescale,
|
| 60 |
+
)
|
| 61 |
+
stage_count = max(len(self.schedule), 1)
|
| 62 |
+
self.register_buffer(
|
| 63 |
+
"stage_positions",
|
| 64 |
+
torch.arange(stage_count, dtype=torch.long).view(stage_count, 1),
|
| 65 |
+
persistent=False,
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
self.rotary = None
|
| 69 |
+
self.register_buffer(
|
| 70 |
+
"stage_positions",
|
| 71 |
+
torch.zeros(0, 1, dtype=torch.long),
|
| 72 |
+
persistent=False,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward_incremental(
|
| 76 |
+
self,
|
| 77 |
+
x_t: torch.Tensor,
|
| 78 |
+
stage_index: int,
|
| 79 |
+
cache_slot,
|
| 80 |
+
) -> Tuple[torch.Tensor, object]:
|
| 81 |
+
bsz, seq, _ = x_t.shape
|
| 82 |
+
if seq != 1:
|
| 83 |
+
raise ValueError("ScheduleAttention expects seq len 1 during decoding")
|
| 84 |
+
orig_dtype = x_t.dtype
|
| 85 |
+
module_index = self.schedule[stage_index]
|
| 86 |
+
proj = self.in_proj[str(module_index)](x_t.to(torch.float32))
|
| 87 |
+
proj = proj.view(bsz, seq, 3, self.num_query_heads, self.head_dim).to(self.compute_dtype)
|
| 88 |
+
|
| 89 |
+
q_proj = self.q_norm(proj[:, :, 0])
|
| 90 |
+
k_proj = self.k_norm(proj[:, :, 1])
|
| 91 |
+
v_proj = proj[:, :, 2]
|
| 92 |
+
|
| 93 |
+
if self.apply_rope:
|
| 94 |
+
pos_ids = self.stage_positions[stage_index : stage_index + 1]
|
| 95 |
+
if pos_ids.device != x_t.device:
|
| 96 |
+
pos_ids = pos_ids.to(x_t.device)
|
| 97 |
+
q_proj = self.rotary(q_proj, pos_ids)
|
| 98 |
+
k_proj = self.rotary(k_proj, pos_ids)
|
| 99 |
+
|
| 100 |
+
q = q_proj.transpose(1, 2)
|
| 101 |
+
k = k_proj.transpose(1, 2)
|
| 102 |
+
v = v_proj.transpose(1, 2)
|
| 103 |
+
|
| 104 |
+
if cache_slot is not None:
|
| 105 |
+
k, v, attn_mask = cache_slot.write_and_view(k, v)
|
| 106 |
+
else:
|
| 107 |
+
attn_mask = None
|
| 108 |
+
|
| 109 |
+
attn = F.scaled_dot_product_attention(
|
| 110 |
+
q,
|
| 111 |
+
k,
|
| 112 |
+
v,
|
| 113 |
+
scale=1.0,
|
| 114 |
+
attn_mask=attn_mask,
|
| 115 |
+
enable_gqa=self.num_gqa_groups > 1,
|
| 116 |
+
)
|
| 117 |
+
attn = attn.transpose(1, 2).contiguous()
|
| 118 |
+
flat = attn.reshape(bsz, seq, self.num_query_heads * self.head_dim)
|
| 119 |
+
out = self.out_proj[str(module_index)](flat.to(torch.float32))
|
| 120 |
+
return out.to(orig_dtype), cache_slot
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DepformerLayer(nn.Module):
|
| 124 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
| 125 |
+
super().__init__()
|
| 126 |
+
dep_cfg = config.model.depformer
|
| 127 |
+
eps = config.model.normalization_layer_epsilon
|
| 128 |
+
self.pre_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
|
| 129 |
+
self.post_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
|
| 130 |
+
self.self_attention = ScheduleAttention(config, compute_dtype)
|
| 131 |
+
self.mlp = Mlp(
|
| 132 |
+
dep_cfg.n_embd,
|
| 133 |
+
dep_cfg.n_hidden,
|
| 134 |
+
compute_dtype,
|
| 135 |
+
tuple(config.model.depformer.mlp_activations),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def decode_step(
|
| 139 |
+
self,
|
| 140 |
+
x_t: torch.Tensor,
|
| 141 |
+
stage_index: int,
|
| 142 |
+
cache_slot,
|
| 143 |
+
) -> Tuple[torch.Tensor, object]:
|
| 144 |
+
residual = x_t
|
| 145 |
+
x_norm = self.pre_norm(x_t)
|
| 146 |
+
sa_out, _ = self.self_attention.forward_incremental(x_norm, stage_index, cache_slot)
|
| 147 |
+
x = residual + sa_out
|
| 148 |
+
residual2 = x
|
| 149 |
+
x_norm2 = self.post_norm(x)
|
| 150 |
+
mlp_out = self.mlp(x_norm2)
|
| 151 |
+
return residual2 + mlp_out, cache_slot
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class Depformer(nn.Module):
|
| 155 |
+
def __init__(self, config: DiaConfig, precision: Precision):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.config = config
|
| 158 |
+
self.precision = precision
|
| 159 |
+
dep_cfg = config.model.depformer
|
| 160 |
+
data_cfg = config.data
|
| 161 |
+
runtime = config.runtime
|
| 162 |
+
|
| 163 |
+
self.num_audio_channels = max(0, data_cfg.channels - 2)
|
| 164 |
+
self.num_depth = max(self.num_audio_channels - 1, 0)
|
| 165 |
+
self.weights_schedule = runtime.weights_schedule
|
| 166 |
+
|
| 167 |
+
self.audio_embeds = nn.ModuleList(
|
| 168 |
+
[nn.Embedding(data_cfg.audio_vocab_size, dep_cfg.n_embd) for _ in range(self.num_depth)]
|
| 169 |
+
)
|
| 170 |
+
if dep_cfg.text_embedding:
|
| 171 |
+
self.text_embed = MultiStreamEmbedding(
|
| 172 |
+
data_cfg.text_vocab_size,
|
| 173 |
+
dep_cfg.n_embd,
|
| 174 |
+
pad_id=data_cfg.text_pad_token_id,
|
| 175 |
+
output_dtype=precision.compute,
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
self.text_embed = None
|
| 179 |
+
|
| 180 |
+
used_ids = sorted(set(self.weights_schedule))
|
| 181 |
+
self.depformer_in = nn.ModuleDict(
|
| 182 |
+
{
|
| 183 |
+
str(i): nn.Linear(
|
| 184 |
+
config.model.decoder.n_embd,
|
| 185 |
+
dep_cfg.n_embd,
|
| 186 |
+
bias=False,
|
| 187 |
+
)
|
| 188 |
+
for i in used_ids
|
| 189 |
+
}
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.layers = nn.ModuleList([DepformerLayer(config, precision.compute) for _ in range(dep_cfg.n_layer)])
|
| 193 |
+
self.norm = nn.RMSNorm(dep_cfg.n_embd, eps=config.model.normalization_layer_epsilon)
|
| 194 |
+
self.logits_dtype = precision.logits
|
| 195 |
+
self.logits = nn.ModuleList(
|
| 196 |
+
[
|
| 197 |
+
nn.Linear(dep_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
|
| 198 |
+
for _ in range(self.num_depth)
|
| 199 |
+
]
|
| 200 |
+
)
|
| 201 |
+
self.audio_vocab_limit = min(data_cfg.audio_pad_token_id, data_cfg.audio_bos_token_id)
|
| 202 |
+
|
| 203 |
+
def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
|
| 204 |
+
heads = self.layers[0].self_attention.num_kv_heads
|
| 205 |
+
head_dim = self.layers[0].self_attention.head_dim
|
| 206 |
+
return KVCache.allocate(
|
| 207 |
+
num_layers=len(self.layers),
|
| 208 |
+
batch_size=batch_size,
|
| 209 |
+
heads=heads,
|
| 210 |
+
max_steps=max_steps,
|
| 211 |
+
head_dim=head_dim,
|
| 212 |
+
device=device,
|
| 213 |
+
dtype=self.precision.compute,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def forward_step(
|
| 217 |
+
self,
|
| 218 |
+
prev_audio: torch.Tensor,
|
| 219 |
+
transformer_out: torch.Tensor,
|
| 220 |
+
stage_index: int,
|
| 221 |
+
cache: KVCache,
|
| 222 |
+
main_text: Optional[torch.Tensor],
|
| 223 |
+
second_text: Optional[torch.Tensor],
|
| 224 |
+
) -> Tuple[torch.Tensor, KVCache]:
|
| 225 |
+
self._validate_inputs(stage_index, cache)
|
| 226 |
+
return self._forward_stage(stage_index, prev_audio, transformer_out, cache, main_text, second_text)
|
| 227 |
+
|
| 228 |
+
def _forward_stage(
|
| 229 |
+
self,
|
| 230 |
+
stage_index: int,
|
| 231 |
+
prev_audio: torch.Tensor,
|
| 232 |
+
transformer_out: torch.Tensor,
|
| 233 |
+
cache: KVCache,
|
| 234 |
+
main_text: Optional[torch.Tensor],
|
| 235 |
+
second_text: Optional[torch.Tensor],
|
| 236 |
+
) -> Tuple[torch.Tensor, KVCache]:
|
| 237 |
+
prev_audio = prev_audio.long()
|
| 238 |
+
weight_idx = self.weights_schedule[stage_index]
|
| 239 |
+
token_emb = self.audio_embeds[stage_index](prev_audio[:, None]).to(self.precision.compute)
|
| 240 |
+
if stage_index == 0 and self.text_embed is not None:
|
| 241 |
+
if main_text is None or second_text is None:
|
| 242 |
+
raise ValueError("stage 0 requires text tokens")
|
| 243 |
+
token_emb = token_emb + self.text_embed(main_text[:, None], second_text[:, None])
|
| 244 |
+
|
| 245 |
+
dep_in = self.depformer_in[str(weight_idx)](transformer_out.to(torch.float32))
|
| 246 |
+
dep_in = dep_in.to(self.precision.compute)
|
| 247 |
+
dep_in = dep_in + token_emb.to(dep_in.dtype)
|
| 248 |
+
x = dep_in
|
| 249 |
+
for idx, layer in enumerate(self.layers):
|
| 250 |
+
slot = cache.get_slot(idx)
|
| 251 |
+
x, _ = layer.decode_step(x, stage_index, slot)
|
| 252 |
+
|
| 253 |
+
hidden = self.norm(x)
|
| 254 |
+
logits = self.logits[stage_index](hidden.to(torch.float32))
|
| 255 |
+
logits = logits.to(self.logits_dtype)
|
| 256 |
+
logits = logits.unsqueeze(1)
|
| 257 |
+
logits = logits[..., : self.audio_vocab_limit]
|
| 258 |
+
return logits, cache
|
| 259 |
+
|
| 260 |
+
def _validate_inputs(self, stage_index: int, cache: KVCache | None) -> None:
|
| 261 |
+
if stage_index < 0 or stage_index >= self.num_depth:
|
| 262 |
+
raise ValueError(f"stage_index {stage_index} out of range (depth={self.num_depth})")
|
| 263 |
+
if cache is None:
|
| 264 |
+
raise ValueError("depformer cache must be initialized")
|
dia2/core/layers.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Tuple, Union, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RotaryEmbedding(nn.Module):
|
| 13 |
+
def __init__(self, head_dim: int, min_timescale: int, max_timescale: int):
|
| 14 |
+
super().__init__()
|
| 15 |
+
if head_dim % 2 != 0:
|
| 16 |
+
raise ValueError("RoPE dimension must be even")
|
| 17 |
+
half_dim = head_dim // 2
|
| 18 |
+
fraction = (2.0 * torch.arange(0, half_dim)) / head_dim
|
| 19 |
+
timescale = min_timescale * (max_timescale / min_timescale) ** fraction
|
| 20 |
+
inv_freq = 1.0 / timescale
|
| 21 |
+
self.register_buffer("inv_freq", inv_freq.to(torch.float32), persistent=False)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
pos = position_ids.to(self.inv_freq.dtype)
|
| 25 |
+
freqs = torch.einsum("...i,j->...ij", pos, self.inv_freq)
|
| 26 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 27 |
+
while emb.dim() < x.dim():
|
| 28 |
+
emb = emb.unsqueeze(-2)
|
| 29 |
+
cos = emb.cos().to(x.dtype)
|
| 30 |
+
sin = emb.sin().to(x.dtype)
|
| 31 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
| 32 |
+
rotated = torch.cat((-x2, x1), dim=-1)
|
| 33 |
+
return (x * cos) + (rotated * sin)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
x1 = x[..., ::2]
|
| 38 |
+
x2 = x[..., 1::2]
|
| 39 |
+
return torch.stack((-x2, x1), dim=-1).reshape_as(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _get_activation(name: str) -> nn.Module:
|
| 43 |
+
name = name.lower()
|
| 44 |
+
if name in ("silu", "swish", "swiglu"):
|
| 45 |
+
return nn.SiLU()
|
| 46 |
+
if name in ("gelu", "geglu"):
|
| 47 |
+
return nn.GELU()
|
| 48 |
+
if name == "relu":
|
| 49 |
+
return nn.ReLU()
|
| 50 |
+
if name == "linear":
|
| 51 |
+
return nn.Identity()
|
| 52 |
+
raise ValueError(f"Unsupported activation {name}")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class AttentionShape:
|
| 57 |
+
dim: int
|
| 58 |
+
heads: int
|
| 59 |
+
kv_heads: int
|
| 60 |
+
head_dim: int
|
| 61 |
+
rope_min: int
|
| 62 |
+
rope_max: int
|
| 63 |
+
apply_rope: bool
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Attention(nn.Module):
|
| 67 |
+
"""Byte-for-byte port of dia_v2 Attention.forward_incremental."""
|
| 68 |
+
|
| 69 |
+
def __init__(self, config: DiaConfig, dim: int, compute_dtype: torch.dtype) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
dec = config.model.decoder
|
| 72 |
+
self.num_query_heads = dec.gqa_query_heads
|
| 73 |
+
self.num_kv_heads = dec.kv_heads
|
| 74 |
+
self.head_dim = dec.gqa_head_dim
|
| 75 |
+
self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
|
| 76 |
+
self.compute_dtype = compute_dtype
|
| 77 |
+
self.q_proj = nn.Linear(dim, self.num_query_heads * self.head_dim, bias=False)
|
| 78 |
+
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
|
| 79 |
+
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
|
| 80 |
+
self.o_proj = nn.Linear(self.num_query_heads * self.head_dim, dim, bias=False)
|
| 81 |
+
eps = config.model.normalization_layer_epsilon
|
| 82 |
+
self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
|
| 83 |
+
self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
|
| 84 |
+
self.rotary = RotaryEmbedding(
|
| 85 |
+
self.head_dim,
|
| 86 |
+
config.model.rope_min_timescale,
|
| 87 |
+
config.model.rope_max_timescale,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def forward_incremental(
|
| 91 |
+
self,
|
| 92 |
+
x: torch.Tensor,
|
| 93 |
+
pos: Optional[torch.Tensor],
|
| 94 |
+
cache_slot,
|
| 95 |
+
) -> Tuple[torch.Tensor, object]:
|
| 96 |
+
B, T, _ = x.shape
|
| 97 |
+
if T != 1:
|
| 98 |
+
raise ValueError("Attention expects sequence length 1 during decoding")
|
| 99 |
+
orig_dtype = x.dtype
|
| 100 |
+
q_proj = self._project_heads(self.q_proj, x, self.num_query_heads)
|
| 101 |
+
k_proj = self._project_heads(self.k_proj, x, self.num_kv_heads)
|
| 102 |
+
v_proj = self._project_heads(self.v_proj, x, self.num_kv_heads)
|
| 103 |
+
q_proj = self.q_norm(q_proj)
|
| 104 |
+
k_proj = self.k_norm(k_proj)
|
| 105 |
+
if pos is not None:
|
| 106 |
+
q_proj = self.rotary(q_proj, pos)
|
| 107 |
+
k_proj = self.rotary(k_proj, pos)
|
| 108 |
+
q = q_proj.transpose(1, 2)
|
| 109 |
+
k = k_proj.transpose(1, 2)
|
| 110 |
+
v = v_proj.transpose(1, 2)
|
| 111 |
+
if cache_slot is not None:
|
| 112 |
+
k_cache, v_cache, attn_mask = cache_slot.write_and_view(k, v)
|
| 113 |
+
else:
|
| 114 |
+
k_cache, v_cache = k, v
|
| 115 |
+
attn_mask = None
|
| 116 |
+
attn = F.scaled_dot_product_attention(
|
| 117 |
+
q,
|
| 118 |
+
k_cache,
|
| 119 |
+
v_cache,
|
| 120 |
+
scale=1.0,
|
| 121 |
+
attn_mask=attn_mask,
|
| 122 |
+
enable_gqa=self.num_gqa_groups > 1,
|
| 123 |
+
)
|
| 124 |
+
attn = attn.transpose(1, 2).contiguous()
|
| 125 |
+
flat = attn.reshape(B, T, self.num_query_heads * self.head_dim)
|
| 126 |
+
out = self.o_proj(flat.to(torch.float32))
|
| 127 |
+
return out.to(orig_dtype), cache_slot
|
| 128 |
+
|
| 129 |
+
def _project_heads(self, layer: nn.Linear, x: torch.Tensor, heads: int) -> torch.Tensor:
|
| 130 |
+
proj = layer(x.to(torch.float32))
|
| 131 |
+
B, T, _ = proj.shape
|
| 132 |
+
proj = proj.view(B, T, heads, self.head_dim)
|
| 133 |
+
return proj.to(self.compute_dtype)
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
x: torch.Tensor,
|
| 138 |
+
positions: Optional[torch.Tensor],
|
| 139 |
+
cache=None,
|
| 140 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 141 |
+
return self.forward_incremental(x, positions, cache)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class MultiStreamEmbedding(nn.Module):
|
| 146 |
+
"""Port of dia_v2 MultiStreamEmbed."""
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
vocab_size: int,
|
| 151 |
+
dim: int,
|
| 152 |
+
pad_id: int,
|
| 153 |
+
*,
|
| 154 |
+
output_dtype: torch.dtype,
|
| 155 |
+
low_rank_dim: Optional[int] = None,
|
| 156 |
+
) -> None:
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.pad_id = pad_id
|
| 159 |
+
self.dtype = output_dtype
|
| 160 |
+
base_dim = low_rank_dim if low_rank_dim is not None else dim
|
| 161 |
+
self.embedding = nn.Embedding(vocab_size, base_dim)
|
| 162 |
+
self.main_proj = nn.Linear(base_dim, dim, bias=False)
|
| 163 |
+
self.second_proj = nn.Linear(base_dim, dim, bias=False)
|
| 164 |
+
|
| 165 |
+
def forward(self, main_inputs: torch.Tensor, second_inputs: torch.Tensor) -> torch.Tensor:
|
| 166 |
+
main_inputs = main_inputs.long()
|
| 167 |
+
second_inputs = second_inputs.long()
|
| 168 |
+
if self.pad_id is not None:
|
| 169 |
+
second_is_pad = second_inputs == self.pad_id
|
| 170 |
+
else:
|
| 171 |
+
second_is_pad = torch.zeros_like(second_inputs, dtype=torch.bool)
|
| 172 |
+
use_second = ~second_is_pad
|
| 173 |
+
emb_main = self.embedding(main_inputs)
|
| 174 |
+
emb_second = self.embedding(second_inputs)
|
| 175 |
+
out_main = self.main_proj(emb_main.to(torch.float32))
|
| 176 |
+
out_second = self.second_proj(emb_second.to(torch.float32))
|
| 177 |
+
zeros = torch.zeros_like(out_second)
|
| 178 |
+
y = out_main + torch.where(use_second.unsqueeze(-1), out_second, zeros)
|
| 179 |
+
target_dtype = self.dtype if self.dtype is not None else y.dtype
|
| 180 |
+
return y.to(target_dtype)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class Mlp(nn.Module):
|
| 184 |
+
"""Port of dia_v2 MlpBlock (two-activation gated MLP)."""
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
dim: int,
|
| 189 |
+
hidden: int,
|
| 190 |
+
compute_dtype: torch.dtype,
|
| 191 |
+
activations: Sequence[str],
|
| 192 |
+
) -> None:
|
| 193 |
+
super().__init__()
|
| 194 |
+
if len(activations) != 2:
|
| 195 |
+
raise ValueError("Mlp expects two activation functions.")
|
| 196 |
+
self.dtype = compute_dtype
|
| 197 |
+
self.hidden = hidden
|
| 198 |
+
self.branch_count = len(activations)
|
| 199 |
+
self.wi = nn.Linear(dim, self.branch_count * hidden, bias=False)
|
| 200 |
+
self.wo = nn.Linear(hidden, dim, bias=False)
|
| 201 |
+
self.activation_fns = [_get_activation(activations[0]), _get_activation(activations[1])]
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
proj = self.wi(x.to(torch.float32))
|
| 205 |
+
proj = proj.view(*x.shape[:-1], self.branch_count, self.hidden).to(self.dtype)
|
| 206 |
+
gate, up = proj.unbind(dim=-2)
|
| 207 |
+
hidden = self.activation_fns[0](gate) * self.activation_fns[1](up)
|
| 208 |
+
out = self.wo(hidden.to(torch.float32))
|
| 209 |
+
return out.to(self.dtype)
|
dia2/core/model.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from ..config import DiaConfig
|
| 9 |
+
from .cache import KVCache
|
| 10 |
+
from .depformer import Depformer
|
| 11 |
+
from .precision import Precision
|
| 12 |
+
from .transformer import TransformerDecoder
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DecodeState:
|
| 17 |
+
transformer: KVCache
|
| 18 |
+
depformer: KVCache
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Dia2Model(nn.Module):
|
| 22 |
+
def __init__(self, config: DiaConfig, precision: Precision):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.config = config
|
| 25 |
+
self.precision = precision
|
| 26 |
+
self.transformer = TransformerDecoder(config, precision)
|
| 27 |
+
self.depformer = Depformer(config, precision)
|
| 28 |
+
self._cast_norms_to_compute()
|
| 29 |
+
|
| 30 |
+
def init_state(self, batch_size: int, device: torch.device, max_steps: int) -> DecodeState:
|
| 31 |
+
transformer_cache = self.transformer.init_cache(batch_size, device, max_steps)
|
| 32 |
+
depformer_cache = self.depformer.init_cache(batch_size, device, self.depformer.num_depth)
|
| 33 |
+
return DecodeState(transformer_cache, depformer_cache)
|
| 34 |
+
|
| 35 |
+
def step_text(
|
| 36 |
+
self,
|
| 37 |
+
tokens: torch.Tensor,
|
| 38 |
+
positions: torch.Tensor,
|
| 39 |
+
state: DecodeState,
|
| 40 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 41 |
+
hidden, action, cb0, cache = self.transformer.forward_step(tokens, positions, state.transformer)
|
| 42 |
+
state.transformer = cache
|
| 43 |
+
return hidden, action, cb0
|
| 44 |
+
|
| 45 |
+
def step_audio_stage(
|
| 46 |
+
self,
|
| 47 |
+
stage_index: int,
|
| 48 |
+
prev_audio: torch.Tensor,
|
| 49 |
+
transformer_hidden: torch.Tensor,
|
| 50 |
+
state: DecodeState,
|
| 51 |
+
main_text: Optional[torch.Tensor],
|
| 52 |
+
second_text: Optional[torch.Tensor],
|
| 53 |
+
) -> torch.Tensor:
|
| 54 |
+
cache = state.depformer
|
| 55 |
+
logits, new_cache = self.depformer.forward_step(
|
| 56 |
+
prev_audio,
|
| 57 |
+
transformer_hidden,
|
| 58 |
+
stage_index,
|
| 59 |
+
cache,
|
| 60 |
+
main_text,
|
| 61 |
+
second_text,
|
| 62 |
+
)
|
| 63 |
+
state.depformer = new_cache
|
| 64 |
+
return logits
|
| 65 |
+
|
| 66 |
+
def _cast_norms_to_compute(self) -> None:
|
| 67 |
+
"""Cast RMSNorm weights/biases to the compute dtype to avoid bf16 warnings."""
|
| 68 |
+
def _convert(module: nn.Module) -> None:
|
| 69 |
+
if isinstance(module, nn.RMSNorm):
|
| 70 |
+
module.to(self.precision.compute)
|
| 71 |
+
|
| 72 |
+
self.apply(_convert)
|
dia2/core/precision.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class Precision:
|
| 10 |
+
compute: torch.dtype
|
| 11 |
+
logits: torch.dtype
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def resolve_precision(kind: str | None, device: torch.device) -> Precision:
|
| 15 |
+
normalized = (kind or "auto").lower()
|
| 16 |
+
if normalized == "auto":
|
| 17 |
+
normalized = "bfloat16" if device.type == "cuda" else "float32"
|
| 18 |
+
if normalized == "bfloat16":
|
| 19 |
+
compute = torch.bfloat16 if device.type == "cuda" else torch.float32
|
| 20 |
+
return Precision(compute=compute, logits=torch.float32)
|
| 21 |
+
if normalized == "float32":
|
| 22 |
+
return Precision(compute=torch.float32, logits=torch.float32)
|
| 23 |
+
raise ValueError(f"Unsupported dtype '{kind}'")
|
dia2/core/transformer.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from ..config import DiaConfig
|
| 10 |
+
from .cache import KVCache
|
| 11 |
+
from .precision import Precision
|
| 12 |
+
from .layers import (
|
| 13 |
+
AttentionShape,
|
| 14 |
+
MultiStreamEmbedding,
|
| 15 |
+
Mlp,
|
| 16 |
+
Attention,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TransformerDecoder(nn.Module):
|
| 21 |
+
"""Inference-time port of dia_v2.model.Transformer."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, config: DiaConfig, precision: Precision):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.config = config
|
| 26 |
+
self.precision = precision
|
| 27 |
+
data_cfg = config.data
|
| 28 |
+
dec_cfg = config.model.decoder
|
| 29 |
+
|
| 30 |
+
self.audio_embeds = nn.ModuleList(
|
| 31 |
+
[
|
| 32 |
+
nn.Embedding(
|
| 33 |
+
data_cfg.audio_vocab_size,
|
| 34 |
+
dec_cfg.n_embd,
|
| 35 |
+
)
|
| 36 |
+
for _ in range(max(0, data_cfg.channels - 2))
|
| 37 |
+
]
|
| 38 |
+
)
|
| 39 |
+
self.text_embed = MultiStreamEmbedding(
|
| 40 |
+
data_cfg.text_vocab_size,
|
| 41 |
+
dec_cfg.n_embd,
|
| 42 |
+
pad_id=data_cfg.text_pad_token_id,
|
| 43 |
+
output_dtype=self.precision.compute,
|
| 44 |
+
low_rank_dim=dec_cfg.low_rank_dim,
|
| 45 |
+
)
|
| 46 |
+
self.layers = nn.ModuleList([DecoderLayer(config, precision) for _ in range(dec_cfg.n_layer)])
|
| 47 |
+
self.norm = nn.RMSNorm(dec_cfg.n_embd, eps=config.model.normalization_layer_epsilon, dtype=torch.float32)
|
| 48 |
+
|
| 49 |
+
self.action_head = nn.Linear(dec_cfg.n_embd, data_cfg.action_vocab_size, bias=False)
|
| 50 |
+
self.cb0_head = nn.Linear(dec_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
|
| 51 |
+
|
| 52 |
+
def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
|
| 53 |
+
heads = self.layers[0].attn.num_kv_heads
|
| 54 |
+
head_dim = self.layers[0].attn.head_dim
|
| 55 |
+
return KVCache.allocate(
|
| 56 |
+
num_layers=len(self.layers),
|
| 57 |
+
batch_size=batch_size,
|
| 58 |
+
heads=heads,
|
| 59 |
+
max_steps=max_steps,
|
| 60 |
+
head_dim=head_dim,
|
| 61 |
+
device=device,
|
| 62 |
+
dtype=self.precision.compute,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def forward_step(
|
| 66 |
+
self,
|
| 67 |
+
tokens: torch.Tensor,
|
| 68 |
+
positions: torch.Tensor,
|
| 69 |
+
cache: KVCache,
|
| 70 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, KVCache]:
|
| 71 |
+
if cache is None:
|
| 72 |
+
raise ValueError("Transformer cache must be initialized")
|
| 73 |
+
|
| 74 |
+
B, C, T1 = tokens.shape
|
| 75 |
+
if T1 != 1:
|
| 76 |
+
raise ValueError("forward_step expects sequence length 1")
|
| 77 |
+
num_audio_channels = max(0, C - 2)
|
| 78 |
+
|
| 79 |
+
hidden_t = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
|
| 80 |
+
for idx in range(num_audio_channels):
|
| 81 |
+
audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
|
| 82 |
+
hidden_t.add_(audio_emb)
|
| 83 |
+
hidden_t = hidden_t.to(self.precision.compute)
|
| 84 |
+
|
| 85 |
+
x = hidden_t
|
| 86 |
+
for idx, layer in enumerate(self.layers):
|
| 87 |
+
slot = cache.get_slot(idx)
|
| 88 |
+
x, _ = layer.decode_step(x, positions, slot)
|
| 89 |
+
|
| 90 |
+
hidden_norm = self.norm(x)
|
| 91 |
+
action_logits = self.action_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
|
| 92 |
+
cb0_logits = self.cb0_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
|
| 93 |
+
return hidden_norm, action_logits, cb0_logits, cache
|
| 94 |
+
|
| 95 |
+
def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 96 |
+
B, C, T1 = tokens.shape
|
| 97 |
+
if T1 != 1:
|
| 98 |
+
raise ValueError("_embed expects sequence length 1")
|
| 99 |
+
num_audio_channels = max(0, C - 2)
|
| 100 |
+
text_hidden = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
|
| 101 |
+
audio_terms: list[torch.Tensor] = []
|
| 102 |
+
for idx in range(num_audio_channels):
|
| 103 |
+
audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
|
| 104 |
+
audio_terms.append(audio_emb)
|
| 105 |
+
hidden = text_hidden
|
| 106 |
+
for term in audio_terms:
|
| 107 |
+
hidden = hidden + term
|
| 108 |
+
final = hidden.to(self.precision.compute)
|
| 109 |
+
return final
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class DecoderLayer(nn.Module):
|
| 113 |
+
def __init__(self, config: DiaConfig, precision: Precision):
|
| 114 |
+
super().__init__()
|
| 115 |
+
dec = config.model.decoder
|
| 116 |
+
eps = config.model.normalization_layer_epsilon
|
| 117 |
+
self.pre_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
|
| 118 |
+
self.attn = Attention(config, dec.n_embd, precision.compute)
|
| 119 |
+
self.post_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
|
| 120 |
+
self.mlp = Mlp(
|
| 121 |
+
dec.n_embd,
|
| 122 |
+
dec.n_hidden,
|
| 123 |
+
precision.compute,
|
| 124 |
+
tuple(config.model.linear.mlp_activations),
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def decode_step(
|
| 128 |
+
self,
|
| 129 |
+
x: torch.Tensor,
|
| 130 |
+
pos: torch.Tensor,
|
| 131 |
+
cache_slot,
|
| 132 |
+
) -> Tuple[torch.Tensor, object]:
|
| 133 |
+
residual = x
|
| 134 |
+
x_norm = self.pre_norm(x)
|
| 135 |
+
attn_out, _ = self.attn(x_norm, pos, cache_slot)
|
| 136 |
+
x = residual + attn_out
|
| 137 |
+
residual2 = x
|
| 138 |
+
x_norm2 = self.post_norm(x)
|
| 139 |
+
mlp_out = self.mlp(x_norm2)
|
| 140 |
+
return residual2 + mlp_out, cache_slot
|
dia2/engine.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Sequence
|
| 5 |
+
|
| 6 |
+
from .assets import resolve_assets
|
| 7 |
+
from .runtime.context import RuntimeContext, build_runtime
|
| 8 |
+
from .runtime.generator import (
|
| 9 |
+
build_initial_state,
|
| 10 |
+
decode_audio,
|
| 11 |
+
run_generation_loop,
|
| 12 |
+
warmup_with_prefix,
|
| 13 |
+
)
|
| 14 |
+
from .runtime.script_parser import parse_script
|
| 15 |
+
from .audio.grid import undelay_frames, write_wav
|
| 16 |
+
from .runtime.voice_clone import build_prefix_plan
|
| 17 |
+
from .generation import (
|
| 18 |
+
GenerationConfig,
|
| 19 |
+
GenerationResult,
|
| 20 |
+
merge_generation_config,
|
| 21 |
+
normalize_script,
|
| 22 |
+
)
|
| 23 |
+
from .runtime.logger import RuntimeLogger
|
| 24 |
+
|
| 25 |
+
class Dia2:
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
*,
|
| 29 |
+
repo: Optional[str] = None,
|
| 30 |
+
config_path: Optional[str | Path] = None,
|
| 31 |
+
weights_path: Optional[str | Path] = None,
|
| 32 |
+
tokenizer_id: Optional[str | Path] = None,
|
| 33 |
+
mimi_id: Optional[str] = None,
|
| 34 |
+
device: str = "cuda",
|
| 35 |
+
dtype: str = "auto",
|
| 36 |
+
default_config: Optional[GenerationConfig] = None,
|
| 37 |
+
) -> None:
|
| 38 |
+
bundle = resolve_assets(
|
| 39 |
+
repo=repo,
|
| 40 |
+
config_path=config_path,
|
| 41 |
+
weights_path=weights_path,
|
| 42 |
+
)
|
| 43 |
+
self._config_path = bundle.config_path
|
| 44 |
+
self._weights_path = bundle.weights_path
|
| 45 |
+
self._tokenizer_id = (str(tokenizer_id) if tokenizer_id else None) or bundle.tokenizer_id
|
| 46 |
+
self._repo_id = bundle.repo_id
|
| 47 |
+
self._mimi_id = mimi_id or bundle.mimi_id
|
| 48 |
+
self.device = device
|
| 49 |
+
self._dtype_pref = dtype or "auto"
|
| 50 |
+
self.default_config = default_config or GenerationConfig()
|
| 51 |
+
self._runtime: Optional[RuntimeContext] = None
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_repo(
|
| 55 |
+
cls,
|
| 56 |
+
repo: str,
|
| 57 |
+
*,
|
| 58 |
+
device: str = "cuda",
|
| 59 |
+
dtype: str = "auto",
|
| 60 |
+
tokenizer_id: Optional[str] = None,
|
| 61 |
+
mimi_id: Optional[str] = None,
|
| 62 |
+
) -> "Dia2":
|
| 63 |
+
return cls(repo=repo, device=device, dtype=dtype, tokenizer_id=tokenizer_id, mimi_id=mimi_id)
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_local(
|
| 67 |
+
cls,
|
| 68 |
+
config_path: str | Path,
|
| 69 |
+
weights_path: str | Path,
|
| 70 |
+
*,
|
| 71 |
+
device: str = "cuda",
|
| 72 |
+
dtype: str = "auto",
|
| 73 |
+
tokenizer_id: Optional[str | Path] = None,
|
| 74 |
+
mimi_id: Optional[str] = None,
|
| 75 |
+
) -> "Dia2":
|
| 76 |
+
return cls(
|
| 77 |
+
config_path=config_path,
|
| 78 |
+
weights_path=weights_path,
|
| 79 |
+
tokenizer_id=tokenizer_id,
|
| 80 |
+
device=device,
|
| 81 |
+
dtype=dtype,
|
| 82 |
+
mimi_id=mimi_id,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def set_device(self, device: str, *, dtype: Optional[str] = None) -> None:
|
| 86 |
+
desired_dtype = dtype or self._dtype_pref
|
| 87 |
+
if self.device == device and desired_dtype == self._dtype_pref:
|
| 88 |
+
return
|
| 89 |
+
self.device = device
|
| 90 |
+
self._dtype_pref = desired_dtype
|
| 91 |
+
self._runtime = None
|
| 92 |
+
|
| 93 |
+
def close(self) -> None:
|
| 94 |
+
self._runtime = None
|
| 95 |
+
|
| 96 |
+
def _ensure_runtime(self) -> RuntimeContext:
|
| 97 |
+
if self._runtime is None:
|
| 98 |
+
self._runtime = self._build_runtime()
|
| 99 |
+
return self._runtime
|
| 100 |
+
|
| 101 |
+
def generate(
|
| 102 |
+
self,
|
| 103 |
+
script: str | Sequence[str],
|
| 104 |
+
*,
|
| 105 |
+
config: Optional[GenerationConfig] = None,
|
| 106 |
+
output_wav: Optional[str | Path] = None,
|
| 107 |
+
prefix_speaker_1: Optional[str] = None,
|
| 108 |
+
prefix_speaker_2: Optional[str] = None,
|
| 109 |
+
include_prefix: Optional[bool] = None,
|
| 110 |
+
verbose: bool = False,
|
| 111 |
+
**overrides,
|
| 112 |
+
):
|
| 113 |
+
runtime = self._ensure_runtime()
|
| 114 |
+
logger = RuntimeLogger(verbose)
|
| 115 |
+
merged_overrides = dict(overrides)
|
| 116 |
+
if prefix_speaker_1 is not None:
|
| 117 |
+
merged_overrides["prefix_speaker_1"] = prefix_speaker_1
|
| 118 |
+
if prefix_speaker_2 is not None:
|
| 119 |
+
merged_overrides["prefix_speaker_2"] = prefix_speaker_2
|
| 120 |
+
if include_prefix is not None:
|
| 121 |
+
merged_overrides["include_prefix"] = include_prefix
|
| 122 |
+
merged = merge_generation_config(base=config or self.default_config, overrides=merged_overrides)
|
| 123 |
+
max_context = runtime.config.runtime.max_context_steps
|
| 124 |
+
text = normalize_script(script)
|
| 125 |
+
prefix_plan = build_prefix_plan(runtime, merged.prefix)
|
| 126 |
+
entries = []
|
| 127 |
+
if prefix_plan is not None:
|
| 128 |
+
entries.extend(prefix_plan.entries)
|
| 129 |
+
entries.extend(parse_script([text], runtime.tokenizer, runtime.constants, runtime.frame_rate))
|
| 130 |
+
runtime.machine.initial_padding = merged.initial_padding
|
| 131 |
+
logger.event(
|
| 132 |
+
f"starting generation: max_context={max_context} cfg_scale={merged.cfg_scale:.2f} "
|
| 133 |
+
f"device={self.device} dtype={self._dtype_pref}"
|
| 134 |
+
)
|
| 135 |
+
state = runtime.machine.new_state(entries)
|
| 136 |
+
cfg_active = merged.cfg_scale != 1.0
|
| 137 |
+
if cfg_active:
|
| 138 |
+
logger.event(f"classifier-free guidance enabled (scale={merged.cfg_scale:.2f})")
|
| 139 |
+
else:
|
| 140 |
+
logger.event("classifier-free guidance disabled (scale=1.0)")
|
| 141 |
+
gen_state = build_initial_state(
|
| 142 |
+
runtime,
|
| 143 |
+
prefix=prefix_plan,
|
| 144 |
+
)
|
| 145 |
+
include_prefix_audio = bool(prefix_plan and merged.prefix and merged.prefix.include_audio)
|
| 146 |
+
start_step = 0
|
| 147 |
+
if prefix_plan is not None:
|
| 148 |
+
logger.event(f"warming up with prefix ({prefix_plan.aligned_frames} frames)")
|
| 149 |
+
start_step = warmup_with_prefix(runtime, prefix_plan, state, gen_state)
|
| 150 |
+
if include_prefix_audio:
|
| 151 |
+
logger.event("prefix audio will be kept in output")
|
| 152 |
+
else:
|
| 153 |
+
logger.event("prefix audio trimmed from output")
|
| 154 |
+
first_word_frame, audio_buf = run_generation_loop(
|
| 155 |
+
runtime,
|
| 156 |
+
state=state,
|
| 157 |
+
generation=gen_state,
|
| 158 |
+
config=merged,
|
| 159 |
+
start_step=start_step,
|
| 160 |
+
logger=logger,
|
| 161 |
+
)
|
| 162 |
+
aligned = undelay_frames(audio_buf[0], runtime.audio_delays, runtime.constants.audio_pad).unsqueeze(0)
|
| 163 |
+
crop = 0 if include_prefix_audio else max(first_word_frame, 0)
|
| 164 |
+
if crop > 0 and crop < aligned.shape[-1]:
|
| 165 |
+
aligned = aligned[:, :, crop:]
|
| 166 |
+
elif crop >= aligned.shape[-1]:
|
| 167 |
+
crop = 0
|
| 168 |
+
logger.event(f"decoding {aligned.shape[-1]} Mimi frames")
|
| 169 |
+
waveform = decode_audio(runtime, aligned)
|
| 170 |
+
if output_wav is not None:
|
| 171 |
+
write_wav(str(output_wav), waveform.detach().cpu().numpy(), runtime.mimi.sample_rate)
|
| 172 |
+
duration = waveform.shape[-1] / max(runtime.mimi.sample_rate, 1)
|
| 173 |
+
logger.event(f"saved {output_wav} ({duration:.2f}s)")
|
| 174 |
+
frame_rate = max(runtime.frame_rate, 1.0)
|
| 175 |
+
prefix_entry_count = len(prefix_plan.entries) if prefix_plan is not None else 0
|
| 176 |
+
transcript_entries = state.transcript
|
| 177 |
+
if prefix_plan is not None and not include_prefix_audio:
|
| 178 |
+
if len(transcript_entries) > prefix_entry_count:
|
| 179 |
+
transcript_entries = transcript_entries[prefix_entry_count:]
|
| 180 |
+
else:
|
| 181 |
+
transcript_entries = []
|
| 182 |
+
timestamps = []
|
| 183 |
+
for word, step in transcript_entries:
|
| 184 |
+
adj = step - crop
|
| 185 |
+
if adj < 0:
|
| 186 |
+
continue
|
| 187 |
+
timestamps.append((word, adj / frame_rate))
|
| 188 |
+
logger.event(f"generation finished in {logger.elapsed():.2f}s")
|
| 189 |
+
return GenerationResult(aligned, waveform, runtime.mimi.sample_rate, timestamps)
|
| 190 |
+
|
| 191 |
+
def save_wav(self, script: str | Sequence[str], path: str | Path, **kwargs):
|
| 192 |
+
return self.generate(script, output_wav=path, **kwargs)
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def sample_rate(self) -> int:
|
| 196 |
+
return self._ensure_runtime().mimi.sample_rate
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def tokenizer_id(self) -> Optional[str]:
|
| 200 |
+
if self._tokenizer_id:
|
| 201 |
+
return self._tokenizer_id
|
| 202 |
+
if self._runtime is not None:
|
| 203 |
+
return getattr(self._runtime.tokenizer, "name_or_path", None)
|
| 204 |
+
return self._repo_id
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def dtype(self) -> str:
|
| 208 |
+
return self._dtype_pref
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def max_context_steps(self) -> int:
|
| 212 |
+
return self._ensure_runtime().config.runtime.max_context_steps
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def repo(self) -> Optional[str]:
|
| 216 |
+
return self._repo_id
|
| 217 |
+
|
| 218 |
+
def _build_runtime(self) -> RuntimeContext:
|
| 219 |
+
runtime, tokenizer_ref, mimi_ref = build_runtime(
|
| 220 |
+
config_path=self._config_path,
|
| 221 |
+
weights_path=self._weights_path,
|
| 222 |
+
tokenizer_id=self._tokenizer_id,
|
| 223 |
+
repo_id=self._repo_id,
|
| 224 |
+
mimi_id=self._mimi_id,
|
| 225 |
+
device=self.device,
|
| 226 |
+
dtype_pref=self._dtype_pref,
|
| 227 |
+
)
|
| 228 |
+
self._tokenizer_id = tokenizer_ref
|
| 229 |
+
self._mimi_id = mimi_ref
|
| 230 |
+
return runtime
|
dia2/generation.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Mapping, Optional, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class SamplingConfig:
|
| 13 |
+
temperature: float = 0.8
|
| 14 |
+
top_k: int = 50
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _default_text_sampling() -> SamplingConfig:
|
| 18 |
+
return SamplingConfig(temperature=0.6, top_k=50)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _default_audio_sampling() -> SamplingConfig:
|
| 22 |
+
return SamplingConfig(temperature=0.8, top_k=50)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class PrefixConfig:
|
| 27 |
+
speaker_1: Optional[str] = None
|
| 28 |
+
speaker_2: Optional[str] = None
|
| 29 |
+
include_audio: bool = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass(frozen=True)
|
| 33 |
+
class GenerationConfig:
|
| 34 |
+
text: SamplingConfig = field(default_factory=_default_text_sampling)
|
| 35 |
+
audio: SamplingConfig = field(default_factory=_default_audio_sampling)
|
| 36 |
+
cfg_scale: float = 2.0
|
| 37 |
+
cfg_filter_k: int = 50
|
| 38 |
+
initial_padding: int = 2
|
| 39 |
+
prefix: Optional["PrefixConfig"] = None
|
| 40 |
+
use_cuda_graph: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass(frozen=True)
|
| 44 |
+
class GenerationResult:
|
| 45 |
+
audio_tokens: torch.Tensor
|
| 46 |
+
waveform: torch.Tensor
|
| 47 |
+
sample_rate: int
|
| 48 |
+
timestamps: List[Tuple[str, float]]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def normalize_script(script: str | Sequence[str]) -> str:
|
| 52 |
+
if isinstance(script, str):
|
| 53 |
+
return script.strip()
|
| 54 |
+
return "\n".join(line.strip() for line in script)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_script_text(path: str | Path) -> str:
|
| 58 |
+
if path == "-":
|
| 59 |
+
return sys.stdin.read().strip()
|
| 60 |
+
path_obj = Path(path)
|
| 61 |
+
if path_obj.exists():
|
| 62 |
+
return path_obj.read_text().strip()
|
| 63 |
+
return str(path).strip()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def validate_generation_params(
|
| 67 |
+
*,
|
| 68 |
+
temperature: float,
|
| 69 |
+
top_k: int,
|
| 70 |
+
cfg_scale: float,
|
| 71 |
+
) -> tuple[float, int, float]:
|
| 72 |
+
if temperature <= 0:
|
| 73 |
+
raise ValueError("temperature must be positive")
|
| 74 |
+
if top_k <= 0:
|
| 75 |
+
raise ValueError("top_k must be positive")
|
| 76 |
+
if cfg_scale <= 0:
|
| 77 |
+
raise ValueError("cfg_scale must be positive")
|
| 78 |
+
return temperature, top_k, cfg_scale
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def build_generation_config(
|
| 82 |
+
*,
|
| 83 |
+
temperature: float,
|
| 84 |
+
top_k: int,
|
| 85 |
+
cfg_scale: float,
|
| 86 |
+
) -> GenerationConfig:
|
| 87 |
+
sampling = SamplingConfig(temperature=temperature, top_k=top_k)
|
| 88 |
+
return GenerationConfig(
|
| 89 |
+
text=sampling,
|
| 90 |
+
audio=sampling,
|
| 91 |
+
cfg_scale=cfg_scale,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def merge_generation_config(
|
| 96 |
+
*,
|
| 97 |
+
base: GenerationConfig,
|
| 98 |
+
overrides: Mapping[str, object],
|
| 99 |
+
) -> GenerationConfig:
|
| 100 |
+
clean_overrides = {k: v for k, v in overrides.items() if v is not None}
|
| 101 |
+
text_temp = clean_overrides.pop("temp_text", None)
|
| 102 |
+
text_topk = clean_overrides.pop("topk_text", None)
|
| 103 |
+
audio_temp = clean_overrides.pop("temp_audio", None)
|
| 104 |
+
audio_topk = clean_overrides.pop("topk_audio", None)
|
| 105 |
+
prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None)
|
| 106 |
+
prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None)
|
| 107 |
+
include_prefix = clean_overrides.pop("include_prefix", None)
|
| 108 |
+
|
| 109 |
+
text_sampling = base.text
|
| 110 |
+
if text_temp is not None or text_topk is not None:
|
| 111 |
+
text_sampling = SamplingConfig(
|
| 112 |
+
temperature=text_temp if text_temp is not None else text_sampling.temperature,
|
| 113 |
+
top_k=text_topk if text_topk is not None else text_sampling.top_k,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
audio_sampling = base.audio
|
| 117 |
+
if audio_temp is not None or audio_topk is not None:
|
| 118 |
+
audio_sampling = SamplingConfig(
|
| 119 |
+
temperature=audio_temp if audio_temp is not None else audio_sampling.temperature,
|
| 120 |
+
top_k=audio_topk if audio_topk is not None else audio_sampling.top_k,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
prefix_cfg = base.prefix
|
| 124 |
+
if (
|
| 125 |
+
prefix_speaker_1 is not None
|
| 126 |
+
or prefix_speaker_2 is not None
|
| 127 |
+
or include_prefix is not None
|
| 128 |
+
or prefix_cfg is not None
|
| 129 |
+
):
|
| 130 |
+
prefix_cfg = prefix_cfg or PrefixConfig()
|
| 131 |
+
prefix_cfg = PrefixConfig(
|
| 132 |
+
speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1,
|
| 133 |
+
speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2,
|
| 134 |
+
include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return GenerationConfig(
|
| 138 |
+
text=text_sampling,
|
| 139 |
+
audio=audio_sampling,
|
| 140 |
+
cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale),
|
| 141 |
+
cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k),
|
| 142 |
+
initial_padding=clean_overrides.pop("initial_padding", base.initial_padding),
|
| 143 |
+
prefix=prefix_cfg,
|
| 144 |
+
use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
__all__ = [
|
| 149 |
+
"SamplingConfig",
|
| 150 |
+
"GenerationConfig",
|
| 151 |
+
"GenerationResult",
|
| 152 |
+
"PrefixConfig",
|
| 153 |
+
"normalize_script",
|
| 154 |
+
"load_script_text",
|
| 155 |
+
"validate_generation_params",
|
| 156 |
+
"build_generation_config",
|
| 157 |
+
"merge_generation_config",
|
| 158 |
+
]
|
dia2/runtime/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .state_machine import Entry, StateMachine, TokenIds
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"Entry",
|
| 5 |
+
"StateMachine",
|
| 6 |
+
"TokenIds",
|
| 7 |
+
]
|
dia2/runtime/audio_io.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import sphn
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from ..audio import MimiCodec
|
| 12 |
+
|
| 13 |
+
PathLike = Union[str, Path]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_mono_audio(path: PathLike, target_sr: int) -> np.ndarray:
|
| 17 |
+
"""Read an audio file, convert to mono float32, and resample to target_sr."""
|
| 18 |
+
path = str(path)
|
| 19 |
+
try:
|
| 20 |
+
audio, sr = sphn.read_wav(path)
|
| 21 |
+
except Exception:
|
| 22 |
+
import soundfile as sf # Local fallback
|
| 23 |
+
|
| 24 |
+
audio, sr = sf.read(path, dtype="float32", always_2d=False)
|
| 25 |
+
audio = np.asarray(audio, dtype=np.float32)
|
| 26 |
+
if audio.ndim == 2:
|
| 27 |
+
audio = audio.mean(axis=1)
|
| 28 |
+
if sr != target_sr:
|
| 29 |
+
if hasattr(sphn, "resample_audio"):
|
| 30 |
+
audio = sphn.resample_audio(audio, sr, target_sr).astype(np.float32)
|
| 31 |
+
else:
|
| 32 |
+
audio = _resample_linear(audio, sr, target_sr)
|
| 33 |
+
return audio
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def audio_to_tensor(audio: np.ndarray, device: torch.device) -> torch.Tensor:
|
| 37 |
+
"""Convert mono PCM samples into shape [1, 1, T] tensor."""
|
| 38 |
+
tensor = torch.from_numpy(audio).to(device)
|
| 39 |
+
if tensor.dim() == 1:
|
| 40 |
+
tensor = tensor.unsqueeze(0)
|
| 41 |
+
if tensor.dim() == 2:
|
| 42 |
+
tensor = tensor.unsqueeze(0)
|
| 43 |
+
return tensor
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def encode_audio_tokens(mimi: MimiCodec, audio: np.ndarray) -> torch.Tensor:
|
| 47 |
+
"""Encode PCM audio into Mimi codebook tokens [C, T]."""
|
| 48 |
+
waveform = audio_to_tensor(audio, mimi.device)
|
| 49 |
+
with torch.inference_mode():
|
| 50 |
+
codes, *_ = mimi.encode(waveform, return_dict=False)
|
| 51 |
+
if isinstance(codes, (tuple, list)):
|
| 52 |
+
codes = codes[0]
|
| 53 |
+
# Mimi.encode returns [B, num_codebooks, T]; select batch 0.
|
| 54 |
+
codes = codes[0].to(torch.long)
|
| 55 |
+
return codes
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _resample_linear(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
|
| 59 |
+
if src_sr == dst_sr:
|
| 60 |
+
return audio.astype(np.float32)
|
| 61 |
+
length = audio.shape[0]
|
| 62 |
+
new_length = max(1, int(round(length * dst_sr / src_sr)))
|
| 63 |
+
tensor = torch.from_numpy(audio.astype(np.float32)).unsqueeze(0).unsqueeze(0)
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
resampled = F.interpolate(tensor, size=new_length, mode="linear", align_corners=False)
|
| 66 |
+
return resampled.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
__all__ = ["load_mono_audio", "audio_to_tensor", "encode_audio_tokens"]
|
dia2/runtime/context.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
| 11 |
+
|
| 12 |
+
from ..config import DiaConfig, load_config
|
| 13 |
+
from ..core.model import Dia2Model
|
| 14 |
+
from ..core.precision import Precision, resolve_precision
|
| 15 |
+
from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
|
| 16 |
+
from .state_machine import StateMachine, TokenIds
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class RuntimeContext:
|
| 21 |
+
config: DiaConfig
|
| 22 |
+
model: Dia2Model
|
| 23 |
+
precision: Precision
|
| 24 |
+
tokenizer: PreTrainedTokenizerBase
|
| 25 |
+
mimi: MimiCodec
|
| 26 |
+
device: torch.device
|
| 27 |
+
machine: StateMachine
|
| 28 |
+
transformer_step: callable
|
| 29 |
+
depformer_step: callable
|
| 30 |
+
constants: TokenIds
|
| 31 |
+
audio_delays: list[int]
|
| 32 |
+
audio_delay_tensor: torch.Tensor
|
| 33 |
+
frame_rate: float
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_runtime(
|
| 37 |
+
*,
|
| 38 |
+
config_path: str | Path,
|
| 39 |
+
weights_path: str | Path,
|
| 40 |
+
tokenizer_id: Optional[str],
|
| 41 |
+
repo_id: Optional[str],
|
| 42 |
+
mimi_id: Optional[str],
|
| 43 |
+
device: str,
|
| 44 |
+
dtype_pref: str,
|
| 45 |
+
) -> tuple[RuntimeContext, str, str]:
|
| 46 |
+
device_obj = torch.device(device)
|
| 47 |
+
if device_obj.type == "cuda":
|
| 48 |
+
cuda_matmul = torch.backends.cuda.matmul
|
| 49 |
+
cudnn_conv = torch.backends.cudnn.conv
|
| 50 |
+
if hasattr(cuda_matmul, "fp32_precision"):
|
| 51 |
+
cuda_matmul.fp32_precision = "tf32"
|
| 52 |
+
with warnings.catch_warnings():
|
| 53 |
+
warnings.filterwarnings(
|
| 54 |
+
"ignore",
|
| 55 |
+
message="Please use the new API settings",
|
| 56 |
+
)
|
| 57 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 58 |
+
else: # pragma: no cover - compatibility with older PyTorch
|
| 59 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 60 |
+
if hasattr(cudnn_conv, "fp32_precision"):
|
| 61 |
+
cudnn_conv.fp32_precision = "tf32"
|
| 62 |
+
with warnings.catch_warnings():
|
| 63 |
+
warnings.filterwarnings(
|
| 64 |
+
"ignore",
|
| 65 |
+
message="Please use the new API settings",
|
| 66 |
+
)
|
| 67 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 68 |
+
else: # pragma: no cover
|
| 69 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 70 |
+
precision = resolve_precision(dtype_pref, device_obj)
|
| 71 |
+
config = load_config(config_path)
|
| 72 |
+
model = Dia2Model(config, precision)
|
| 73 |
+
state = load_file(str(weights_path))
|
| 74 |
+
model.load_state_dict(state)
|
| 75 |
+
model = model.to(device_obj)
|
| 76 |
+
|
| 77 |
+
tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
|
| 78 |
+
if tokenizer_ref is None:
|
| 79 |
+
raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
|
| 80 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 81 |
+
tokenizer_ref,
|
| 82 |
+
use_fast=False,
|
| 83 |
+
trust_remote_code=True,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
|
| 87 |
+
mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)
|
| 88 |
+
|
| 89 |
+
data_cfg = config.data
|
| 90 |
+
constants = TokenIds(
|
| 91 |
+
card=data_cfg.text_vocab_size,
|
| 92 |
+
new_word=data_cfg.text_new_word_token_id,
|
| 93 |
+
pad=data_cfg.text_pad_token_id,
|
| 94 |
+
bos=getattr(tokenizer, "bos_token_id", 1) or 1,
|
| 95 |
+
zero=data_cfg.text_zero_token_id,
|
| 96 |
+
spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
|
| 97 |
+
spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
|
| 98 |
+
audio_pad=data_cfg.audio_pad_token_id,
|
| 99 |
+
audio_bos=data_cfg.audio_bos_token_id,
|
| 100 |
+
)
|
| 101 |
+
machine = StateMachine(
|
| 102 |
+
token_ids=constants,
|
| 103 |
+
second_stream_ahead=data_cfg.second_stream_ahead,
|
| 104 |
+
max_padding=6,
|
| 105 |
+
initial_padding=0,
|
| 106 |
+
)
|
| 107 |
+
audio_delays = list(data_cfg.delay_pattern)
|
| 108 |
+
audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
|
| 109 |
+
frame_rate = getattr(mimi, "frame_rate", 75.0)
|
| 110 |
+
|
| 111 |
+
runtime = RuntimeContext(
|
| 112 |
+
config=config,
|
| 113 |
+
precision=precision,
|
| 114 |
+
model=model,
|
| 115 |
+
tokenizer=tokenizer,
|
| 116 |
+
mimi=mimi,
|
| 117 |
+
device=device_obj,
|
| 118 |
+
machine=machine,
|
| 119 |
+
constants=constants,
|
| 120 |
+
audio_delays=audio_delays,
|
| 121 |
+
audio_delay_tensor=audio_delay_tensor,
|
| 122 |
+
frame_rate=frame_rate,
|
| 123 |
+
transformer_step=model.transformer.forward_step,
|
| 124 |
+
depformer_step=model.depformer.forward_step,
|
| 125 |
+
)
|
| 126 |
+
return runtime, tokenizer_ref, mimi_ref
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
__all__ = [
|
| 130 |
+
"RuntimeContext",
|
| 131 |
+
"build_runtime",
|
| 132 |
+
]
|
dia2/runtime/generator.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ..core.cache import KVCache
|
| 9 |
+
from ..core.model import DecodeState
|
| 10 |
+
from ..generation import GenerationConfig
|
| 11 |
+
from ..audio.grid import delay_frames, mask_audio_logits, undelay_frames
|
| 12 |
+
from .context import RuntimeContext
|
| 13 |
+
from .state_machine import State, TokenIds
|
| 14 |
+
from .guidance import apply_classifier_guidance, sample_audio_logits
|
| 15 |
+
from .sampler import sample_token
|
| 16 |
+
from .voice_clone import PrefixPlan
|
| 17 |
+
from .logger import RuntimeLogger
|
| 18 |
+
|
| 19 |
+
_GRAPH_CUBLAS_READY = False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _ensure_graph_cublas_ready(device: torch.device) -> None:
|
| 23 |
+
global _GRAPH_CUBLAS_READY
|
| 24 |
+
if _GRAPH_CUBLAS_READY or device.type != "cuda":
|
| 25 |
+
return
|
| 26 |
+
tmp = torch.empty((1, 1), device=device, dtype=torch.float32)
|
| 27 |
+
torch.matmul(tmp, tmp)
|
| 28 |
+
torch.cuda.synchronize()
|
| 29 |
+
_GRAPH_CUBLAS_READY = True
|
| 30 |
+
@dataclass
|
| 31 |
+
class GenerationState:
|
| 32 |
+
decode: DecodeState
|
| 33 |
+
step_tokens: torch.Tensor
|
| 34 |
+
audio_buf: torch.Tensor
|
| 35 |
+
|
| 36 |
+
def trim_audio(self, limit: int, pad_token: int, ungenerated: int) -> torch.Tensor:
|
| 37 |
+
trimmed = self.audio_buf[:, :, :limit]
|
| 38 |
+
pad = torch.full_like(trimmed, pad_token)
|
| 39 |
+
trimmed = torch.where(trimmed == ungenerated, pad, trimmed)
|
| 40 |
+
self.audio_buf = trimmed
|
| 41 |
+
return trimmed
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def transformer_cache(self) -> KVCache:
|
| 45 |
+
return self.decode.transformer
|
| 46 |
+
|
| 47 |
+
@transformer_cache.setter
|
| 48 |
+
def transformer_cache(self, cache: KVCache) -> None:
|
| 49 |
+
self.decode.transformer = cache
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def depformer_cache(self) -> KVCache:
|
| 53 |
+
return self.decode.depformer
|
| 54 |
+
|
| 55 |
+
@depformer_cache.setter
|
| 56 |
+
def depformer_cache(self, cache: KVCache) -> None:
|
| 57 |
+
self.decode.depformer = cache
|
| 58 |
+
|
| 59 |
+
def reset_dep_cache(self) -> None:
|
| 60 |
+
self.decode.depformer.reset()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class NetworkBuffers:
|
| 65 |
+
text: torch.Tensor
|
| 66 |
+
cb0: torch.Tensor
|
| 67 |
+
dep: list[torch.Tensor]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _allocate_network_buffers(runtime: RuntimeContext, branches: int) -> NetworkBuffers:
|
| 71 |
+
device = runtime.device
|
| 72 |
+
logits_dtype = runtime.precision.logits
|
| 73 |
+
data_cfg = runtime.config.data
|
| 74 |
+
text_logits = torch.empty((branches, 1, data_cfg.action_vocab_size), dtype=logits_dtype, device=device)
|
| 75 |
+
cb0_logits = torch.empty((branches, 1, data_cfg.audio_vocab_size), dtype=logits_dtype, device=device)
|
| 76 |
+
dep_vocab = runtime.model.depformer.audio_vocab_limit or data_cfg.audio_vocab_size
|
| 77 |
+
dep_logits = [
|
| 78 |
+
torch.empty((branches, 1, 1, dep_vocab), dtype=logits_dtype, device=device)
|
| 79 |
+
for _ in range(runtime.model.depformer.num_depth)
|
| 80 |
+
]
|
| 81 |
+
return NetworkBuffers(text=text_logits, cb0=cb0_logits, dep=dep_logits)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_initial_state(
|
| 85 |
+
runtime: RuntimeContext,
|
| 86 |
+
*,
|
| 87 |
+
prefix: PrefixPlan | None = None,
|
| 88 |
+
) -> GenerationState:
|
| 89 |
+
dep_q = runtime.model.depformer.num_audio_channels
|
| 90 |
+
channels = 2 + dep_q
|
| 91 |
+
branches = 2
|
| 92 |
+
token_ids = runtime.constants
|
| 93 |
+
step_tokens = torch.full(
|
| 94 |
+
(branches, channels, 1),
|
| 95 |
+
token_ids.pad,
|
| 96 |
+
dtype=torch.long,
|
| 97 |
+
device=runtime.device,
|
| 98 |
+
)
|
| 99 |
+
step_tokens[0, 0, 0] = token_ids.bos
|
| 100 |
+
step_tokens[0, 1, 0] = token_ids.pad
|
| 101 |
+
step_tokens[1, 0, 0] = token_ids.zero
|
| 102 |
+
step_tokens[1, 1, 0] = token_ids.pad
|
| 103 |
+
prefix_len = 0
|
| 104 |
+
if prefix is not None:
|
| 105 |
+
delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad)
|
| 106 |
+
prefix_len = delayed.shape[1]
|
| 107 |
+
limit = runtime.config.runtime.max_context_steps
|
| 108 |
+
total_steps = max(limit + prefix_len + 1, limit)
|
| 109 |
+
decode_state = runtime.model.init_state(branches, runtime.device, total_steps)
|
| 110 |
+
audio_buf = torch.full(
|
| 111 |
+
(branches, dep_q, total_steps),
|
| 112 |
+
token_ids.ungenerated,
|
| 113 |
+
dtype=torch.long,
|
| 114 |
+
device=runtime.device,
|
| 115 |
+
)
|
| 116 |
+
if prefix is not None:
|
| 117 |
+
delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad).to(runtime.device)
|
| 118 |
+
audio_buf[0, :, : delayed.shape[1]] = delayed
|
| 119 |
+
if branches > 1:
|
| 120 |
+
audio_buf[1:, :, : delayed.shape[1]] = delayed
|
| 121 |
+
return GenerationState(decode_state, step_tokens, audio_buf)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _fill_audio_channels(
|
| 125 |
+
step_tokens: torch.Tensor,
|
| 126 |
+
audio_buf: torch.Tensor,
|
| 127 |
+
delays: torch.Tensor,
|
| 128 |
+
step: int,
|
| 129 |
+
bos_token: int,
|
| 130 |
+
) -> None:
|
| 131 |
+
channels = delays.numel()
|
| 132 |
+
if channels == 0:
|
| 133 |
+
return
|
| 134 |
+
target = step_tokens[:, 2 : 2 + channels, 0]
|
| 135 |
+
if step < audio_buf.shape[-1]:
|
| 136 |
+
target.copy_(audio_buf[:, :channels, step])
|
| 137 |
+
else:
|
| 138 |
+
target.fill_(bos_token)
|
| 139 |
+
mask = delays > step
|
| 140 |
+
if mask.any().item():
|
| 141 |
+
target[:, mask] = bos_token
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _execute_transformer_step(
|
| 145 |
+
step_tokens: torch.Tensor,
|
| 146 |
+
positions_view: torch.Tensor,
|
| 147 |
+
generation: GenerationState,
|
| 148 |
+
transformer_step,
|
| 149 |
+
buffers: NetworkBuffers,
|
| 150 |
+
) -> torch.Tensor:
|
| 151 |
+
hidden_t, text_logits_t, cb0_logits_t, present = transformer_step(
|
| 152 |
+
step_tokens,
|
| 153 |
+
positions_view,
|
| 154 |
+
generation.transformer_cache,
|
| 155 |
+
)
|
| 156 |
+
buffers.text.copy_(text_logits_t)
|
| 157 |
+
buffers.cb0.copy_(cb0_logits_t)
|
| 158 |
+
generation.transformer_cache = present
|
| 159 |
+
return hidden_t
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _execute_depformer_stage(
|
| 163 |
+
stage_index: int,
|
| 164 |
+
prev_audio: torch.Tensor,
|
| 165 |
+
hidden_t: torch.Tensor,
|
| 166 |
+
generation: GenerationState,
|
| 167 |
+
depformer_step,
|
| 168 |
+
main_tokens: Optional[torch.Tensor],
|
| 169 |
+
second_tokens: Optional[torch.Tensor],
|
| 170 |
+
buffers: NetworkBuffers,
|
| 171 |
+
) -> None:
|
| 172 |
+
logits_stage, dep_present = depformer_step(
|
| 173 |
+
prev_audio=prev_audio,
|
| 174 |
+
transformer_out=hidden_t,
|
| 175 |
+
stage_index=stage_index,
|
| 176 |
+
cache=generation.depformer_cache,
|
| 177 |
+
main_text=main_tokens if stage_index == 0 else None,
|
| 178 |
+
second_text=second_tokens if stage_index == 0 else None,
|
| 179 |
+
)
|
| 180 |
+
target = buffers.dep[stage_index]
|
| 181 |
+
if logits_stage.shape != target.shape:
|
| 182 |
+
raise RuntimeError(
|
| 183 |
+
f"depformer logits shape mismatch: {logits_stage.shape} vs {target.shape}"
|
| 184 |
+
)
|
| 185 |
+
target.copy_(logits_stage)
|
| 186 |
+
generation.depformer_cache = dep_present
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def run_generation_loop(
|
| 192 |
+
runtime: RuntimeContext,
|
| 193 |
+
*,
|
| 194 |
+
state: State,
|
| 195 |
+
generation: GenerationState,
|
| 196 |
+
config: GenerationConfig,
|
| 197 |
+
start_step: int = 0,
|
| 198 |
+
logger: RuntimeLogger | None = None,
|
| 199 |
+
) -> tuple[Optional[int], torch.Tensor]:
|
| 200 |
+
step_tokens = generation.step_tokens
|
| 201 |
+
audio_buf = generation.audio_buf
|
| 202 |
+
branches = step_tokens.shape[0]
|
| 203 |
+
max_context = runtime.config.runtime.max_context_steps
|
| 204 |
+
if max_context <= 0:
|
| 205 |
+
raise ValueError("Runtime configuration must specify a positive max_context_steps")
|
| 206 |
+
positions = torch.empty(1, 1, dtype=torch.long, device=runtime.device)
|
| 207 |
+
main_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
|
| 208 |
+
aux_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
|
| 209 |
+
cfg_active = config.cfg_scale != 1.0
|
| 210 |
+
token_ids = runtime.constants
|
| 211 |
+
delay_tensor = runtime.audio_delay_tensor
|
| 212 |
+
max_delay = int(delay_tensor.max().item()) if delay_tensor.numel() else 0
|
| 213 |
+
flush_tail = max_delay + getattr(runtime.machine, "max_padding", 0)
|
| 214 |
+
first_word_frame: Optional[int] = None
|
| 215 |
+
eos_cutoff: Optional[int] = None
|
| 216 |
+
last_step = start_step - 1
|
| 217 |
+
use_graph = bool(config.use_cuda_graph and runtime.device.type == "cuda")
|
| 218 |
+
transformer_step = runtime.transformer_step
|
| 219 |
+
depformer_step = runtime.depformer_step
|
| 220 |
+
buffers = _allocate_network_buffers(runtime, branches)
|
| 221 |
+
positions_view = positions.expand(branches, -1)
|
| 222 |
+
transformer_capture = None
|
| 223 |
+
dep_captures: list[dict] | None = None
|
| 224 |
+
if use_graph:
|
| 225 |
+
_ensure_graph_cublas_ready(runtime.device)
|
| 226 |
+
processed_steps = 0
|
| 227 |
+
report_interval = 12
|
| 228 |
+
with torch.inference_mode():
|
| 229 |
+
for offset in range(max_context):
|
| 230 |
+
t = start_step + offset
|
| 231 |
+
if eos_cutoff is not None and t >= eos_cutoff:
|
| 232 |
+
break
|
| 233 |
+
if t + 1 >= audio_buf.shape[-1]:
|
| 234 |
+
break
|
| 235 |
+
generation.reset_dep_cache()
|
| 236 |
+
positions.fill_(t)
|
| 237 |
+
_fill_audio_channels(step_tokens, audio_buf, delay_tensor, t, token_ids.audio_bos)
|
| 238 |
+
if branches > 1:
|
| 239 |
+
step_tokens[1:, 0, 0] = token_ids.zero
|
| 240 |
+
step_tokens[1:, 1, 0] = token_ids.pad
|
| 241 |
+
if use_graph:
|
| 242 |
+
if transformer_capture is None:
|
| 243 |
+
torch.cuda.synchronize()
|
| 244 |
+
graph = torch.cuda.CUDAGraph()
|
| 245 |
+
with torch.cuda.graph(graph):
|
| 246 |
+
hidden_ref = _execute_transformer_step(
|
| 247 |
+
step_tokens,
|
| 248 |
+
positions_view,
|
| 249 |
+
generation,
|
| 250 |
+
transformer_step,
|
| 251 |
+
buffers,
|
| 252 |
+
)
|
| 253 |
+
transformer_capture = (graph, hidden_ref)
|
| 254 |
+
if runtime.model.depformer.num_depth > 0:
|
| 255 |
+
dep_captures = []
|
| 256 |
+
for idx in range(runtime.model.depformer.num_depth):
|
| 257 |
+
capture = {
|
| 258 |
+
"graph": torch.cuda.CUDAGraph(),
|
| 259 |
+
"captured": False,
|
| 260 |
+
"prev_audio": torch.empty((branches,), dtype=torch.long, device=runtime.device),
|
| 261 |
+
"main_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
|
| 262 |
+
"second_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
|
| 263 |
+
}
|
| 264 |
+
dep_captures.append(capture)
|
| 265 |
+
else:
|
| 266 |
+
transformer_capture[0].replay()
|
| 267 |
+
hidden_t = transformer_capture[1]
|
| 268 |
+
else:
|
| 269 |
+
hidden_t = _execute_transformer_step(
|
| 270 |
+
step_tokens,
|
| 271 |
+
positions_view,
|
| 272 |
+
generation,
|
| 273 |
+
transformer_step,
|
| 274 |
+
buffers,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
guided_text = apply_classifier_guidance(buffers.text, cfg_active, config.cfg_scale, config.cfg_filter_k)
|
| 278 |
+
if guided_text.shape[0] > 1:
|
| 279 |
+
guided_text = guided_text[:1]
|
| 280 |
+
text_token = sample_token(
|
| 281 |
+
guided_text,
|
| 282 |
+
temp=config.text.temperature,
|
| 283 |
+
top_k=config.text.top_k,
|
| 284 |
+
).item()
|
| 285 |
+
|
| 286 |
+
main_token, aux_token, _ = runtime.machine.process(t, state, text_token)
|
| 287 |
+
second_token = aux_token if aux_token != -1 else token_ids.pad
|
| 288 |
+
if first_word_frame is None and main_token == token_ids.new_word:
|
| 289 |
+
first_word_frame = t - config.initial_padding
|
| 290 |
+
step_tokens[:, 0, 0] = main_token
|
| 291 |
+
step_tokens[:, 1, 0] = second_token
|
| 292 |
+
|
| 293 |
+
guided_cb0 = apply_classifier_guidance(buffers.cb0, cfg_active, config.cfg_scale, config.cfg_filter_k)
|
| 294 |
+
if guided_cb0.shape[0] > 1:
|
| 295 |
+
guided_cb0 = guided_cb0[:1]
|
| 296 |
+
masked_cb0 = mask_audio_logits(guided_cb0, token_ids.audio_pad, token_ids.audio_bos)
|
| 297 |
+
codebook_token = sample_audio_logits(masked_cb0, config.audio.temperature, config.audio.top_k)
|
| 298 |
+
audio_buf[:, 0, t + 1] = codebook_token
|
| 299 |
+
|
| 300 |
+
prev_audio = codebook_token.expand(branches)
|
| 301 |
+
main_tokens.fill_(main_token)
|
| 302 |
+
aux_tokens.fill_(second_token)
|
| 303 |
+
for stage in range(runtime.model.depformer.num_depth):
|
| 304 |
+
if use_graph and dep_captures is not None:
|
| 305 |
+
capture = dep_captures[stage]
|
| 306 |
+
capture["prev_audio"].copy_(prev_audio)
|
| 307 |
+
if capture["main_tokens"] is not None and stage == 0:
|
| 308 |
+
capture["main_tokens"].copy_(main_tokens)
|
| 309 |
+
capture["second_tokens"].copy_(aux_tokens)
|
| 310 |
+
if not capture["captured"]:
|
| 311 |
+
torch.cuda.synchronize()
|
| 312 |
+
with torch.cuda.graph(capture["graph"]):
|
| 313 |
+
_execute_depformer_stage(
|
| 314 |
+
stage_index=stage,
|
| 315 |
+
prev_audio=capture["prev_audio"],
|
| 316 |
+
hidden_t=hidden_t,
|
| 317 |
+
generation=generation,
|
| 318 |
+
depformer_step=depformer_step,
|
| 319 |
+
main_tokens=capture["main_tokens"],
|
| 320 |
+
second_tokens=capture["second_tokens"],
|
| 321 |
+
buffers=buffers,
|
| 322 |
+
)
|
| 323 |
+
capture["captured"] = True
|
| 324 |
+
else:
|
| 325 |
+
capture["graph"].replay()
|
| 326 |
+
else:
|
| 327 |
+
_execute_depformer_stage(
|
| 328 |
+
stage_index=stage,
|
| 329 |
+
prev_audio=prev_audio,
|
| 330 |
+
hidden_t=hidden_t,
|
| 331 |
+
generation=generation,
|
| 332 |
+
depformer_step=depformer_step,
|
| 333 |
+
main_tokens=main_tokens,
|
| 334 |
+
second_tokens=aux_tokens,
|
| 335 |
+
buffers=buffers,
|
| 336 |
+
)
|
| 337 |
+
dep_logits = apply_classifier_guidance(buffers.dep[stage], cfg_active, config.cfg_scale, config.cfg_filter_k)
|
| 338 |
+
if dep_logits.shape[0] > 1:
|
| 339 |
+
dep_logits = dep_logits[:1]
|
| 340 |
+
stage_token = sample_audio_logits(
|
| 341 |
+
dep_logits,
|
| 342 |
+
config.audio.temperature,
|
| 343 |
+
config.audio.top_k,
|
| 344 |
+
)
|
| 345 |
+
audio_buf[:, stage + 1, t + 1] = stage_token
|
| 346 |
+
prev_audio = stage_token.expand(branches)
|
| 347 |
+
last_step = t
|
| 348 |
+
if eos_cutoff is None and state.end_step is not None:
|
| 349 |
+
eos_cutoff = state.end_step + flush_tail
|
| 350 |
+
processed_steps = offset + 1
|
| 351 |
+
if logger and processed_steps % report_interval == 0:
|
| 352 |
+
logger.progress(processed_steps, max_context)
|
| 353 |
+
|
| 354 |
+
if logger and processed_steps and processed_steps % report_interval != 0:
|
| 355 |
+
logger.progress(processed_steps, max_context)
|
| 356 |
+
|
| 357 |
+
if first_word_frame is None:
|
| 358 |
+
first_word_frame = start_step
|
| 359 |
+
if last_step < start_step:
|
| 360 |
+
limit = min(start_step + 1, audio_buf.shape[-1])
|
| 361 |
+
else:
|
| 362 |
+
limit = min(last_step + 2, audio_buf.shape[-1])
|
| 363 |
+
trimmed = generation.trim_audio(limit, token_ids.audio_pad, token_ids.ungenerated)
|
| 364 |
+
return first_word_frame, trimmed
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def decode_audio(runtime: RuntimeContext, tokens: torch.Tensor) -> torch.Tensor:
|
| 368 |
+
if tokens.shape[-1] == 0:
|
| 369 |
+
return torch.zeros(0, device=runtime.device)
|
| 370 |
+
with torch.inference_mode():
|
| 371 |
+
pcm = runtime.mimi.decode(tokens.to(runtime.device))
|
| 372 |
+
return pcm[0, 0]
|
| 373 |
+
|
| 374 |
+
def warmup_with_prefix(
|
| 375 |
+
runtime: RuntimeContext,
|
| 376 |
+
plan: PrefixPlan,
|
| 377 |
+
state: State,
|
| 378 |
+
generation: GenerationState,
|
| 379 |
+
) -> int:
|
| 380 |
+
step_tokens = generation.step_tokens
|
| 381 |
+
model_state = generation.decode
|
| 382 |
+
branches = step_tokens.shape[0]
|
| 383 |
+
device = runtime.device
|
| 384 |
+
tokens = plan.aligned_tokens.to(device)
|
| 385 |
+
new_word_steps = set(plan.new_word_steps)
|
| 386 |
+
positions = torch.empty(1, 1, dtype=torch.long, device=device)
|
| 387 |
+
|
| 388 |
+
with torch.inference_mode():
|
| 389 |
+
for t in range(plan.aligned_frames):
|
| 390 |
+
positions.fill_(t)
|
| 391 |
+
channels = tokens.shape[0]
|
| 392 |
+
for cb in range(channels):
|
| 393 |
+
delay = runtime.audio_delays[cb] if cb < len(runtime.audio_delays) else 0
|
| 394 |
+
idx = t - delay
|
| 395 |
+
value = tokens[cb, idx] if idx >= 0 else runtime.constants.audio_bos
|
| 396 |
+
step_tokens[:, 2 + cb, 0] = value
|
| 397 |
+
hidden, text_logits, cb0_logits, present = runtime.model.transformer.forward_step(
|
| 398 |
+
step_tokens,
|
| 399 |
+
positions.expand(branches, -1),
|
| 400 |
+
model_state.transformer,
|
| 401 |
+
)
|
| 402 |
+
model_state.transformer = present
|
| 403 |
+
|
| 404 |
+
forced = runtime.constants.new_word if t in new_word_steps else runtime.constants.pad
|
| 405 |
+
main_token, aux_token, _ = runtime.machine.process(t, state, forced, is_forced=True)
|
| 406 |
+
second_token = runtime.constants.pad if aux_token == -1 else aux_token
|
| 407 |
+
step_tokens[0, 0, 0] = main_token
|
| 408 |
+
step_tokens[0, 1, 0] = second_token
|
| 409 |
+
if branches > 1:
|
| 410 |
+
step_tokens[1:, 0, 0] = runtime.constants.zero
|
| 411 |
+
step_tokens[1:, 1, 0] = runtime.constants.pad
|
| 412 |
+
|
| 413 |
+
return max(plan.aligned_frames - 1, 0)
|
| 414 |
+
__all__ = [
|
| 415 |
+
"build_initial_state",
|
| 416 |
+
"run_generation_loop",
|
| 417 |
+
"decode_audio",
|
| 418 |
+
"warmup_with_prefix",
|
| 419 |
+
"GenerationState",
|
| 420 |
+
]
|
dia2/runtime/guidance.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .sampler import sample_token
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def apply_classifier_guidance(
|
| 9 |
+
logits: torch.Tensor,
|
| 10 |
+
cfg_active: bool,
|
| 11 |
+
scale: float,
|
| 12 |
+
top_k: int,
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
+
if not cfg_active:
|
| 15 |
+
return logits
|
| 16 |
+
conditional = logits[0:1]
|
| 17 |
+
unconditional = logits[1:2]
|
| 18 |
+
cond32 = conditional.to(torch.float32)
|
| 19 |
+
uncond32 = unconditional.to(torch.float32)
|
| 20 |
+
guided = torch.lerp(uncond32, cond32, scale)
|
| 21 |
+
if top_k > 0 and guided.shape[-1] > 0:
|
| 22 |
+
k = min(top_k, guided.shape[-1])
|
| 23 |
+
threshold = torch.topk(guided, k=k, dim=-1, sorted=False).values[..., -1:]
|
| 24 |
+
mask = guided >= threshold
|
| 25 |
+
neg_inf = torch.full_like(cond32, float("-inf"))
|
| 26 |
+
cond32 = torch.where(mask, cond32, neg_inf)
|
| 27 |
+
return cond32.to(conditional.dtype)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def sample_audio_logits(logits: torch.Tensor, temp: float, top_k: int) -> torch.Tensor:
|
| 31 |
+
"""Sample a single audio token (shape [1]) from logits."""
|
| 32 |
+
return (
|
| 33 |
+
sample_token(
|
| 34 |
+
logits,
|
| 35 |
+
temp=temp,
|
| 36 |
+
top_k=top_k,
|
| 37 |
+
).view(1)
|
| 38 |
+
)
|
dia2/runtime/logger.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
class RuntimeLogger:
|
| 5 |
+
def __init__(self, enabled: bool) -> None:
|
| 6 |
+
self.enabled = enabled
|
| 7 |
+
self.start_time = time.perf_counter()
|
| 8 |
+
self.last_time = self.start_time
|
| 9 |
+
self.last_step = 0
|
| 10 |
+
|
| 11 |
+
def event(self, message: str) -> None:
|
| 12 |
+
if self.enabled:
|
| 13 |
+
print(f"[dia2] {message}")
|
| 14 |
+
|
| 15 |
+
def progress(self, step: int, total: Optional[int] = None) -> None:
|
| 16 |
+
if not self.enabled:
|
| 17 |
+
return
|
| 18 |
+
now = time.perf_counter()
|
| 19 |
+
delta_t = max(now - self.last_time, 1e-6)
|
| 20 |
+
delta_steps = max(step - self.last_step, 1)
|
| 21 |
+
speed = delta_steps / delta_t
|
| 22 |
+
if total is None:
|
| 23 |
+
self.event(f"step {step} :: {speed:.1f} toks/s")
|
| 24 |
+
else:
|
| 25 |
+
self.event(f"step {step}/{total} :: {speed:.1f} toks/s")
|
| 26 |
+
self.last_time = now
|
| 27 |
+
self.last_step = step
|
| 28 |
+
|
| 29 |
+
def elapsed(self) -> float:
|
| 30 |
+
return time.perf_counter() - self.start_time
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
__all__ = ["RuntimeLogger"]
|
dia2/runtime/sampler.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def sample_token(
|
| 7 |
+
logits: torch.Tensor,
|
| 8 |
+
*,
|
| 9 |
+
temp: float,
|
| 10 |
+
top_k: int = 0,
|
| 11 |
+
) -> torch.Tensor:
|
| 12 |
+
logits32 = logits.to(torch.float32)
|
| 13 |
+
if temp <= 0.0:
|
| 14 |
+
return torch.argmax(logits32, dim=-1, keepdim=True)
|
| 15 |
+
probs = torch.softmax(logits32 / max(temp, 1e-6), dim=-1)
|
| 16 |
+
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 17 |
+
probs = torch.clamp_min(probs, 0.0)
|
| 18 |
+
flat = probs.reshape(-1, probs.shape[-1])
|
| 19 |
+
norm = flat.sum(dim=-1, keepdim=True)
|
| 20 |
+
zero_mask = norm <= 0
|
| 21 |
+
norm = norm.clamp_min(1e-12)
|
| 22 |
+
flat = flat / norm
|
| 23 |
+
if zero_mask.any():
|
| 24 |
+
filler = torch.zeros_like(flat)
|
| 25 |
+
filler[..., 0] = 1.0
|
| 26 |
+
mask = zero_mask.expand_as(flat)
|
| 27 |
+
flat = torch.where(mask, filler, flat)
|
| 28 |
+
vocab = flat.shape[-1]
|
| 29 |
+
if top_k > 0 and top_k < vocab:
|
| 30 |
+
topv, indices = torch.topk(flat, top_k, dim=-1)
|
| 31 |
+
topv = topv / topv.sum(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 32 |
+
draws = torch.multinomial(topv, num_samples=1)
|
| 33 |
+
picks = torch.gather(indices, dim=-1, index=draws)
|
| 34 |
+
else:
|
| 35 |
+
picks = torch.multinomial(flat, num_samples=1)
|
| 36 |
+
picks = picks.reshape(*probs.shape[:-1], 1)
|
| 37 |
+
return picks
|
dia2/runtime/script_parser.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Optional, Sequence
|
| 5 |
+
|
| 6 |
+
from .state_machine import Entry
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_script(
|
| 10 |
+
script: Sequence[str],
|
| 11 |
+
tokenizer,
|
| 12 |
+
constants,
|
| 13 |
+
frame_rate: float,
|
| 14 |
+
) -> List[Entry]:
|
| 15 |
+
entries: List[Entry] = []
|
| 16 |
+
speaker_tokens = [constants.spk1, constants.spk2]
|
| 17 |
+
padding_between = 1
|
| 18 |
+
event_re = re.compile(r"(?:<break\s+time=\"([0-9]+(?:.[0-9]*)?)s\"\s*/?>)|(?:\s+)")
|
| 19 |
+
last_speaker_idx = [None]
|
| 20 |
+
|
| 21 |
+
def add_entry(idx: int, word: str, *, pending: Optional[int], first_content: List[bool]):
|
| 22 |
+
tokens: List[int]
|
| 23 |
+
if pending is not None:
|
| 24 |
+
prefix = "[S1]" if pending == constants.spk1 else "[S2]"
|
| 25 |
+
tokens = tokenizer.encode(f"{prefix} {word}", add_special_tokens=False)
|
| 26 |
+
else:
|
| 27 |
+
tokens = tokenizer.encode(word, add_special_tokens=False)
|
| 28 |
+
if first_content[0]:
|
| 29 |
+
if speaker_tokens:
|
| 30 |
+
speaker_idx = idx % len(speaker_tokens)
|
| 31 |
+
speaker_token = speaker_tokens[speaker_idx]
|
| 32 |
+
if speaker_token is not None and last_speaker_idx[0] != speaker_idx:
|
| 33 |
+
if not tokens or tokens[0] != speaker_token:
|
| 34 |
+
tokens.insert(0, speaker_token)
|
| 35 |
+
last_speaker_idx[0] = speaker_idx
|
| 36 |
+
first_content[0] = False
|
| 37 |
+
padding = max(0, padding_between + len(tokens) - 1)
|
| 38 |
+
entries.append(Entry(tokens=tokens, text=word, padding=padding))
|
| 39 |
+
|
| 40 |
+
for idx, line in enumerate(script):
|
| 41 |
+
normalized = line.replace("’", "'").replace(":", " ")
|
| 42 |
+
remaining = normalized
|
| 43 |
+
first_content = [True]
|
| 44 |
+
pending_speaker: Optional[int] = None
|
| 45 |
+
while remaining:
|
| 46 |
+
match = event_re.search(remaining)
|
| 47 |
+
if match is None:
|
| 48 |
+
segment = remaining
|
| 49 |
+
remaining = ""
|
| 50 |
+
else:
|
| 51 |
+
segment = remaining[: match.start()]
|
| 52 |
+
remaining = remaining[match.end() :]
|
| 53 |
+
if segment:
|
| 54 |
+
for raw_word in segment.split():
|
| 55 |
+
if raw_word in ("[S1]", "[S2]"):
|
| 56 |
+
pending_speaker = (
|
| 57 |
+
constants.spk1 if raw_word == "[S1]" else constants.spk2
|
| 58 |
+
)
|
| 59 |
+
continue
|
| 60 |
+
add_entry(idx, raw_word, pending=pending_speaker, first_content=first_content)
|
| 61 |
+
pending_speaker = None
|
| 62 |
+
if match and match.group(1):
|
| 63 |
+
seconds = float(match.group(1))
|
| 64 |
+
padding = int(round(seconds * frame_rate))
|
| 65 |
+
if padding > 0:
|
| 66 |
+
entries.append(Entry(tokens=[], text="", padding=padding))
|
| 67 |
+
if remaining:
|
| 68 |
+
continue
|
| 69 |
+
return entries
|
dia2/runtime/state_machine.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import deque
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Deque, Iterable, List, Sequence, Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class TokenIds:
|
| 10 |
+
card: int
|
| 11 |
+
new_word: int
|
| 12 |
+
pad: int
|
| 13 |
+
bos: int
|
| 14 |
+
zero: int
|
| 15 |
+
spk1: int
|
| 16 |
+
spk2: int
|
| 17 |
+
audio_pad: int
|
| 18 |
+
audio_bos: int
|
| 19 |
+
ungenerated: int = -2
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class Entry:
|
| 24 |
+
tokens: List[int]
|
| 25 |
+
text: str
|
| 26 |
+
padding: int = 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class State:
|
| 31 |
+
entries: Deque[Entry]
|
| 32 |
+
padding_budget: int
|
| 33 |
+
forced_padding: int
|
| 34 |
+
pending_tokens: Deque[int] = field(default_factory=deque)
|
| 35 |
+
lookahead_tokens: Deque[int] = field(default_factory=deque)
|
| 36 |
+
end_step: int | None = None
|
| 37 |
+
consumption_times: List[int] = field(default_factory=list)
|
| 38 |
+
transcript: List[Tuple[str, int]] = field(default_factory=list)
|
| 39 |
+
|
| 40 |
+
def peek_tokens(self, count: int) -> List[int]:
|
| 41 |
+
"""Return tokens from upcoming entries (used for second-stream lookahead)."""
|
| 42 |
+
assert count > 0
|
| 43 |
+
for entry in self.entries:
|
| 44 |
+
if entry.tokens:
|
| 45 |
+
count -= 1
|
| 46 |
+
if count == 0:
|
| 47 |
+
return entry.tokens
|
| 48 |
+
return []
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class StateMachine:
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
token_ids: TokenIds,
|
| 55 |
+
*,
|
| 56 |
+
second_stream_ahead: int = 0,
|
| 57 |
+
max_padding: int = 6,
|
| 58 |
+
initial_padding: int = 0,
|
| 59 |
+
) -> None:
|
| 60 |
+
self.token_ids = token_ids
|
| 61 |
+
self.second_stream_ahead = second_stream_ahead
|
| 62 |
+
self.max_padding = max_padding
|
| 63 |
+
self.initial_padding = initial_padding
|
| 64 |
+
|
| 65 |
+
def new_state(self, entries: Iterable[Entry]) -> State:
|
| 66 |
+
return State(
|
| 67 |
+
entries=deque(entries),
|
| 68 |
+
padding_budget=self.initial_padding,
|
| 69 |
+
forced_padding=self.initial_padding,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def process(
|
| 73 |
+
self,
|
| 74 |
+
step: int,
|
| 75 |
+
state: State,
|
| 76 |
+
token: int,
|
| 77 |
+
is_forced: bool = False,
|
| 78 |
+
) -> Tuple[int, int, bool]:
|
| 79 |
+
token = self._sanitize_token(token)
|
| 80 |
+
token = self._enforce_token_constraints(state, token, is_forced)
|
| 81 |
+
token, consumed_new_word = self._handle_new_word(step, state, token)
|
| 82 |
+
output_token = self._select_output_token(state, token)
|
| 83 |
+
final_main, final_second = self._maybe_multiplex_second_stream(
|
| 84 |
+
state, output_token
|
| 85 |
+
)
|
| 86 |
+
return final_main, final_second, consumed_new_word
|
| 87 |
+
|
| 88 |
+
def _sanitize_token(self, token: int) -> int:
|
| 89 |
+
if token == 1:
|
| 90 |
+
token = self.token_ids.new_word
|
| 91 |
+
elif token == 0:
|
| 92 |
+
token = self.token_ids.pad
|
| 93 |
+
if token not in (self.token_ids.new_word, self.token_ids.pad):
|
| 94 |
+
return self.token_ids.pad
|
| 95 |
+
return token
|
| 96 |
+
|
| 97 |
+
def _enforce_token_constraints(
|
| 98 |
+
self, state: State, token: int, is_forced: bool
|
| 99 |
+
) -> int:
|
| 100 |
+
if state.pending_tokens:
|
| 101 |
+
return self.token_ids.pad
|
| 102 |
+
if is_forced:
|
| 103 |
+
return token
|
| 104 |
+
if state.forced_padding > 0:
|
| 105 |
+
if token != self.token_ids.pad:
|
| 106 |
+
token = self.token_ids.pad
|
| 107 |
+
return token
|
| 108 |
+
if state.padding_budget <= 0 and token != self.token_ids.new_word:
|
| 109 |
+
return self.token_ids.new_word
|
| 110 |
+
return token
|
| 111 |
+
|
| 112 |
+
def _handle_new_word(
|
| 113 |
+
self, step: int, state: State, token: int
|
| 114 |
+
) -> Tuple[int, bool]:
|
| 115 |
+
if token != self.token_ids.new_word:
|
| 116 |
+
return token, False
|
| 117 |
+
if state.entries:
|
| 118 |
+
entry = state.entries.popleft()
|
| 119 |
+
state.consumption_times.append(step)
|
| 120 |
+
if entry.tokens:
|
| 121 |
+
state.transcript.append((entry.text, step))
|
| 122 |
+
state.pending_tokens.extend(entry.tokens)
|
| 123 |
+
if self.second_stream_ahead:
|
| 124 |
+
state.lookahead_tokens.extend(
|
| 125 |
+
state.peek_tokens(self.second_stream_ahead)
|
| 126 |
+
)
|
| 127 |
+
state.padding_budget = self.max_padding
|
| 128 |
+
else:
|
| 129 |
+
token = self.token_ids.pad
|
| 130 |
+
state.forced_padding = entry.padding
|
| 131 |
+
return token, True
|
| 132 |
+
token = self.token_ids.pad
|
| 133 |
+
if self.second_stream_ahead and state.end_step is None:
|
| 134 |
+
token = self.token_ids.new_word
|
| 135 |
+
if state.end_step is None:
|
| 136 |
+
state.end_step = step
|
| 137 |
+
return token, False
|
| 138 |
+
|
| 139 |
+
def _select_output_token(self, state: State, token: int) -> int:
|
| 140 |
+
if token == self.token_ids.pad:
|
| 141 |
+
if state.padding_budget > 0:
|
| 142 |
+
state.padding_budget -= 1
|
| 143 |
+
if state.forced_padding > 0:
|
| 144 |
+
state.forced_padding -= 1
|
| 145 |
+
if state.pending_tokens:
|
| 146 |
+
return state.pending_tokens.popleft()
|
| 147 |
+
return self.token_ids.pad
|
| 148 |
+
if token == self.token_ids.new_word:
|
| 149 |
+
return self.token_ids.new_word
|
| 150 |
+
if token == self.token_ids.zero:
|
| 151 |
+
return token
|
| 152 |
+
raise RuntimeError(f"Invalid token {token}")
|
| 153 |
+
|
| 154 |
+
def _maybe_multiplex_second_stream(
|
| 155 |
+
self, state: State, output: int
|
| 156 |
+
) -> Tuple[int, int]:
|
| 157 |
+
if not self.second_stream_ahead:
|
| 158 |
+
return output, output
|
| 159 |
+
second = -1
|
| 160 |
+
if output == self.token_ids.new_word:
|
| 161 |
+
second = self.token_ids.new_word
|
| 162 |
+
if state.pending_tokens:
|
| 163 |
+
output = state.pending_tokens.popleft()
|
| 164 |
+
else:
|
| 165 |
+
output = self.token_ids.pad
|
| 166 |
+
elif state.lookahead_tokens:
|
| 167 |
+
second = state.lookahead_tokens.popleft()
|
| 168 |
+
else:
|
| 169 |
+
second = self.token_ids.pad
|
| 170 |
+
return output, second
|
dia2/runtime/voice_clone.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable, List, Optional, Sequence, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from ..generation import PrefixConfig
|
| 10 |
+
from .audio_io import encode_audio_tokens, load_mono_audio
|
| 11 |
+
from .state_machine import Entry
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING: # pragma: no cover
|
| 14 |
+
from .context import RuntimeContext
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class WhisperWord:
|
| 19 |
+
text: str
|
| 20 |
+
start: float
|
| 21 |
+
end: float
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class PrefixPlan:
|
| 26 |
+
entries: List[Entry]
|
| 27 |
+
new_word_steps: List[int]
|
| 28 |
+
aligned_tokens: torch.Tensor
|
| 29 |
+
aligned_frames: int
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_prefix_plan(
|
| 33 |
+
runtime: "RuntimeContext",
|
| 34 |
+
prefix: Optional[PrefixConfig],
|
| 35 |
+
*,
|
| 36 |
+
transcribe_fn: Optional[Callable[[str, torch.device], List[WhisperWord]]] = None,
|
| 37 |
+
load_audio_fn: Optional[Callable[[str, int], np.ndarray]] = None,
|
| 38 |
+
encode_fn: Optional[Callable[[np.ndarray], torch.Tensor]] = None,
|
| 39 |
+
) -> Optional[PrefixPlan]:
|
| 40 |
+
if prefix is None:
|
| 41 |
+
return None
|
| 42 |
+
if not prefix.speaker_1:
|
| 43 |
+
if prefix.speaker_2:
|
| 44 |
+
raise ValueError("speaker_2 requires speaker_1 to be provided")
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
transcribe = transcribe_fn or (lambda path, device: transcribe_words(path, device))
|
| 48 |
+
load_audio = load_audio_fn or (lambda path, sr: load_mono_audio(path, sr))
|
| 49 |
+
encode_audio = encode_fn or (lambda audio: encode_audio_tokens(runtime.mimi, audio))
|
| 50 |
+
|
| 51 |
+
entries1, steps1, tokens1 = _process_prefix_audio(
|
| 52 |
+
runtime=runtime,
|
| 53 |
+
audio_path=prefix.speaker_1,
|
| 54 |
+
speaker_token=runtime.constants.spk1,
|
| 55 |
+
transcribe=transcribe,
|
| 56 |
+
load_audio=load_audio,
|
| 57 |
+
encode_audio=encode_audio,
|
| 58 |
+
)
|
| 59 |
+
offset = 3 # Match legacy BOS/PAD offset
|
| 60 |
+
entries = list(entries1)
|
| 61 |
+
new_word_steps = [step + offset for step in steps1]
|
| 62 |
+
audio_tokens = tokens1.to(runtime.device)
|
| 63 |
+
|
| 64 |
+
if prefix.speaker_2:
|
| 65 |
+
entries2, steps2, tokens2 = _process_prefix_audio(
|
| 66 |
+
runtime=runtime,
|
| 67 |
+
audio_path=prefix.speaker_2,
|
| 68 |
+
speaker_token=runtime.constants.spk2,
|
| 69 |
+
transcribe=transcribe,
|
| 70 |
+
load_audio=load_audio,
|
| 71 |
+
encode_audio=encode_audio,
|
| 72 |
+
)
|
| 73 |
+
spk1_frames = audio_tokens.shape[-1]
|
| 74 |
+
new_word_steps.extend(step + spk1_frames for step in steps2)
|
| 75 |
+
entries.extend(entries2)
|
| 76 |
+
audio_tokens = torch.cat([audio_tokens, tokens2.to(runtime.device)], dim=1)
|
| 77 |
+
|
| 78 |
+
return PrefixPlan(
|
| 79 |
+
entries=entries,
|
| 80 |
+
new_word_steps=new_word_steps,
|
| 81 |
+
aligned_tokens=audio_tokens,
|
| 82 |
+
aligned_frames=audio_tokens.shape[-1],
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _process_prefix_audio(
|
| 87 |
+
runtime: "RuntimeContext",
|
| 88 |
+
audio_path: str,
|
| 89 |
+
speaker_token: int,
|
| 90 |
+
*,
|
| 91 |
+
transcribe: Callable[[str, torch.device], List[WhisperWord]],
|
| 92 |
+
load_audio: Callable[[str, int], np.ndarray],
|
| 93 |
+
encode_audio: Callable[[np.ndarray], torch.Tensor],
|
| 94 |
+
) -> tuple[List[Entry], List[int], torch.Tensor]:
|
| 95 |
+
words = transcribe(audio_path, runtime.device)
|
| 96 |
+
entries, steps = words_to_entries(
|
| 97 |
+
words=words,
|
| 98 |
+
tokenizer=runtime.tokenizer,
|
| 99 |
+
speaker_token=speaker_token,
|
| 100 |
+
frame_rate=runtime.frame_rate,
|
| 101 |
+
)
|
| 102 |
+
audio = load_audio(audio_path, runtime.mimi.sample_rate)
|
| 103 |
+
tokens = encode_audio(audio)
|
| 104 |
+
return entries, steps, tokens
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def transcribe_words(
|
| 108 |
+
audio_path: str,
|
| 109 |
+
device: torch.device,
|
| 110 |
+
language: Optional[str] = None,
|
| 111 |
+
) -> List[WhisperWord]:
|
| 112 |
+
import whisper_timestamped as wts # Imported lazily
|
| 113 |
+
|
| 114 |
+
model = wts.load_model("openai/whisper-large-v3", device=str(device))
|
| 115 |
+
result = wts.transcribe(model, audio_path, language=language)
|
| 116 |
+
|
| 117 |
+
words: List[WhisperWord] = []
|
| 118 |
+
for segment in result.get("segments", []):
|
| 119 |
+
for word in segment.get("words", []):
|
| 120 |
+
text = (word.get("text") or word.get("word") or "").strip()
|
| 121 |
+
if not text:
|
| 122 |
+
continue
|
| 123 |
+
words.append(
|
| 124 |
+
WhisperWord(
|
| 125 |
+
text=text,
|
| 126 |
+
start=float(word.get("start", 0.0)),
|
| 127 |
+
end=float(word.get("end", 0.0)),
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
return words
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def words_to_entries(
|
| 134 |
+
*,
|
| 135 |
+
words: Sequence[WhisperWord],
|
| 136 |
+
tokenizer,
|
| 137 |
+
speaker_token: int,
|
| 138 |
+
frame_rate: float,
|
| 139 |
+
) -> tuple[List[Entry], List[int]]:
|
| 140 |
+
entries: List[Entry] = []
|
| 141 |
+
new_word_steps: List[int] = []
|
| 142 |
+
if not words:
|
| 143 |
+
return entries, new_word_steps
|
| 144 |
+
|
| 145 |
+
convert = getattr(tokenizer, "convert_tokens_to_ids", None)
|
| 146 |
+
speaker_prefix: Optional[str] = None
|
| 147 |
+
if callable(convert):
|
| 148 |
+
s1_id = convert("[S1]")
|
| 149 |
+
s2_id = convert("[S2]")
|
| 150 |
+
if speaker_token == s1_id:
|
| 151 |
+
speaker_prefix = "[S1]"
|
| 152 |
+
elif speaker_token == s2_id:
|
| 153 |
+
speaker_prefix = "[S2]"
|
| 154 |
+
pending_prefix: Optional[str] = speaker_prefix
|
| 155 |
+
current_pos = 0
|
| 156 |
+
|
| 157 |
+
for idx, word in enumerate(words):
|
| 158 |
+
tokens = _encode_word(word.text, tokenizer, pending_prefix)
|
| 159 |
+
pending_prefix = None
|
| 160 |
+
start_frame = max(current_pos + 1, int(round(word.start * frame_rate)))
|
| 161 |
+
end_frame = start_frame + len(tokens)
|
| 162 |
+
new_word_steps.append(start_frame - 1)
|
| 163 |
+
|
| 164 |
+
if idx < len(words) - 1:
|
| 165 |
+
next_start = int(round(words[idx + 1].start * frame_rate))
|
| 166 |
+
next_word_start = max(end_frame + 1, next_start)
|
| 167 |
+
else:
|
| 168 |
+
end_time = int(round(words[-1].end * frame_rate))
|
| 169 |
+
next_word_start = max(end_frame + 1, end_time)
|
| 170 |
+
|
| 171 |
+
padding = max(0, next_word_start - start_frame - 1)
|
| 172 |
+
entries.append(Entry(tokens=tokens, text=word.text, padding=padding))
|
| 173 |
+
current_pos = end_frame
|
| 174 |
+
|
| 175 |
+
return entries, new_word_steps
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _encode_word(text: str, tokenizer, prefix: Optional[str]) -> List[int]:
|
| 179 |
+
if prefix:
|
| 180 |
+
return tokenizer.encode(f"{prefix} {text}", add_special_tokens=False)
|
| 181 |
+
return tokenizer.encode(text, add_special_tokens=False)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
__all__ = [
|
| 185 |
+
"PrefixPlan",
|
| 186 |
+
"WhisperWord",
|
| 187 |
+
"build_prefix_plan",
|
| 188 |
+
"transcribe_words",
|
| 189 |
+
"words_to_entries",
|
| 190 |
+
]
|
input.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[S1] Um, like, I don't know, I've never actually, like, been on a real vacation, you know? [S2] Oh, seriously? That's wild. I've, uh, only been on, like, one trip myself, and it was kinda stressful. [S1] Yeah, I always see people going places on, like, Instagram, but then I'm just, um, at home thinking, "Maybe next year." [S2] Honestly, same. I, like, plan stuff in my head but then forget or just, you know, bail at the last minute. [S1] So, we should, like, totally go somewhere together one day. [S2] For real, that would be awesome.
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=70.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "dia2"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Dia2 CUDA-only text-to-speech runtime"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
authors = [{ name = "Dia Contributors" }]
|
| 12 |
+
license = { file = "LICENSE" }
|
| 13 |
+
dependencies = [
|
| 14 |
+
"torch>=2.8.0",
|
| 15 |
+
"numpy>=2.1.0,<3.0",
|
| 16 |
+
"transformers>=4.55.3",
|
| 17 |
+
"safetensors==0.5.3",
|
| 18 |
+
"huggingface-hub>=0.24.7",
|
| 19 |
+
"sphn>=0.2.0",
|
| 20 |
+
"soundfile>=0.12.1",
|
| 21 |
+
"whisper-timestamped>=1.14.2",
|
| 22 |
+
"gradio>=4.44.1",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.optional-dependencies]
|
| 26 |
+
dev = [
|
| 27 |
+
"ruff>=0.6.9",
|
| 28 |
+
"pyright>=1.1.385",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[tool.uv]
|
| 32 |
+
package = true
|
| 33 |
+
|
| 34 |
+
[tool.uv.sources]
|
| 35 |
+
torch = [
|
| 36 |
+
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
[[tool.uv.index]]
|
| 40 |
+
name = "pytorch-cu128"
|
| 41 |
+
url = "https://download.pytorch.org/whl/cu128"
|
| 42 |
+
explicit = true
|
| 43 |
+
|
| 44 |
+
[tool.setuptools]
|
| 45 |
+
packages = ["dia2"]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|