Spaces:
Sleeping
Sleeping
| import torch | |
| import lightning | |
| # from torch.utils.data import Dataset | |
| # from typing import Any, Dict | |
| # import argparse | |
| from pydantic import BaseModel | |
| # from get_dataset_dictionaries import get_dict_pair | |
| # import os | |
| # import shutil | |
| # import optuna | |
| # from optuna.integration import PyTorchLightningPruningCallback | |
| # from functools import partial | |
| class FFNModule(torch.nn.Module): | |
| """ | |
| A pytorch module that regresses from a hidden state representation of a word | |
| to its continuous linguistic feature norm vector. | |
| It is a FFN with the general structure of: | |
| input -> (linear -> nonlinearity -> dropout) x (num_layers - 1) -> linear -> output | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| hidden_size: int, | |
| num_layers: int, | |
| dropout: float, | |
| ): | |
| super(FFNModule, self).__init__() | |
| layers = [] | |
| for _ in range(num_layers - 1): | |
| layers.append(torch.nn.Linear(input_size, hidden_size)) | |
| layers.append(torch.nn.ReLU()) | |
| layers.append(torch.nn.Dropout(dropout)) | |
| # changes input size to hidden size after first layer | |
| input_size = hidden_size | |
| layers.append(torch.nn.Linear(hidden_size, output_size)) | |
| self.network = torch.nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.network(x) | |
| class FFNParams(BaseModel): | |
| input_size: int | |
| output_size: int | |
| hidden_size: int | |
| num_layers: int | |
| dropout: float | |
| class TrainingParams(BaseModel): | |
| num_epochs: int | |
| batch_size: int | |
| learning_rate: float | |
| weight_decay: float | |
| class FeatureNormPredictor(lightning.LightningModule): | |
| def __init__(self, ffn_params : FFNParams, training_params : TrainingParams): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.ffn_params = ffn_params | |
| self.training_params = training_params | |
| self.model = FFNModule(**ffn_params.model_dump()) | |
| self.loss_function = torch.nn.MSELoss() | |
| self.training_params = training_params | |
| def training_step(self, batch, batch_idx): | |
| x,y = batch | |
| outputs = self.model(x) | |
| loss = self.loss_function(outputs, y) | |
| self.log("train_loss", loss) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x,y = batch | |
| outputs = self.model(x) | |
| loss = self.loss_function(outputs, y) | |
| self.log("val_loss", loss, on_epoch=True, prog_bar=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| return self.model(batch) | |
| def predict(self, batch): | |
| return self.model(batch) | |
| def __call__(self, input): | |
| return self.model(input) | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam( | |
| self.parameters(), | |
| lr=self.training_params.learning_rate, | |
| weight_decay=self.training_params.weight_decay, | |
| ) | |
| return optimizer | |
| def save_model(self, path: str): | |
| torch.save(self.model.state_dict(), path) | |
| def load_model(self, path: str): | |
| self.model.load_state_dict(torch.load(path)) | |
| # class HiddenStateFeatureNormDataset(Dataset): | |
| # def __init__( | |
| # self, | |
| # input_embeddings: Dict[str, torch.Tensor], | |
| # feature_norms: Dict[str, torch.Tensor], | |
| # ): | |
| # # Invariant: input_embeddings and target_feature_norms have exactly the same keys | |
| # # this should be done by the train/test split and upstream data processing | |
| # assert(input_embeddings.keys() == feature_norms.keys()) | |
| # self.words = list(input_embeddings.keys()) | |
| # self.input_embeddings = torch.stack([ | |
| # input_embeddings[word] for word in self.words | |
| # ]) | |
| # self.feature_norms = torch.stack([ | |
| # feature_norms[word] for word in self.words | |
| # ]) | |
| # def __len__(self): | |
| # return len(self.words) | |
| # def __getitem__(self, idx): | |
| # return self.input_embeddings[idx], self.feature_norms[idx] | |
| # # this is used when not optimizing | |
| # def train(args : Dict[str, Any]): | |
| # # input_embeddings = torch.load(args.input_embeddings) | |
| # # feature_norms = torch.load(args.feature_norms) | |
| # # words = list(input_embeddings.keys()) | |
| # input_embeddings, feature_norms, norm_list = get_dict_pair( | |
| # args.norm, | |
| # args.embedding_dir, | |
| # args.lm_layer, | |
| # translated= False if args.raw_buchanan else True, | |
| # normalized= True if args.normal_buchanan else False | |
| # ) | |
| # norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w') | |
| # norms_file.write("\n".join(norm_list)) | |
| # norms_file.close() | |
| # words = list(input_embeddings.keys()) | |
| # model = FeatureNormPredictor( | |
| # FFNParams( | |
| # input_size=input_embeddings[words[0]].shape[0], | |
| # output_size=feature_norms[words[0]].shape[0], | |
| # hidden_size=args.hidden_size, | |
| # num_layers=args.num_layers, | |
| # dropout=args.dropout, | |
| # ), | |
| # TrainingParams( | |
| # num_epochs=args.num_epochs, | |
| # batch_size=args.batch_size, | |
| # learning_rate=args.learning_rate, | |
| # weight_decay=args.weight_decay, | |
| # ), | |
| # ) | |
| # # train/val split | |
| # train_size = int(len(words) * 0.8) | |
| # valid_size = len(words) - train_size | |
| # train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size]) | |
| # # TODO: Methodology Decision: should we be normalizing the hidden states/feature norms? | |
| # train_embeddings = {word: input_embeddings[word] for word in train_words} | |
| # train_feature_norms = {word: feature_norms[word] for word in train_words} | |
| # validation_embeddings = {word: input_embeddings[word] for word in validation_words} | |
| # validation_feature_norms = {word: feature_norms[word] for word in validation_words} | |
| # train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms) | |
| # train_dataloader = torch.utils.data.DataLoader( | |
| # train_dataset, | |
| # batch_size=args.batch_size, | |
| # shuffle=True, | |
| # ) | |
| # validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms) | |
| # validation_dataloader = torch.utils.data.DataLoader( | |
| # validation_dataset, | |
| # batch_size=args.batch_size, | |
| # shuffle=True, | |
| # ) | |
| # callbacks = [ | |
| # lightning.pytorch.callbacks.ModelCheckpoint( | |
| # save_last=True, | |
| # dirpath=args.save_dir, | |
| # filename=args.save_model_name, | |
| # ), | |
| # ] | |
| # if args.early_stopping is not None: | |
| # callbacks.append(lightning.pytorch.callbacks.EarlyStopping( | |
| # monitor="val_loss", | |
| # patience=args.early_stopping, | |
| # mode='min', | |
| # min_delta=0.0 | |
| # )) | |
| # #TODO Design Decision - other trainer args? Is device necessary? | |
| # # cpu is fine for the scale of this model - only a few layers and a few hundred words | |
| # trainer = lightning.Trainer( | |
| # max_epochs=args.num_epochs, | |
| # callbacks=callbacks, | |
| # accelerator="cpu", | |
| # log_every_n_steps=7 | |
| # ) | |
| # trainer.fit(model, train_dataloader, validation_dataloader) | |
| # trainer.validate(model, validation_dataloader) | |
| # return model | |
| # # this is used when optimizing | |
| # def objective(trial: optuna.trial.Trial, args: Dict[str, Any]) -> float: | |
| # # optimizing hidden size, batch size, and learning rate | |
| # input_embeddings, feature_norms, norm_list = get_dict_pair( | |
| # args.norm, | |
| # args.embedding_dir, | |
| # args.lm_layer, | |
| # translated= False if args.raw_buchanan else True, | |
| # normalized= True if args.normal_buchanan else False | |
| # ) | |
| # norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w') | |
| # norms_file.write("\n".join(norm_list)) | |
| # norms_file.close() | |
| # words = list(input_embeddings.keys()) | |
| # input_size=input_embeddings[words[0]].shape[0] | |
| # output_size=feature_norms[words[0]].shape[0] | |
| # min_size = min(output_size, input_size) | |
| # max_size = min(output_size, 2*input_size)if min_size == input_size else min(2*output_size, input_size) | |
| # hidden_size = trial.suggest_int("hidden_size", min_size, max_size, log=True) | |
| # batch_size = trial.suggest_int("batch_size", 16, 128, log=True) | |
| # learning_rate = trial.suggest_float("learning_rate", 1e-6, 1, log=True) | |
| # model = FeatureNormPredictor( | |
| # FFNParams( | |
| # input_size=input_size, | |
| # output_size=output_size, | |
| # hidden_size=hidden_size, | |
| # num_layers=args.num_layers, | |
| # dropout=args.dropout, | |
| # ), | |
| # TrainingParams( | |
| # num_epochs=args.num_epochs, | |
| # batch_size=batch_size, | |
| # learning_rate=learning_rate, | |
| # weight_decay=args.weight_decay, | |
| # ), | |
| # ) | |
| # # train/val split | |
| # train_size = int(len(words) * 0.8) | |
| # valid_size = len(words) - train_size | |
| # train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size]) | |
| # train_embeddings = {word: input_embeddings[word] for word in train_words} | |
| # train_feature_norms = {word: feature_norms[word] for word in train_words} | |
| # validation_embeddings = {word: input_embeddings[word] for word in validation_words} | |
| # validation_feature_norms = {word: feature_norms[word] for word in validation_words} | |
| # train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms) | |
| # train_dataloader = torch.utils.data.DataLoader( | |
| # train_dataset, | |
| # batch_size=args.batch_size, | |
| # shuffle=True, | |
| # ) | |
| # validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms) | |
| # validation_dataloader = torch.utils.data.DataLoader( | |
| # validation_dataset, | |
| # batch_size=args.batch_size, | |
| # shuffle=True, | |
| # ) | |
| # callbacks = [ | |
| # # all trial models will be saved in temporary directory | |
| # lightning.pytorch.callbacks.ModelCheckpoint( | |
| # save_last=True, | |
| # dirpath=os.path.join(args.save_dir,'optuna_trials'), | |
| # filename="{}".format(trial.number) | |
| # ), | |
| # ] | |
| # if args.prune is not None: | |
| # callbacks.append(PyTorchLightningPruningCallback( | |
| # trial, | |
| # monitor='val_loss' | |
| # )) | |
| # if args.early_stopping is not None: | |
| # callbacks.append(lightning.pytorch.callbacks.EarlyStopping( | |
| # monitor="val_loss", | |
| # patience=args.early_stopping, | |
| # mode='min', | |
| # min_delta=0.0 | |
| # )) | |
| # # note that if optimizing is chosen, will automatically not implement vanilla early stopping | |
| # #TODO Design Decision - other trainer args? Is device necessary? | |
| # # cpu is fine for the scale of this model - only a few layers and a few hundred words | |
| # trainer = lightning.Trainer( | |
| # max_epochs=args.num_epochs, | |
| # callbacks=callbacks, | |
| # accelerator="cpu", | |
| # log_every_n_steps=7, | |
| # # enable_checkpointing=False | |
| # ) | |
| # trainer.fit(model, train_dataloader, validation_dataloader) | |
| # trainer.validate(model, validation_dataloader) | |
| # return trainer.callback_metrics['val_loss'].item() | |
| # if __name__ == "__main__": | |
| # # parse args | |
| # parser = argparse.ArgumentParser() | |
| # #TODO: Design Decision: Should we input paths, to the pre-extracted layers, or the model/layer we want to generate them from | |
| # # required inputs | |
| # parser.add_argument("--norm", type=str, required=True, help="feature norm set to use") | |
| # parser.add_argument("--embedding_dir", type=str, required=True, help=" directory containing embeddings") | |
| # parser.add_argument("--lm_layer", type=int, required=True, help="layer of embeddings to use") | |
| # # if user selects optimize, hidden_size, batch_size and learning_rate will be optimized. | |
| # parser.add_argument("--optimize", action="store_true", help="optimize hyperparameters for training") | |
| # parser.add_argument("--prune", action="store_true", help="prune unpromising trials when optimizing") | |
| # # optional hyperparameter specs | |
| # parser.add_argument("--num_layers", type=int, default=2, help="number of layers in FFN") | |
| # parser.add_argument("--hidden_size", type=int, default=100, help="hidden size of FFN") | |
| # parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate of FFN") | |
| # # set this to at least 100 if doing early stopping | |
| # parser.add_argument("--num_epochs", type=int, default=10, help="number of epochs to train for") | |
| # parser.add_argument("--batch_size", type=int, default=32, help="batch size for training") | |
| # parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate for training") | |
| # parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for training") | |
| # parser.add_argument("--early_stopping", type=int, default=None, help="number of epochs to wait for early stopping") | |
| # # optional dataset specs, for buchanan really | |
| # parser.add_argument('--raw_buchanan', action="store_true", help="do not use translated values for buchanan") | |
| # parser.add_argument('--normal_buchanan', action="store_true", help="use normalized features for buchanan") | |
| # # required for output | |
| # parser.add_argument("--save_dir", type=str, required=True, help="directory to save model to") | |
| # parser.add_argument("--save_model_name", type=str, required=True, help="name of model to save") | |
| # args = parser.parse_args() | |
| # if args.early_stopping is not None: | |
| # args.num_epochs = max(50, args.num_epochs) | |
| # torch.manual_seed(10) | |
| # if args.optimize: | |
| # # call optimizer code here | |
| # print("optimizing for learning rate, batch size, and hidden size") | |
| # pruner = optuna.pruners.MedianPruner() if args.prune else optuna.pruners.NopPruner() | |
| # sampler = optuna.samplers.TPESampler(seed=10) | |
| # study = optuna.create_study(direction='minimize', pruner=pruner, sampler=sampler) | |
| # study.optimize(partial(objective, args=args), n_trials = 100, timeout=600) | |
| # other_params = { | |
| # "num_layers": args.num_layers, | |
| # "num_epochs": args.num_epochs, | |
| # "dropout": args.dropout, | |
| # "weight_decay": args.weight_decay, | |
| # } | |
| # print("Number of finished trials: {}".format(len(study.trials))) | |
| # trial = study.best_trial | |
| # print("Best trial: "+str(trial.number)) | |
| # print(" Validation Loss: {}".format(trial.value)) | |
| # print(" Optimized Params: ") | |
| # for key, value in trial.params.items(): | |
| # print(" {}: {}".format(key, value)) | |
| # print(" User Defined Params: ") | |
| # for key, value in other_params.items(): | |
| # print(" {}: {}".format(key, value)) | |
| # print('saving best trial') | |
| # for filename in os.listdir(os.path.join(args.save_dir,'optuna_trials')): | |
| # if filename == "{}.ckpt".format(trial.number): | |
| # shutil.move(os.path.join(args.save_dir,'optuna_trials',filename), os.path.join(args.save_dir, "{}.ckpt".format(args.save_model_name))) | |
| # shutil.rmtree(os.path.join(args.save_dir,'optuna_trials')) | |
| # else: | |
| # model = train(args) | |