sentinel / tests /test_risk_models /test_amap_model.py
jeuko's picture
Sync from GitHub (main)
0ba176c verified
"""Tests for the aMAP Liver Cancer Risk Model.
Ground truth values to be collected from: https://mdac.cuhk.edu.hk/calculators/amap/
"""
import pytest
from sentinel.risk_models import AMAPRiskModel
from sentinel.user_input import (
AlbuminTest,
Anthropometrics,
BilirubinTest,
ChronicCondition,
ClinicalTests,
Demographics,
Lifestyle,
PersonalMedicalHistory,
PlateletTest,
Sex,
SmokingHistory,
SmokingStatus,
UserInput,
)
GROUND_TRUTH_CASES = [
{
"name": "young_male_low_risk",
"input": UserInput(
demographics=Demographics(
age_years=35,
sex=Sex.MALE,
anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=70.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[ChronicCondition.CHRONIC_HEPATITIS_B]
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=42.0),
bilirubin=BilirubinTest(value_umol_per_L=12.0),
platelets=PlateletTest(value_10e9_per_L=200.0),
),
),
"expected": 0.8, # From web calculator: Score 48, Risk: Low, 5-year HCC risk 0.8%
},
{
"name": "middle_aged_female_medium_risk",
"input": UserInput(
demographics=Demographics(
age_years=50,
sex=Sex.FEMALE,
anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[ChronicCondition.CHRONIC_HEPATITIS_B]
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=35.0),
bilirubin=BilirubinTest(value_umol_per_L=20.0),
platelets=PlateletTest(value_10e9_per_L=120.0),
),
),
"expected": 4.2, # From web calculator: Score 55, Risk: Intermediate, 5-year HCC risk 4.2%
},
{
"name": "elderly_male_high_risk",
"input": UserInput(
demographics=Demographics(
age_years=70,
sex=Sex.MALE,
anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=70.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[ChronicCondition.CHRONIC_HEPATITIS_B]
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=28.0),
bilirubin=BilirubinTest(value_umol_per_L=40.0),
platelets=PlateletTest(value_10e9_per_L=80.0),
),
),
"expected": 19.9, # From web calculator: Score 75, Risk: High, 5-year HCC risk 19.9%
},
{
"name": "edge_case_low_platelets",
"input": UserInput(
demographics=Demographics(
age_years=45,
sex=Sex.FEMALE,
anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[ChronicCondition.CHRONIC_HEPATITIS_B]
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=38.0),
bilirubin=BilirubinTest(value_umol_per_L=15.0),
platelets=PlateletTest(value_10e9_per_L=50.0),
),
),
"expected": 4.2, # From web calculator: Score 57, Risk: Intermediate, 5-year HCC risk 4.2%
},
{
"name": "edge_case_high_bilirubin",
"input": UserInput(
demographics=Demographics(
age_years=60,
sex=Sex.MALE,
anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=70.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[ChronicCondition.CHRONIC_HEPATITIS_B]
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=32.0),
bilirubin=BilirubinTest(value_umol_per_L=60.0),
platelets=PlateletTest(value_10e9_per_L=100.0),
),
),
"expected": 19.9, # From web calculator: Score 69, Risk: High, 5-year HCC risk 19.9%
},
{
"name": "boundary_case_low_medium",
"input": UserInput(
demographics=Demographics(
age_years=40,
sex=Sex.FEMALE,
anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[ChronicCondition.CHRONIC_HEPATITIS_B]
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=36.0),
bilirubin=BilirubinTest(value_umol_per_L=18.0),
platelets=PlateletTest(value_10e9_per_L=140.0),
),
),
"expected": 0.8, # Actual score 49.6 (web shows rounded 50) → Low risk (<50) → 0.8%
},
{
"name": "boundary_case_medium_high",
"input": UserInput(
demographics=Demographics(
age_years=55,
sex=Sex.MALE,
anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=70.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[ChronicCondition.CHRONIC_HEPATITIS_B]
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=33.0),
bilirubin=BilirubinTest(value_umol_per_L=25.0),
platelets=PlateletTest(value_10e9_per_L=110.0),
),
),
"expected": 19.9, # From web calculator: Score 65, Risk: High, 5-year HCC risk 19.9%
},
]
class TestAMAPModel:
"""Test suite for AMAPRiskModel."""
def setup_method(self) -> None:
"""Initialize AMAPRiskModel instance for testing."""
self.model = AMAPRiskModel()
@pytest.mark.parametrize("case", GROUND_TRUTH_CASES, ids=lambda x: x["name"])
def test_ground_truth_placeholders(self, case):
"""Placeholder test for ground truth validation.
Once expected values are filled in from the web calculator,
this will validate our implementation against known reference values.
Args:
case: Parameterized ground truth case dict.
"""
user_input = case["input"]
score_str = self.model.compute_score(user_input)
# Verify we get a valid output format
assert isinstance(score_str, str)
assert score_str.endswith("%")
# Validate against expected percentage
expected_percent = case["expected"]
if expected_percent is not None:
# Extract probability percentage from output
actual_percent = float(score_str.rstrip("%"))
# Should exactly match expected percentage
assert actual_percent == pytest.approx(expected_percent, abs=0.1)
def test_compute_score_male_with_hep_b(self):
"""Test male patient with chronic hepatitis B."""
user = UserInput(
demographics=Demographics(
age_years=55,
sex=Sex.MALE,
anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=70.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[ChronicCondition.CHRONIC_HEPATITIS_B]
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=40.0),
bilirubin=BilirubinTest(value_umol_per_L=14.0),
platelets=PlateletTest(value_10e9_per_L=180.0),
),
)
score_str = self.model.compute_score(user)
assert score_str.endswith("%")
# Should be a valid percentage
assert float(score_str.rstrip("%")) > 0
def test_compute_score_female_without_hep_b(self):
"""Test female patient without chronic hepatitis B."""
user = UserInput(
demographics=Demographics(
age_years=45,
sex=Sex.FEMALE,
anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(
chronic_conditions=[] # No hepatitis B
),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=38.0),
bilirubin=BilirubinTest(value_umol_per_L=15.0),
platelets=PlateletTest(value_10e9_per_L=150.0),
),
)
score_str = self.model.compute_score(user)
assert score_str.endswith("%")
# Should be a valid percentage
assert float(score_str.rstrip("%")) > 0
def test_missing_albumin(self):
"""Test that missing albumin raises ValueError."""
user = UserInput(
demographics=Demographics(
age_years=50,
sex=Sex.MALE,
anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=70.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(),
clinical_tests=ClinicalTests(
# albumin missing
bilirubin=BilirubinTest(value_umol_per_L=15.0),
platelets=PlateletTest(value_10e9_per_L=150.0),
),
)
with pytest.raises(ValueError, match=r"Invalid inputs for amap.*albumin"):
self.model.compute_score(user)
def test_missing_bilirubin(self):
"""Test that missing bilirubin raises ValueError."""
user = UserInput(
demographics=Demographics(
age_years=50,
sex=Sex.MALE,
anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=70.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=38.0),
# bilirubin missing
platelets=PlateletTest(value_10e9_per_L=150.0),
),
)
with pytest.raises(ValueError, match=r"Invalid inputs for amap.*bilirubin"):
self.model.compute_score(user)
def test_missing_platelets(self):
"""Test that missing platelets raises ValueError."""
user = UserInput(
demographics=Demographics(
age_years=50,
sex=Sex.MALE,
anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=70.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(),
clinical_tests=ClinicalTests(
albumin=AlbuminTest(value_g_per_L=38.0),
bilirubin=BilirubinTest(value_umol_per_L=15.0),
# platelets missing
),
)
with pytest.raises(ValueError, match=r"Invalid inputs for amap.*platelets"):
self.model.compute_score(user)
def test_amap_score_calculation(self):
"""Test direct aMAP score calculation method."""
# Example from the provided code snippet
score = self.model.amap_score(
age_years=55,
sex=Sex.MALE,
albumin_g_per_L=40.0,
bilirubin_umol_per_L=14.0,
platelets_10e9_per_L=180.0,
)
# Score should be between 0 and 100
assert 0.0 <= score <= 100.0
assert isinstance(score, float)
def test_amap_risk_band_low(self):
"""Test risk band classification for low risk."""
assert self.model.amap_risk_band(30.0) == "low"
assert self.model.amap_risk_band(49.9) == "low"
def test_amap_risk_band_medium(self):
"""Test risk band classification for medium risk."""
assert self.model.amap_risk_band(50.0) == "medium"
assert self.model.amap_risk_band(55.0) == "medium"
assert self.model.amap_risk_band(60.0) == "medium"
def test_amap_risk_band_high(self):
"""Test risk band classification for high risk."""
assert self.model.amap_risk_band(60.1) == "high"
assert self.model.amap_risk_band(80.0) == "high"
def test_model_metadata(self):
"""Test model metadata methods."""
assert self.model.name == "amap"
assert self.model.cancer_type() == "liver"
assert "aMAP" in self.model.description()
assert "hepatocellular carcinoma" in self.model.description().lower()
assert "low risk" in self.model.interpretation().lower()
assert isinstance(self.model.references(), list)
assert len(self.model.references()) > 0
assert self.model.time_horizon_years() == 5.0
def test_invalid_bilirubin_zero(self):
"""Test that zero bilirubin raises ValueError (needed for log10)."""
with pytest.raises(ValueError, match=r"bilirubin must be > 0"):
self.model.amap_score(
age_years=50,
sex=Sex.MALE,
albumin_g_per_L=40.0,
bilirubin_umol_per_L=0.0, # Invalid: zero
platelets_10e9_per_L=150.0,
)
def test_score_clipping(self):
"""Test that aMAP score is clipped to 0-100 range."""
# Test with extreme values that might produce out-of-range scores
score_clipped = self.model.amap_score(
age_years=80,
sex=Sex.MALE,
albumin_g_per_L=20.0, # Very low
bilirubin_umol_per_L=100.0, # Very high
platelets_10e9_per_L=20.0, # Very low
clip_0_100=True,
)
assert 0.0 <= score_clipped <= 100.0
# Test without clipping
score_unclipped = self.model.amap_score(
age_years=80,
sex=Sex.MALE,
albumin_g_per_L=20.0,
bilirubin_umol_per_L=100.0,
platelets_10e9_per_L=20.0,
clip_0_100=False,
)
# Unclipped might be outside range
assert isinstance(score_unclipped, float)