sentinel / tests /test_generate_documentation.py
jeuko's picture
Sync from GitHub (main)
3fc6f6d verified
"""Tests for the documentation generation script."""
import pytest
from scripts.generate_documentation import (
_normalise_cancer_label,
_unique_qcancer_sites,
build_field_usage_map,
cancer_types_for_model,
discover_risk_models,
extract_field_attributes,
extract_model_requirements,
format_field_path,
gather_spec_details,
group_fields_by_requirements,
prettify_field_name,
traverse_user_input_structure,
)
class TestUtilityFunctions:
"""Test utility functions for documentation generation."""
def test_prettify_field_name(self):
"""Test field name prettification."""
assert prettify_field_name("female_specific") == "Female Specific"
assert prettify_field_name("family_history[]") == "Family History"
assert prettify_field_name("age_years") == "Age Years"
assert prettify_field_name("test") == "Test"
def test_format_field_path(self):
"""Test field path formatting."""
assert (
format_field_path("demographics.age_years") == "Demographics\n - Age Years"
)
assert (
format_field_path("family_history[].relation")
== "Family History\n - Relation"
)
assert format_field_path("simple_field") == "Simple Field"
def test_normalise_cancer_label(self):
"""Test cancer label normalization."""
assert _normalise_cancer_label("Lung Cancer") == "Lung"
assert _normalise_cancer_label("breast-cancer") == "Breast"
assert _normalise_cancer_label("colorectal_cancer") == "Colorectal"
assert _normalise_cancer_label("Prostate") == "Prostate"
def test_unique_qcancer_sites(self):
"""Test QCancer sites extraction."""
sites = _unique_qcancer_sites()
assert isinstance(sites, list)
assert len(sites) > 0
# Check that sites are normalized
for site in sites:
assert "cancer" not in site.lower()
assert "_" not in site
assert "-" not in site
def test_cancer_types_for_model(self):
"""Test cancer type extraction for models."""
# Mock a risk model
class MockModel:
"""Mock risk model for testing."""
def __init__(self, name, cancer_type):
"""Initialize mock model.
Args:
name: Model name.
cancer_type: Cancer type string.
"""
self.name = name
self._cancer_type = cancer_type
def cancer_type(self):
"""Return cancer type.
Returns:
str: Cancer type string.
"""
return self._cancer_type
# Test regular model
model = MockModel("gail", "breast")
types = cancer_types_for_model(model)
assert types == ["Breast"]
# Test QCancer model
qcancer_model = MockModel("qcancer", "multiple")
qcancer_types = cancer_types_for_model(qcancer_model)
assert isinstance(qcancer_types, list)
assert len(qcancer_types) > 0
def test_group_fields_by_requirements(self):
"""Test field grouping by requirements."""
# Mock requirements data
requirements = [
("demographics.age_years", int, True),
("demographics.sex", str, True),
("family_history.relation", str, False),
("family_history.cancer_type", str, False),
]
grouped = group_fields_by_requirements(requirements)
assert len(grouped) == 2
# Check demographics group
dem_group = next((g for g in grouped if g[0] == "Demographics"), None)
assert dem_group is not None
assert len(dem_group[1]) == 2
# Check family history group
fh_group = next((g for g in grouped if g[0] == "Family History"), None)
assert fh_group is not None
assert len(fh_group[1]) == 2
def test_gather_spec_details_regular(self):
"""Test spec details gathering for regular fields."""
note = "Test note"
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, note
)
assert note_text == "Test note"
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "-"
def test_gather_spec_details_clinical_observation(self):
"""Test spec details gathering for clinical observations."""
note = "multivitamin - Yes/No"
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, note
)
assert "Multivitamin usage status" in note_text
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "Yes/No"
def test_gather_spec_details_unknown_observation(self):
"""Test spec details gathering for unknown clinical observations."""
note = "unknown_obs - Some values"
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, note
)
assert "Clinical observation: unknown_obs" in note_text
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "Some values"
class TestMainFunctionality:
"""Test main functionality of the documentation generator."""
def test_discover_risk_models(self):
"""Test risk model discovery."""
models = discover_risk_models()
assert isinstance(models, list)
assert len(models) > 0
# Check that all models have required attributes
for model in models:
assert hasattr(model, "name")
assert hasattr(model, "cancer_type")
assert hasattr(model, "description")
assert hasattr(model, "interpretation")
assert hasattr(model, "references")
def test_main_function_import(self):
"""Test that the main function can be imported without errors."""
from scripts.generate_documentation import main
assert callable(main)
class TestEdgeCases:
"""Test edge cases and error handling."""
def test_empty_field_grouping(self):
"""Test field grouping with empty input."""
grouped = group_fields_by_requirements([])
assert grouped == []
def test_single_segment_path(self):
"""Test field path formatting with single segment."""
result = format_field_path("single_field")
assert result == "Single Field"
def test_empty_cancer_label(self):
"""Test cancer label normalization with empty input."""
result = _normalise_cancer_label("")
assert result == ""
def test_none_cancer_label(self):
"""Test cancer label normalization with None input."""
# The function should handle None input gracefully
with pytest.raises(AttributeError):
_normalise_cancer_label(None)
def test_gather_spec_details_none_inputs(self):
"""Test spec details gathering with None inputs."""
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, ""
)
assert note_text == "-"
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "-"
def test_gather_spec_details_empty_note(self):
"""Test spec details gathering with empty note."""
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, ""
)
assert note_text == "-"
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "-"
class TestUserInputStructureExtraction:
"""Test functions for extracting and processing UserInput structure."""
def test_traverse_user_input_structure(self):
"""Test UserInput structure traversal."""
from sentinel.user_input import UserInput
structure = traverse_user_input_structure(UserInput)
assert isinstance(structure, list)
assert len(structure) > 0
# Check that we have both parent models and leaf fields
parent_models = [item for item in structure if item[2] is not None]
leaf_fields = [item for item in structure if item[2] is None]
assert len(parent_models) > 0
assert len(leaf_fields) > 0
# Check structure format: (path, name, model_class)
for path, name, model_class in structure:
assert isinstance(path, str)
assert isinstance(name, str)
assert model_class is None or hasattr(model_class, "model_fields")
def test_extract_model_requirements(self):
"""Test model requirements extraction."""
from sentinel.risk_models.gail import GailRiskModel
model = GailRiskModel()
requirements = extract_model_requirements(model)
assert isinstance(requirements, list)
assert len(requirements) > 0
# Check format: (field_path, field_type, is_required)
for field_path, field_type, is_required in requirements:
assert isinstance(field_path, str)
# field_type can be Annotated types, so we check it's not None
assert field_type is not None
assert isinstance(is_required, bool)
def test_build_field_usage_map(self):
"""Test field usage mapping."""
from sentinel.risk_models.claus import ClausRiskModel
from sentinel.risk_models.gail import GailRiskModel
models = [GailRiskModel(), ClausRiskModel()]
usage_map = build_field_usage_map(models)
assert isinstance(usage_map, dict)
assert len(usage_map) > 0
# Check format: field_path -> [(model_name, is_required), ...]
for field_path, usage_list in usage_map.items():
assert isinstance(field_path, str)
assert isinstance(usage_list, list)
for model_name, is_required in usage_list:
assert isinstance(model_name, str)
assert isinstance(is_required, bool)
def test_extract_field_attributes(self):
"""Test field attributes extraction."""
from sentinel.user_input import UserInput
# Get a field from UserInput
field_info = UserInput.model_fields["demographics"]
field_type = field_info.annotation
description, examples, constraints, used_by, enum_class = (
extract_field_attributes(field_info, field_type)
)
assert isinstance(description, str)
assert isinstance(examples, str)
assert isinstance(constraints, str)
assert isinstance(used_by, str)
assert enum_class is None or isinstance(enum_class, type)