In [1]:
import gradio as gr
import torch
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
from mammal.keys import *
from mammal.model import Mammal
from abc import ABC, abstractmethod


In [None]:
class MammalObjectBroker():
    def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None:
        self.model_path = model_path
        if name is None:
            name = model_path
        self.name = name        
        
        if task_list is not None:
            self.tasks=task_list
        else:
            self.task = []
        self._model = None
        self._tokenizer_op = None
        
        
    @property
    def model(self)-> Mammal:
        if self._model is None:
            self._model =  Mammal.from_pretrained(self.model_path)
            self._model.eval()
        return self._model
    
    @property
    def tokenizer_op(self):
        if self._tokenizer_op is None:
            self._tokenizer_op =  ModularTokenizerOp.from_pretrained(self.model_path)
        return self._tokenizer_op
    
    
    
    

class MammalTask(ABC):
    def __init__(self, name:str) -> None:
            self.name = name
            self.description = None
            self._demo = None

    @abstractmethod
    def generate_prompt(self, **kwargs) -> str:
        """Formatting prompt to match pre-training syntax

        Args:
            prot1 (_type_): _description_
            prot2 (_type_): _description_

        Raises:
            No: _description_
        """
        raise NotImplementedError()

    @abstractmethod
    def crate_sample_dict(self, prompt: str, **kwargs) -> dict:
        """Formatting prompt to match pre-training syntax

        Args:
            prompt (str): _description_

        Returns:
            dict: sample_dict for feeding into model
        """
        raise NotImplementedError()

    @abstractmethod
    def run_model(_, sample_dict, model:Mammal):
        raise NotImplementedError()
    
    def decode_output(self,batch_dict, model):
        pass

    @abstractmethod
    def create_demo(self):
        """create an gradio demo group

        Returns:
            _type_: _description_
        """
        raise NotImplementedError()

    
    def demo(self,model_dropdown=None):
        if self._demo is None:
            self._demo = self.create_demo(model_dropdown)
        return self._demo

    @abstractmethod
    def decode_output(self,batch_dict, model:Mammal):
        raise NotImplementedError()

    #self._setup()
        
    # def _setup(self):
    #     pass
    



In [None]:
all_tasks = dict()
all_models= dict()

In [None]:

class PpiTask(MammalTask):
    def __init__(self):
        super().__init__(name="PPI")
        self.description = "Protein-Protein Interaction (PPI)"
        self.examples = {
            "protein_calmodulin": ,"MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK"
            "protein_calcineurin": "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ",
        }
        self.markup_text = """
    # Mammal based {self.description} demonstration
    
    Given two protein sequences, estimate if the proteins interact or not."""
    
        
        
    @staticmethod
    def positive_token_id(model_holder: MammalObjectBroker):
        """token for positive binding

        Args:
            model (MammalTrainedModel): model holding tokenizer

        Returns:
            int: id of positive binding token
        """
        return model_holder.tokenizer_op.get_token_id("<1>")
    
    def generate_prompt(self, prot1, prot2):
        """Formatting prompt to match pre-training syntax

        Args:
            prot1 (str): sequance of protein number 1
            prot2 (str): sequance of protein number 2

        Returns:
            str: prompt
        """   
        prompt =  "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"\
            "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
            f"<SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END>"\
            "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
            f"<SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
        return prompt
    
    
    def crate_sample_dict(self,prompt: str, model_holder:MammalObjectBroker):
        # Create and load sample
        sample_dict = dict()
        sample_dict[ENCODER_INPUTS_STR] = prompt

        # Tokenize
        sample_dict = model_holder.tokenizer_op(
            sample_dict=sample_dict,
            key_in=ENCODER_INPUTS_STR,
            key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
            key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
        )
        sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
            sample_dict[ENCODER_INPUTS_TOKENS]
        )
        sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
            sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
        )
        return sample_dict

    def run_model(_, sample_dict, model: Mammal):
        # Generate Prediction
        batch_dict = model.generate(
            [sample_dict],
            output_scores=True,
            return_dict_in_generate=True,
            max_new_tokens=5,
        )
        return batch_dict
        
    def decode_output(self,batch_dict, model_holder):

        # Get output
        generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
        score = batch_dict["model.out.scores"][0][1][self.positive_token_id(model_holder)].item()

        return generated_output, score


    def create_and_run_prompt(self,model_name,protein1, protein2):
        model_holder = all_models[model_name]
        prompt = self.generate_prompt(protein1, protein2)
        sample_dict = self.crate_sample_dict(prompt=prompt, model_holder=model_holder)
        model_output = self.run_model(sample_dict=sample_dict, model=model_holder.model)
        res = prompt, *model_output
        return res

    
    def create_demo(self,model_name_dropdown):
        
    # """
    # ### Using the model from

    # ```{model} ```
    # """
        with gr.Group() as demo:
            gr.Markdown(self.markup_text)
            with gr.Row():
                prot1 = gr.Textbox(
                    label="Protein 1 sequence",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["protein_calmodulin"],
                )
                prot2 = gr.Textbox(
                    label="Protein 2 sequence",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["protein_calcineurin"],
                )
            with gr.Row():
                run_mammal = gr.Button(
                    "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
                )
            with gr.Row():
                prompt_box = gr.Textbox(label="Mammal prompt", lines=5)

            with gr.Row():
                decoded = gr.Textbox(label="Mammal output")
                run_mammal.click(
                    fn=self.create_and_run_prompt,
                    inputs=[model_name_dropdown, prot1, prot2],
                    outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
                )
            with gr.Row():
                gr.Markdown(
                    "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
                )
            demo.visible = True
            return demo

