yizheapple commited on
Commit
10d76e2
·
verified ·
1 Parent(s): acccd4b
README.md CHANGED
@@ -1,3 +1,56 @@
1
- ---
2
- license: apple-amlr
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: unknown
3
+ base_model:
4
+ - mistralai/Mistral-7B-Instruct-v0.2
5
+ tags:
6
+ - rag
7
+ - compression
8
+ - retrieval
9
+ - end-to-end
10
+ - generation
11
+ ---
12
+
13
+ # CLaRa-7B-E2E (Compression-16 & 128)
14
+
15
+ The **CLaRa-7B-E2E** model is our fully end-to-end unified RAG model, jointly optimizing retrieval and generation with 16× and 128x document compression.
16
+
17
+ **Training recipe:** End-to-end finetuning with differentiable top-k retrieval and a unified language-modeling objective.
18
+ **Benchmarks:** Strong retrieval-augmented QA performance under aggressive compression.
19
+
20
+ ---
21
+
22
+ ## More details and usage examples:
23
+
24
+ Paper: [CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning](https://arxiv.org/abs/2511.18659)
25
+ GitHub: https://github.com/apple/ml-clara
26
+
27
+ ---
28
+
29
+ ## Example Usage (End-to-End Inference)
30
+
31
+ ```python
32
+ from transformers import AutoModel
33
+
34
+ unirag = AutoModel.from_pretrained(
35
+ "/mnt/ceph_rbd/model/CLaRa-7B-E2E/compression-16",
36
+ trust_remote_code=True
37
+ ).to("cuda")
38
+
39
+ # Example documents and question
40
+ documents = [[
41
+ "Weldenia is a monotypic genus of flowering plant in the family Commelinaceae...",
42
+ ] * 20]
43
+
44
+ questions = [
45
+ "Which genus of plant grows originally in Mexico and Guatemala, Phylica or Weldenia?"
46
+ ]
47
+
48
+ # End-to-end usage (retrieval + generation)
49
+ # The effective top-k is controlled by `generation_top_k` in config.json.
50
+ out = unirag.generate_from_questions(
51
+ questions=questions,
52
+ documents=documents,
53
+ max_new_tokens=64
54
+ )
55
+
56
+ print("Generated answer", out)
compression-128/adapters.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf779cf29ab86f4a0592370e6b66664b7448c53c59b86de811c4c2849867b230
3
+ size 252096669
compression-128/chat_template.jinja ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if messages[0]['role'] == 'system' %}
2
+ {%- set system_message = messages[0]['content'] %}
3
+ {%- set loop_messages = messages[1:] %}
4
+ {%- else %}
5
+ {%- set loop_messages = messages %}
6
+ {%- endif %}
7
+
8
+ {{- bos_token }}
9
+ {%- for message in loop_messages %}
10
+ {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
11
+ {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
12
+ {%- endif %}
13
+ {%- if message['role'] == 'user' %}
14
+ {%- if loop.first and system_message is defined %}
15
+ {{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}
16
+ {%- else %}
17
+ {{- ' [INST] ' + message['content'] + ' [/INST]' }}
18
+ {%- endif %}
19
+ {%- elif message['role'] == 'assistant' %}
20
+ {{- ' ' + message['content'] + eos_token}}
21
+ {%- else %}
22
+ {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}
23
+ {%- endif %}
24
+ {%- endfor %}
compression-128/config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ae_mode": "token",
3
+ "attn_implementation": null,
4
+ "auto_map": {
5
+ "AutoConfig": "modeling_clara.CLaRaConfig",
6
+ "AutoModel": "modeling_clara.CLaRa"
7
+ },
8
+ "compr_base_model_name": "/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2",
9
+ "compr_every_n_layer": null,
10
+ "compr_linear_type": "concat",
11
+ "compr_mlp_hidden_dim": 8096,
12
+ "compr_model_name": null,
13
+ "compr_n_layers": 5,
14
+ "compr_rate": 128,
15
+ "compr_rms_norm": false,
16
+ "compr_use_mlp": false,
17
+ "decoder_model_name": "/mnt/conductor_data/data/hf_models/Mistral-7B-Instruct-v0.2",
18
+ "device_map": null,
19
+ "different_mem_tokens": true,
20
+ "doc_max_length": 256,
21
+ "generation_top_k": 5,
22
+ "kbtc_training": false,
23
+ "load_adapters": true,
24
+ "load_pretrained_checkpoint": false,
25
+ "lora": true,
26
+ "lora_compressor": false,
27
+ "lora_r": 16,
28
+ "lora_r_compressor": 16,
29
+ "max_new_tokens": 128,
30
+ "model_type": "CLaRa",
31
+ "optimize_mem_tokens": true,
32
+ "pad_token_id": 2,
33
+ "pure_inference": false,
34
+ "quantization": "no",
35
+ "sep": true,
36
+ "stage2_retrieval_top_n": 1,
37
+ "training_form": "both_separately",
38
+ "training_stage": "stage2",
39
+ "transformers_version": "4.53.3"
40
+ }
compression-128/decoder_first_last_layers.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3029ac143f3cc3a23462daebe83b2eddc4bba5117530a7e93ea28fb59ac44e06
3
+ size 524372021
compression-128/modeling_clara.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+
6
+ import warnings
7
+ import os
8
+ import torch
9
+ import gc
10
+ import time
11
+ import json
12
+ import copy
13
+ import random
14
+ import requests
15
+ import re
16
+
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ from torch.nn.functional import gelu
20
+ from jinja2.exceptions import TemplateError
21
+ from peft import LoraConfig
22
+ from transformers import (
23
+ AutoModelForCausalLM,
24
+ AutoTokenizer,
25
+ BitsAndBytesConfig,
26
+ PreTrainedModel,
27
+ PretrainedConfig,
28
+ StoppingCriteria,
29
+ StoppingCriteriaList
30
+ )
31
+ from huggingface_hub import hf_hub_download
32
+ from typing import List, Dict, Any, Optional, Tuple
33
+
34
+ # Environment setup
35
+ torch.set_printoptions(threshold=float("inf"))
36
+ os.environ["NCCL_TIMEOUT"] = "5400"
37
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
38
+
39
+ # Constants
40
+ IGNORE_INDEX = -100
41
+ PARAPHRASE_INSTRUCTIONS = [
42
+ 'Background: {docs} means the same as',
43
+ "Background: {docs} Can you put the above sentences in your own terms?",
44
+ "Background: {docs} Please provide a reinterpretation of the preceding background text.",
45
+ "These two expressions are equivalent in essence:\n(1) {docs}\n(2)",
46
+ "Background: {docs} is a paraphrase of what?",
47
+ "Background: {docs} Could you give me a different version of the background sentences above?",
48
+ "In other words, background: {docs} is just another way of saying:",
49
+ "You're getting across the same point whether you say background: {docs} or",
50
+ "Background: {docs} After unpacking the ideas in the background information above, we got:",
51
+ "Background: {docs} Please offer a restatement of the background sentences I've just read.",
52
+ "Background: {docs}, which also means:",
53
+ "Strip away the mystery, and you'll find background: {docs} is simply another rendition of:",
54
+ "The essence of background: {docs} is captured again in the following statement:",
55
+ ]
56
+
57
+
58
+ class StopOnCriteria(StoppingCriteria):
59
+ """Custom stopping criteria for generation."""
60
+
61
+ def __init__(self, tokenizer, stop_strings: List[str] = None, stop_token_ids: List[int] = None):
62
+ self.tokenizer = tokenizer
63
+ self.stop_strings = stop_strings or []
64
+ self.stop_token_ids = stop_token_ids or []
65
+ self.reason = None
66
+
67
+ def __call__(self, input_ids, scores, **kwargs):
68
+ # Check if last token is in stop_token_ids
69
+ last_token = input_ids[0, -1].item()
70
+ if last_token in self.stop_token_ids:
71
+ self.reason = f"stop_token_{last_token}"
72
+ return True
73
+
74
+ # Check if any stop_strings appear in generated text
75
+ text = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
76
+ for stop_str in self.stop_strings:
77
+ if stop_str in text:
78
+ self.reason = f"stop_string_{stop_str}"
79
+ return True
80
+
81
+ return False
82
+
83
+
84
+ class LlamaRMSNorm(nn.Module):
85
+ """Llama-style RMS normalization layer."""
86
+
87
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
88
+ super().__init__()
89
+ self.weight = nn.Parameter(torch.ones(hidden_size))
90
+ self.variance_epsilon = eps
91
+
92
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
93
+ input_dtype = hidden_states.dtype
94
+ hidden_states = hidden_states.to(torch.float32)
95
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
96
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
97
+ return self.weight * hidden_states.to(input_dtype)
98
+
99
+
100
+ class Converter(nn.Module):
101
+ """Converter module for dimension transformation."""
102
+
103
+ def __init__(self, input_dim: int, output_dim: int):
104
+ super().__init__()
105
+ self.input_dim = input_dim
106
+ self.output_dim = output_dim
107
+
108
+ self.rms_norm = LlamaRMSNorm(input_dim)
109
+ self.dense_in = nn.Linear(input_dim, output_dim)
110
+ self.dense_out = nn.Linear(output_dim, output_dim)
111
+
112
+ self._print_trainable_parameters()
113
+
114
+ def _print_trainable_parameters(self):
115
+ """Print parameter statistics."""
116
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
117
+ total_params = sum(p.numel() for p in self.parameters())
118
+ print(f"Converter trainable parameters: {trainable_params}, Total parameters: {total_params}")
119
+
120
+ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
121
+ embeddings = self.rms_norm(embeddings)
122
+ x = self.dense_in(embeddings)
123
+ x = self.dense_out(gelu(x))
124
+ return x.to(torch.float32)
125
+
126
+
127
+ class CLaRaConfig(PretrainedConfig):
128
+ """Configuration class for CLaRa model."""
129
+
130
+ model_type = "CLaRa"
131
+
132
+ def __init__(self,
133
+ decoder_model_name: str = "meta-llama/Llama-2-7b-chat-hf",
134
+ doc_max_length: int = 128,
135
+ quantization: str = 'no',
136
+ sep: bool = False,
137
+ compr_model_name: str = "google-bert/bert-base-uncased",
138
+ compr_rate: int = 64,
139
+ compr_n_layers: int = None,
140
+ compr_every_n_layer: int = None,
141
+ compr_base_model_name: str = '/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2',
142
+ compr_rms_norm: bool = False,
143
+ compr_mlp_hidden_dim: int = 8096,
144
+ compr_use_mlp: bool = True,
145
+ compr_linear_type: str = "concat",
146
+ lora: bool = False,
147
+ lora_compressor: bool = False,
148
+ training_form: str = "both",
149
+ training_stage: str = "stage1",
150
+ generation_top_k: int = 1,
151
+ lora_r: int = 16,
152
+ lora_r_compressor: int = None,
153
+ load_adapters: bool = True,
154
+ kbtc_training: bool = False,
155
+ optimize_mem_tokens: bool = False,
156
+ different_mem_tokens: bool = False,
157
+ attn_implementation: str = None,
158
+ _attn_implementation_autoset: bool = True,
159
+ ae_mode: str = "token",
160
+ max_new_tokens: int = 128,
161
+ stage2_retrieval_top_n: int = 1,
162
+ load_pretrained_checkpoint: bool = False,
163
+ device_map=None,
164
+ auto_map: dict = {
165
+ "AutoConfig": "modeling_clara.CLaRaConfig",
166
+ "AutoModel": "modeling_clara.CLaRa"
167
+ },
168
+ **kwargs):
169
+ super().__init__(**kwargs)
170
+
171
+ self.decoder_model_name = decoder_model_name
172
+ self.doc_max_length = doc_max_length
173
+ self.quantization = quantization
174
+ self.sep = sep
175
+
176
+ self.compr_model_name = compr_model_name
177
+ self.compr_rate = compr_rate
178
+ self.compr_use_mlp = compr_use_mlp
179
+ self.compr_mlp_hidden_dim = compr_mlp_hidden_dim
180
+ self.compr_n_layers = compr_n_layers
181
+ self.compr_every_n_layer = compr_every_n_layer
182
+ self.compr_base_model_name = compr_base_model_name
183
+ self.compr_rms_norm = compr_rms_norm
184
+ self.compr_linear_type = compr_linear_type
185
+
186
+ self.lora = lora
187
+ self.lora_compressor = lora_compressor
188
+ self.training_form = training_form
189
+ self.lora_r = lora_r
190
+ self.lora_r_compressor = lora_r_compressor or lora_r
191
+ self.load_adapters = load_adapters
192
+ self.optimize_mem_tokens = optimize_mem_tokens
193
+ self.different_mem_tokens = different_mem_tokens
194
+ self.kbtc_training = kbtc_training
195
+ self.training_stage = training_stage
196
+ self.device_map = device_map
197
+ self.attn_implementation = attn_implementation
198
+ self._attn_implementation_autoset = _attn_implementation_autoset
199
+ self.ae_mode = ae_mode
200
+ self.max_new_tokens = max_new_tokens
201
+ self.auto_map = auto_map
202
+ self.load_pretrained_checkpoint = load_pretrained_checkpoint
203
+
204
+ self.generation_top_k = generation_top_k
205
+ self.stage2_retrieval_top_n = stage2_retrieval_top_n
206
+
207
+ if training_form == 'compressor':
208
+ assert compr_model_name is not None and not self.lora
209
+
210
+
211
+ # Utility functions
212
+ def remote_generate(docs: List[str], questions: List[str], api_url: str) -> List[str]:
213
+ """Generate responses using remote API."""
214
+ response = requests.post(
215
+ f"{api_url}/generate",
216
+ json={"docs": docs, "questions": questions}
217
+ )
218
+ return response.json()["texts"]
219
+
220
+
221
+ def add_memory_tokens_to_inputs(input_ids: torch.Tensor,
222
+ attention_mask: torch.Tensor,
223
+ n_mem_tokens: int,
224
+ tokenizer) -> Tuple[torch.Tensor, torch.Tensor]:
225
+ """Add memory tokens to input sequences."""
226
+ assert len(tokenizer.mem_tokens) == n_mem_tokens
227
+
228
+ mem_tokens = torch.stack([tokenizer.mem_token_ids_pt] * input_ids.size(0), 0)
229
+ assert len(mem_tokens) == input_ids.size(0)
230
+ assert len(mem_tokens[0]) == n_mem_tokens
231
+
232
+ input_ids = torch.cat([input_ids, mem_tokens], dim=1)
233
+ attention_mask = torch.cat([attention_mask, torch.ones(input_ids.size(0), n_mem_tokens)], dim=1)
234
+
235
+ return input_ids, attention_mask
236
+
237
+
238
+ def build_pos_mask(pos_index: List[List[int]], N: int, device: torch.device) -> torch.Tensor:
239
+ """Build positive mask for retrieval training."""
240
+ if isinstance(pos_index, (list, tuple)):
241
+ B = len(pos_index)
242
+ mask = torch.zeros(B, N, dtype=torch.bool, device=device)
243
+ for b, idxs in enumerate(pos_index):
244
+ if len(idxs) > 0:
245
+ mask[b, torch.as_tensor(idxs, device=device, dtype=torch.long)] = True
246
+ return mask
247
+ else: # tensor [B, M]
248
+ B, M = pos_index.shape
249
+ mask = torch.zeros(B, N, dtype=torch.bool, device=device)
250
+ for m in range(M):
251
+ col = pos_index[:, m]
252
+ v = col >= 0
253
+ if v.any():
254
+ mask[v, col[v]] = True
255
+ return mask
256
+
257
+
258
+ def differentiable_topk_top_1(logits: torch.Tensor, k: int, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
259
+ """Implements differentiable top-1 selection using Gumbel-Softmax."""
260
+ y = logits / temperature
261
+ y_soft = F.softmax(y, dim=-1).float()
262
+
263
+ # Hard one-hot version
264
+ index = y_soft.argmax(dim=-1, keepdim=True)
265
+ y_hard = torch.zeros_like(y_soft).scatter_(-1, index, 1.0)
266
+
267
+ # Straight-through estimator
268
+ z = y_hard + y_soft - y_soft.detach()
269
+ z = z.unsqueeze(1).to(logits.dtype)
270
+
271
+ return z, index
272
+
273
+
274
+ def differentiable_topk(logits: torch.Tensor, k: int, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ """Differentiable top-k selection."""
276
+ B, N = logits.shape
277
+ perturbed = logits / max(temperature, 1e-6)
278
+
279
+ # Hard top-k indices
280
+ topk_vals, topk_idx = perturbed.topk(k, dim=-1)
281
+ K_hard = torch.zeros(B, k, N, device=logits.device, dtype=logits.dtype)
282
+ K_hard.scatter_(2, topk_idx.unsqueeze(-1), 1.0)
283
+
284
+ # Soft distributions for each slot
285
+ K_soft = torch.zeros_like(K_hard)
286
+ taken = torch.zeros(B, N, device=logits.device, dtype=logits.dtype)
287
+
288
+ for j in range(k):
289
+ mask = (1.0 - taken.detach())
290
+ masked = perturbed + (mask + 1e-8).log()
291
+ pj = F.softmax(masked, dim=-1).float()
292
+ K_soft[:, j, :] = pj
293
+ taken = torch.clamp(taken + K_hard[:, j, :], max=1.0)
294
+
295
+ # Straight-through estimator
296
+ W = K_hard + (K_soft - K_soft.detach())
297
+ return W, topk_idx
298
+
299
+
300
+ class CLaRa(PreTrainedModel):
301
+ """CLaRa: Unified Retrieval-Augmented Generation Model."""
302
+
303
+ config_class = CLaRaConfig
304
+
305
+ def __init__(self, cfg: CLaRaConfig):
306
+ super().__init__(cfg)
307
+ self.decoder_model_name = cfg.decoder_model_name
308
+ self.decoder = self._create_decoder(cfg)
309
+ self.doc_max_length = cfg.doc_max_length
310
+
311
+ print(f'Base decoder parameters: {self.decoder.num_parameters()}')
312
+
313
+ # Model configuration
314
+ self.compr_model_name = cfg.compr_model_name
315
+ self.training_form = cfg.training_form
316
+ self.lora = cfg.lora
317
+ self.adapter_keys = []
318
+ self.compr = None
319
+
320
+ # Initialize LoRA adapters if needed
321
+ if cfg.lora and not getattr(cfg, 'pure_inference', False):
322
+ self._setup_lora_adapters(cfg)
323
+
324
+ print(f'Model adapter keys: {self.adapter_keys}')
325
+
326
+ # Initialize tokenizer and resize embeddings
327
+ self.decoder_tokenizer = self._create_decoder_tokenizer(cfg)
328
+ self.decoder.resize_token_embeddings(len(self.decoder_tokenizer))
329
+ self._configure_generation_config()
330
+
331
+ # Model parameters
332
+ self.generation_top_k = cfg.generation_top_k
333
+ self.training_stage = cfg.training_stage
334
+ self.stage2_retrieval_top_n = cfg.stage2_retrieval_top_n
335
+ self.sep = cfg.sep
336
+ self.compr_rate = cfg.compr_rate
337
+ self.local_rank = os.getenv('LOCAL_RANK', '0')
338
+
339
+ self.n_mem_tokens = self.doc_max_length // self.compr_rate
340
+ self.hidden_size = self.decoder.config.hidden_size
341
+
342
+ # Setup adapters and memory token optimization
343
+ if self.lora:
344
+ self._setup_adapter_training()
345
+ else:
346
+ print(f'Total trainable parameters: {self.num_parameters(only_trainable=True)}')
347
+
348
+ self._prepare_mem_tokens_optimization()
349
+
350
+ # Retrieval configuration
351
+ self.url_retrieval = "http://127.0.0.1:5004/queries"
352
+
353
+ def _create_decoder(self, cfg: CLaRaConfig) -> AutoModelForCausalLM:
354
+ """Create and configure the decoder model."""
355
+ if not torch.cuda.is_available():
356
+ return AutoModelForCausalLM.from_pretrained(
357
+ cfg.decoder_model_name,
358
+ torch_dtype=torch.bfloat16,
359
+ resume_download=True,
360
+ trust_remote_code=True,
361
+ device_map=cfg.device_map
362
+ )
363
+
364
+ if cfg.quantization == "no":
365
+ return AutoModelForCausalLM.from_pretrained(
366
+ cfg.decoder_model_name,
367
+ torch_dtype=torch.bfloat16,
368
+ attn_implementation=cfg.attn_implementation,
369
+ device_map=cfg.device_map
370
+ )
371
+ elif cfg.quantization == "int4":
372
+ quant_config = BitsAndBytesConfig(
373
+ load_in_4bit=True,
374
+ bnb_4bit_quant_type='nf4',
375
+ bnb_4bit_compute_dtype='bfloat16',
376
+ )
377
+ return AutoModelForCausalLM.from_pretrained(
378
+ cfg.decoder_model_name,
379
+ quantization_config=quant_config,
380
+ attn_implementation=cfg.attn_implementation,
381
+ torch_dtype=torch.bfloat16,
382
+ resume_download=True,
383
+ trust_remote_code=True,
384
+ device_map=cfg.device_map
385
+ )
386
+ elif cfg.quantization == "int8":
387
+ quant_config = BitsAndBytesConfig(
388
+ load_in_8bit=True,
389
+ llm_int8_enable_fp32_cpu_offload=True,
390
+ bnb_4bit_compute_dtype='bfloat16',
391
+ )
392
+ return AutoModelForCausalLM.from_pretrained(
393
+ cfg.decoder_model_name,
394
+ quantization_config=quant_config,
395
+ attn_implementation=cfg.attn_implementation,
396
+ torch_dtype=torch.bfloat16,
397
+ resume_download=True,
398
+ trust_remote_code=True,
399
+ device_map=cfg.device_map
400
+ )
401
+ else:
402
+ raise NotImplementedError(f"Quantization {cfg.quantization} not supported")
403
+
404
+ def _setup_lora_adapters(self, cfg: CLaRaConfig):
405
+ """Setup LoRA adapters based on training stage."""
406
+ peft_config = self._get_peft_config(lora_r=cfg.lora_r)
407
+
408
+ if cfg.training_stage == "stage1" and cfg.load_adapters:
409
+ print('Loading encoder and decoder adapter for stage1')
410
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
411
+ self.adapter_keys.append('decoder_adapter')
412
+ self.decoder.add_adapter(peft_config, 'encoder_adapter')
413
+ self.adapter_keys.append('encoder_adapter')
414
+ elif cfg.training_stage == "stage2" and cfg.load_adapters:
415
+ if 'decoder_adapter' not in self.adapter_keys:
416
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
417
+ self.adapter_keys.append('decoder_adapter')
418
+ if 'query_reasoner_adapter' not in self.adapter_keys:
419
+ self.decoder.add_adapter(peft_config, 'query_reasoner_adapter')
420
+ self.adapter_keys.append('query_reasoner_adapter')
421
+ elif cfg.training_stage == 'stage1_2':
422
+ if not cfg.load_adapters:
423
+ print('Loading decoder adapter for stage1_2')
424
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
425
+ self.adapter_keys.append('decoder_adapter')
426
+ elif cfg.load_adapters:
427
+ print('Loading encoder and decoder adapter for stage1_2')
428
+ self.decoder.add_adapter(peft_config, 'encoder_adapter')
429
+ self.adapter_keys.append('encoder_adapter')
430
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
431
+ self.adapter_keys.append('decoder_adapter')
432
+ elif cfg.training_stage == 'stage2_reasoning':
433
+ if not cfg.load_adapters:
434
+ print('Loading decoder adapter for stage2_reasoning')
435
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
436
+ self.adapter_keys.append('decoder_adapter')
437
+
438
+ def _setup_adapter_training(self):
439
+ """Setup adapters for training."""
440
+ for adapter_key in self.adapter_keys:
441
+ self.decoder.set_adapter(adapter_key)
442
+ print(f'Adapter {adapter_key} trainable parameters: {self.num_parameters(only_trainable=True)}')
443
+ self._set_all_adapters()
444
+
445
+ def _configure_generation_config(self):
446
+ """Configure generation parameters."""
447
+ self.decoder.generation_config.top_p = None
448
+ self.decoder.generation_config.temperature = None
449
+ self.decoder.generation_config.pad_token_id = self.decoder_tokenizer.pad_token_id
450
+
451
+ @staticmethod
452
+ def _create_decoder_tokenizer(cfg: CLaRaConfig) -> AutoTokenizer:
453
+ """Create and configure the decoder tokenizer."""
454
+ tokenizer = AutoTokenizer.from_pretrained(
455
+ cfg.decoder_model_name,
456
+ use_fast=True,
457
+ padding_side='left'
458
+ )
459
+
460
+ # Define special tokens
461
+ n_mem_tokens = cfg.doc_max_length // cfg.compr_rate
462
+ existing_special_tokens = tokenizer.special_tokens_map.get("additional_special_tokens", [])
463
+
464
+ if cfg.different_mem_tokens:
465
+ mem_tokens = [f'<MEM{i}>' for i in range(n_mem_tokens)]
466
+ tokenizer.add_special_tokens({
467
+ 'additional_special_tokens': existing_special_tokens + mem_tokens + ['<AE>', '<ENC>', '<SEP>']
468
+ })
469
+ tokenizer.mem_tokens = mem_tokens
470
+ else:
471
+ tokenizer.add_special_tokens({
472
+ 'additional_special_tokens': existing_special_tokens + ['<MEM>', '<AE>', '<ENC>', '<SEP>']
473
+ })
474
+ tokenizer.mem_tokens = ['<MEM>'] * n_mem_tokens
475
+
476
+ tokenizer.mem_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokenizer.mem_tokens]
477
+ tokenizer.mem_token_ids_pt = torch.LongTensor(tokenizer.mem_token_ids)
478
+
479
+ # Additional special tokens
480
+ tokenizer.ae_token = '<AE>'
481
+ tokenizer.ae_token_id = tokenizer.convert_tokens_to_ids('<AE>')
482
+ tokenizer.enc_token = '<ENC>'
483
+ tokenizer.sep_token = '<SEP>'
484
+ tokenizer.sep_token_id = tokenizer.convert_tokens_to_ids('<SEP>')
485
+
486
+ # Handle model-specific tokens
487
+ if tokenizer.bos_token is None and 'qwen' in cfg.decoder_model_name.lower():
488
+ tokenizer.bos_token = tokenizer.special_tokens_map['additional_special_tokens'][0]
489
+ tokenizer.bos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.bos_token)
490
+
491
+ if tokenizer.eos_token is None and "qwen" in cfg.decoder_model_name.lower():
492
+ tokenizer.eos_token = tokenizer.special_tokens_map['additional_special_tokens'][1]
493
+ tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
494
+
495
+ # KBTC training tokens
496
+ if cfg.kbtc_training:
497
+ tokenizer.add_special_tokens({'additional_special_tokens': ['<KBTC>']})
498
+ tokenizer.kbtc_token = '<KBTC>'
499
+ tokenizer.kbtc_token_id = tokenizer.convert_tokens_to_ids('<KBTC>')
500
+
501
+ # Set pad token
502
+ if tokenizer.pad_token_id is None:
503
+ tokenizer.pad_token_id = tokenizer.bos_token_id
504
+
505
+ print(f'Memory token count: {n_mem_tokens}')
506
+ return tokenizer
507
+
508
+ def _get_peft_config(self, lora_r: int) -> LoraConfig:
509
+ """Build the PEFT configuration."""
510
+ return LoraConfig(
511
+ task_type="CAUSAL_LM",
512
+ r=lora_r,
513
+ lora_alpha=2*lora_r,
514
+ target_modules='all-linear',
515
+ lora_dropout=0.1
516
+ )
517
+
518
+ def _prepare_mem_tokens_optimization(self):
519
+ """Setup memory token optimization if enabled."""
520
+ if self.config.optimize_mem_tokens and self.compr is None:
521
+ # Enable gradients for input embeddings
522
+ self.decoder.get_input_embeddings().weight.requires_grad = True
523
+
524
+ # Apply hook to zero gradients except for memory tokens
525
+ def hook(grad):
526
+ mask = torch.zeros_like(grad)
527
+ mask[self.decoder_tokenizer.mem_token_ids] = 1.0
528
+ return grad * mask
529
+
530
+ self.decoder.get_input_embeddings().weight.register_hook(hook)
531
+
532
+ def _set_all_adapters(self):
533
+ """Activate all adapters for training."""
534
+ if len(self.adapter_keys) > 0:
535
+ self.decoder.set_adapter(self.adapter_keys)
536
+
537
+ # Core compression and generation methods
538
+ def compress(self, enc_input_ids: torch.Tensor, enc_attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
539
+ """Compress input documents."""
540
+ if self.compr:
541
+ return self.compr(enc_input_ids, enc_attention_mask)
542
+ else:
543
+ return self._compr_decoder(enc_input_ids, enc_attention_mask)
544
+
545
+ def _compr_decoder(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
546
+ """Use decoder as compressor."""
547
+ assert input_ids.size() == attention_mask.size()
548
+
549
+ if 'encoder_adapter' in self.adapter_keys:
550
+ self.decoder.set_adapter('encoder_adapter')
551
+ else:
552
+ raise ValueError(f"encoder_adapter not in adapter_keys: {self.adapter_keys}")
553
+
554
+ # Get embeddings from decoder
555
+ emb = self.decoder(
556
+ input_ids=input_ids,
557
+ attention_mask=attention_mask,
558
+ output_hidden_states=True
559
+ ).hidden_states[-1]
560
+
561
+ # Create mask for memory tokens
562
+ mask = torch.isin(
563
+ input_ids,
564
+ self.decoder_tokenizer.mem_token_ids_pt.to(input_ids.device)
565
+ )
566
+
567
+ # Calculate MSE loss between memory and non-memory regions
568
+ attn = attention_mask.bool()
569
+ mem_mask = mask & attn
570
+ non_mem_mask = (~mask) & attn
571
+
572
+ mem_len = mem_mask.sum(dim=1)
573
+ non_mem_len = non_mem_mask.sum(dim=1)
574
+
575
+ if (mem_len == 0).any():
576
+ raise ValueError("Some samples have no memory tokens")
577
+ if (non_mem_len == 0).any():
578
+ raise ValueError("Some samples have no non-memory tokens")
579
+
580
+ mem_sum = (emb * mem_mask.unsqueeze(-1)).sum(dim=1)
581
+ non_mem_sum = (emb * non_mem_mask.unsqueeze(-1)).sum(dim=1)
582
+
583
+ mem_mean = mem_sum / mem_len.unsqueeze(-1)
584
+ non_mem_mean = non_mem_sum / non_mem_len.unsqueeze(-1)
585
+
586
+ mse_loss = F.mse_loss(non_mem_mean, mem_mean, reduction='mean')
587
+
588
+ return emb[mask].reshape(emb.size(0), -1, emb.size(-1)), mse_loss
589
+
590
+ def _compr_query_reasoner_stage2(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
591
+ """Query reasoning compression for stage 2."""
592
+ assert input_ids.size() == attention_mask.size()
593
+
594
+ if 'query_reasoner_adapter' in self.adapter_keys:
595
+ self.decoder.set_adapter('query_reasoner_adapter')
596
+ else:
597
+ raise ValueError(f"query_reasoner_adapter not in adapter_keys: {self.adapter_keys}")
598
+
599
+ emb = self.decoder(
600
+ input_ids=input_ids,
601
+ attention_mask=attention_mask,
602
+ output_hidden_states=True
603
+ ).hidden_states[-1]
604
+
605
+ mask = torch.isin(
606
+ input_ids,
607
+ self.decoder_tokenizer.mem_token_ids_pt.to(input_ids.device)
608
+ )
609
+
610
+ return emb[mask].reshape(emb.size(0), -1)
611
+
612
+ # Generation methods
613
+ def generate_from_questions(self,
614
+ questions: List[str],
615
+ max_new_tokens: int = 128,
616
+ temperature: float = 0.5,
617
+ documents: List[List[str]] = None,
618
+ stage2_mips: bool = False,
619
+ stage2_retrieval_top_n: int = None,
620
+ time_count: bool = False) -> Tuple[List[str], torch.Tensor]:
621
+ """Generate answers from questions using query reasoning."""
622
+ if "query_reasoner_adapter" not in self.adapter_keys:
623
+ raise ValueError("Query reasoner adapter not found")
624
+
625
+ self.eval()
626
+
627
+ with torch.no_grad():
628
+ # Encode questions
629
+ self.decoder.set_adapter('query_reasoner_adapter')
630
+ flat_questions = [q for q in questions]
631
+
632
+ if time_count:
633
+ start_time = time.time()
634
+
635
+ q_tok = self._prepare_encoder_inputs(flat_questions, max_length=self.doc_max_length)
636
+ query_reps = self._compr_query_reasoner_stage2(
637
+ q_tok["input_ids"].to(self.decoder.device),
638
+ q_tok["attention_mask"].to(self.decoder.device)
639
+ )
640
+
641
+ # Document retrieval and selection
642
+ if stage2_mips:
643
+ retrieved_doc_embeddings = self._retrieve_embeddings(
644
+ query_reps, stage2_retrieval_top_n=stage2_retrieval_top_n
645
+ )
646
+ scores = torch.bmm(
647
+ query_reps.unsqueeze(1),
648
+ retrieved_doc_embeddings.transpose(1, 2)
649
+ ).squeeze(1)
650
+ z, topk_idx = differentiable_topk(scores, self.generation_top_k, temperature=0.5)
651
+ selected_doc_embeddings = torch.einsum('bkn,bnd->bkd', z, retrieved_doc_embeddings)
652
+ selected_doc_embeddings = selected_doc_embeddings.view(
653
+ selected_doc_embeddings.size(0) * selected_doc_embeddings.size(1),
654
+ -1, self.hidden_size
655
+ )
656
+ else:
657
+ # Use provided documents
658
+ flat_documents = sum(documents, [])
659
+
660
+ if time_count:
661
+ start_time1 = time.time()
662
+
663
+ input_encoder = self._prepare_encoder_inputs(flat_documents, max_length=self.doc_max_length)
664
+ device = self.decoder.device
665
+ enc_input_ids = input_encoder['input_ids'].to(device)
666
+ enc_attention_mask = input_encoder['attention_mask'].to(device)
667
+ retrieved_doc_embeddings, _ = self.compress(enc_input_ids, enc_attention_mask)
668
+
669
+ if time_count:
670
+ start_time2 = time.time()
671
+ compress_time = start_time2 - start_time1
672
+
673
+ B = len(questions)
674
+ stage2_retrieval_top_n = retrieved_doc_embeddings.shape[0] // B
675
+ retrieved_doc_embeddings = retrieved_doc_embeddings.reshape(B, stage2_retrieval_top_n, -1)
676
+ query_reps = query_reps.to(retrieved_doc_embeddings.dtype)
677
+
678
+ if time_count:
679
+ start_time3 = time.time()
680
+
681
+ scores = torch.bmm(
682
+ F.normalize(query_reps, dim=-1, p=2).unsqueeze(1).float(),
683
+ F.normalize(retrieved_doc_embeddings, dim=-1, p=2).float().transpose(1, 2)
684
+ ).squeeze(1)
685
+
686
+ z, topk_idx = differentiable_topk(scores, self.generation_top_k, temperature=0.02)
687
+ selected_doc_embeddings = torch.einsum('bkn,bnd->bkd', z.to(retrieved_doc_embeddings.dtype), retrieved_doc_embeddings)
688
+ selected_doc_embeddings = selected_doc_embeddings.view(
689
+ selected_doc_embeddings.size(0) * selected_doc_embeddings.size(1),
690
+ -1, self.hidden_size
691
+ )
692
+
693
+ if time_count:
694
+ start_time4 = time.time()
695
+ query_time = start_time4 - start_time3 + start_time1 - start_time
696
+
697
+ # Generate instructions and decode
698
+ if time_count:
699
+ start_time5 = time.time()
700
+
701
+ instructions = [
702
+ self._blend_prompt_and_selected_memory_tokens(query=q)[1]
703
+ for q in questions
704
+ ]
705
+
706
+ decoder_inputs = self.decoder_tokenizer(
707
+ instructions,
708
+ return_tensors='pt',
709
+ padding="longest",
710
+ add_special_tokens=False,
711
+ truncation=True,
712
+ max_length=1024,
713
+ )
714
+
715
+ dec_input_ids = decoder_inputs['input_ids'].to(self.decoder.device)
716
+ dec_attention_mask = decoder_inputs['attention_mask'].to(self.decoder.device)
717
+
718
+ # Replace memory token embeddings
719
+ inputs_embeds = self._replace_emb_stage2(selected_doc_embeddings, dec_input_ids)
720
+
721
+ # Switch to decoder adapter for generation
722
+ if 'decoder_adapter' in self.adapter_keys:
723
+ self.decoder.set_adapter('decoder_adapter')
724
+
725
+ # Generate answers
726
+ output_ids = self.decoder.generate(
727
+ inputs_embeds=inputs_embeds,
728
+ attention_mask=dec_attention_mask,
729
+ do_sample=False,
730
+ top_p=None,
731
+ temperature=None,
732
+ max_new_tokens=max_new_tokens,
733
+ pad_token_id=self.decoder_tokenizer.pad_token_id
734
+ )
735
+
736
+ if time_count:
737
+ start_time6 = time.time()
738
+ generate_time = start_time6 - start_time5
739
+
740
+ # Decode generated tokens
741
+ decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
742
+
743
+ if time_count:
744
+ return decoded, topk_idx, compress_time, query_time, generate_time, compress_time + query_time + generate_time
745
+ else:
746
+ return decoded, topk_idx
747
+ def generate_from_paraphrase(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
748
+ """
749
+ Generates answers from documents (via compression then decoding)
750
+ questions: list of string
751
+ documents: list of list of strings (they should all be of equal length: the nb of doc for each question)
752
+ """
753
+ self.generation_top_k = len(documents[0])
754
+ assert len(documents) == len(questions)
755
+ assert all([len(context) == len(documents[0]) for context in documents])
756
+ flat_documents = sum(documents, [])
757
+
758
+ model_input = {}
759
+
760
+ # Creating encoder inputs:
761
+ input_encoder = self._prepare_encoder_inputs(flat_documents, max_length=self.doc_max_length)
762
+ device = self.decoder.device
763
+ model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device)
764
+
765
+ # Creating decoder inputs
766
+ instr = [self._blend_prompt_and_memory_tokens(query="", stage = "stage1", paraphrase_loss = True) for q in questions]
767
+ inp_dec = self.decoder_tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=1024)
768
+ model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
769
+
770
+ # Generation
771
+ return self._generate(model_input, max_new_tokens=max_new_tokens)
772
+
773
+
774
+ def generate_from_text(self,
775
+ questions: List[str],
776
+ documents: List[List[str]],
777
+ max_new_tokens: int = 128) -> List[str]:
778
+ """Generate answers from documents via compression then decoding."""
779
+ self.generation_top_k = len(documents[0])
780
+ assert len(documents) == len(questions)
781
+ assert all(len(context) == len(documents[0]) for context in documents)
782
+
783
+ flat_documents = sum(documents, [])
784
+
785
+ # Create encoder inputs
786
+ input_encoder = self._prepare_encoder_inputs(flat_documents, max_length=self.doc_max_length)
787
+ device = self.decoder.device
788
+ enc_input_ids = input_encoder['input_ids'].to(device)
789
+ enc_attention_mask = input_encoder['attention_mask'].to(device)
790
+
791
+ # Create decoder inputs
792
+ instructions = [self._blend_prompt_and_memory_tokens(query=q, stage="stage1_2") for q in questions]
793
+ inp_dec = self.decoder_tokenizer(
794
+ instructions,
795
+ return_tensors='pt',
796
+ padding="longest",
797
+ add_special_tokens=False,
798
+ truncation=True,
799
+ max_length=1024
800
+ )
801
+ dec_input_ids = inp_dec['input_ids'].to(device)
802
+ dec_attention_mask = inp_dec['attention_mask'].to(device)
803
+
804
+ # Generate
805
+ return self._generate({
806
+ 'enc_input_ids': enc_input_ids,
807
+ 'enc_attention_mask': enc_attention_mask,
808
+ 'dec_input_ids': dec_input_ids,
809
+ 'dec_attention_mask': dec_attention_mask
810
+ }, max_new_tokens=max_new_tokens)
811
+
812
+ def generate_from_compressed_documents_and_questions(self,
813
+ questions: List[str],
814
+ compressed_documents: torch.Tensor,
815
+ max_new_tokens: int = 128) -> List[str]:
816
+ """Generate answers from compressed documents."""
817
+ self.generation_top_k = compressed_documents.size(0) // len(questions)
818
+ assert compressed_documents.size(0) % self.generation_top_k == 0
819
+
820
+ # Create decoder inputs
821
+ instructions = [self._blend_prompt_and_memory_tokens(query=q, stage="stage1_2") for q in questions]
822
+ inp_dec = self.decoder_tokenizer(
823
+ instructions,
824
+ return_tensors='pt',
825
+ padding="longest",
826
+ add_special_tokens=False,
827
+ truncation=True,
828
+ max_length=1024
829
+ )
830
+ device = self.decoder.device
831
+ dec_input_ids = inp_dec['input_ids'].to(device)
832
+ dec_attention_mask = inp_dec['attention_mask'].to(device)
833
+
834
+ # Create input decoder embeddings from prompt + compressed documents
835
+ inputs_embeds = self._replace_emb(compressed_documents, dec_input_ids)
836
+
837
+ # Activate decoder generator
838
+ if 'decoder_adapter' in self.adapter_keys:
839
+ self.decoder.set_adapter('decoder_adapter')
840
+
841
+ output_ids = self.decoder.generate(
842
+ inputs_embeds=inputs_embeds,
843
+ attention_mask=dec_attention_mask,
844
+ max_new_tokens=max_new_tokens
845
+ )
846
+
847
+ return self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
848
+
849
+ def compress_documents(self, documents: List[str]) -> torch.Tensor:
850
+ """Compress a list of documents."""
851
+ input_encoder = self._prepare_encoder_inputs(documents, max_length=self.doc_max_length)
852
+ enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
853
+ attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
854
+ return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
855
+
856
+ # Helper methods
857
+ def _prepare_encoder_inputs(self, texts: List[str], max_length: int, q_texts: List[str] = None) -> Dict[str, torch.Tensor]:
858
+ """Create inputs for the encoder."""
859
+ if q_texts is not None:
860
+ assert len(texts) == len(q_texts)
861
+
862
+ if self.compr is None:
863
+ return self._prepare_encoder_inputs_to_decoder(texts, max_length, q_texts)
864
+ else:
865
+ return self.compr.prepare_inputs(texts, max_length, q_texts)
866
+
867
+ def _prepare_encoder_inputs_to_decoder(self, texts: List[str], max_length: int, q_texts: List[str] = None) -> Dict[str, torch.Tensor]:
868
+ """Prepare encoder inputs when using decoder as compressor."""
869
+ if q_texts is not None:
870
+ texts_to_encode = [
871
+ self.decoder_tokenizer.enc_token +
872
+ self.decoder_tokenizer.bos_token +
873
+ '\nQuery:\n' + query +
874
+ 'Document:\n' + text +
875
+ self.decoder_tokenizer.eos_token
876
+ for text, query in zip(texts, q_texts)
877
+ ]
878
+ inp_enc = self.decoder_tokenizer(
879
+ texts_to_encode,
880
+ return_tensors='pt',
881
+ padding='max_length',
882
+ max_length=max_length + 8,
883
+ truncation=True,
884
+ add_special_tokens=False
885
+ )
886
+ else:
887
+ inp_enc = [
888
+ self.decoder_tokenizer.enc_token +
889
+ self.decoder_tokenizer.bos_token +
890
+ text +
891
+ self.decoder_tokenizer.eos_token
892
+ for text in texts
893
+ ]
894
+ inp_enc = self.decoder_tokenizer(
895
+ inp_enc,
896
+ return_tensors='pt',
897
+ padding="max_length",
898
+ max_length=max_length + 3,
899
+ truncation=True,
900
+ add_special_tokens=False
901
+ )
902
+
903
+ num_mem_tokens = self.doc_max_length // self.compr_rate
904
+ assert num_mem_tokens == len(self.decoder_tokenizer.mem_tokens)
905
+
906
+ inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(
907
+ inp_enc['input_ids'],
908
+ inp_enc['attention_mask'],
909
+ num_mem_tokens,
910
+ tokenizer=self.decoder_tokenizer
911
+ )
912
+
913
+ return inp_enc
914
+
915
+ def _replace_emb(self, compressed_embs: torch.Tensor, dec_input_ids: torch.Tensor) -> torch.Tensor:
916
+ """Replace memory tokens in decoder input with compressed embeddings."""
917
+ indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
918
+ return self._replace_embeddings(compressed_embs, dec_input_ids, indices)
919
+
920
+ def _replace_emb_stage2(self, compressed_embs: torch.Tensor, dec_input_ids: torch.Tensor) -> torch.Tensor:
921
+ """Replace memory tokens for stage 2."""
922
+ indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
923
+ return self._replace_embeddings(compressed_embs, dec_input_ids, indices)
924
+
925
+ def _replace_embeddings(self, compressed_embs: torch.Tensor, dec_input_ids: torch.Tensor, indices: range) -> torch.Tensor:
926
+ """Replace memory tokens with compressed embeddings."""
927
+ inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
928
+ num_embs = compressed_embs.size(1)
929
+ slot_len = num_embs + (1 if self.sep else 0)
930
+
931
+ # Get first memory token indices
932
+ first_mem_token_indices = torch.argmax(
933
+ (dec_input_ids == self.decoder_tokenizer.mem_token_ids[0]).int(), dim=1
934
+ )
935
+ batch_size = inputs_embeds.size(0)
936
+
937
+ # Replace with compressed embeddings
938
+ for i in range(batch_size):
939
+ for j in range(indices[i], indices[i + 1]):
940
+ start_idx = first_mem_token_indices[i].item() + (j - indices[i]) * slot_len
941
+ assert inputs_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size()
942
+ inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
943
+
944
+ return inputs_embeds
945
+
946
+ def _retrieve_embeddings(self, questions: torch.Tensor, stage2_retrieval_top_n: int = 1) -> torch.Tensor:
947
+ """Retrieve embeddings of documents."""
948
+ response = requests.post(
949
+ self.url_retrieval,
950
+ json={
951
+ "queries": questions.detach().cpu().float().numpy().tolist(),
952
+ 'k': self.generation_top_k
953
+ }
954
+ )
955
+
956
+ if response.status_code != 200:
957
+ raise Exception(f"Error: {response.status_code} - {response.text}")
958
+
959
+ results = response.json()
960
+ retrieval_embeddings = results['retrieved_embeddings']
961
+ retrieval_embeddings = torch.tensor(
962
+ retrieval_embeddings,
963
+ dtype=torch.bfloat16,
964
+ device=questions.device
965
+ )
966
+
967
+ if len(retrieval_embeddings.shape) == 4:
968
+ retrieval_embeddings = retrieval_embeddings.reshape(
969
+ retrieval_embeddings.shape[0] * retrieval_embeddings.shape[1],
970
+ retrieval_embeddings.shape[2], -1
971
+ )
972
+
973
+ return retrieval_embeddings
974
+
975
+ def _blend_prompt_and_memory_tokens(self, query: str, answer: str = None, qa_loss: bool = False,
976
+ paraphrase_loss: bool = False, stage: str = "stage1") -> Tuple[int, str]:
977
+ """Blend prompt with memory tokens for different training stages."""
978
+ mem_tokens_str = ''.join(self.decoder_tokenizer.mem_tokens) + self.decoder_tokenizer.sep_token
979
+ docs = mem_tokens_str * self.generation_top_k
980
+
981
+ if stage == "stage1":
982
+ if qa_loss:
983
+ return self._blend_qa_prompt(docs, query, answer)
984
+ elif paraphrase_loss:
985
+ return self._blend_paraphrase_prompt(docs, answer)
986
+ elif stage == "stage1_2":
987
+ return self._blend_standard_prompt(docs, query, answer)
988
+
989
+ raise ValueError(f"Unknown stage: {stage}")
990
+
991
+ def _blend_qa_prompt(self, docs: str, query: List[str], answer: List[str]) -> Tuple[int, str]:
992
+ """Create QA prompt for stage 1."""
993
+ prompt_system = 'You are a helpful assistant. Given a document, your task is to generate some single questions to cover all key information of the document and answer them sequentially.'
994
+ prompt_user = f"Background:\n{docs}"
995
+
996
+ sys_prompt = [{"role": "system", "content": prompt_system}]
997
+ user_prompt = [{"role": "user", "content": prompt_user.replace(':\ ', ': ')}]
998
+
999
+ qa_lines = [f"Question: {q}\nAnswer: {a}" for q, a in zip(query, answer)]
1000
+ query_answer = "\n".join(qa_lines)
1001
+ assistant_prompt = [{"role": "assistant", "content": query_answer}]
1002
+
1003
+ try:
1004
+ prompt = self.decoder_tokenizer.apply_chat_template(
1005
+ sys_prompt + user_prompt,
1006
+ tokenize=False,
1007
+ add_generation_prompt=True,
1008
+ enable_thinking=False
1009
+ )
1010
+ response = self.decoder_tokenizer.apply_chat_template(
1011
+ sys_prompt + user_prompt + assistant_prompt,
1012
+ tokenize=False,
1013
+ add_generation_prompt=False,
1014
+ enable_thinking=False
1015
+ )
1016
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1017
+ except TemplateError as e:
1018
+ if "System role not supported" in str(e):
1019
+ messages = [{"role": "user", "content": sys_prompt[0]['content'] + '\n' + user_prompt[0]['content']}]
1020
+ prompt = self.decoder_tokenizer.apply_chat_template(
1021
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
1022
+ )
1023
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1024
+ # Handle response for unsupported system role
1025
+ messages_with_answer = messages + assistant_prompt
1026
+ response = self.decoder_tokenizer.apply_chat_template(
1027
+ messages_with_answer, tokenize=False, add_generation_prompt=False, enable_thinking=False
1028
+ )
1029
+ else:
1030
+ raise e
1031
+
1032
+ return prompt_len, response
1033
+
1034
+ def _blend_paraphrase_prompt(self, docs: str, answer: str) -> Tuple[int, str]:
1035
+ """Create paraphrase prompt for stage 1."""
1036
+ prompt_system = 'You are a helpful assistant. Your task is follow the instructions to paraphrase the background information.'
1037
+ prompt_user = random.choice(PARAPHRASE_INSTRUCTIONS).format(docs=docs)
1038
+
1039
+ sys_prompt = [{"role": "system", "content": prompt_system}]
1040
+ user_prompt = [{"role": "user", "content": prompt_user.replace(':\ ', ': ')}]
1041
+
1042
+ try:
1043
+ prompt = self.decoder_tokenizer.apply_chat_template(
1044
+ sys_prompt + user_prompt,
1045
+ tokenize=False,
1046
+ add_generation_prompt=True,
1047
+ enable_thinking=False
1048
+ )
1049
+ if answer is None:
1050
+ return prompt
1051
+
1052
+ assistant_prompt = [{"role": "assistant", "content": answer}]
1053
+ response = self.decoder_tokenizer.apply_chat_template(
1054
+ sys_prompt + user_prompt + assistant_prompt,
1055
+ tokenize=False,
1056
+ add_generation_prompt=False,
1057
+ enable_thinking=False
1058
+ )
1059
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1060
+ except TemplateError as e:
1061
+ if "System role not supported" in str(e):
1062
+ combined_content = prompt_system + '\n' + prompt_user.replace(':\ ', ': ')
1063
+ messages = [{"role": "user", "content": combined_content}]
1064
+ prompt = self.decoder_tokenizer.apply_chat_template(
1065
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
1066
+ )
1067
+ if answer is None:
1068
+ return prompt
1069
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1070
+ messages_with_answer = messages + [{"role": "assistant", "content": answer}]
1071
+ response = self.decoder_tokenizer.apply_chat_template(
1072
+ messages_with_answer, tokenize=False, add_generation_prompt=False, enable_thinking=False
1073
+ )
1074
+ else:
1075
+ raise e
1076
+
1077
+ return prompt_len, response
1078
+
1079
+ def _blend_standard_prompt(self, docs: str, query: str, answer: str) -> Tuple[int, str]:
1080
+ """Create standard prompt for stage 1_2."""
1081
+ prompt_system = 'You are a helpful assistant. Your task is to extract relevant information from provided documents and to answer to questions as briefly as possible.'
1082
+ prompt_user = f"Background:\n{docs}\n\nQuestion:{query}"
1083
+
1084
+ sys_prompt = [{"role": "system", "content": prompt_system}]
1085
+ user_prompt = [{"role": "user", "content": prompt_user.replace(':\ ', ': ')}]
1086
+
1087
+ try:
1088
+ prompt = self.decoder_tokenizer.apply_chat_template(
1089
+ sys_prompt + user_prompt,
1090
+ tokenize=False,
1091
+ add_generation_prompt=True,
1092
+ enable_thinking=False
1093
+ )
1094
+ if answer is None:
1095
+ return prompt
1096
+
1097
+ assistant_prompt = [{"role": "assistant", "content": answer}]
1098
+ response = self.decoder_tokenizer.apply_chat_template(
1099
+ sys_prompt + user_prompt + assistant_prompt,
1100
+ tokenize=False,
1101
+ add_generation_prompt=False,
1102
+ enable_thinking=False
1103
+ )
1104
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1105
+ except TemplateError as e:
1106
+ if "System role not supported" in str(e):
1107
+ combined_content = prompt_system + '\n' + prompt_user.replace(':\ ', ': ')
1108
+ messages = [{"role": "user", "content": combined_content}]
1109
+ prompt = self.decoder_tokenizer.apply_chat_template(
1110
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
1111
+ )
1112
+ if answer is None:
1113
+ return prompt
1114
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1115
+ messages_with_answer = messages + [{"role": "assistant", "content": answer}]
1116
+ response = self.decoder_tokenizer.apply_chat_template(
1117
+ messages_with_answer, tokenize=False, add_generation_prompt=False, enable_thinking=False
1118
+ )
1119
+ else:
1120
+ raise e
1121
+
1122
+ return prompt_len, response
1123
+
1124
+ def _blend_prompt_and_selected_memory_tokens(self, query: str, answer: str = None) -> Tuple[int, str]:
1125
+ """Create prompt for stage 2 with selected memory tokens."""
1126
+ mem_tokens_str = ''.join(self.decoder_tokenizer.mem_tokens) + self.decoder_tokenizer.sep_token
1127
+ docs = mem_tokens_str * self.generation_top_k
1128
+
1129
+ prompt_system = 'You are a helpful assistant. Your task is to extract relevant information from provided documents and to answer to questions as briefly as possible.'
1130
+ prompt_user = f"Background:\n{docs}\n\nQuestion:{query}"
1131
+
1132
+ sys_prompt = [{"role": "system", "content": prompt_system}]
1133
+ user_prompt = [{"role": "user", "content": prompt_user.replace(':\ ', ': ')}]
1134
+
1135
+ try:
1136
+ prompt = self.decoder_tokenizer.apply_chat_template(
1137
+ sys_prompt + user_prompt,
1138
+ tokenize=False,
1139
+ add_generation_prompt=True,
1140
+ enable_thinking=False
1141
+ )
1142
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1143
+
1144
+ if answer is not None:
1145
+ assistant_prompt = [{"role": "assistant", "content": answer}]
1146
+ response = self.decoder_tokenizer.apply_chat_template(
1147
+ sys_prompt + user_prompt + assistant_prompt,
1148
+ tokenize=False,
1149
+ add_generation_prompt=False,
1150
+ enable_thinking=False
1151
+ )
1152
+ else:
1153
+ response = prompt
1154
+
1155
+ except TemplateError as e:
1156
+ if "System role not supported" in str(e):
1157
+ combined_content = prompt_system + '\n' + prompt_user.replace(':\ ', ': ')
1158
+ messages = [{"role": "user", "content": combined_content}]
1159
+
1160
+ prompt = self.decoder_tokenizer.apply_chat_template(
1161
+ messages,
1162
+ tokenize=False,
1163
+ add_generation_prompt=True,
1164
+ enable_thinking=False
1165
+ )
1166
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1167
+
1168
+ if answer is not None:
1169
+ messages_with_answer = messages + [{"role": "assistant", "content": answer}]
1170
+ response = self.decoder_tokenizer.apply_chat_template(
1171
+ messages_with_answer,
1172
+ tokenize=False,
1173
+ add_generation_prompt=False,
1174
+ enable_thinking=False
1175
+ )
1176
+ else:
1177
+ response = prompt
1178
+ else:
1179
+ raise e
1180
+
1181
+ return prompt_len, response
1182
+
1183
+ # Model saving and loading methods
1184
+ def save_pretrained(self, save_directory: str, **kwargs):
1185
+ """Save only the LoRA adapters and their configurations."""
1186
+ if self.lora:
1187
+ if not os.path.exists(save_directory):
1188
+ os.makedirs(save_directory)
1189
+
1190
+ # Save LoRA adapter weights
1191
+ torch.save(
1192
+ self._get_all_adapters_state_dict(),
1193
+ os.path.join(save_directory, "adapters.pth")
1194
+ )
1195
+
1196
+ # Save first and last layers of decoder
1197
+ torch.save(
1198
+ self._get_decoder_first_and_last_layer_state_dict(),
1199
+ os.path.join(save_directory, "decoder_first_last_layers.pth")
1200
+ )
1201
+
1202
+ # Save configuration
1203
+ self.config.save_pretrained(save_directory)
1204
+ else:
1205
+ super().save_pretrained(save_directory, **kwargs)
1206
+
1207
+ def _get_all_adapters_state_dict(self) -> Dict[str, Dict[str, torch.Tensor]]:
1208
+ """Return the state dicts of all adapters."""
1209
+ return {
1210
+ key: {k: v.cpu() for k, v in self.decoder.get_adapter_state_dict(key).items()}
1211
+ for key in self.adapter_keys
1212
+ }
1213
+
1214
+ def _get_decoder_first_and_last_layer_state_dict(self) -> Dict[str, torch.Tensor]:
1215
+ """Get first and last layers that change when adding tokens."""
1216
+ out = {}
1217
+ for k, v in self.decoder.named_parameters():
1218
+ if 'lm_head.weight' in k or 'embed_tokens.weight' in k:
1219
+ out[k] = v.cpu()
1220
+ return out
1221
+
1222
+ @classmethod
1223
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
1224
+ """Load model from pretrained checkpoint."""
1225
+ # Load configuration
1226
+ config = CLaRaConfig.from_pretrained(pretrained_model_name_or_path)
1227
+
1228
+ # Update config with kwargs
1229
+ for key, value in kwargs.items():
1230
+ if hasattr(config, key):
1231
+ setattr(config, key, value)
1232
+
1233
+ map_location = torch.device("cpu") if not torch.cuda.is_available() else None
1234
+
1235
+ if config.lora:
1236
+ # Delay adapter construction
1237
+ config.load_adapters = False
1238
+ if 'device_map' in kwargs:
1239
+ config.device_map = kwargs['device_map']
1240
+
1241
+ # Initialize model
1242
+ print(f"Initializing model from trained checkpoint: {config}")
1243
+ model = cls(config)
1244
+
1245
+ # Load first and last layers
1246
+ try:
1247
+ first_and_last_layers_path = hf_hub_download(
1248
+ repo_id=pretrained_model_name_or_path,
1249
+ filename="decoder_first_last_layers.pth"
1250
+ )
1251
+ except Exception:
1252
+ first_and_last_layers_path = os.path.join(
1253
+ pretrained_model_name_or_path, "decoder_first_last_layers.pth"
1254
+ )
1255
+
1256
+ if os.path.exists(first_and_last_layers_path):
1257
+ first_and_last_decoder_state_dict = torch.load(
1258
+ first_and_last_layers_path, map_location=map_location, weights_only=True
1259
+ )
1260
+ for key in first_and_last_decoder_state_dict:
1261
+ assert key in model.decoder.state_dict()
1262
+ model.decoder.load_state_dict(first_and_last_decoder_state_dict, strict=False)
1263
+ else:
1264
+ print(f'First and last layer not found: {first_and_last_layers_path}')
1265
+
1266
+ peft_config = model._get_peft_config(lora_r=config.lora_r)
1267
+
1268
+ # Load LoRA adapters
1269
+ try:
1270
+ adapters_path = hf_hub_download(
1271
+ repo_id=pretrained_model_name_or_path,
1272
+ filename="adapters.pth"
1273
+ )
1274
+ except Exception:
1275
+ adapters_path = os.path.join(pretrained_model_name_or_path, "adapters.pth")
1276
+
1277
+ if os.path.exists(adapters_path):
1278
+ adapters_state_dict = torch.load(adapters_path, map_location=map_location, weights_only=True)
1279
+ model._load_adapters_from_state_dict(adapters_state_dict, peft_config, config)
1280
+ else:
1281
+ warnings.warn(f'Adapters not found at {adapters_path}')
1282
+
1283
+ model._set_all_adapters()
1284
+ config.load_adapters = True
1285
+ return model
1286
+ else:
1287
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
1288
+ def _load_adapters_from_state_dict(self, adapters_state_dict: Dict, peft_config: LoraConfig, config: CLaRaConfig):
1289
+ """Load adapters from state dict based on training stage."""
1290
+ if not getattr(config, 'pure_inference', False):
1291
+ for key, val in adapters_state_dict.items():
1292
+ # Skip certain adapters based on training stage
1293
+ if config.training_stage == 'stage1' and key == 'query_reasoner_adapter':
1294
+ continue
1295
+ elif config.training_stage == 'stage1_2' and key in ['query_reasoner_adapter', 'decoder_adapter']:
1296
+ continue
1297
+ elif config.training_stage == 'stage2_reasoning' and key == 'decoder_adapter':
1298
+ continue
1299
+
1300
+ self._load_adapter_from_state_dict(
1301
+ peft_config=peft_config,
1302
+ adapter_name=key,
1303
+ adapter_state_dict=val
1304
+ )
1305
+ else:
1306
+ # Load all adapters for pure inference
1307
+ for key, val in adapters_state_dict.items():
1308
+ self._load_adapter_from_state_dict(
1309
+ peft_config=peft_config,
1310
+ adapter_name=key,
1311
+ adapter_state_dict=val
1312
+ )
1313
+
1314
+ # Handle special cases for stage 2 training
1315
+ if config.training_stage == 'stage2' and 'query_reasoner_adapter' not in adapters_state_dict:
1316
+ self._handle_query_reasoner_adapter_loading(adapters_state_dict, peft_config)
1317
+
1318
+ def _load_adapter_from_state_dict(self, peft_config: LoraConfig, adapter_name: str, adapter_state_dict: Dict):
1319
+ """Create adapter from state dict."""
1320
+ print(f'Loading checkpoint adapter: {adapter_name}')
1321
+ self.decoder.load_adapter(
1322
+ peft_config=peft_config,
1323
+ adapter_name=adapter_name,
1324
+ adapter_state_dict=adapter_state_dict
1325
+ )
1326
+ self.adapter_keys.append(adapter_name)
1327
+
1328
+ def _handle_query_reasoner_adapter_loading(self, adapters_state_dict: Dict, peft_config: LoraConfig):
1329
+ """Handle special loading logic for query reasoner adapter."""
1330
+ if 'encoder_adapter' in adapters_state_dict and 'query_reasoner_adapter' not in adapters_state_dict:
1331
+ # Rename encoder adapter to query reasoner adapter
1332
+ renamed = {}
1333
+ for k, v in adapters_state_dict['encoder_adapter'].items():
1334
+ new_k = k.replace('encoder_adapter', 'query_reasoner_adapter')
1335
+ renamed[new_k] = v.detach().clone()
1336
+
1337
+ self._load_adapter_from_state_dict(
1338
+ peft_config=peft_config,
1339
+ adapter_name='query_reasoner_adapter',
1340
+ adapter_state_dict=renamed
1341
+ )
1342
+ print('Loaded query_reasoner_adapter from stage 1 compressor checkpoint')
1343
+ else:
1344
+ # Create new adapter randomly
1345
+ self.decoder.add_adapter(peft_config, 'query_reasoner_adapter')
1346
+ self.adapter_keys.append('query_reasoner_adapter')
1347
+ print('Loaded query_reasoner_adapter randomly for stage 2 training')
1348
+
1349
+ # Forward pass methods
1350
+ def forward(self,
1351
+ batch: Dict = None,
1352
+ questions: List[str] = None,
1353
+ documents: List[List[str]] = None,
1354
+ answers: List[str] = None,
1355
+ original_answer_gen_api: str = None,
1356
+ stage2_mips: bool = False,
1357
+ stage2_retrieval_top_n: int = None) -> Tuple[torch.Tensor, Dict]:
1358
+ """
1359
+ Forward pass with support for both batch and legacy interfaces.
1360
+
1361
+ Args:
1362
+ batch: Preprocessed batch dict (new interface)
1363
+ questions: List of questions (legacy interface)
1364
+ documents: List of document lists (legacy interface)
1365
+ answers: List of answers (legacy interface)
1366
+ original_answer_gen_api: API URL for generation (legacy interface)
1367
+ stage2_mips: Whether to use MIPS for stage2
1368
+ stage2_retrieval_top_n: Top-n for stage2 retrieval
1369
+
1370
+ Returns:
1371
+ Tuple of (loss, additional_outputs_dict)
1372
+ """
1373
+ if batch is not None:
1374
+ return self._forward_batch(batch, stage2_mips, stage2_retrieval_top_n)
1375
+ else:
1376
+ return self._forward_legacy(questions, documents, answers, original_answer_gen_api)
1377
+
1378
+ def _forward_batch(self, batch: Dict, stage2_mips: bool, stage2_retrieval_top_n: int) -> Tuple[torch.Tensor, Dict]:
1379
+ """Handle batch-based forward pass."""
1380
+ stage = batch.get("stage", None)
1381
+
1382
+ if stage in ["stage1", "stage1_2"]:
1383
+ return self._forward_stage1_batch(batch)
1384
+ elif stage == "stage2":
1385
+ return self._forward_stage2_batch(batch, stage2_mips, stage2_retrieval_top_n)
1386
+ elif stage == "stage2_pretrain_retrieval":
1387
+ return self._forward_stage2_pretrain_batch(batch, stage2_mips, stage2_retrieval_top_n)
1388
+ elif stage == "stage2_reasoning":
1389
+ return self._forward_stage2_reasoning_batch(batch)
1390
+ else:
1391
+ raise ValueError(f"Unknown stage: {stage}")
1392
+
1393
+ def _forward_stage1_batch(self, batch: Dict) -> Tuple[torch.Tensor, Dict]:
1394
+ """Forward pass for stage 1 training."""
1395
+ # Move tensors to device
1396
+ enc_input_ids = batch["enc_input_ids"].to(self.decoder.device)
1397
+ enc_attention_mask = batch["enc_attention_mask"].to(self.decoder.device)
1398
+ dec_input_ids = batch["dec_input_ids"].to(self.decoder.device)
1399
+ dec_attention_mask = batch["dec_attention_mask"].to(self.decoder.device)
1400
+ labels = batch["labels"].to(self.decoder.device)
1401
+
1402
+ out = self._forward_stage_1(
1403
+ enc_input_ids=enc_input_ids,
1404
+ enc_attention_mask=enc_attention_mask,
1405
+ dec_input_ids=dec_input_ids,
1406
+ dec_attention_mask=dec_attention_mask,
1407
+ labels=labels,
1408
+ )
1409
+ return out["loss"], {"logits": out["logits"], "mse_loss": out["mse_loss"]}
1410
+
1411
+ def _forward_stage2_batch(self, batch: Dict, stage2_mips: bool, stage2_retrieval_top_n: int) -> Tuple[torch.Tensor, Dict]:
1412
+ """Forward pass for stage 2 training."""
1413
+ self.decoder.set_adapter('query_reasoner_adapter')
1414
+
1415
+ B = batch["labels"].shape[0]
1416
+ query_reps = self._compr_query_reasoner_stage2(
1417
+ batch["query_input_ids"].to(self.decoder.device),
1418
+ batch["query_attention_mask"].to(self.decoder.device)
1419
+ )
1420
+
1421
+ enc_input_ids = batch["enc_input_ids"].to(self.decoder.device)
1422
+ enc_attention_mask = batch["enc_attention_mask"].to(self.decoder.device)
1423
+ dec_input_ids = batch["dec_input_ids"].to(self.decoder.device)
1424
+ dec_attention_mask = batch["dec_attention_mask"].to(self.decoder.device)
1425
+ labels = batch["labels"].to(self.decoder.device)
1426
+
1427
+ # Document retrieval and selection
1428
+ if stage2_mips:
1429
+ retrieved_doc_embeddings = self._retrieve_embeddings(
1430
+ query_reps, stage2_retrieval_top_n=stage2_retrieval_top_n
1431
+ )
1432
+ scores = torch.bmm(
1433
+ query_reps.unsqueeze(1),
1434
+ retrieved_doc_embeddings.transpose(1, 2)
1435
+ ).squeeze(1)
1436
+ z, topk_idx = differentiable_topk(scores, self.generation_top_k, temperature=1)
1437
+ selected = torch.einsum('bkn,bnd->bkd', z, retrieved_doc_embeddings)
1438
+ selected = selected.view(selected.size(0) * selected.size(1), -1, self.hidden_size)
1439
+ else:
1440
+ with torch.no_grad():
1441
+ retrieved_doc_embeddings, mse_loss = self.compress(enc_input_ids, enc_attention_mask)
1442
+
1443
+ stage2_retrieval_top_n = retrieved_doc_embeddings.shape[0] // B
1444
+ retrieved_doc_embeddings = retrieved_doc_embeddings.reshape(B, stage2_retrieval_top_n, -1)
1445
+ query_reps = query_reps.to(retrieved_doc_embeddings.dtype)
1446
+
1447
+ scores = torch.bmm(
1448
+ F.normalize(query_reps, dim=-1, p=2).unsqueeze(1).float(),
1449
+ F.normalize(retrieved_doc_embeddings, dim=-1, p=2).float().transpose(1, 2)
1450
+ ).squeeze(1)
1451
+
1452
+ z, topk_idx = differentiable_topk(scores, self.generation_top_k, temperature=0.02)
1453
+ selected = torch.einsum('bkn,bnd->bkd', z.to(retrieved_doc_embeddings.dtype), retrieved_doc_embeddings)
1454
+ selected = selected.view(selected.size(0) * selected.size(1), -1, self.hidden_size)
1455
+
1456
+ inputs_embeds = self._replace_emb_stage2(selected, dec_input_ids)
1457
+
1458
+ if 'decoder_adapter' in self.adapter_keys:
1459
+ self.decoder.set_adapter('decoder_adapter')
1460
+
1461
+ dec_out = self.decoder(
1462
+ inputs_embeds=inputs_embeds,
1463
+ attention_mask=dec_attention_mask,
1464
+ labels=labels,
1465
+ )
1466
+
1467
+ self.decoder.set_adapter(['decoder_adapter', 'query_reasoner_adapter'])
1468
+ return dec_out.loss, {"logits": dec_out.logits, "topk_idx": topk_idx, "mse_loss": mse_loss}
1469
+
1470
+ def _forward_stage2_pretrain_batch(self, batch: Dict, stage2_mips: bool, stage2_retrieval_top_n: int) -> Tuple[torch.Tensor, Dict]:
1471
+ """Forward pass for stage 2 pretraining with retrieval."""
1472
+ self.decoder.set_adapter('query_reasoner_adapter')
1473
+
1474
+ B = batch["labels"].shape[0]
1475
+ N = batch["enc_input_ids"].shape[0] // B
1476
+ device = self.decoder.device
1477
+
1478
+ query_reps = self._compr_query_reasoner_stage2(
1479
+ batch["query_input_ids"].to(device),
1480
+ batch["query_attention_mask"].to(device)
1481
+ )
1482
+
1483
+ enc_input_ids = batch["enc_input_ids"].to(device)
1484
+ enc_attention_mask = batch["enc_attention_mask"].to(device)
1485
+
1486
+ with torch.no_grad():
1487
+ retrieved_doc_embeddings, mse_loss = self.compress(enc_input_ids, enc_attention_mask)
1488
+
1489
+ stage2_retrieval_top_n = retrieved_doc_embeddings.shape[0] // B
1490
+ retrieved_doc_embeddings = retrieved_doc_embeddings.reshape(B, stage2_retrieval_top_n, -1)
1491
+ query_reps = query_reps.to(retrieved_doc_embeddings.dtype)
1492
+
1493
+ scores = torch.bmm(
1494
+ F.normalize(query_reps, dim=-1, p=2).unsqueeze(1).float(),
1495
+ F.normalize(retrieved_doc_embeddings, dim=-1, p=2).float().transpose(1, 2)
1496
+ ).squeeze(1)
1497
+
1498
+ pos_index = batch["pos_index"]
1499
+ pos_mask = build_pos_mask(pos_index, N, device)
1500
+ tau = 0.02
1501
+ logits = scores / tau
1502
+
1503
+ pos_logits = logits.masked_fill(~pos_mask, float('-inf'))
1504
+ num = torch.logsumexp(pos_logits, dim=-1)
1505
+ den = torch.logsumexp(logits, dim=-1)
1506
+ loss_vec = -(num - den)
1507
+ valid = pos_mask.any(dim=-1)
1508
+ loss = loss_vec[valid].mean()
1509
+
1510
+ topk = self.generation_top_k
1511
+ topk_idx = logits.topk(k=min(topk, N), dim=-1).indices
1512
+
1513
+ return loss, {"logits": [[]], "topk_idx": topk_idx, "mse_loss": mse_loss}
1514
+
1515
+ def _forward_stage2_reasoning_batch(self, batch: Dict) -> Tuple[torch.Tensor, Dict]:
1516
+ """Forward pass for stage 2 reasoning training."""
1517
+ B = batch["labels"].shape[0]
1518
+ enc_input_ids = batch["enc_input_ids"].to(self.decoder.device)
1519
+ enc_attention_mask = batch["enc_attention_mask"].to(self.decoder.device)
1520
+ dec_input_ids = batch["dec_input_ids"].to(self.decoder.device)
1521
+ dec_attention_mask = batch["dec_attention_mask"].to(self.decoder.device)
1522
+ labels = batch["labels"].to(self.decoder.device)
1523
+
1524
+ if sum(batch["docs_num"]) != 0:
1525
+ with torch.no_grad():
1526
+ selected, mse_loss = self.compress(enc_input_ids, enc_attention_mask)
1527
+ indices = batch["docs_num"]
1528
+ inputs_embeds = self._replace_reasoning_embeddings(selected, dec_input_ids, indices)
1529
+ else:
1530
+ inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
1531
+ mse_loss = 0
1532
+
1533
+ if 'decoder_adapter' in self.adapter_keys:
1534
+ self.decoder.set_adapter('decoder_adapter')
1535
+
1536
+ dec_out = self.decoder(
1537
+ inputs_embeds=inputs_embeds,
1538
+ attention_mask=dec_attention_mask,
1539
+ labels=labels,
1540
+ )
1541
+
1542
+ self.decoder.set_adapter(['decoder_adapter'])
1543
+ return dec_out.loss, {"logits": dec_out.logits, "mse_loss": mse_loss}
1544
+
1545
+ def _forward_stage_1(self,
1546
+ enc_input_ids: torch.LongTensor = None,
1547
+ enc_attention_mask: torch.LongTensor = None,
1548
+ dec_input_ids: torch.LongTensor = None,
1549
+ dec_attention_mask: torch.LongTensor = None,
1550
+ labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
1551
+ """Stage 1 forward pass for document compression and QA."""
1552
+ assert enc_input_ids.size() == enc_attention_mask.size()
1553
+
1554
+ # Flatten 3D inputs to 2D if needed
1555
+ if len(enc_input_ids.size()) == 3:
1556
+ batch_size, top_k, seq_length = enc_input_ids.size()
1557
+ enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
1558
+ enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
1559
+
1560
+ assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k
1561
+
1562
+ # Compress documents
1563
+ compressed_embs, mse_loss = self.compress(enc_input_ids, enc_attention_mask)
1564
+
1565
+ # Replace memory tokens with compressed embeddings
1566
+ inputs_embeds = self._replace_emb(compressed_embs, dec_input_ids)
1567
+
1568
+ # Detach if compressor-only training
1569
+ if (self.training_form == "compressor") and (self.compr is None):
1570
+ inputs_embeds = inputs_embeds.detach()
1571
+
1572
+ # Set decoder adapter
1573
+ if 'decoder_adapter' in self.adapter_keys:
1574
+ self.decoder.set_adapter('decoder_adapter')
1575
+
1576
+ # Forward through decoder
1577
+ decoder_outputs = self.decoder(
1578
+ inputs_embeds=inputs_embeds,
1579
+ attention_mask=dec_attention_mask,
1580
+ labels=labels
1581
+ )
1582
+
1583
+ # Reactivate all adapters
1584
+ self.decoder.set_adapter(['decoder_adapter', 'encoder_adapter'])
1585
+
1586
+ return {
1587
+ "loss": decoder_outputs.loss,
1588
+ "logits": decoder_outputs.logits,
1589
+ "mse_loss": mse_loss
1590
+ }
1591
+
1592
+ def _replace_reasoning_embeddings(self,
1593
+ compressed_embs: torch.Tensor,
1594
+ dec_input_ids: torch.LongTensor,
1595
+ docs_per_example: List[int]) -> torch.Tensor:
1596
+ """Replace memory slots with compressed embeddings for reasoning."""
1597
+ device = dec_input_ids.device
1598
+ inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
1599
+
1600
+ num_embs = compressed_embs.size(1)
1601
+ slot_len = num_embs + (1 if getattr(self, "sep", False) else 0)
1602
+
1603
+ if not isinstance(docs_per_example, torch.Tensor):
1604
+ docs_per_example = torch.tensor(docs_per_example, device=device, dtype=torch.long)
1605
+ else:
1606
+ docs_per_example = docs_per_example.to(device=device, dtype=torch.long)
1607
+
1608
+ offsets = torch.zeros(docs_per_example.size(0) + 1, device=device, dtype=torch.long)
1609
+ offsets[1:] = torch.cumsum(docs_per_example, dim=0)
1610
+ total_docs = int(offsets[-1].item())
1611
+ assert total_docs == compressed_embs.size(0)
1612
+
1613
+ mem_id = self.decoder_tokenizer.mem_token_ids[0]
1614
+ B, L, H = inputs_embeds.size()
1615
+
1616
+ for i in range(B):
1617
+ # Find first memory token position
1618
+ mem_pos = (dec_input_ids[i] == mem_id).nonzero(as_tuple=True)[0]
1619
+ if mem_pos.numel() == 0:
1620
+ continue
1621
+ first_mem_idx = int(mem_pos[0].item())
1622
+
1623
+ n_docs_i = int(docs_per_example[i].item())
1624
+ base = int(offsets[i].item())
1625
+
1626
+ needed_len = first_mem_idx + n_docs_i * slot_len
1627
+ assert needed_len <= L
1628
+
1629
+ for local_j in range(n_docs_i):
1630
+ global_j = base + local_j
1631
+ start_idx = first_mem_idx + local_j * slot_len
1632
+ target_slice = inputs_embeds[i, start_idx:start_idx + num_embs, :]
1633
+ src = compressed_embs[global_j]
1634
+ assert target_slice.size() == src.size()
1635
+ inputs_embeds[i, start_idx:start_idx + num_embs, :] = src
1636
+
1637
+ return inputs_embeds
1638
+
1639
+ def _generate(self, model_input: Dict[str, torch.Tensor], max_new_tokens: int = 128,
1640
+ return_doc_embeddings: bool = False) -> List[str]:
1641
+ """Generate text from model inputs."""
1642
+ enc_input_ids = model_input['enc_input_ids']
1643
+ enc_attention_mask = model_input['enc_attention_mask']
1644
+ dec_input_ids = model_input['dec_input_ids']
1645
+ dec_attention_mask = model_input['dec_attention_mask']
1646
+
1647
+ assert enc_input_ids.size() == enc_attention_mask.size()
1648
+
1649
+ if len(enc_input_ids.size()) == 3:
1650
+ batch_size, top_k, seq_length = enc_input_ids.size()
1651
+ enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
1652
+ enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
1653
+
1654
+ assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k
1655
+
1656
+ compressed_embs, _ = self.compress(enc_input_ids.to('cuda'), enc_attention_mask.to('cuda'))
1657
+ inputs_embeds = self._replace_emb(compressed_embs, dec_input_ids.to('cuda'))
1658
+
1659
+ if 'decoder_adapter' in self.adapter_keys:
1660
+ self.decoder.set_adapter('decoder_adapter')
1661
+
1662
+ output_ids = self.decoder.generate(
1663
+ inputs_embeds=inputs_embeds.to("cuda"),
1664
+ attention_mask=dec_attention_mask.to("cuda"),
1665
+ do_sample=False,
1666
+ top_p=None,
1667
+ max_new_tokens=max_new_tokens
1668
+ )
1669
+
1670
+ decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
1671
+
1672
+ if return_doc_embeddings:
1673
+ assert 'batch_size' in locals() and 'top_k' in locals()
1674
+ compressed_embs = compressed_embs.view(batch_size, top_k, compressed_embs.size(1), compressed_embs.size(2))
1675
+ return decoded, compressed_embs
1676
+ else:
1677
+ return decoded
1678
+
1679
+
1680
+ # Example usage and testing
1681
+ if __name__ == '__main__':
1682
+ # Example configuration
1683
+ cfg = CLaRaConfig(
1684
+ decoder_model_name='/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2',
1685
+ compr_model_name="mistral_trimmed",
1686
+ compr_rate=64,
1687
+ compr_n_layers=5,
1688
+ compr_mlp_hidden_dim=8096,
1689
+ compr_use_mlp=False,
1690
+ lora=True,
1691
+ lora_compressor=True,
1692
+ training_form="both",
1693
+ load_adapters=True,
1694
+ kbtc_training=False,
1695
+ optimize_mem_tokens=True,
1696
+ different_mem_tokens=True,
1697
+ attn_implementation='flash_attention_2'
1698
+ )
1699
+
1700
+ # Initialize model
1701
+ clara = CLaRa(cfg)
1702
+
1703
+ # Save and reload test
1704
+ clara.save_pretrained('test_ckpt')
1705
+
1706
+ del clara
1707
+ torch.cuda.empty_cache()
1708
+ gc.collect()
1709
+
1710
+ # Reload model
1711
+ clara = CLaRa.from_pretrained('test_ckpt')
1712
+ print("Model successfully loaded!")
compression-128/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
compression-128/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
compression-128/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
compression-128/tokenizer_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<s>",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "</s>",
35
+ "extra_special_tokens": {},
36
+ "legacy": false,
37
+ "model_max_length": 1000000000000000019884624838656,
38
+ "pad_token": "</s>",
39
+ "sp_model_kwargs": {},
40
+ "spaces_between_special_tokens": false,
41
+ "tokenizer_class": "LlamaTokenizer",
42
+ "unk_token": "<unk>",
43
+ "use_default_system_prompt": false
44
+ }
compression-16/adapters.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f328d35c099fe367727256b3c66ee18c5c1e68b5eb9b8b515e1cfa8ee8023d48
3
+ size 252096669
compression-16/chat_template.jinja ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if messages[0]['role'] == 'system' %}
2
+ {%- set system_message = messages[0]['content'] %}
3
+ {%- set loop_messages = messages[1:] %}
4
+ {%- else %}
5
+ {%- set loop_messages = messages %}
6
+ {%- endif %}
7
+
8
+ {{- bos_token }}
9
+ {%- for message in loop_messages %}
10
+ {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
11
+ {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
12
+ {%- endif %}
13
+ {%- if message['role'] == 'user' %}
14
+ {%- if loop.first and system_message is defined %}
15
+ {{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}
16
+ {%- else %}
17
+ {{- ' [INST] ' + message['content'] + ' [/INST]' }}
18
+ {%- endif %}
19
+ {%- elif message['role'] == 'assistant' %}
20
+ {{- ' ' + message['content'] + eos_token}}
21
+ {%- else %}
22
+ {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}
23
+ {%- endif %}
24
+ {%- endfor %}
compression-16/config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ae_mode": "token",
3
+ "attn_implementation": null,
4
+ "auto_map": {
5
+ "AutoConfig": "modeling_clara.CLaRaConfig",
6
+ "AutoModel": "modeling_clara.CLaRa"
7
+ },
8
+ "compr_base_model_name": "/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2",
9
+ "compr_every_n_layer": null,
10
+ "compr_linear_type": "concat",
11
+ "compr_mlp_hidden_dim": 8096,
12
+ "compr_model_name": null,
13
+ "compr_n_layers": 5,
14
+ "compr_rate": 16,
15
+ "compr_rms_norm": false,
16
+ "compr_use_mlp": false,
17
+ "decoder_model_name": "/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2",
18
+ "device_map": null,
19
+ "different_mem_tokens": true,
20
+ "doc_max_length": 256,
21
+ "generation_top_k": 5,
22
+ "kbtc_training": false,
23
+ "load_adapters": true,
24
+ "load_pretrained_checkpoint": false,
25
+ "lora": true,
26
+ "lora_compressor": false,
27
+ "lora_r": 16,
28
+ "lora_r_compressor": 16,
29
+ "max_new_tokens": 128,
30
+ "model_type": "CLaRa",
31
+ "optimize_mem_tokens": true,
32
+ "pad_token_id": 2,
33
+ "pure_inference": false,
34
+ "quantization": "no",
35
+ "sep": true,
36
+ "stage2_retrieval_top_n": 1,
37
+ "training_form": "both_separately",
38
+ "training_stage": "stage2",
39
+ "transformers_version": "4.53.3"
40
+ }
compression-16/decoder_first_last_layers.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5618c932bf93fb9e092d1a6eac240497d129c7aa7d4b87fb0b25065bdfa6c62f
3
+ size 524601397
compression-16/modeling_clara.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+
6
+ import warnings
7
+ import os
8
+ import torch
9
+ import gc
10
+ import time
11
+ import json
12
+ import copy
13
+ import random
14
+ import requests
15
+ import re
16
+
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ from torch.nn.functional import gelu
20
+ from jinja2.exceptions import TemplateError
21
+ from peft import LoraConfig
22
+ from transformers import (
23
+ AutoModelForCausalLM,
24
+ AutoTokenizer,
25
+ BitsAndBytesConfig,
26
+ PreTrainedModel,
27
+ PretrainedConfig,
28
+ StoppingCriteria,
29
+ StoppingCriteriaList
30
+ )
31
+ from huggingface_hub import hf_hub_download
32
+ from typing import List, Dict, Any, Optional, Tuple
33
+
34
+ # Environment setup
35
+ torch.set_printoptions(threshold=float("inf"))
36
+ os.environ["NCCL_TIMEOUT"] = "5400"
37
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
38
+
39
+ # Constants
40
+ IGNORE_INDEX = -100
41
+ PARAPHRASE_INSTRUCTIONS = [
42
+ 'Background: {docs} means the same as',
43
+ "Background: {docs} Can you put the above sentences in your own terms?",
44
+ "Background: {docs} Please provide a reinterpretation of the preceding background text.",
45
+ "These two expressions are equivalent in essence:\n(1) {docs}\n(2)",
46
+ "Background: {docs} is a paraphrase of what?",
47
+ "Background: {docs} Could you give me a different version of the background sentences above?",
48
+ "In other words, background: {docs} is just another way of saying:",
49
+ "You're getting across the same point whether you say background: {docs} or",
50
+ "Background: {docs} After unpacking the ideas in the background information above, we got:",
51
+ "Background: {docs} Please offer a restatement of the background sentences I've just read.",
52
+ "Background: {docs}, which also means:",
53
+ "Strip away the mystery, and you'll find background: {docs} is simply another rendition of:",
54
+ "The essence of background: {docs} is captured again in the following statement:",
55
+ ]
56
+
57
+
58
+ class StopOnCriteria(StoppingCriteria):
59
+ """Custom stopping criteria for generation."""
60
+
61
+ def __init__(self, tokenizer, stop_strings: List[str] = None, stop_token_ids: List[int] = None):
62
+ self.tokenizer = tokenizer
63
+ self.stop_strings = stop_strings or []
64
+ self.stop_token_ids = stop_token_ids or []
65
+ self.reason = None
66
+
67
+ def __call__(self, input_ids, scores, **kwargs):
68
+ # Check if last token is in stop_token_ids
69
+ last_token = input_ids[0, -1].item()
70
+ if last_token in self.stop_token_ids:
71
+ self.reason = f"stop_token_{last_token}"
72
+ return True
73
+
74
+ # Check if any stop_strings appear in generated text
75
+ text = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
76
+ for stop_str in self.stop_strings:
77
+ if stop_str in text:
78
+ self.reason = f"stop_string_{stop_str}"
79
+ return True
80
+
81
+ return False
82
+
83
+
84
+ class LlamaRMSNorm(nn.Module):
85
+ """Llama-style RMS normalization layer."""
86
+
87
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
88
+ super().__init__()
89
+ self.weight = nn.Parameter(torch.ones(hidden_size))
90
+ self.variance_epsilon = eps
91
+
92
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
93
+ input_dtype = hidden_states.dtype
94
+ hidden_states = hidden_states.to(torch.float32)
95
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
96
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
97
+ return self.weight * hidden_states.to(input_dtype)
98
+
99
+
100
+ class Converter(nn.Module):
101
+ """Converter module for dimension transformation."""
102
+
103
+ def __init__(self, input_dim: int, output_dim: int):
104
+ super().__init__()
105
+ self.input_dim = input_dim
106
+ self.output_dim = output_dim
107
+
108
+ self.rms_norm = LlamaRMSNorm(input_dim)
109
+ self.dense_in = nn.Linear(input_dim, output_dim)
110
+ self.dense_out = nn.Linear(output_dim, output_dim)
111
+
112
+ self._print_trainable_parameters()
113
+
114
+ def _print_trainable_parameters(self):
115
+ """Print parameter statistics."""
116
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
117
+ total_params = sum(p.numel() for p in self.parameters())
118
+ print(f"Converter trainable parameters: {trainable_params}, Total parameters: {total_params}")
119
+
120
+ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
121
+ embeddings = self.rms_norm(embeddings)
122
+ x = self.dense_in(embeddings)
123
+ x = self.dense_out(gelu(x))
124
+ return x.to(torch.float32)
125
+
126
+
127
+ class CLaRaConfig(PretrainedConfig):
128
+ """Configuration class for CLaRa model."""
129
+
130
+ model_type = "CLaRa"
131
+
132
+ def __init__(self,
133
+ decoder_model_name: str = "meta-llama/Llama-2-7b-chat-hf",
134
+ doc_max_length: int = 128,
135
+ quantization: str = 'no',
136
+ sep: bool = False,
137
+ compr_model_name: str = "google-bert/bert-base-uncased",
138
+ compr_rate: int = 64,
139
+ compr_n_layers: int = None,
140
+ compr_every_n_layer: int = None,
141
+ compr_base_model_name: str = '/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2',
142
+ compr_rms_norm: bool = False,
143
+ compr_mlp_hidden_dim: int = 8096,
144
+ compr_use_mlp: bool = True,
145
+ compr_linear_type: str = "concat",
146
+ lora: bool = False,
147
+ lora_compressor: bool = False,
148
+ training_form: str = "both",
149
+ training_stage: str = "stage1",
150
+ generation_top_k: int = 1,
151
+ lora_r: int = 16,
152
+ lora_r_compressor: int = None,
153
+ load_adapters: bool = True,
154
+ kbtc_training: bool = False,
155
+ optimize_mem_tokens: bool = False,
156
+ different_mem_tokens: bool = False,
157
+ attn_implementation: str = None,
158
+ _attn_implementation_autoset: bool = True,
159
+ ae_mode: str = "token",
160
+ max_new_tokens: int = 128,
161
+ stage2_retrieval_top_n: int = 1,
162
+ load_pretrained_checkpoint: bool = False,
163
+ device_map=None,
164
+ auto_map: dict = {
165
+ "AutoConfig": "modeling_clara.CLaRaConfig",
166
+ "AutoModel": "modeling_clara.CLaRa"
167
+ },
168
+ **kwargs):
169
+ super().__init__(**kwargs)
170
+
171
+ self.decoder_model_name = decoder_model_name
172
+ self.doc_max_length = doc_max_length
173
+ self.quantization = quantization
174
+ self.sep = sep
175
+
176
+ self.compr_model_name = compr_model_name
177
+ self.compr_rate = compr_rate
178
+ self.compr_use_mlp = compr_use_mlp
179
+ self.compr_mlp_hidden_dim = compr_mlp_hidden_dim
180
+ self.compr_n_layers = compr_n_layers
181
+ self.compr_every_n_layer = compr_every_n_layer
182
+ self.compr_base_model_name = compr_base_model_name
183
+ self.compr_rms_norm = compr_rms_norm
184
+ self.compr_linear_type = compr_linear_type
185
+
186
+ self.lora = lora
187
+ self.lora_compressor = lora_compressor
188
+ self.training_form = training_form
189
+ self.lora_r = lora_r
190
+ self.lora_r_compressor = lora_r_compressor or lora_r
191
+ self.load_adapters = load_adapters
192
+ self.optimize_mem_tokens = optimize_mem_tokens
193
+ self.different_mem_tokens = different_mem_tokens
194
+ self.kbtc_training = kbtc_training
195
+ self.training_stage = training_stage
196
+ self.device_map = device_map
197
+ self.attn_implementation = attn_implementation
198
+ self._attn_implementation_autoset = _attn_implementation_autoset
199
+ self.ae_mode = ae_mode
200
+ self.max_new_tokens = max_new_tokens
201
+ self.auto_map = auto_map
202
+ self.load_pretrained_checkpoint = load_pretrained_checkpoint
203
+
204
+ self.generation_top_k = generation_top_k
205
+ self.stage2_retrieval_top_n = stage2_retrieval_top_n
206
+
207
+ if training_form == 'compressor':
208
+ assert compr_model_name is not None and not self.lora
209
+
210
+
211
+ # Utility functions
212
+ def remote_generate(docs: List[str], questions: List[str], api_url: str) -> List[str]:
213
+ """Generate responses using remote API."""
214
+ response = requests.post(
215
+ f"{api_url}/generate",
216
+ json={"docs": docs, "questions": questions}
217
+ )
218
+ return response.json()["texts"]
219
+
220
+
221
+ def add_memory_tokens_to_inputs(input_ids: torch.Tensor,
222
+ attention_mask: torch.Tensor,
223
+ n_mem_tokens: int,
224
+ tokenizer) -> Tuple[torch.Tensor, torch.Tensor]:
225
+ """Add memory tokens to input sequences."""
226
+ assert len(tokenizer.mem_tokens) == n_mem_tokens
227
+
228
+ mem_tokens = torch.stack([tokenizer.mem_token_ids_pt] * input_ids.size(0), 0)
229
+ assert len(mem_tokens) == input_ids.size(0)
230
+ assert len(mem_tokens[0]) == n_mem_tokens
231
+
232
+ input_ids = torch.cat([input_ids, mem_tokens], dim=1)
233
+ attention_mask = torch.cat([attention_mask, torch.ones(input_ids.size(0), n_mem_tokens)], dim=1)
234
+
235
+ return input_ids, attention_mask
236
+
237
+
238
+ def build_pos_mask(pos_index: List[List[int]], N: int, device: torch.device) -> torch.Tensor:
239
+ """Build positive mask for retrieval training."""
240
+ if isinstance(pos_index, (list, tuple)):
241
+ B = len(pos_index)
242
+ mask = torch.zeros(B, N, dtype=torch.bool, device=device)
243
+ for b, idxs in enumerate(pos_index):
244
+ if len(idxs) > 0:
245
+ mask[b, torch.as_tensor(idxs, device=device, dtype=torch.long)] = True
246
+ return mask
247
+ else: # tensor [B, M]
248
+ B, M = pos_index.shape
249
+ mask = torch.zeros(B, N, dtype=torch.bool, device=device)
250
+ for m in range(M):
251
+ col = pos_index[:, m]
252
+ v = col >= 0
253
+ if v.any():
254
+ mask[v, col[v]] = True
255
+ return mask
256
+
257
+
258
+ def differentiable_topk_top_1(logits: torch.Tensor, k: int, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
259
+ """Implements differentiable top-1 selection using Gumbel-Softmax."""
260
+ y = logits / temperature
261
+ y_soft = F.softmax(y, dim=-1).float()
262
+
263
+ # Hard one-hot version
264
+ index = y_soft.argmax(dim=-1, keepdim=True)
265
+ y_hard = torch.zeros_like(y_soft).scatter_(-1, index, 1.0)
266
+
267
+ # Straight-through estimator
268
+ z = y_hard + y_soft - y_soft.detach()
269
+ z = z.unsqueeze(1).to(logits.dtype)
270
+
271
+ return z, index
272
+
273
+
274
+ def differentiable_topk(logits: torch.Tensor, k: int, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ """Differentiable top-k selection."""
276
+ B, N = logits.shape
277
+ perturbed = logits / max(temperature, 1e-6)
278
+
279
+ # Hard top-k indices
280
+ topk_vals, topk_idx = perturbed.topk(k, dim=-1)
281
+ K_hard = torch.zeros(B, k, N, device=logits.device, dtype=logits.dtype)
282
+ K_hard.scatter_(2, topk_idx.unsqueeze(-1), 1.0)
283
+
284
+ # Soft distributions for each slot
285
+ K_soft = torch.zeros_like(K_hard)
286
+ taken = torch.zeros(B, N, device=logits.device, dtype=logits.dtype)
287
+
288
+ for j in range(k):
289
+ mask = (1.0 - taken.detach())
290
+ masked = perturbed + (mask + 1e-8).log()
291
+ pj = F.softmax(masked, dim=-1).float()
292
+ K_soft[:, j, :] = pj
293
+ taken = torch.clamp(taken + K_hard[:, j, :], max=1.0)
294
+
295
+ # Straight-through estimator
296
+ W = K_hard + (K_soft - K_soft.detach())
297
+ return W, topk_idx
298
+
299
+
300
+ class CLaRa(PreTrainedModel):
301
+ """CLaRa: Unified Retrieval-Augmented Generation Model."""
302
+
303
+ config_class = CLaRaConfig
304
+
305
+ def __init__(self, cfg: CLaRaConfig):
306
+ super().__init__(cfg)
307
+ self.decoder_model_name = cfg.decoder_model_name
308
+ self.decoder = self._create_decoder(cfg)
309
+ self.doc_max_length = cfg.doc_max_length
310
+
311
+ print(f'Base decoder parameters: {self.decoder.num_parameters()}')
312
+
313
+ # Model configuration
314
+ self.compr_model_name = cfg.compr_model_name
315
+ self.training_form = cfg.training_form
316
+ self.lora = cfg.lora
317
+ self.adapter_keys = []
318
+ self.compr = None
319
+
320
+ # Initialize LoRA adapters if needed
321
+ if cfg.lora and not getattr(cfg, 'pure_inference', False):
322
+ self._setup_lora_adapters(cfg)
323
+
324
+ print(f'Model adapter keys: {self.adapter_keys}')
325
+
326
+ # Initialize tokenizer and resize embeddings
327
+ self.decoder_tokenizer = self._create_decoder_tokenizer(cfg)
328
+ self.decoder.resize_token_embeddings(len(self.decoder_tokenizer))
329
+ self._configure_generation_config()
330
+
331
+ # Model parameters
332
+ self.generation_top_k = cfg.generation_top_k
333
+ self.training_stage = cfg.training_stage
334
+ self.stage2_retrieval_top_n = cfg.stage2_retrieval_top_n
335
+ self.sep = cfg.sep
336
+ self.compr_rate = cfg.compr_rate
337
+ self.local_rank = os.getenv('LOCAL_RANK', '0')
338
+
339
+ self.n_mem_tokens = self.doc_max_length // self.compr_rate
340
+ self.hidden_size = self.decoder.config.hidden_size
341
+
342
+ # Setup adapters and memory token optimization
343
+ if self.lora:
344
+ self._setup_adapter_training()
345
+ else:
346
+ print(f'Total trainable parameters: {self.num_parameters(only_trainable=True)}')
347
+
348
+ self._prepare_mem_tokens_optimization()
349
+
350
+ # Retrieval configuration
351
+ self.url_retrieval = "http://127.0.0.1:5004/queries"
352
+
353
+ def _create_decoder(self, cfg: CLaRaConfig) -> AutoModelForCausalLM:
354
+ """Create and configure the decoder model."""
355
+ if not torch.cuda.is_available():
356
+ return AutoModelForCausalLM.from_pretrained(
357
+ cfg.decoder_model_name,
358
+ torch_dtype=torch.bfloat16,
359
+ resume_download=True,
360
+ trust_remote_code=True,
361
+ device_map=cfg.device_map
362
+ )
363
+
364
+ if cfg.quantization == "no":
365
+ return AutoModelForCausalLM.from_pretrained(
366
+ cfg.decoder_model_name,
367
+ torch_dtype=torch.bfloat16,
368
+ attn_implementation=cfg.attn_implementation,
369
+ device_map=cfg.device_map
370
+ )
371
+ elif cfg.quantization == "int4":
372
+ quant_config = BitsAndBytesConfig(
373
+ load_in_4bit=True,
374
+ bnb_4bit_quant_type='nf4',
375
+ bnb_4bit_compute_dtype='bfloat16',
376
+ )
377
+ return AutoModelForCausalLM.from_pretrained(
378
+ cfg.decoder_model_name,
379
+ quantization_config=quant_config,
380
+ attn_implementation=cfg.attn_implementation,
381
+ torch_dtype=torch.bfloat16,
382
+ resume_download=True,
383
+ trust_remote_code=True,
384
+ device_map=cfg.device_map
385
+ )
386
+ elif cfg.quantization == "int8":
387
+ quant_config = BitsAndBytesConfig(
388
+ load_in_8bit=True,
389
+ llm_int8_enable_fp32_cpu_offload=True,
390
+ bnb_4bit_compute_dtype='bfloat16',
391
+ )
392
+ return AutoModelForCausalLM.from_pretrained(
393
+ cfg.decoder_model_name,
394
+ quantization_config=quant_config,
395
+ attn_implementation=cfg.attn_implementation,
396
+ torch_dtype=torch.bfloat16,
397
+ resume_download=True,
398
+ trust_remote_code=True,
399
+ device_map=cfg.device_map
400
+ )
401
+ else:
402
+ raise NotImplementedError(f"Quantization {cfg.quantization} not supported")
403
+
404
+ def _setup_lora_adapters(self, cfg: CLaRaConfig):
405
+ """Setup LoRA adapters based on training stage."""
406
+ peft_config = self._get_peft_config(lora_r=cfg.lora_r)
407
+
408
+ if cfg.training_stage == "stage1" and cfg.load_adapters:
409
+ print('Loading encoder and decoder adapter for stage1')
410
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
411
+ self.adapter_keys.append('decoder_adapter')
412
+ self.decoder.add_adapter(peft_config, 'encoder_adapter')
413
+ self.adapter_keys.append('encoder_adapter')
414
+ elif cfg.training_stage == "stage2" and cfg.load_adapters:
415
+ if 'decoder_adapter' not in self.adapter_keys:
416
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
417
+ self.adapter_keys.append('decoder_adapter')
418
+ if 'query_reasoner_adapter' not in self.adapter_keys:
419
+ self.decoder.add_adapter(peft_config, 'query_reasoner_adapter')
420
+ self.adapter_keys.append('query_reasoner_adapter')
421
+ elif cfg.training_stage == 'stage1_2':
422
+ if not cfg.load_adapters:
423
+ print('Loading decoder adapter for stage1_2')
424
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
425
+ self.adapter_keys.append('decoder_adapter')
426
+ elif cfg.load_adapters:
427
+ print('Loading encoder and decoder adapter for stage1_2')
428
+ self.decoder.add_adapter(peft_config, 'encoder_adapter')
429
+ self.adapter_keys.append('encoder_adapter')
430
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
431
+ self.adapter_keys.append('decoder_adapter')
432
+ elif cfg.training_stage == 'stage2_reasoning':
433
+ if not cfg.load_adapters:
434
+ print('Loading decoder adapter for stage2_reasoning')
435
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
436
+ self.adapter_keys.append('decoder_adapter')
437
+
438
+ def _setup_adapter_training(self):
439
+ """Setup adapters for training."""
440
+ for adapter_key in self.adapter_keys:
441
+ self.decoder.set_adapter(adapter_key)
442
+ print(f'Adapter {adapter_key} trainable parameters: {self.num_parameters(only_trainable=True)}')
443
+ self._set_all_adapters()
444
+
445
+ def _configure_generation_config(self):
446
+ """Configure generation parameters."""
447
+ self.decoder.generation_config.top_p = None
448
+ self.decoder.generation_config.temperature = None
449
+ self.decoder.generation_config.pad_token_id = self.decoder_tokenizer.pad_token_id
450
+
451
+ @staticmethod
452
+ def _create_decoder_tokenizer(cfg: CLaRaConfig) -> AutoTokenizer:
453
+ """Create and configure the decoder tokenizer."""
454
+ tokenizer = AutoTokenizer.from_pretrained(
455
+ cfg.decoder_model_name,
456
+ use_fast=True,
457
+ padding_side='left'
458
+ )
459
+
460
+ # Define special tokens
461
+ n_mem_tokens = cfg.doc_max_length // cfg.compr_rate
462
+ existing_special_tokens = tokenizer.special_tokens_map.get("additional_special_tokens", [])
463
+
464
+ if cfg.different_mem_tokens:
465
+ mem_tokens = [f'<MEM{i}>' for i in range(n_mem_tokens)]
466
+ tokenizer.add_special_tokens({
467
+ 'additional_special_tokens': existing_special_tokens + mem_tokens + ['<AE>', '<ENC>', '<SEP>']
468
+ })
469
+ tokenizer.mem_tokens = mem_tokens
470
+ else:
471
+ tokenizer.add_special_tokens({
472
+ 'additional_special_tokens': existing_special_tokens + ['<MEM>', '<AE>', '<ENC>', '<SEP>']
473
+ })
474
+ tokenizer.mem_tokens = ['<MEM>'] * n_mem_tokens
475
+
476
+ tokenizer.mem_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokenizer.mem_tokens]
477
+ tokenizer.mem_token_ids_pt = torch.LongTensor(tokenizer.mem_token_ids)
478
+
479
+ # Additional special tokens
480
+ tokenizer.ae_token = '<AE>'
481
+ tokenizer.ae_token_id = tokenizer.convert_tokens_to_ids('<AE>')
482
+ tokenizer.enc_token = '<ENC>'
483
+ tokenizer.sep_token = '<SEP>'
484
+ tokenizer.sep_token_id = tokenizer.convert_tokens_to_ids('<SEP>')
485
+
486
+ # Handle model-specific tokens
487
+ if tokenizer.bos_token is None and 'qwen' in cfg.decoder_model_name.lower():
488
+ tokenizer.bos_token = tokenizer.special_tokens_map['additional_special_tokens'][0]
489
+ tokenizer.bos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.bos_token)
490
+
491
+ if tokenizer.eos_token is None and "qwen" in cfg.decoder_model_name.lower():
492
+ tokenizer.eos_token = tokenizer.special_tokens_map['additional_special_tokens'][1]
493
+ tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
494
+
495
+ # KBTC training tokens
496
+ if cfg.kbtc_training:
497
+ tokenizer.add_special_tokens({'additional_special_tokens': ['<KBTC>']})
498
+ tokenizer.kbtc_token = '<KBTC>'
499
+ tokenizer.kbtc_token_id = tokenizer.convert_tokens_to_ids('<KBTC>')
500
+
501
+ # Set pad token
502
+ if tokenizer.pad_token_id is None:
503
+ tokenizer.pad_token_id = tokenizer.bos_token_id
504
+
505
+ print(f'Memory token count: {n_mem_tokens}')
506
+ return tokenizer
507
+
508
+ def _get_peft_config(self, lora_r: int) -> LoraConfig:
509
+ """Build the PEFT configuration."""
510
+ return LoraConfig(
511
+ task_type="CAUSAL_LM",
512
+ r=lora_r,
513
+ lora_alpha=2*lora_r,
514
+ target_modules='all-linear',
515
+ lora_dropout=0.1
516
+ )
517
+
518
+ def _prepare_mem_tokens_optimization(self):
519
+ """Setup memory token optimization if enabled."""
520
+ if self.config.optimize_mem_tokens and self.compr is None:
521
+ # Enable gradients for input embeddings
522
+ self.decoder.get_input_embeddings().weight.requires_grad = True
523
+
524
+ # Apply hook to zero gradients except for memory tokens
525
+ def hook(grad):
526
+ mask = torch.zeros_like(grad)
527
+ mask[self.decoder_tokenizer.mem_token_ids] = 1.0
528
+ return grad * mask
529
+
530
+ self.decoder.get_input_embeddings().weight.register_hook(hook)
531
+
532
+ def _set_all_adapters(self):
533
+ """Activate all adapters for training."""
534
+ if len(self.adapter_keys) > 0:
535
+ self.decoder.set_adapter(self.adapter_keys)
536
+
537
+ # Core compression and generation methods
538
+ def compress(self, enc_input_ids: torch.Tensor, enc_attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
539
+ """Compress input documents."""
540
+ if self.compr:
541
+ return self.compr(enc_input_ids, enc_attention_mask)
542
+ else:
543
+ return self._compr_decoder(enc_input_ids, enc_attention_mask)
544
+
545
+ def _compr_decoder(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
546
+ """Use decoder as compressor."""
547
+ assert input_ids.size() == attention_mask.size()
548
+
549
+ if 'encoder_adapter' in self.adapter_keys:
550
+ self.decoder.set_adapter('encoder_adapter')
551
+ else:
552
+ raise ValueError(f"encoder_adapter not in adapter_keys: {self.adapter_keys}")
553
+
554
+ # Get embeddings from decoder
555
+ emb = self.decoder(
556
+ input_ids=input_ids,
557
+ attention_mask=attention_mask,
558
+ output_hidden_states=True
559
+ ).hidden_states[-1]
560
+
561
+ # Create mask for memory tokens
562
+ mask = torch.isin(
563
+ input_ids,
564
+ self.decoder_tokenizer.mem_token_ids_pt.to(input_ids.device)
565
+ )
566
+
567
+ # Calculate MSE loss between memory and non-memory regions
568
+ attn = attention_mask.bool()
569
+ mem_mask = mask & attn
570
+ non_mem_mask = (~mask) & attn
571
+
572
+ mem_len = mem_mask.sum(dim=1)
573
+ non_mem_len = non_mem_mask.sum(dim=1)
574
+
575
+ if (mem_len == 0).any():
576
+ raise ValueError("Some samples have no memory tokens")
577
+ if (non_mem_len == 0).any():
578
+ raise ValueError("Some samples have no non-memory tokens")
579
+
580
+ mem_sum = (emb * mem_mask.unsqueeze(-1)).sum(dim=1)
581
+ non_mem_sum = (emb * non_mem_mask.unsqueeze(-1)).sum(dim=1)
582
+
583
+ mem_mean = mem_sum / mem_len.unsqueeze(-1)
584
+ non_mem_mean = non_mem_sum / non_mem_len.unsqueeze(-1)
585
+
586
+ mse_loss = F.mse_loss(non_mem_mean, mem_mean, reduction='mean')
587
+
588
+ return emb[mask].reshape(emb.size(0), -1, emb.size(-1)), mse_loss
589
+
590
+ def _compr_query_reasoner_stage2(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
591
+ """Query reasoning compression for stage 2."""
592
+ assert input_ids.size() == attention_mask.size()
593
+
594
+ if 'query_reasoner_adapter' in self.adapter_keys:
595
+ self.decoder.set_adapter('query_reasoner_adapter')
596
+ else:
597
+ raise ValueError(f"query_reasoner_adapter not in adapter_keys: {self.adapter_keys}")
598
+
599
+ emb = self.decoder(
600
+ input_ids=input_ids,
601
+ attention_mask=attention_mask,
602
+ output_hidden_states=True
603
+ ).hidden_states[-1]
604
+
605
+ mask = torch.isin(
606
+ input_ids,
607
+ self.decoder_tokenizer.mem_token_ids_pt.to(input_ids.device)
608
+ )
609
+
610
+ return emb[mask].reshape(emb.size(0), -1)
611
+
612
+ # Generation methods
613
+ def generate_from_questions(self,
614
+ questions: List[str],
615
+ max_new_tokens: int = 128,
616
+ temperature: float = 0.5,
617
+ documents: List[List[str]] = None,
618
+ stage2_mips: bool = False,
619
+ stage2_retrieval_top_n: int = None,
620
+ time_count: bool = False) -> Tuple[List[str], torch.Tensor]:
621
+ """Generate answers from questions using query reasoning."""
622
+ if "query_reasoner_adapter" not in self.adapter_keys:
623
+ raise ValueError("Query reasoner adapter not found")
624
+
625
+ self.eval()
626
+
627
+ with torch.no_grad():
628
+ # Encode questions
629
+ self.decoder.set_adapter('query_reasoner_adapter')
630
+ flat_questions = [q for q in questions]
631
+
632
+ if time_count:
633
+ start_time = time.time()
634
+
635
+ q_tok = self._prepare_encoder_inputs(flat_questions, max_length=self.doc_max_length)
636
+ query_reps = self._compr_query_reasoner_stage2(
637
+ q_tok["input_ids"].to(self.decoder.device),
638
+ q_tok["attention_mask"].to(self.decoder.device)
639
+ )
640
+
641
+ # Document retrieval and selection
642
+ if stage2_mips:
643
+ retrieved_doc_embeddings = self._retrieve_embeddings(
644
+ query_reps, stage2_retrieval_top_n=stage2_retrieval_top_n
645
+ )
646
+ scores = torch.bmm(
647
+ query_reps.unsqueeze(1),
648
+ retrieved_doc_embeddings.transpose(1, 2)
649
+ ).squeeze(1)
650
+ z, topk_idx = differentiable_topk(scores, self.generation_top_k, temperature=0.5)
651
+ selected_doc_embeddings = torch.einsum('bkn,bnd->bkd', z, retrieved_doc_embeddings)
652
+ selected_doc_embeddings = selected_doc_embeddings.view(
653
+ selected_doc_embeddings.size(0) * selected_doc_embeddings.size(1),
654
+ -1, self.hidden_size
655
+ )
656
+ else:
657
+ # Use provided documents
658
+ flat_documents = sum(documents, [])
659
+
660
+ if time_count:
661
+ start_time1 = time.time()
662
+
663
+ input_encoder = self._prepare_encoder_inputs(flat_documents, max_length=self.doc_max_length)
664
+ device = self.decoder.device
665
+ enc_input_ids = input_encoder['input_ids'].to(device)
666
+ enc_attention_mask = input_encoder['attention_mask'].to(device)
667
+ retrieved_doc_embeddings, _ = self.compress(enc_input_ids, enc_attention_mask)
668
+
669
+ if time_count:
670
+ start_time2 = time.time()
671
+ compress_time = start_time2 - start_time1
672
+
673
+ B = len(questions)
674
+ stage2_retrieval_top_n = retrieved_doc_embeddings.shape[0] // B
675
+ retrieved_doc_embeddings = retrieved_doc_embeddings.reshape(B, stage2_retrieval_top_n, -1)
676
+ query_reps = query_reps.to(retrieved_doc_embeddings.dtype)
677
+
678
+ if time_count:
679
+ start_time3 = time.time()
680
+
681
+ scores = torch.bmm(
682
+ F.normalize(query_reps, dim=-1, p=2).unsqueeze(1).float(),
683
+ F.normalize(retrieved_doc_embeddings, dim=-1, p=2).float().transpose(1, 2)
684
+ ).squeeze(1)
685
+
686
+ z, topk_idx = differentiable_topk(scores, self.generation_top_k, temperature=0.02)
687
+ selected_doc_embeddings = torch.einsum('bkn,bnd->bkd', z.to(retrieved_doc_embeddings.dtype), retrieved_doc_embeddings)
688
+ selected_doc_embeddings = selected_doc_embeddings.view(
689
+ selected_doc_embeddings.size(0) * selected_doc_embeddings.size(1),
690
+ -1, self.hidden_size
691
+ )
692
+
693
+ if time_count:
694
+ start_time4 = time.time()
695
+ query_time = start_time4 - start_time3 + start_time1 - start_time
696
+
697
+ # Generate instructions and decode
698
+ if time_count:
699
+ start_time5 = time.time()
700
+
701
+ instructions = [
702
+ self._blend_prompt_and_selected_memory_tokens(query=q)[1]
703
+ for q in questions
704
+ ]
705
+
706
+ decoder_inputs = self.decoder_tokenizer(
707
+ instructions,
708
+ return_tensors='pt',
709
+ padding="longest",
710
+ add_special_tokens=False,
711
+ truncation=True,
712
+ max_length=1024,
713
+ )
714
+
715
+ dec_input_ids = decoder_inputs['input_ids'].to(self.decoder.device)
716
+ dec_attention_mask = decoder_inputs['attention_mask'].to(self.decoder.device)
717
+
718
+ # Replace memory token embeddings
719
+ inputs_embeds = self._replace_emb_stage2(selected_doc_embeddings, dec_input_ids)
720
+
721
+ # Switch to decoder adapter for generation
722
+ if 'decoder_adapter' in self.adapter_keys:
723
+ self.decoder.set_adapter('decoder_adapter')
724
+
725
+ # Generate answers
726
+ output_ids = self.decoder.generate(
727
+ inputs_embeds=inputs_embeds,
728
+ attention_mask=dec_attention_mask,
729
+ do_sample=False,
730
+ top_p=None,
731
+ temperature=None,
732
+ max_new_tokens=max_new_tokens,
733
+ pad_token_id=self.decoder_tokenizer.pad_token_id
734
+ )
735
+
736
+ if time_count:
737
+ start_time6 = time.time()
738
+ generate_time = start_time6 - start_time5
739
+
740
+ # Decode generated tokens
741
+ decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
742
+
743
+ if time_count:
744
+ return decoded, topk_idx, compress_time, query_time, generate_time, compress_time + query_time + generate_time
745
+ else:
746
+ return decoded, topk_idx
747
+ def generate_from_paraphrase(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
748
+ """
749
+ Generates answers from documents (via compression then decoding)
750
+ questions: list of string
751
+ documents: list of list of strings (they should all be of equal length: the nb of doc for each question)
752
+ """
753
+ self.generation_top_k = len(documents[0])
754
+ assert len(documents) == len(questions)
755
+ assert all([len(context) == len(documents[0]) for context in documents])
756
+ flat_documents = sum(documents, [])
757
+
758
+ model_input = {}
759
+
760
+ # Creating encoder inputs:
761
+ input_encoder = self._prepare_encoder_inputs(flat_documents, max_length=self.doc_max_length)
762
+ device = self.decoder.device
763
+ model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device)
764
+
765
+ # Creating decoder inputs
766
+ instr = [self._blend_prompt_and_memory_tokens(query="", stage = "stage1", paraphrase_loss = True) for q in questions]
767
+ inp_dec = self.decoder_tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=1024)
768
+ model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
769
+
770
+ # Generation
771
+ return self._generate(model_input, max_new_tokens=max_new_tokens)
772
+
773
+
774
+ def generate_from_text(self,
775
+ questions: List[str],
776
+ documents: List[List[str]],
777
+ max_new_tokens: int = 128) -> List[str]:
778
+ """Generate answers from documents via compression then decoding."""
779
+ self.generation_top_k = len(documents[0])
780
+ assert len(documents) == len(questions)
781
+ assert all(len(context) == len(documents[0]) for context in documents)
782
+
783
+ flat_documents = sum(documents, [])
784
+
785
+ # Create encoder inputs
786
+ input_encoder = self._prepare_encoder_inputs(flat_documents, max_length=self.doc_max_length)
787
+ device = self.decoder.device
788
+ enc_input_ids = input_encoder['input_ids'].to(device)
789
+ enc_attention_mask = input_encoder['attention_mask'].to(device)
790
+
791
+ # Create decoder inputs
792
+ instructions = [self._blend_prompt_and_memory_tokens(query=q, stage="stage1_2") for q in questions]
793
+ inp_dec = self.decoder_tokenizer(
794
+ instructions,
795
+ return_tensors='pt',
796
+ padding="longest",
797
+ add_special_tokens=False,
798
+ truncation=True,
799
+ max_length=1024
800
+ )
801
+ dec_input_ids = inp_dec['input_ids'].to(device)
802
+ dec_attention_mask = inp_dec['attention_mask'].to(device)
803
+
804
+ # Generate
805
+ return self._generate({
806
+ 'enc_input_ids': enc_input_ids,
807
+ 'enc_attention_mask': enc_attention_mask,
808
+ 'dec_input_ids': dec_input_ids,
809
+ 'dec_attention_mask': dec_attention_mask
810
+ }, max_new_tokens=max_new_tokens)
811
+
812
+ def generate_from_compressed_documents_and_questions(self,
813
+ questions: List[str],
814
+ compressed_documents: torch.Tensor,
815
+ max_new_tokens: int = 128) -> List[str]:
816
+ """Generate answers from compressed documents."""
817
+ self.generation_top_k = compressed_documents.size(0) // len(questions)
818
+ assert compressed_documents.size(0) % self.generation_top_k == 0
819
+
820
+ # Create decoder inputs
821
+ instructions = [self._blend_prompt_and_memory_tokens(query=q, stage="stage1_2") for q in questions]
822
+ inp_dec = self.decoder_tokenizer(
823
+ instructions,
824
+ return_tensors='pt',
825
+ padding="longest",
826
+ add_special_tokens=False,
827
+ truncation=True,
828
+ max_length=1024
829
+ )
830
+ device = self.decoder.device
831
+ dec_input_ids = inp_dec['input_ids'].to(device)
832
+ dec_attention_mask = inp_dec['attention_mask'].to(device)
833
+
834
+ # Create input decoder embeddings from prompt + compressed documents
835
+ inputs_embeds = self._replace_emb(compressed_documents, dec_input_ids)
836
+
837
+ # Activate decoder generator
838
+ if 'decoder_adapter' in self.adapter_keys:
839
+ self.decoder.set_adapter('decoder_adapter')
840
+
841
+ output_ids = self.decoder.generate(
842
+ inputs_embeds=inputs_embeds,
843
+ attention_mask=dec_attention_mask,
844
+ max_new_tokens=max_new_tokens
845
+ )
846
+
847
+ return self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
848
+
849
+ def compress_documents(self, documents: List[str]) -> torch.Tensor:
850
+ """Compress a list of documents."""
851
+ input_encoder = self._prepare_encoder_inputs(documents, max_length=self.doc_max_length)
852
+ enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
853
+ attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
854
+ return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
855
+
856
+ # Helper methods
857
+ def _prepare_encoder_inputs(self, texts: List[str], max_length: int, q_texts: List[str] = None) -> Dict[str, torch.Tensor]:
858
+ """Create inputs for the encoder."""
859
+ if q_texts is not None:
860
+ assert len(texts) == len(q_texts)
861
+
862
+ if self.compr is None:
863
+ return self._prepare_encoder_inputs_to_decoder(texts, max_length, q_texts)
864
+ else:
865
+ return self.compr.prepare_inputs(texts, max_length, q_texts)
866
+
867
+ def _prepare_encoder_inputs_to_decoder(self, texts: List[str], max_length: int, q_texts: List[str] = None) -> Dict[str, torch.Tensor]:
868
+ """Prepare encoder inputs when using decoder as compressor."""
869
+ if q_texts is not None:
870
+ texts_to_encode = [
871
+ self.decoder_tokenizer.enc_token +
872
+ self.decoder_tokenizer.bos_token +
873
+ '\nQuery:\n' + query +
874
+ 'Document:\n' + text +
875
+ self.decoder_tokenizer.eos_token
876
+ for text, query in zip(texts, q_texts)
877
+ ]
878
+ inp_enc = self.decoder_tokenizer(
879
+ texts_to_encode,
880
+ return_tensors='pt',
881
+ padding='max_length',
882
+ max_length=max_length + 8,
883
+ truncation=True,
884
+ add_special_tokens=False
885
+ )
886
+ else:
887
+ inp_enc = [
888
+ self.decoder_tokenizer.enc_token +
889
+ self.decoder_tokenizer.bos_token +
890
+ text +
891
+ self.decoder_tokenizer.eos_token
892
+ for text in texts
893
+ ]
894
+ inp_enc = self.decoder_tokenizer(
895
+ inp_enc,
896
+ return_tensors='pt',
897
+ padding="max_length",
898
+ max_length=max_length + 3,
899
+ truncation=True,
900
+ add_special_tokens=False
901
+ )
902
+
903
+ num_mem_tokens = self.doc_max_length // self.compr_rate
904
+ assert num_mem_tokens == len(self.decoder_tokenizer.mem_tokens)
905
+
906
+ inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(
907
+ inp_enc['input_ids'],
908
+ inp_enc['attention_mask'],
909
+ num_mem_tokens,
910
+ tokenizer=self.decoder_tokenizer
911
+ )
912
+
913
+ return inp_enc
914
+
915
+ def _replace_emb(self, compressed_embs: torch.Tensor, dec_input_ids: torch.Tensor) -> torch.Tensor:
916
+ """Replace memory tokens in decoder input with compressed embeddings."""
917
+ indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
918
+ return self._replace_embeddings(compressed_embs, dec_input_ids, indices)
919
+
920
+ def _replace_emb_stage2(self, compressed_embs: torch.Tensor, dec_input_ids: torch.Tensor) -> torch.Tensor:
921
+ """Replace memory tokens for stage 2."""
922
+ indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
923
+ return self._replace_embeddings(compressed_embs, dec_input_ids, indices)
924
+
925
+ def _replace_embeddings(self, compressed_embs: torch.Tensor, dec_input_ids: torch.Tensor, indices: range) -> torch.Tensor:
926
+ """Replace memory tokens with compressed embeddings."""
927
+ inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
928
+ num_embs = compressed_embs.size(1)
929
+ slot_len = num_embs + (1 if self.sep else 0)
930
+
931
+ # Get first memory token indices
932
+ first_mem_token_indices = torch.argmax(
933
+ (dec_input_ids == self.decoder_tokenizer.mem_token_ids[0]).int(), dim=1
934
+ )
935
+ batch_size = inputs_embeds.size(0)
936
+
937
+ # Replace with compressed embeddings
938
+ for i in range(batch_size):
939
+ for j in range(indices[i], indices[i + 1]):
940
+ start_idx = first_mem_token_indices[i].item() + (j - indices[i]) * slot_len
941
+ assert inputs_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size()
942
+ inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
943
+
944
+ return inputs_embeds
945
+
946
+ def _retrieve_embeddings(self, questions: torch.Tensor, stage2_retrieval_top_n: int = 1) -> torch.Tensor:
947
+ """Retrieve embeddings of documents."""
948
+ response = requests.post(
949
+ self.url_retrieval,
950
+ json={
951
+ "queries": questions.detach().cpu().float().numpy().tolist(),
952
+ 'k': self.generation_top_k
953
+ }
954
+ )
955
+
956
+ if response.status_code != 200:
957
+ raise Exception(f"Error: {response.status_code} - {response.text}")
958
+
959
+ results = response.json()
960
+ retrieval_embeddings = results['retrieved_embeddings']
961
+ retrieval_embeddings = torch.tensor(
962
+ retrieval_embeddings,
963
+ dtype=torch.bfloat16,
964
+ device=questions.device
965
+ )
966
+
967
+ if len(retrieval_embeddings.shape) == 4:
968
+ retrieval_embeddings = retrieval_embeddings.reshape(
969
+ retrieval_embeddings.shape[0] * retrieval_embeddings.shape[1],
970
+ retrieval_embeddings.shape[2], -1
971
+ )
972
+
973
+ return retrieval_embeddings
974
+
975
+ def _blend_prompt_and_memory_tokens(self, query: str, answer: str = None, qa_loss: bool = False,
976
+ paraphrase_loss: bool = False, stage: str = "stage1") -> Tuple[int, str]:
977
+ """Blend prompt with memory tokens for different training stages."""
978
+ mem_tokens_str = ''.join(self.decoder_tokenizer.mem_tokens) + self.decoder_tokenizer.sep_token
979
+ docs = mem_tokens_str * self.generation_top_k
980
+
981
+ if stage == "stage1":
982
+ if qa_loss:
983
+ return self._blend_qa_prompt(docs, query, answer)
984
+ elif paraphrase_loss:
985
+ return self._blend_paraphrase_prompt(docs, answer)
986
+ elif stage == "stage1_2":
987
+ return self._blend_standard_prompt(docs, query, answer)
988
+
989
+ raise ValueError(f"Unknown stage: {stage}")
990
+
991
+ def _blend_qa_prompt(self, docs: str, query: List[str], answer: List[str]) -> Tuple[int, str]:
992
+ """Create QA prompt for stage 1."""
993
+ prompt_system = 'You are a helpful assistant. Given a document, your task is to generate some single questions to cover all key information of the document and answer them sequentially.'
994
+ prompt_user = f"Background:\n{docs}"
995
+
996
+ sys_prompt = [{"role": "system", "content": prompt_system}]
997
+ user_prompt = [{"role": "user", "content": prompt_user.replace(':\ ', ': ')}]
998
+
999
+ qa_lines = [f"Question: {q}\nAnswer: {a}" for q, a in zip(query, answer)]
1000
+ query_answer = "\n".join(qa_lines)
1001
+ assistant_prompt = [{"role": "assistant", "content": query_answer}]
1002
+
1003
+ try:
1004
+ prompt = self.decoder_tokenizer.apply_chat_template(
1005
+ sys_prompt + user_prompt,
1006
+ tokenize=False,
1007
+ add_generation_prompt=True,
1008
+ enable_thinking=False
1009
+ )
1010
+ response = self.decoder_tokenizer.apply_chat_template(
1011
+ sys_prompt + user_prompt + assistant_prompt,
1012
+ tokenize=False,
1013
+ add_generation_prompt=False,
1014
+ enable_thinking=False
1015
+ )
1016
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1017
+ except TemplateError as e:
1018
+ if "System role not supported" in str(e):
1019
+ messages = [{"role": "user", "content": sys_prompt[0]['content'] + '\n' + user_prompt[0]['content']}]
1020
+ prompt = self.decoder_tokenizer.apply_chat_template(
1021
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
1022
+ )
1023
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1024
+ # Handle response for unsupported system role
1025
+ messages_with_answer = messages + assistant_prompt
1026
+ response = self.decoder_tokenizer.apply_chat_template(
1027
+ messages_with_answer, tokenize=False, add_generation_prompt=False, enable_thinking=False
1028
+ )
1029
+ else:
1030
+ raise e
1031
+
1032
+ return prompt_len, response
1033
+
1034
+ def _blend_paraphrase_prompt(self, docs: str, answer: str) -> Tuple[int, str]:
1035
+ """Create paraphrase prompt for stage 1."""
1036
+ prompt_system = 'You are a helpful assistant. Your task is follow the instructions to paraphrase the background information.'
1037
+ prompt_user = random.choice(PARAPHRASE_INSTRUCTIONS).format(docs=docs)
1038
+
1039
+ sys_prompt = [{"role": "system", "content": prompt_system}]
1040
+ user_prompt = [{"role": "user", "content": prompt_user.replace(':\ ', ': ')}]
1041
+
1042
+ try:
1043
+ prompt = self.decoder_tokenizer.apply_chat_template(
1044
+ sys_prompt + user_prompt,
1045
+ tokenize=False,
1046
+ add_generation_prompt=True,
1047
+ enable_thinking=False
1048
+ )
1049
+ if answer is None:
1050
+ return prompt
1051
+
1052
+ assistant_prompt = [{"role": "assistant", "content": answer}]
1053
+ response = self.decoder_tokenizer.apply_chat_template(
1054
+ sys_prompt + user_prompt + assistant_prompt,
1055
+ tokenize=False,
1056
+ add_generation_prompt=False,
1057
+ enable_thinking=False
1058
+ )
1059
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1060
+ except TemplateError as e:
1061
+ if "System role not supported" in str(e):
1062
+ combined_content = prompt_system + '\n' + prompt_user.replace(':\ ', ': ')
1063
+ messages = [{"role": "user", "content": combined_content}]
1064
+ prompt = self.decoder_tokenizer.apply_chat_template(
1065
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
1066
+ )
1067
+ if answer is None:
1068
+ return prompt
1069
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1070
+ messages_with_answer = messages + [{"role": "assistant", "content": answer}]
1071
+ response = self.decoder_tokenizer.apply_chat_template(
1072
+ messages_with_answer, tokenize=False, add_generation_prompt=False, enable_thinking=False
1073
+ )
1074
+ else:
1075
+ raise e
1076
+
1077
+ return prompt_len, response
1078
+
1079
+ def _blend_standard_prompt(self, docs: str, query: str, answer: str) -> Tuple[int, str]:
1080
+ """Create standard prompt for stage 1_2."""
1081
+ prompt_system = 'You are a helpful assistant. Your task is to extract relevant information from provided documents and to answer to questions as briefly as possible.'
1082
+ prompt_user = f"Background:\n{docs}\n\nQuestion:{query}"
1083
+
1084
+ sys_prompt = [{"role": "system", "content": prompt_system}]
1085
+ user_prompt = [{"role": "user", "content": prompt_user.replace(':\ ', ': ')}]
1086
+
1087
+ try:
1088
+ prompt = self.decoder_tokenizer.apply_chat_template(
1089
+ sys_prompt + user_prompt,
1090
+ tokenize=False,
1091
+ add_generation_prompt=True,
1092
+ enable_thinking=False
1093
+ )
1094
+ if answer is None:
1095
+ return prompt
1096
+
1097
+ assistant_prompt = [{"role": "assistant", "content": answer}]
1098
+ response = self.decoder_tokenizer.apply_chat_template(
1099
+ sys_prompt + user_prompt + assistant_prompt,
1100
+ tokenize=False,
1101
+ add_generation_prompt=False,
1102
+ enable_thinking=False
1103
+ )
1104
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1105
+ except TemplateError as e:
1106
+ if "System role not supported" in str(e):
1107
+ combined_content = prompt_system + '\n' + prompt_user.replace(':\ ', ': ')
1108
+ messages = [{"role": "user", "content": combined_content}]
1109
+ prompt = self.decoder_tokenizer.apply_chat_template(
1110
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
1111
+ )
1112
+ if answer is None:
1113
+ return prompt
1114
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1115
+ messages_with_answer = messages + [{"role": "assistant", "content": answer}]
1116
+ response = self.decoder_tokenizer.apply_chat_template(
1117
+ messages_with_answer, tokenize=False, add_generation_prompt=False, enable_thinking=False
1118
+ )
1119
+ else:
1120
+ raise e
1121
+
1122
+ return prompt_len, response
1123
+
1124
+ def _blend_prompt_and_selected_memory_tokens(self, query: str, answer: str = None) -> Tuple[int, str]:
1125
+ """Create prompt for stage 2 with selected memory tokens."""
1126
+ mem_tokens_str = ''.join(self.decoder_tokenizer.mem_tokens) + self.decoder_tokenizer.sep_token
1127
+ docs = mem_tokens_str * self.generation_top_k
1128
+
1129
+ prompt_system = 'You are a helpful assistant. Your task is to extract relevant information from provided documents and to answer to questions as briefly as possible.'
1130
+ prompt_user = f"Background:\n{docs}\n\nQuestion:{query}"
1131
+
1132
+ sys_prompt = [{"role": "system", "content": prompt_system}]
1133
+ user_prompt = [{"role": "user", "content": prompt_user.replace(':\ ', ': ')}]
1134
+
1135
+ try:
1136
+ prompt = self.decoder_tokenizer.apply_chat_template(
1137
+ sys_prompt + user_prompt,
1138
+ tokenize=False,
1139
+ add_generation_prompt=True,
1140
+ enable_thinking=False
1141
+ )
1142
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1143
+
1144
+ if answer is not None:
1145
+ assistant_prompt = [{"role": "assistant", "content": answer}]
1146
+ response = self.decoder_tokenizer.apply_chat_template(
1147
+ sys_prompt + user_prompt + assistant_prompt,
1148
+ tokenize=False,
1149
+ add_generation_prompt=False,
1150
+ enable_thinking=False
1151
+ )
1152
+ else:
1153
+ response = prompt
1154
+
1155
+ except TemplateError as e:
1156
+ if "System role not supported" in str(e):
1157
+ combined_content = prompt_system + '\n' + prompt_user.replace(':\ ', ': ')
1158
+ messages = [{"role": "user", "content": combined_content}]
1159
+
1160
+ prompt = self.decoder_tokenizer.apply_chat_template(
1161
+ messages,
1162
+ tokenize=False,
1163
+ add_generation_prompt=True,
1164
+ enable_thinking=False
1165
+ )
1166
+ prompt_len = len(self.decoder_tokenizer.encode(prompt, add_special_tokens=False))
1167
+
1168
+ if answer is not None:
1169
+ messages_with_answer = messages + [{"role": "assistant", "content": answer}]
1170
+ response = self.decoder_tokenizer.apply_chat_template(
1171
+ messages_with_answer,
1172
+ tokenize=False,
1173
+ add_generation_prompt=False,
1174
+ enable_thinking=False
1175
+ )
1176
+ else:
1177
+ response = prompt
1178
+ else:
1179
+ raise e
1180
+
1181
+ return prompt_len, response
1182
+
1183
+ # Model saving and loading methods
1184
+ def save_pretrained(self, save_directory: str, **kwargs):
1185
+ """Save only the LoRA adapters and their configurations."""
1186
+ if self.lora:
1187
+ if not os.path.exists(save_directory):
1188
+ os.makedirs(save_directory)
1189
+
1190
+ # Save LoRA adapter weights
1191
+ torch.save(
1192
+ self._get_all_adapters_state_dict(),
1193
+ os.path.join(save_directory, "adapters.pth")
1194
+ )
1195
+
1196
+ # Save first and last layers of decoder
1197
+ torch.save(
1198
+ self._get_decoder_first_and_last_layer_state_dict(),
1199
+ os.path.join(save_directory, "decoder_first_last_layers.pth")
1200
+ )
1201
+
1202
+ # Save configuration
1203
+ self.config.save_pretrained(save_directory)
1204
+ else:
1205
+ super().save_pretrained(save_directory, **kwargs)
1206
+
1207
+ def _get_all_adapters_state_dict(self) -> Dict[str, Dict[str, torch.Tensor]]:
1208
+ """Return the state dicts of all adapters."""
1209
+ return {
1210
+ key: {k: v.cpu() for k, v in self.decoder.get_adapter_state_dict(key).items()}
1211
+ for key in self.adapter_keys
1212
+ }
1213
+
1214
+ def _get_decoder_first_and_last_layer_state_dict(self) -> Dict[str, torch.Tensor]:
1215
+ """Get first and last layers that change when adding tokens."""
1216
+ out = {}
1217
+ for k, v in self.decoder.named_parameters():
1218
+ if 'lm_head.weight' in k or 'embed_tokens.weight' in k:
1219
+ out[k] = v.cpu()
1220
+ return out
1221
+
1222
+ @classmethod
1223
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
1224
+ """Load model from pretrained checkpoint."""
1225
+ # Load configuration
1226
+ config = CLaRaConfig.from_pretrained(pretrained_model_name_or_path)
1227
+
1228
+ # Update config with kwargs
1229
+ for key, value in kwargs.items():
1230
+ if hasattr(config, key):
1231
+ setattr(config, key, value)
1232
+
1233
+ map_location = torch.device("cpu") if not torch.cuda.is_available() else None
1234
+
1235
+ if config.lora:
1236
+ # Delay adapter construction
1237
+ config.load_adapters = False
1238
+ if 'device_map' in kwargs:
1239
+ config.device_map = kwargs['device_map']
1240
+
1241
+ # Initialize model
1242
+ print(f"Initializing model from trained checkpoint: {config}")
1243
+ model = cls(config)
1244
+
1245
+ # Load first and last layers
1246
+ try:
1247
+ first_and_last_layers_path = hf_hub_download(
1248
+ repo_id=pretrained_model_name_or_path,
1249
+ filename="decoder_first_last_layers.pth"
1250
+ )
1251
+ except Exception:
1252
+ first_and_last_layers_path = os.path.join(
1253
+ pretrained_model_name_or_path, "decoder_first_last_layers.pth"
1254
+ )
1255
+
1256
+ if os.path.exists(first_and_last_layers_path):
1257
+ first_and_last_decoder_state_dict = torch.load(
1258
+ first_and_last_layers_path, map_location=map_location, weights_only=True
1259
+ )
1260
+ for key in first_and_last_decoder_state_dict:
1261
+ assert key in model.decoder.state_dict()
1262
+ model.decoder.load_state_dict(first_and_last_decoder_state_dict, strict=False)
1263
+ else:
1264
+ print(f'First and last layer not found: {first_and_last_layers_path}')
1265
+
1266
+ peft_config = model._get_peft_config(lora_r=config.lora_r)
1267
+
1268
+ # Load LoRA adapters
1269
+ try:
1270
+ adapters_path = hf_hub_download(
1271
+ repo_id=pretrained_model_name_or_path,
1272
+ filename="adapters.pth"
1273
+ )
1274
+ except Exception:
1275
+ adapters_path = os.path.join(pretrained_model_name_or_path, "adapters.pth")
1276
+
1277
+ if os.path.exists(adapters_path):
1278
+ adapters_state_dict = torch.load(adapters_path, map_location=map_location, weights_only=True)
1279
+ model._load_adapters_from_state_dict(adapters_state_dict, peft_config, config)
1280
+ else:
1281
+ warnings.warn(f'Adapters not found at {adapters_path}')
1282
+
1283
+ model._set_all_adapters()
1284
+ config.load_adapters = True
1285
+ return model
1286
+ else:
1287
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
1288
+ def _load_adapters_from_state_dict(self, adapters_state_dict: Dict, peft_config: LoraConfig, config: CLaRaConfig):
1289
+ """Load adapters from state dict based on training stage."""
1290
+ if not getattr(config, 'pure_inference', False):
1291
+ for key, val in adapters_state_dict.items():
1292
+ # Skip certain adapters based on training stage
1293
+ if config.training_stage == 'stage1' and key == 'query_reasoner_adapter':
1294
+ continue
1295
+ elif config.training_stage == 'stage1_2' and key in ['query_reasoner_adapter', 'decoder_adapter']:
1296
+ continue
1297
+ elif config.training_stage == 'stage2_reasoning' and key == 'decoder_adapter':
1298
+ continue
1299
+
1300
+ self._load_adapter_from_state_dict(
1301
+ peft_config=peft_config,
1302
+ adapter_name=key,
1303
+ adapter_state_dict=val
1304
+ )
1305
+ else:
1306
+ # Load all adapters for pure inference
1307
+ for key, val in adapters_state_dict.items():
1308
+ self._load_adapter_from_state_dict(
1309
+ peft_config=peft_config,
1310
+ adapter_name=key,
1311
+ adapter_state_dict=val
1312
+ )
1313
+
1314
+ # Handle special cases for stage 2 training
1315
+ if config.training_stage == 'stage2' and 'query_reasoner_adapter' not in adapters_state_dict:
1316
+ self._handle_query_reasoner_adapter_loading(adapters_state_dict, peft_config)
1317
+
1318
+ def _load_adapter_from_state_dict(self, peft_config: LoraConfig, adapter_name: str, adapter_state_dict: Dict):
1319
+ """Create adapter from state dict."""
1320
+ print(f'Loading checkpoint adapter: {adapter_name}')
1321
+ self.decoder.load_adapter(
1322
+ peft_config=peft_config,
1323
+ adapter_name=adapter_name,
1324
+ adapter_state_dict=adapter_state_dict
1325
+ )
1326
+ self.adapter_keys.append(adapter_name)
1327
+
1328
+ def _handle_query_reasoner_adapter_loading(self, adapters_state_dict: Dict, peft_config: LoraConfig):
1329
+ """Handle special loading logic for query reasoner adapter."""
1330
+ if 'encoder_adapter' in adapters_state_dict and 'query_reasoner_adapter' not in adapters_state_dict:
1331
+ # Rename encoder adapter to query reasoner adapter
1332
+ renamed = {}
1333
+ for k, v in adapters_state_dict['encoder_adapter'].items():
1334
+ new_k = k.replace('encoder_adapter', 'query_reasoner_adapter')
1335
+ renamed[new_k] = v.detach().clone()
1336
+
1337
+ self._load_adapter_from_state_dict(
1338
+ peft_config=peft_config,
1339
+ adapter_name='query_reasoner_adapter',
1340
+ adapter_state_dict=renamed
1341
+ )
1342
+ print('Loaded query_reasoner_adapter from stage 1 compressor checkpoint')
1343
+ else:
1344
+ # Create new adapter randomly
1345
+ self.decoder.add_adapter(peft_config, 'query_reasoner_adapter')
1346
+ self.adapter_keys.append('query_reasoner_adapter')
1347
+ print('Loaded query_reasoner_adapter randomly for stage 2 training')
1348
+
1349
+ # Forward pass methods
1350
+ def forward(self,
1351
+ batch: Dict = None,
1352
+ questions: List[str] = None,
1353
+ documents: List[List[str]] = None,
1354
+ answers: List[str] = None,
1355
+ original_answer_gen_api: str = None,
1356
+ stage2_mips: bool = False,
1357
+ stage2_retrieval_top_n: int = None) -> Tuple[torch.Tensor, Dict]:
1358
+ """
1359
+ Forward pass with support for both batch and legacy interfaces.
1360
+
1361
+ Args:
1362
+ batch: Preprocessed batch dict (new interface)
1363
+ questions: List of questions (legacy interface)
1364
+ documents: List of document lists (legacy interface)
1365
+ answers: List of answers (legacy interface)
1366
+ original_answer_gen_api: API URL for generation (legacy interface)
1367
+ stage2_mips: Whether to use MIPS for stage2
1368
+ stage2_retrieval_top_n: Top-n for stage2 retrieval
1369
+
1370
+ Returns:
1371
+ Tuple of (loss, additional_outputs_dict)
1372
+ """
1373
+ if batch is not None:
1374
+ return self._forward_batch(batch, stage2_mips, stage2_retrieval_top_n)
1375
+ else:
1376
+ return self._forward_legacy(questions, documents, answers, original_answer_gen_api)
1377
+
1378
+ def _forward_batch(self, batch: Dict, stage2_mips: bool, stage2_retrieval_top_n: int) -> Tuple[torch.Tensor, Dict]:
1379
+ """Handle batch-based forward pass."""
1380
+ stage = batch.get("stage", None)
1381
+
1382
+ if stage in ["stage1", "stage1_2"]:
1383
+ return self._forward_stage1_batch(batch)
1384
+ elif stage == "stage2":
1385
+ return self._forward_stage2_batch(batch, stage2_mips, stage2_retrieval_top_n)
1386
+ elif stage == "stage2_pretrain_retrieval":
1387
+ return self._forward_stage2_pretrain_batch(batch, stage2_mips, stage2_retrieval_top_n)
1388
+ elif stage == "stage2_reasoning":
1389
+ return self._forward_stage2_reasoning_batch(batch)
1390
+ else:
1391
+ raise ValueError(f"Unknown stage: {stage}")
1392
+
1393
+ def _forward_stage1_batch(self, batch: Dict) -> Tuple[torch.Tensor, Dict]:
1394
+ """Forward pass for stage 1 training."""
1395
+ # Move tensors to device
1396
+ enc_input_ids = batch["enc_input_ids"].to(self.decoder.device)
1397
+ enc_attention_mask = batch["enc_attention_mask"].to(self.decoder.device)
1398
+ dec_input_ids = batch["dec_input_ids"].to(self.decoder.device)
1399
+ dec_attention_mask = batch["dec_attention_mask"].to(self.decoder.device)
1400
+ labels = batch["labels"].to(self.decoder.device)
1401
+
1402
+ out = self._forward_stage_1(
1403
+ enc_input_ids=enc_input_ids,
1404
+ enc_attention_mask=enc_attention_mask,
1405
+ dec_input_ids=dec_input_ids,
1406
+ dec_attention_mask=dec_attention_mask,
1407
+ labels=labels,
1408
+ )
1409
+ return out["loss"], {"logits": out["logits"], "mse_loss": out["mse_loss"]}
1410
+
1411
+ def _forward_stage2_batch(self, batch: Dict, stage2_mips: bool, stage2_retrieval_top_n: int) -> Tuple[torch.Tensor, Dict]:
1412
+ """Forward pass for stage 2 training."""
1413
+ self.decoder.set_adapter('query_reasoner_adapter')
1414
+
1415
+ B = batch["labels"].shape[0]
1416
+ query_reps = self._compr_query_reasoner_stage2(
1417
+ batch["query_input_ids"].to(self.decoder.device),
1418
+ batch["query_attention_mask"].to(self.decoder.device)
1419
+ )
1420
+
1421
+ enc_input_ids = batch["enc_input_ids"].to(self.decoder.device)
1422
+ enc_attention_mask = batch["enc_attention_mask"].to(self.decoder.device)
1423
+ dec_input_ids = batch["dec_input_ids"].to(self.decoder.device)
1424
+ dec_attention_mask = batch["dec_attention_mask"].to(self.decoder.device)
1425
+ labels = batch["labels"].to(self.decoder.device)
1426
+
1427
+ # Document retrieval and selection
1428
+ if stage2_mips:
1429
+ retrieved_doc_embeddings = self._retrieve_embeddings(
1430
+ query_reps, stage2_retrieval_top_n=stage2_retrieval_top_n
1431
+ )
1432
+ scores = torch.bmm(
1433
+ query_reps.unsqueeze(1),
1434
+ retrieved_doc_embeddings.transpose(1, 2)
1435
+ ).squeeze(1)
1436
+ z, topk_idx = differentiable_topk(scores, self.generation_top_k, temperature=1)
1437
+ selected = torch.einsum('bkn,bnd->bkd', z, retrieved_doc_embeddings)
1438
+ selected = selected.view(selected.size(0) * selected.size(1), -1, self.hidden_size)
1439
+ else:
1440
+ with torch.no_grad():
1441
+ retrieved_doc_embeddings, mse_loss = self.compress(enc_input_ids, enc_attention_mask)
1442
+
1443
+ stage2_retrieval_top_n = retrieved_doc_embeddings.shape[0] // B
1444
+ retrieved_doc_embeddings = retrieved_doc_embeddings.reshape(B, stage2_retrieval_top_n, -1)
1445
+ query_reps = query_reps.to(retrieved_doc_embeddings.dtype)
1446
+
1447
+ scores = torch.bmm(
1448
+ F.normalize(query_reps, dim=-1, p=2).unsqueeze(1).float(),
1449
+ F.normalize(retrieved_doc_embeddings, dim=-1, p=2).float().transpose(1, 2)
1450
+ ).squeeze(1)
1451
+
1452
+ z, topk_idx = differentiable_topk(scores, self.generation_top_k, temperature=0.02)
1453
+ selected = torch.einsum('bkn,bnd->bkd', z.to(retrieved_doc_embeddings.dtype), retrieved_doc_embeddings)
1454
+ selected = selected.view(selected.size(0) * selected.size(1), -1, self.hidden_size)
1455
+
1456
+ inputs_embeds = self._replace_emb_stage2(selected, dec_input_ids)
1457
+
1458
+ if 'decoder_adapter' in self.adapter_keys:
1459
+ self.decoder.set_adapter('decoder_adapter')
1460
+
1461
+ dec_out = self.decoder(
1462
+ inputs_embeds=inputs_embeds,
1463
+ attention_mask=dec_attention_mask,
1464
+ labels=labels,
1465
+ )
1466
+
1467
+ self.decoder.set_adapter(['decoder_adapter', 'query_reasoner_adapter'])
1468
+ return dec_out.loss, {"logits": dec_out.logits, "topk_idx": topk_idx, "mse_loss": mse_loss}
1469
+
1470
+ def _forward_stage2_pretrain_batch(self, batch: Dict, stage2_mips: bool, stage2_retrieval_top_n: int) -> Tuple[torch.Tensor, Dict]:
1471
+ """Forward pass for stage 2 pretraining with retrieval."""
1472
+ self.decoder.set_adapter('query_reasoner_adapter')
1473
+
1474
+ B = batch["labels"].shape[0]
1475
+ N = batch["enc_input_ids"].shape[0] // B
1476
+ device = self.decoder.device
1477
+
1478
+ query_reps = self._compr_query_reasoner_stage2(
1479
+ batch["query_input_ids"].to(device),
1480
+ batch["query_attention_mask"].to(device)
1481
+ )
1482
+
1483
+ enc_input_ids = batch["enc_input_ids"].to(device)
1484
+ enc_attention_mask = batch["enc_attention_mask"].to(device)
1485
+
1486
+ with torch.no_grad():
1487
+ retrieved_doc_embeddings, mse_loss = self.compress(enc_input_ids, enc_attention_mask)
1488
+
1489
+ stage2_retrieval_top_n = retrieved_doc_embeddings.shape[0] // B
1490
+ retrieved_doc_embeddings = retrieved_doc_embeddings.reshape(B, stage2_retrieval_top_n, -1)
1491
+ query_reps = query_reps.to(retrieved_doc_embeddings.dtype)
1492
+
1493
+ scores = torch.bmm(
1494
+ F.normalize(query_reps, dim=-1, p=2).unsqueeze(1).float(),
1495
+ F.normalize(retrieved_doc_embeddings, dim=-1, p=2).float().transpose(1, 2)
1496
+ ).squeeze(1)
1497
+
1498
+ pos_index = batch["pos_index"]
1499
+ pos_mask = build_pos_mask(pos_index, N, device)
1500
+ tau = 0.02
1501
+ logits = scores / tau
1502
+
1503
+ pos_logits = logits.masked_fill(~pos_mask, float('-inf'))
1504
+ num = torch.logsumexp(pos_logits, dim=-1)
1505
+ den = torch.logsumexp(logits, dim=-1)
1506
+ loss_vec = -(num - den)
1507
+ valid = pos_mask.any(dim=-1)
1508
+ loss = loss_vec[valid].mean()
1509
+
1510
+ topk = self.generation_top_k
1511
+ topk_idx = logits.topk(k=min(topk, N), dim=-1).indices
1512
+
1513
+ return loss, {"logits": [[]], "topk_idx": topk_idx, "mse_loss": mse_loss}
1514
+
1515
+ def _forward_stage2_reasoning_batch(self, batch: Dict) -> Tuple[torch.Tensor, Dict]:
1516
+ """Forward pass for stage 2 reasoning training."""
1517
+ B = batch["labels"].shape[0]
1518
+ enc_input_ids = batch["enc_input_ids"].to(self.decoder.device)
1519
+ enc_attention_mask = batch["enc_attention_mask"].to(self.decoder.device)
1520
+ dec_input_ids = batch["dec_input_ids"].to(self.decoder.device)
1521
+ dec_attention_mask = batch["dec_attention_mask"].to(self.decoder.device)
1522
+ labels = batch["labels"].to(self.decoder.device)
1523
+
1524
+ if sum(batch["docs_num"]) != 0:
1525
+ with torch.no_grad():
1526
+ selected, mse_loss = self.compress(enc_input_ids, enc_attention_mask)
1527
+ indices = batch["docs_num"]
1528
+ inputs_embeds = self._replace_reasoning_embeddings(selected, dec_input_ids, indices)
1529
+ else:
1530
+ inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
1531
+ mse_loss = 0
1532
+
1533
+ if 'decoder_adapter' in self.adapter_keys:
1534
+ self.decoder.set_adapter('decoder_adapter')
1535
+
1536
+ dec_out = self.decoder(
1537
+ inputs_embeds=inputs_embeds,
1538
+ attention_mask=dec_attention_mask,
1539
+ labels=labels,
1540
+ )
1541
+
1542
+ self.decoder.set_adapter(['decoder_adapter'])
1543
+ return dec_out.loss, {"logits": dec_out.logits, "mse_loss": mse_loss}
1544
+
1545
+ def _forward_stage_1(self,
1546
+ enc_input_ids: torch.LongTensor = None,
1547
+ enc_attention_mask: torch.LongTensor = None,
1548
+ dec_input_ids: torch.LongTensor = None,
1549
+ dec_attention_mask: torch.LongTensor = None,
1550
+ labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
1551
+ """Stage 1 forward pass for document compression and QA."""
1552
+ assert enc_input_ids.size() == enc_attention_mask.size()
1553
+
1554
+ # Flatten 3D inputs to 2D if needed
1555
+ if len(enc_input_ids.size()) == 3:
1556
+ batch_size, top_k, seq_length = enc_input_ids.size()
1557
+ enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
1558
+ enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
1559
+
1560
+ assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k
1561
+
1562
+ # Compress documents
1563
+ compressed_embs, mse_loss = self.compress(enc_input_ids, enc_attention_mask)
1564
+
1565
+ # Replace memory tokens with compressed embeddings
1566
+ inputs_embeds = self._replace_emb(compressed_embs, dec_input_ids)
1567
+
1568
+ # Detach if compressor-only training
1569
+ if (self.training_form == "compressor") and (self.compr is None):
1570
+ inputs_embeds = inputs_embeds.detach()
1571
+
1572
+ # Set decoder adapter
1573
+ if 'decoder_adapter' in self.adapter_keys:
1574
+ self.decoder.set_adapter('decoder_adapter')
1575
+
1576
+ # Forward through decoder
1577
+ decoder_outputs = self.decoder(
1578
+ inputs_embeds=inputs_embeds,
1579
+ attention_mask=dec_attention_mask,
1580
+ labels=labels
1581
+ )
1582
+
1583
+ # Reactivate all adapters
1584
+ self.decoder.set_adapter(['decoder_adapter', 'encoder_adapter'])
1585
+
1586
+ return {
1587
+ "loss": decoder_outputs.loss,
1588
+ "logits": decoder_outputs.logits,
1589
+ "mse_loss": mse_loss
1590
+ }
1591
+
1592
+ def _replace_reasoning_embeddings(self,
1593
+ compressed_embs: torch.Tensor,
1594
+ dec_input_ids: torch.LongTensor,
1595
+ docs_per_example: List[int]) -> torch.Tensor:
1596
+ """Replace memory slots with compressed embeddings for reasoning."""
1597
+ device = dec_input_ids.device
1598
+ inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
1599
+
1600
+ num_embs = compressed_embs.size(1)
1601
+ slot_len = num_embs + (1 if getattr(self, "sep", False) else 0)
1602
+
1603
+ if not isinstance(docs_per_example, torch.Tensor):
1604
+ docs_per_example = torch.tensor(docs_per_example, device=device, dtype=torch.long)
1605
+ else:
1606
+ docs_per_example = docs_per_example.to(device=device, dtype=torch.long)
1607
+
1608
+ offsets = torch.zeros(docs_per_example.size(0) + 1, device=device, dtype=torch.long)
1609
+ offsets[1:] = torch.cumsum(docs_per_example, dim=0)
1610
+ total_docs = int(offsets[-1].item())
1611
+ assert total_docs == compressed_embs.size(0)
1612
+
1613
+ mem_id = self.decoder_tokenizer.mem_token_ids[0]
1614
+ B, L, H = inputs_embeds.size()
1615
+
1616
+ for i in range(B):
1617
+ # Find first memory token position
1618
+ mem_pos = (dec_input_ids[i] == mem_id).nonzero(as_tuple=True)[0]
1619
+ if mem_pos.numel() == 0:
1620
+ continue
1621
+ first_mem_idx = int(mem_pos[0].item())
1622
+
1623
+ n_docs_i = int(docs_per_example[i].item())
1624
+ base = int(offsets[i].item())
1625
+
1626
+ needed_len = first_mem_idx + n_docs_i * slot_len
1627
+ assert needed_len <= L
1628
+
1629
+ for local_j in range(n_docs_i):
1630
+ global_j = base + local_j
1631
+ start_idx = first_mem_idx + local_j * slot_len
1632
+ target_slice = inputs_embeds[i, start_idx:start_idx + num_embs, :]
1633
+ src = compressed_embs[global_j]
1634
+ assert target_slice.size() == src.size()
1635
+ inputs_embeds[i, start_idx:start_idx + num_embs, :] = src
1636
+
1637
+ return inputs_embeds
1638
+
1639
+ def _generate(self, model_input: Dict[str, torch.Tensor], max_new_tokens: int = 128,
1640
+ return_doc_embeddings: bool = False) -> List[str]:
1641
+ """Generate text from model inputs."""
1642
+ enc_input_ids = model_input['enc_input_ids']
1643
+ enc_attention_mask = model_input['enc_attention_mask']
1644
+ dec_input_ids = model_input['dec_input_ids']
1645
+ dec_attention_mask = model_input['dec_attention_mask']
1646
+
1647
+ assert enc_input_ids.size() == enc_attention_mask.size()
1648
+
1649
+ if len(enc_input_ids.size()) == 3:
1650
+ batch_size, top_k, seq_length = enc_input_ids.size()
1651
+ enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
1652
+ enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
1653
+
1654
+ assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k
1655
+
1656
+ compressed_embs, _ = self.compress(enc_input_ids.to('cuda'), enc_attention_mask.to('cuda'))
1657
+ inputs_embeds = self._replace_emb(compressed_embs, dec_input_ids.to('cuda'))
1658
+
1659
+ if 'decoder_adapter' in self.adapter_keys:
1660
+ self.decoder.set_adapter('decoder_adapter')
1661
+
1662
+ output_ids = self.decoder.generate(
1663
+ inputs_embeds=inputs_embeds.to("cuda"),
1664
+ attention_mask=dec_attention_mask.to("cuda"),
1665
+ do_sample=False,
1666
+ top_p=None,
1667
+ max_new_tokens=max_new_tokens
1668
+ )
1669
+
1670
+ decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
1671
+
1672
+ if return_doc_embeddings:
1673
+ assert 'batch_size' in locals() and 'top_k' in locals()
1674
+ compressed_embs = compressed_embs.view(batch_size, top_k, compressed_embs.size(1), compressed_embs.size(2))
1675
+ return decoded, compressed_embs
1676
+ else:
1677
+ return decoded
1678
+
1679
+
1680
+ # Example usage and testing
1681
+ if __name__ == '__main__':
1682
+ # Example configuration
1683
+ cfg = CLaRaConfig(
1684
+ decoder_model_name='/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2',
1685
+ compr_model_name="mistral_trimmed",
1686
+ compr_rate=64,
1687
+ compr_n_layers=5,
1688
+ compr_mlp_hidden_dim=8096,
1689
+ compr_use_mlp=False,
1690
+ lora=True,
1691
+ lora_compressor=True,
1692
+ training_form="both",
1693
+ load_adapters=True,
1694
+ kbtc_training=False,
1695
+ optimize_mem_tokens=True,
1696
+ different_mem_tokens=True,
1697
+ attn_implementation='flash_attention_2'
1698
+ )
1699
+
1700
+ # Initialize model
1701
+ clara = CLaRa(cfg)
1702
+
1703
+ # Save and reload test
1704
+ clara.save_pretrained('test_ckpt')
1705
+
1706
+ del clara
1707
+ torch.cuda.empty_cache()
1708
+ gc.collect()
1709
+
1710
+ # Reload model
1711
+ clara = CLaRa.from_pretrained('test_ckpt')
1712
+ print("Model successfully loaded!")
compression-16/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
compression-16/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
compression-16/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
compression-16/tokenizer_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<s>",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "</s>",
35
+ "extra_special_tokens": {},
36
+ "legacy": false,
37
+ "model_max_length": 1000000000000000019884624838656,
38
+ "pad_token": "</s>",
39
+ "sp_model_kwargs": {},
40
+ "spaces_between_special_tokens": false,
41
+ "tokenizer_class": "LlamaTokenizer",
42
+ "unk_token": "<unk>",
43
+ "use_default_system_prompt": false
44
+ }