Ram07 commited on
Commit
83505eb
·
verified ·
1 Parent(s): fd0654a

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - bitnet
8
+ - quantization
9
+ - early-exit
10
+ - layer-skipping
11
+ - efficient-transformers
12
+ datasets:
13
+ - roneneldan/TinyStories
14
+ ---
15
+
16
+ # bitskip-v1-earlyexit
17
+
18
+ BitSkip v1 with 8-bit activation quantization and ternary weights (no Hadamard transform)
19
+
20
+ ## Model Description
21
+
22
+ This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference.
23
+
24
+ ## Architecture Details
25
+
26
+ - **Layers**: 24
27
+ - **Hidden dimension**: 2048
28
+ - **Attention heads**: 32 (64-dimensional each)
29
+ - **Key-Value heads**: 8 (Grouped Query Attention with 4:1 ratio)
30
+ - **FFN intermediate size**: 4096
31
+ - **Position embeddings**: Rotary Position Embeddings (RoPE)
32
+ - **Normalization**: RMSNorm
33
+ - **Activation**: SwiGLU (for MLP)
34
+ - **Parameters**: ~1.06B
35
+
36
+ ### Quantization Scheme
37
+
38
+ - **Weights**: Ternary {-1, 0, 1}
39
+ - **Activations**: 8-bit quantization
40
+ - **Hadamard**: No
41
+
42
+ ## Training Details
43
+
44
+ ### Dataset
45
+ - **Source**: TinyStories (2.1M stories)
46
+ - **Tokenizer**: GPT-2 BPE (vocab size: 50,257)
47
+ - **Sequence length**: 512 tokens
48
+
49
+ ### Training Techniques
50
+
51
+ **Quadratic Layer Dropout:**
52
+ - Progressive dropout: p_l = 0.5 × (l/L)²
53
+ - Normalized so Σp_l = 1.0
54
+ - Never drops final layer
55
+ - Makes earlier layers more accurate
56
+
57
+ **Early Exit Loss:**
58
+ - All layers share the same LM head
59
+ - Loss = main_loss + 0.3 × early_exit_loss
60
+ - Layer-proportional weighting: w_i = (i+1)/L
61
+ - Enables flexible early exit at inference
62
+
63
+ ### Hyperparameters
64
+
65
+ - **Optimizer**: AdamW
66
+ - **Learning rate**: 6e-4
67
+ - **Warmup steps**: 1000
68
+ - **Batch size**: 16 (effective: 64)
69
+ - **Training steps**: 50000
70
+ - **Gradient clipping**: 1.0
71
+
72
+ ## Performance
73
+
74
+ ### Perplexity (TinyStories validation)
75
+
76
+ | Exit Layer | Perplexity | Speed (tok/s) |
77
+ |------------|------------|---------------|
78
+ | All layers | TBD | TBD |
79
+ | Layer 18 | TBD | TBD |
80
+ | Layer 12 | TBD | TBD |
81
+ | Layer 6 | TBD | TBD |
82
+
83
+ ### Training Stability
84
+
85
+ - **Gradient norms**: 2-5
86
+ - **Final loss**: TBD
87
+
88
+ ## Usage
89
+
90
+ ### Installation
91
+
92
+ ```bash
93
+ pip install transformers torch
94
+ ```
95
+
96
+ ### Basic Inference
97
+
98
+ ```python
99
+ from transformers import AutoTokenizer, AutoModelForCausalLM
100
+
101
+ # Load model
102
+ model = AutoModelForCausalLM.from_pretrained("your-username/bitskip-v1-earlyexit")
103
+ tokenizer = AutoTokenizer.from_pretrained("your-username/bitskip-v1-earlyexit")
104
+
105
+ # Generate text
106
+ inputs = tokenizer("Once upon a time", return_tensors="pt")
107
+ outputs = model.generate(**inputs, max_length=100)
108
+ print(tokenizer.decode(outputs[0]))
109
+ ```
110
+
111
+ ### Early Exit Inference
112
+
113
+ ```python
114
+ # Exit at layer 12 for faster inference
115
+ model.set_exit_layer(12)
116
+ outputs = model.generate(**inputs, max_length=100)
117
+ # 1.5-2x faster with minimal quality loss
118
+ ```
119
+
120
+ ### Benchmark Different Exit Layers
121
+
122
+ ```python
123
+ for exit_layer in [6, 12, 18, 24]:
124
+ model.set_exit_layer(exit_layer)
125
+ outputs = model.generate(**inputs, max_length=100)
126
+ print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}")
127
+ ```
128
+
129
+ ## Limitations
130
+
131
+ - **Inference speed**: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width
132
+ - **Training instability**: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning
133
+ - **Dataset scope**: Trained only on TinyStories; may not generalize to other domains without fine-tuning
134
+
135
+ ## Citation
136
+
137
+ If you use this model, please cite:
138
+
139
+ ```bibtex
140
+ @article{bitnet,
141
+ title={BitNet: Scaling 1-bit Transformers for Large Language Models},
142
+ author={Wang, Hongyu and Ma, Shuming and Dong, Li and others},
143
+ journal={arXiv preprint arXiv:2310.11453},
144
+ year={2023}
145
+ }
146
+
147
+ @article{layerskip,
148
+ title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
149
+ author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others},
150
+ journal={arXiv preprint arXiv:2404.16710},
151
+ year={2024}
152
+ }
153
+ ```
154
+
155
+ ## License
156
+
157
+ MIT License
158
+
159
+ ## Contact
160
+
161
+ For questions or issues, please open an issue on the model repository.
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BitSkipV1ForCausalLMWithEarlyExit"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model_v1_earlyexit.BitSkipV1EarlyExitConfig",
7
+ "AutoModelForCausalLM": "model_v1_earlyexit.BitSkipV1ForCausalLMWithEarlyExit"
8
+ },
9
+ "early_exit_loss_weight": 0.3,
10
+ "hidden_size": 2048,
11
+ "inference_exit_layer": null,
12
+ "intermediate_size": 4096,
13
+ "max_dropout_prob": 0.5,
14
+ "max_position_embeddings": 2048,
15
+ "model_type": "bitskip_v1_earlyexit",
16
+ "num_attention_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_key_value_heads": 8,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 10000.0,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.45.2",
23
+ "vocab_size": 50257
24
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.45.2"
4
+ }
inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for bitskip-v1-earlyexit
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ def main():
9
+ # Load from HuggingFace Hub or local path
10
+ model_path = "." # Current directory or specify repo_id
11
+
12
+ print("Loading model...")
13
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+
16
+ model.eval()
17
+ print("Model loaded!")
18
+
19
+ # Example generation
20
+ prompt = "Once upon a time"
21
+ inputs = tokenizer(prompt, return_tensors="pt")
22
+
23
+ print(f"\nPrompt: {prompt}\n")
24
+
25
+ # Full model
26
+ print("Generating with all layers...")
27
+ outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
28
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
29
+
30
+ # Early exit at layer 12
31
+ print("\nGenerating with early exit at layer 12...")
32
+ model.set_exit_layer(12)
33
+ outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
34
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
35
+
36
+ if __name__ == "__main__":
37
+ main()
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:643aa06158d00a0975ba01c1b8abd6b35d32df0c1051cb20b1f9dae331d9e958
3
+ size 3834689608
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Model files for bitskip-v1-earlyexit"""
models/bitlinear.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standard BitLinear layer for BitSkip v1 (8-bit activations, NO Hadamard transform)
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class BitLinear(nn.Module):
11
+ """
12
+ Standard BitLinear: Ternary weights + 8-bit activations.
13
+ NO Hadamard transform - direct quantization.
14
+ """
15
+
16
+ def __init__(self, in_features, out_features, bias=False):
17
+ super().__init__()
18
+ self.in_features = in_features
19
+ self.out_features = out_features
20
+
21
+ # Standard weight initialization
22
+ self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
23
+ self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
24
+
25
+ def forward(self, x):
26
+ """
27
+ Forward with 8-bit activation quantization and ternary weights.
28
+ Uses STE (Straight-Through Estimator) for gradients.
29
+ """
30
+ # 8-bit activation quantization
31
+ x_scale = x.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5)
32
+ x_quant = (x / x_scale * 127).round().clamp(-128, 127)
33
+ x_quant = x_quant / 127 * x_scale
34
+
35
+ # STE: quantized forward, full precision backward
36
+ if self.training:
37
+ x_quant = x + (x_quant - x).detach()
38
+
39
+ # Ternary weight quantization
40
+ w_scale = self.weight.abs().mean().clamp(min=1e-5)
41
+ w_quant = torch.zeros_like(self.weight)
42
+ w_quant[self.weight > 0.5 * w_scale] = 1.0
43
+ w_quant[self.weight < -0.5 * w_scale] = -1.0
44
+ w_quant = w_quant * w_scale
45
+
46
+ # STE for weights
47
+ if self.training:
48
+ w_quant = self.weight + (w_quant - self.weight).detach()
49
+
50
+ # Standard linear operation
51
+ return F.linear(x_quant, w_quant, self.bias)
models/model_v1_earlyexit.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BitSkip v1 with Early Exit Loss and Quadratic Dropout
3
+ - BitLinear quantization (8-bit)
4
+ - Quadratic layer dropout (0 to 0.5 progression, sum=1 constraint)
5
+ - Early exit loss from all layers
6
+ - HuggingFace compatible
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import math
13
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+ from typing import Optional, Tuple
16
+
17
+ from .bitlinear import BitLinear
18
+
19
+
20
+ class BitSkipV1EarlyExitConfig(PretrainedConfig):
21
+ model_type = "bitskip_v1_earlyexit"
22
+
23
+ def __init__(
24
+ self,
25
+ vocab_size=50257,
26
+ hidden_size=2048,
27
+ num_hidden_layers=24,
28
+ num_attention_heads=32,
29
+ num_key_value_heads=8,
30
+ intermediate_size=4096,
31
+ max_position_embeddings=2048,
32
+ rms_norm_eps=1e-5,
33
+ rope_theta=10000.0,
34
+ # Early exit parameters
35
+ early_exit_loss_weight=0.3,
36
+ # Quadratic dropout parameters
37
+ max_dropout_prob=0.5,
38
+ # Inference
39
+ inference_exit_layer=None,
40
+ **kwargs
41
+ ):
42
+ self.vocab_size = vocab_size
43
+ self.hidden_size = hidden_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_attention_heads = num_attention_heads
46
+ self.num_key_value_heads = num_key_value_heads
47
+ self.intermediate_size = intermediate_size
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.rms_norm_eps = rms_norm_eps
50
+ self.rope_theta = rope_theta
51
+ self.early_exit_loss_weight = early_exit_loss_weight
52
+ self.max_dropout_prob = max_dropout_prob
53
+ self.inference_exit_layer = inference_exit_layer
54
+ super().__init__(**kwargs)
55
+
56
+
57
+ class QuadraticLayerDropout(nn.Module):
58
+ """
59
+ Quadratic layer dropout: p_l = p_max * (l/L)^2
60
+ Normalized so sum of probabilities = 1
61
+ """
62
+
63
+ def __init__(self, num_layers, max_dropout_prob=0.5):
64
+ super().__init__()
65
+ self.num_layers = num_layers
66
+ self.max_dropout_prob = max_dropout_prob
67
+
68
+ # Compute quadratic dropout probabilities
69
+ dropout_probs = []
70
+ for i in range(num_layers):
71
+ # Quadratic: p_l = p_max * (l/L)^2
72
+ prob = max_dropout_prob * ((i / max(num_layers - 1, 1)) ** 2)
73
+ dropout_probs.append(prob)
74
+
75
+ # Normalize so sum = 1 (as per requirement)
76
+ total_prob = sum(dropout_probs)
77
+ if total_prob > 0:
78
+ dropout_probs = [p / total_prob for p in dropout_probs]
79
+
80
+ self.dropout_probs = dropout_probs
81
+
82
+ def should_drop_layer(self, layer_idx):
83
+ """Returns True if layer should be dropped during training."""
84
+ if not self.training or layer_idx >= self.num_layers - 1: # Never drop last layer
85
+ return False
86
+ return torch.rand(1).item() < self.dropout_probs[layer_idx]
87
+
88
+
89
+ class RMSNorm(nn.Module):
90
+ def __init__(self, hidden_size, eps=1e-6):
91
+ super().__init__()
92
+ self.weight = nn.Parameter(torch.ones(hidden_size))
93
+ self.variance_epsilon = eps
94
+
95
+ def forward(self, hidden_states):
96
+ input_dtype = hidden_states.dtype
97
+ hidden_states = hidden_states.to(torch.float32)
98
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
99
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
100
+ return self.weight * hidden_states.to(input_dtype)
101
+
102
+
103
+ class RotaryEmbedding(nn.Module):
104
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
105
+ super().__init__()
106
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
107
+ self.register_buffer("inv_freq", inv_freq)
108
+
109
+ def forward(self, x, position_ids):
110
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
111
+ position_ids_expanded = position_ids[:, None, :].float()
112
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
113
+ emb = torch.cat((freqs, freqs), dim=-1)
114
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
115
+
116
+
117
+ def rotate_half(x):
118
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
119
+ return torch.cat((-x2, x1), dim=-1)
120
+
121
+
122
+ def apply_rotary_pos_emb(q, k, cos, sin):
123
+ q_embed = (q * cos) + (rotate_half(q) * sin)
124
+ k_embed = (k * cos) + (rotate_half(k) * sin)
125
+ return q_embed, k_embed
126
+
127
+
128
+ class BitSkipV1Attention(nn.Module):
129
+ def __init__(self, config):
130
+ super().__init__()
131
+ self.hidden_size = config.hidden_size
132
+ self.num_heads = config.num_attention_heads
133
+ self.head_dim = self.hidden_size // self.num_heads
134
+ self.num_key_value_heads = config.num_key_value_heads
135
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
136
+
137
+ self.q_proj = BitLinear(self.hidden_size, self.num_heads * self.head_dim)
138
+ self.k_proj = BitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim)
139
+ self.v_proj = BitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim)
140
+ self.o_proj = BitLinear(self.hidden_size, self.hidden_size)
141
+
142
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
143
+
144
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
145
+ bsz, q_len, _ = hidden_states.size()
146
+
147
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
148
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
149
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
150
+
151
+ cos, sin = self.rotary_emb(value_states, position_ids)
152
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
153
+
154
+ if past_key_value is not None:
155
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
156
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
157
+
158
+ past_key_value = (key_states, value_states) if use_cache else None
159
+
160
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
161
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
162
+
163
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
164
+ if attention_mask is not None:
165
+ attn_weights = attn_weights + attention_mask
166
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
167
+ attn_output = torch.matmul(attn_weights, value_states)
168
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
169
+ attn_output = self.o_proj(attn_output)
170
+
171
+ return attn_output, None, past_key_value
172
+
173
+
174
+ class BitSkipV1MLP(nn.Module):
175
+ def __init__(self, config):
176
+ super().__init__()
177
+ self.gate_proj = BitLinear(config.hidden_size, config.intermediate_size)
178
+ self.up_proj = BitLinear(config.hidden_size, config.intermediate_size)
179
+ self.down_proj = BitLinear(config.intermediate_size, config.hidden_size)
180
+
181
+ def forward(self, x):
182
+ return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
183
+
184
+
185
+ class BitSkipV1DecoderLayer(nn.Module):
186
+ def __init__(self, config):
187
+ super().__init__()
188
+ self.self_attn = BitSkipV1Attention(config)
189
+ self.mlp = BitSkipV1MLP(config)
190
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
191
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
192
+
193
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
194
+ residual = hidden_states
195
+ hidden_states = self.input_layernorm(hidden_states)
196
+ hidden_states, _, present_key_value = self.self_attn(
197
+ hidden_states, attention_mask, position_ids, past_key_value, use_cache
198
+ )
199
+ hidden_states = residual + hidden_states
200
+
201
+ residual = hidden_states
202
+ hidden_states = self.post_attention_layernorm(hidden_states)
203
+ hidden_states = self.mlp(hidden_states)
204
+ hidden_states = residual + hidden_states
205
+
206
+ return (hidden_states,) + ((present_key_value,) if use_cache else ())
207
+
208
+
209
+ class BitSkipV1PreTrainedModel(PreTrainedModel):
210
+ config_class = BitSkipV1EarlyExitConfig
211
+ base_model_prefix = "model"
212
+ supports_gradient_checkpointing = True
213
+
214
+ def _init_weights(self, module):
215
+ if isinstance(module, (nn.Linear, BitLinear)):
216
+ if hasattr(module, 'weight'):
217
+ module.weight.data.normal_(mean=0.0, std=0.02)
218
+ if hasattr(module, 'bias') and module.bias is not None:
219
+ module.bias.data.zero_()
220
+ elif isinstance(module, nn.Embedding):
221
+ module.weight.data.normal_(mean=0.0, std=0.02)
222
+
223
+
224
+ class BitSkipV1Model(BitSkipV1PreTrainedModel):
225
+ def __init__(self, config):
226
+ super().__init__(config)
227
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
228
+ self.layers = nn.ModuleList([BitSkipV1DecoderLayer(config) for _ in range(config.num_hidden_layers)])
229
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
230
+ self.gradient_checkpointing = False
231
+
232
+ # Quadratic dropout module
233
+ self.layer_dropout = QuadraticLayerDropout(config.num_hidden_layers, config.max_dropout_prob)
234
+
235
+ self.post_init()
236
+
237
+ def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, output_hidden_states=False, return_all_layer_outputs=False):
238
+ hidden_states = self.embed_tokens(input_ids)
239
+
240
+ if position_ids is None:
241
+ position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device)
242
+ position_ids = position_ids.unsqueeze(0)
243
+
244
+ next_decoder_cache = () if use_cache else None
245
+ all_layer_hidden_states = [] # Store ALL layer outputs for early exit loss
246
+
247
+ # Determine layers to run
248
+ num_layers_to_run = self.config.inference_exit_layer if self.config.inference_exit_layer else len(self.layers)
249
+ num_layers_to_run = min(num_layers_to_run, len(self.layers))
250
+
251
+ for idx in range(num_layers_to_run):
252
+ layer = self.layers[idx]
253
+ past_key_value = past_key_values[idx] if past_key_values else None
254
+
255
+ # Apply quadratic dropout during training
256
+ if self.training and self.layer_dropout.should_drop_layer(idx):
257
+ # Skip this layer - hidden_states pass through unchanged
258
+ all_layer_hidden_states.append(hidden_states)
259
+ continue
260
+
261
+ if self.gradient_checkpointing and self.training:
262
+ layer_outputs = self._gradient_checkpointing_func(
263
+ layer.__call__,
264
+ hidden_states,
265
+ attention_mask,
266
+ position_ids,
267
+ past_key_value,
268
+ use_cache,
269
+ )
270
+ else:
271
+ layer_outputs = layer(hidden_states, attention_mask, position_ids, past_key_value, use_cache)
272
+
273
+ hidden_states = layer_outputs[0]
274
+ all_layer_hidden_states.append(hidden_states)
275
+
276
+ if use_cache:
277
+ next_decoder_cache += (layer_outputs[1],)
278
+
279
+ hidden_states = self.norm(hidden_states)
280
+ all_layer_hidden_states.append(hidden_states) # Add final normed output
281
+
282
+ if return_all_layer_outputs:
283
+ return hidden_states, next_decoder_cache, all_layer_hidden_states
284
+ else:
285
+ return hidden_states, next_decoder_cache, None
286
+
287
+
288
+ class BitSkipV1ForCausalLMWithEarlyExit(BitSkipV1PreTrainedModel, GenerationMixin):
289
+ """BitSkip v1 with early exit loss."""
290
+
291
+ _tied_weights_keys = ["lm_head.weight"]
292
+
293
+ def __init__(self, config):
294
+ super().__init__(config)
295
+ self.model = BitSkipV1Model(config)
296
+ self.vocab_size = config.vocab_size
297
+
298
+ # Shared LM head for all layers (LayerSkip approach)
299
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
300
+
301
+ self.post_init()
302
+
303
+ def get_input_embeddings(self):
304
+ return self.model.embed_tokens
305
+
306
+ def set_input_embeddings(self, value):
307
+ self.model.embed_tokens = value
308
+
309
+ def get_output_embeddings(self):
310
+ return self.lm_head
311
+
312
+ def set_output_embeddings(self, new_embeddings):
313
+ self.lm_head = new_embeddings
314
+
315
+ def compute_early_exit_loss(self, all_layer_hidden_states, labels):
316
+ """
317
+ Compute early exit loss from all layers.
318
+ Uses layer-proportional weighting: w_i = i/N
319
+ """
320
+ num_layers = len(all_layer_hidden_states)
321
+
322
+ # Compute weights: layer-proportional (deeper layers get more weight)
323
+ weights = [(i + 1) / num_layers for i in range(num_layers)]
324
+ weight_sum = sum(weights)
325
+ weights = [w / weight_sum for w in weights] # Normalize
326
+
327
+ total_exit_loss = 0.0
328
+
329
+ for i, hidden_states in enumerate(all_layer_hidden_states):
330
+ # Get logits for this layer
331
+ logits = self.lm_head(hidden_states)
332
+
333
+ # Compute cross-entropy loss
334
+ shift_logits = logits[..., :-1, :].contiguous()
335
+ shift_labels = labels[..., 1:].contiguous()
336
+
337
+ loss_fct = nn.CrossEntropyLoss()
338
+ layer_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
339
+
340
+ # Weight and accumulate
341
+ total_exit_loss += weights[i] * layer_loss
342
+
343
+ return total_exit_loss
344
+
345
+ def forward(
346
+ self,
347
+ input_ids=None,
348
+ attention_mask=None,
349
+ position_ids=None,
350
+ past_key_values=None,
351
+ inputs_embeds=None,
352
+ labels=None,
353
+ use_cache=None,
354
+ output_attentions=None,
355
+ output_hidden_states=None,
356
+ return_dict=None,
357
+ ):
358
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
359
+
360
+ # Get all layer outputs if training with labels (for early exit loss)
361
+ return_all = self.training and labels is not None
362
+
363
+ hidden_states, past_key_values_output, all_layer_hidden_states = self.model(
364
+ input_ids=input_ids,
365
+ attention_mask=attention_mask,
366
+ position_ids=position_ids,
367
+ past_key_values=past_key_values,
368
+ use_cache=use_cache,
369
+ output_hidden_states=output_hidden_states,
370
+ return_all_layer_outputs=return_all,
371
+ )
372
+
373
+ logits = self.lm_head(hidden_states)
374
+ logits = logits.float()
375
+
376
+ loss = None
377
+ if labels is not None:
378
+ # Main loss (final layer)
379
+ shift_logits = logits[..., :-1, :].contiguous()
380
+ shift_labels = labels[..., 1:].contiguous()
381
+ loss_fct = nn.CrossEntropyLoss()
382
+ main_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
383
+
384
+ # Early exit loss (all intermediate layers)
385
+ if all_layer_hidden_states is not None and len(all_layer_hidden_states) > 0:
386
+ early_exit_loss = self.compute_early_exit_loss(all_layer_hidden_states[:-1], labels) # Exclude final layer
387
+
388
+ # Combine: main + weighted early exit
389
+ loss = main_loss + self.config.early_exit_loss_weight * early_exit_loss
390
+ else:
391
+ loss = main_loss
392
+
393
+ if not return_dict:
394
+ output = (logits,) + (past_key_values_output,)
395
+ return (loss,) + output if loss is not None else output
396
+
397
+ return CausalLMOutputWithPast(
398
+ loss=loss,
399
+ logits=logits,
400
+ past_key_values=past_key_values_output,
401
+ hidden_states=None,
402
+ attentions=None,
403
+ )
404
+
405
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
406
+ if past_key_values is not None:
407
+ past_length = past_key_values[0][0].shape[2]
408
+ if input_ids.shape[1] > past_length:
409
+ remove_prefix_length = past_length
410
+ else:
411
+ remove_prefix_length = input_ids.shape[1] - 1
412
+ input_ids = input_ids[:, remove_prefix_length:]
413
+
414
+ position_ids = kwargs.get("position_ids", None)
415
+ if attention_mask is not None and position_ids is None:
416
+ position_ids = attention_mask.long().cumsum(-1) - 1
417
+ position_ids.masked_fill_(attention_mask == 0, 1)
418
+ if past_key_values:
419
+ position_ids = position_ids[:, -input_ids.shape[1] :]
420
+
421
+ if inputs_embeds is not None and past_key_values is None:
422
+ model_inputs = {"inputs_embeds": inputs_embeds}
423
+ else:
424
+ model_inputs = {"input_ids": input_ids}
425
+
426
+ model_inputs.update({
427
+ "position_ids": position_ids,
428
+ "past_key_values": past_key_values,
429
+ "use_cache": kwargs.get("use_cache"),
430
+ "attention_mask": attention_mask,
431
+ })
432
+ return model_inputs
433
+
434
+ @staticmethod
435
+ def _reorder_cache(past_key_values, beam_idx):
436
+ reordered_past = ()
437
+ for layer_past in past_key_values:
438
+ reordered_past += (
439
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
440
+ )
441
+ return reordered_past
442
+
443
+ def set_exit_layer(self, exit_layer):
444
+ """Set early exit layer for inference."""
445
+ self.config.inference_exit_layer = exit_layer
446
+ self.model.config.inference_exit_layer = exit_layer
447
+
448
+
449
+ BitSkipV1EarlyExitConfig.register_for_auto_class()
450
+ BitSkipV1ForCausalLMWithEarlyExit.register_for_auto_class("AutoModelForCausalLM")
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "model_max_length": 1024,
17
+ "pad_token": "<|endoftext|>",
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff