itarutomy commited on
Commit
943506c
·
verified ·
1 Parent(s): 8e20b57

Add custom modeling file

Browse files
Files changed (1) hide show
  1. modeling_gptscratch.py +38 -4
modeling_gptscratch.py CHANGED
@@ -1,18 +1,52 @@
1
- from transformers import PreTrainedModel, GPT2Config
 
2
  import torch.nn as nn
3
- from .gpt_model import GPTModel # ✅ 実際の場所に合わせて変更
 
 
4
 
5
  class GPTScratchForCausalLM(PreTrainedModel):
6
  config_class = GPT2Config
 
7
  def __init__(self, config):
8
  super().__init__(config)
 
9
  self.inner = GPTModel({
10
  "vocab_size": config.vocab_size,
11
  "emb_dim": config.n_embd,
12
  "n_heads": config.n_head,
13
  "n_layers": config.n_layer,
14
  "context_length": config.n_positions,
15
- "drop_rate": 0.1
16
  })
 
 
 
 
17
  def forward(self, input_ids, **kwargs):
18
- return self.inner(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_gptscratch.py
2
+ import torch
3
  import torch.nn as nn
4
+ from transformers import PreTrainedModel, GPT2Config
5
+ from transformers.modeling_outputs import CausalLMOutput
6
+ from .gpt_model import GPTModel # ← gpt_model.py を同梱済み前提
7
 
8
  class GPTScratchForCausalLM(PreTrainedModel):
9
  config_class = GPT2Config
10
+
11
  def __init__(self, config):
12
  super().__init__(config)
13
+ # 学習時の cfg に合わせて内部モデルを構築
14
  self.inner = GPTModel({
15
  "vocab_size": config.vocab_size,
16
  "emb_dim": config.n_embd,
17
  "n_heads": config.n_head,
18
  "n_layers": config.n_layer,
19
  "context_length": config.n_positions,
20
+ "drop_rate": 0.1,
21
  })
22
+
23
+ # 互換のために lm_head を生やしておく(重みは inner.out_head を参照)
24
+ self.lm_head = self.inner.out_head
25
+
26
  def forward(self, input_ids, **kwargs):
27
+ logits = self.inner(input_ids)
28
+ # HF の慣習に合わせて CausalLMOutput で返す
29
+ return CausalLMOutput(logits=logits)
30
+
31
+ # --- これが肝:最小実装の generate(greedy) ---
32
+ @torch.no_grad()
33
+ def generate(self, input_ids, max_new_tokens=32, **gen_kwargs):
34
+ # 非教師ありの最小版(Greedy)。pad/attention_mask 等は省略
35
+ for _ in range(max_new_tokens):
36
+ out = self.forward(input_ids)
37
+ next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
38
+ input_ids = torch.cat([input_ids, next_token], dim=1)
39
+ return input_ids
40
+
41
+ # (任意)古いチェックポイントからのキー名ズレを吸収
42
+ @classmethod
43
+ def _load_state_dict_into_model(cls, model, state_dict, *args, **kwargs):
44
+ # inner.inner.* → inner.* にリネーム
45
+ remap = {}
46
+ for k, v in list(state_dict.items()):
47
+ if k.startswith("inner.inner."):
48
+ remap[k.replace("inner.inner.", "inner.", 1)] = v
49
+ del state_dict[k]
50
+ state_dict.update(remap)
51
+ return super()._load_state_dict_into_model(model, state_dict, *args, **kwargs)
52
+