rishabhsetiya commited on
Commit
ffcb97c
·
verified ·
1 Parent(s): 3cced9f

Update fine_tuning.py

Browse files
Files changed (1) hide show
  1. fine_tuning.py +1 -7
fine_tuning.py CHANGED
@@ -10,7 +10,6 @@ import transformers
10
  from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
11
  from peft import LoraConfig, get_peft_model
12
  from sentence_transformers import SentenceTransformer, util
13
- import spaces
14
 
15
  # -----------------------------
16
  # ENVIRONMENT / CACHE
@@ -51,7 +50,6 @@ class LoraLinear(nn.Module):
51
  else:
52
  self.lora_A, self.lora_B, self.lora_dropout = None, None, None
53
 
54
- @spaces.GPU
55
  def forward(self, x):
56
  result = F.linear(x, self.weight, self.bias)
57
  if self.r > 0:
@@ -71,7 +69,6 @@ class MoELoRALinear(nn.Module):
71
  ])
72
  self.gate = nn.Linear(base_linear.in_features, num_experts)
73
 
74
- @spaces.GPU
75
  def forward(self, x):
76
  base_out = self.base_linear(x)
77
  gate_scores = torch.softmax(self.gate(x), dim=-1)
@@ -80,7 +77,6 @@ class MoELoRALinear(nn.Module):
80
  expert_out += gate_scores[..., i:i+1] * expert(x)
81
  return base_out + expert_out
82
 
83
- @spaces.GPU
84
  def replace_proj_with_moe_lora(model, r=8, num_experts=2, k=1, lora_alpha=16, lora_dropout=0.05):
85
  for layer in model.model.layers:
86
  for proj_name in ["up_proj", "down_proj"]:
@@ -113,7 +109,6 @@ def preprocess(example):
113
  # -----------------------------
114
  # LOAD & TRAIN MODEL
115
  # -----------------------------
116
- @spaces.GPU
117
  def load_and_train(model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
118
  global model
119
  current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -144,7 +139,7 @@ def load_and_train(model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
144
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
145
 
146
  training_args = TrainingArguments(
147
- learning_rate=5e-5,
148
  lr_scheduler_type="constant",
149
  output_dir="./results",
150
  num_train_epochs=4,
@@ -206,7 +201,6 @@ def validate_query(query: str, threshold: float = 0.5) -> bool:
206
  # -----------------------------
207
  # GENERATE ANSWER
208
  # -----------------------------
209
- @spaces.GPU
210
  def generate_answer(prompt, max_tokens=200):
211
  if prompt.strip() == "":
212
  return "Please enter a prompt!"
 
10
  from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
11
  from peft import LoraConfig, get_peft_model
12
  from sentence_transformers import SentenceTransformer, util
 
13
 
14
  # -----------------------------
15
  # ENVIRONMENT / CACHE
 
50
  else:
51
  self.lora_A, self.lora_B, self.lora_dropout = None, None, None
52
 
 
53
  def forward(self, x):
54
  result = F.linear(x, self.weight, self.bias)
55
  if self.r > 0:
 
69
  ])
70
  self.gate = nn.Linear(base_linear.in_features, num_experts)
71
 
 
72
  def forward(self, x):
73
  base_out = self.base_linear(x)
74
  gate_scores = torch.softmax(self.gate(x), dim=-1)
 
77
  expert_out += gate_scores[..., i:i+1] * expert(x)
78
  return base_out + expert_out
79
 
 
80
  def replace_proj_with_moe_lora(model, r=8, num_experts=2, k=1, lora_alpha=16, lora_dropout=0.05):
81
  for layer in model.model.layers:
82
  for proj_name in ["up_proj", "down_proj"]:
 
109
  # -----------------------------
110
  # LOAD & TRAIN MODEL
111
  # -----------------------------
 
112
  def load_and_train(model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
113
  global model
114
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
139
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
140
 
141
  training_args = TrainingArguments(
142
+ learning_rate=1e-4,
143
  lr_scheduler_type="constant",
144
  output_dir="./results",
145
  num_train_epochs=4,
 
201
  # -----------------------------
202
  # GENERATE ANSWER
203
  # -----------------------------
 
204
  def generate_answer(prompt, max_tokens=200):
205
  if prompt.strip() == "":
206
  return "Please enter a prompt!"