"""Tests for probability aggregation utilities.""" import pytest from sentinel.models import RiskScore from sentinel.probability_aggregation import ( AggregatedRisk, aggregate_probabilities, categorize_risk, get_display_cancer_type, normalize_cancer_type, separate_score_types, ) class TestAggregateProbabilities: """Test probability aggregation functionality.""" def test_single_model_per_cancer_type(self): """Test aggregation with one model per cancer type.""" scores = [ RiskScore( name="Gail", score="1.5%", cancer_type="breast", probability_percent=1.5, time_horizon_years=5.0, score_type="probability", ), RiskScore( name="PLCOm2012", score="2.3%", cancer_type="lung", probability_percent=2.3, time_horizon_years=6.0, score_type="probability", ), ] aggregated = aggregate_probabilities(scores) assert len(aggregated) == 2 # Check breast cancer aggregation breast_agg = next(agg for agg in aggregated if agg.cancer_type == "breast") assert breast_agg.time_horizon_years == 5.0 assert breast_agg.avg_probability_percent == 1.5 assert breast_agg.risk_category == "Moderate" # 1.5% for 5-year horizon assert breast_agg.model_count == 1 assert len(breast_agg.individual_scores) == 1 assert breast_agg.individual_scores[0].name == "Gail" # Check lung cancer aggregation lung_agg = next(agg for agg in aggregated if agg.cancer_type == "lung") assert lung_agg.time_horizon_years == 6.0 assert lung_agg.avg_probability_percent == 2.3 assert lung_agg.risk_category == "Moderate" # 2.3% for 6-year horizon assert lung_agg.model_count == 1 def test_multiple_models_same_cancer_same_horizon(self): """Test aggregation when multiple models assess same cancer with same time horizon.""" scores = [ RiskScore( name="Model1", score="2.0%", cancer_type="breast", probability_percent=2.0, time_horizon_years=10.0, score_type="probability", ), RiskScore( name="Model2", score="3.0%", cancer_type="breast", probability_percent=3.0, time_horizon_years=10.0, score_type="probability", ), RiskScore( name="Model3", score="4.0%", cancer_type="breast", probability_percent=4.0, time_horizon_years=10.0, score_type="probability", ), ] aggregated = aggregate_probabilities(scores) assert len(aggregated) == 1 agg = aggregated[0] assert agg.cancer_type == "breast" assert agg.time_horizon_years == 10.0 assert agg.avg_probability_percent == pytest.approx(3.0) # (2+3+4)/3 assert agg.risk_category == "Moderate" # 3.0% for 10-year horizon assert agg.model_count == 3 assert len(agg.individual_scores) == 3 def test_multiple_models_same_cancer_different_horizons(self): """Test aggregation with same cancer type but different time horizons.""" scores = [ RiskScore( name="Model1", score="1.5%", cancer_type="breast", probability_percent=1.5, time_horizon_years=5.0, score_type="probability", ), RiskScore( name="Model2", score="3.0%", cancer_type="breast", probability_percent=3.0, time_horizon_years=10.0, score_type="probability", ), RiskScore( name="Model3", score="15.0%", cancer_type="breast", probability_percent=15.0, time_horizon_years=79.0, score_type="probability", ), ] aggregated = aggregate_probabilities(scores) assert len(aggregated) == 3 # Three different time horizons # Verify each time horizon is separate horizons = {agg.time_horizon_years for agg in aggregated} assert horizons == {5.0, 10.0, 79.0} # Verify each has single model for agg in aggregated: assert agg.model_count == 1 def test_excludes_non_probability_scores(self): """Test that non-probability scores are excluded from aggregation.""" scores = [ RiskScore( name="Gail", score="1.5%", cancer_type="breast", probability_percent=1.5, time_horizon_years=5.0, score_type="probability", ), RiskScore( name="PCPT", score="No Cancer: 45%, Low Grade: 30%, High Grade: 25%", cancer_type="prostate", probability_percent=None, time_horizon_years=None, score_type="categorical", ), RiskScore( name="Model", score="N/A: Age out of range", cancer_type="lung", probability_percent=None, time_horizon_years=None, score_type="not_applicable", ), ] aggregated = aggregate_probabilities(scores) assert len(aggregated) == 1 assert aggregated[0].cancer_type == "breast" def test_empty_list(self): """Test aggregation with empty score list.""" aggregated = aggregate_probabilities([]) assert aggregated == [] def test_all_non_probability_scores(self): """Test aggregation when all scores are non-probability.""" scores = [ RiskScore( name="PCPT", score="Results", cancer_type="prostate", score_type="categorical", ), RiskScore( name="Model", score="N/A", cancer_type="lung", score_type="not_applicable", ), ] aggregated = aggregate_probabilities(scores) assert aggregated == [] def test_case_insensitive_cancer_type_grouping(self): """Test that cancer types are grouped case-insensitively.""" scores = [ RiskScore( name="Model1", score="1.5%", cancer_type="Breast", probability_percent=1.5, time_horizon_years=5.0, score_type="probability", ), RiskScore( name="Model2", score="1.8%", cancer_type="breast", probability_percent=1.8, time_horizon_years=5.0, score_type="probability", ), RiskScore( name="Model3", score="1.7%", cancer_type="BREAST", probability_percent=1.7, time_horizon_years=5.0, score_type="probability", ), ] aggregated = aggregate_probabilities(scores) assert len(aggregated) == 1 assert aggregated[0].cancer_type == "breast" # normalized to lowercase assert aggregated[0].model_count == 3 assert aggregated[0].avg_probability_percent == pytest.approx(1.6667, abs=0.001) class TestSeparateScoreTypes: """Test score type separation functionality.""" def test_separate_all_types(self): """Test separation of all three score types.""" scores = [ RiskScore( name="Gail", score="1.5%", score_type="probability", ), RiskScore( name="BOADICEA", score="2.0%", score_type="probability", ), RiskScore( name="PCPT", score="No Cancer: 45%", score_type="categorical", ), RiskScore( name="Model", score="N/A: Age out of range", score_type="not_applicable", ), RiskScore( name="Model2", score="N/A: Invalid", score_type="not_applicable", ), ] separated = separate_score_types(scores) assert len(separated["probability"]) == 2 assert len(separated["categorical"]) == 1 assert len(separated["not_applicable"]) == 2 def test_empty_list(self): """Test separation with empty list.""" separated = separate_score_types([]) assert separated["probability"] == [] assert separated["categorical"] == [] assert separated["not_applicable"] == [] def test_only_probabilities(self): """Test separation when all scores are probabilities.""" scores = [ RiskScore(name="Model1", score="1%", score_type="probability"), RiskScore(name="Model2", score="2%", score_type="probability"), ] separated = separate_score_types(scores) assert len(separated["probability"]) == 2 assert separated["categorical"] == [] assert separated["not_applicable"] == [] class TestFilterFunctions: """Test individual filter functions.""" def test_separate_score_types_for_probability(self): """Test using separate_score_types to get probability scores.""" scores = [ RiskScore(name="Model1", score="1%", score_type="probability"), RiskScore(name="Model2", score="Result", score_type="categorical"), RiskScore(name="Model3", score="2%", score_type="probability"), ] separated = separate_score_types(scores) assert len(separated["probability"]) == 2 assert all( score.score_type == "probability" for score in separated["probability"] ) def test_separate_score_types_for_categorical(self): """Test using separate_score_types to get categorical scores.""" scores = [ RiskScore(name="Model1", score="1%", score_type="probability"), RiskScore(name="Model2", score="Result", score_type="categorical"), RiskScore(name="Model3", score="N/A", score_type="not_applicable"), ] separated = separate_score_types(scores) assert len(separated["categorical"]) == 1 assert separated["categorical"][0].score_type == "categorical" def test_separate_score_types_for_not_applicable(self): """Test using separate_score_types to get not_applicable scores.""" scores = [ RiskScore(name="Model1", score="1%", score_type="probability"), RiskScore( name="Model2", score="N/A: Reason 1", score_type="not_applicable" ), RiskScore( name="Model3", score="N/A: Reason 2", score_type="not_applicable" ), ] separated = separate_score_types(scores) assert len(separated["not_applicable"]) == 2 assert all( score.score_type == "not_applicable" for score in separated["not_applicable"] ) def test_separate_score_types_for_all_types(self): """Test using separate_score_types to get all score types at once.""" scores = [ RiskScore(name="Model1", score="1%", score_type="probability"), RiskScore(name="Model2", score="2%", score_type="probability"), RiskScore(name="Model3", score="Result", score_type="categorical"), RiskScore(name="Model4", score="N/A: Age", score_type="not_applicable"), ] separated = separate_score_types(scores) assert len(separated["probability"]) == 2 assert len(separated["categorical"]) == 1 assert len(separated["not_applicable"]) == 1 class TestAggregatedRiskDataclass: """Test the AggregatedRisk dataclass.""" def test_dataclass_creation(self): """Test creating an AggregatedRisk object.""" score = RiskScore( name="Gail", score="1.5%", cancer_type="breast", probability_percent=1.5, time_horizon_years=5.0, score_type="probability", ) agg = AggregatedRisk( cancer_type="breast", time_horizon_years=5.0, avg_probability_percent=1.5, risk_category="Low", model_count=1, individual_scores=[score], ) assert agg.cancer_type == "breast" assert agg.time_horizon_years == 5.0 assert agg.avg_probability_percent == 1.5 assert agg.risk_category == "Low" assert agg.model_count == 1 assert len(agg.individual_scores) == 1 class TestNormalizeCancerType: """Test cancer type normalization.""" def test_normalize_with_cancer_suffix(self): """Test removing 'cancer' suffix.""" assert normalize_cancer_type("Breast Cancer") == "breast" assert normalize_cancer_type("Lung cancer") == "lung" assert normalize_cancer_type("PROSTATE CANCER") == "prostate" def test_normalize_without_cancer_suffix(self): """Test normalization without 'cancer' suffix.""" assert normalize_cancer_type("Breast") == "breast" assert normalize_cancer_type("LUNG") == "lung" assert normalize_cancer_type("Prostate") == "prostate" def test_normalize_with_whitespace(self): """Test trimming whitespace.""" assert normalize_cancer_type(" Breast Cancer ") == "breast" assert normalize_cancer_type("Lung cancer") == "lung" def test_normalize_empty_string(self): """Test empty string.""" assert normalize_cancer_type("") == "" def test_display_cancer_type(self): """Test display-friendly cancer type names.""" assert get_display_cancer_type("breast") == "Breast" assert get_display_cancer_type("lung") == "Lung" assert get_display_cancer_type("prostate") == "Prostate" class TestCategorizeRisk: """Test risk categorization.""" def test_categorize_short_horizon_very_low(self): """Test very low risk for short time horizon.""" assert categorize_risk(0.3, 5.0) == "Very Low" def test_categorize_short_horizon_low(self): """Test low risk for short time horizon.""" assert categorize_risk(1.0, 5.0) == "Low" def test_categorize_short_horizon_moderate(self): """Test moderate risk for short time horizon.""" assert categorize_risk(2.0, 5.0) == "Moderate" def test_categorize_short_horizon_moderately_high(self): """Test moderately high risk for short time horizon.""" assert categorize_risk(4.0, 5.0) == "Moderately High" def test_categorize_short_horizon_high(self): """Test high risk for short time horizon.""" assert categorize_risk(6.0, 5.0) == "High" def test_categorize_long_horizon_very_low(self): """Test very low risk for long time horizon.""" assert categorize_risk(0.5, 10.0) == "Very Low" def test_categorize_long_horizon_low(self): """Test low risk for long time horizon.""" assert categorize_risk(2.0, 10.0) == "Low" def test_categorize_long_horizon_moderate(self): """Test moderate risk for long time horizon.""" assert categorize_risk(5.0, 10.0) == "Moderate" def test_categorize_long_horizon_moderately_high(self): """Test moderately high risk for long time horizon.""" assert categorize_risk(10.0, 10.0) == "Moderately High" def test_categorize_long_horizon_high(self): """Test high risk for long time horizon.""" assert categorize_risk(20.0, 10.0) == "High" def test_categorize_lifetime_risk(self): """Test categorization for lifetime risk.""" assert categorize_risk(0.5, 79.0) == "Very Low" assert categorize_risk(2.0, 79.0) == "Low" assert categorize_risk(5.0, 79.0) == "Moderate" assert categorize_risk(12.0, 79.0) == "Moderately High" assert categorize_risk(20.0, 79.0) == "High"