| import os | |
| from transformers import AutoModelForCausalLM | |
| model_name = os.getenv('MODEL_NAME') | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype="bfloat16", | |
| ) | |
| print(model_name, sum(p.numel() for p in model.parameters()), model.num_parameters()) | |