| | --- |
| | library_name: transformers |
| | tags: [] |
| | --- |
| | |
| | # Model Card for Model ID |
| |
|
| | ProtST for binary localization. |
| |
|
| | The following script shows how to finetune ProtST on Gaudi. |
| |
|
| | ## Running script |
| | ```diff |
| | from transformers import AutoModel, AutoTokenizer, HfArgumentParser, TrainingArguments, Trainer |
| | from transformers.data.data_collator import DataCollatorWithPadding |
| | from transformers.trainer_pt_utils import get_parameter_names |
| | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
| | from datasets import load_dataset |
| | import functools |
| | import numpy as np |
| | from sklearn.metrics import accuracy_score, matthews_corrcoef |
| | import sys |
| | import torch |
| | import logging |
| | import datasets |
| | import transformers |
| | + import habana_frameworks.torch |
| | + from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments |
| | |
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| | |
| | def create_optimizer(opt_model, lr_ratio=0.1): |
| | head_names = [] |
| | for n, p in opt_model.named_parameters(): |
| | if "classifier" in n: |
| | head_names.append(n) |
| | else: |
| | p.requires_grad = False |
| | # turn a list of tuple to 2 lists |
| | for n, p in opt_model.named_parameters(): |
| | if n in head_names: |
| | assert p.requires_grad |
| | backbone_names = [] |
| | for n, p in opt_model.named_parameters(): |
| | if n not in head_names and p.requires_grad: |
| | backbone_names.append(n) |
| | # for weight_decay policy, see |
| | # https://github.com/huggingface/transformers/blob/50573c648ae953dcc1b94d663651f07fb02268f4/src/transformers/trainer.py#L947 |
| | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) # forbidden layer norm |
| | decay_parameters = [name for name in decay_parameters if "bias" not in name] |
| | # training_args.learning_rate |
| | head_decay_parameters = [name for name in head_names if name in decay_parameters] |
| | head_not_decay_parameters = [name for name in head_names if name not in decay_parameters] |
| | # training_args.learning_rate * model_config.lr_ratio |
| | backbone_decay_parameters = [name for name in backbone_names if name in decay_parameters] |
| | backbone_not_decay_parameters = [name for name in backbone_names if name not in decay_parameters] |
| | optimizer_grouped_parameters = [ |
| | { |
| | "params": [p for n, p in opt_model.named_parameters() if (n in head_decay_parameters and p.requires_grad)], |
| | "weight_decay": training_args.weight_decay, |
| | "lr": training_args.learning_rate |
| | }, |
| | { |
| | "params": [p for n, p in opt_model.named_parameters() if (n in backbone_decay_parameters and p.requires_grad)], |
| | "weight_decay": training_args.weight_decay, |
| | "lr": training_args.learning_rate * lr_ratio |
| | }, |
| | { |
| | "params": [p for n, p in opt_model.named_parameters() if (n in head_not_decay_parameters and p.requires_grad)], |
| | "weight_decay": 0.0, |
| | "lr": training_args.learning_rate |
| | }, |
| | { |
| | "params": [p for n, p in opt_model.named_parameters() if (n in backbone_not_decay_parameters and p.requires_grad)], |
| | "weight_decay": 0.0, |
| | "lr": training_args.learning_rate * lr_ratio |
| | }, |
| | ] |
| | - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) |
| | + optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args) |
| | optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
| | |
| | return optimizer |
| | |
| | def create_scheduler(training_args, optimizer): |
| | from transformers.optimization import get_scheduler |
| | return get_scheduler( |
| | training_args.lr_scheduler_type, |
| | optimizer=optimizer if optimizer is None else optimizer, |
| | num_warmup_steps=training_args.get_warmup_steps(training_args.max_steps), |
| | num_training_steps=training_args.max_steps, |
| | ) |
| | |
| | def compute_metrics(eval_preds): |
| | probs, labels = eval_preds |
| | preds = np.argmax(probs, axis=-1) |
| | result = {"accuracy": accuracy_score(labels, preds), "mcc": matthews_corrcoef(labels, preds)} |
| | return result |
| | |
| | def preprocess_logits_for_metrics(logits, labels): |
| | return torch.softmax(logits, dim=-1) |
| | |
| | |
| | if __name__ == "__main__": |
| | - device = torch.device("cpu") |
| | + device = torch.device("hpu") |
| | raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization") |
| | model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) |
| | tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") |
| | |
| | output_dir = "/home/jiqingfe/protst/protst_2/ProtST-HuggingFace/output_dir/ProtSTModel/default/ESM-1b_PubMedBERT-abs/240123_015856" |
| | training_args = {'output_dir': output_dir, 'overwrite_output_dir': True, 'do_train': True, 'per_device_train_batch_size': 32, 'gradient_accumulation_steps': 1, \ |
| | 'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \ |
| | 'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \ |
| | 'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \ |
| | - 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3} |
| | + 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True} |
| | - training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0] |
| | + training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0] |
| | |
| | def tokenize_protein(example, tokenizer=None): |
| | protein_seq = example["prot_seq"] |
| | protein_seq_str = tokenizer(protein_seq, add_special_tokens=True) |
| | example["input_ids"] = protein_seq_str["input_ids"] |
| | example["attention_mask"] = protein_seq_str["attention_mask"] |
| | example["labels"] = example["localization"] |
| | |
| | return example |
| | |
| | func_tokenize_protein = functools.partial(tokenize_protein, tokenizer=tokenizer) |
| | |
| | for split in ["train", "validation", "test"]: |
| | raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"]) |
| | |
| | - data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
| | + data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024) |
| | |
| | transformers.utils.logging.set_verbosity_info() |
| | log_level = training_args.get_process_log_level() |
| | logger.setLevel(log_level) |
| | |
| | optimizer = create_optimizer(model) |
| | scheduler = create_scheduler(training_args, optimizer) |
| | |
| | + gaudi_config = GaudiConfig() |
| | + gaudi_config.use_fused_adam = True |
| | + gaudi_config.use_fused_clip_norm =True |
| | |
| | |
| | # build trainer |
| | - trainer = Trainer( |
| | + trainer = GaudiTrainer( |
| | model=model, |
| | + gaudi_config=gaudi_config, |
| | args=training_args, |
| | train_dataset=raw_dataset["train"], |
| | eval_dataset=raw_dataset["validation"], |
| | data_collator=data_collator, |
| | optimizers=(optimizer, scheduler), |
| | compute_metrics=compute_metrics, |
| | preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
| | ) |
| | |
| | train_result = trainer.train() |
| | |
| | trainer.save_model() |
| | # Saves the tokenizer too for easy upload |
| | tokenizer.save_pretrained(training_args.output_dir) |
| | |
| | metrics = train_result.metrics |
| | metrics["train_samples"] = len(raw_dataset["train"]) |
| | |
| | trainer.log_metrics("train", metrics) |
| | trainer.save_metrics("train", metrics) |
| | trainer.save_state() |
| | |
| | metric = trainer.evaluate(raw_dataset["test"], metric_key_prefix="test") |
| | print("test metric: ", metric) |
| | |
| | metric = trainer.evaluate(raw_dataset["validation"], metric_key_prefix="valid") |
| | print("valid metric: ", metric) |
| | ``` |
| |
|
| |
|
| |
|
| |
|