Spaces:
Runtime error
Runtime error
| from utils import * | |
| import torch | |
| from rxnfp.transformer_fingerprints import ( | |
| RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints | |
| ) | |
| class Reaction_model: | |
| def __init__(self, train_list, test_list): | |
| self.train_list = train_list | |
| self.test_list = test_list | |
| model, tokenizer = get_default_model_and_tokenizer() | |
| self.rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer) | |
| def generate_random(self): | |
| pred = random.sample(self.train_list, k=len(self.test_list)) | |
| pred = [i['actions'] for i in pred] | |
| return pred | |
| def generate_random_compatible_old(self): | |
| pred_list = [] | |
| len_id_map = defaultdict(list) | |
| for train_rxn in self.train_list: | |
| len_id_map[len(train_rxn['extracted_molecules'])-1].append(train_rxn['index']) | |
| keys = sorted(k for k in len_id_map.keys()) | |
| accumulated_counts = {} | |
| count = 0 | |
| for key in keys: | |
| count += len(len_id_map[key]) | |
| accumulated_counts[key] = count | |
| for rxn in self.test_list: | |
| test_token_num = len(rxn['extracted_molecules'])-1 | |
| idx = random.randint(0, accumulated_counts[test_token_num] - 1) | |
| for key in keys: | |
| if idx < len(len_id_map[key]): | |
| pred_list.append(self.train_list[len_id_map[key][idx]]['actions']) | |
| break | |
| else: | |
| idx -= len(len_id_map[key]) | |
| return pred_list | |
| def generate_random_compatible(self): | |
| pred_list = [] | |
| len_id_map = defaultdict(list) | |
| for train_rxn in self.train_list: | |
| len_id_map[len(train_rxn['extracted_molecules'])-1].append(train_rxn['index']) | |
| for rxn in self.test_list: | |
| mole_num = len(rxn['extracted_molecules'])-1 | |
| pred_list.append(self.train_list[random.choice(len_id_map[mole_num])]['actions']) | |
| return pred_list | |
| def generate_nn(self, batch_size=2048): | |
| train_rxns = [f"{'.'.join(rxn['REACTANT'])}>>{rxn['PRODUCT'][0]}" for rxn in self.train_list] | |
| test_rxns = [f"{'.'.join(rxn['REACTANT'])}>>{rxn['PRODUCT'][0]}" for rxn in self.test_list] | |
| train_rxns_batches = [train_rxns[i:i+batch_size] for i in range(0, len(train_rxns), batch_size)] | |
| test_rxns_batches = [test_rxns[i:i+batch_size] for i in range(0, len(test_rxns), batch_size)] | |
| device = torch.device("cuda") | |
| train_fps = [] | |
| for batch in tqdm(train_rxns_batches, desc='Generating fingerprints for training reactions'): | |
| batch_fps = self.rxnfp_generator.convert_batch(batch) | |
| train_fps.extend(batch_fps) | |
| train_fps = torch.tensor(train_fps, device=device) # N x 256 | |
| most_similar_indices = [] | |
| for batch in tqdm(test_rxns_batches, desc='Generating fingerprints for test reactions'): | |
| batch_fps = self.rxnfp_generator.convert_batch(batch) | |
| batch_fps = torch.tensor(batch_fps, device=device) # BS x 256 | |
| batch_fps = batch_fps / torch.norm(batch_fps, dim=1, keepdim=True) | |
| similarity_matrix = torch.matmul(train_fps, batch_fps.T) # N x BS | |
| most_similar_indices.extend(torch.argmax(similarity_matrix, dim=0).tolist()) | |
| return [self.train_list[i]['actions'] for i in most_similar_indices] | |
| def save_results(self, gt_list, pred_list, target_file): | |
| text_dict_list = [{ | |
| "targets": gt, | |
| "indices": i, | |
| "predictions": pred, | |
| } for i, (gt, pred) in enumerate(zip(gt_list, pred_list))] | |
| with open(target_file, 'w') as f: | |
| json.dump(text_dict_list, f, indent=4) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="A simple argument parser") | |
| parser.add_argument('--name', default='none', type=str) | |
| parser.add_argument('--train_file', default=None, type=str) | |
| parser.add_argument('--test_file', default=None, type=str) | |
| parser.add_argument('--use_tok', default=False, action='store_true') | |
| args = parser.parse_args() | |
| return args | |
| def read_dataset(args): | |
| print(f'Reading {args.train_file}...') | |
| with open(args.train_file, 'r', encoding='utf-8') as f: | |
| train_ds = json.load(f) | |
| print(f'{len(train_ds)} samples read.') | |
| print(f'Reading {args.test_file}...') | |
| with open(args.test_file, 'r', encoding='utf-8') as f: | |
| test_ds = json.load(f) | |
| print(f'{len(test_ds)} samples read.') | |
| return train_ds, test_ds | |
| def run_baselines(args): | |
| set_random_seed(0) | |
| train_ds, test_ds = read_dataset(args) | |
| model = Reaction_model(train_ds, test_ds) | |
| calculator = Metric_calculator() | |
| gt_list = [i['actions'] for i in test_ds] | |
| print('Random:') | |
| pred_list = model.generate_random() | |
| calculator(gt_list, pred_list, args.use_tok) | |
| model.save_results(gt_list, pred_list, f'results/{args.name}/random.json') | |
| print('Random (compatible pattern):') | |
| pred_list = model.generate_random_compatible() | |
| calculator(gt_list, pred_list, args.use_tok) | |
| model.save_results(gt_list, pred_list, f'results/{args.name}/random_compatible.json') | |
| print('Nearest neighbor:') | |
| pred_list = model.generate_nn() | |
| calculator(gt_list, pred_list, args.use_tok) | |
| model.save_results(gt_list, pred_list, f'results/{args.name}/nn.json') | |
| # assert 0 | |
| if __name__ == "__main__": | |
| args=parse_args() | |
| run_baselines(args) |