Upload folder using huggingface_hub
Browse files- LICENSE +17 -0
- README.md +197 -0
- model.py +193 -0
- model.safetensors +3 -0
LICENSE
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2025 HighkeyPrxneeth
|
| 6 |
+
|
| 7 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
you may not use this file except in compliance with the License.
|
| 9 |
+
You may obtain a copy of the License at
|
| 10 |
+
|
| 11 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
|
| 13 |
+
Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
See the License for the specific language governing permissions and
|
| 17 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- "en"
|
| 4 |
+
pretty_name: "ModernTrajectoryNet: Transaction Embedding Classifier"
|
| 5 |
+
tags:
|
| 6 |
+
- embedding
|
| 7 |
+
- pytorch
|
| 8 |
+
- finance
|
| 9 |
+
- transaction-classifier
|
| 10 |
+
- contrastive-learning
|
| 11 |
+
license: "apache-2.0"
|
| 12 |
+
datasets:
|
| 13 |
+
- "HighkeyPrxneeth/BusinessTransactions"
|
| 14 |
+
library_name: "pytorch"
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# ModernTrajectoryNet: Transaction Embedding Classifier
|
| 18 |
+
|
| 19 |
+
A state-of-the-art PyTorch embedding classifier trained with modern deep learning techniques for transaction categorization. The model learns to project transaction embeddings toward their target category embeddings through trajectory-based contrastive learning.
|
| 20 |
+
|
| 21 |
+
## Model Architecture
|
| 22 |
+
|
| 23 |
+
**ModernTrajectoryNet** combines several modern architectural innovations:
|
| 24 |
+
|
| 25 |
+
### Core Components
|
| 26 |
+
|
| 27 |
+
1. **RMSNorm (Root Mean Square Layer Normalization)**
|
| 28 |
+
- More stable and computationally efficient than LayerNorm
|
| 29 |
+
- Used in LLaMA, PaLM, and Gopher
|
| 30 |
+
- Provides consistent gradient flow through deep networks
|
| 31 |
+
|
| 32 |
+
2. **SwiGLU (Swish-Gated Linear Unit)**
|
| 33 |
+
- SOTA activation function for feed-forward networks
|
| 34 |
+
- Outperforms GELU and ReLU in expressivity
|
| 35 |
+
- Gate mechanism: `(x * sigmoid(x)) * linear(x)`
|
| 36 |
+
|
| 37 |
+
3. **SEBlock (Squeeze-and-Excitation)**
|
| 38 |
+
- Channel attention mechanism
|
| 39 |
+
- Allows dynamic weighting of embedding dimensions
|
| 40 |
+
- Context-aware feature recalibration
|
| 41 |
+
|
| 42 |
+
4. **ModernBlock (Pre-Norm Architecture)**
|
| 43 |
+
- RMSNorm → SwiGLU → SEBlock → Residual Connection
|
| 44 |
+
- Incorporates layer scaling and stochastic depth (DropPath)
|
| 45 |
+
- Enables training of very deep networks
|
| 46 |
+
|
| 47 |
+
### Configuration
|
| 48 |
+
|
| 49 |
+
- **Input dimension**: 768 (embedding size)
|
| 50 |
+
- **Hidden layers**: 12 transformer-style blocks
|
| 51 |
+
- **Expansion ratio**: 4x hidden dimension in SwiGLU
|
| 52 |
+
- **Dropout**: 0.1
|
| 53 |
+
- **Stochastic depth**: Linear decay across layers (0.0 → 0.1)
|
| 54 |
+
|
| 55 |
+
## Training Objective: Hybrid Trajectory Learning
|
| 56 |
+
|
| 57 |
+
The model is trained with **HybridTrajectoryLoss**, combining two objectives:
|
| 58 |
+
|
| 59 |
+
### 1. Adaptive InfoNCE (Contrastive Component)
|
| 60 |
+
- Learnable temperature parameter for dynamic scaling
|
| 61 |
+
- Contrastive loss with label smoothing (0.1)
|
| 62 |
+
- Ensures the model maps input embeddings close to their true target embedding
|
| 63 |
+
- Equation: `L_contrastive = CrossEntropy(logits / T, labels)`
|
| 64 |
+
|
| 65 |
+
### 2. Monotonic Ranking (Trajectory Component)
|
| 66 |
+
- Enforces **monotonically increasing similarity** through the transaction sequence
|
| 67 |
+
- Each step in the trajectory should have higher similarity than the previous step
|
| 68 |
+
- Final embedding must achieve high similarity (ideally 1.0) with target
|
| 69 |
+
- Margin constraint: `sim[i+1] > sim[i] + 0.01`
|
| 70 |
+
- Ensures the model learns the **path** to the target, not just the endpoint
|
| 71 |
+
|
| 72 |
+
### Loss Formulation
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
Total Loss = InfoNCE Loss + Monotonicity Loss
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
**Why Trajectory Learning?**
|
| 79 |
+
- Transactions often evolve gradually toward their correct category
|
| 80 |
+
- Intermediate embeddings should show progression toward the target
|
| 81 |
+
- This inductive bias improves generalization and interpretability
|
| 82 |
+
|
| 83 |
+
## Training Details
|
| 84 |
+
|
| 85 |
+
- **Optimizer**: AdamW with weight decay (1e-4)
|
| 86 |
+
- **Learning rate**: Cosine annealing from 3e-4 to 1e-6
|
| 87 |
+
- **Batch size**: 128
|
| 88 |
+
- **Gradient clipping**: 1.0
|
| 89 |
+
- **Epochs**: 50 with early stopping (patience=5)
|
| 90 |
+
- **EMA (Exponential Moving Average)**: Decay=0.99 for evaluation stability
|
| 91 |
+
- **Augmentation**: Input masking (p=0.15) and Gaussian noise (std=0.01) during training
|
| 92 |
+
- **Mixed Precision**: AMP enabled for faster training on CUDA
|
| 93 |
+
|
| 94 |
+
## Performance Metrics
|
| 95 |
+
|
| 96 |
+
The model optimizes for:
|
| 97 |
+
1. **Last Similarity**: Similarity of final embedding with target (Target: ≈1.0)
|
| 98 |
+
2. **Monotonicity Accuracy**: % of transitions with strictly increasing similarity (Target: 100%)
|
| 99 |
+
3. **Contrastive Accuracy**: Ability to distinguish true target from other targets in batch
|
| 100 |
+
|
| 101 |
+
## How to Load
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
from safetensors.torch import load_file
|
| 105 |
+
import torch
|
| 106 |
+
from config import Config
|
| 107 |
+
from model import ModernTrajectoryNet
|
| 108 |
+
|
| 109 |
+
# Load weights
|
| 110 |
+
weights = load_file("model.safetensors")
|
| 111 |
+
|
| 112 |
+
# Instantiate model
|
| 113 |
+
config = Config()
|
| 114 |
+
model = ModernTrajectoryNet(config)
|
| 115 |
+
model.load_state_dict(weights)
|
| 116 |
+
model.eval()
|
| 117 |
+
|
| 118 |
+
# Use model
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
input_embedding = torch.randn(1, 768) # Your transaction embedding
|
| 121 |
+
output_embedding = model(input_embedding)
|
| 122 |
+
print(output_embedding.shape) # [1, 768]
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Usage Example
|
| 126 |
+
|
| 127 |
+
```python
|
| 128 |
+
import torch
|
| 129 |
+
from torch.nn.functional import normalize
|
| 130 |
+
|
| 131 |
+
# Assuming you have transaction embeddings and category embeddings
|
| 132 |
+
transaction_emb = model(input_embedding) # [B, 768]
|
| 133 |
+
|
| 134 |
+
# Compute similarity with category embeddings
|
| 135 |
+
category_embs = normalize(category_embeddings, p=2, dim=1) # [N_cats, 768]
|
| 136 |
+
transaction_emb_norm = normalize(transaction_emb, p=2, dim=1) # [B, 768]
|
| 137 |
+
|
| 138 |
+
similarities = torch.matmul(transaction_emb_norm, category_embs.t()) # [B, N_cats]
|
| 139 |
+
predicted_category = torch.argmax(similarities, dim=1) # [B]
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
## Intended Uses
|
| 143 |
+
|
| 144 |
+
- **Transaction categorization**: Classify business transactions into merchant categories
|
| 145 |
+
- **Embedding refinement**: Project raw transaction embeddings to discriminative space
|
| 146 |
+
- **Contrastive learning**: Extract improved embeddings for downstream tasks
|
| 147 |
+
- **Research**: Study trajectory-based learning for sequential decision problems
|
| 148 |
+
|
| 149 |
+
## Limitations & Biases
|
| 150 |
+
|
| 151 |
+
- **Synthetic data**: Trained on synthetic transaction strings generated from Foursquare Open-Source (FSQ OS) business names and categories using `qwen2.5-4b-instruct` LLM
|
| 152 |
+
- **FSQ OS biases**: Inherits biases from the FSQ OS dataset (e.g., geographic coverage, business type distribution)
|
| 153 |
+
- **Generation artifacts**: LLM-based synthetic data may not reflect real-world transaction diversity
|
| 154 |
+
- **Category coverage**: Limited to categories present in FSQ OS (typically 200-500 merchant types)
|
| 155 |
+
- **Language**: Trained on English transaction strings; may not generalize to other languages
|
| 156 |
+
|
| 157 |
+
**Recommendation**: Validate performance on your specific transaction domain before production deployment.
|
| 158 |
+
|
| 159 |
+
## Dataset
|
| 160 |
+
|
| 161 |
+
- **Source**: Foursquare Open-Source (FSQ OS) business names and categories
|
| 162 |
+
- **Processing**: LLM-based synthetic transaction generation
|
| 163 |
+
- **Size**: ~1M synthetic transaction embeddings
|
| 164 |
+
- **Train/Val split**: 90% / 10%
|
| 165 |
+
|
| 166 |
+
See the [dataset](https://huggingface.co/datasets/HighkeyPrxneeth/BusinessTransactions) for more details.
|
| 167 |
+
|
| 168 |
+
## Files in This Repository
|
| 169 |
+
|
| 170 |
+
- `model.safetensors`: Model weights in HuggingFace SafeTensors format (160MB)
|
| 171 |
+
- `README.md`: This file
|
| 172 |
+
- `LICENSE`: Apache 2.0 license
|
| 173 |
+
|
| 174 |
+
## License
|
| 175 |
+
|
| 176 |
+
Apache License 2.0. See LICENSE file for details.
|
| 177 |
+
|
| 178 |
+
## Citation
|
| 179 |
+
|
| 180 |
+
If you use this model, please cite:
|
| 181 |
+
|
| 182 |
+
```bibtex
|
| 183 |
+
@software{transactionclassifier2024,
|
| 184 |
+
title={TransactionClassifier: Embedding-based Transaction Categorization},
|
| 185 |
+
author={HighkeyPrxneeth},
|
| 186 |
+
year={2024},
|
| 187 |
+
url={https://huggingface.co/HighkeyPrxneeth/ModernTrajectoryNet}
|
| 188 |
+
}
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
## Contact & Support
|
| 192 |
+
|
| 193 |
+
- **Repository**: [GitHub - TransactionClassifier](https://github.com/HighkeyPrxneeth/TransactionClassifier)
|
| 194 |
+
- **Issues**: Open an issue in the main project repository
|
| 195 |
+
- **Author**: HighkeyPrxneeth
|
| 196 |
+
|
| 197 |
+
For questions about the model architecture, training, or usage, feel free to reach out!
|
model.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class RMSNorm(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Root Mean Square Layer Normalization.
|
| 8 |
+
More stable and computationally efficient than LayerNorm.
|
| 9 |
+
Used in LLaMA, PaLM, Gopher.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.eps = eps
|
| 14 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 15 |
+
|
| 16 |
+
def _norm(self, x):
|
| 17 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
output = self._norm(x.float()).type_as(x)
|
| 21 |
+
return output * self.weight
|
| 22 |
+
|
| 23 |
+
class SwiGLU(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Swish-Gated Linear Unit.
|
| 26 |
+
SOTA activation function for FFNs (outperforms GELU/ReLU).
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 31 |
+
self.w2 = nn.Linear(dim, hidden_dim, bias=False)
|
| 32 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=False)
|
| 33 |
+
self.dropout = nn.Dropout(dropout)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
# Gate mechanism: (x * sigmoid(x)) * linear(x)
|
| 37 |
+
x1 = self.w1(x)
|
| 38 |
+
x2 = self.w2(x)
|
| 39 |
+
hidden = F.silu(x1) * x2
|
| 40 |
+
return self.w3(self.dropout(hidden))
|
| 41 |
+
|
| 42 |
+
class SEBlock(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
Squeeze-and-Excitation Block.
|
| 45 |
+
Allows the model to dynamically weight different dimensions of the embedding
|
| 46 |
+
based on global context.
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, dim: int, reduction: int = 4):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
| 51 |
+
self.fc = nn.Sequential(
|
| 52 |
+
nn.Linear(dim, dim // reduction, bias=False),
|
| 53 |
+
nn.ReLU(inplace=True),
|
| 54 |
+
nn.Linear(dim // reduction, dim, bias=False),
|
| 55 |
+
nn.Sigmoid()
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
# Input: [B, D] -> unsqueeze to [B, D, 1] for pool/conv compatibility if needed
|
| 60 |
+
# But here we are working with vectors, so we simulate it.
|
| 61 |
+
b, d = x.shape
|
| 62 |
+
y = self.fc(x) # [B, D]
|
| 63 |
+
return x * y
|
| 64 |
+
|
| 65 |
+
class DropPath(nn.Module):
|
| 66 |
+
"""Stochastic depth regularizer (Improved)."""
|
| 67 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.drop_prob = drop_prob
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
if self.drop_prob == 0.0 or not self.training:
|
| 73 |
+
return x
|
| 74 |
+
keep_prob = 1.0 - self.drop_prob
|
| 75 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 76 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 77 |
+
random_tensor.floor_()
|
| 78 |
+
return x.div(keep_prob) * random_tensor
|
| 79 |
+
|
| 80 |
+
class ModernBlock(nn.Module):
|
| 81 |
+
"""
|
| 82 |
+
A Pre-Norm Block combining RMSNorm, SwiGLU, and Channel Attention.
|
| 83 |
+
"""
|
| 84 |
+
def __init__(self, dim: int, expand: int = 4, dropout: float = 0.1,
|
| 85 |
+
layer_scale_init: float = 1e-6, drop_path: float = 0.0):
|
| 86 |
+
super().__init__()
|
| 87 |
+
|
| 88 |
+
# 1. Normalization
|
| 89 |
+
self.norm = RMSNorm(dim)
|
| 90 |
+
|
| 91 |
+
# 2. SOTA Feed Forward (SwiGLU)
|
| 92 |
+
# SwiGLU usually requires 2/3 hidden dim of standard MLP to match params,
|
| 93 |
+
# but we keep it high for expressivity.
|
| 94 |
+
self.ffn = SwiGLU(dim, int(dim * expand * 2 / 3), dropout=dropout)
|
| 95 |
+
|
| 96 |
+
# 3. Channel Attention (Context awareness)
|
| 97 |
+
self.se = SEBlock(dim, reduction=4)
|
| 98 |
+
|
| 99 |
+
# 4. Regularization
|
| 100 |
+
self.layer_scale = nn.Parameter(torch.ones(dim) * layer_scale_init) if layer_scale_init > 0 else None
|
| 101 |
+
self.drop_path = DropPath(drop_path)
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
residual = x
|
| 105 |
+
|
| 106 |
+
# Pre-Norm Architecture
|
| 107 |
+
out = self.norm(x)
|
| 108 |
+
out = self.ffn(out)
|
| 109 |
+
out = self.se(out) # Apply attention
|
| 110 |
+
|
| 111 |
+
if self.layer_scale is not None:
|
| 112 |
+
out = out * self.layer_scale
|
| 113 |
+
|
| 114 |
+
out = self.drop_path(out)
|
| 115 |
+
|
| 116 |
+
return residual + out
|
| 117 |
+
|
| 118 |
+
class ModernTrajectoryNet(nn.Module):
|
| 119 |
+
def __init__(self, config):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.d_model = config.d_model
|
| 122 |
+
self.n_layers = config.n_layers
|
| 123 |
+
|
| 124 |
+
# Config defaults
|
| 125 |
+
dropout = getattr(config, "dropout", 0.1)
|
| 126 |
+
expand = getattr(config, "expand", 4)
|
| 127 |
+
drop_path_rate = getattr(config, "drop_path_rate", 0.1)
|
| 128 |
+
|
| 129 |
+
# Input Projection (Projects to latent space)
|
| 130 |
+
self.input_proj = nn.Sequential(
|
| 131 |
+
RMSNorm(self.d_model),
|
| 132 |
+
nn.Linear(self.d_model, self.d_model)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Backbone
|
| 136 |
+
self.blocks = nn.ModuleList([
|
| 137 |
+
ModernBlock(
|
| 138 |
+
dim=self.d_model,
|
| 139 |
+
expand=expand,
|
| 140 |
+
dropout=dropout,
|
| 141 |
+
drop_path=drop_path_rate * (i / (self.n_layers - 1)) # Linear decay
|
| 142 |
+
) for i in range(self.n_layers)
|
| 143 |
+
])
|
| 144 |
+
|
| 145 |
+
self.final_norm = RMSNorm(self.d_model)
|
| 146 |
+
|
| 147 |
+
# Projector Head (SimCLR / CLIP style)
|
| 148 |
+
# Important: Keep high dimension for the final linear probe
|
| 149 |
+
self.head = nn.Sequential(
|
| 150 |
+
nn.Linear(self.d_model, self.d_model),
|
| 151 |
+
nn.GELU(),
|
| 152 |
+
nn.Linear(self.d_model, self.d_model)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.apply(self._init_weights)
|
| 156 |
+
|
| 157 |
+
def _init_weights(self, m):
|
| 158 |
+
if isinstance(m, nn.Linear):
|
| 159 |
+
torch.nn.init.trunc_normal_(m.weight, std=.02)
|
| 160 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 161 |
+
nn.init.constant_(m.bias, 0)
|
| 162 |
+
elif isinstance(m, nn.LayerNorm):
|
| 163 |
+
nn.init.constant_(m.bias, 0)
|
| 164 |
+
nn.init.constant_(m.weight, 1.0)
|
| 165 |
+
|
| 166 |
+
def forward(self, x, return_trajectory=False):
|
| 167 |
+
# Handle sequence dimension if present
|
| 168 |
+
if x.dim() == 3:
|
| 169 |
+
x = x.mean(dim=1)
|
| 170 |
+
|
| 171 |
+
x = self.input_proj(x)
|
| 172 |
+
|
| 173 |
+
trajectory = []
|
| 174 |
+
for block in self.blocks:
|
| 175 |
+
x = block(x)
|
| 176 |
+
trajectory.append(x)
|
| 177 |
+
|
| 178 |
+
x = self.final_norm(x)
|
| 179 |
+
|
| 180 |
+
# Residual connection to original input is implicit via the blocks,
|
| 181 |
+
# but for trajectory learning, we want the final head to dictate the shift.
|
| 182 |
+
output = self.head(x)
|
| 183 |
+
|
| 184 |
+
# OPTIONAL: Add Denoising / Residual connection to input
|
| 185 |
+
# output = output + input_tensor_if_saved
|
| 186 |
+
|
| 187 |
+
if return_trajectory:
|
| 188 |
+
return output, torch.stack(trajectory, dim=1)
|
| 189 |
+
|
| 190 |
+
return output
|
| 191 |
+
|
| 192 |
+
# Backwards compatibility
|
| 193 |
+
HybridMambaAttentionModel = ModernTrajectoryNet
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd2615d46943643727d117d14950b2437000a0a7346d1a969d15728c5cadd56b
|
| 3 |
+
size 167580472
|