MiniMind Max2 - Efficient MoE Language Model
Browse files- .gitignore +88 -0
- LICENSE +131 -0
- README.md +216 -0
- android/README.md +125 -0
- android/app/ChatScreen.kt +270 -0
- android/app/Mind2Model.kt +256 -0
- android/app/build.gradle +103 -0
- android/jni/CMakeLists.txt +53 -0
- android/jni/mind2_jni.cpp +301 -0
- config.json +59 -0
- configs/__init__.py +15 -0
- configs/model_config.py +154 -0
- examples/quickstart.py +94 -0
- model/__init__.py +52 -0
- model/components.py +274 -0
- model/mind2_model.py +185 -0
- optimization/__init__.py +10 -0
- optimization/export.py +365 -0
- optimization/pruning.py +346 -0
- optimization/quantization.py +311 -0
- pyproject.toml +108 -0
- requirements.txt +32 -0
- scripts/export.py +101 -0
- scripts/train.py +165 -0
- setup.py +91 -0
- training/__init__.py +6 -0
- training/dataset.py +128 -0
- training/distillation.py +320 -0
- training/trainer.py +274 -0
.gitignore
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
*.egg-info/
|
| 24 |
+
.installed.cfg
|
| 25 |
+
*.egg
|
| 26 |
+
|
| 27 |
+
# PyInstaller
|
| 28 |
+
*.manifest
|
| 29 |
+
*.spec
|
| 30 |
+
|
| 31 |
+
# Installer logs
|
| 32 |
+
pip-log.txt
|
| 33 |
+
pip-delete-this-directory.txt
|
| 34 |
+
|
| 35 |
+
# Unit test / coverage reports
|
| 36 |
+
htmlcov/
|
| 37 |
+
.tox/
|
| 38 |
+
.coverage
|
| 39 |
+
.coverage.*
|
| 40 |
+
.cache
|
| 41 |
+
nosetests.xml
|
| 42 |
+
coverage.xml
|
| 43 |
+
*.cover
|
| 44 |
+
.hypothesis/
|
| 45 |
+
.pytest_cache/
|
| 46 |
+
|
| 47 |
+
# Translations
|
| 48 |
+
*.mo
|
| 49 |
+
*.pot
|
| 50 |
+
|
| 51 |
+
# Jupyter Notebook
|
| 52 |
+
.ipynb_checkpoints
|
| 53 |
+
|
| 54 |
+
# pyenv
|
| 55 |
+
.python-version
|
| 56 |
+
|
| 57 |
+
# Environments
|
| 58 |
+
.env
|
| 59 |
+
.venv
|
| 60 |
+
env/
|
| 61 |
+
venv/
|
| 62 |
+
ENV/
|
| 63 |
+
env.bak/
|
| 64 |
+
venv.bak/
|
| 65 |
+
|
| 66 |
+
# IDE
|
| 67 |
+
.idea/
|
| 68 |
+
.vscode/
|
| 69 |
+
*.swp
|
| 70 |
+
*.swo
|
| 71 |
+
*~
|
| 72 |
+
|
| 73 |
+
# OS
|
| 74 |
+
.DS_Store
|
| 75 |
+
Thumbs.db
|
| 76 |
+
|
| 77 |
+
# Project specific
|
| 78 |
+
outputs/
|
| 79 |
+
checkpoints/
|
| 80 |
+
*.pt
|
| 81 |
+
*.bin
|
| 82 |
+
*.safetensors
|
| 83 |
+
*.onnx
|
| 84 |
+
*.gguf
|
| 85 |
+
logs/
|
| 86 |
+
wandb/
|
| 87 |
+
data/
|
| 88 |
+
models/
|
LICENSE
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work.
|
| 38 |
+
|
| 39 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 40 |
+
form, that is based on (or derived from) the Work and for which the
|
| 41 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 42 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 43 |
+
of this License, Derivative Works shall not include works that remain
|
| 44 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 45 |
+
the Work and Derivative Works thereof.
|
| 46 |
+
|
| 47 |
+
"Contribution" shall mean any work of authorship, including
|
| 48 |
+
the original version of the Work and any modifications or additions
|
| 49 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 50 |
+
submitted to the Licensor for inclusion in the Work by the copyright owner
|
| 51 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 52 |
+
the copyright owner.
|
| 53 |
+
|
| 54 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 55 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 56 |
+
subsequently incorporated within the Work.
|
| 57 |
+
|
| 58 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 59 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 60 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 61 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 62 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 63 |
+
Work and such Derivative Works in Source or Object form.
|
| 64 |
+
|
| 65 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 66 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 67 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 68 |
+
(except as stated in this section) patent license to make, have made,
|
| 69 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 70 |
+
where such license applies only to those patent claims licensable
|
| 71 |
+
by such Contributor that are necessarily infringed by their
|
| 72 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 73 |
+
with the Work to which such Contribution(s) was submitted.
|
| 74 |
+
|
| 75 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 76 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 77 |
+
modifications, and in Source or Object form, provided that You
|
| 78 |
+
meet the following conditions:
|
| 79 |
+
|
| 80 |
+
(a) You must give any other recipients of the Work or
|
| 81 |
+
Derivative Works a copy of this License; and
|
| 82 |
+
|
| 83 |
+
(b) You must cause any modified files to carry prominent notices
|
| 84 |
+
stating that You changed the files; and
|
| 85 |
+
|
| 86 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 87 |
+
that You distribute, all copyright, patent, trademark, and
|
| 88 |
+
attribution notices from the Source form of the Work,
|
| 89 |
+
excluding those notices that do not pertain to any part of
|
| 90 |
+
the Derivative Works; and
|
| 91 |
+
|
| 92 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 93 |
+
distribution, then any Derivative Works that You distribute must
|
| 94 |
+
include a readable copy of the attribution notices contained
|
| 95 |
+
within such NOTICE file, excluding those notices that do not
|
| 96 |
+
pertain to any part of the Derivative Works, in at least one
|
| 97 |
+
of the following places: within a NOTICE text file distributed
|
| 98 |
+
as part of the Derivative Works; within the Source form or
|
| 99 |
+
documentation, if provided along with the Derivative Works; or,
|
| 100 |
+
within a display generated by the Derivative Works, if and
|
| 101 |
+
wherever such third-party notices normally appear.
|
| 102 |
+
|
| 103 |
+
You may add Your own attribution notices within Derivative Works
|
| 104 |
+
that You distribute, alongside or as an addendum to the NOTICE text
|
| 105 |
+
from the Work, provided that such additional attribution notices
|
| 106 |
+
cannot be construed as modifying the License.
|
| 107 |
+
|
| 108 |
+
5. Submission of Contributions.
|
| 109 |
+
|
| 110 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 111 |
+
names, trademarks, service marks, or product names of the Licensor.
|
| 112 |
+
|
| 113 |
+
7. Disclaimer of Warranty.
|
| 114 |
+
|
| 115 |
+
8. Limitation of Liability.
|
| 116 |
+
|
| 117 |
+
9. Accepting Warranty or Additional Liability.
|
| 118 |
+
|
| 119 |
+
Copyright 2024 MiniMind Contributors
|
| 120 |
+
|
| 121 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 122 |
+
you may not use this file except in compliance with the License.
|
| 123 |
+
You may obtain a copy of the License at
|
| 124 |
+
|
| 125 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 126 |
+
|
| 127 |
+
Unless required by applicable law or agreed to in writing, software
|
| 128 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 129 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 130 |
+
See the License for the specific language governing permissions and
|
| 131 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags:
|
| 7 |
+
- text-generation
|
| 8 |
+
- moe
|
| 9 |
+
- mixture-of-experts
|
| 10 |
+
- gqa
|
| 11 |
+
- grouped-query-attention
|
| 12 |
+
- edge-deployment
|
| 13 |
+
- mobile
|
| 14 |
+
- android
|
| 15 |
+
- efficient
|
| 16 |
+
- llama-cpp
|
| 17 |
+
pipeline_tag: text-generation
|
| 18 |
+
model-index:
|
| 19 |
+
- name: MiniMind-Max2
|
| 20 |
+
results: []
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
# MiniMind Max2
|
| 24 |
+
|
| 25 |
+
**Tiny Model, Powerful Experience** - A lightweight, efficient language model designed for edge deployment, inspired by MiniMax M2's efficient activated parameters design.
|
| 26 |
+
|
| 27 |
+
## Model Description
|
| 28 |
+
|
| 29 |
+
MiniMind Max2 is a family of efficient language models that leverage Mixture of Experts (MoE) architecture to achieve high performance with minimal active parameters. Only 25% of parameters are activated per token, enabling deployment on resource-constrained devices like smartphones, tablets, and IoT devices.
|
| 30 |
+
|
| 31 |
+
## Key Features
|
| 32 |
+
|
| 33 |
+
- **Efficient MoE Architecture**: Only 25% of parameters activated per token
|
| 34 |
+
- **Grouped Query Attention (GQA)**: 4:1 ratio for memory efficiency
|
| 35 |
+
- **Multiple Model Sizes**: From 500M (Nano) to 3B (Pro) parameters
|
| 36 |
+
- **Edge-Ready**: Runs on Android, iOS, and embedded devices
|
| 37 |
+
- **Easy Deployment**: Export to ONNX, GGUF (llama.cpp), TFLite
|
| 38 |
+
|
| 39 |
+
## Model Variants
|
| 40 |
+
|
| 41 |
+
| Model | Total Params | Active Params | Size (INT4) | Target Device |
|
| 42 |
+
|-------|-------------|---------------|-------------|---------------|
|
| 43 |
+
| **max2-nano** | 500M | 125M | ~300MB | Smartwatch, IoT |
|
| 44 |
+
| **max2-lite** | 1.5B | 375M | ~900MB | Mobile phones |
|
| 45 |
+
| **max2-pro** | 3B | 750M | ~1.8GB | Tablets, laptops |
|
| 46 |
+
|
| 47 |
+
## Quick Start
|
| 48 |
+
|
| 49 |
+
### Installation
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
# Clone from HuggingFace
|
| 53 |
+
git clone https://huggingface.co/fariasultana/MiniMind
|
| 54 |
+
cd MiniMind
|
| 55 |
+
pip install -r requirements.txt
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Basic Usage
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
import torch
|
| 62 |
+
from model import create_model
|
| 63 |
+
|
| 64 |
+
# Create model (options: max2-nano, max2-lite, max2-pro)
|
| 65 |
+
model = create_model("max2-lite", device="cuda", dtype=torch.float16)
|
| 66 |
+
|
| 67 |
+
# Generate text
|
| 68 |
+
input_ids = tokenizer.encode("Hello, I am", return_tensors="pt").cuda()
|
| 69 |
+
output = model.generate(input_ids, max_new_tokens=50)
|
| 70 |
+
print(tokenizer.decode(output[0]))
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Using with Transformers (Custom)
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
import torch
|
| 77 |
+
from configs.model_config import get_config
|
| 78 |
+
from model import Max2ForCausalLM
|
| 79 |
+
|
| 80 |
+
# Load configuration
|
| 81 |
+
config = get_config("max2-nano")
|
| 82 |
+
|
| 83 |
+
# Create model
|
| 84 |
+
model = Max2ForCausalLM(config)
|
| 85 |
+
|
| 86 |
+
# Forward pass
|
| 87 |
+
input_ids = torch.randint(0, config.vocab_size, (1, 32))
|
| 88 |
+
loss, logits, cache, aux_loss = model(input_ids, labels=input_ids)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## Training
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Standard training
|
| 95 |
+
python scripts/train.py \
|
| 96 |
+
--model max2-lite \
|
| 97 |
+
--train-data data/train.jsonl \
|
| 98 |
+
--epochs 3 \
|
| 99 |
+
--batch-size 8 \
|
| 100 |
+
--output-dir outputs/
|
| 101 |
+
|
| 102 |
+
# Knowledge distillation from larger model
|
| 103 |
+
python scripts/train.py \
|
| 104 |
+
--model max2-lite \
|
| 105 |
+
--train-data data/train.jsonl \
|
| 106 |
+
--teacher-model path/to/teacher.pt \
|
| 107 |
+
--temperature 2.0 \
|
| 108 |
+
--alpha-kd 0.5
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
## Export for Deployment
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
# Export to ONNX and GGUF
|
| 115 |
+
python scripts/export.py \
|
| 116 |
+
--model max2-lite \
|
| 117 |
+
--checkpoint outputs/final/model.pt \
|
| 118 |
+
--format onnx gguf \
|
| 119 |
+
--quantize int4_awq
|
| 120 |
+
|
| 121 |
+
# Export for Android
|
| 122 |
+
python scripts/export.py \
|
| 123 |
+
--model max2-nano \
|
| 124 |
+
--format android \
|
| 125 |
+
--quantize int4_awq
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Architecture Details
|
| 129 |
+
|
| 130 |
+
### Mixture of Experts (MoE)
|
| 131 |
+
- 8 experts with top-2 routing (25% activation)
|
| 132 |
+
- Load balancing auxiliary loss for expert utilization
|
| 133 |
+
- Efficient sparse computation
|
| 134 |
+
|
| 135 |
+
### Grouped Query Attention (GQA)
|
| 136 |
+
- 4:1 ratio (4 query heads per KV head)
|
| 137 |
+
- Reduced memory footprint for KV cache
|
| 138 |
+
- Maintains quality with fewer parameters
|
| 139 |
+
|
| 140 |
+
### Core Optimizations
|
| 141 |
+
- **RMSNorm**: Faster than standard LayerNorm
|
| 142 |
+
- **SwiGLU**: Improved activation function
|
| 143 |
+
- **RoPE**: Rotary Position Embeddings for long context
|
| 144 |
+
- **Flash Attention**: Compatible for memory-efficient attention
|
| 145 |
+
|
| 146 |
+
## Project Structure
|
| 147 |
+
|
| 148 |
+
```
|
| 149 |
+
MiniMind/
|
| 150 |
+
├── configs/
|
| 151 |
+
│ └── model_config.py # Model configurations
|
| 152 |
+
├── model/
|
| 153 |
+
│ ├── components.py # RMSNorm, RoPE, GQA, MoE
|
| 154 |
+
│ └── mind2_model.py # Main model implementation
|
| 155 |
+
├── training/
|
| 156 |
+
│ ├── trainer.py # Training loop with AMP
|
| 157 |
+
│ ├── distillation.py # Knowledge distillation
|
| 158 |
+
│ └── dataset.py # Data loading utilities
|
| 159 |
+
├── optimization/
|
| 160 |
+
│ ├── quantization.py # INT4/INT8 quantization
|
| 161 |
+
│ ├── pruning.py # Structured/unstructured pruning
|
| 162 |
+
│ └── export.py # ONNX/GGUF export
|
| 163 |
+
├── android/
|
| 164 |
+
│ ├── app/ # Android app code
|
| 165 |
+
│ ├── jni/ # Native JNI bridge
|
| 166 |
+
│ └── README.md # Android deployment guide
|
| 167 |
+
├── examples/
|
| 168 |
+
│ └── quickstart.py # Quick start example
|
| 169 |
+
└── scripts/
|
| 170 |
+
├── train.py # Training script
|
| 171 |
+
└── export.py # Export script
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## Performance Benchmarks
|
| 175 |
+
|
| 176 |
+
| Device | Model | Tokens/sec | Memory |
|
| 177 |
+
|--------|-------|-----------|--------|
|
| 178 |
+
| RTX 4090 | max2-pro | 150+ | 4GB |
|
| 179 |
+
| M2 MacBook | max2-lite | 45 | 2GB |
|
| 180 |
+
| Pixel 8 Pro | max2-nano | 45 | 400MB |
|
| 181 |
+
| iPhone 15 Pro | max2-nano | 50 | 400MB |
|
| 182 |
+
|
| 183 |
+
## Android Deployment
|
| 184 |
+
|
| 185 |
+
See [android/README.md](android/README.md) for detailed Android deployment instructions.
|
| 186 |
+
|
| 187 |
+
Quick overview:
|
| 188 |
+
1. Export model to GGUF format
|
| 189 |
+
2. Build llama.cpp for Android NDK
|
| 190 |
+
3. Integrate with provided Kotlin wrapper
|
| 191 |
+
4. Use streaming API for responsive UI
|
| 192 |
+
|
| 193 |
+
## Citation
|
| 194 |
+
|
| 195 |
+
```bibtex
|
| 196 |
+
@misc{minimind-max2,
|
| 197 |
+
title={MiniMind Max2: Efficient Language Models for Edge Deployment},
|
| 198 |
+
author={Faria Sultana},
|
| 199 |
+
year={2024},
|
| 200 |
+
url={https://huggingface.co/fariasultana/MiniMind}
|
| 201 |
+
}
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
## License
|
| 205 |
+
|
| 206 |
+
Apache 2.0
|
| 207 |
+
|
| 208 |
+
## Acknowledgments
|
| 209 |
+
|
| 210 |
+
- Inspired by [MiniMax M2](https://www.minimax.io/news/minimax-m2)'s efficient activated parameters design
|
| 211 |
+
- Built with PyTorch and llama.cpp
|
| 212 |
+
- Thanks to the open-source AI community
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
**MiniMind Max2** - Bringing powerful AI to every device
|
android/README.md
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MiniMind Android Deployment Guide
|
| 2 |
+
|
| 3 |
+
Deploy MiniMind (Mind2) models on Android devices using multiple runtime options.
|
| 4 |
+
|
| 5 |
+
## Deployment Options
|
| 6 |
+
|
| 7 |
+
| Runtime | Size | Speed | Ease of Use |
|
| 8 |
+
|---------|------|-------|-------------|
|
| 9 |
+
| **llama.cpp** | ★★★★★ | ★★★★☆ | ★★★★☆ |
|
| 10 |
+
| **ONNX Runtime** | ★★★★☆ | ★★★☆☆ | ★★★★★ |
|
| 11 |
+
| **MLC-LLM** | ★★★★☆ | ★★★★★ | ★★★☆☆ |
|
| 12 |
+
| **TensorFlow Lite** | ★★★★★ | ★★★☆☆ | ★★★★☆ |
|
| 13 |
+
|
| 14 |
+
## Quick Start
|
| 15 |
+
|
| 16 |
+
### Option 1: llama.cpp (Recommended)
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
# 1. Export model to GGUF format
|
| 20 |
+
python scripts/export_gguf.py --model mind2-lite --output models/mind2-lite.gguf
|
| 21 |
+
|
| 22 |
+
# 2. Build llama.cpp for Android
|
| 23 |
+
git clone https://github.com/ggerganov/llama.cpp
|
| 24 |
+
cd llama.cpp
|
| 25 |
+
mkdir build-android && cd build-android
|
| 26 |
+
cmake .. -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake \
|
| 27 |
+
-DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-28
|
| 28 |
+
make -j
|
| 29 |
+
|
| 30 |
+
# 3. Copy to Android project
|
| 31 |
+
cp libllama.so ../android/app/src/main/jniLibs/arm64-v8a/
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### Option 2: ONNX Runtime
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# 1. Export model to ONNX
|
| 38 |
+
python scripts/export_onnx.py --model mind2-lite --output models/mind2-lite.onnx
|
| 39 |
+
|
| 40 |
+
# 2. Add ONNX Runtime to Android project
|
| 41 |
+
# In app/build.gradle:
|
| 42 |
+
dependencies {
|
| 43 |
+
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.16.0'
|
| 44 |
+
}
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Option 3: MLC-LLM
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
# 1. Install MLC-LLM
|
| 51 |
+
pip install mlc-llm
|
| 52 |
+
|
| 53 |
+
# 2. Compile model for Android
|
| 54 |
+
mlc_llm compile mind2-lite --target android
|
| 55 |
+
|
| 56 |
+
# 3. Package for deployment
|
| 57 |
+
mlc_llm package mind2-lite --target android --output ./android/app/src/main/assets/
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Project Structure
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
android/
|
| 64 |
+
├── app/
|
| 65 |
+
│ ├── src/main/
|
| 66 |
+
│ │ ├── java/com/minimind/
|
| 67 |
+
│ │ │ ├── Mind2Model.java # Model wrapper
|
| 68 |
+
│ │ │ ├── Mind2Tokenizer.java # Tokenizer
|
| 69 |
+
│ │ │ └── Mind2Chat.java # Chat interface
|
| 70 |
+
│ │ ├── jniLibs/
|
| 71 |
+
│ │ │ └── arm64-v8a/
|
| 72 |
+
│ │ │ └── libllama.so
|
| 73 |
+
│ │ └── assets/
|
| 74 |
+
│ │ ├── mind2-lite.gguf
|
| 75 |
+
│ │ └── tokenizer.json
|
| 76 |
+
│ └── build.gradle
|
| 77 |
+
├── jni/
|
| 78 |
+
│ ├── mind2_jni.cpp # JNI bridge
|
| 79 |
+
│ └── CMakeLists.txt
|
| 80 |
+
└── README.md
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Memory Requirements
|
| 84 |
+
|
| 85 |
+
| Model | RAM (INT4) | RAM (FP16) | Storage |
|
| 86 |
+
|-------|-----------|-----------|---------|
|
| 87 |
+
| mind2-nano | ~400MB | ~800MB | ~300MB |
|
| 88 |
+
| mind2-lite | ~1.2GB | ~2.4GB | ~900MB |
|
| 89 |
+
| mind2-pro | ~2.4GB | ~4.8GB | ~1.8GB |
|
| 90 |
+
|
| 91 |
+
## Performance Benchmarks
|
| 92 |
+
|
| 93 |
+
Tested on common Android devices:
|
| 94 |
+
|
| 95 |
+
| Device | Model | Tokens/sec |
|
| 96 |
+
|--------|-------|-----------|
|
| 97 |
+
| Pixel 8 Pro | mind2-nano | 45 |
|
| 98 |
+
| Pixel 8 Pro | mind2-lite | 22 |
|
| 99 |
+
| Samsung S24 | mind2-nano | 52 |
|
| 100 |
+
| Samsung S24 | mind2-lite | 28 |
|
| 101 |
+
|
| 102 |
+
## Best Practices
|
| 103 |
+
|
| 104 |
+
1. **Use INT4 quantization** for best size/performance balance
|
| 105 |
+
2. **Limit context length** to 512-1024 tokens on mobile
|
| 106 |
+
3. **Enable KV-cache** for faster generation
|
| 107 |
+
4. **Use streaming** for responsive UI
|
| 108 |
+
5. **Handle memory pressure** gracefully
|
| 109 |
+
|
| 110 |
+
## Troubleshooting
|
| 111 |
+
|
| 112 |
+
### Out of Memory
|
| 113 |
+
- Use smaller model (nano instead of lite)
|
| 114 |
+
- Reduce context length
|
| 115 |
+
- Enable swap if available
|
| 116 |
+
|
| 117 |
+
### Slow Inference
|
| 118 |
+
- Check CPU governor (set to performance)
|
| 119 |
+
- Ensure using NEON/ARM optimizations
|
| 120 |
+
- Consider GPU acceleration (MLC-LLM)
|
| 121 |
+
|
| 122 |
+
### Model Loading Failed
|
| 123 |
+
- Verify GGUF file integrity
|
| 124 |
+
- Check storage permissions
|
| 125 |
+
- Ensure enough free space
|
android/app/ChatScreen.kt
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package com.minimind.mind2.ui
|
| 2 |
+
|
| 3 |
+
import androidx.compose.foundation.layout.*
|
| 4 |
+
import androidx.compose.foundation.lazy.LazyColumn
|
| 5 |
+
import androidx.compose.foundation.lazy.items
|
| 6 |
+
import androidx.compose.foundation.lazy.rememberLazyListState
|
| 7 |
+
import androidx.compose.foundation.shape.RoundedCornerShape
|
| 8 |
+
import androidx.compose.material.icons.Icons
|
| 9 |
+
import androidx.compose.material.icons.filled.Send
|
| 10 |
+
import androidx.compose.material.icons.filled.Stop
|
| 11 |
+
import androidx.compose.material3.*
|
| 12 |
+
import androidx.compose.runtime.*
|
| 13 |
+
import androidx.compose.ui.Alignment
|
| 14 |
+
import androidx.compose.ui.Modifier
|
| 15 |
+
import androidx.compose.ui.graphics.Color
|
| 16 |
+
import androidx.compose.ui.text.font.FontWeight
|
| 17 |
+
import androidx.compose.ui.unit.dp
|
| 18 |
+
import androidx.lifecycle.ViewModel
|
| 19 |
+
import androidx.lifecycle.viewModelScope
|
| 20 |
+
import com.minimind.mind2.Mind2Model
|
| 21 |
+
import kotlinx.coroutines.flow.catch
|
| 22 |
+
import kotlinx.coroutines.launch
|
| 23 |
+
|
| 24 |
+
/**
|
| 25 |
+
* Chat ViewModel for MiniMind
|
| 26 |
+
*/
|
| 27 |
+
class ChatViewModel : ViewModel() {
|
| 28 |
+
private val model = Mind2Model.getInstance()
|
| 29 |
+
|
| 30 |
+
var messages = mutableStateListOf<ChatMessage>()
|
| 31 |
+
private set
|
| 32 |
+
|
| 33 |
+
var isGenerating by mutableStateOf(false)
|
| 34 |
+
private set
|
| 35 |
+
|
| 36 |
+
var isLoading by mutableStateOf(false)
|
| 37 |
+
private set
|
| 38 |
+
|
| 39 |
+
var error by mutableStateOf<String?>(null)
|
| 40 |
+
private set
|
| 41 |
+
|
| 42 |
+
var modelInfo by mutableStateOf("")
|
| 43 |
+
private set
|
| 44 |
+
|
| 45 |
+
private var currentResponse = StringBuilder()
|
| 46 |
+
|
| 47 |
+
fun loadModel(context: android.content.Context, modelName: String = "mind2-lite.gguf") {
|
| 48 |
+
viewModelScope.launch {
|
| 49 |
+
isLoading = true
|
| 50 |
+
error = null
|
| 51 |
+
|
| 52 |
+
model.load(context, modelName)
|
| 53 |
+
.onSuccess {
|
| 54 |
+
modelInfo = model.getInfo()
|
| 55 |
+
}
|
| 56 |
+
.onFailure {
|
| 57 |
+
error = "Failed to load model: ${it.message}"
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
isLoading = false
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
fun sendMessage(content: String) {
|
| 65 |
+
if (content.isBlank() || isGenerating) return
|
| 66 |
+
|
| 67 |
+
// Add user message
|
| 68 |
+
messages.add(ChatMessage("user", content))
|
| 69 |
+
|
| 70 |
+
// Add placeholder for assistant
|
| 71 |
+
currentResponse.clear()
|
| 72 |
+
messages.add(ChatMessage("assistant", ""))
|
| 73 |
+
|
| 74 |
+
isGenerating = true
|
| 75 |
+
error = null
|
| 76 |
+
|
| 77 |
+
val history = messages.dropLast(1).map {
|
| 78 |
+
Mind2Model.ChatMessage(it.role, it.content)
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
viewModelScope.launch {
|
| 82 |
+
model.chatStream(content, history)
|
| 83 |
+
.catch { e ->
|
| 84 |
+
error = "Generation error: ${e.message}"
|
| 85 |
+
isGenerating = false
|
| 86 |
+
}
|
| 87 |
+
.collect { token ->
|
| 88 |
+
currentResponse.append(token)
|
| 89 |
+
// Update last message
|
| 90 |
+
val lastIndex = messages.lastIndex
|
| 91 |
+
messages[lastIndex] = ChatMessage("assistant", currentResponse.toString())
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
isGenerating = false
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
fun stopGeneration() {
|
| 99 |
+
model.stop()
|
| 100 |
+
isGenerating = false
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
fun clearChat() {
|
| 104 |
+
messages.clear()
|
| 105 |
+
currentResponse.clear()
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
override fun onCleared() {
|
| 109 |
+
super.onCleared()
|
| 110 |
+
model.release()
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
data class ChatMessage(
|
| 115 |
+
val role: String,
|
| 116 |
+
val content: String
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
/**
|
| 120 |
+
* Chat Screen Composable
|
| 121 |
+
*/
|
| 122 |
+
@OptIn(ExperimentalMaterial3Api::class)
|
| 123 |
+
@Composable
|
| 124 |
+
fun ChatScreen(
|
| 125 |
+
viewModel: ChatViewModel
|
| 126 |
+
) {
|
| 127 |
+
var inputText by remember { mutableStateOf("") }
|
| 128 |
+
val listState = rememberLazyListState()
|
| 129 |
+
|
| 130 |
+
// Auto-scroll to bottom when new messages arrive
|
| 131 |
+
LaunchedEffect(viewModel.messages.size) {
|
| 132 |
+
if (viewModel.messages.isNotEmpty()) {
|
| 133 |
+
listState.animateScrollToItem(viewModel.messages.lastIndex)
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
Scaffold(
|
| 138 |
+
topBar = {
|
| 139 |
+
TopAppBar(
|
| 140 |
+
title = {
|
| 141 |
+
Column {
|
| 142 |
+
Text("MiniMind", fontWeight = FontWeight.Bold)
|
| 143 |
+
if (viewModel.isLoading) {
|
| 144 |
+
Text(
|
| 145 |
+
"Loading model...",
|
| 146 |
+
style = MaterialTheme.typography.bodySmall,
|
| 147 |
+
color = MaterialTheme.colorScheme.onSurfaceVariant
|
| 148 |
+
)
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
},
|
| 152 |
+
colors = TopAppBarDefaults.topAppBarColors(
|
| 153 |
+
containerColor = MaterialTheme.colorScheme.primaryContainer
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
}
|
| 157 |
+
) { padding ->
|
| 158 |
+
Column(
|
| 159 |
+
modifier = Modifier
|
| 160 |
+
.fillMaxSize()
|
| 161 |
+
.padding(padding)
|
| 162 |
+
) {
|
| 163 |
+
// Error banner
|
| 164 |
+
viewModel.error?.let { errorMsg ->
|
| 165 |
+
Surface(
|
| 166 |
+
color = MaterialTheme.colorScheme.errorContainer,
|
| 167 |
+
modifier = Modifier.fillMaxWidth()
|
| 168 |
+
) {
|
| 169 |
+
Text(
|
| 170 |
+
text = errorMsg,
|
| 171 |
+
color = MaterialTheme.colorScheme.onErrorContainer,
|
| 172 |
+
modifier = Modifier.padding(16.dp)
|
| 173 |
+
)
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// Messages list
|
| 178 |
+
LazyColumn(
|
| 179 |
+
state = listState,
|
| 180 |
+
modifier = Modifier
|
| 181 |
+
.weight(1f)
|
| 182 |
+
.fillMaxWidth(),
|
| 183 |
+
contentPadding = PaddingValues(16.dp),
|
| 184 |
+
verticalArrangement = Arrangement.spacedBy(12.dp)
|
| 185 |
+
) {
|
| 186 |
+
items(viewModel.messages) { message ->
|
| 187 |
+
MessageBubble(message)
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// Input area
|
| 192 |
+
Surface(
|
| 193 |
+
tonalElevation = 3.dp,
|
| 194 |
+
modifier = Modifier.fillMaxWidth()
|
| 195 |
+
) {
|
| 196 |
+
Row(
|
| 197 |
+
modifier = Modifier
|
| 198 |
+
.padding(16.dp)
|
| 199 |
+
.fillMaxWidth(),
|
| 200 |
+
verticalAlignment = Alignment.CenterVertically
|
| 201 |
+
) {
|
| 202 |
+
OutlinedTextField(
|
| 203 |
+
value = inputText,
|
| 204 |
+
onValueChange = { inputText = it },
|
| 205 |
+
modifier = Modifier.weight(1f),
|
| 206 |
+
placeholder = { Text("Type a message...") },
|
| 207 |
+
shape = RoundedCornerShape(24.dp),
|
| 208 |
+
enabled = !viewModel.isLoading && !viewModel.isGenerating
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
Spacer(modifier = Modifier.width(8.dp))
|
| 212 |
+
|
| 213 |
+
if (viewModel.isGenerating) {
|
| 214 |
+
FilledIconButton(
|
| 215 |
+
onClick = { viewModel.stopGeneration() },
|
| 216 |
+
colors = IconButtonDefaults.filledIconButtonColors(
|
| 217 |
+
containerColor = MaterialTheme.colorScheme.error
|
| 218 |
+
)
|
| 219 |
+
) {
|
| 220 |
+
Icon(Icons.Default.Stop, contentDescription = "Stop")
|
| 221 |
+
}
|
| 222 |
+
} else {
|
| 223 |
+
FilledIconButton(
|
| 224 |
+
onClick = {
|
| 225 |
+
viewModel.sendMessage(inputText)
|
| 226 |
+
inputText = ""
|
| 227 |
+
},
|
| 228 |
+
enabled = inputText.isNotBlank() && !viewModel.isLoading
|
| 229 |
+
) {
|
| 230 |
+
Icon(Icons.Default.Send, contentDescription = "Send")
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
@Composable
|
| 240 |
+
fun MessageBubble(message: ChatMessage) {
|
| 241 |
+
val isUser = message.role == "user"
|
| 242 |
+
|
| 243 |
+
Row(
|
| 244 |
+
modifier = Modifier.fillMaxWidth(),
|
| 245 |
+
horizontalArrangement = if (isUser) Arrangement.End else Arrangement.Start
|
| 246 |
+
) {
|
| 247 |
+
Surface(
|
| 248 |
+
shape = RoundedCornerShape(
|
| 249 |
+
topStart = 16.dp,
|
| 250 |
+
topEnd = 16.dp,
|
| 251 |
+
bottomStart = if (isUser) 16.dp else 4.dp,
|
| 252 |
+
bottomEnd = if (isUser) 4.dp else 16.dp
|
| 253 |
+
),
|
| 254 |
+
color = if (isUser)
|
| 255 |
+
MaterialTheme.colorScheme.primary
|
| 256 |
+
else
|
| 257 |
+
MaterialTheme.colorScheme.surfaceVariant,
|
| 258 |
+
modifier = Modifier.widthIn(max = 300.dp)
|
| 259 |
+
) {
|
| 260 |
+
Text(
|
| 261 |
+
text = message.content.ifEmpty { "..." },
|
| 262 |
+
modifier = Modifier.padding(12.dp),
|
| 263 |
+
color = if (isUser)
|
| 264 |
+
MaterialTheme.colorScheme.onPrimary
|
| 265 |
+
else
|
| 266 |
+
MaterialTheme.colorScheme.onSurfaceVariant
|
| 267 |
+
)
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
}
|
android/app/Mind2Model.kt
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package com.minimind.mind2
|
| 2 |
+
|
| 3 |
+
import android.content.Context
|
| 4 |
+
import kotlinx.coroutines.*
|
| 5 |
+
import kotlinx.coroutines.flow.*
|
| 6 |
+
import java.io.File
|
| 7 |
+
|
| 8 |
+
/**
|
| 9 |
+
* MiniMind (Mind2) Model Interface
|
| 10 |
+
* Kotlin wrapper for native llama.cpp inference
|
| 11 |
+
*/
|
| 12 |
+
class Mind2Model private constructor() {
|
| 13 |
+
|
| 14 |
+
companion object {
|
| 15 |
+
init {
|
| 16 |
+
System.loadLibrary("mind2")
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
private var instance: Mind2Model? = null
|
| 20 |
+
|
| 21 |
+
@JvmStatic
|
| 22 |
+
fun getInstance(): Mind2Model {
|
| 23 |
+
return instance ?: synchronized(this) {
|
| 24 |
+
instance ?: Mind2Model().also { instance = it }
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// Model state
|
| 30 |
+
private var isLoaded = false
|
| 31 |
+
private var modelPath: String? = null
|
| 32 |
+
|
| 33 |
+
// Generation parameters
|
| 34 |
+
data class GenerationConfig(
|
| 35 |
+
val maxTokens: Int = 256,
|
| 36 |
+
val temperature: Float = 0.7f,
|
| 37 |
+
val topP: Float = 0.9f,
|
| 38 |
+
val topK: Int = 40,
|
| 39 |
+
val repeatPenalty: Float = 1.1f,
|
| 40 |
+
val stopTokens: List<String> = listOf("<|endoftext|>", "<|im_end|>")
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
/**
|
| 44 |
+
* Load model from assets or file path
|
| 45 |
+
*/
|
| 46 |
+
suspend fun load(
|
| 47 |
+
context: Context,
|
| 48 |
+
modelName: String = "mind2-lite.gguf",
|
| 49 |
+
contextLength: Int = 2048,
|
| 50 |
+
threads: Int = 0 // 0 = auto
|
| 51 |
+
): Result<Unit> = withContext(Dispatchers.IO) {
|
| 52 |
+
try {
|
| 53 |
+
// Check if model is in assets
|
| 54 |
+
val assetPath = "models/$modelName"
|
| 55 |
+
val modelFile = File(context.filesDir, modelName)
|
| 56 |
+
|
| 57 |
+
if (!modelFile.exists()) {
|
| 58 |
+
// Copy from assets
|
| 59 |
+
context.assets.open(assetPath).use { input ->
|
| 60 |
+
modelFile.outputStream().use { output ->
|
| 61 |
+
input.copyTo(output)
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
modelPath = modelFile.absolutePath
|
| 67 |
+
val success = nativeInit(modelPath!!, contextLength, threads)
|
| 68 |
+
|
| 69 |
+
if (success) {
|
| 70 |
+
isLoaded = true
|
| 71 |
+
Result.success(Unit)
|
| 72 |
+
} else {
|
| 73 |
+
Result.failure(RuntimeException("Failed to load model"))
|
| 74 |
+
}
|
| 75 |
+
} catch (e: Exception) {
|
| 76 |
+
Result.failure(e)
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
/**
|
| 81 |
+
* Generate text (non-streaming)
|
| 82 |
+
*/
|
| 83 |
+
suspend fun generate(
|
| 84 |
+
prompt: String,
|
| 85 |
+
config: GenerationConfig = GenerationConfig()
|
| 86 |
+
): Result<String> = withContext(Dispatchers.IO) {
|
| 87 |
+
if (!isLoaded) {
|
| 88 |
+
return@withContext Result.failure(IllegalStateException("Model not loaded"))
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
try {
|
| 92 |
+
val result = nativeGenerate(
|
| 93 |
+
prompt,
|
| 94 |
+
config.maxTokens,
|
| 95 |
+
config.temperature,
|
| 96 |
+
config.topP,
|
| 97 |
+
config.topK
|
| 98 |
+
)
|
| 99 |
+
Result.success(result)
|
| 100 |
+
} catch (e: Exception) {
|
| 101 |
+
Result.failure(e)
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
/**
|
| 106 |
+
* Generate text with streaming
|
| 107 |
+
*/
|
| 108 |
+
fun generateStream(
|
| 109 |
+
prompt: String,
|
| 110 |
+
config: GenerationConfig = GenerationConfig()
|
| 111 |
+
): Flow<String> = callbackFlow {
|
| 112 |
+
if (!isLoaded) {
|
| 113 |
+
throw IllegalStateException("Model not loaded")
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
val callback = object : TokenCallback {
|
| 117 |
+
override fun onToken(token: String) {
|
| 118 |
+
trySend(token)
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
override fun onComplete() {
|
| 122 |
+
channel.close()
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
nativeGenerateStream(
|
| 127 |
+
prompt,
|
| 128 |
+
config.maxTokens,
|
| 129 |
+
config.temperature,
|
| 130 |
+
config.topP,
|
| 131 |
+
config.topK,
|
| 132 |
+
callback
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
awaitClose { stop() }
|
| 136 |
+
}.flowOn(Dispatchers.IO)
|
| 137 |
+
|
| 138 |
+
/**
|
| 139 |
+
* Chat with conversation history
|
| 140 |
+
*/
|
| 141 |
+
suspend fun chat(
|
| 142 |
+
message: String,
|
| 143 |
+
history: List<ChatMessage> = emptyList(),
|
| 144 |
+
config: GenerationConfig = GenerationConfig()
|
| 145 |
+
): Result<String> {
|
| 146 |
+
val prompt = buildChatPrompt(message, history)
|
| 147 |
+
return generate(prompt, config)
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
/**
|
| 151 |
+
* Chat with streaming
|
| 152 |
+
*/
|
| 153 |
+
fun chatStream(
|
| 154 |
+
message: String,
|
| 155 |
+
history: List<ChatMessage> = emptyList(),
|
| 156 |
+
config: GenerationConfig = GenerationConfig()
|
| 157 |
+
): Flow<String> {
|
| 158 |
+
val prompt = buildChatPrompt(message, history)
|
| 159 |
+
return generateStream(prompt, config)
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
private fun buildChatPrompt(message: String, history: List<ChatMessage>): String {
|
| 163 |
+
val sb = StringBuilder()
|
| 164 |
+
|
| 165 |
+
// System prompt
|
| 166 |
+
sb.append("<|im_start|>system\n")
|
| 167 |
+
sb.append("You are Mind2, a helpful AI assistant running locally on this device.\n")
|
| 168 |
+
sb.append("<|im_end|>\n")
|
| 169 |
+
|
| 170 |
+
// History
|
| 171 |
+
for (msg in history) {
|
| 172 |
+
sb.append("<|im_start|>${msg.role}\n")
|
| 173 |
+
sb.append("${msg.content}\n")
|
| 174 |
+
sb.append("<|im_end|>\n")
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// Current message
|
| 178 |
+
sb.append("<|im_start|>user\n")
|
| 179 |
+
sb.append("$message\n")
|
| 180 |
+
sb.append("<|im_end|>\n")
|
| 181 |
+
sb.append("<|im_start|>assistant\n")
|
| 182 |
+
|
| 183 |
+
return sb.toString()
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
/**
|
| 187 |
+
* Stop ongoing generation
|
| 188 |
+
*/
|
| 189 |
+
fun stop() {
|
| 190 |
+
nativeStop()
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
/**
|
| 194 |
+
* Release resources
|
| 195 |
+
*/
|
| 196 |
+
fun release() {
|
| 197 |
+
nativeRelease()
|
| 198 |
+
isLoaded = false
|
| 199 |
+
modelPath = null
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
/**
|
| 203 |
+
* Get model info
|
| 204 |
+
*/
|
| 205 |
+
fun getInfo(): String = nativeGetInfo()
|
| 206 |
+
|
| 207 |
+
/**
|
| 208 |
+
* Benchmark inference speed
|
| 209 |
+
*/
|
| 210 |
+
suspend fun benchmark(tokens: Int = 100): Float = withContext(Dispatchers.IO) {
|
| 211 |
+
nativeBenchmark(tokens)
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// Native methods
|
| 215 |
+
private external fun nativeInit(modelPath: String, nCtx: Int, nThreads: Int): Boolean
|
| 216 |
+
private external fun nativeGenerate(
|
| 217 |
+
prompt: String,
|
| 218 |
+
maxTokens: Int,
|
| 219 |
+
temperature: Float,
|
| 220 |
+
topP: Float,
|
| 221 |
+
topK: Int
|
| 222 |
+
): String
|
| 223 |
+
private external fun nativeGenerateStream(
|
| 224 |
+
prompt: String,
|
| 225 |
+
maxTokens: Int,
|
| 226 |
+
temperature: Float,
|
| 227 |
+
topP: Float,
|
| 228 |
+
topK: Int,
|
| 229 |
+
callback: TokenCallback
|
| 230 |
+
)
|
| 231 |
+
private external fun nativeStop()
|
| 232 |
+
private external fun nativeRelease()
|
| 233 |
+
private external fun nativeGetInfo(): String
|
| 234 |
+
private external fun nativeBenchmark(nTokens: Int): Float
|
| 235 |
+
|
| 236 |
+
interface TokenCallback {
|
| 237 |
+
fun onToken(token: String)
|
| 238 |
+
fun onComplete()
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
data class ChatMessage(
|
| 242 |
+
val role: String, // "user" or "assistant"
|
| 243 |
+
val content: String
|
| 244 |
+
)
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/**
|
| 248 |
+
* Extension function for easy initialization
|
| 249 |
+
*/
|
| 250 |
+
suspend fun Context.loadMind2Model(
|
| 251 |
+
modelName: String = "mind2-lite.gguf",
|
| 252 |
+
contextLength: Int = 2048
|
| 253 |
+
): Result<Mind2Model> {
|
| 254 |
+
val model = Mind2Model.getInstance()
|
| 255 |
+
return model.load(this, modelName, contextLength).map { model }
|
| 256 |
+
}
|
android/app/build.gradle
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
plugins {
|
| 2 |
+
id 'com.android.application'
|
| 3 |
+
id 'org.jetbrains.kotlin.android'
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
android {
|
| 7 |
+
namespace 'com.minimind.mind2'
|
| 8 |
+
compileSdk 34
|
| 9 |
+
|
| 10 |
+
defaultConfig {
|
| 11 |
+
applicationId "com.minimind.mind2"
|
| 12 |
+
minSdk 26
|
| 13 |
+
targetSdk 34
|
| 14 |
+
versionCode 1
|
| 15 |
+
versionName "1.0.0"
|
| 16 |
+
|
| 17 |
+
ndk {
|
| 18 |
+
abiFilters 'arm64-v8a', 'armeabi-v7a'
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
externalNativeBuild {
|
| 22 |
+
cmake {
|
| 23 |
+
cppFlags "-std=c++17 -O3 -ffast-math"
|
| 24 |
+
arguments "-DANDROID_ARM_NEON=TRUE"
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
buildTypes {
|
| 30 |
+
release {
|
| 31 |
+
minifyEnabled true
|
| 32 |
+
shrinkResources true
|
| 33 |
+
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
| 34 |
+
}
|
| 35 |
+
debug {
|
| 36 |
+
debuggable true
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
externalNativeBuild {
|
| 41 |
+
cmake {
|
| 42 |
+
path file('../jni/CMakeLists.txt')
|
| 43 |
+
version '3.22.1'
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
compileOptions {
|
| 48 |
+
sourceCompatibility JavaVersion.VERSION_17
|
| 49 |
+
targetCompatibility JavaVersion.VERSION_17
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
kotlinOptions {
|
| 53 |
+
jvmTarget = '17'
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
buildFeatures {
|
| 57 |
+
viewBinding true
|
| 58 |
+
compose true
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
composeOptions {
|
| 62 |
+
kotlinCompilerExtensionVersion '1.5.3'
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
packagingOptions {
|
| 66 |
+
jniLibs {
|
| 67 |
+
useLegacyPackaging true
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// Asset compression settings
|
| 72 |
+
aaptOptions {
|
| 73 |
+
noCompress 'gguf', 'onnx', 'bin'
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
dependencies {
|
| 78 |
+
// Core Android
|
| 79 |
+
implementation 'androidx.core:core-ktx:1.12.0'
|
| 80 |
+
implementation 'androidx.appcompat:appcompat:1.6.1'
|
| 81 |
+
implementation 'com.google.android.material:material:1.11.0'
|
| 82 |
+
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
|
| 83 |
+
|
| 84 |
+
// Jetpack Compose
|
| 85 |
+
implementation platform('androidx.compose:compose-bom:2024.01.00')
|
| 86 |
+
implementation 'androidx.compose.ui:ui'
|
| 87 |
+
implementation 'androidx.compose.ui:ui-graphics'
|
| 88 |
+
implementation 'androidx.compose.ui:ui-tooling-preview'
|
| 89 |
+
implementation 'androidx.compose.material3:material3'
|
| 90 |
+
implementation 'androidx.activity:activity-compose:1.8.2'
|
| 91 |
+
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.7.0'
|
| 92 |
+
|
| 93 |
+
// ONNX Runtime (optional - for ONNX deployment)
|
| 94 |
+
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.16.3'
|
| 95 |
+
|
| 96 |
+
// Coroutines
|
| 97 |
+
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3'
|
| 98 |
+
|
| 99 |
+
// Testing
|
| 100 |
+
testImplementation 'junit:junit:4.13.2'
|
| 101 |
+
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
|
| 102 |
+
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
|
| 103 |
+
}
|
android/jni/CMakeLists.txt
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 3.22.1)
|
| 2 |
+
project(mind2_android VERSION 1.0.0 LANGUAGES CXX)
|
| 3 |
+
|
| 4 |
+
set(CMAKE_CXX_STANDARD 17)
|
| 5 |
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
| 6 |
+
|
| 7 |
+
# Optimization flags
|
| 8 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math -fno-finite-math-only")
|
| 9 |
+
|
| 10 |
+
# ARM NEON optimizations
|
| 11 |
+
if(ANDROID_ABI STREQUAL "arm64-v8a")
|
| 12 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8-a+fp+simd")
|
| 13 |
+
elseif(ANDROID_ABI STREQUAL "armeabi-v7a")
|
| 14 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon -mfloat-abi=softfp")
|
| 15 |
+
endif()
|
| 16 |
+
|
| 17 |
+
# Include directories
|
| 18 |
+
include_directories(
|
| 19 |
+
${CMAKE_SOURCE_DIR}/include
|
| 20 |
+
${CMAKE_SOURCE_DIR}/llama.cpp
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# llama.cpp source files (subset needed for inference)
|
| 24 |
+
set(LLAMA_SOURCES
|
| 25 |
+
llama.cpp/ggml.c
|
| 26 |
+
llama.cpp/ggml-alloc.c
|
| 27 |
+
llama.cpp/ggml-backend.c
|
| 28 |
+
llama.cpp/ggml-quants.c
|
| 29 |
+
llama.cpp/llama.cpp
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Mind2 JNI bridge
|
| 33 |
+
set(MIND2_SOURCES
|
| 34 |
+
mind2_jni.cpp
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Build shared library
|
| 38 |
+
add_library(mind2_jni SHARED
|
| 39 |
+
${LLAMA_SOURCES}
|
| 40 |
+
${MIND2_SOURCES}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Link libraries
|
| 44 |
+
target_link_libraries(mind2_jni
|
| 45 |
+
android
|
| 46 |
+
log
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Set output name
|
| 50 |
+
set_target_properties(mind2_jni PROPERTIES
|
| 51 |
+
OUTPUT_NAME "mind2"
|
| 52 |
+
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/../app/src/main/jniLibs/${ANDROID_ABI}"
|
| 53 |
+
)
|
android/jni/mind2_jni.cpp
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* MiniMind (Mind2) JNI Bridge
|
| 3 |
+
* Provides Java/Kotlin interface to llama.cpp inference engine
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
#include <jni.h>
|
| 7 |
+
#include <android/log.h>
|
| 8 |
+
#include <android/asset_manager.h>
|
| 9 |
+
#include <android/asset_manager_jni.h>
|
| 10 |
+
|
| 11 |
+
#include <string>
|
| 12 |
+
#include <vector>
|
| 13 |
+
#include <memory>
|
| 14 |
+
#include <thread>
|
| 15 |
+
#include <atomic>
|
| 16 |
+
#include <mutex>
|
| 17 |
+
|
| 18 |
+
// If using llama.cpp, include these headers
|
| 19 |
+
// #include "llama.h"
|
| 20 |
+
// #include "ggml.h"
|
| 21 |
+
|
| 22 |
+
#define LOG_TAG "Mind2"
|
| 23 |
+
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
| 24 |
+
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
| 25 |
+
|
| 26 |
+
namespace {
|
| 27 |
+
|
| 28 |
+
// Model context (placeholder - would use llama_context in real implementation)
|
| 29 |
+
struct Mind2Context {
|
| 30 |
+
std::string model_path;
|
| 31 |
+
int n_ctx = 2048;
|
| 32 |
+
int n_threads = 4;
|
| 33 |
+
bool loaded = false;
|
| 34 |
+
std::atomic<bool> generating{false};
|
| 35 |
+
std::mutex mutex;
|
| 36 |
+
|
| 37 |
+
// llama_model* model = nullptr;
|
| 38 |
+
// llama_context* ctx = nullptr;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
std::unique_ptr<Mind2Context> g_context;
|
| 42 |
+
|
| 43 |
+
// Token callback for streaming
|
| 44 |
+
JavaVM* g_jvm = nullptr;
|
| 45 |
+
jobject g_callback = nullptr;
|
| 46 |
+
jmethodID g_callback_method = nullptr;
|
| 47 |
+
|
| 48 |
+
void stream_token(const std::string& token) {
|
| 49 |
+
if (!g_jvm || !g_callback) return;
|
| 50 |
+
|
| 51 |
+
JNIEnv* env = nullptr;
|
| 52 |
+
bool attached = false;
|
| 53 |
+
|
| 54 |
+
if (g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6) != JNI_OK) {
|
| 55 |
+
g_jvm->AttachCurrentThread(&env, nullptr);
|
| 56 |
+
attached = true;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
if (env && g_callback && g_callback_method) {
|
| 60 |
+
jstring jtoken = env->NewStringUTF(token.c_str());
|
| 61 |
+
env->CallVoidMethod(g_callback, g_callback_method, jtoken);
|
| 62 |
+
env->DeleteLocalRef(jtoken);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
if (attached) {
|
| 66 |
+
g_jvm->DetachCurrentThread();
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
} // anonymous namespace
|
| 71 |
+
|
| 72 |
+
extern "C" {
|
| 73 |
+
|
| 74 |
+
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) {
|
| 75 |
+
g_jvm = vm;
|
| 76 |
+
LOGI("Mind2 JNI loaded");
|
| 77 |
+
return JNI_VERSION_1_6;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
JNIEXPORT void JNICALL JNI_OnUnload(JavaVM* vm, void* reserved) {
|
| 81 |
+
g_context.reset();
|
| 82 |
+
g_jvm = nullptr;
|
| 83 |
+
LOGI("Mind2 JNI unloaded");
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/**
|
| 87 |
+
* Initialize the model
|
| 88 |
+
*/
|
| 89 |
+
JNIEXPORT jboolean JNICALL
|
| 90 |
+
Java_com_minimind_mind2_Mind2Model_nativeInit(
|
| 91 |
+
JNIEnv* env,
|
| 92 |
+
jobject thiz,
|
| 93 |
+
jstring model_path,
|
| 94 |
+
jint n_ctx,
|
| 95 |
+
jint n_threads
|
| 96 |
+
) {
|
| 97 |
+
const char* path = env->GetStringUTFChars(model_path, nullptr);
|
| 98 |
+
LOGI("Initializing Mind2 with model: %s", path);
|
| 99 |
+
|
| 100 |
+
g_context = std::make_unique<Mind2Context>();
|
| 101 |
+
g_context->model_path = path;
|
| 102 |
+
g_context->n_ctx = n_ctx;
|
| 103 |
+
g_context->n_threads = n_threads > 0 ? n_threads : std::thread::hardware_concurrency();
|
| 104 |
+
|
| 105 |
+
env->ReleaseStringUTFChars(model_path, path);
|
| 106 |
+
|
| 107 |
+
// TODO: Actual llama.cpp initialization
|
| 108 |
+
// llama_model_params model_params = llama_model_default_params();
|
| 109 |
+
// g_context->model = llama_load_model_from_file(g_context->model_path.c_str(), model_params);
|
| 110 |
+
// if (!g_context->model) {
|
| 111 |
+
// LOGE("Failed to load model");
|
| 112 |
+
// return JNI_FALSE;
|
| 113 |
+
// }
|
| 114 |
+
//
|
| 115 |
+
// llama_context_params ctx_params = llama_context_default_params();
|
| 116 |
+
// ctx_params.n_ctx = g_context->n_ctx;
|
| 117 |
+
// ctx_params.n_threads = g_context->n_threads;
|
| 118 |
+
// g_context->ctx = llama_new_context_with_model(g_context->model, ctx_params);
|
| 119 |
+
|
| 120 |
+
g_context->loaded = true;
|
| 121 |
+
LOGI("Mind2 initialized successfully (threads: %d, ctx: %d)",
|
| 122 |
+
g_context->n_threads, g_context->n_ctx);
|
| 123 |
+
|
| 124 |
+
return JNI_TRUE;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
/**
|
| 128 |
+
* Generate text from prompt
|
| 129 |
+
*/
|
| 130 |
+
JNIEXPORT jstring JNICALL
|
| 131 |
+
Java_com_minimind_mind2_Mind2Model_nativeGenerate(
|
| 132 |
+
JNIEnv* env,
|
| 133 |
+
jobject thiz,
|
| 134 |
+
jstring prompt,
|
| 135 |
+
jint max_tokens,
|
| 136 |
+
jfloat temperature,
|
| 137 |
+
jfloat top_p,
|
| 138 |
+
jint top_k
|
| 139 |
+
) {
|
| 140 |
+
if (!g_context || !g_context->loaded) {
|
| 141 |
+
LOGE("Model not initialized");
|
| 142 |
+
return env->NewStringUTF("");
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
std::lock_guard<std::mutex> lock(g_context->mutex);
|
| 146 |
+
|
| 147 |
+
const char* prompt_str = env->GetStringUTFChars(prompt, nullptr);
|
| 148 |
+
std::string result;
|
| 149 |
+
|
| 150 |
+
LOGI("Generating with prompt: %.50s...", prompt_str);
|
| 151 |
+
|
| 152 |
+
// TODO: Actual generation with llama.cpp
|
| 153 |
+
// This is a placeholder that returns the prompt
|
| 154 |
+
result = std::string(prompt_str) + "\n\n[Generated response would appear here]";
|
| 155 |
+
|
| 156 |
+
// Actual implementation would be:
|
| 157 |
+
// std::vector<llama_token> tokens = llama_tokenize(g_context->ctx, prompt_str, true);
|
| 158 |
+
// for (int i = 0; i < max_tokens; i++) {
|
| 159 |
+
// llama_token new_token = llama_sample_token(g_context->ctx, ...);
|
| 160 |
+
// if (new_token == llama_token_eos(g_context->ctx)) break;
|
| 161 |
+
// result += llama_token_to_piece(g_context->ctx, new_token);
|
| 162 |
+
// stream_token(llama_token_to_piece(g_context->ctx, new_token));
|
| 163 |
+
// }
|
| 164 |
+
|
| 165 |
+
env->ReleaseStringUTFChars(prompt, prompt_str);
|
| 166 |
+
|
| 167 |
+
return env->NewStringUTF(result.c_str());
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/**
|
| 171 |
+
* Generate with streaming callback
|
| 172 |
+
*/
|
| 173 |
+
JNIEXPORT void JNICALL
|
| 174 |
+
Java_com_minimind_mind2_Mind2Model_nativeGenerateStream(
|
| 175 |
+
JNIEnv* env,
|
| 176 |
+
jobject thiz,
|
| 177 |
+
jstring prompt,
|
| 178 |
+
jint max_tokens,
|
| 179 |
+
jfloat temperature,
|
| 180 |
+
jfloat top_p,
|
| 181 |
+
jint top_k,
|
| 182 |
+
jobject callback
|
| 183 |
+
) {
|
| 184 |
+
if (!g_context || !g_context->loaded) {
|
| 185 |
+
LOGE("Model not initialized");
|
| 186 |
+
return;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// Store callback reference
|
| 190 |
+
g_callback = env->NewGlobalRef(callback);
|
| 191 |
+
jclass callback_class = env->GetObjectClass(callback);
|
| 192 |
+
g_callback_method = env->GetMethodID(callback_class, "onToken", "(Ljava/lang/String;)V");
|
| 193 |
+
|
| 194 |
+
const char* prompt_str = env->GetStringUTFChars(prompt, nullptr);
|
| 195 |
+
|
| 196 |
+
g_context->generating = true;
|
| 197 |
+
|
| 198 |
+
// TODO: Actual streaming generation
|
| 199 |
+
// Simulated streaming for now
|
| 200 |
+
std::vector<std::string> demo_tokens = {
|
| 201 |
+
"Hello", "!", " ", "I", "'m", " ", "Mind2", ",",
|
| 202 |
+
" ", "a", " ", "lightweight", " ", "AI", " ", "assistant", "."
|
| 203 |
+
};
|
| 204 |
+
|
| 205 |
+
for (const auto& token : demo_tokens) {
|
| 206 |
+
if (!g_context->generating) break;
|
| 207 |
+
stream_token(token);
|
| 208 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
// Signal completion
|
| 212 |
+
jmethodID complete_method = env->GetMethodID(callback_class, "onComplete", "()V");
|
| 213 |
+
if (complete_method) {
|
| 214 |
+
env->CallVoidMethod(callback, complete_method);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
env->ReleaseStringUTFChars(prompt, prompt_str);
|
| 218 |
+
env->DeleteGlobalRef(g_callback);
|
| 219 |
+
g_callback = nullptr;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
/**
|
| 223 |
+
* Stop ongoing generation
|
| 224 |
+
*/
|
| 225 |
+
JNIEXPORT void JNICALL
|
| 226 |
+
Java_com_minimind_mind2_Mind2Model_nativeStop(
|
| 227 |
+
JNIEnv* env,
|
| 228 |
+
jobject thiz
|
| 229 |
+
) {
|
| 230 |
+
if (g_context) {
|
| 231 |
+
g_context->generating = false;
|
| 232 |
+
LOGI("Generation stopped");
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
/**
|
| 237 |
+
* Release model resources
|
| 238 |
+
*/
|
| 239 |
+
JNIEXPORT void JNICALL
|
| 240 |
+
Java_com_minimind_mind2_Mind2Model_nativeRelease(
|
| 241 |
+
JNIEnv* env,
|
| 242 |
+
jobject thiz
|
| 243 |
+
) {
|
| 244 |
+
if (g_context) {
|
| 245 |
+
std::lock_guard<std::mutex> lock(g_context->mutex);
|
| 246 |
+
|
| 247 |
+
// TODO: Release llama.cpp resources
|
| 248 |
+
// if (g_context->ctx) llama_free(g_context->ctx);
|
| 249 |
+
// if (g_context->model) llama_free_model(g_context->model);
|
| 250 |
+
|
| 251 |
+
g_context->loaded = false;
|
| 252 |
+
LOGI("Mind2 resources released");
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
/**
|
| 257 |
+
* Get model info
|
| 258 |
+
*/
|
| 259 |
+
JNIEXPORT jstring JNICALL
|
| 260 |
+
Java_com_minimind_mind2_Mind2Model_nativeGetInfo(
|
| 261 |
+
JNIEnv* env,
|
| 262 |
+
jobject thiz
|
| 263 |
+
) {
|
| 264 |
+
if (!g_context) {
|
| 265 |
+
return env->NewStringUTF("{}");
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
char info[512];
|
| 269 |
+
snprintf(info, sizeof(info),
|
| 270 |
+
"{\"loaded\": %s, \"model\": \"%s\", \"n_ctx\": %d, \"n_threads\": %d}",
|
| 271 |
+
g_context->loaded ? "true" : "false",
|
| 272 |
+
g_context->model_path.c_str(),
|
| 273 |
+
g_context->n_ctx,
|
| 274 |
+
g_context->n_threads
|
| 275 |
+
);
|
| 276 |
+
|
| 277 |
+
return env->NewStringUTF(info);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
/**
|
| 281 |
+
* Benchmark inference speed
|
| 282 |
+
*/
|
| 283 |
+
JNIEXPORT jfloat JNICALL
|
| 284 |
+
Java_com_minimind_mind2_Mind2Model_nativeBenchmark(
|
| 285 |
+
JNIEnv* env,
|
| 286 |
+
jobject thiz,
|
| 287 |
+
jint n_tokens
|
| 288 |
+
) {
|
| 289 |
+
if (!g_context || !g_context->loaded) {
|
| 290 |
+
return 0.0f;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
// TODO: Actual benchmark
|
| 294 |
+
// Simulated result
|
| 295 |
+
float tokens_per_second = 25.0f + (rand() % 10);
|
| 296 |
+
|
| 297 |
+
LOGI("Benchmark: %.1f tokens/sec", tokens_per_second);
|
| 298 |
+
return tokens_per_second;
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
} // extern "C"
|
config.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": ["Max2ForCausalLM"],
|
| 3 |
+
"model_type": "max2",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "configs.model_config--Max2Config",
|
| 6 |
+
"AutoModelForCausalLM": "model.mind2_model--Max2ForCausalLM"
|
| 7 |
+
},
|
| 8 |
+
"hidden_size": 1536,
|
| 9 |
+
"intermediate_size": 4096,
|
| 10 |
+
"num_hidden_layers": 24,
|
| 11 |
+
"num_attention_heads": 12,
|
| 12 |
+
"num_key_value_heads": 3,
|
| 13 |
+
"vocab_size": 32000,
|
| 14 |
+
"max_position_embeddings": 8192,
|
| 15 |
+
"rope_theta": 10000.0,
|
| 16 |
+
"use_moe": true,
|
| 17 |
+
"num_experts": 8,
|
| 18 |
+
"num_experts_per_tok": 2,
|
| 19 |
+
"expert_hidden_size": 1024,
|
| 20 |
+
"router_aux_loss_coef": 0.01,
|
| 21 |
+
"rms_norm_eps": 1e-6,
|
| 22 |
+
"hidden_act": "silu",
|
| 23 |
+
"hidden_dropout": 0.0,
|
| 24 |
+
"attention_dropout": 0.0,
|
| 25 |
+
"pad_token_id": 0,
|
| 26 |
+
"bos_token_id": 1,
|
| 27 |
+
"eos_token_id": 2,
|
| 28 |
+
"initializer_range": 0.02,
|
| 29 |
+
"use_cache": true,
|
| 30 |
+
"use_flash_attention": true,
|
| 31 |
+
"torch_dtype": "float16",
|
| 32 |
+
"transformers_version": "4.40.0",
|
| 33 |
+
"model_variants": {
|
| 34 |
+
"max2-nano": {
|
| 35 |
+
"hidden_size": 768,
|
| 36 |
+
"num_hidden_layers": 12,
|
| 37 |
+
"num_experts": 4,
|
| 38 |
+
"num_experts_per_tok": 1,
|
| 39 |
+
"total_params": "500M",
|
| 40 |
+
"active_params": "125M"
|
| 41 |
+
},
|
| 42 |
+
"max2-lite": {
|
| 43 |
+
"hidden_size": 1536,
|
| 44 |
+
"num_hidden_layers": 24,
|
| 45 |
+
"num_experts": 8,
|
| 46 |
+
"num_experts_per_tok": 2,
|
| 47 |
+
"total_params": "1.5B",
|
| 48 |
+
"active_params": "375M"
|
| 49 |
+
},
|
| 50 |
+
"max2-pro": {
|
| 51 |
+
"hidden_size": 2560,
|
| 52 |
+
"num_hidden_layers": 32,
|
| 53 |
+
"num_experts": 8,
|
| 54 |
+
"num_experts_per_tok": 2,
|
| 55 |
+
"total_params": "3B",
|
| 56 |
+
"active_params": "750M"
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
}
|
configs/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MiniMind Max2 Configuration Module"""
|
| 2 |
+
from .model_config import Max2Config, get_config, estimate_params, MAX2_CONFIGS
|
| 3 |
+
|
| 4 |
+
# Backward compatibility
|
| 5 |
+
Mind2Config = Max2Config
|
| 6 |
+
MIND2_CONFIGS = MAX2_CONFIGS
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"Max2Config",
|
| 10 |
+
"Mind2Config",
|
| 11 |
+
"get_config",
|
| 12 |
+
"estimate_params",
|
| 13 |
+
"MAX2_CONFIGS",
|
| 14 |
+
"MIND2_CONFIGS",
|
| 15 |
+
]
|
configs/model_config.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Max2 Model Configuration
|
| 3 |
+
Inspired by MiniMax M2's efficient activated parameters design
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Dict, Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Max2Config:
|
| 12 |
+
"""Configuration for MiniMind Max2 models."""
|
| 13 |
+
|
| 14 |
+
# Model identification
|
| 15 |
+
model_name: str = "max2-lite"
|
| 16 |
+
model_version: str = "1.0.0"
|
| 17 |
+
|
| 18 |
+
# Architecture dimensions
|
| 19 |
+
hidden_size: int = 1536
|
| 20 |
+
intermediate_size: int = 4096
|
| 21 |
+
num_hidden_layers: int = 24
|
| 22 |
+
num_attention_heads: int = 12
|
| 23 |
+
num_key_value_heads: int = 3 # GQA ratio 4:1
|
| 24 |
+
|
| 25 |
+
# Vocabulary and embeddings
|
| 26 |
+
vocab_size: int = 32000
|
| 27 |
+
max_position_embeddings: int = 8192
|
| 28 |
+
rope_theta: float = 10000.0
|
| 29 |
+
|
| 30 |
+
# MoE (Mixture of Experts) configuration
|
| 31 |
+
use_moe: bool = True
|
| 32 |
+
num_experts: int = 8
|
| 33 |
+
num_experts_per_tok: int = 2 # Only 25% activation
|
| 34 |
+
expert_hidden_size: int = 1024
|
| 35 |
+
router_aux_loss_coef: float = 0.01
|
| 36 |
+
|
| 37 |
+
# Normalization and activation
|
| 38 |
+
rms_norm_eps: float = 1e-6
|
| 39 |
+
hidden_act: str = "silu"
|
| 40 |
+
|
| 41 |
+
# Regularization
|
| 42 |
+
hidden_dropout: float = 0.0
|
| 43 |
+
attention_dropout: float = 0.0
|
| 44 |
+
|
| 45 |
+
# Special tokens
|
| 46 |
+
pad_token_id: int = 0
|
| 47 |
+
bos_token_id: int = 1
|
| 48 |
+
eos_token_id: int = 2
|
| 49 |
+
|
| 50 |
+
# Initialization
|
| 51 |
+
initializer_range: float = 0.02
|
| 52 |
+
|
| 53 |
+
# Memory optimization
|
| 54 |
+
use_cache: bool = True
|
| 55 |
+
use_flash_attention: bool = True
|
| 56 |
+
gradient_checkpointing: bool = False
|
| 57 |
+
|
| 58 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 59 |
+
return {k: v for k, v in self.__dict__.items()}
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def from_dict(cls, config_dict: Dict[str, Any]) -> "Max2Config":
|
| 63 |
+
return cls(**{k: v for k, v in config_dict.items() if k in cls.__dataclass_fields__})
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Predefined model configurations
|
| 67 |
+
MAX2_CONFIGS = {
|
| 68 |
+
"max2-nano": Max2Config(
|
| 69 |
+
model_name="max2-nano",
|
| 70 |
+
hidden_size=768,
|
| 71 |
+
intermediate_size=2048,
|
| 72 |
+
num_hidden_layers=12,
|
| 73 |
+
num_attention_heads=12,
|
| 74 |
+
num_key_value_heads=3,
|
| 75 |
+
num_experts=4,
|
| 76 |
+
num_experts_per_tok=1,
|
| 77 |
+
expert_hidden_size=512,
|
| 78 |
+
max_position_embeddings=4096,
|
| 79 |
+
),
|
| 80 |
+
"max2-lite": Max2Config(
|
| 81 |
+
model_name="max2-lite",
|
| 82 |
+
hidden_size=1536,
|
| 83 |
+
intermediate_size=4096,
|
| 84 |
+
num_hidden_layers=24,
|
| 85 |
+
num_attention_heads=12,
|
| 86 |
+
num_key_value_heads=3,
|
| 87 |
+
num_experts=8,
|
| 88 |
+
num_experts_per_tok=2,
|
| 89 |
+
expert_hidden_size=1024,
|
| 90 |
+
max_position_embeddings=8192,
|
| 91 |
+
),
|
| 92 |
+
"max2-pro": Max2Config(
|
| 93 |
+
model_name="max2-pro",
|
| 94 |
+
hidden_size=2560,
|
| 95 |
+
intermediate_size=6912,
|
| 96 |
+
num_hidden_layers=32,
|
| 97 |
+
num_attention_heads=20,
|
| 98 |
+
num_key_value_heads=4,
|
| 99 |
+
num_experts=8,
|
| 100 |
+
num_experts_per_tok=2,
|
| 101 |
+
expert_hidden_size=1728,
|
| 102 |
+
max_position_embeddings=16384,
|
| 103 |
+
),
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Aliases for backward compatibility
|
| 107 |
+
Mind2Config = Max2Config
|
| 108 |
+
MIND2_CONFIGS = MAX2_CONFIGS
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_config(model_name: str) -> Max2Config:
|
| 112 |
+
"""Get predefined configuration by name."""
|
| 113 |
+
if model_name not in MAX2_CONFIGS:
|
| 114 |
+
raise ValueError(f"Unknown model: {model_name}. Available: {list(MAX2_CONFIGS.keys())}")
|
| 115 |
+
return MAX2_CONFIGS[model_name]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def estimate_params(config: Max2Config) -> dict:
|
| 119 |
+
"""Estimate parameter counts for a configuration."""
|
| 120 |
+
embed_params = config.vocab_size * config.hidden_size
|
| 121 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
| 122 |
+
|
| 123 |
+
# Attention parameters per layer (GQA)
|
| 124 |
+
q_params = config.hidden_size * config.hidden_size
|
| 125 |
+
kv_params = 2 * config.hidden_size * (config.num_key_value_heads * head_dim)
|
| 126 |
+
o_params = config.hidden_size * config.hidden_size
|
| 127 |
+
attn_params_per_layer = q_params + kv_params + o_params
|
| 128 |
+
|
| 129 |
+
# MoE FFN parameters per layer
|
| 130 |
+
if config.use_moe:
|
| 131 |
+
router_params = config.hidden_size * config.num_experts
|
| 132 |
+
expert_params = 3 * config.hidden_size * config.expert_hidden_size
|
| 133 |
+
ffn_params_per_layer = router_params + (config.num_experts * expert_params)
|
| 134 |
+
active_ffn_params = router_params + (config.num_experts_per_tok * expert_params)
|
| 135 |
+
else:
|
| 136 |
+
ffn_params_per_layer = 3 * config.hidden_size * config.intermediate_size
|
| 137 |
+
active_ffn_params = ffn_params_per_layer
|
| 138 |
+
|
| 139 |
+
norm_params_per_layer = 2 * config.hidden_size
|
| 140 |
+
layer_params = attn_params_per_layer + ffn_params_per_layer + norm_params_per_layer
|
| 141 |
+
active_layer_params = attn_params_per_layer + active_ffn_params + norm_params_per_layer
|
| 142 |
+
|
| 143 |
+
total_params = embed_params + (config.num_hidden_layers * layer_params) + embed_params
|
| 144 |
+
active_params = embed_params + (config.num_hidden_layers * active_layer_params) + embed_params
|
| 145 |
+
|
| 146 |
+
return {
|
| 147 |
+
"total_params": total_params,
|
| 148 |
+
"active_params": active_params,
|
| 149 |
+
"activation_ratio": active_params / total_params,
|
| 150 |
+
"total_params_b": total_params / 1e9,
|
| 151 |
+
"active_params_b": active_params / 1e9,
|
| 152 |
+
"estimated_size_fp16_gb": (total_params * 2) / (1024**3),
|
| 153 |
+
"estimated_size_int4_gb": (total_params * 0.5) / (1024**3),
|
| 154 |
+
}
|
examples/quickstart.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
MiniMind Max2 Quick Start Example
|
| 4 |
+
Demonstrates basic usage of the Max2 model.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add parent directory
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
print("=" * 60)
|
| 18 |
+
print("MiniMind Max2 Quick Start")
|
| 19 |
+
print("=" * 60)
|
| 20 |
+
|
| 21 |
+
# Import model components
|
| 22 |
+
from configs.model_config import get_config, estimate_params
|
| 23 |
+
from model import Max2ForCausalLM
|
| 24 |
+
|
| 25 |
+
# Select model variant
|
| 26 |
+
model_name = "max2-nano" # Options: max2-nano, max2-lite, max2-pro
|
| 27 |
+
print(f"\n1. Creating {model_name} model...")
|
| 28 |
+
|
| 29 |
+
config = get_config(model_name)
|
| 30 |
+
model = Max2ForCausalLM(config)
|
| 31 |
+
|
| 32 |
+
# Show model info
|
| 33 |
+
params = estimate_params(config)
|
| 34 |
+
print(f" Total parameters: {params['total_params_b']:.3f}B")
|
| 35 |
+
print(f" Active parameters: {params['active_params_b']:.3f}B")
|
| 36 |
+
print(f" Activation ratio: {params['activation_ratio']:.1%}")
|
| 37 |
+
print(f" Estimated size (INT4): {params['estimated_size_int4_gb']:.2f}GB")
|
| 38 |
+
|
| 39 |
+
# Move to device
|
| 40 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 41 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 42 |
+
model = model.to(device=device, dtype=dtype)
|
| 43 |
+
print(f"\n2. Model loaded on {device} with {dtype}")
|
| 44 |
+
|
| 45 |
+
# Test forward pass
|
| 46 |
+
print("\n3. Testing forward pass...")
|
| 47 |
+
batch_size, seq_len = 2, 64
|
| 48 |
+
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device)
|
| 49 |
+
|
| 50 |
+
model.eval()
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
loss, logits, _, aux_loss = model(input_ids, labels=input_ids)
|
| 53 |
+
|
| 54 |
+
print(f" Input shape: {input_ids.shape}")
|
| 55 |
+
print(f" Output logits shape: {logits.shape}")
|
| 56 |
+
print(f" Loss: {loss:.4f}")
|
| 57 |
+
print(f" MoE auxiliary loss: {aux_loss:.6f}")
|
| 58 |
+
|
| 59 |
+
# Test generation
|
| 60 |
+
print("\n4. Testing generation...")
|
| 61 |
+
prompt = torch.randint(0, config.vocab_size, (1, 10), device=device)
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
generated = model.generate(
|
| 65 |
+
prompt,
|
| 66 |
+
max_new_tokens=20,
|
| 67 |
+
temperature=0.8,
|
| 68 |
+
top_k=50,
|
| 69 |
+
top_p=0.9,
|
| 70 |
+
do_sample=True,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
print(f" Prompt length: {prompt.shape[1]}")
|
| 74 |
+
print(f" Generated length: {generated.shape[1]}")
|
| 75 |
+
print(f" New tokens: {generated.shape[1] - prompt.shape[1]}")
|
| 76 |
+
|
| 77 |
+
# Memory usage
|
| 78 |
+
if device == "cuda":
|
| 79 |
+
memory_used = torch.cuda.max_memory_allocated() / 1024**3
|
| 80 |
+
print(f"\n5. Peak GPU memory: {memory_used:.2f}GB")
|
| 81 |
+
|
| 82 |
+
print("\n" + "=" * 60)
|
| 83 |
+
print("Quick start complete!")
|
| 84 |
+
print("=" * 60)
|
| 85 |
+
|
| 86 |
+
# Usage hints
|
| 87 |
+
print("\nNext steps:")
|
| 88 |
+
print(" - Train: python scripts/train.py --model max2-lite --train-data your_data.jsonl")
|
| 89 |
+
print(" - Export: python scripts/export.py --model max2-nano --format onnx gguf")
|
| 90 |
+
print(" - See README.md for full documentation")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
main()
|
model/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Max2 Model Package
|
| 3 |
+
A lightweight, efficient language model designed for edge deployment.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .mind2_model import (
|
| 7 |
+
Max2ForCausalLM,
|
| 8 |
+
Max2Model,
|
| 9 |
+
Mind2ForCausalLM,
|
| 10 |
+
Mind2Model,
|
| 11 |
+
create_model
|
| 12 |
+
)
|
| 13 |
+
from .components import (
|
| 14 |
+
Max2Attention,
|
| 15 |
+
Max2MoE,
|
| 16 |
+
Max2DecoderLayer,
|
| 17 |
+
Max2RMSNorm,
|
| 18 |
+
Max2RotaryEmbedding,
|
| 19 |
+
Max2MLP,
|
| 20 |
+
Max2Expert,
|
| 21 |
+
# Backward compatibility
|
| 22 |
+
Mind2Attention,
|
| 23 |
+
Mind2MoE,
|
| 24 |
+
Mind2DecoderLayer,
|
| 25 |
+
Mind2RMSNorm,
|
| 26 |
+
Mind2RotaryEmbedding,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
# Max2 (primary)
|
| 31 |
+
"Max2ForCausalLM",
|
| 32 |
+
"Max2Model",
|
| 33 |
+
"Max2Attention",
|
| 34 |
+
"Max2MoE",
|
| 35 |
+
"Max2DecoderLayer",
|
| 36 |
+
"Max2RMSNorm",
|
| 37 |
+
"Max2RotaryEmbedding",
|
| 38 |
+
"Max2MLP",
|
| 39 |
+
"Max2Expert",
|
| 40 |
+
# Mind2 (backward compatibility)
|
| 41 |
+
"Mind2ForCausalLM",
|
| 42 |
+
"Mind2Model",
|
| 43 |
+
"Mind2Attention",
|
| 44 |
+
"Mind2MoE",
|
| 45 |
+
"Mind2DecoderLayer",
|
| 46 |
+
"Mind2RMSNorm",
|
| 47 |
+
"Mind2RotaryEmbedding",
|
| 48 |
+
# Factory
|
| 49 |
+
"create_model",
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
__version__ = "1.0.0"
|
model/components.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Max2 Model Components
|
| 3 |
+
Core building blocks: RMSNorm, RoPE, GQA Attention, MoE
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
+
from configs.model_config import Max2Config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Max2RMSNorm(nn.Module):
|
| 19 |
+
"""Root Mean Square Layer Normalization (faster than LayerNorm)."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 24 |
+
self.eps = eps
|
| 25 |
+
|
| 26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
input_dtype = x.dtype
|
| 28 |
+
x = x.to(torch.float32)
|
| 29 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 30 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 31 |
+
return self.weight * x.to(input_dtype)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Max2RotaryEmbedding(nn.Module):
|
| 35 |
+
"""Rotary Position Embedding (RoPE) for efficient position encoding."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, dim: int, max_position_embeddings: int = 8192, base: float = 10000.0):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.dim = dim
|
| 40 |
+
self.max_position_embeddings = max_position_embeddings
|
| 41 |
+
self.base = base
|
| 42 |
+
|
| 43 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
| 44 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 45 |
+
self._set_cos_sin_cache(max_position_embeddings)
|
| 46 |
+
|
| 47 |
+
def _set_cos_sin_cache(self, seq_len: int):
|
| 48 |
+
self.max_seq_len_cached = seq_len
|
| 49 |
+
t = torch.arange(seq_len, dtype=torch.float32)
|
| 50 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 51 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 52 |
+
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| 53 |
+
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
| 54 |
+
|
| 55 |
+
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 56 |
+
if seq_len > self.max_seq_len_cached:
|
| 57 |
+
self._set_cos_sin_cache(seq_len)
|
| 58 |
+
return self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
"""Rotate half the hidden dims of the input."""
|
| 63 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 64 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 68 |
+
"""Apply rotary position embeddings to query and key tensors."""
|
| 69 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 70 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 71 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 72 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 73 |
+
return q_embed, k_embed
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Max2Attention(nn.Module):
|
| 77 |
+
"""Grouped Query Attention (GQA) - fewer KV heads than Q heads for memory efficiency."""
|
| 78 |
+
|
| 79 |
+
def __init__(self, config: Max2Config, layer_idx: int):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.config = config
|
| 82 |
+
self.layer_idx = layer_idx
|
| 83 |
+
self.hidden_size = config.hidden_size
|
| 84 |
+
self.num_heads = config.num_attention_heads
|
| 85 |
+
self.num_kv_heads = config.num_key_value_heads
|
| 86 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 87 |
+
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
| 88 |
+
|
| 89 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 90 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| 91 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| 92 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 93 |
+
|
| 94 |
+
self.rotary_emb = Max2RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
|
| 95 |
+
self.attention_dropout = config.attention_dropout
|
| 96 |
+
|
| 97 |
+
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 98 |
+
if n_rep == 1:
|
| 99 |
+
return hidden_states
|
| 100 |
+
bs, num_kv_heads, seq_len, head_dim = hidden_states.shape
|
| 101 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(bs, num_kv_heads, n_rep, seq_len, head_dim)
|
| 102 |
+
return hidden_states.reshape(bs, num_kv_heads * n_rep, seq_len, head_dim)
|
| 103 |
+
|
| 104 |
+
def forward(
|
| 105 |
+
self,
|
| 106 |
+
hidden_states: torch.Tensor,
|
| 107 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 108 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 109 |
+
use_cache: bool = False,
|
| 110 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 111 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 112 |
+
|
| 113 |
+
query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 114 |
+
key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 115 |
+
value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 116 |
+
|
| 117 |
+
cos, sin = self.rotary_emb(value_states, seq_len)
|
| 118 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 119 |
+
|
| 120 |
+
if past_key_value is not None:
|
| 121 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 122 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 123 |
+
|
| 124 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 125 |
+
|
| 126 |
+
key_states = self._repeat_kv(key_states, self.num_key_value_groups)
|
| 127 |
+
value_states = self._repeat_kv(value_states, self.num_key_value_groups)
|
| 128 |
+
|
| 129 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 130 |
+
if attention_mask is not None:
|
| 131 |
+
attn_weights = attn_weights + attention_mask
|
| 132 |
+
|
| 133 |
+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 134 |
+
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 135 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 136 |
+
|
| 137 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
|
| 138 |
+
attn_output = self.o_proj(attn_output)
|
| 139 |
+
|
| 140 |
+
return attn_output, past_key_value
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class Max2MLP(nn.Module):
|
| 144 |
+
"""SwiGLU Feed-Forward Network."""
|
| 145 |
+
|
| 146 |
+
def __init__(self, hidden_size: int, intermediate_size: int):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 149 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 150 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 151 |
+
|
| 152 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 153 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Max2Expert(nn.Module):
|
| 157 |
+
"""Single expert in the Mixture of Experts layer."""
|
| 158 |
+
|
| 159 |
+
def __init__(self, hidden_size: int, expert_hidden_size: int):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.mlp = Max2MLP(hidden_size, expert_hidden_size)
|
| 162 |
+
|
| 163 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
return self.mlp(x)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Max2MoE(nn.Module):
|
| 168 |
+
"""
|
| 169 |
+
Mixture of Experts (MoE) layer.
|
| 170 |
+
Efficient parameter activation - only top-k experts are used per token.
|
| 171 |
+
Inspired by MiniMax M2's efficient activated parameters design.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, config: Max2Config):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.hidden_size = config.hidden_size
|
| 177 |
+
self.num_experts = config.num_experts
|
| 178 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 179 |
+
self.expert_hidden_size = config.expert_hidden_size
|
| 180 |
+
|
| 181 |
+
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
|
| 182 |
+
self.experts = nn.ModuleList([
|
| 183 |
+
Max2Expert(self.hidden_size, self.expert_hidden_size)
|
| 184 |
+
for _ in range(self.num_experts)
|
| 185 |
+
])
|
| 186 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 187 |
+
|
| 188 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 189 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 190 |
+
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
| 191 |
+
|
| 192 |
+
router_logits = self.gate(hidden_states_flat)
|
| 193 |
+
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
| 194 |
+
|
| 195 |
+
router_weights, selected_experts = torch.topk(router_probs, self.num_experts_per_tok, dim=-1)
|
| 196 |
+
router_weights = router_weights.to(hidden_states.dtype)
|
| 197 |
+
router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
|
| 198 |
+
|
| 199 |
+
final_hidden_states = torch.zeros_like(hidden_states_flat)
|
| 200 |
+
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 201 |
+
|
| 202 |
+
for expert_idx in range(self.num_experts):
|
| 203 |
+
expert = self.experts[expert_idx]
|
| 204 |
+
for top_k_idx in range(self.num_experts_per_tok):
|
| 205 |
+
token_indices = expert_mask[expert_idx, top_k_idx].nonzero(as_tuple=True)[0]
|
| 206 |
+
if token_indices.numel() > 0:
|
| 207 |
+
expert_input = hidden_states_flat[token_indices]
|
| 208 |
+
expert_output = expert(expert_input)
|
| 209 |
+
weights = router_weights[token_indices, top_k_idx].unsqueeze(-1)
|
| 210 |
+
final_hidden_states[token_indices] += weights * expert_output
|
| 211 |
+
|
| 212 |
+
final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
|
| 213 |
+
|
| 214 |
+
num_tokens = router_probs.shape[0]
|
| 215 |
+
expert_mask_float = F.one_hot(selected_experts, num_classes=self.num_experts).float()
|
| 216 |
+
tokens_per_expert = expert_mask_float.sum(dim=(0, 1)) / num_tokens
|
| 217 |
+
router_prob_per_expert = router_probs.mean(dim=0)
|
| 218 |
+
aux_loss = self.num_experts * (tokens_per_expert * router_prob_per_expert).sum() * self.router_aux_loss_coef
|
| 219 |
+
|
| 220 |
+
return final_hidden_states, aux_loss
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class Max2DecoderLayer(nn.Module):
|
| 224 |
+
"""Single transformer decoder layer with GQA attention and MoE FFN."""
|
| 225 |
+
|
| 226 |
+
def __init__(self, config: Max2Config, layer_idx: int):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.hidden_size = config.hidden_size
|
| 229 |
+
self.self_attn = Max2Attention(config, layer_idx)
|
| 230 |
+
|
| 231 |
+
if config.use_moe:
|
| 232 |
+
self.mlp = Max2MoE(config)
|
| 233 |
+
self.use_moe = True
|
| 234 |
+
else:
|
| 235 |
+
self.mlp = Max2MLP(config.hidden_size, config.intermediate_size)
|
| 236 |
+
self.use_moe = False
|
| 237 |
+
|
| 238 |
+
self.input_layernorm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 239 |
+
self.post_attention_layernorm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 240 |
+
|
| 241 |
+
def forward(
|
| 242 |
+
self,
|
| 243 |
+
hidden_states: torch.Tensor,
|
| 244 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 245 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 246 |
+
use_cache: bool = False,
|
| 247 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
|
| 248 |
+
residual = hidden_states
|
| 249 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 250 |
+
hidden_states, present_key_value = self.self_attn(hidden_states, attention_mask, past_key_value, use_cache)
|
| 251 |
+
hidden_states = residual + hidden_states
|
| 252 |
+
|
| 253 |
+
residual = hidden_states
|
| 254 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 255 |
+
|
| 256 |
+
if self.use_moe:
|
| 257 |
+
hidden_states, aux_loss = self.mlp(hidden_states)
|
| 258 |
+
else:
|
| 259 |
+
hidden_states = self.mlp(hidden_states)
|
| 260 |
+
aux_loss = torch.tensor(0.0, device=hidden_states.device)
|
| 261 |
+
|
| 262 |
+
hidden_states = residual + hidden_states
|
| 263 |
+
|
| 264 |
+
return hidden_states, present_key_value, aux_loss
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Backward compatibility aliases
|
| 268 |
+
Mind2RMSNorm = Max2RMSNorm
|
| 269 |
+
Mind2RotaryEmbedding = Max2RotaryEmbedding
|
| 270 |
+
Mind2Attention = Max2Attention
|
| 271 |
+
Mind2MLP = Max2MLP
|
| 272 |
+
Mind2Expert = Max2Expert
|
| 273 |
+
Mind2MoE = Max2MoE
|
| 274 |
+
Mind2DecoderLayer = Max2DecoderLayer
|
model/mind2_model.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Max2 Main Model
|
| 3 |
+
Complete implementation of the Max2 language model.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.nn import CrossEntropyLoss
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
+
from configs.model_config import Max2Config, get_config
|
| 16 |
+
from .components import Max2DecoderLayer, Max2RMSNorm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Max2Model(nn.Module):
|
| 20 |
+
"""Max2 Transformer Model - outputs raw hidden states."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, config: Max2Config):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.config = config
|
| 25 |
+
self.padding_idx = config.pad_token_id
|
| 26 |
+
self.vocab_size = config.vocab_size
|
| 27 |
+
|
| 28 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
|
| 29 |
+
self.layers = nn.ModuleList([Max2DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
|
| 30 |
+
self.norm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 31 |
+
|
| 32 |
+
self.gradient_checkpointing = False
|
| 33 |
+
self._init_weights()
|
| 34 |
+
|
| 35 |
+
def _init_weights(self):
|
| 36 |
+
for module in self.modules():
|
| 37 |
+
if isinstance(module, nn.Linear):
|
| 38 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 39 |
+
if module.bias is not None:
|
| 40 |
+
module.bias.data.zero_()
|
| 41 |
+
elif isinstance(module, nn.Embedding):
|
| 42 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 43 |
+
|
| 44 |
+
def _make_causal_mask(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
| 45 |
+
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
| 46 |
+
mask = torch.triu(mask, diagonal=1)
|
| 47 |
+
return mask.unsqueeze(0).unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
def forward(
|
| 50 |
+
self,
|
| 51 |
+
input_ids: torch.LongTensor,
|
| 52 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 53 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 54 |
+
use_cache: bool = False,
|
| 55 |
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]], torch.Tensor]:
|
| 56 |
+
batch_size, seq_len = input_ids.shape
|
| 57 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 58 |
+
|
| 59 |
+
causal_mask = self._make_causal_mask(seq_len, hidden_states.dtype, hidden_states.device)
|
| 60 |
+
if attention_mask is not None:
|
| 61 |
+
padding_mask = (1.0 - attention_mask[:, None, None, :].to(hidden_states.dtype)) * float("-inf")
|
| 62 |
+
causal_mask = causal_mask + padding_mask
|
| 63 |
+
|
| 64 |
+
next_cache = [] if use_cache else None
|
| 65 |
+
total_aux_loss = torch.tensor(0.0, device=hidden_states.device)
|
| 66 |
+
|
| 67 |
+
for idx, layer in enumerate(self.layers):
|
| 68 |
+
past_kv = past_key_values[idx] if past_key_values else None
|
| 69 |
+
hidden_states, present_kv, aux_loss = layer(hidden_states, causal_mask, past_kv, use_cache)
|
| 70 |
+
|
| 71 |
+
if use_cache:
|
| 72 |
+
next_cache.append(present_kv)
|
| 73 |
+
total_aux_loss = total_aux_loss + aux_loss
|
| 74 |
+
|
| 75 |
+
hidden_states = self.norm(hidden_states)
|
| 76 |
+
return hidden_states, next_cache, total_aux_loss
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Max2ForCausalLM(nn.Module):
|
| 80 |
+
"""Max2 Model with Language Modeling head for text generation."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, config: Max2Config):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.config = config
|
| 85 |
+
self.model = Max2Model(config)
|
| 86 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 87 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 88 |
+
|
| 89 |
+
def forward(
|
| 90 |
+
self,
|
| 91 |
+
input_ids: torch.LongTensor,
|
| 92 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 93 |
+
labels: Optional[torch.LongTensor] = None,
|
| 94 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 95 |
+
use_cache: bool = False,
|
| 96 |
+
) -> Tuple[Optional[torch.Tensor], torch.Tensor, Optional[List], torch.Tensor]:
|
| 97 |
+
hidden_states, next_cache, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache)
|
| 98 |
+
logits = self.lm_head(hidden_states).float()
|
| 99 |
+
|
| 100 |
+
loss = None
|
| 101 |
+
if labels is not None:
|
| 102 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 103 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 104 |
+
loss = CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
| 105 |
+
loss = loss + aux_loss
|
| 106 |
+
|
| 107 |
+
return loss, logits, next_cache, aux_loss
|
| 108 |
+
|
| 109 |
+
@torch.no_grad()
|
| 110 |
+
def generate(
|
| 111 |
+
self,
|
| 112 |
+
input_ids: torch.LongTensor,
|
| 113 |
+
max_new_tokens: int = 100,
|
| 114 |
+
temperature: float = 1.0,
|
| 115 |
+
top_k: int = 50,
|
| 116 |
+
top_p: float = 0.95,
|
| 117 |
+
do_sample: bool = True,
|
| 118 |
+
) -> torch.LongTensor:
|
| 119 |
+
"""Simple generation with top-k/top-p sampling."""
|
| 120 |
+
generated = input_ids
|
| 121 |
+
past_key_values = None
|
| 122 |
+
|
| 123 |
+
for _ in range(max_new_tokens):
|
| 124 |
+
if past_key_values is None:
|
| 125 |
+
_, logits, past_key_values, _ = self(generated, use_cache=True)
|
| 126 |
+
else:
|
| 127 |
+
_, logits, past_key_values, _ = self(generated[:, -1:], past_key_values=past_key_values, use_cache=True)
|
| 128 |
+
|
| 129 |
+
next_token_logits = logits[:, -1, :] / temperature
|
| 130 |
+
|
| 131 |
+
if do_sample:
|
| 132 |
+
if top_k > 0:
|
| 133 |
+
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
| 134 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
| 135 |
+
|
| 136 |
+
if top_p < 1.0:
|
| 137 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 138 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 139 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 140 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 141 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 142 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 143 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
| 144 |
+
|
| 145 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 146 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 147 |
+
else:
|
| 148 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 149 |
+
|
| 150 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 151 |
+
|
| 152 |
+
if (next_token == self.config.eos_token_id).all():
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
return generated
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# Backward compatibility aliases
|
| 159 |
+
Mind2Model = Max2Model
|
| 160 |
+
Mind2ForCausalLM = Max2ForCausalLM
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def create_model(model_name: str = "max2-lite", device: str = "cuda", dtype: torch.dtype = torch.float16) -> Max2ForCausalLM:
|
| 164 |
+
"""Factory function to create a Max2 model."""
|
| 165 |
+
config = get_config(model_name)
|
| 166 |
+
model = Max2ForCausalLM(config)
|
| 167 |
+
return model.to(device=device, dtype=dtype) if torch.cuda.is_available() else model
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
for model_name in ["max2-nano", "max2-lite", "max2-pro"]:
|
| 172 |
+
print(f"\n{'='*50}\nTesting {model_name}\n{'='*50}")
|
| 173 |
+
config = get_config(model_name)
|
| 174 |
+
model = Max2ForCausalLM(config)
|
| 175 |
+
|
| 176 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 177 |
+
print(f"Total Parameters: {total_params / 1e9:.3f}B")
|
| 178 |
+
|
| 179 |
+
input_ids = torch.randint(0, config.vocab_size, (2, 128))
|
| 180 |
+
model.eval()
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
loss, logits, _, aux_loss = model(input_ids, labels=input_ids)
|
| 183 |
+
print(f"Logits shape: {logits.shape}")
|
| 184 |
+
print(f"Loss: {loss:.4f}, Aux loss: {aux_loss:.6f}")
|
| 185 |
+
print("Forward pass successful!")
|
optimization/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MiniMind Optimization Package"""
|
| 2 |
+
from .quantization import Mind2Quantizer, quantize_model
|
| 3 |
+
from .pruning import Mind2Pruner, prune_model
|
| 4 |
+
from .export import export_to_onnx, export_to_gguf
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"Mind2Quantizer", "quantize_model",
|
| 8 |
+
"Mind2Pruner", "prune_model",
|
| 9 |
+
"export_to_onnx", "export_to_gguf",
|
| 10 |
+
]
|
optimization/export.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Export Utilities
|
| 3 |
+
Export models to ONNX, GGUF (llama.cpp), and other formats.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import struct
|
| 8 |
+
from typing import Optional, Dict, Any, List
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from dataclasses import dataclass, asdict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ExportConfig:
|
| 18 |
+
"""Configuration for model export."""
|
| 19 |
+
# ONNX settings
|
| 20 |
+
opset_version: int = 17
|
| 21 |
+
use_external_data: bool = False
|
| 22 |
+
optimize_for_mobile: bool = True
|
| 23 |
+
|
| 24 |
+
# GGUF settings
|
| 25 |
+
gguf_quant_type: str = "Q4_K_M" # Q4_0, Q4_K_M, Q5_K_M, Q8_0, F16
|
| 26 |
+
gguf_use_mmap: bool = True
|
| 27 |
+
|
| 28 |
+
# General
|
| 29 |
+
max_seq_len: int = 2048
|
| 30 |
+
batch_size: int = 1
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def export_to_onnx(
|
| 34 |
+
model: nn.Module,
|
| 35 |
+
output_path: str,
|
| 36 |
+
config: Optional[ExportConfig] = None,
|
| 37 |
+
sample_input: Optional[torch.Tensor] = None,
|
| 38 |
+
) -> str:
|
| 39 |
+
"""
|
| 40 |
+
Export model to ONNX format.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
model: PyTorch model to export
|
| 44 |
+
output_path: Path to save ONNX model
|
| 45 |
+
config: Export configuration
|
| 46 |
+
sample_input: Sample input tensor for tracing
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Path to exported model
|
| 50 |
+
"""
|
| 51 |
+
config = config or ExportConfig()
|
| 52 |
+
output_path = Path(output_path)
|
| 53 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
model.eval()
|
| 56 |
+
device = next(model.parameters()).device
|
| 57 |
+
|
| 58 |
+
# Create sample input if not provided
|
| 59 |
+
if sample_input is None:
|
| 60 |
+
sample_input = torch.randint(
|
| 61 |
+
0, 1000,
|
| 62 |
+
(config.batch_size, config.max_seq_len),
|
| 63 |
+
dtype=torch.long,
|
| 64 |
+
device=device,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Dynamic axes for variable sequence length
|
| 68 |
+
dynamic_axes = {
|
| 69 |
+
"input_ids": {0: "batch_size", 1: "sequence_length"},
|
| 70 |
+
"logits": {0: "batch_size", 1: "sequence_length"},
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
# Wrapper to simplify output
|
| 74 |
+
class ONNXWrapper(nn.Module):
|
| 75 |
+
def __init__(self, model):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.model = model
|
| 78 |
+
|
| 79 |
+
def forward(self, input_ids):
|
| 80 |
+
_, logits, _, _ = self.model(input_ids)
|
| 81 |
+
return logits
|
| 82 |
+
|
| 83 |
+
wrapped_model = ONNXWrapper(model)
|
| 84 |
+
|
| 85 |
+
# Export
|
| 86 |
+
torch.onnx.export(
|
| 87 |
+
wrapped_model,
|
| 88 |
+
(sample_input,),
|
| 89 |
+
str(output_path),
|
| 90 |
+
opset_version=config.opset_version,
|
| 91 |
+
input_names=["input_ids"],
|
| 92 |
+
output_names=["logits"],
|
| 93 |
+
dynamic_axes=dynamic_axes,
|
| 94 |
+
do_constant_folding=True,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
print(f"ONNX model exported to {output_path}")
|
| 98 |
+
|
| 99 |
+
# Optimize for mobile if requested
|
| 100 |
+
if config.optimize_for_mobile:
|
| 101 |
+
try:
|
| 102 |
+
import onnx
|
| 103 |
+
from onnxruntime.transformers import optimizer
|
| 104 |
+
|
| 105 |
+
optimized_path = output_path.with_suffix(".optimized.onnx")
|
| 106 |
+
onnx_model = onnx.load(str(output_path))
|
| 107 |
+
|
| 108 |
+
# Basic optimization
|
| 109 |
+
from onnx import optimizer as onnx_optimizer
|
| 110 |
+
passes = ["fuse_bn_into_conv", "fuse_consecutive_transposes"]
|
| 111 |
+
optimized_model = onnx_optimizer.optimize(onnx_model, passes)
|
| 112 |
+
onnx.save(optimized_model, str(optimized_path))
|
| 113 |
+
|
| 114 |
+
print(f"Optimized ONNX model saved to {optimized_path}")
|
| 115 |
+
except ImportError:
|
| 116 |
+
print("Note: Install onnx and onnxruntime for optimization")
|
| 117 |
+
|
| 118 |
+
return str(output_path)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# GGUF format constants
|
| 122 |
+
GGUF_MAGIC = 0x46554747 # "GGUF" in little endian
|
| 123 |
+
GGUF_VERSION = 3
|
| 124 |
+
|
| 125 |
+
GGUF_TYPE_UINT8 = 0
|
| 126 |
+
GGUF_TYPE_INT8 = 1
|
| 127 |
+
GGUF_TYPE_UINT16 = 2
|
| 128 |
+
GGUF_TYPE_INT16 = 3
|
| 129 |
+
GGUF_TYPE_UINT32 = 4
|
| 130 |
+
GGUF_TYPE_INT32 = 5
|
| 131 |
+
GGUF_TYPE_FLOAT32 = 6
|
| 132 |
+
GGUF_TYPE_BOOL = 7
|
| 133 |
+
GGUF_TYPE_STRING = 8
|
| 134 |
+
GGUF_TYPE_ARRAY = 9
|
| 135 |
+
GGUF_TYPE_UINT64 = 10
|
| 136 |
+
GGUF_TYPE_INT64 = 11
|
| 137 |
+
GGUF_TYPE_FLOAT64 = 12
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class GGUFWriter:
|
| 141 |
+
"""Writer for GGUF format (llama.cpp compatible)."""
|
| 142 |
+
|
| 143 |
+
def __init__(self, output_path: str):
|
| 144 |
+
self.output_path = Path(output_path)
|
| 145 |
+
self.metadata: Dict[str, Any] = {}
|
| 146 |
+
self.tensors: List[Dict[str, Any]] = []
|
| 147 |
+
|
| 148 |
+
def add_metadata(self, key: str, value: Any, value_type: int = None):
|
| 149 |
+
"""Add metadata key-value pair."""
|
| 150 |
+
self.metadata[key] = {"value": value, "type": value_type}
|
| 151 |
+
|
| 152 |
+
def add_tensor(self, name: str, tensor: torch.Tensor, quant_type: str = "F32"):
|
| 153 |
+
"""Add a tensor to be written."""
|
| 154 |
+
self.tensors.append({
|
| 155 |
+
"name": name,
|
| 156 |
+
"data": tensor.cpu().numpy(),
|
| 157 |
+
"quant_type": quant_type,
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
def _write_string(self, f, s: str):
|
| 161 |
+
"""Write a string in GGUF format."""
|
| 162 |
+
encoded = s.encode("utf-8")
|
| 163 |
+
f.write(struct.pack("<Q", len(encoded)))
|
| 164 |
+
f.write(encoded)
|
| 165 |
+
|
| 166 |
+
def _write_metadata_value(self, f, value: Any, value_type: int):
|
| 167 |
+
"""Write a metadata value."""
|
| 168 |
+
f.write(struct.pack("<I", value_type))
|
| 169 |
+
|
| 170 |
+
if value_type == GGUF_TYPE_UINT32:
|
| 171 |
+
f.write(struct.pack("<I", value))
|
| 172 |
+
elif value_type == GGUF_TYPE_INT32:
|
| 173 |
+
f.write(struct.pack("<i", value))
|
| 174 |
+
elif value_type == GGUF_TYPE_FLOAT32:
|
| 175 |
+
f.write(struct.pack("<f", value))
|
| 176 |
+
elif value_type == GGUF_TYPE_STRING:
|
| 177 |
+
self._write_string(f, value)
|
| 178 |
+
elif value_type == GGUF_TYPE_BOOL:
|
| 179 |
+
f.write(struct.pack("<?", value))
|
| 180 |
+
|
| 181 |
+
def write(self):
|
| 182 |
+
"""Write the GGUF file."""
|
| 183 |
+
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 184 |
+
|
| 185 |
+
with open(self.output_path, "wb") as f:
|
| 186 |
+
# Header
|
| 187 |
+
f.write(struct.pack("<I", GGUF_MAGIC))
|
| 188 |
+
f.write(struct.pack("<I", GGUF_VERSION))
|
| 189 |
+
f.write(struct.pack("<Q", len(self.tensors)))
|
| 190 |
+
f.write(struct.pack("<Q", len(self.metadata)))
|
| 191 |
+
|
| 192 |
+
# Metadata
|
| 193 |
+
for key, meta in self.metadata.items():
|
| 194 |
+
self._write_string(f, key)
|
| 195 |
+
self._write_metadata_value(f, meta["value"], meta["type"])
|
| 196 |
+
|
| 197 |
+
# Tensor info (headers)
|
| 198 |
+
tensor_data_offset = f.tell()
|
| 199 |
+
for tensor_info in self.tensors:
|
| 200 |
+
self._write_string(f, tensor_info["name"])
|
| 201 |
+
data = tensor_info["data"]
|
| 202 |
+
|
| 203 |
+
# Number of dimensions
|
| 204 |
+
f.write(struct.pack("<I", len(data.shape)))
|
| 205 |
+
|
| 206 |
+
# Dimensions
|
| 207 |
+
for dim in data.shape:
|
| 208 |
+
f.write(struct.pack("<Q", dim))
|
| 209 |
+
|
| 210 |
+
# Data type (simplified - using F32 for now)
|
| 211 |
+
f.write(struct.pack("<I", GGUF_TYPE_FLOAT32))
|
| 212 |
+
|
| 213 |
+
# Offset (to be updated)
|
| 214 |
+
f.write(struct.pack("<Q", 0))
|
| 215 |
+
|
| 216 |
+
# Alignment padding
|
| 217 |
+
alignment = 32
|
| 218 |
+
current_pos = f.tell()
|
| 219 |
+
padding = (alignment - (current_pos % alignment)) % alignment
|
| 220 |
+
f.write(b"\x00" * padding)
|
| 221 |
+
|
| 222 |
+
# Tensor data
|
| 223 |
+
for tensor_info in self.tensors:
|
| 224 |
+
data = tensor_info["data"].astype("float32")
|
| 225 |
+
f.write(data.tobytes())
|
| 226 |
+
|
| 227 |
+
print(f"GGUF model written to {self.output_path}")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def export_to_gguf(
|
| 231 |
+
model: nn.Module,
|
| 232 |
+
output_path: str,
|
| 233 |
+
model_config: Any,
|
| 234 |
+
config: Optional[ExportConfig] = None,
|
| 235 |
+
) -> str:
|
| 236 |
+
"""
|
| 237 |
+
Export model to GGUF format for llama.cpp.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
model: PyTorch model to export
|
| 241 |
+
output_path: Path to save GGUF model
|
| 242 |
+
model_config: Model configuration
|
| 243 |
+
config: Export configuration
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Path to exported model
|
| 247 |
+
"""
|
| 248 |
+
config = config or ExportConfig()
|
| 249 |
+
writer = GGUFWriter(output_path)
|
| 250 |
+
|
| 251 |
+
# Add model metadata
|
| 252 |
+
writer.add_metadata("general.architecture", "mind2", GGUF_TYPE_STRING)
|
| 253 |
+
writer.add_metadata("general.name", model_config.model_name, GGUF_TYPE_STRING)
|
| 254 |
+
writer.add_metadata("mind2.context_length", model_config.max_position_embeddings, GGUF_TYPE_UINT32)
|
| 255 |
+
writer.add_metadata("mind2.embedding_length", model_config.hidden_size, GGUF_TYPE_UINT32)
|
| 256 |
+
writer.add_metadata("mind2.block_count", model_config.num_hidden_layers, GGUF_TYPE_UINT32)
|
| 257 |
+
writer.add_metadata("mind2.attention.head_count", model_config.num_attention_heads, GGUF_TYPE_UINT32)
|
| 258 |
+
writer.add_metadata("mind2.attention.head_count_kv", model_config.num_key_value_heads, GGUF_TYPE_UINT32)
|
| 259 |
+
writer.add_metadata("mind2.rope.freq_base", model_config.rope_theta, GGUF_TYPE_FLOAT32)
|
| 260 |
+
writer.add_metadata("mind2.expert_count", model_config.num_experts, GGUF_TYPE_UINT32)
|
| 261 |
+
writer.add_metadata("mind2.expert_used_count", model_config.num_experts_per_tok, GGUF_TYPE_UINT32)
|
| 262 |
+
|
| 263 |
+
# Add tokenizer metadata (placeholder)
|
| 264 |
+
writer.add_metadata("tokenizer.ggml.model", "gpt2", GGUF_TYPE_STRING)
|
| 265 |
+
|
| 266 |
+
# Export tensors
|
| 267 |
+
state_dict = model.state_dict()
|
| 268 |
+
tensor_name_map = {
|
| 269 |
+
"model.embed_tokens.weight": "token_embd.weight",
|
| 270 |
+
"model.norm.weight": "output_norm.weight",
|
| 271 |
+
"lm_head.weight": "output.weight",
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
for name, tensor in state_dict.items():
|
| 275 |
+
# Map tensor names to GGUF convention
|
| 276 |
+
gguf_name = tensor_name_map.get(name, name)
|
| 277 |
+
|
| 278 |
+
# Layer-specific mappings
|
| 279 |
+
if "layers." in name:
|
| 280 |
+
parts = name.split(".")
|
| 281 |
+
layer_idx = parts[2]
|
| 282 |
+
|
| 283 |
+
if "self_attn.q_proj" in name:
|
| 284 |
+
gguf_name = f"blk.{layer_idx}.attn_q.weight"
|
| 285 |
+
elif "self_attn.k_proj" in name:
|
| 286 |
+
gguf_name = f"blk.{layer_idx}.attn_k.weight"
|
| 287 |
+
elif "self_attn.v_proj" in name:
|
| 288 |
+
gguf_name = f"blk.{layer_idx}.attn_v.weight"
|
| 289 |
+
elif "self_attn.o_proj" in name:
|
| 290 |
+
gguf_name = f"blk.{layer_idx}.attn_output.weight"
|
| 291 |
+
elif "input_layernorm" in name:
|
| 292 |
+
gguf_name = f"blk.{layer_idx}.attn_norm.weight"
|
| 293 |
+
elif "post_attention_layernorm" in name:
|
| 294 |
+
gguf_name = f"blk.{layer_idx}.ffn_norm.weight"
|
| 295 |
+
elif "mlp.gate" in name:
|
| 296 |
+
gguf_name = f"blk.{layer_idx}.ffn_gate.weight"
|
| 297 |
+
elif "experts" in name:
|
| 298 |
+
expert_idx = parts[4]
|
| 299 |
+
if "gate_proj" in name:
|
| 300 |
+
gguf_name = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.weight"
|
| 301 |
+
elif "up_proj" in name:
|
| 302 |
+
gguf_name = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.weight"
|
| 303 |
+
elif "down_proj" in name:
|
| 304 |
+
gguf_name = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.weight"
|
| 305 |
+
|
| 306 |
+
writer.add_tensor(gguf_name, tensor)
|
| 307 |
+
|
| 308 |
+
writer.write()
|
| 309 |
+
return str(output_path)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def export_for_android(
|
| 313 |
+
model: nn.Module,
|
| 314 |
+
output_dir: str,
|
| 315 |
+
model_config: Any,
|
| 316 |
+
export_formats: List[str] = ["onnx", "gguf"],
|
| 317 |
+
) -> Dict[str, str]:
|
| 318 |
+
"""
|
| 319 |
+
Export model in formats suitable for Android deployment.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
model: PyTorch model
|
| 323 |
+
output_dir: Output directory
|
| 324 |
+
model_config: Model configuration
|
| 325 |
+
export_formats: List of formats to export
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
Dictionary mapping format to output path
|
| 329 |
+
"""
|
| 330 |
+
output_dir = Path(output_dir)
|
| 331 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 332 |
+
outputs = {}
|
| 333 |
+
|
| 334 |
+
config = ExportConfig(
|
| 335 |
+
optimize_for_mobile=True,
|
| 336 |
+
max_seq_len=512, # Shorter for mobile
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if "onnx" in export_formats:
|
| 340 |
+
onnx_path = output_dir / f"{model_config.model_name}.onnx"
|
| 341 |
+
outputs["onnx"] = export_to_onnx(model, str(onnx_path), config)
|
| 342 |
+
|
| 343 |
+
if "gguf" in export_formats:
|
| 344 |
+
gguf_path = output_dir / f"{model_config.model_name}.gguf"
|
| 345 |
+
outputs["gguf"] = export_to_gguf(model, str(gguf_path), model_config, config)
|
| 346 |
+
|
| 347 |
+
# Create model info JSON for Android app
|
| 348 |
+
model_info = {
|
| 349 |
+
"model_name": model_config.model_name,
|
| 350 |
+
"vocab_size": model_config.vocab_size,
|
| 351 |
+
"hidden_size": model_config.hidden_size,
|
| 352 |
+
"num_layers": model_config.num_hidden_layers,
|
| 353 |
+
"num_heads": model_config.num_attention_heads,
|
| 354 |
+
"max_seq_len": config.max_seq_len,
|
| 355 |
+
"exports": {k: str(v) for k, v in outputs.items()},
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
info_path = output_dir / "model_info.json"
|
| 359 |
+
with open(info_path, "w") as f:
|
| 360 |
+
json.dump(model_info, f, indent=2)
|
| 361 |
+
|
| 362 |
+
print(f"Model info saved to {info_path}")
|
| 363 |
+
outputs["info"] = str(info_path)
|
| 364 |
+
|
| 365 |
+
return outputs
|
optimization/pruning.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Pruning Toolkit
|
| 3 |
+
Structured and unstructured pruning for model compression.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Dict, List, Tuple
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from enum import Enum
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.utils.prune as prune
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PruningMethod(Enum):
|
| 17 |
+
"""Supported pruning methods."""
|
| 18 |
+
MAGNITUDE = "magnitude" # L1 magnitude pruning
|
| 19 |
+
STRUCTURED = "structured" # Channel/head pruning
|
| 20 |
+
MOVEMENT = "movement" # Movement pruning (requires training)
|
| 21 |
+
WANDA = "wanda" # Weights AND Activations
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class PruningConfig:
|
| 26 |
+
"""Configuration for pruning."""
|
| 27 |
+
method: PruningMethod = PruningMethod.MAGNITUDE
|
| 28 |
+
sparsity: float = 0.5 # Target sparsity ratio
|
| 29 |
+
structured: bool = False # Whether to use structured pruning
|
| 30 |
+
prune_heads: bool = True # Prune attention heads
|
| 31 |
+
prune_experts: bool = True # Prune MoE experts
|
| 32 |
+
prune_ffn: bool = True # Prune FFN neurons
|
| 33 |
+
min_heads: int = 2 # Minimum attention heads to keep
|
| 34 |
+
min_experts: int = 2 # Minimum experts to keep
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Mind2Pruner:
|
| 38 |
+
"""Pruner for MiniMind models."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, config: Optional[PruningConfig] = None):
|
| 41 |
+
self.config = config or PruningConfig()
|
| 42 |
+
|
| 43 |
+
def prune(
|
| 44 |
+
self,
|
| 45 |
+
model: nn.Module,
|
| 46 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 47 |
+
) -> nn.Module:
|
| 48 |
+
"""
|
| 49 |
+
Prune the model.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
model: Model to prune
|
| 53 |
+
calibration_data: Data for importance estimation
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Pruned model
|
| 57 |
+
"""
|
| 58 |
+
if self.config.method == PruningMethod.MAGNITUDE:
|
| 59 |
+
return self._magnitude_pruning(model)
|
| 60 |
+
elif self.config.method == PruningMethod.STRUCTURED:
|
| 61 |
+
return self._structured_pruning(model, calibration_data)
|
| 62 |
+
elif self.config.method == PruningMethod.WANDA:
|
| 63 |
+
return self._wanda_pruning(model, calibration_data)
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"Unsupported pruning method: {self.config.method}")
|
| 66 |
+
|
| 67 |
+
def _magnitude_pruning(self, model: nn.Module) -> nn.Module:
|
| 68 |
+
"""Apply unstructured magnitude pruning."""
|
| 69 |
+
modules_to_prune = []
|
| 70 |
+
|
| 71 |
+
for name, module in model.named_modules():
|
| 72 |
+
if isinstance(module, nn.Linear):
|
| 73 |
+
modules_to_prune.append((module, "weight"))
|
| 74 |
+
|
| 75 |
+
# Apply global unstructured pruning
|
| 76 |
+
prune.global_unstructured(
|
| 77 |
+
modules_to_prune,
|
| 78 |
+
pruning_method=prune.L1Unstructured,
|
| 79 |
+
amount=self.config.sparsity,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Make pruning permanent
|
| 83 |
+
for module, _ in modules_to_prune:
|
| 84 |
+
prune.remove(module, "weight")
|
| 85 |
+
|
| 86 |
+
return model
|
| 87 |
+
|
| 88 |
+
def _structured_pruning(
|
| 89 |
+
self,
|
| 90 |
+
model: nn.Module,
|
| 91 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 92 |
+
) -> nn.Module:
|
| 93 |
+
"""Apply structured pruning (channels/heads)."""
|
| 94 |
+
# Compute importance scores
|
| 95 |
+
importance_scores = self._compute_importance(model, calibration_data)
|
| 96 |
+
|
| 97 |
+
# Prune attention heads
|
| 98 |
+
if self.config.prune_heads:
|
| 99 |
+
model = self._prune_attention_heads(model, importance_scores)
|
| 100 |
+
|
| 101 |
+
# Prune FFN neurons
|
| 102 |
+
if self.config.prune_ffn:
|
| 103 |
+
model = self._prune_ffn_neurons(model, importance_scores)
|
| 104 |
+
|
| 105 |
+
# Prune experts
|
| 106 |
+
if self.config.prune_experts:
|
| 107 |
+
model = self._prune_experts(model, importance_scores)
|
| 108 |
+
|
| 109 |
+
return model
|
| 110 |
+
|
| 111 |
+
def _compute_importance(
|
| 112 |
+
self,
|
| 113 |
+
model: nn.Module,
|
| 114 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 115 |
+
) -> Dict[str, torch.Tensor]:
|
| 116 |
+
"""Compute importance scores for different components."""
|
| 117 |
+
importance = {}
|
| 118 |
+
|
| 119 |
+
# Head importance (based on output norm)
|
| 120 |
+
for name, module in model.named_modules():
|
| 121 |
+
if hasattr(module, "num_heads"):
|
| 122 |
+
# Use weight magnitude as proxy for importance
|
| 123 |
+
q_weight = getattr(module, "q_proj", None)
|
| 124 |
+
if q_weight is not None:
|
| 125 |
+
weight = q_weight.weight.data
|
| 126 |
+
num_heads = module.num_heads
|
| 127 |
+
head_dim = weight.shape[0] // num_heads
|
| 128 |
+
|
| 129 |
+
head_importance = torch.zeros(num_heads)
|
| 130 |
+
for h in range(num_heads):
|
| 131 |
+
start = h * head_dim
|
| 132 |
+
end = (h + 1) * head_dim
|
| 133 |
+
head_importance[h] = weight[start:end].norm()
|
| 134 |
+
|
| 135 |
+
importance[f"{name}.heads"] = head_importance
|
| 136 |
+
|
| 137 |
+
# FFN neuron importance
|
| 138 |
+
for name, module in model.named_modules():
|
| 139 |
+
if isinstance(module, nn.Linear) and "gate_proj" in name:
|
| 140 |
+
weight = module.weight.data
|
| 141 |
+
neuron_importance = weight.norm(dim=1)
|
| 142 |
+
importance[f"{name}.neurons"] = neuron_importance
|
| 143 |
+
|
| 144 |
+
# Expert importance (for MoE)
|
| 145 |
+
for name, module in model.named_modules():
|
| 146 |
+
if hasattr(module, "experts"):
|
| 147 |
+
expert_importance = torch.zeros(len(module.experts))
|
| 148 |
+
for i, expert in enumerate(module.experts):
|
| 149 |
+
expert_params = sum(p.numel() for p in expert.parameters())
|
| 150 |
+
expert_norm = sum(p.data.norm() for p in expert.parameters())
|
| 151 |
+
expert_importance[i] = expert_norm / max(1, expert_params)
|
| 152 |
+
|
| 153 |
+
importance[f"{name}.experts"] = expert_importance
|
| 154 |
+
|
| 155 |
+
return importance
|
| 156 |
+
|
| 157 |
+
def _prune_attention_heads(
|
| 158 |
+
self,
|
| 159 |
+
model: nn.Module,
|
| 160 |
+
importance: Dict[str, torch.Tensor],
|
| 161 |
+
) -> nn.Module:
|
| 162 |
+
"""Prune least important attention heads."""
|
| 163 |
+
for name, module in model.named_modules():
|
| 164 |
+
if hasattr(module, "num_heads"):
|
| 165 |
+
head_key = f"{name}.heads"
|
| 166 |
+
if head_key in importance:
|
| 167 |
+
scores = importance[head_key]
|
| 168 |
+
num_heads = len(scores)
|
| 169 |
+
num_prune = int(num_heads * self.config.sparsity)
|
| 170 |
+
num_keep = max(self.config.min_heads, num_heads - num_prune)
|
| 171 |
+
|
| 172 |
+
# Get indices of heads to keep
|
| 173 |
+
_, keep_indices = torch.topk(scores, num_keep)
|
| 174 |
+
keep_indices = keep_indices.sort()[0]
|
| 175 |
+
|
| 176 |
+
# Create mask for pruning
|
| 177 |
+
head_dim = module.head_dim
|
| 178 |
+
mask = torch.zeros(num_heads * head_dim)
|
| 179 |
+
for idx in keep_indices:
|
| 180 |
+
start = idx * head_dim
|
| 181 |
+
end = (idx + 1) * head_dim
|
| 182 |
+
mask[start:end] = 1
|
| 183 |
+
|
| 184 |
+
# Apply mask to Q, K, V, O projections
|
| 185 |
+
for proj_name in ["q_proj", "o_proj"]:
|
| 186 |
+
proj = getattr(module, proj_name, None)
|
| 187 |
+
if proj is not None:
|
| 188 |
+
if proj_name == "q_proj":
|
| 189 |
+
proj.weight.data *= mask.unsqueeze(1).to(proj.weight.device)
|
| 190 |
+
else:
|
| 191 |
+
proj.weight.data *= mask.unsqueeze(0).to(proj.weight.device)
|
| 192 |
+
|
| 193 |
+
return model
|
| 194 |
+
|
| 195 |
+
def _prune_ffn_neurons(
|
| 196 |
+
self,
|
| 197 |
+
model: nn.Module,
|
| 198 |
+
importance: Dict[str, torch.Tensor],
|
| 199 |
+
) -> nn.Module:
|
| 200 |
+
"""Prune least important FFN neurons."""
|
| 201 |
+
for name, module in model.named_modules():
|
| 202 |
+
if isinstance(module, nn.Linear) and "gate_proj" in name:
|
| 203 |
+
neuron_key = f"{name}.neurons"
|
| 204 |
+
if neuron_key in importance:
|
| 205 |
+
scores = importance[neuron_key]
|
| 206 |
+
num_neurons = len(scores)
|
| 207 |
+
num_prune = int(num_neurons * self.config.sparsity)
|
| 208 |
+
num_keep = num_neurons - num_prune
|
| 209 |
+
|
| 210 |
+
_, keep_indices = torch.topk(scores, num_keep)
|
| 211 |
+
|
| 212 |
+
# Create neuron mask
|
| 213 |
+
mask = torch.zeros(num_neurons)
|
| 214 |
+
mask[keep_indices] = 1
|
| 215 |
+
|
| 216 |
+
# Apply to gate and up projections
|
| 217 |
+
module.weight.data *= mask.unsqueeze(1).to(module.weight.device)
|
| 218 |
+
|
| 219 |
+
return model
|
| 220 |
+
|
| 221 |
+
def _prune_experts(
|
| 222 |
+
self,
|
| 223 |
+
model: nn.Module,
|
| 224 |
+
importance: Dict[str, torch.Tensor],
|
| 225 |
+
) -> nn.Module:
|
| 226 |
+
"""Prune least important MoE experts."""
|
| 227 |
+
for name, module in model.named_modules():
|
| 228 |
+
if hasattr(module, "experts"):
|
| 229 |
+
expert_key = f"{name}.experts"
|
| 230 |
+
if expert_key in importance:
|
| 231 |
+
scores = importance[expert_key]
|
| 232 |
+
num_experts = len(scores)
|
| 233 |
+
num_prune = int(num_experts * self.config.sparsity)
|
| 234 |
+
num_keep = max(self.config.min_experts, num_experts - num_prune)
|
| 235 |
+
|
| 236 |
+
_, keep_indices = torch.topk(scores, num_keep)
|
| 237 |
+
keep_indices = keep_indices.sort()[0].tolist()
|
| 238 |
+
|
| 239 |
+
# Zero out pruned experts (actual removal requires model restructuring)
|
| 240 |
+
for i, expert in enumerate(module.experts):
|
| 241 |
+
if i not in keep_indices:
|
| 242 |
+
for param in expert.parameters():
|
| 243 |
+
param.data.zero_()
|
| 244 |
+
|
| 245 |
+
print(f"Pruned experts in {name}: keeping {keep_indices}")
|
| 246 |
+
|
| 247 |
+
return model
|
| 248 |
+
|
| 249 |
+
def _wanda_pruning(
|
| 250 |
+
self,
|
| 251 |
+
model: nn.Module,
|
| 252 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 253 |
+
) -> nn.Module:
|
| 254 |
+
"""
|
| 255 |
+
Apply WANDA (Weights AND Activations) pruning.
|
| 256 |
+
Combines weight magnitude with activation magnitude.
|
| 257 |
+
"""
|
| 258 |
+
if calibration_data is None:
|
| 259 |
+
print("Warning: WANDA requires calibration data, falling back to magnitude pruning")
|
| 260 |
+
return self._magnitude_pruning(model)
|
| 261 |
+
|
| 262 |
+
model.eval()
|
| 263 |
+
activation_norms = {}
|
| 264 |
+
|
| 265 |
+
# Hook to capture activations
|
| 266 |
+
def hook_fn(name):
|
| 267 |
+
def hook(module, input, output):
|
| 268 |
+
if isinstance(input, tuple):
|
| 269 |
+
input = input[0]
|
| 270 |
+
activation_norms[name] = input.abs().mean(dim=(0, 1))
|
| 271 |
+
return hook
|
| 272 |
+
|
| 273 |
+
# Register hooks
|
| 274 |
+
handles = []
|
| 275 |
+
for name, module in model.named_modules():
|
| 276 |
+
if isinstance(module, nn.Linear):
|
| 277 |
+
handles.append(module.register_forward_hook(hook_fn(name)))
|
| 278 |
+
|
| 279 |
+
# Run calibration
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
model(calibration_data)
|
| 282 |
+
|
| 283 |
+
# Remove hooks
|
| 284 |
+
for handle in handles:
|
| 285 |
+
handle.remove()
|
| 286 |
+
|
| 287 |
+
# Compute WANDA scores and prune
|
| 288 |
+
for name, module in model.named_modules():
|
| 289 |
+
if isinstance(module, nn.Linear) and name in activation_norms:
|
| 290 |
+
weight = module.weight.data
|
| 291 |
+
act_norm = activation_norms[name].to(weight.device)
|
| 292 |
+
|
| 293 |
+
# WANDA score: |W| * |X|
|
| 294 |
+
wanda_score = weight.abs() * act_norm.unsqueeze(0)
|
| 295 |
+
|
| 296 |
+
# Prune based on scores
|
| 297 |
+
threshold = torch.quantile(wanda_score.flatten(), self.config.sparsity)
|
| 298 |
+
mask = (wanda_score >= threshold).float()
|
| 299 |
+
module.weight.data *= mask
|
| 300 |
+
|
| 301 |
+
return model
|
| 302 |
+
|
| 303 |
+
def compute_sparsity(self, model: nn.Module) -> Dict[str, float]:
|
| 304 |
+
"""Compute actual sparsity of the model."""
|
| 305 |
+
total_params = 0
|
| 306 |
+
zero_params = 0
|
| 307 |
+
layer_sparsity = {}
|
| 308 |
+
|
| 309 |
+
for name, module in model.named_modules():
|
| 310 |
+
if isinstance(module, nn.Linear):
|
| 311 |
+
params = module.weight.numel()
|
| 312 |
+
zeros = (module.weight == 0).sum().item()
|
| 313 |
+
total_params += params
|
| 314 |
+
zero_params += zeros
|
| 315 |
+
layer_sparsity[name] = zeros / params
|
| 316 |
+
|
| 317 |
+
return {
|
| 318 |
+
"total_sparsity": zero_params / max(1, total_params),
|
| 319 |
+
"layer_sparsity": layer_sparsity,
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def prune_model(
|
| 324 |
+
model: nn.Module,
|
| 325 |
+
sparsity: float = 0.5,
|
| 326 |
+
method: str = "magnitude",
|
| 327 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 328 |
+
) -> nn.Module:
|
| 329 |
+
"""
|
| 330 |
+
Convenience function to prune a model.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
model: Model to prune
|
| 334 |
+
sparsity: Target sparsity ratio
|
| 335 |
+
method: Pruning method (magnitude, structured, wanda)
|
| 336 |
+
calibration_data: Calibration data for importance estimation
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
Pruned model
|
| 340 |
+
"""
|
| 341 |
+
config = PruningConfig(
|
| 342 |
+
method=PruningMethod(method),
|
| 343 |
+
sparsity=sparsity,
|
| 344 |
+
)
|
| 345 |
+
pruner = Mind2Pruner(config)
|
| 346 |
+
return pruner.prune(model, calibration_data)
|
optimization/quantization.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Quantization Toolkit
|
| 3 |
+
INT4/INT8 quantization for efficient inference on edge devices.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Optional, Dict, Any, Tuple, List
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class QuantizationType(Enum):
|
| 18 |
+
"""Supported quantization types."""
|
| 19 |
+
INT8_DYNAMIC = "int8_dynamic"
|
| 20 |
+
INT8_STATIC = "int8_static"
|
| 21 |
+
INT4_AWQ = "int4_awq"
|
| 22 |
+
INT4_GPTQ = "int4_gptq"
|
| 23 |
+
FP8 = "fp8"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class QuantizationConfig:
|
| 28 |
+
"""Configuration for quantization."""
|
| 29 |
+
quant_type: QuantizationType = QuantizationType.INT4_AWQ
|
| 30 |
+
bits: int = 4
|
| 31 |
+
group_size: int = 128
|
| 32 |
+
use_double_quant: bool = False
|
| 33 |
+
compute_dtype: torch.dtype = torch.float16
|
| 34 |
+
calibration_samples: int = 128
|
| 35 |
+
calibration_seq_len: int = 512
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Int4Linear(nn.Module):
|
| 39 |
+
"""INT4 quantized linear layer with group-wise quantization."""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
in_features: int,
|
| 44 |
+
out_features: int,
|
| 45 |
+
bias: bool = False,
|
| 46 |
+
group_size: int = 128,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.in_features = in_features
|
| 50 |
+
self.out_features = out_features
|
| 51 |
+
self.group_size = group_size
|
| 52 |
+
|
| 53 |
+
# Number of groups
|
| 54 |
+
self.num_groups = math.ceil(in_features / group_size)
|
| 55 |
+
|
| 56 |
+
# Packed INT4 weights (2 values per byte)
|
| 57 |
+
packed_size = out_features * math.ceil(in_features / 2)
|
| 58 |
+
self.register_buffer("qweight", torch.zeros(packed_size, dtype=torch.uint8))
|
| 59 |
+
|
| 60 |
+
# Scales and zeros per group
|
| 61 |
+
self.register_buffer("scales", torch.zeros(out_features, self.num_groups, dtype=torch.float16))
|
| 62 |
+
self.register_buffer("zeros", torch.zeros(out_features, self.num_groups, dtype=torch.float16))
|
| 63 |
+
|
| 64 |
+
if bias:
|
| 65 |
+
self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16))
|
| 66 |
+
else:
|
| 67 |
+
self.bias = None
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def pack_int4(values: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
"""Pack two INT4 values into one INT8."""
|
| 72 |
+
assert values.shape[-1] % 2 == 0
|
| 73 |
+
low = values[..., 0::2] & 0xF
|
| 74 |
+
high = values[..., 1::2] & 0xF
|
| 75 |
+
return (high << 4 | low).to(torch.uint8)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def unpack_int4(packed: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
"""Unpack INT8 to two INT4 values."""
|
| 80 |
+
low = packed & 0xF
|
| 81 |
+
high = (packed >> 4) & 0xF
|
| 82 |
+
return torch.stack([low, high], dim=-1).flatten(-2)
|
| 83 |
+
|
| 84 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
"""Dequantize and compute linear transformation."""
|
| 86 |
+
input_dtype = x.dtype
|
| 87 |
+
|
| 88 |
+
# Unpack weights
|
| 89 |
+
unpacked = self.unpack_int4(self.qweight)
|
| 90 |
+
unpacked = unpacked.view(self.out_features, self.in_features)
|
| 91 |
+
|
| 92 |
+
# Dequantize
|
| 93 |
+
weight = torch.zeros(self.out_features, self.in_features, dtype=self.scales.dtype, device=x.device)
|
| 94 |
+
for g in range(self.num_groups):
|
| 95 |
+
start = g * self.group_size
|
| 96 |
+
end = min((g + 1) * self.group_size, self.in_features)
|
| 97 |
+
weight[:, start:end] = (unpacked[:, start:end].float() - self.zeros[:, g:g+1]) * self.scales[:, g:g+1]
|
| 98 |
+
|
| 99 |
+
weight = weight.to(input_dtype)
|
| 100 |
+
output = F.linear(x, weight, self.bias)
|
| 101 |
+
return output
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def from_float(cls, module: nn.Linear, group_size: int = 128) -> "Int4Linear":
|
| 105 |
+
"""Convert a float linear layer to INT4."""
|
| 106 |
+
int4_layer = cls(
|
| 107 |
+
module.in_features,
|
| 108 |
+
module.out_features,
|
| 109 |
+
bias=module.bias is not None,
|
| 110 |
+
group_size=group_size,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
weight = module.weight.data.float()
|
| 114 |
+
out_features, in_features = weight.shape
|
| 115 |
+
|
| 116 |
+
# Quantize per group
|
| 117 |
+
num_groups = math.ceil(in_features / group_size)
|
| 118 |
+
qweight = torch.zeros_like(weight, dtype=torch.int8)
|
| 119 |
+
|
| 120 |
+
for g in range(num_groups):
|
| 121 |
+
start = g * group_size
|
| 122 |
+
end = min((g + 1) * group_size, in_features)
|
| 123 |
+
group_weight = weight[:, start:end]
|
| 124 |
+
|
| 125 |
+
# Compute scales and zeros
|
| 126 |
+
min_val = group_weight.min(dim=1, keepdim=True)[0]
|
| 127 |
+
max_val = group_weight.max(dim=1, keepdim=True)[0]
|
| 128 |
+
|
| 129 |
+
scale = (max_val - min_val) / 15.0
|
| 130 |
+
scale = scale.clamp(min=1e-8)
|
| 131 |
+
zero = -min_val / scale
|
| 132 |
+
|
| 133 |
+
int4_layer.scales[:, g] = scale.squeeze().to(torch.float16)
|
| 134 |
+
int4_layer.zeros[:, g] = zero.squeeze().to(torch.float16)
|
| 135 |
+
|
| 136 |
+
# Quantize
|
| 137 |
+
qweight[:, start:end] = ((group_weight / scale + zero).round().clamp(0, 15)).to(torch.int8)
|
| 138 |
+
|
| 139 |
+
# Pack weights
|
| 140 |
+
int4_layer.qweight.copy_(cls.pack_int4(qweight.flatten()))
|
| 141 |
+
|
| 142 |
+
if module.bias is not None:
|
| 143 |
+
int4_layer.bias = module.bias.data.to(torch.float16)
|
| 144 |
+
|
| 145 |
+
return int4_layer
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Mind2Quantizer:
|
| 149 |
+
"""Quantizer for MiniMind models."""
|
| 150 |
+
|
| 151 |
+
def __init__(self, config: Optional[QuantizationConfig] = None):
|
| 152 |
+
self.config = config or QuantizationConfig()
|
| 153 |
+
|
| 154 |
+
def quantize(
|
| 155 |
+
self,
|
| 156 |
+
model: nn.Module,
|
| 157 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 158 |
+
) -> nn.Module:
|
| 159 |
+
"""
|
| 160 |
+
Quantize the model.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
model: Model to quantize
|
| 164 |
+
calibration_data: Calibration data for static quantization
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Quantized model
|
| 168 |
+
"""
|
| 169 |
+
if self.config.quant_type == QuantizationType.INT8_DYNAMIC:
|
| 170 |
+
return self._quantize_int8_dynamic(model)
|
| 171 |
+
elif self.config.quant_type == QuantizationType.INT4_AWQ:
|
| 172 |
+
return self._quantize_int4_awq(model, calibration_data)
|
| 173 |
+
elif self.config.quant_type == QuantizationType.INT4_GPTQ:
|
| 174 |
+
return self._quantize_int4_gptq(model, calibration_data)
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(f"Unsupported quantization type: {self.config.quant_type}")
|
| 177 |
+
|
| 178 |
+
def _quantize_int8_dynamic(self, model: nn.Module) -> nn.Module:
|
| 179 |
+
"""Apply INT8 dynamic quantization."""
|
| 180 |
+
return torch.quantization.quantize_dynamic(
|
| 181 |
+
model,
|
| 182 |
+
{nn.Linear},
|
| 183 |
+
dtype=torch.qint8,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def _quantize_int4_awq(
|
| 187 |
+
self,
|
| 188 |
+
model: nn.Module,
|
| 189 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 190 |
+
) -> nn.Module:
|
| 191 |
+
"""Apply AWQ-style INT4 quantization."""
|
| 192 |
+
model = model.cpu().float()
|
| 193 |
+
|
| 194 |
+
# Replace linear layers
|
| 195 |
+
for name, module in model.named_modules():
|
| 196 |
+
if isinstance(module, nn.Linear) and module.weight.shape[0] >= 64:
|
| 197 |
+
parent_name = ".".join(name.split(".")[:-1])
|
| 198 |
+
child_name = name.split(".")[-1]
|
| 199 |
+
|
| 200 |
+
parent = model
|
| 201 |
+
for part in parent_name.split("."):
|
| 202 |
+
if part:
|
| 203 |
+
parent = getattr(parent, part)
|
| 204 |
+
|
| 205 |
+
int4_linear = Int4Linear.from_float(module, self.config.group_size)
|
| 206 |
+
setattr(parent, child_name, int4_linear)
|
| 207 |
+
|
| 208 |
+
return model
|
| 209 |
+
|
| 210 |
+
def _quantize_int4_gptq(
|
| 211 |
+
self,
|
| 212 |
+
model: nn.Module,
|
| 213 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 214 |
+
) -> nn.Module:
|
| 215 |
+
"""Apply GPTQ-style INT4 quantization with calibration."""
|
| 216 |
+
# GPTQ requires calibration for optimal quantization
|
| 217 |
+
if calibration_data is None:
|
| 218 |
+
print("Warning: GPTQ without calibration, falling back to AWQ")
|
| 219 |
+
return self._quantize_int4_awq(model, calibration_data)
|
| 220 |
+
|
| 221 |
+
model = model.cpu().float()
|
| 222 |
+
|
| 223 |
+
# Run calibration to collect activation statistics
|
| 224 |
+
model.eval()
|
| 225 |
+
with torch.no_grad():
|
| 226 |
+
model(calibration_data)
|
| 227 |
+
|
| 228 |
+
# Apply GPTQ quantization
|
| 229 |
+
for name, module in model.named_modules():
|
| 230 |
+
if isinstance(module, nn.Linear) and module.weight.shape[0] >= 64:
|
| 231 |
+
parent_name = ".".join(name.split(".")[:-1])
|
| 232 |
+
child_name = name.split(".")[-1]
|
| 233 |
+
|
| 234 |
+
parent = model
|
| 235 |
+
for part in parent_name.split("."):
|
| 236 |
+
if part:
|
| 237 |
+
parent = getattr(parent, part)
|
| 238 |
+
|
| 239 |
+
int4_linear = Int4Linear.from_float(module, self.config.group_size)
|
| 240 |
+
setattr(parent, child_name, int4_linear)
|
| 241 |
+
|
| 242 |
+
return model
|
| 243 |
+
|
| 244 |
+
def estimate_model_size(self, model: nn.Module) -> Dict[str, float]:
|
| 245 |
+
"""Estimate model size in different formats."""
|
| 246 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
"params": total_params,
|
| 250 |
+
"fp32_gb": (total_params * 4) / (1024**3),
|
| 251 |
+
"fp16_gb": (total_params * 2) / (1024**3),
|
| 252 |
+
"int8_gb": (total_params * 1) / (1024**3),
|
| 253 |
+
"int4_gb": (total_params * 0.5) / (1024**3),
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def quantize_model(
|
| 258 |
+
model: nn.Module,
|
| 259 |
+
quant_type: str = "int4_awq",
|
| 260 |
+
group_size: int = 128,
|
| 261 |
+
calibration_data: Optional[torch.Tensor] = None,
|
| 262 |
+
) -> nn.Module:
|
| 263 |
+
"""
|
| 264 |
+
Convenience function to quantize a model.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
model: Model to quantize
|
| 268 |
+
quant_type: Quantization type (int4_awq, int4_gptq, int8_dynamic)
|
| 269 |
+
group_size: Group size for INT4 quantization
|
| 270 |
+
calibration_data: Calibration data for GPTQ
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Quantized model
|
| 274 |
+
"""
|
| 275 |
+
config = QuantizationConfig(
|
| 276 |
+
quant_type=QuantizationType(quant_type),
|
| 277 |
+
group_size=group_size,
|
| 278 |
+
)
|
| 279 |
+
quantizer = Mind2Quantizer(config)
|
| 280 |
+
return quantizer.quantize(model, calibration_data)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
if __name__ == "__main__":
|
| 284 |
+
# Test quantization
|
| 285 |
+
import sys
|
| 286 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 287 |
+
from model import create_model
|
| 288 |
+
|
| 289 |
+
print("Testing quantization...")
|
| 290 |
+
|
| 291 |
+
# Create a small model for testing
|
| 292 |
+
model = create_model("mind2-nano", device="cpu", dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
quantizer = Mind2Quantizer()
|
| 295 |
+
|
| 296 |
+
# Estimate sizes
|
| 297 |
+
sizes = quantizer.estimate_model_size(model)
|
| 298 |
+
print(f"Model sizes:")
|
| 299 |
+
for fmt, size in sizes.items():
|
| 300 |
+
print(f" {fmt}: {size:.3f}")
|
| 301 |
+
|
| 302 |
+
# Quantize
|
| 303 |
+
print("\nQuantizing to INT4...")
|
| 304 |
+
quantized_model = quantizer.quantize(model)
|
| 305 |
+
|
| 306 |
+
# Test inference
|
| 307 |
+
input_ids = torch.randint(0, 1000, (1, 32))
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
_, logits, _, _ = quantized_model(input_ids)
|
| 310 |
+
print(f"Output shape: {logits.shape}")
|
| 311 |
+
print("✓ Quantization successful!")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "minimind"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "MiniMind (Mind2) - Lightweight language models for edge deployment"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "Apache-2.0"}
|
| 11 |
+
authors = [
|
| 12 |
+
{name = "Matrix Agent", email = "contact@minimind.ai"}
|
| 13 |
+
]
|
| 14 |
+
requires-python = ">=3.9"
|
| 15 |
+
classifiers = [
|
| 16 |
+
"Development Status :: 4 - Beta",
|
| 17 |
+
"Intended Audience :: Developers",
|
| 18 |
+
"Intended Audience :: Science/Research",
|
| 19 |
+
"License :: OSI Approved :: Apache Software License",
|
| 20 |
+
"Operating System :: OS Independent",
|
| 21 |
+
"Programming Language :: Python :: 3",
|
| 22 |
+
"Programming Language :: Python :: 3.9",
|
| 23 |
+
"Programming Language :: Python :: 3.10",
|
| 24 |
+
"Programming Language :: Python :: 3.11",
|
| 25 |
+
"Programming Language :: Python :: 3.12",
|
| 26 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 27 |
+
]
|
| 28 |
+
dependencies = [
|
| 29 |
+
"torch>=2.1.0",
|
| 30 |
+
"numpy>=1.24.0",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
[project.optional-dependencies]
|
| 34 |
+
train = [
|
| 35 |
+
"transformers>=4.35.0",
|
| 36 |
+
"datasets>=2.14.0",
|
| 37 |
+
"accelerate>=0.24.0",
|
| 38 |
+
"wandb>=0.15.0",
|
| 39 |
+
]
|
| 40 |
+
export = [
|
| 41 |
+
"onnx>=1.14.0",
|
| 42 |
+
"onnxruntime>=1.16.0",
|
| 43 |
+
]
|
| 44 |
+
dev = [
|
| 45 |
+
"pytest>=7.4.0",
|
| 46 |
+
"black>=23.0.0",
|
| 47 |
+
"isort>=5.12.0",
|
| 48 |
+
"mypy>=1.5.0",
|
| 49 |
+
"ruff>=0.1.0",
|
| 50 |
+
]
|
| 51 |
+
all = [
|
| 52 |
+
"minimind[train,export,dev]",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
[project.scripts]
|
| 56 |
+
minimind-train = "scripts.train:main"
|
| 57 |
+
minimind-export = "scripts.export:main"
|
| 58 |
+
|
| 59 |
+
[project.urls]
|
| 60 |
+
Homepage = "https://github.com/minimind/minimind"
|
| 61 |
+
Documentation = "https://github.com/minimind/minimind#readme"
|
| 62 |
+
Repository = "https://github.com/minimind/minimind"
|
| 63 |
+
Issues = "https://github.com/minimind/minimind/issues"
|
| 64 |
+
|
| 65 |
+
[tool.setuptools.packages.find]
|
| 66 |
+
exclude = ["tests*", "android*"]
|
| 67 |
+
|
| 68 |
+
[tool.black]
|
| 69 |
+
line-length = 100
|
| 70 |
+
target-version = ["py39", "py310", "py311", "py312"]
|
| 71 |
+
include = '\.pyi?$'
|
| 72 |
+
exclude = '''
|
| 73 |
+
/(
|
| 74 |
+
\.git
|
| 75 |
+
| \.mypy_cache
|
| 76 |
+
| \.venv
|
| 77 |
+
| build
|
| 78 |
+
| dist
|
| 79 |
+
| android
|
| 80 |
+
)/
|
| 81 |
+
'''
|
| 82 |
+
|
| 83 |
+
[tool.isort]
|
| 84 |
+
profile = "black"
|
| 85 |
+
line_length = 100
|
| 86 |
+
skip = [".git", ".venv", "build", "dist", "android"]
|
| 87 |
+
|
| 88 |
+
[tool.ruff]
|
| 89 |
+
line-length = 100
|
| 90 |
+
target-version = "py39"
|
| 91 |
+
exclude = [".git", ".venv", "build", "dist", "android"]
|
| 92 |
+
|
| 93 |
+
[tool.ruff.lint]
|
| 94 |
+
select = ["E", "F", "W", "I", "N", "B", "C4"]
|
| 95 |
+
ignore = ["E501"]
|
| 96 |
+
|
| 97 |
+
[tool.mypy]
|
| 98 |
+
python_version = "3.9"
|
| 99 |
+
warn_return_any = true
|
| 100 |
+
warn_unused_configs = true
|
| 101 |
+
ignore_missing_imports = true
|
| 102 |
+
exclude = ["android", "build", "dist"]
|
| 103 |
+
|
| 104 |
+
[tool.pytest.ini_options]
|
| 105 |
+
testpaths = ["tests"]
|
| 106 |
+
python_files = ["test_*.py"]
|
| 107 |
+
python_functions = ["test_*"]
|
| 108 |
+
addopts = "-v --tb=short"
|
requirements.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MiniMind (Mind2) Requirements
|
| 2 |
+
|
| 3 |
+
# Core
|
| 4 |
+
torch>=2.1.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
|
| 7 |
+
# Training
|
| 8 |
+
transformers>=4.35.0
|
| 9 |
+
datasets>=2.14.0
|
| 10 |
+
accelerate>=0.24.0
|
| 11 |
+
wandb>=0.15.0
|
| 12 |
+
|
| 13 |
+
# Optimization & Export
|
| 14 |
+
onnx>=1.14.0
|
| 15 |
+
onnxruntime>=1.16.0
|
| 16 |
+
|
| 17 |
+
# Utilities
|
| 18 |
+
tqdm>=4.65.0
|
| 19 |
+
pyyaml>=6.0
|
| 20 |
+
jsonlines>=3.1.0
|
| 21 |
+
|
| 22 |
+
# Optional: Flash Attention (install separately)
|
| 23 |
+
# pip install flash-attn --no-build-isolation
|
| 24 |
+
|
| 25 |
+
# Optional: For INT4 quantization
|
| 26 |
+
# auto-gptq>=0.4.0
|
| 27 |
+
# autoawq>=0.1.0
|
| 28 |
+
|
| 29 |
+
# Development
|
| 30 |
+
pytest>=7.4.0
|
| 31 |
+
black>=23.0.0
|
| 32 |
+
isort>=5.12.0
|
scripts/export.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
MiniMind Export Script
|
| 4 |
+
Export models to ONNX and GGUF formats for deployment.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from configs.model_config import get_config
|
| 16 |
+
from model import Mind2ForCausalLM
|
| 17 |
+
from optimization.export import export_to_onnx, export_to_gguf, export_for_android, ExportConfig
|
| 18 |
+
from optimization.quantization import quantize_model, QuantizationConfig, QuantizationType
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
parser = argparse.ArgumentParser(description="Export MiniMind models")
|
| 23 |
+
|
| 24 |
+
parser.add_argument("--model", type=str, default="mind2-lite",
|
| 25 |
+
choices=["mind2-nano", "mind2-lite", "mind2-pro"])
|
| 26 |
+
parser.add_argument("--checkpoint", type=str, default=None,
|
| 27 |
+
help="Path to model checkpoint")
|
| 28 |
+
parser.add_argument("--output-dir", type=str, default="./exports")
|
| 29 |
+
|
| 30 |
+
parser.add_argument("--format", type=str, nargs="+",
|
| 31 |
+
default=["onnx", "gguf"],
|
| 32 |
+
choices=["onnx", "gguf", "android"])
|
| 33 |
+
|
| 34 |
+
parser.add_argument("--quantize", type=str, default=None,
|
| 35 |
+
choices=["int4_awq", "int4_gptq", "int8_dynamic"])
|
| 36 |
+
parser.add_argument("--max-seq-len", type=int, default=2048)
|
| 37 |
+
|
| 38 |
+
return parser.parse_args()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
args = parse_args()
|
| 43 |
+
|
| 44 |
+
print(f"=" * 60)
|
| 45 |
+
print(f"MiniMind Export")
|
| 46 |
+
print(f"=" * 60)
|
| 47 |
+
print(f"Model: {args.model}")
|
| 48 |
+
print(f"Formats: {args.format}")
|
| 49 |
+
print(f"Quantization: {args.quantize or 'None'}")
|
| 50 |
+
|
| 51 |
+
# Load model
|
| 52 |
+
config = get_config(args.model)
|
| 53 |
+
model = Mind2ForCausalLM(config)
|
| 54 |
+
|
| 55 |
+
if args.checkpoint:
|
| 56 |
+
print(f"Loading checkpoint from {args.checkpoint}")
|
| 57 |
+
state_dict = torch.load(args.checkpoint, map_location="cpu")
|
| 58 |
+
model.load_state_dict(state_dict)
|
| 59 |
+
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
# Quantize if requested
|
| 63 |
+
if args.quantize:
|
| 64 |
+
print(f"\nQuantizing to {args.quantize}...")
|
| 65 |
+
model = quantize_model(model, args.quantize)
|
| 66 |
+
print("Quantization complete!")
|
| 67 |
+
|
| 68 |
+
# Export
|
| 69 |
+
output_dir = Path(args.output_dir)
|
| 70 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
|
| 72 |
+
export_config = ExportConfig(
|
| 73 |
+
max_seq_len=args.max_seq_len,
|
| 74 |
+
optimize_for_mobile=True,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
outputs = {}
|
| 78 |
+
|
| 79 |
+
if "android" in args.format:
|
| 80 |
+
print(f"\nExporting for Android...")
|
| 81 |
+
outputs = export_for_android(model, str(output_dir / "android"), config)
|
| 82 |
+
else:
|
| 83 |
+
if "onnx" in args.format:
|
| 84 |
+
print(f"\nExporting to ONNX...")
|
| 85 |
+
onnx_path = output_dir / f"{args.model}.onnx"
|
| 86 |
+
outputs["onnx"] = export_to_onnx(model, str(onnx_path), export_config)
|
| 87 |
+
|
| 88 |
+
if "gguf" in args.format:
|
| 89 |
+
print(f"\nExporting to GGUF...")
|
| 90 |
+
gguf_path = output_dir / f"{args.model}.gguf"
|
| 91 |
+
outputs["gguf"] = export_to_gguf(model, str(gguf_path), config, export_config)
|
| 92 |
+
|
| 93 |
+
print(f"\n" + "=" * 60)
|
| 94 |
+
print("Export complete!")
|
| 95 |
+
print("=" * 60)
|
| 96 |
+
for fmt, path in outputs.items():
|
| 97 |
+
print(f" {fmt}: {path}")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
main()
|
scripts/train.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
MiniMind Training Script
|
| 4 |
+
Train Mind2 models from scratch or with knowledge distillation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# Add parent directory to path
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
|
| 17 |
+
from configs.model_config import get_config, estimate_params
|
| 18 |
+
from model import Mind2ForCausalLM
|
| 19 |
+
from training.trainer import Mind2Trainer, TrainingConfig
|
| 20 |
+
from training.distillation import DistillationTrainer, DistillationConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def parse_args():
|
| 24 |
+
parser = argparse.ArgumentParser(description="Train MiniMind (Mind2) models")
|
| 25 |
+
|
| 26 |
+
# Model
|
| 27 |
+
parser.add_argument("--model", type=str, default="mind2-lite",
|
| 28 |
+
choices=["mind2-nano", "mind2-lite", "mind2-pro"],
|
| 29 |
+
help="Model variant to train")
|
| 30 |
+
|
| 31 |
+
# Data
|
| 32 |
+
parser.add_argument("--train-data", type=str, required=True,
|
| 33 |
+
help="Path to training data (JSONL format)")
|
| 34 |
+
parser.add_argument("--eval-data", type=str, default=None,
|
| 35 |
+
help="Path to evaluation data")
|
| 36 |
+
|
| 37 |
+
# Training
|
| 38 |
+
parser.add_argument("--epochs", type=int, default=3)
|
| 39 |
+
parser.add_argument("--batch-size", type=int, default=8)
|
| 40 |
+
parser.add_argument("--grad-accum", type=int, default=4)
|
| 41 |
+
parser.add_argument("--lr", type=float, default=3e-4)
|
| 42 |
+
parser.add_argument("--warmup-steps", type=int, default=1000)
|
| 43 |
+
parser.add_argument("--max-steps", type=int, default=None)
|
| 44 |
+
|
| 45 |
+
# Distillation
|
| 46 |
+
parser.add_argument("--teacher-model", type=str, default=None,
|
| 47 |
+
help="Path to teacher model for distillation")
|
| 48 |
+
parser.add_argument("--temperature", type=float, default=2.0)
|
| 49 |
+
parser.add_argument("--alpha-kd", type=float, default=0.5)
|
| 50 |
+
|
| 51 |
+
# Output
|
| 52 |
+
parser.add_argument("--output-dir", type=str, default="./outputs")
|
| 53 |
+
parser.add_argument("--save-steps", type=int, default=1000)
|
| 54 |
+
|
| 55 |
+
# Hardware
|
| 56 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 57 |
+
parser.add_argument("--dtype", type=str, default="float16",
|
| 58 |
+
choices=["float16", "bfloat16", "float32"])
|
| 59 |
+
|
| 60 |
+
return parser.parse_args()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
args = parse_args()
|
| 65 |
+
|
| 66 |
+
# Setup
|
| 67 |
+
device = args.device if torch.cuda.is_available() else "cpu"
|
| 68 |
+
dtype = getattr(torch, args.dtype)
|
| 69 |
+
|
| 70 |
+
print(f"=" * 60)
|
| 71 |
+
print(f"MiniMind Training")
|
| 72 |
+
print(f"=" * 60)
|
| 73 |
+
print(f"Model: {args.model}")
|
| 74 |
+
print(f"Device: {device}, Dtype: {args.dtype}")
|
| 75 |
+
|
| 76 |
+
# Create model
|
| 77 |
+
config = get_config(args.model)
|
| 78 |
+
model = Mind2ForCausalLM(config).to(device=device, dtype=dtype)
|
| 79 |
+
|
| 80 |
+
# Print model info
|
| 81 |
+
params = estimate_params(config)
|
| 82 |
+
print(f"Total params: {params['total_params_b']:.2f}B")
|
| 83 |
+
print(f"Active params: {params['active_params_b']:.2f}B")
|
| 84 |
+
print(f"Activation ratio: {params['activation_ratio']:.1%}")
|
| 85 |
+
|
| 86 |
+
# Create dummy dataloader (replace with actual data loading)
|
| 87 |
+
print(f"\nNote: Using dummy data. Replace with actual data loading.")
|
| 88 |
+
train_data = torch.randint(0, config.vocab_size, (1000, 512))
|
| 89 |
+
train_loader = DataLoader(
|
| 90 |
+
torch.utils.data.TensorDataset(train_data, train_data),
|
| 91 |
+
batch_size=args.batch_size,
|
| 92 |
+
shuffle=True
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Training configuration
|
| 96 |
+
if args.teacher_model:
|
| 97 |
+
# Knowledge distillation
|
| 98 |
+
print(f"\nUsing knowledge distillation from: {args.teacher_model}")
|
| 99 |
+
|
| 100 |
+
distill_config = DistillationConfig(
|
| 101 |
+
learning_rate=args.lr,
|
| 102 |
+
num_epochs=args.epochs,
|
| 103 |
+
batch_size=args.batch_size,
|
| 104 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 105 |
+
temperature=args.temperature,
|
| 106 |
+
alpha_kd=args.alpha_kd,
|
| 107 |
+
alpha_ce=1.0 - args.alpha_kd,
|
| 108 |
+
warmup_steps=args.warmup_steps,
|
| 109 |
+
max_steps=args.max_steps,
|
| 110 |
+
save_steps=args.save_steps,
|
| 111 |
+
output_dir=args.output_dir,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Load teacher (placeholder)
|
| 115 |
+
teacher = None # Load actual teacher model
|
| 116 |
+
|
| 117 |
+
trainer = DistillationTrainer(
|
| 118 |
+
student_model=model,
|
| 119 |
+
teacher_model=teacher,
|
| 120 |
+
train_dataloader=train_loader,
|
| 121 |
+
config=distill_config,
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
# Standard training
|
| 125 |
+
train_config = TrainingConfig(
|
| 126 |
+
learning_rate=args.lr,
|
| 127 |
+
num_epochs=args.epochs,
|
| 128 |
+
batch_size=args.batch_size,
|
| 129 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 130 |
+
warmup_steps=args.warmup_steps,
|
| 131 |
+
max_steps=args.max_steps,
|
| 132 |
+
save_steps=args.save_steps,
|
| 133 |
+
output_dir=args.output_dir,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Wrap dataloader to return dict format
|
| 137 |
+
class DictDataLoader:
|
| 138 |
+
def __init__(self, loader):
|
| 139 |
+
self.loader = loader
|
| 140 |
+
|
| 141 |
+
def __iter__(self):
|
| 142 |
+
for input_ids, labels in self.loader:
|
| 143 |
+
yield {
|
| 144 |
+
"input_ids": input_ids,
|
| 145 |
+
"labels": labels,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
def __len__(self):
|
| 149 |
+
return len(self.loader)
|
| 150 |
+
|
| 151 |
+
trainer = Mind2Trainer(
|
| 152 |
+
model=model,
|
| 153 |
+
train_dataloader=DictDataLoader(train_loader),
|
| 154 |
+
config=train_config,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Train
|
| 158 |
+
print(f"\nStarting training...")
|
| 159 |
+
results = trainer.train()
|
| 160 |
+
print(f"\nTraining complete!")
|
| 161 |
+
print(f"Results: {results}")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
main()
|
setup.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
MiniMind (Mind2) - Setup Script
|
| 4 |
+
Lightweight language models for edge deployment.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from setuptools import setup, find_packages
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Read README
|
| 11 |
+
readme_path = Path(__file__).parent / "README.md"
|
| 12 |
+
long_description = readme_path.read_text(encoding="utf-8") if readme_path.exists() else ""
|
| 13 |
+
|
| 14 |
+
# Read requirements
|
| 15 |
+
req_path = Path(__file__).parent / "requirements.txt"
|
| 16 |
+
requirements = []
|
| 17 |
+
if req_path.exists():
|
| 18 |
+
requirements = [
|
| 19 |
+
line.strip() for line in req_path.read_text().splitlines()
|
| 20 |
+
if line.strip() and not line.startswith("#")
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
setup(
|
| 24 |
+
name="minimind",
|
| 25 |
+
version="1.0.0",
|
| 26 |
+
author="Matrix Agent",
|
| 27 |
+
author_email="contact@minimind.ai",
|
| 28 |
+
description="MiniMind (Mind2) - Lightweight language models for edge deployment",
|
| 29 |
+
long_description=long_description,
|
| 30 |
+
long_description_content_type="text/markdown",
|
| 31 |
+
url="https://github.com/minimind/minimind",
|
| 32 |
+
project_urls={
|
| 33 |
+
"Documentation": "https://github.com/minimind/minimind#readme",
|
| 34 |
+
"Bug Tracker": "https://github.com/minimind/minimind/issues",
|
| 35 |
+
},
|
| 36 |
+
packages=find_packages(exclude=["tests", "tests.*", "android", "android.*"]),
|
| 37 |
+
classifiers=[
|
| 38 |
+
"Development Status :: 4 - Beta",
|
| 39 |
+
"Intended Audience :: Developers",
|
| 40 |
+
"Intended Audience :: Science/Research",
|
| 41 |
+
"License :: OSI Approved :: Apache Software License",
|
| 42 |
+
"Operating System :: OS Independent",
|
| 43 |
+
"Programming Language :: Python :: 3",
|
| 44 |
+
"Programming Language :: Python :: 3.9",
|
| 45 |
+
"Programming Language :: Python :: 3.10",
|
| 46 |
+
"Programming Language :: Python :: 3.11",
|
| 47 |
+
"Programming Language :: Python :: 3.12",
|
| 48 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 49 |
+
],
|
| 50 |
+
python_requires=">=3.9",
|
| 51 |
+
install_requires=[
|
| 52 |
+
"torch>=2.1.0",
|
| 53 |
+
"numpy>=1.24.0",
|
| 54 |
+
],
|
| 55 |
+
extras_require={
|
| 56 |
+
"train": [
|
| 57 |
+
"transformers>=4.35.0",
|
| 58 |
+
"datasets>=2.14.0",
|
| 59 |
+
"accelerate>=0.24.0",
|
| 60 |
+
"wandb>=0.15.0",
|
| 61 |
+
],
|
| 62 |
+
"export": [
|
| 63 |
+
"onnx>=1.14.0",
|
| 64 |
+
"onnxruntime>=1.16.0",
|
| 65 |
+
],
|
| 66 |
+
"dev": [
|
| 67 |
+
"pytest>=7.4.0",
|
| 68 |
+
"black>=23.0.0",
|
| 69 |
+
"isort>=5.12.0",
|
| 70 |
+
"mypy>=1.5.0",
|
| 71 |
+
],
|
| 72 |
+
"all": [
|
| 73 |
+
"transformers>=4.35.0",
|
| 74 |
+
"datasets>=2.14.0",
|
| 75 |
+
"accelerate>=0.24.0",
|
| 76 |
+
"wandb>=0.15.0",
|
| 77 |
+
"onnx>=1.14.0",
|
| 78 |
+
"onnxruntime>=1.16.0",
|
| 79 |
+
"pytest>=7.4.0",
|
| 80 |
+
"black>=23.0.0",
|
| 81 |
+
],
|
| 82 |
+
},
|
| 83 |
+
entry_points={
|
| 84 |
+
"console_scripts": [
|
| 85 |
+
"minimind-train=scripts.train:main",
|
| 86 |
+
"minimind-export=scripts.export:main",
|
| 87 |
+
],
|
| 88 |
+
},
|
| 89 |
+
include_package_data=True,
|
| 90 |
+
zip_safe=False,
|
| 91 |
+
)
|
training/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MiniMind Training Package"""
|
| 2 |
+
from .trainer import Mind2Trainer
|
| 3 |
+
from .distillation import DistillationTrainer
|
| 4 |
+
from .dataset import TextDataset, create_dataloader
|
| 5 |
+
|
| 6 |
+
__all__ = ["Mind2Trainer", "DistillationTrainer", "TextDataset", "create_dataloader"]
|
training/dataset.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Dataset and DataLoader utilities
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from typing import Optional, List, Dict, Any
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TextDataset(Dataset):
|
| 13 |
+
"""Simple text dataset for language model training."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
data_path: str,
|
| 18 |
+
tokenizer: Any,
|
| 19 |
+
max_length: int = 2048,
|
| 20 |
+
format_type: str = "jsonl", # jsonl, txt, parquet
|
| 21 |
+
):
|
| 22 |
+
self.tokenizer = tokenizer
|
| 23 |
+
self.max_length = max_length
|
| 24 |
+
self.data = self._load_data(data_path, format_type)
|
| 25 |
+
|
| 26 |
+
def _load_data(self, data_path: str, format_type: str) -> List[str]:
|
| 27 |
+
data = []
|
| 28 |
+
path = Path(data_path)
|
| 29 |
+
|
| 30 |
+
if format_type == "jsonl":
|
| 31 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 32 |
+
for line in f:
|
| 33 |
+
item = json.loads(line.strip())
|
| 34 |
+
text = item.get("text", item.get("content", ""))
|
| 35 |
+
if text:
|
| 36 |
+
data.append(text)
|
| 37 |
+
elif format_type == "txt":
|
| 38 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 39 |
+
data = [line.strip() for line in f if line.strip()]
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f"Unsupported format: {format_type}")
|
| 42 |
+
|
| 43 |
+
return data
|
| 44 |
+
|
| 45 |
+
def __len__(self) -> int:
|
| 46 |
+
return len(self.data)
|
| 47 |
+
|
| 48 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 49 |
+
text = self.data[idx]
|
| 50 |
+
encoding = self.tokenizer(
|
| 51 |
+
text,
|
| 52 |
+
truncation=True,
|
| 53 |
+
max_length=self.max_length,
|
| 54 |
+
padding="max_length",
|
| 55 |
+
return_tensors="pt",
|
| 56 |
+
)
|
| 57 |
+
return {
|
| 58 |
+
"input_ids": encoding["input_ids"].squeeze(0),
|
| 59 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 60 |
+
"labels": encoding["input_ids"].squeeze(0),
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DistillationDataset(Dataset):
|
| 65 |
+
"""Dataset for knowledge distillation with teacher logits."""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
data_path: str,
|
| 70 |
+
tokenizer: Any,
|
| 71 |
+
teacher_logits_path: Optional[str] = None,
|
| 72 |
+
max_length: int = 2048,
|
| 73 |
+
):
|
| 74 |
+
self.tokenizer = tokenizer
|
| 75 |
+
self.max_length = max_length
|
| 76 |
+
self.data = self._load_data(data_path)
|
| 77 |
+
self.teacher_logits = self._load_teacher_logits(teacher_logits_path) if teacher_logits_path else None
|
| 78 |
+
|
| 79 |
+
def _load_data(self, data_path: str) -> List[str]:
|
| 80 |
+
with open(data_path, "r", encoding="utf-8") as f:
|
| 81 |
+
return [json.loads(line.strip()).get("text", "") for line in f if line.strip()]
|
| 82 |
+
|
| 83 |
+
def _load_teacher_logits(self, path: str) -> Optional[torch.Tensor]:
|
| 84 |
+
if Path(path).exists():
|
| 85 |
+
return torch.load(path, map_location="cpu")
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
def __len__(self) -> int:
|
| 89 |
+
return len(self.data)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 92 |
+
text = self.data[idx]
|
| 93 |
+
encoding = self.tokenizer(
|
| 94 |
+
text,
|
| 95 |
+
truncation=True,
|
| 96 |
+
max_length=self.max_length,
|
| 97 |
+
padding="max_length",
|
| 98 |
+
return_tensors="pt",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
item = {
|
| 102 |
+
"input_ids": encoding["input_ids"].squeeze(0),
|
| 103 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 104 |
+
"labels": encoding["input_ids"].squeeze(0),
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
if self.teacher_logits is not None:
|
| 108 |
+
item["teacher_logits"] = self.teacher_logits[idx]
|
| 109 |
+
|
| 110 |
+
return item
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def create_dataloader(
|
| 114 |
+
dataset: Dataset,
|
| 115 |
+
batch_size: int = 8,
|
| 116 |
+
shuffle: bool = True,
|
| 117 |
+
num_workers: int = 4,
|
| 118 |
+
pin_memory: bool = True,
|
| 119 |
+
) -> DataLoader:
|
| 120 |
+
"""Create a DataLoader with optimal settings."""
|
| 121 |
+
return DataLoader(
|
| 122 |
+
dataset,
|
| 123 |
+
batch_size=batch_size,
|
| 124 |
+
shuffle=shuffle,
|
| 125 |
+
num_workers=num_workers,
|
| 126 |
+
pin_memory=pin_memory,
|
| 127 |
+
drop_last=True,
|
| 128 |
+
)
|
training/distillation.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Knowledge Distillation for MiniMind
|
| 3 |
+
Train smaller models using larger teacher models.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Optional, Dict, Any, Callable
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class DistillationConfig:
|
| 20 |
+
"""Configuration for knowledge distillation."""
|
| 21 |
+
# Distillation parameters
|
| 22 |
+
temperature: float = 2.0
|
| 23 |
+
alpha_ce: float = 0.5 # Weight for hard label loss
|
| 24 |
+
alpha_kd: float = 0.5 # Weight for distillation loss
|
| 25 |
+
alpha_hidden: float = 0.0 # Weight for hidden state matching
|
| 26 |
+
|
| 27 |
+
# Optimization
|
| 28 |
+
learning_rate: float = 1e-4
|
| 29 |
+
min_learning_rate: float = 1e-5
|
| 30 |
+
weight_decay: float = 0.1
|
| 31 |
+
warmup_steps: int = 500
|
| 32 |
+
grad_clip: float = 1.0
|
| 33 |
+
|
| 34 |
+
# Training
|
| 35 |
+
num_epochs: int = 5
|
| 36 |
+
batch_size: int = 4
|
| 37 |
+
gradient_accumulation_steps: int = 8
|
| 38 |
+
max_steps: Optional[int] = None
|
| 39 |
+
|
| 40 |
+
# Mixed precision
|
| 41 |
+
use_amp: bool = True
|
| 42 |
+
|
| 43 |
+
# Checkpointing
|
| 44 |
+
save_steps: int = 500
|
| 45 |
+
output_dir: str = "./distill_outputs"
|
| 46 |
+
log_steps: int = 10
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DistillationTrainer:
|
| 50 |
+
"""
|
| 51 |
+
Knowledge Distillation Trainer.
|
| 52 |
+
Supports:
|
| 53 |
+
- Soft label distillation (KL divergence)
|
| 54 |
+
- Hard label training (CE loss)
|
| 55 |
+
- Hidden state matching (optional)
|
| 56 |
+
- Online and offline distillation
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
student_model: nn.Module,
|
| 62 |
+
teacher_model: Optional[nn.Module] = None,
|
| 63 |
+
train_dataloader: DataLoader = None,
|
| 64 |
+
config: Optional[DistillationConfig] = None,
|
| 65 |
+
projection_layer: Optional[nn.Module] = None,
|
| 66 |
+
):
|
| 67 |
+
self.student = student_model
|
| 68 |
+
self.teacher = teacher_model
|
| 69 |
+
self.train_dataloader = train_dataloader
|
| 70 |
+
self.config = config or DistillationConfig()
|
| 71 |
+
self.projection_layer = projection_layer # For hidden state matching
|
| 72 |
+
|
| 73 |
+
self.device = next(student_model.parameters()).device
|
| 74 |
+
|
| 75 |
+
if self.teacher is not None:
|
| 76 |
+
self.teacher.eval()
|
| 77 |
+
for param in self.teacher.parameters():
|
| 78 |
+
param.requires_grad = False
|
| 79 |
+
|
| 80 |
+
self.optimizer = self._create_optimizer()
|
| 81 |
+
self.scheduler = self._create_scheduler()
|
| 82 |
+
self.scaler = GradScaler() if self.config.use_amp else None
|
| 83 |
+
|
| 84 |
+
self.global_step = 0
|
| 85 |
+
Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
def _create_optimizer(self) -> torch.optim.Optimizer:
|
| 88 |
+
params = list(self.student.parameters())
|
| 89 |
+
if self.projection_layer is not None:
|
| 90 |
+
params += list(self.projection_layer.parameters())
|
| 91 |
+
|
| 92 |
+
return torch.optim.AdamW(
|
| 93 |
+
params,
|
| 94 |
+
lr=self.config.learning_rate,
|
| 95 |
+
weight_decay=self.config.weight_decay,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def _create_scheduler(self):
|
| 99 |
+
total_steps = self._get_total_steps()
|
| 100 |
+
|
| 101 |
+
def lr_lambda(step):
|
| 102 |
+
if step < self.config.warmup_steps:
|
| 103 |
+
return step / max(1, self.config.warmup_steps)
|
| 104 |
+
progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps)
|
| 105 |
+
return max(
|
| 106 |
+
self.config.min_learning_rate / self.config.learning_rate,
|
| 107 |
+
0.5 * (1.0 + math.cos(math.pi * progress))
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
|
| 111 |
+
|
| 112 |
+
def _get_total_steps(self) -> int:
|
| 113 |
+
if self.config.max_steps:
|
| 114 |
+
return self.config.max_steps
|
| 115 |
+
steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps
|
| 116 |
+
return steps_per_epoch * self.config.num_epochs
|
| 117 |
+
|
| 118 |
+
def distillation_loss(
|
| 119 |
+
self,
|
| 120 |
+
student_logits: torch.Tensor,
|
| 121 |
+
teacher_logits: torch.Tensor,
|
| 122 |
+
labels: torch.Tensor,
|
| 123 |
+
student_hidden: Optional[torch.Tensor] = None,
|
| 124 |
+
teacher_hidden: Optional[torch.Tensor] = None,
|
| 125 |
+
) -> Dict[str, torch.Tensor]:
|
| 126 |
+
"""
|
| 127 |
+
Compute combined distillation loss.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
student_logits: Student model output logits [B, T, V]
|
| 131 |
+
teacher_logits: Teacher model output logits [B, T, V]
|
| 132 |
+
labels: Ground truth labels [B, T]
|
| 133 |
+
student_hidden: Student hidden states (optional)
|
| 134 |
+
teacher_hidden: Teacher hidden states (optional)
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Dictionary with loss components and total loss
|
| 138 |
+
"""
|
| 139 |
+
# Temperature-scaled soft labels
|
| 140 |
+
T = self.config.temperature
|
| 141 |
+
|
| 142 |
+
# Soft label loss (KL divergence)
|
| 143 |
+
student_log_probs = F.log_softmax(student_logits / T, dim=-1)
|
| 144 |
+
teacher_probs = F.softmax(teacher_logits / T, dim=-1)
|
| 145 |
+
kd_loss = F.kl_div(
|
| 146 |
+
student_log_probs,
|
| 147 |
+
teacher_probs,
|
| 148 |
+
reduction="batchmean"
|
| 149 |
+
) * (T ** 2)
|
| 150 |
+
|
| 151 |
+
# Hard label loss (Cross entropy)
|
| 152 |
+
shift_logits = student_logits[..., :-1, :].contiguous()
|
| 153 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 154 |
+
ce_loss = F.cross_entropy(
|
| 155 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 156 |
+
shift_labels.view(-1),
|
| 157 |
+
ignore_index=-100,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Hidden state matching (optional)
|
| 161 |
+
hidden_loss = torch.tensor(0.0, device=self.device)
|
| 162 |
+
if student_hidden is not None and teacher_hidden is not None and self.projection_layer is not None:
|
| 163 |
+
projected_student = self.projection_layer(student_hidden)
|
| 164 |
+
hidden_loss = F.mse_loss(projected_student, teacher_hidden)
|
| 165 |
+
|
| 166 |
+
# Combined loss
|
| 167 |
+
total_loss = (
|
| 168 |
+
self.config.alpha_ce * ce_loss +
|
| 169 |
+
self.config.alpha_kd * kd_loss +
|
| 170 |
+
self.config.alpha_hidden * hidden_loss
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return {
|
| 174 |
+
"total_loss": total_loss,
|
| 175 |
+
"ce_loss": ce_loss,
|
| 176 |
+
"kd_loss": kd_loss,
|
| 177 |
+
"hidden_loss": hidden_loss,
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
def train(self) -> Dict[str, float]:
|
| 181 |
+
"""Main distillation training loop."""
|
| 182 |
+
self.student.train()
|
| 183 |
+
total_steps = self._get_total_steps()
|
| 184 |
+
|
| 185 |
+
print(f"Starting knowledge distillation for {total_steps} steps")
|
| 186 |
+
print(f" Temperature: {self.config.temperature}")
|
| 187 |
+
print(f" Alpha CE: {self.config.alpha_ce}, Alpha KD: {self.config.alpha_kd}")
|
| 188 |
+
|
| 189 |
+
running_losses = {"total": 0.0, "ce": 0.0, "kd": 0.0}
|
| 190 |
+
|
| 191 |
+
for epoch in range(self.config.num_epochs):
|
| 192 |
+
for step, batch in enumerate(self.train_dataloader):
|
| 193 |
+
losses = self._training_step(batch)
|
| 194 |
+
|
| 195 |
+
for key in running_losses:
|
| 196 |
+
running_losses[key] += losses.get(f"{key}_loss", losses.get("total_loss", 0.0)).item() if isinstance(losses.get(f"{key}_loss", losses.get("total_loss")), torch.Tensor) else 0.0
|
| 197 |
+
|
| 198 |
+
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
| 199 |
+
self._optimizer_step()
|
| 200 |
+
self.global_step += 1
|
| 201 |
+
|
| 202 |
+
if self.global_step % self.config.log_steps == 0:
|
| 203 |
+
avg_losses = {k: v / self.config.log_steps for k, v in running_losses.items()}
|
| 204 |
+
print(
|
| 205 |
+
f"Step {self.global_step}/{total_steps} | "
|
| 206 |
+
f"Total: {avg_losses['total']:.4f} | "
|
| 207 |
+
f"CE: {avg_losses['ce']:.4f} | "
|
| 208 |
+
f"KD: {avg_losses['kd']:.4f}"
|
| 209 |
+
)
|
| 210 |
+
running_losses = {k: 0.0 for k in running_losses}
|
| 211 |
+
|
| 212 |
+
if self.global_step % self.config.save_steps == 0:
|
| 213 |
+
self._save_checkpoint()
|
| 214 |
+
|
| 215 |
+
if self.config.max_steps and self.global_step >= self.config.max_steps:
|
| 216 |
+
break
|
| 217 |
+
|
| 218 |
+
if self.config.max_steps and self.global_step >= self.config.max_steps:
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
self._save_checkpoint(final=True)
|
| 222 |
+
return {"final_step": self.global_step}
|
| 223 |
+
|
| 224 |
+
def _training_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 225 |
+
"""Single distillation training step."""
|
| 226 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 227 |
+
attention_mask = batch.get("attention_mask")
|
| 228 |
+
if attention_mask is not None:
|
| 229 |
+
attention_mask = attention_mask.to(self.device)
|
| 230 |
+
labels = batch["labels"].to(self.device)
|
| 231 |
+
|
| 232 |
+
# Check for pre-computed teacher logits
|
| 233 |
+
teacher_logits = batch.get("teacher_logits")
|
| 234 |
+
if teacher_logits is not None:
|
| 235 |
+
teacher_logits = teacher_logits.to(self.device)
|
| 236 |
+
elif self.teacher is not None:
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
_, teacher_logits, _, _ = self.teacher(input_ids, attention_mask)
|
| 239 |
+
|
| 240 |
+
if self.config.use_amp:
|
| 241 |
+
with autocast(dtype=torch.float16):
|
| 242 |
+
_, student_logits, _, _ = self.student(input_ids, attention_mask)
|
| 243 |
+
losses = self.distillation_loss(student_logits, teacher_logits, labels)
|
| 244 |
+
loss = losses["total_loss"] / self.config.gradient_accumulation_steps
|
| 245 |
+
self.scaler.scale(loss).backward()
|
| 246 |
+
else:
|
| 247 |
+
_, student_logits, _, _ = self.student(input_ids, attention_mask)
|
| 248 |
+
losses = self.distillation_loss(student_logits, teacher_logits, labels)
|
| 249 |
+
loss = losses["total_loss"] / self.config.gradient_accumulation_steps
|
| 250 |
+
loss.backward()
|
| 251 |
+
|
| 252 |
+
return losses
|
| 253 |
+
|
| 254 |
+
def _optimizer_step(self):
|
| 255 |
+
if self.config.use_amp:
|
| 256 |
+
self.scaler.unscale_(self.optimizer)
|
| 257 |
+
|
| 258 |
+
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.config.grad_clip)
|
| 259 |
+
|
| 260 |
+
if self.config.use_amp:
|
| 261 |
+
self.scaler.step(self.optimizer)
|
| 262 |
+
self.scaler.update()
|
| 263 |
+
else:
|
| 264 |
+
self.optimizer.step()
|
| 265 |
+
|
| 266 |
+
self.scheduler.step()
|
| 267 |
+
self.optimizer.zero_grad()
|
| 268 |
+
|
| 269 |
+
def _save_checkpoint(self, final: bool = False):
|
| 270 |
+
name = "final" if final else f"step_{self.global_step}"
|
| 271 |
+
path = Path(self.config.output_dir) / name
|
| 272 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
|
| 274 |
+
torch.save(self.student.state_dict(), path / "student_model.pt")
|
| 275 |
+
if self.projection_layer is not None:
|
| 276 |
+
torch.save(self.projection_layer.state_dict(), path / "projection.pt")
|
| 277 |
+
|
| 278 |
+
print(f"Checkpoint saved to {path}")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def generate_teacher_logits(
|
| 282 |
+
teacher_model: nn.Module,
|
| 283 |
+
dataloader: DataLoader,
|
| 284 |
+
output_path: str,
|
| 285 |
+
device: str = "cuda",
|
| 286 |
+
top_k: int = 100, # Only save top-k logits to reduce storage
|
| 287 |
+
):
|
| 288 |
+
"""
|
| 289 |
+
Pre-generate teacher logits for offline distillation.
|
| 290 |
+
Saves storage by only keeping top-k logits per position.
|
| 291 |
+
"""
|
| 292 |
+
teacher_model.eval()
|
| 293 |
+
teacher_model.to(device)
|
| 294 |
+
|
| 295 |
+
all_logits = []
|
| 296 |
+
|
| 297 |
+
print(f"Generating teacher logits for {len(dataloader)} batches...")
|
| 298 |
+
|
| 299 |
+
with torch.no_grad():
|
| 300 |
+
for batch in dataloader:
|
| 301 |
+
input_ids = batch["input_ids"].to(device)
|
| 302 |
+
attention_mask = batch.get("attention_mask")
|
| 303 |
+
if attention_mask is not None:
|
| 304 |
+
attention_mask = attention_mask.to(device)
|
| 305 |
+
|
| 306 |
+
_, logits, _, _ = teacher_model(input_ids, attention_mask)
|
| 307 |
+
|
| 308 |
+
# Keep only top-k logits
|
| 309 |
+
if top_k > 0 and top_k < logits.shape[-1]:
|
| 310 |
+
topk_values, topk_indices = torch.topk(logits, k=top_k, dim=-1)
|
| 311 |
+
sparse_logits = {
|
| 312 |
+
"values": topk_values.cpu(),
|
| 313 |
+
"indices": topk_indices.cpu(),
|
| 314 |
+
}
|
| 315 |
+
all_logits.append(sparse_logits)
|
| 316 |
+
else:
|
| 317 |
+
all_logits.append(logits.cpu())
|
| 318 |
+
|
| 319 |
+
torch.save(all_logits, output_path)
|
| 320 |
+
print(f"Teacher logits saved to {output_path}")
|
training/trainer.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMind Training Utilities
|
| 3 |
+
Standard training loop with mixed precision and gradient accumulation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
import time
|
| 9 |
+
from typing import Optional, Dict, Any
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 20 |
+
from configs.model_config import Mind2Config
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class TrainingConfig:
|
| 25 |
+
"""Training configuration."""
|
| 26 |
+
# Optimization
|
| 27 |
+
learning_rate: float = 3e-4
|
| 28 |
+
min_learning_rate: float = 3e-5
|
| 29 |
+
weight_decay: float = 0.1
|
| 30 |
+
beta1: float = 0.9
|
| 31 |
+
beta2: float = 0.95
|
| 32 |
+
grad_clip: float = 1.0
|
| 33 |
+
warmup_steps: int = 1000
|
| 34 |
+
|
| 35 |
+
# Training
|
| 36 |
+
num_epochs: int = 3
|
| 37 |
+
batch_size: int = 8
|
| 38 |
+
gradient_accumulation_steps: int = 4
|
| 39 |
+
max_steps: Optional[int] = None
|
| 40 |
+
|
| 41 |
+
# Mixed precision
|
| 42 |
+
use_amp: bool = True
|
| 43 |
+
amp_dtype: str = "float16" # float16 or bfloat16
|
| 44 |
+
|
| 45 |
+
# Checkpointing
|
| 46 |
+
save_steps: int = 1000
|
| 47 |
+
eval_steps: int = 500
|
| 48 |
+
output_dir: str = "./outputs"
|
| 49 |
+
resume_from: Optional[str] = None
|
| 50 |
+
|
| 51 |
+
# Logging
|
| 52 |
+
log_steps: int = 10
|
| 53 |
+
wandb_project: Optional[str] = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Mind2Trainer:
|
| 57 |
+
"""Trainer for MiniMind models."""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
model: nn.Module,
|
| 62 |
+
train_dataloader: DataLoader,
|
| 63 |
+
eval_dataloader: Optional[DataLoader] = None,
|
| 64 |
+
config: Optional[TrainingConfig] = None,
|
| 65 |
+
):
|
| 66 |
+
self.model = model
|
| 67 |
+
self.train_dataloader = train_dataloader
|
| 68 |
+
self.eval_dataloader = eval_dataloader
|
| 69 |
+
self.config = config or TrainingConfig()
|
| 70 |
+
|
| 71 |
+
self.device = next(model.parameters()).device
|
| 72 |
+
self.global_step = 0
|
| 73 |
+
self.epoch = 0
|
| 74 |
+
|
| 75 |
+
# Setup optimizer
|
| 76 |
+
self.optimizer = self._create_optimizer()
|
| 77 |
+
self.scheduler = self._create_scheduler()
|
| 78 |
+
|
| 79 |
+
# Mixed precision
|
| 80 |
+
self.scaler = GradScaler() if self.config.use_amp else None
|
| 81 |
+
self.amp_dtype = torch.float16 if self.config.amp_dtype == "float16" else torch.bfloat16
|
| 82 |
+
|
| 83 |
+
# Output directory
|
| 84 |
+
Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)
|
| 85 |
+
|
| 86 |
+
def _create_optimizer(self) -> torch.optim.Optimizer:
|
| 87 |
+
"""Create AdamW optimizer with weight decay."""
|
| 88 |
+
decay_params = []
|
| 89 |
+
no_decay_params = []
|
| 90 |
+
|
| 91 |
+
for name, param in self.model.named_parameters():
|
| 92 |
+
if not param.requires_grad:
|
| 93 |
+
continue
|
| 94 |
+
if "bias" in name or "norm" in name or "layernorm" in name:
|
| 95 |
+
no_decay_params.append(param)
|
| 96 |
+
else:
|
| 97 |
+
decay_params.append(param)
|
| 98 |
+
|
| 99 |
+
optimizer_groups = [
|
| 100 |
+
{"params": decay_params, "weight_decay": self.config.weight_decay},
|
| 101 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
return torch.optim.AdamW(
|
| 105 |
+
optimizer_groups,
|
| 106 |
+
lr=self.config.learning_rate,
|
| 107 |
+
betas=(self.config.beta1, self.config.beta2),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def _create_scheduler(self):
|
| 111 |
+
"""Create cosine annealing scheduler with warmup."""
|
| 112 |
+
total_steps = self._get_total_steps()
|
| 113 |
+
|
| 114 |
+
def lr_lambda(step):
|
| 115 |
+
if step < self.config.warmup_steps:
|
| 116 |
+
return step / max(1, self.config.warmup_steps)
|
| 117 |
+
progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps)
|
| 118 |
+
return max(
|
| 119 |
+
self.config.min_learning_rate / self.config.learning_rate,
|
| 120 |
+
0.5 * (1.0 + math.cos(math.pi * progress))
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
|
| 124 |
+
|
| 125 |
+
def _get_total_steps(self) -> int:
|
| 126 |
+
if self.config.max_steps:
|
| 127 |
+
return self.config.max_steps
|
| 128 |
+
steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps
|
| 129 |
+
return steps_per_epoch * self.config.num_epochs
|
| 130 |
+
|
| 131 |
+
def train(self) -> Dict[str, float]:
|
| 132 |
+
"""Main training loop."""
|
| 133 |
+
self.model.train()
|
| 134 |
+
total_steps = self._get_total_steps()
|
| 135 |
+
|
| 136 |
+
print(f"Starting training for {total_steps} steps")
|
| 137 |
+
print(f" Batch size: {self.config.batch_size}")
|
| 138 |
+
print(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
|
| 139 |
+
print(f" Effective batch size: {self.config.batch_size * self.config.gradient_accumulation_steps}")
|
| 140 |
+
|
| 141 |
+
running_loss = 0.0
|
| 142 |
+
start_time = time.time()
|
| 143 |
+
|
| 144 |
+
for epoch in range(self.config.num_epochs):
|
| 145 |
+
self.epoch = epoch
|
| 146 |
+
|
| 147 |
+
for step, batch in enumerate(self.train_dataloader):
|
| 148 |
+
loss = self._training_step(batch)
|
| 149 |
+
running_loss += loss
|
| 150 |
+
|
| 151 |
+
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
| 152 |
+
self._optimizer_step()
|
| 153 |
+
self.global_step += 1
|
| 154 |
+
|
| 155 |
+
# Logging
|
| 156 |
+
if self.global_step % self.config.log_steps == 0:
|
| 157 |
+
avg_loss = running_loss / self.config.log_steps
|
| 158 |
+
elapsed = time.time() - start_time
|
| 159 |
+
tokens_per_sec = (
|
| 160 |
+
self.config.batch_size * self.config.gradient_accumulation_steps *
|
| 161 |
+
batch["input_ids"].shape[1] * self.config.log_steps / elapsed
|
| 162 |
+
)
|
| 163 |
+
print(
|
| 164 |
+
f"Step {self.global_step}/{total_steps} | "
|
| 165 |
+
f"Loss: {avg_loss:.4f} | "
|
| 166 |
+
f"LR: {self.scheduler.get_last_lr()[0]:.2e} | "
|
| 167 |
+
f"Tokens/s: {tokens_per_sec:.0f}"
|
| 168 |
+
)
|
| 169 |
+
running_loss = 0.0
|
| 170 |
+
start_time = time.time()
|
| 171 |
+
|
| 172 |
+
# Evaluation
|
| 173 |
+
if self.eval_dataloader and self.global_step % self.config.eval_steps == 0:
|
| 174 |
+
eval_loss = self.evaluate()
|
| 175 |
+
print(f"Eval Loss: {eval_loss:.4f}")
|
| 176 |
+
self.model.train()
|
| 177 |
+
|
| 178 |
+
# Save checkpoint
|
| 179 |
+
if self.global_step % self.config.save_steps == 0:
|
| 180 |
+
self.save_checkpoint()
|
| 181 |
+
|
| 182 |
+
if self.config.max_steps and self.global_step >= self.config.max_steps:
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
if self.config.max_steps and self.global_step >= self.config.max_steps:
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
self.save_checkpoint(final=True)
|
| 189 |
+
return {"final_loss": running_loss}
|
| 190 |
+
|
| 191 |
+
def _training_step(self, batch: Dict[str, torch.Tensor]) -> float:
|
| 192 |
+
"""Single training step."""
|
| 193 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 194 |
+
attention_mask = batch.get("attention_mask", None)
|
| 195 |
+
if attention_mask is not None:
|
| 196 |
+
attention_mask = attention_mask.to(self.device)
|
| 197 |
+
labels = batch["labels"].to(self.device)
|
| 198 |
+
|
| 199 |
+
if self.config.use_amp:
|
| 200 |
+
with autocast(dtype=self.amp_dtype):
|
| 201 |
+
loss, _, _, _ = self.model(input_ids, attention_mask, labels)
|
| 202 |
+
loss = loss / self.config.gradient_accumulation_steps
|
| 203 |
+
self.scaler.scale(loss).backward()
|
| 204 |
+
else:
|
| 205 |
+
loss, _, _, _ = self.model(input_ids, attention_mask, labels)
|
| 206 |
+
loss = loss / self.config.gradient_accumulation_steps
|
| 207 |
+
loss.backward()
|
| 208 |
+
|
| 209 |
+
return loss.item() * self.config.gradient_accumulation_steps
|
| 210 |
+
|
| 211 |
+
def _optimizer_step(self):
|
| 212 |
+
"""Optimizer step with gradient clipping."""
|
| 213 |
+
if self.config.use_amp:
|
| 214 |
+
self.scaler.unscale_(self.optimizer)
|
| 215 |
+
|
| 216 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
|
| 217 |
+
|
| 218 |
+
if self.config.use_amp:
|
| 219 |
+
self.scaler.step(self.optimizer)
|
| 220 |
+
self.scaler.update()
|
| 221 |
+
else:
|
| 222 |
+
self.optimizer.step()
|
| 223 |
+
|
| 224 |
+
self.scheduler.step()
|
| 225 |
+
self.optimizer.zero_grad()
|
| 226 |
+
|
| 227 |
+
@torch.no_grad()
|
| 228 |
+
def evaluate(self) -> float:
|
| 229 |
+
"""Evaluate model on eval dataset."""
|
| 230 |
+
self.model.eval()
|
| 231 |
+
total_loss = 0.0
|
| 232 |
+
num_batches = 0
|
| 233 |
+
|
| 234 |
+
for batch in self.eval_dataloader:
|
| 235 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 236 |
+
attention_mask = batch.get("attention_mask")
|
| 237 |
+
if attention_mask is not None:
|
| 238 |
+
attention_mask = attention_mask.to(self.device)
|
| 239 |
+
labels = batch["labels"].to(self.device)
|
| 240 |
+
|
| 241 |
+
loss, _, _, _ = self.model(input_ids, attention_mask, labels)
|
| 242 |
+
total_loss += loss.item()
|
| 243 |
+
num_batches += 1
|
| 244 |
+
|
| 245 |
+
return total_loss / max(1, num_batches)
|
| 246 |
+
|
| 247 |
+
def save_checkpoint(self, final: bool = False):
|
| 248 |
+
"""Save model checkpoint."""
|
| 249 |
+
checkpoint_name = "final" if final else f"step_{self.global_step}"
|
| 250 |
+
checkpoint_path = Path(self.config.output_dir) / checkpoint_name
|
| 251 |
+
|
| 252 |
+
checkpoint_path.mkdir(parents=True, exist_ok=True)
|
| 253 |
+
|
| 254 |
+
torch.save(self.model.state_dict(), checkpoint_path / "model.pt")
|
| 255 |
+
torch.save(self.optimizer.state_dict(), checkpoint_path / "optimizer.pt")
|
| 256 |
+
torch.save({
|
| 257 |
+
"global_step": self.global_step,
|
| 258 |
+
"epoch": self.epoch,
|
| 259 |
+
"config": self.config,
|
| 260 |
+
}, checkpoint_path / "trainer_state.pt")
|
| 261 |
+
|
| 262 |
+
print(f"Checkpoint saved to {checkpoint_path}")
|
| 263 |
+
|
| 264 |
+
def load_checkpoint(self, checkpoint_path: str):
|
| 265 |
+
"""Load model checkpoint."""
|
| 266 |
+
path = Path(checkpoint_path)
|
| 267 |
+
self.model.load_state_dict(torch.load(path / "model.pt", map_location=self.device))
|
| 268 |
+
self.optimizer.load_state_dict(torch.load(path / "optimizer.pt", map_location=self.device))
|
| 269 |
+
|
| 270 |
+
state = torch.load(path / "trainer_state.pt", map_location=self.device)
|
| 271 |
+
self.global_step = state["global_step"]
|
| 272 |
+
self.epoch = state["epoch"]
|
| 273 |
+
|
| 274 |
+
print(f"Checkpoint loaded from {checkpoint_path}")
|