Spaces:
Running
Running
Upload 5 files
Browse files- main.sh +24 -0
- model_merger.py +187 -0
- questioner_train.sh +35 -0
- questioner_train_penalty.sh +46 -0
- solver_train.sh +39 -0
main.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Base_model=$1
|
| 2 |
+
Model_abbr=$2
|
| 3 |
+
echo "Model_abbr: $Model_abbr"
|
| 4 |
+
# Initialize first iteration with base model
|
| 5 |
+
bash scripts/questioner_train_penalty.sh $Base_model $Base_model ${Model_abbr}_questioner_v1
|
| 6 |
+
bash scripts/solver_train.sh $Base_model ${STORAGE_PATH}/models/${Model_abbr}_questioner_v1/global_step_5/actor/huggingface ${Model_abbr}_solver_v1
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
for i in {2..5}; do
|
| 10 |
+
prev=$((i-1))
|
| 11 |
+
|
| 12 |
+
bash scripts/questioner_train_penalty.sh \
|
| 13 |
+
${STORAGE_PATH}/models/${Model_abbr}_solver_v${prev}/global_step_15/actor/huggingface \
|
| 14 |
+
${STORAGE_PATH}/models/${Model_abbr}_questioner_v${prev}/global_step_5/actor/huggingface \
|
| 15 |
+
${Model_abbr}_questioner_v${i}
|
| 16 |
+
|
| 17 |
+
# Train solver
|
| 18 |
+
bash scripts/solver_train.sh \
|
| 19 |
+
${STORAGE_PATH}/models/${Model_abbr}_solver_v${prev}/global_step_15/actor/huggingface \
|
| 20 |
+
${STORAGE_PATH}/models/${Model_abbr}_questioner_v${i}/global_step_5/actor/huggingface \
|
| 21 |
+
${Model_abbr}_solver_v${i}
|
| 22 |
+
done
|
| 23 |
+
|
| 24 |
+
bash evaluation/eval_math.sh $Base_model
|
model_merger.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
+
from typing import Dict, List, Tuple
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from torch.distributed._tensor import DTensor, Placement, Shard
|
| 24 |
+
from transformers import (
|
| 25 |
+
AutoConfig,
|
| 26 |
+
AutoModelForCausalLM,
|
| 27 |
+
AutoModelForTokenClassification,
|
| 28 |
+
AutoModelForVision2Seq,
|
| 29 |
+
PretrainedConfig,
|
| 30 |
+
PreTrainedModel,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
|
| 35 |
+
if placement.is_replicate():
|
| 36 |
+
return tensors[0]
|
| 37 |
+
elif placement.is_partial():
|
| 38 |
+
raise NotImplementedError("Partial placement is not supported yet")
|
| 39 |
+
elif placement.is_shard():
|
| 40 |
+
return torch.cat(tensors, dim=placement.dim).contiguous()
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"Unsupported placement: {placement}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def upload_model_to_huggingface(local_path: str, remote_path: str):
|
| 46 |
+
# Push to hugging face
|
| 47 |
+
from huggingface_hub import HfApi
|
| 48 |
+
|
| 49 |
+
api = HfApi()
|
| 50 |
+
api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
|
| 51 |
+
api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
parser = argparse.ArgumentParser()
|
| 56 |
+
parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
|
| 57 |
+
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
local_dir: str = args.local_dir
|
| 60 |
+
|
| 61 |
+
assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface."
|
| 62 |
+
|
| 63 |
+
# copy rank zero to find the shape of (dp, fsdp)
|
| 64 |
+
rank = 0
|
| 65 |
+
world_size = 0
|
| 66 |
+
for filename in os.listdir(local_dir):
|
| 67 |
+
match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
|
| 68 |
+
if match:
|
| 69 |
+
world_size = match.group(1)
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
assert world_size, "No model file with the proper format."
|
| 73 |
+
|
| 74 |
+
rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
|
| 75 |
+
state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
|
| 76 |
+
pivot_key = sorted(state_dict.keys())[0]
|
| 77 |
+
weight = state_dict[pivot_key]
|
| 78 |
+
if isinstance(weight, DTensor):
|
| 79 |
+
# get sharding info
|
| 80 |
+
device_mesh = weight.device_mesh
|
| 81 |
+
mesh = device_mesh.mesh
|
| 82 |
+
mesh_dim_names = device_mesh.mesh_dim_names
|
| 83 |
+
else:
|
| 84 |
+
# for non-DTensor
|
| 85 |
+
mesh = np.array([int(world_size)], dtype=np.int64)
|
| 86 |
+
mesh_dim_names = ("fsdp",)
|
| 87 |
+
|
| 88 |
+
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
|
| 89 |
+
|
| 90 |
+
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}."
|
| 91 |
+
|
| 92 |
+
if "tp" in mesh_dim_names:
|
| 93 |
+
# fsdp * tp
|
| 94 |
+
total_shards = mesh.shape[-1] * mesh.shape[-2]
|
| 95 |
+
mesh_shape = (mesh.shape[-2], mesh.shape[-1])
|
| 96 |
+
else:
|
| 97 |
+
# fsdp
|
| 98 |
+
total_shards = mesh.shape[-1]
|
| 99 |
+
mesh_shape = (mesh.shape[-1],)
|
| 100 |
+
|
| 101 |
+
print(f"Processing {total_shards} model shards in total.")
|
| 102 |
+
model_state_dict_lst = []
|
| 103 |
+
model_state_dict_lst.append(state_dict)
|
| 104 |
+
model_state_dict_lst.extend([""] * (total_shards - 1))
|
| 105 |
+
|
| 106 |
+
def process_one_shard(rank, model_state_dict_lst):
|
| 107 |
+
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
|
| 108 |
+
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 109 |
+
model_state_dict_lst[rank] = state_dict
|
| 110 |
+
return state_dict
|
| 111 |
+
|
| 112 |
+
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
|
| 113 |
+
for rank in range(1, total_shards):
|
| 114 |
+
executor.submit(process_one_shard, rank, model_state_dict_lst)
|
| 115 |
+
|
| 116 |
+
state_dict: Dict[str, List[torch.Tensor]] = {}
|
| 117 |
+
param_placements: Dict[str, List[Placement]] = {}
|
| 118 |
+
keys = set(model_state_dict_lst[0].keys())
|
| 119 |
+
for key in keys:
|
| 120 |
+
state_dict[key] = []
|
| 121 |
+
for model_state_dict in model_state_dict_lst:
|
| 122 |
+
try:
|
| 123 |
+
tensor = model_state_dict.pop(key)
|
| 124 |
+
except Exception:
|
| 125 |
+
print(f"Cannot find key {key} in rank {rank}.")
|
| 126 |
+
|
| 127 |
+
if isinstance(tensor, DTensor):
|
| 128 |
+
state_dict[key].append(tensor._local_tensor.bfloat16())
|
| 129 |
+
placements = tuple(tensor.placements)
|
| 130 |
+
# replicated placement at ddp dimension can be discarded
|
| 131 |
+
if mesh_dim_names[0] == "ddp":
|
| 132 |
+
placements = placements[1:]
|
| 133 |
+
|
| 134 |
+
if key not in param_placements:
|
| 135 |
+
param_placements[key] = placements
|
| 136 |
+
else:
|
| 137 |
+
assert param_placements[key] == placements
|
| 138 |
+
else:
|
| 139 |
+
state_dict[key].append(tensor.bfloat16())
|
| 140 |
+
|
| 141 |
+
del model_state_dict_lst
|
| 142 |
+
|
| 143 |
+
for key in sorted(state_dict):
|
| 144 |
+
if not isinstance(state_dict[key], list):
|
| 145 |
+
print(f"No need to merge key {key}")
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
if key in param_placements:
|
| 149 |
+
# merge shards
|
| 150 |
+
placements: Tuple[Shard] = param_placements[key]
|
| 151 |
+
if len(mesh_shape) == 1:
|
| 152 |
+
# 1-D list, FSDP without TP
|
| 153 |
+
assert len(placements) == 1
|
| 154 |
+
shards = state_dict[key]
|
| 155 |
+
state_dict[key] = merge_by_placement(shards, placements[0])
|
| 156 |
+
else:
|
| 157 |
+
# 2-D list, FSDP + TP
|
| 158 |
+
raise NotImplementedError("FSDP + TP is not supported yet.")
|
| 159 |
+
else:
|
| 160 |
+
state_dict[key] = torch.cat(state_dict[key], dim=0)
|
| 161 |
+
|
| 162 |
+
print("Merge completed.")
|
| 163 |
+
hf_path = os.path.join(local_dir, "huggingface")
|
| 164 |
+
config: PretrainedConfig = AutoConfig.from_pretrained(hf_path)
|
| 165 |
+
architectures: List[str] = getattr(config, "architectures", ["Unknown"])
|
| 166 |
+
|
| 167 |
+
if "ForTokenClassification" in architectures[0]:
|
| 168 |
+
AutoClass = AutoModelForTokenClassification
|
| 169 |
+
elif "ForCausalLM" in architectures[0]:
|
| 170 |
+
AutoClass = AutoModelForCausalLM
|
| 171 |
+
elif "ForConditionalGeneration" in architectures[0]:
|
| 172 |
+
AutoClass = AutoModelForVision2Seq
|
| 173 |
+
else:
|
| 174 |
+
raise NotImplementedError(f"Unknown architecture {architectures}.")
|
| 175 |
+
|
| 176 |
+
with torch.device("meta"):
|
| 177 |
+
model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16)
|
| 178 |
+
|
| 179 |
+
assert isinstance(model, PreTrainedModel)
|
| 180 |
+
model.to_empty(device="cpu")
|
| 181 |
+
|
| 182 |
+
print(f"Saving model to {hf_path}...")
|
| 183 |
+
model.save_pretrained(hf_path, state_dict=state_dict)
|
| 184 |
+
del state_dict, model
|
| 185 |
+
|
| 186 |
+
if args.hf_upload_path:
|
| 187 |
+
upload_model_to_huggingface(hf_path, args.hf_upload_path)
|
questioner_train.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
solver_model_path=$1
|
| 2 |
+
questioner_model_path=$2
|
| 3 |
+
save_path=$3
|
| 4 |
+
|
| 5 |
+
echo $STORAGE_PATH
|
| 6 |
+
|
| 7 |
+
echo "start train questioner $questioner_model_path $save_path"
|
| 8 |
+
|
| 9 |
+
bash vllm_service_init/start.sh $solver_model_path &
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m verl.trainer.main \
|
| 13 |
+
config=examples/config.yaml \
|
| 14 |
+
data.max_response_length=4096 \
|
| 15 |
+
worker.actor.model.model_path=$questioner_model_path \
|
| 16 |
+
trainer.experiment_name=$save_path \
|
| 17 |
+
trainer.save_checkpoint_path=${STORAGE_PATH}/models/$save_path \
|
| 18 |
+
trainer.total_epochs=1000 \
|
| 19 |
+
worker.reward.reward_function=./examples/reward_function/caller.py:compute_score \
|
| 20 |
+
trainer.val_freq=-1 \
|
| 21 |
+
trainer.n_gpus_per_node=4 \
|
| 22 |
+
data.format_prompt=./examples/format_prompt/questioner.jinja \
|
| 23 |
+
worker.rollout.n=16 \
|
| 24 |
+
worker.actor.global_batch_size=4 \
|
| 25 |
+
worker.actor.micro_batch_size_per_device_for_update=1 \
|
| 26 |
+
worker.actor.micro_batch_size_per_device_for_experience=1 \
|
| 27 |
+
trainer.max_steps=11
|
| 28 |
+
|
| 29 |
+
# python gpu_burn.py
|
| 30 |
+
|
| 31 |
+
pkill python
|
| 32 |
+
|
| 33 |
+
sleep 1
|
| 34 |
+
|
| 35 |
+
python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/$save_path/global_step_10/actor
|
questioner_train_penalty.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
solver_model_path=$1
|
| 4 |
+
questioner_model_path=$2
|
| 5 |
+
save_path=$3
|
| 6 |
+
echo "save_path: $save_path"
|
| 7 |
+
# 生成唯一 RUN_ID
|
| 8 |
+
RUN_ID=$(date +%s%N)
|
| 9 |
+
export RUN_ID
|
| 10 |
+
|
| 11 |
+
echo "RUN_ID=$RUN_ID"
|
| 12 |
+
|
| 13 |
+
# 启动 vllm 服务(记录 PID)
|
| 14 |
+
bash vllm_service_init/start.sh $solver_model_path $RUN_ID
|
| 15 |
+
echo "vLLM services started with RUN_ID=$RUN_ID"
|
| 16 |
+
|
| 17 |
+
# 开始训练 Questioner
|
| 18 |
+
echo "Start training questioner: $questioner_model_path -> $save_path"
|
| 19 |
+
|
| 20 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m verl.trainer.main \
|
| 21 |
+
config=examples/config.yaml \
|
| 22 |
+
data.max_response_length=4096 \
|
| 23 |
+
worker.actor.model.model_path=$questioner_model_path \
|
| 24 |
+
trainer.experiment_name=$save_path \
|
| 25 |
+
trainer.save_checkpoint_path=${STORAGE_PATH}/models/$save_path \
|
| 26 |
+
trainer.total_epochs=1000 \
|
| 27 |
+
worker.reward.reward_function=./examples/reward_function/caller_penalty.py:compute_score \
|
| 28 |
+
trainer.val_freq=-1 \
|
| 29 |
+
trainer.n_gpus_per_node=4 \
|
| 30 |
+
data.format_prompt=./examples/format_prompt/questioner.jinja \
|
| 31 |
+
worker.rollout.n=4 \
|
| 32 |
+
worker.actor.global_batch_size=16 \
|
| 33 |
+
trainer.max_steps=6 \
|
| 34 |
+
trainer.save_freq=1
|
| 35 |
+
|
| 36 |
+
sleep 5
|
| 37 |
+
|
| 38 |
+
# 合并模型
|
| 39 |
+
echo "merging model"
|
| 40 |
+
python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/$save_path/global_step_5/actor
|
| 41 |
+
|
| 42 |
+
sleep 10
|
| 43 |
+
|
| 44 |
+
pkill python
|
| 45 |
+
|
| 46 |
+
echo "questioner training finished"
|
solver_train.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
solver_model_path=$1
|
| 2 |
+
questioner_model_path=$2
|
| 3 |
+
experiment_name=$3
|
| 4 |
+
|
| 5 |
+
echo $STORAGE_PATH
|
| 6 |
+
|
| 7 |
+
echo "start train solver $experiment_name $solver_model_path $questioner_model_path"
|
| 8 |
+
|
| 9 |
+
export VLLM_DISABLE_COMPILE_CACHE=1
|
| 10 |
+
echo 'start generate question'
|
| 11 |
+
bash question_generate/question_generate.bash $questioner_model_path 1000 $experiment_name
|
| 12 |
+
echo 'start evaluate generated question'
|
| 13 |
+
bash question_evaluate/evaluate.sh $solver_model_path $experiment_name
|
| 14 |
+
echo 'start upload'
|
| 15 |
+
python question_evaluate/upload.py --repo_name ${experiment_name} --max_score 0.8 --min_score 0.3 --experiment_name ${experiment_name}
|
| 16 |
+
echo 'start train'
|
| 17 |
+
|
| 18 |
+
python3 -m verl.trainer.main \
|
| 19 |
+
config=examples/config.yaml \
|
| 20 |
+
data.max_response_length=4096 \
|
| 21 |
+
worker.actor.model.model_path=$solver_model_path \
|
| 22 |
+
trainer.experiment_name=${experiment_name} \
|
| 23 |
+
trainer.save_checkpoint_path=${STORAGE_PATH}/models/${experiment_name}/ \
|
| 24 |
+
data.train_files=${HUGGINGFACENAME}/${experiment_name}@train \
|
| 25 |
+
trainer.total_epochs=100 \
|
| 26 |
+
trainer.max_steps=20 \
|
| 27 |
+
data.format_prompt=./examples/format_prompt/solver.jinja \
|
| 28 |
+
trainer.val_freq=4 \
|
| 29 |
+
worker.actor.micro_batch_size_per_device_for_update=1 \
|
| 30 |
+
worker.actor.micro_batch_size_per_device_for_experience=1 \
|
| 31 |
+
|
| 32 |
+
echo "merging model"
|
| 33 |
+
python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/${experiment_name}/global_step_15/actor
|
| 34 |
+
|
| 35 |
+
sleep 10
|
| 36 |
+
|
| 37 |
+
echo "solver training finished"
|
| 38 |
+
|
| 39 |
+
bash evaluation/eval_math.bash ${STORAGE_PATH}/models/${experiment_name}/global_step_15/actor/huggingface
|