ppi_task = PpiTask()
all_tasks[ppi_task.name]=ppi_task
all_tasks

In [None]:


### DTI:

#
dti = "Drug-Target Binding Affinity"





# input
target_seq = "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC"
drug_seq = "CC(=O)NCCC1=CNc2c1cc(OC)cc2"


# token for positive binding
positive_token_id = tokenizer_op[dti].get_token_id("<1>")


def generate_prompt_dti(prot, drug):
    sample_dict = {"target_seq": target_seq, "drug_seq": drug_seq}
    sample_dict = DtiBindingdbKdTask.data_preprocessing(
        sample_dict=sample_dict,
        tokenizer_op=tokenizer_op[dti],
        target_sequence_key="target_seq",
        drug_sequence_key="drug_seq",
        norm_y_mean=None,
        norm_y_std=None,
        device=models[dti].device,
    )
    return sample_dict


def create_and_run_prompt_dtb(prot, drug):
    sample_dict = generate_prompt_dti(prot, drug)
    # Post-process the model's output
    # batch_dict = model_dti.forward_encoder_only([sample_dict])
    batch_dict = models[dti].forward_encoder_only([sample_dict])
    batch_dict = DtiBindingdbKdTask.process_model_output(
        batch_dict,
        scalars_preds_processed_key="model.out.dti_bindingdb_kd",
        norm_y_mean=5.79384684128215,
        norm_y_std=1.33808027428196,
    )
    ans = [
        "model.out.dti_bindingdb_kd",
        float(batch_dict["model.out.dti_bindingdb_kd"][0]),
    ]
    res = sample_dict["data.query.encoder_input"], *ans
    return res


def create_tdb_demo():
    markup_text = f"""
# Mammal based Target-Drug binding affinity demonstration

Given a protein sequence and a drug (in SMILES), estimate the binding affinity.

### Using the model from

 ```{model_paths[dti]} ```
"""
    with gr.Group() as tdb_demo:
        gr.Markdown(markup_text)
        with gr.Row():
            prot = gr.Textbox(
                label="Protein sequence",
                # info="standard",
                interactive=True,
                lines=3,
                value=target_seq,
            )
            drug = gr.Textbox(
                label="drug sequence (SMILES)",
                # info="standard",
                interactive=True,
                lines=3,
                value=drug_seq,
            )
        with gr.Row():
            run_mammal = gr.Button(
                "Run Mammal prompt for Target Drug Affinity", variant="primary"
            )
        with gr.Row():
            prompt_box = gr.Textbox(label="Mammal prompt", lines=5)

        with gr.Row():
            decoded = gr.Textbox(label="Mammal output")
            run_mammal.click(
                fn=create_and_run_prompt_dtb,
                inputs=[prot, drug],
                outputs=[prompt_box, decoded, gr.Number(label="DTI score")],
            )
        tdb_demo.visible = False
        return tdb_demo



In [None]:

ppi_model = MammalObjectBroker(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=["PPI"])

all_models[ppi_model.name]=ppi_model
# tdi_model = MammalTrainedModel(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd")  TODO: ## task list still empty
# all_models.append(tdi_model)

In [None]:

def create_application():
    def task_change(value):
        choices=[model_name for model_name, model in all_models.items() if value in model.tasks]
        if choices:
            return  gr.update(choices=choices, value=choices[0])
        else:
            return
        # return model_dropdown
        
       
    with gr.Blocks() as demo:
        task_dropdown = gr.Dropdown(choices=["select demo"] + list(all_tasks.keys()))
        task_dropdown.interactive = True
        model_dropdown = gr.Dropdown(choices=[model_name for model_name, model in all_models.items() if task_dropdown.value in model.tasks], interactive=True)
        task_dropdown.change(task_change,inputs=[task_dropdown],outputs=[model_dropdown])
        
            



        ppi_demo = all_tasks["PPI"].demo(model_dropdown = model_dropdown)
        ppi_demo.visible = True
        # dtb_demo = create_tdb_demo()

        def set_ppi_vis(main_text):
            main_text=main_text
            print(f"main text is {main_text}")
            return gr.Group(visible=True)
            #return gr.Group(visible=(main_text == "PPI"))
        # , gr.Group(                visible=(main_text == "DTI")            )

        task_dropdown.change(
            set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]
        )
        return demo

full_demo=None
def main():
    global full_demo
    full_demo = create_application()
    full_demo.launch(show_error=True, share=False)


if __name__ == "__main__":
    main()


In [None]:
for model_name, model_holder in all_models.items():
    print(model_name)
    print(model_holder.tasks, "PPI" in model_holder.tasks)

In [None]:
full_demo.blocks[240].EVENTS

