Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForTokenClassification, | |
| AutoModelForSequenceClassification, | |
| AutoModelForSeq2SeqLM, | |
| pipeline | |
| ) | |
| import re | |
| import os | |
| import json | |
| from typing import Dict, List, Tuple, Any | |
| class SymptomExtractor: | |
| """Model for extracting symptoms from patient descriptions using BioBERT.""" | |
| def __init__(self, model_name="dmis-lab/biobert-v1.1", device=None): | |
| self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading Symptom Extractor model on {self.device}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForTokenClassification.from_pretrained(model_name).to(self.device) | |
| self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1) | |
| print("Symptom Extractor model loaded successfully.") | |
| def extract_symptoms(self, text: str) -> Dict[str, Any]: | |
| """Extract symptoms from the input text.""" | |
| results = self.nlp(text) | |
| # Process the NER results to group related tokens | |
| symptoms = [] | |
| current_symptom = None | |
| for entity in results: | |
| if entity["entity"].startswith("B-"): # Beginning of a symptom | |
| if current_symptom: | |
| symptoms.append(current_symptom) | |
| current_symptom = { | |
| "text": entity["word"], | |
| "start": entity["start"], | |
| "end": entity["end"], | |
| "score": entity["score"] | |
| } | |
| elif entity["entity"].startswith("I-") and current_symptom: # Inside a symptom | |
| current_symptom["text"] += " " + entity["word"].replace("##", "") | |
| current_symptom["end"] = entity["end"] | |
| current_symptom["score"] = (current_symptom["score"] + entity["score"]) / 2 | |
| if current_symptom: | |
| symptoms.append(current_symptom) | |
| # Extract duration information | |
| duration_patterns = [ | |
| r"(\d+)\s*(day|days|week|weeks|month|months|year|years)", | |
| r"since\s+(\w+)", | |
| r"for\s+(\w+)" | |
| ] | |
| duration_info = [] | |
| for pattern in duration_patterns: | |
| matches = re.finditer(pattern, text, re.IGNORECASE) | |
| for match in matches: | |
| duration_info.append({ | |
| "text": match.group(0), | |
| "start": match.start(), | |
| "end": match.end() | |
| }) | |
| return { | |
| "symptoms": symptoms, | |
| "duration": duration_info | |
| } | |
| class RiskClassifier: | |
| """Model for classifying patient risk level using PubMedBERT.""" | |
| def __init__(self, model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", device=None): | |
| self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading Risk Classifier model on {self.device}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| num_labels=3 # Low, Medium, High | |
| ).to(self.device) | |
| self.id2label = {0: "Low", 1: "Medium", 2: "High"} | |
| print("Risk Classifier model loaded successfully.") | |
| def classify_risk(self, text: str) -> Dict[str, Any]: | |
| """Classify the risk level based on the input text.""" | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
| model_prediction = torch.argmax(logits, dim=1).item() | |
| # 由于模型没有经过微调,我们添加基于规则的后处理来调整风险级别 | |
| # 检查文本中是否存在高风险关键词 | |
| high_risk_keywords = [ | |
| "severe", "extreme", "intense", "unbearable", "emergency", | |
| "chest pain", "difficulty breathing", "can't breathe", | |
| "losing consciousness", "fainted", "seizure", "stroke", "heart attack", | |
| "allergic reaction", "bleeding heavily", "blood", "poisoning", | |
| "overdose", "suicide", "self-harm", "hallucinations" | |
| ] | |
| medium_risk_keywords = [ | |
| "worsening", "spreading", "persistent", "chronic", "recurring", | |
| "infection", "fever", "swelling", "rash", "pain", "vomiting", | |
| "diarrhea", "dizzy", "headache", "concerning", "worried", | |
| "weeks", "days", "increasing", "progressing" | |
| ] | |
| low_risk_keywords = [ | |
| "mild", "slight", "minor", "occasional", "intermittent", | |
| "improving", "better", "sometimes", "rarely", "manageable" | |
| ] | |
| text_lower = text.lower() | |
| # 计算匹配的关键词数量 | |
| high_risk_matches = sum(keyword in text_lower for keyword in high_risk_keywords) | |
| medium_risk_matches = sum(keyword in text_lower for keyword in medium_risk_keywords) | |
| low_risk_matches = sum(keyword in text_lower for keyword in low_risk_keywords) | |
| # 根据关键词匹配调整风险预测 | |
| adjusted_prediction = model_prediction | |
| if high_risk_matches >= 2: | |
| adjusted_prediction = 2 # High risk | |
| elif high_risk_matches == 1 and medium_risk_matches >= 2: | |
| adjusted_prediction = 2 # High risk | |
| elif medium_risk_matches >= 3: | |
| adjusted_prediction = 1 # Medium risk | |
| elif medium_risk_matches >= 1 and low_risk_matches <= 1: | |
| adjusted_prediction = 1 # Medium risk | |
| elif low_risk_matches >= 2 and high_risk_matches == 0: | |
| adjusted_prediction = 0 # Low risk | |
| # 如果文本很长(详细描述),可能表明情况更复杂,风险更高 | |
| if len(text.split()) > 40 and adjusted_prediction == 0: | |
| adjusted_prediction = 1 # 升级到Medium风险 | |
| # 对调整后的概率进行修正 | |
| adjusted_probabilities = probabilities.copy() | |
| # 增强对应风险级别的概率 | |
| adjusted_probabilities[adjusted_prediction] = max(0.6, adjusted_probabilities[adjusted_prediction]) | |
| # 规范化概率使其总和为1 | |
| adjusted_probabilities = adjusted_probabilities / adjusted_probabilities.sum() | |
| return { | |
| "risk_level": self.id2label[adjusted_prediction], | |
| "confidence": float(adjusted_probabilities[adjusted_prediction]), | |
| "all_probabilities": { | |
| self.id2label[i]: float(prob) | |
| for i, prob in enumerate(adjusted_probabilities) | |
| }, | |
| "original_prediction": self.id2label[model_prediction] | |
| } | |
| class RecommendationGenerator: | |
| """Model for generating medical recommendations using fine-tuned t5-small.""" | |
| def __init__(self, model_path="t5-small", device=None): | |
| self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading Recommendation Generator model on {self.device}...") | |
| # 检查常见的微调模型路径 | |
| possible_local_paths = [ | |
| "./finetuned_t5-small", # 添加用户指定的微调模型路径 | |
| "./t5-small-medical-recommendation", | |
| "./models/t5-small-medical-recommendation", | |
| "./fine_tuned_models/t5-small", | |
| "./output", | |
| "./fine_tuning_output" | |
| ] | |
| # 检查是否为路径或模型标识符 | |
| model_exists = False | |
| for path in possible_local_paths: | |
| if os.path.exists(path): | |
| model_path = path | |
| model_exists = True | |
| print(f"Found fine-tuned model at: {model_path}") | |
| break | |
| if not model_exists and model_path == "t5-small-medical-recommendation": | |
| print("Fine-tuned model not found locally. Falling back to base t5-small...") | |
| model_path = "t5-small" | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device) | |
| print(f"Recommendation Generator model '{model_path}' loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model from {model_path}: {str(e)}") | |
| print("Falling back to base t5-small model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(self.device) | |
| print("Base t5-small model loaded successfully as fallback.") | |
| # 科室映射 - 症状关键词到科室的映射 | |
| self.symptom_to_department = { | |
| "headache": "Neurology", | |
| "dizziness": "Neurology", | |
| "confusion": "Neurology", | |
| "memory": "Neurology", | |
| "numbness": "Neurology", | |
| "tingling": "Neurology", | |
| "seizure": "Neurology", | |
| "nerve": "Neurology", | |
| "chest pain": "Cardiology", | |
| "heart": "Cardiology", | |
| "palpitation": "Cardiology", | |
| "arrhythmia": "Cardiology", | |
| "high blood pressure": "Cardiology", | |
| "hypertension": "Cardiology", | |
| "heart attack": "Cardiology", | |
| "cardiovascular": "Cardiology", | |
| "cough": "Pulmonology", | |
| "breathing": "Pulmonology", | |
| "shortness": "Pulmonology", | |
| "lung": "Pulmonology", | |
| "respiratory": "Pulmonology", | |
| "asthma": "Pulmonology", | |
| "pneumonia": "Pulmonology", | |
| "copd": "Pulmonology", | |
| "stomach": "Gastroenterology", | |
| "abdomen": "Gastroenterology", | |
| "nausea": "Gastroenterology", | |
| "vomit": "Gastroenterology", | |
| "diarrhea": "Gastroenterology", | |
| "constipation": "Gastroenterology", | |
| "heartburn": "Gastroenterology", | |
| "liver": "Gastroenterology", | |
| "digestive": "Gastroenterology", | |
| "joint": "Orthopedics", | |
| "bone": "Orthopedics", | |
| "muscle": "Orthopedics", | |
| "pain": "Orthopedics", | |
| "back": "Orthopedics", | |
| "arthritis": "Orthopedics", | |
| "fracture": "Orthopedics", | |
| "sprain": "Orthopedics", | |
| "rash": "Dermatology", | |
| "skin": "Dermatology", | |
| "itching": "Dermatology", | |
| "itch": "Dermatology", | |
| "acne": "Dermatology", | |
| "eczema": "Dermatology", | |
| "psoriasis": "Dermatology", | |
| "fever": "General Medicine / Primary Care", | |
| "infection": "General Medicine / Primary Care", | |
| "sore throat": "General Medicine / Primary Care", | |
| "flu": "General Medicine / Primary Care", | |
| "cold": "General Medicine / Primary Care", | |
| "fatigue": "General Medicine / Primary Care", | |
| "pregnancy": "Obstetrics / Gynecology", | |
| "menstruation": "Obstetrics / Gynecology", | |
| "period": "Obstetrics / Gynecology", | |
| "vaginal": "Obstetrics / Gynecology", | |
| "menopause": "Obstetrics / Gynecology", | |
| "depression": "Psychiatry", | |
| "anxiety": "Psychiatry", | |
| "mood": "Psychiatry", | |
| "stress": "Psychiatry", | |
| "sleep": "Psychiatry", | |
| "insomnia": "Psychiatry", | |
| "mental": "Psychiatry", | |
| "ear": "Otolaryngology (ENT)", | |
| "nose": "Otolaryngology (ENT)", | |
| "throat": "Otolaryngology (ENT)", | |
| "hearing": "Otolaryngology (ENT)", | |
| "sinus": "Otolaryngology (ENT)", | |
| "eye": "Ophthalmology", | |
| "vision": "Ophthalmology", | |
| "blindness": "Ophthalmology", | |
| "blurry": "Ophthalmology", | |
| "urination": "Urology", | |
| "kidney": "Urology", | |
| "bladder": "Urology", | |
| "urine": "Urology", | |
| "prostate": "Urology" | |
| } | |
| # 自我护理建议 | |
| self.self_care_by_risk = { | |
| "Low": [ | |
| "Ensure you're getting adequate rest", | |
| "Stay hydrated by drinking plenty of water", | |
| "Monitor your symptoms and note any changes", | |
| "Consider over-the-counter medications appropriate for your symptoms", | |
| "Maintain a balanced diet to support your immune system", | |
| "Try gentle exercises if appropriate for your condition", | |
| "Avoid activities that worsen your symptoms", | |
| "Keep track of any patterns in your symptoms" | |
| ], | |
| "Medium": [ | |
| "Rest and avoid strenuous activities", | |
| "Stay hydrated and maintain proper nutrition", | |
| "Take your temperature and other vital signs if possible", | |
| "Write down any changes in symptoms and when they occur", | |
| "Have someone stay with you if your symptoms are concerning", | |
| "Prepare a list of your symptoms and medications for your doctor", | |
| "Avoid self-medicating beyond basic over-the-counter remedies", | |
| "Consider arranging transportation to your medical appointment" | |
| ], | |
| "High": [ | |
| "Don't wait - seek medical attention immediately", | |
| "Have someone drive you to the emergency room if safe to do so", | |
| "Call emergency services if symptoms are severe", | |
| "Bring a list of your current medications if possible", | |
| "Follow any first aid protocols appropriate for your symptoms", | |
| "Don't eat or drink anything if you might need surgery", | |
| "Take prescribed emergency medications if applicable (like an inhaler for asthma)", | |
| "Try to remain calm and focused on getting help" | |
| ] | |
| } | |
| def _extract_departments_from_symptoms(self, symptoms_text: str) -> List[str]: | |
| """ | |
| 从症状描述中提取可能的相关科室 | |
| Args: | |
| symptoms_text: 症状描述文本 | |
| Returns: | |
| 科室名称列表 | |
| """ | |
| departments = set() | |
| symptoms_lower = symptoms_text.lower() | |
| # 通过关键词匹配寻找相关科室 | |
| for keyword, department in self.symptom_to_department.items(): | |
| if keyword in symptoms_lower: | |
| departments.add(department) | |
| # 如果没有找到匹配的科室,返回常规医疗科室 | |
| if not departments: | |
| departments.add("General Medicine / Primary Care") | |
| return list(departments) | |
| def _get_self_care_suggestions(self, risk_level: str) -> List[str]: | |
| """ | |
| 根据风险级别获取自我护理建议 | |
| Args: | |
| risk_level: 风险级别 (Low, Medium, High) | |
| Returns: | |
| 自我护理建议列表 | |
| """ | |
| # 确保风险级别有效 | |
| if risk_level not in self.self_care_by_risk: | |
| risk_level = "Medium" # 默认返回中等风险的建议 | |
| # 返回为该风险级别准备的建议 | |
| suggestions = self.self_care_by_risk[risk_level] | |
| # 随机选择5项建议,避免每次返回完全相同的内容 | |
| import random | |
| if len(suggestions) > 5: | |
| selected = random.sample(suggestions, 5) | |
| else: | |
| selected = suggestions | |
| return selected | |
| def _format_structured_recommendation(self, medical_advice: str, departments: List[str], self_care: List[str], risk_level: str) -> str: | |
| """ | |
| 格式化结构化建议为文本格式 | |
| Args: | |
| medical_advice: 主要医疗建议 | |
| departments: 建议科室列表 | |
| self_care: 自我护理建议列表 | |
| risk_level: 风险级别 | |
| Returns: | |
| 格式化后的完整建议文本 | |
| """ | |
| # 初始化建议文本 | |
| recommendation = "" | |
| # 添加主要医疗建议 | |
| recommendation += medical_advice.strip() + "\n\n" | |
| # 添加建议科室部分 | |
| recommendation += f"RECOMMENDED DEPARTMENTS: Based on your symptoms, consider consulting the following departments: {', '.join(departments)}.\n\n" | |
| # 添加自我护理部分 | |
| recommendation += f"SELF-CARE SUGGESTIONS: While {risk_level.lower()} risk level requires {'immediate attention' if risk_level == 'High' else 'medical care soon' if risk_level == 'Medium' else 'monitoring'}, you can also:\n" | |
| for suggestion in self_care: | |
| recommendation += f"- {suggestion}\n" | |
| return recommendation | |
| def generate_recommendation(self, | |
| symptoms: str, | |
| risk_level: str, | |
| max_length: int = 150) -> Dict[str, Any]: | |
| """ | |
| Generate a comprehensive medical recommendation based on symptoms and risk level. | |
| Args: | |
| symptoms: Symptom description text | |
| risk_level: Risk level (Low, Medium, High) | |
| max_length: Maximum length for generated text | |
| Returns: | |
| Dictionary containing structured recommendation including medical advice, | |
| department suggestions, and self-care tips | |
| """ | |
| # 创建输入提示 | |
| input_text = f"Symptoms: {symptoms} Risk: {risk_level}" | |
| # 通过模型生成主要医疗建议 | |
| inputs = self.tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| # 解码生成的医疗建议 | |
| medical_advice = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # 从症状提取建议科室 | |
| departments = self._extract_departments_from_symptoms(symptoms) | |
| # 如果是高风险,添加急诊科 | |
| if risk_level == "High" and "Emergency Medicine" not in departments: | |
| departments.insert(0, "Emergency Medicine") | |
| # 获取自我护理建议 | |
| self_care_suggestions = self._get_self_care_suggestions(risk_level) | |
| # 创建完整的结构化建议 | |
| structured_recommendation = { | |
| "medical_advice": medical_advice, | |
| "departments": departments, | |
| "self_care": self_care_suggestions | |
| } | |
| # 格式化为文本格式的完整建议 | |
| formatted_text = self._format_structured_recommendation( | |
| medical_advice, | |
| departments, | |
| self_care_suggestions, | |
| risk_level | |
| ) | |
| return { | |
| "text": formatted_text, | |
| "structured": structured_recommendation | |
| } | |
| class MedicalConsultationPipeline: | |
| """Complete pipeline for medical consultation.""" | |
| def __init__(self, | |
| symptom_model="dmis-lab/biobert-v1.1", | |
| risk_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", | |
| recommendation_model="t5-small", | |
| device=None): | |
| self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Initializing Medical Consultation Pipeline on {self.device}...") | |
| self.symptom_extractor = SymptomExtractor(model_name=symptom_model, device=self.device) | |
| self.risk_classifier = RiskClassifier(model_name=risk_model, device=self.device) | |
| self.recommendation_generator = RecommendationGenerator(model_path=recommendation_model, device=self.device) | |
| print("Medical Consultation Pipeline initialized successfully.") | |
| def process(self, text: str) -> Dict[str, Any]: | |
| """Process the patient description through the complete pipeline.""" | |
| # Step 1: Extract symptoms | |
| extraction_results = self.symptom_extractor.extract_symptoms(text) | |
| # Step 2: Classify risk | |
| risk_results = self.risk_classifier.classify_risk(text) | |
| # Create a summary of the symptoms for the recommendation model | |
| symptoms_summary = ", ".join([symptom["text"] for symptom in extraction_results["symptoms"]]) | |
| if not symptoms_summary: | |
| symptoms_summary = text # Use original text if no symptoms found | |
| # Step 3: Generate recommendation | |
| recommendation_result = self.recommendation_generator.generate_recommendation( | |
| symptoms=symptoms_summary, | |
| risk_level=risk_results["risk_level"] | |
| ) | |
| return { | |
| "extraction": extraction_results, | |
| "risk": risk_results, | |
| "recommendation": recommendation_result["text"], | |
| "structured_recommendation": recommendation_result["structured"], | |
| "input_text": text | |
| } | |
| # Example usage | |
| if __name__ == "__main__": | |
| # This is just a test code that won't run in the Streamlit app | |
| pipeline = MedicalConsultationPipeline() | |
| sample_text = "I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous." | |
| result = pipeline.process(sample_text) | |
| print("Extracted symptoms:", [s["text"] for s in result["extraction"]["symptoms"]]) | |
| print("Duration info:", [d["text"] for d in result["extraction"]["duration"]]) | |
| print("Risk level:", result["risk"]["risk_level"], f"(Confidence: {result['risk']['confidence']:.2f})") | |
| print("Recommendation:", result["recommendation"]) |