Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import time | |
| import torch | |
| import os | |
| from models import MedicalConsultationPipeline | |
| from utils import ( | |
| highlight_text_with_entities, | |
| format_duration, | |
| create_risk_gauge, | |
| create_risk_probability_chart, | |
| save_consultation, | |
| load_consultation_history, | |
| init_session_state, | |
| RISK_COLORS | |
| ) | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="AI Medical Consultation", | |
| page_icon="🩺", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| def load_css(): | |
| with open("style.css", "r") as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| # 检查本地是否有fine-tuned的T5模型 | |
| def find_fine_tuned_model(): | |
| possible_local_paths = [ | |
| "./finetuned_t5-small", # 添加用户提供的微调模型路径 | |
| "./t5-small-medical-recommendation", | |
| "./models/t5-small-medical-recommendation", | |
| "./fine_tuned_models/t5-small", | |
| "./output", | |
| "./fine_tuning_output" | |
| ] | |
| for path in possible_local_paths: | |
| if os.path.exists(path): | |
| return path | |
| return "t5-small" # 如果没有找到,返回基础模型 | |
| # Initialize session state | |
| init_session_state() | |
| # Apply custom CSS | |
| load_css() | |
| # Sidebar for settings and history | |
| with st.sidebar: | |
| st.image("https://img.icons8.com/fluency/96/000000/hospital-3.png", width=80) | |
| st.title("AI Medical Assistant") | |
| st.markdown("---") | |
| with st.expander("⚙️ Settings", expanded=False): | |
| # Model settings | |
| st.subheader("Model Settings") | |
| symptom_model = st.selectbox( | |
| "Symptom Extraction Model", | |
| ["dmis-lab/biobert-v1.1"], | |
| index=0, | |
| disabled=st.session_state.loaded_models # Disable after models are loaded | |
| ) | |
| risk_model = st.selectbox( | |
| "Risk Classification Model", | |
| ["microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"], | |
| index=0, | |
| disabled=st.session_state.loaded_models # Disable after models are loaded | |
| ) | |
| # 查找可用的t5模型 | |
| available_t5_model = find_fine_tuned_model() | |
| recommendation_model_options = [] | |
| # 总是添加基础模型 | |
| recommendation_model_options.append("t5-small (base model)") | |
| # 如果找到了fine-tuned模型,添加到选项中 | |
| if available_t5_model != "t5-small": | |
| recommendation_model_options.insert(0, f"{available_t5_model} (fine-tuned)") | |
| recommendation_model_label = st.selectbox( | |
| "Recommendation Model", | |
| recommendation_model_options, | |
| index=0, | |
| disabled=st.session_state.loaded_models # Disable after models are loaded | |
| ) | |
| # 提取实际的模型路径 | |
| if "(fine-tuned)" in recommendation_model_label: | |
| recommendation_model = available_t5_model | |
| else: | |
| recommendation_model = "t5-small" | |
| # Device selection | |
| device = st.radio( | |
| "Compute Device", | |
| ["CPU", "GPU (if available)"], | |
| index=1 if torch.cuda.is_available() else 0, | |
| disabled=st.session_state.loaded_models # Disable after models are loaded | |
| ) | |
| device = "cuda" if device == "GPU (if available)" and torch.cuda.is_available() else "cpu" | |
| if st.session_state.loaded_models: | |
| st.info("注意:设置已锁定,因为模型已加载。要更改设置,请刷新页面。") | |
| # Consultation history section | |
| st.markdown("---") | |
| st.subheader("📋 Consultation History") | |
| # Load consultation history | |
| if st.button("Refresh History"): | |
| st.session_state.consultation_history = load_consultation_history() | |
| st.success("History refreshed!") | |
| # If history is not already loaded, load it | |
| if not st.session_state.consultation_history: | |
| st.session_state.consultation_history = load_consultation_history() | |
| # Display history items | |
| if not st.session_state.consultation_history: | |
| st.info("No previous consultations found.") | |
| else: | |
| for i, consultation in enumerate(st.session_state.consultation_history[:10]): # Show only the 10 most recent | |
| timestamp = pd.to_datetime(consultation.get("timestamp", "")).strftime("%Y-%m-%d %H:%M") | |
| risk_level = consultation.get("risk", {}).get("risk_level", "Unknown") | |
| risk_color = RISK_COLORS.get(risk_level, "#6c757d") | |
| # Create a clickable history item | |
| history_item = f""" | |
| <div class='history-item' onclick=''> | |
| <strong>Patient Input:</strong> {consultation.get('input_text', '')[:50]}...<br> | |
| <strong>Time:</strong> {timestamp}<br> | |
| <strong>Risk Level:</strong> <span style='color:{risk_color};'>{risk_level}</span> | |
| </div> | |
| """ | |
| clicked = st.markdown(history_item, unsafe_allow_html=True) | |
| # If clicked, set this consultation as the current result | |
| if clicked: | |
| st.session_state.current_result = consultation | |
| # Main app layout | |
| st.markdown("<h1 class='main-header'>AI-Powered Medical Consultation</h1>", unsafe_allow_html=True) | |
| # Introduction row | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.markdown(""" | |
| <div class="card"> | |
| <h2 class="card-header">How it Works</h2> | |
| <p>This AI-powered medical consultation system helps you understand your symptoms and provides guidance on next steps.</p> | |
| <p><strong>Simply describe your symptoms</strong> in natural language and the system will:</p> | |
| <ol> | |
| <li>Extract key symptoms and duration information</li> | |
| <li>Assess your risk level</li> | |
| <li>Generate personalized medical recommendations</li> | |
| </ol> | |
| <p><em>Note: This system is for informational purposes only and does not replace professional medical advice.</em></p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown(""" | |
| <div class="card"> | |
| <h2 class="card-header">Example Inputs</h2> | |
| <ul> | |
| <li>"I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous."</li> | |
| <li>"My child has had a high fever of 39°C since yesterday and is coughing a lot."</li> | |
| <li>"I've noticed a persistent rash on my arm for the past 3 days, it's itchy and slightly swollen."</li> | |
| </ul> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # 显示当前使用的模型信息 | |
| model_info = f""" | |
| <div class="card"> | |
| <h2 class="card-header">当前模型配置</h2> | |
| <ul> | |
| <li><strong>症状抽取模型:</strong> {symptom_model}</li> | |
| <li><strong>风险分类模型:</strong> {risk_model}</li> | |
| <li><strong>推荐生成模型:</strong> {recommendation_model} {"(微调模型)" if recommendation_model != "t5-small" else "(基础模型)"}</li> | |
| <li><strong>计算设备:</strong> {device.upper()}</li> | |
| </ul> | |
| </div> | |
| """ | |
| st.markdown(model_info, unsafe_allow_html=True) | |
| # Load models on first run or when settings change | |
| def load_pipeline(_symptom_model, _risk_model, _recommendation_model, _device): | |
| return MedicalConsultationPipeline( | |
| symptom_model=_symptom_model, | |
| risk_model=_risk_model, | |
| recommendation_model=_recommendation_model, | |
| device=_device | |
| ) | |
| # Only load models if they haven't been loaded yet | |
| if not st.session_state.loaded_models: | |
| try: | |
| with st.spinner("Loading AI models... This may take a minute..."): | |
| pipeline = load_pipeline(symptom_model, risk_model, recommendation_model, device) | |
| st.session_state.pipeline = pipeline | |
| st.session_state.loaded_models = True | |
| st.success("✅ Models loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error loading models: {str(e)}") | |
| else: | |
| pipeline = st.session_state.pipeline | |
| # Input section | |
| st.markdown("<h2 class='subheader'>Describe Your Symptoms</h2>", unsafe_allow_html=True) | |
| # Text input for patient description | |
| patient_input = st.text_area( | |
| "Please describe your symptoms, including when they started and any other relevant information:", | |
| height=150, | |
| placeholder="Example: I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous." | |
| ) | |
| # Process button | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| with col2: | |
| process_button = st.button("Analyze Symptoms", type="primary", use_container_width=True) | |
| # Handle processing | |
| if process_button and patient_input and not st.session_state.is_processing: | |
| st.session_state.is_processing = True | |
| # Process the input | |
| with st.spinner("Analyzing your symptoms..."): | |
| try: | |
| # Process through pipeline | |
| start_time = time.time() | |
| result = pipeline.process(patient_input) | |
| elapsed_time = time.time() - start_time | |
| # Save result to session state | |
| st.session_state.current_result = result | |
| # Save consultation to history | |
| save_consultation(result) | |
| # Success message | |
| st.success(f"Analysis completed in {elapsed_time:.2f} seconds!") | |
| except Exception as e: | |
| st.error(f"Error processing your input: {str(e)}") | |
| st.session_state.is_processing = False | |
| # Results section - show if there's a current result | |
| if st.session_state.current_result: | |
| result = st.session_state.current_result | |
| st.markdown("<h2 class='subheader'>Consultation Results</h2>", unsafe_allow_html=True) | |
| # Create tabs for different sections of the results | |
| tabs = st.tabs(["Overview", "Symptoms Analysis", "Risk Assessment", "Recommendations"]) | |
| # Overview tab - summary of all results | |
| with tabs[0]: | |
| col1, col2 = st.columns([3, 2]) | |
| with col1: | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Patient Description</h3> | |
| """, unsafe_allow_html=True) | |
| # Highlight symptoms and duration in the text | |
| highlighted_text = highlight_text_with_entities( | |
| result.get("input_text", ""), | |
| result.get("extraction", {}).get("symptoms", []) | |
| ) | |
| st.markdown(f"<p>{highlighted_text}</p>", unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Recommendations card | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Medical Recommendations</h3> | |
| <div class="recommendation-container"> | |
| """, unsafe_allow_html=True) | |
| recommendation = result.get("recommendation", "No recommendations available.") | |
| st.markdown(f"<p>{recommendation}</p>", unsafe_allow_html=True) | |
| st.markdown(""" | |
| </div> | |
| <p><em>Note: This is AI-generated guidance and should not replace professional medical advice.</em></p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| # Risk level card | |
| risk_level = result.get("risk", {}).get("risk_level", "Unknown") | |
| confidence = result.get("risk", {}).get("confidence", 0.0) | |
| st.markdown(f""" | |
| <div class="card"> | |
| <h3 class="card-header">Risk Assessment</h3> | |
| <div style="text-align: center;"> | |
| <span class="risk-{risk_level.lower()}" style="font-size: 1.8rem;">{risk_level}</span> | |
| <p>Confidence: {confidence:.1%}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Add risk gauge | |
| risk_gauge = create_risk_gauge(risk_level, confidence) | |
| st.plotly_chart(risk_gauge, use_container_width=True, key="overview_risk_gauge") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Extracted symptoms summary | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Key Findings</h3> | |
| """, unsafe_allow_html=True) | |
| symptoms = result.get("extraction", {}).get("symptoms", []) | |
| duration = result.get("extraction", {}).get("duration", []) | |
| if symptoms: | |
| st.markdown("<strong>Identified Symptoms:</strong>", unsafe_allow_html=True) | |
| for symptom in symptoms: | |
| st.markdown(f"• {symptom['text']} ({symptom['score']:.1%} confidence)", unsafe_allow_html=True) | |
| else: | |
| st.info("No specific symptoms identified") | |
| st.markdown("<br><strong>Duration Information:</strong>", unsafe_allow_html=True) | |
| st.markdown(f"<p>{format_duration(duration)}</p>", unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Symptoms Analysis tab | |
| with tabs[1]: | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Detailed Symptom Analysis</h3> | |
| """, unsafe_allow_html=True) | |
| symptoms = result.get("extraction", {}).get("symptoms", []) | |
| if symptoms: | |
| # Create a DataFrame for symptoms | |
| symptom_df = pd.DataFrame([ | |
| { | |
| "Symptom": s["text"], | |
| "Confidence": s["score"], | |
| "Start Position": s["start"], | |
| "End Position": s["end"] | |
| } for s in symptoms | |
| ]) | |
| # Sort by confidence | |
| symptom_df = symptom_df.sort_values("Confidence", ascending=False) | |
| # Display DataFrame | |
| st.dataframe(symptom_df, use_container_width=True) | |
| # Bar chart of symptoms by confidence | |
| if len(symptoms) > 1: | |
| st.markdown("<h4>Symptom Confidence Scores</h4>", unsafe_allow_html=True) | |
| chart_data = symptom_df[["Symptom", "Confidence"]].set_index("Symptom") | |
| st.bar_chart(chart_data, use_container_width=True) | |
| else: | |
| st.info("No specific symptoms were detected in the input text.") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Duration information card | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Duration Analysis</h3> | |
| """, unsafe_allow_html=True) | |
| duration = result.get("extraction", {}).get("duration", []) | |
| if duration: | |
| # Create a DataFrame for duration information | |
| duration_df = pd.DataFrame([ | |
| { | |
| "Duration": d["text"], | |
| "Start Position": d["start"], | |
| "End Position": d["end"] | |
| } for d in duration | |
| ]) | |
| # Display DataFrame | |
| st.dataframe(duration_df, use_container_width=True) | |
| # Highlight duration in text | |
| st.markdown("<h4>Original Text with Duration Highlighted</h4>", unsafe_allow_html=True) | |
| # Highlight duration in a different color | |
| duration_text = result.get("input_text", "") | |
| sorted_duration = sorted(duration, key=lambda x: x['start'], reverse=True) | |
| for d in sorted_duration: | |
| start = d['start'] | |
| end = d['end'] | |
| highlight = f"<span class='duration-highlight'>{duration_text[start:end]}</span>" | |
| duration_text = duration_text[:start] + highlight + duration_text[end:] | |
| st.markdown(f"<p>{duration_text}</p>", unsafe_allow_html=True) | |
| else: | |
| st.info("No specific duration information was detected in the input text.") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Risk Assessment tab | |
| with tabs[2]: | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Risk Level Assessment</h3> | |
| """, unsafe_allow_html=True) | |
| risk_data = result.get("risk", {}) | |
| risk_level = risk_data.get("risk_level", "Unknown") | |
| confidence = risk_data.get("confidence", 0.0) | |
| probabilities = risk_data.get("all_probabilities", {}) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Display risk gauge | |
| risk_gauge = create_risk_gauge(risk_level, confidence) | |
| st.plotly_chart(risk_gauge, use_container_width=True, key="risk_assessment_gauge") | |
| with col2: | |
| # Display probability distribution | |
| prob_chart = create_risk_probability_chart(probabilities) | |
| st.plotly_chart(prob_chart, use_container_width=True, key="risk_probability_chart") | |
| # Risk level descriptions | |
| st.markdown("<h4>Risk Levels Explained</h4>", unsafe_allow_html=True) | |
| risk_descriptions = { | |
| "Low": """ | |
| <div style="border-left: 3px solid #7FD8BE; padding-left: 10px; margin: 10px 0;"> | |
| <strong style="color: #7FD8BE;">Low Risk</strong>: Your symptoms suggest a condition that is likely non-urgent. | |
| While it's good to stay vigilant, these types of conditions typically don't require immediate medical attention | |
| and can often be managed with self-care or a routine appointment within the next few days or weeks. | |
| </div> | |
| """, | |
| "Medium": """ | |
| <div style="border-left: 3px solid #FFC857; padding-left: 10px; margin: 10px 0;"> | |
| <strong style="color: #FFC857;">Medium Risk</strong>: Your symptoms indicate a condition that may need medical attention | |
| soon, but may not be an emergency. Consider scheduling an appointment with your primary care provider within 24-48 hours, | |
| or visit an urgent care facility if your symptoms worsen or if you cannot schedule a timely appointment. | |
| </div> | |
| """, | |
| "High": """ | |
| <div style="border-left: 3px solid #E84855; padding-left: 10px; margin: 10px 0;"> | |
| <strong style="color: #E84855;">High Risk</strong>: Your symptoms suggest a potentially serious condition that requires | |
| prompt medical attention. Consider seeking emergency care or calling emergency services if symptoms are severe or rapidly | |
| worsening, especially if they include difficulty breathing, severe pain, or altered consciousness. | |
| </div> | |
| """ | |
| } | |
| # Display the description for the current risk level first | |
| if risk_level in risk_descriptions: | |
| st.markdown(risk_descriptions[risk_level], unsafe_allow_html=True) | |
| # Then display the others | |
| for level, desc in risk_descriptions.items(): | |
| if level != risk_level: | |
| st.markdown(desc, unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Disclaimer | |
| st.warning(""" | |
| **Important Disclaimer**: This risk assessment is based on AI analysis and should be used as a guidance only. | |
| It is not a definitive medical diagnosis. Always consult with a healthcare professional for proper evaluation, | |
| especially if you experience severe symptoms, symptoms that persist or worsen, or if you're unsure about your condition. | |
| """) | |
| # Recommendations tab | |
| with tabs[3]: | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Detailed Recommendations</h3> | |
| """, unsafe_allow_html=True) | |
| recommendation = result.get("recommendation", "No recommendations available.") | |
| # Split recommendation into paragraphs for better readability | |
| recommendation_parts = recommendation.split('. ') | |
| formatted_recommendation = "" | |
| current_paragraph = [] | |
| for part in recommendation_parts: | |
| current_paragraph.append(part) | |
| # Start a new paragraph every 2-3 sentences | |
| if len(current_paragraph) >= 2 and part.endswith('.'): | |
| formatted_recommendation += '. '.join(current_paragraph) + ".<br><br>" | |
| current_paragraph = [] | |
| # Add any remaining parts | |
| if current_paragraph: | |
| formatted_recommendation += '. '.join(current_paragraph) | |
| st.markdown(f"<p>{formatted_recommendation}</p>", unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Department suggestion based on symptoms | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Suggested Medical Departments</h3> | |
| """, unsafe_allow_html=True) | |
| # 使用模型生成的科室建议而不是规则基础的建议 | |
| departments = result.get("structured_recommendation", {}).get("departments", []) | |
| if not departments: | |
| departments = ["General Medicine / Primary Care"] | |
| # Display departments | |
| for dept in departments: | |
| st.markdown(f"• **{dept}**", unsafe_allow_html=True) | |
| st.markdown("<br><em>Note: Department suggestions are based on your symptoms and risk level. Consult with a healthcare provider for proper referral.</em>", unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Self-care suggestions | |
| st.markdown(""" | |
| <div class="card"> | |
| <h3 class="card-header">Self-Care Suggestions</h3> | |
| """, unsafe_allow_html=True) | |
| # 使用模型生成的自我护理建议 | |
| self_care_tips = result.get("structured_recommendation", {}).get("self_care", []) | |
| if self_care_tips: | |
| st.markdown("<ul>", unsafe_allow_html=True) | |
| for tip in self_care_tips: | |
| st.markdown(f"<li>{tip}</li>", unsafe_allow_html=True) | |
| st.markdown("</ul>", unsafe_allow_html=True) | |
| else: | |
| # 如果没有获取到模型生成的自我护理建议,则显示默认信息 | |
| risk_level = result.get("risk", {}).get("risk_level", "Medium") | |
| if risk_level == "Low": | |
| st.markdown(""" | |
| <p>While waiting for your symptoms to improve:</p> | |
| <ul> | |
| <li>Ensure you're getting adequate rest</li> | |
| <li>Stay hydrated by drinking plenty of water</li> | |
| <li>Monitor your symptoms and note any changes</li> | |
| <li>Consider over-the-counter medications appropriate for your symptoms</li> | |
| <li>Maintain a balanced diet to support your immune system</li> | |
| </ul> | |
| """, unsafe_allow_html=True) | |
| elif risk_level == "Medium": | |
| st.markdown(""" | |
| <p>While arranging medical care:</p> | |
| <ul> | |
| <li>Rest and avoid strenuous activities</li> | |
| <li>Stay hydrated and maintain proper nutrition</li> | |
| <li>Take your temperature and other vital signs if possible</li> | |
| <li>Write down any changes in symptoms and when they occur</li> | |
| <li>Have someone stay with you if your symptoms are concerning</li> | |
| </ul> | |
| """, unsafe_allow_html=True) | |
| else: # High risk | |
| st.markdown(""" | |
| <p>While seeking emergency care:</p> | |
| <ul> | |
| <li>Don't wait - seek medical attention immediately</li> | |
| <li>Have someone drive you to the emergency room if safe to do so</li> | |
| <li>Call emergency services if symptoms are severe</li> | |
| <li>Bring a list of your current medications if possible</li> | |
| <li>Follow any first aid protocols appropriate for your symptoms</li> | |
| </ul> | |
| """, unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Footer | |
| st.markdown(""" | |
| <div class="footer"> | |
| <p>AI Medical Consultation System | Created with Streamlit | Not a substitute for professional medical advice</p> | |
| <p>Powered by: dmis-lab/biobert-v1.1, microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract, and fine-tuned T5-small</p> | |
| </div> | |
| """, unsafe_allow_html=True) |