In [5]:
from mammal.examples.tcr_epitope_binding.main_infer import load_model, task_infer



In [6]:

tcr_beta_seq = "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT"
epitope_seq = "LLQTGIHVRVSQPSL"


In [8]:
model = Mammal.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m")

Path doesn't exist. Will try to download fron hf hub. pretrained_model_name_or_path='ibm/biomed.omics.bl.sm.ma-ted-458m'


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Attempting to load model from dir: pretrained_model_name_or_path='/Users/matann/.cache/huggingface/hub/models--ibm--biomed.omics.bl.sm.ma-ted-458m/snapshots/421daf3f8eae4ada57ffd3580f7347828b34d69a'


In [9]:
device="cpu"
path = "ibm/biomed.omics.bl.sm.ma-ted-458m"

# Load Model and set to evaluation mode
model = Mammal.from_pretrained(path)
model.eval()

# model.to(device=device)

# Load Tokenizer
tokenizer_op = ModularTokenizerOp.from_pretrained(path)
    
# model_inst, tokenizer_op = load_model(device="cpu")


Path doesn't exist. Will try to download fron hf hub. pretrained_model_name_or_path='ibm/biomed.omics.bl.sm.ma-ted-458m'


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Attempting to load model from dir: pretrained_model_name_or_path='/Users/matann/.cache/huggingface/hub/models--ibm--biomed.omics.bl.sm.ma-ted-458m/snapshots/421daf3f8eae4ada57ffd3580f7347828b34d69a'


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

The OrderedVocab you are attempting to save contains holes for indices [314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499

In [11]:
result = task_infer(
    model=model,
    tokenizer_op=tokenizer_op,
    tcr_beta_seq=tcr_beta_seq,
    epitope_seq=epitope_seq,
)
print(f"The prediction for {epitope_seq} and {tcr_beta_seq} is {result}")


The prediction for LLQTGIHVRVSQPSL and NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT is {'pred': 0, 'score': 0.1266935169696808}


In [12]:
import os

from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp

from mammal.examples.protein_solubility.task import ProteinSolubilityTask
from mammal.keys import CLS_PRED, SCORES
from mammal.model import Mammal


In [14]:

# Load Model
model = Mammal.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility")


Path doesn't exist. Will try to download fron hf hub. pretrained_model_name_or_path='ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility'


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

(â€¦)th_aug_4272372_samples_balanced_1_1.json:   0%|          | 0.00/277k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.83k [00:00<?, ?B/s]

tokenizer/config.yaml:   0%|          | 0.00/967 [00:00<?, ?B/s]

tokenizer/cell_attributes_tokenizer.json:   0%|          | 0.00/93.4k [00:00<?, ?B/s]

tokenizer/gene_tokenizer.json:   0%|          | 0.00/2.76M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.83G [00:00<?, ?B/s]

tokenizer/t5_tokenizer_AA_special.json:   0%|          | 0.00/70.9k [00:00<?, ?B/s]

Attempting to load model from dir: pretrained_model_name_or_path='/Users/matann/.cache/huggingface/hub/models--ibm--biomed.omics.bl.sm.ma-ted-458m.protein_solubility/snapshots/5644b80883d961ecae5cbb5773bee961b872869c'


In [15]:

model.eval()

# Load Tokenizer
tokenizer_op = ModularTokenizerOp.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility")


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

The OrderedVocab you are attempting to save contains holes for indices [314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499

In [16]:
protein_seq = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK"

In [17]:

# convert to MAMMAL style
sample_dict = {"protein_seq": protein_seq}
sample_dict = ProteinSolubilityTask.data_preprocessing(
    sample_dict=sample_dict,
    protein_sequence_key="protein_seq",
    tokenizer_op=tokenizer_op,
    device=model.device,
)


In [18]:

# running in generate mode
batch_dict = model.generate(
    [sample_dict],
    output_scores=True,
    return_dict_in_generate=True,
    max_new_tokens=5,
)


In [None]:

# Post-process the model's output
ans = ProteinSolubilityTask.process_model_output(
    tokenizer_op=tokenizer_op,
    decoder_output=batch_dict[CLS_PRED][0],
    decoder_output_scores=batch_dict[SCORES][0],
)

# Print prediction



ans={'pred': 1, 'not_normalized_scores': tensor(0.8730), 'normalized_scores': tensor(0.8730)}


In [20]:
tokenizer_op = ModularTokenizerOp.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility")

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

The OrderedVocab you are attempting to save contains holes for indices [314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499

In [21]:
# Assisted by watsonx Code Assistant 
all_models = {'model1': 'model1_path', 'model2': 'model2_path'}

def register_model(self, name):
    self.update({name: f'{name}_path'})

all_models.register_model = register_model.__get__(all_models, dict)
all_models.register_model("model3")
print(all_models)


AttributeError: 'dict' object has no attribute 'register_model'

In [22]:
class AllModels(dict):
    def register_model(self, name):
        self.update({name: f'{name}_path'})


In [None]:
all_models=AllModels()

In [25]:
all_models.register_model("abc")

In [26]:
all_models

{'abc': 'abc_path'}