sentinel / tests /test_probability_aggregation.py
jeuko's picture
Sync from GitHub (main)
94a0f4c verified
"""Tests for probability aggregation utilities."""
import pytest
from sentinel.models import RiskScore
from sentinel.probability_aggregation import (
AggregatedRisk,
aggregate_probabilities,
categorize_risk,
get_display_cancer_type,
normalize_cancer_type,
separate_score_types,
)
class TestAggregateProbabilities:
"""Test probability aggregation functionality."""
def test_single_model_per_cancer_type(self):
"""Test aggregation with one model per cancer type."""
scores = [
RiskScore(
name="Gail",
score="1.5%",
cancer_type="breast",
probability_percent=1.5,
time_horizon_years=5.0,
score_type="probability",
),
RiskScore(
name="PLCOm2012",
score="2.3%",
cancer_type="lung",
probability_percent=2.3,
time_horizon_years=6.0,
score_type="probability",
),
]
aggregated = aggregate_probabilities(scores)
assert len(aggregated) == 2
# Check breast cancer aggregation
breast_agg = next(agg for agg in aggregated if agg.cancer_type == "breast")
assert breast_agg.time_horizon_years == 5.0
assert breast_agg.avg_probability_percent == 1.5
assert breast_agg.risk_category == "Moderate" # 1.5% for 5-year horizon
assert breast_agg.model_count == 1
assert len(breast_agg.individual_scores) == 1
assert breast_agg.individual_scores[0].name == "Gail"
# Check lung cancer aggregation
lung_agg = next(agg for agg in aggregated if agg.cancer_type == "lung")
assert lung_agg.time_horizon_years == 6.0
assert lung_agg.avg_probability_percent == 2.3
assert lung_agg.risk_category == "Moderate" # 2.3% for 6-year horizon
assert lung_agg.model_count == 1
def test_multiple_models_same_cancer_same_horizon(self):
"""Test aggregation when multiple models assess same cancer with same time horizon."""
scores = [
RiskScore(
name="Model1",
score="2.0%",
cancer_type="breast",
probability_percent=2.0,
time_horizon_years=10.0,
score_type="probability",
),
RiskScore(
name="Model2",
score="3.0%",
cancer_type="breast",
probability_percent=3.0,
time_horizon_years=10.0,
score_type="probability",
),
RiskScore(
name="Model3",
score="4.0%",
cancer_type="breast",
probability_percent=4.0,
time_horizon_years=10.0,
score_type="probability",
),
]
aggregated = aggregate_probabilities(scores)
assert len(aggregated) == 1
agg = aggregated[0]
assert agg.cancer_type == "breast"
assert agg.time_horizon_years == 10.0
assert agg.avg_probability_percent == pytest.approx(3.0) # (2+3+4)/3
assert agg.risk_category == "Moderate" # 3.0% for 10-year horizon
assert agg.model_count == 3
assert len(agg.individual_scores) == 3
def test_multiple_models_same_cancer_different_horizons(self):
"""Test aggregation with same cancer type but different time horizons."""
scores = [
RiskScore(
name="Model1",
score="1.5%",
cancer_type="breast",
probability_percent=1.5,
time_horizon_years=5.0,
score_type="probability",
),
RiskScore(
name="Model2",
score="3.0%",
cancer_type="breast",
probability_percent=3.0,
time_horizon_years=10.0,
score_type="probability",
),
RiskScore(
name="Model3",
score="15.0%",
cancer_type="breast",
probability_percent=15.0,
time_horizon_years=79.0,
score_type="probability",
),
]
aggregated = aggregate_probabilities(scores)
assert len(aggregated) == 3 # Three different time horizons
# Verify each time horizon is separate
horizons = {agg.time_horizon_years for agg in aggregated}
assert horizons == {5.0, 10.0, 79.0}
# Verify each has single model
for agg in aggregated:
assert agg.model_count == 1
def test_excludes_non_probability_scores(self):
"""Test that non-probability scores are excluded from aggregation."""
scores = [
RiskScore(
name="Gail",
score="1.5%",
cancer_type="breast",
probability_percent=1.5,
time_horizon_years=5.0,
score_type="probability",
),
RiskScore(
name="PCPT",
score="No Cancer: 45%, Low Grade: 30%, High Grade: 25%",
cancer_type="prostate",
probability_percent=None,
time_horizon_years=None,
score_type="categorical",
),
RiskScore(
name="Model",
score="N/A: Age out of range",
cancer_type="lung",
probability_percent=None,
time_horizon_years=None,
score_type="not_applicable",
),
]
aggregated = aggregate_probabilities(scores)
assert len(aggregated) == 1
assert aggregated[0].cancer_type == "breast"
def test_empty_list(self):
"""Test aggregation with empty score list."""
aggregated = aggregate_probabilities([])
assert aggregated == []
def test_all_non_probability_scores(self):
"""Test aggregation when all scores are non-probability."""
scores = [
RiskScore(
name="PCPT",
score="Results",
cancer_type="prostate",
score_type="categorical",
),
RiskScore(
name="Model",
score="N/A",
cancer_type="lung",
score_type="not_applicable",
),
]
aggregated = aggregate_probabilities(scores)
assert aggregated == []
def test_case_insensitive_cancer_type_grouping(self):
"""Test that cancer types are grouped case-insensitively."""
scores = [
RiskScore(
name="Model1",
score="1.5%",
cancer_type="Breast",
probability_percent=1.5,
time_horizon_years=5.0,
score_type="probability",
),
RiskScore(
name="Model2",
score="1.8%",
cancer_type="breast",
probability_percent=1.8,
time_horizon_years=5.0,
score_type="probability",
),
RiskScore(
name="Model3",
score="1.7%",
cancer_type="BREAST",
probability_percent=1.7,
time_horizon_years=5.0,
score_type="probability",
),
]
aggregated = aggregate_probabilities(scores)
assert len(aggregated) == 1
assert aggregated[0].cancer_type == "breast" # normalized to lowercase
assert aggregated[0].model_count == 3
assert aggregated[0].avg_probability_percent == pytest.approx(1.6667, abs=0.001)
class TestSeparateScoreTypes:
"""Test score type separation functionality."""
def test_separate_all_types(self):
"""Test separation of all three score types."""
scores = [
RiskScore(
name="Gail",
score="1.5%",
score_type="probability",
),
RiskScore(
name="BOADICEA",
score="2.0%",
score_type="probability",
),
RiskScore(
name="PCPT",
score="No Cancer: 45%",
score_type="categorical",
),
RiskScore(
name="Model",
score="N/A: Age out of range",
score_type="not_applicable",
),
RiskScore(
name="Model2",
score="N/A: Invalid",
score_type="not_applicable",
),
]
separated = separate_score_types(scores)
assert len(separated["probability"]) == 2
assert len(separated["categorical"]) == 1
assert len(separated["not_applicable"]) == 2
def test_empty_list(self):
"""Test separation with empty list."""
separated = separate_score_types([])
assert separated["probability"] == []
assert separated["categorical"] == []
assert separated["not_applicable"] == []
def test_only_probabilities(self):
"""Test separation when all scores are probabilities."""
scores = [
RiskScore(name="Model1", score="1%", score_type="probability"),
RiskScore(name="Model2", score="2%", score_type="probability"),
]
separated = separate_score_types(scores)
assert len(separated["probability"]) == 2
assert separated["categorical"] == []
assert separated["not_applicable"] == []
class TestFilterFunctions:
"""Test individual filter functions."""
def test_separate_score_types_for_probability(self):
"""Test using separate_score_types to get probability scores."""
scores = [
RiskScore(name="Model1", score="1%", score_type="probability"),
RiskScore(name="Model2", score="Result", score_type="categorical"),
RiskScore(name="Model3", score="2%", score_type="probability"),
]
separated = separate_score_types(scores)
assert len(separated["probability"]) == 2
assert all(
score.score_type == "probability" for score in separated["probability"]
)
def test_separate_score_types_for_categorical(self):
"""Test using separate_score_types to get categorical scores."""
scores = [
RiskScore(name="Model1", score="1%", score_type="probability"),
RiskScore(name="Model2", score="Result", score_type="categorical"),
RiskScore(name="Model3", score="N/A", score_type="not_applicable"),
]
separated = separate_score_types(scores)
assert len(separated["categorical"]) == 1
assert separated["categorical"][0].score_type == "categorical"
def test_separate_score_types_for_not_applicable(self):
"""Test using separate_score_types to get not_applicable scores."""
scores = [
RiskScore(name="Model1", score="1%", score_type="probability"),
RiskScore(
name="Model2", score="N/A: Reason 1", score_type="not_applicable"
),
RiskScore(
name="Model3", score="N/A: Reason 2", score_type="not_applicable"
),
]
separated = separate_score_types(scores)
assert len(separated["not_applicable"]) == 2
assert all(
score.score_type == "not_applicable"
for score in separated["not_applicable"]
)
def test_separate_score_types_for_all_types(self):
"""Test using separate_score_types to get all score types at once."""
scores = [
RiskScore(name="Model1", score="1%", score_type="probability"),
RiskScore(name="Model2", score="2%", score_type="probability"),
RiskScore(name="Model3", score="Result", score_type="categorical"),
RiskScore(name="Model4", score="N/A: Age", score_type="not_applicable"),
]
separated = separate_score_types(scores)
assert len(separated["probability"]) == 2
assert len(separated["categorical"]) == 1
assert len(separated["not_applicable"]) == 1
class TestAggregatedRiskDataclass:
"""Test the AggregatedRisk dataclass."""
def test_dataclass_creation(self):
"""Test creating an AggregatedRisk object."""
score = RiskScore(
name="Gail",
score="1.5%",
cancer_type="breast",
probability_percent=1.5,
time_horizon_years=5.0,
score_type="probability",
)
agg = AggregatedRisk(
cancer_type="breast",
time_horizon_years=5.0,
avg_probability_percent=1.5,
risk_category="Low",
model_count=1,
individual_scores=[score],
)
assert agg.cancer_type == "breast"
assert agg.time_horizon_years == 5.0
assert agg.avg_probability_percent == 1.5
assert agg.risk_category == "Low"
assert agg.model_count == 1
assert len(agg.individual_scores) == 1
class TestNormalizeCancerType:
"""Test cancer type normalization."""
def test_normalize_with_cancer_suffix(self):
"""Test removing 'cancer' suffix."""
assert normalize_cancer_type("Breast Cancer") == "breast"
assert normalize_cancer_type("Lung cancer") == "lung"
assert normalize_cancer_type("PROSTATE CANCER") == "prostate"
def test_normalize_without_cancer_suffix(self):
"""Test normalization without 'cancer' suffix."""
assert normalize_cancer_type("Breast") == "breast"
assert normalize_cancer_type("LUNG") == "lung"
assert normalize_cancer_type("Prostate") == "prostate"
def test_normalize_with_whitespace(self):
"""Test trimming whitespace."""
assert normalize_cancer_type(" Breast Cancer ") == "breast"
assert normalize_cancer_type("Lung cancer") == "lung"
def test_normalize_empty_string(self):
"""Test empty string."""
assert normalize_cancer_type("") == ""
def test_display_cancer_type(self):
"""Test display-friendly cancer type names."""
assert get_display_cancer_type("breast") == "Breast"
assert get_display_cancer_type("lung") == "Lung"
assert get_display_cancer_type("prostate") == "Prostate"
class TestCategorizeRisk:
"""Test risk categorization."""
def test_categorize_short_horizon_very_low(self):
"""Test very low risk for short time horizon."""
assert categorize_risk(0.3, 5.0) == "Very Low"
def test_categorize_short_horizon_low(self):
"""Test low risk for short time horizon."""
assert categorize_risk(1.0, 5.0) == "Low"
def test_categorize_short_horizon_moderate(self):
"""Test moderate risk for short time horizon."""
assert categorize_risk(2.0, 5.0) == "Moderate"
def test_categorize_short_horizon_moderately_high(self):
"""Test moderately high risk for short time horizon."""
assert categorize_risk(4.0, 5.0) == "Moderately High"
def test_categorize_short_horizon_high(self):
"""Test high risk for short time horizon."""
assert categorize_risk(6.0, 5.0) == "High"
def test_categorize_long_horizon_very_low(self):
"""Test very low risk for long time horizon."""
assert categorize_risk(0.5, 10.0) == "Very Low"
def test_categorize_long_horizon_low(self):
"""Test low risk for long time horizon."""
assert categorize_risk(2.0, 10.0) == "Low"
def test_categorize_long_horizon_moderate(self):
"""Test moderate risk for long time horizon."""
assert categorize_risk(5.0, 10.0) == "Moderate"
def test_categorize_long_horizon_moderately_high(self):
"""Test moderately high risk for long time horizon."""
assert categorize_risk(10.0, 10.0) == "Moderately High"
def test_categorize_long_horizon_high(self):
"""Test high risk for long time horizon."""
assert categorize_risk(20.0, 10.0) == "High"
def test_categorize_lifetime_risk(self):
"""Test categorization for lifetime risk."""
assert categorize_risk(0.5, 79.0) == "Very Low"
assert categorize_risk(2.0, 79.0) == "Low"
assert categorize_risk(5.0, 79.0) == "Moderate"
assert categorize_risk(12.0, 79.0) == "Moderately High"
assert categorize_risk(20.0, 79.0) == "High"