Spaces:
Running
Running
| import numpy as np | |
| import pandas as pd | |
| import re | |
| import selfies as sf | |
| import torch | |
| from rdkit import Chem | |
| from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw | |
| from rdkit.Chem.Crippen import MolLogP | |
| from rdkit.Contrib.SA_Score import sascorer | |
| from transformers import BartForConditionalGeneration, AutoTokenizer | |
| from transformers.modeling_outputs import BaseModelOutput | |
| gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") | |
| gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted") | |
| # Function to display molecule image from SMILES | |
| def smiles_to_image(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| return Draw.MolToImage(mol) if mol else None | |
| def calculate_properties(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol: | |
| qed = QED.qed(mol) | |
| logp = MolLogP(mol) | |
| sa = sascorer.calculateScore(mol) | |
| wt = Descriptors.MolWt(mol) | |
| return qed, sa, logp, wt | |
| return None, None, None, None | |
| # Function to calculate Tanimoto similarity | |
| def calculate_tanimoto(smiles1, smiles2): | |
| mol1 = Chem.MolFromSmiles(smiles1) | |
| mol2 = Chem.MolFromSmiles(smiles2) | |
| if mol1 and mol2: | |
| fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2) | |
| fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2) | |
| return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2) | |
| return None | |
| def _perturb_latent(latent_vecs, noise_scale=0.5): | |
| return ( | |
| torch.tensor( | |
| np.random.uniform(0, 1, latent_vecs.shape) * noise_scale, | |
| dtype=torch.float32, | |
| ) | |
| + latent_vecs | |
| ) | |
| def _encode(selfies): | |
| encoding = gen_tokenizer( | |
| selfies, | |
| return_tensors='pt', | |
| max_length=128, | |
| truncation=True, | |
| padding='max_length', | |
| ) | |
| input_ids = encoding['input_ids'] | |
| attention_mask = encoding['attention_mask'] | |
| outputs = gen_model.model.encoder( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| ) | |
| model_output = outputs.last_hidden_state | |
| return model_output, attention_mask | |
| def _generate(latent_vector, mask): | |
| encoder_outputs = BaseModelOutput(latent_vector) | |
| decoder_output = gen_model.generate( | |
| encoder_outputs=encoder_outputs, | |
| attention_mask=mask, | |
| max_new_tokens=64, | |
| do_sample=True, | |
| top_k=5, | |
| top_p=0.95, | |
| num_return_sequences=1, | |
| ) | |
| selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True) | |
| return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies] | |
| # Function to generate canonical SMILES and molecule image | |
| def generate_canonical(smiles): | |
| s = sf.encoder(smiles) | |
| selfie = s.replace("][", "] [") | |
| latent_vec, mask = _encode([selfie]) | |
| gen_mol = None | |
| for i in range(5, 51): | |
| print("Searching Latent space") | |
| noise = i / 10 | |
| perturbed_latent = _perturb_latent(latent_vec, noise_scale=noise) | |
| gen = _generate(perturbed_latent, mask) | |
| mol = Chem.MolFromSmiles(gen[0]) | |
| if mol: | |
| gen_mol = Chem.MolToSmiles(mol) | |
| if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): | |
| break | |
| else: | |
| print('Abnormal molecule:', gen[0]) | |
| if gen_mol: | |
| # Calculate properties for ref and gen molecules | |
| print("calculating properties") | |
| ref_properties = calculate_properties(smiles) | |
| gen_properties = calculate_properties(gen_mol) | |
| tanimoto_similarity = calculate_tanimoto(smiles, gen_mol) | |
| # Prepare the table with ref mol and gen mol | |
| data = { | |
| "Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"], | |
| "Reference Mol": [ | |
| ref_properties[0], | |
| ref_properties[1], | |
| ref_properties[2], | |
| ref_properties[3], | |
| tanimoto_similarity, | |
| ], | |
| "Generated Mol": [ | |
| gen_properties[0], | |
| gen_properties[1], | |
| gen_properties[2], | |
| gen_properties[3], | |
| "", | |
| ], | |
| } | |
| df = pd.DataFrame(data) | |
| # Display molecule image of canonical smiles | |
| print("Getting image") | |
| mol_image = smiles_to_image(gen_mol) | |
| return df, gen_mol, mol_image | |
| return "Invalid SMILES", None, None | |