bookmark - Refactor

This commit is contained in:
2026-04-10 17:28:29 -04:00
parent 48ab8c28b9
commit 68422dd304
29 changed files with 1371 additions and 756 deletions

View File

@@ -0,0 +1,93 @@
"""Tests for chest deflection, femur load, tibia index, 3ms clip, viscous criterion."""
import numpy as np
import pytest
from impakt.criteria import chest_deflection, clip_3ms, femur_load, tibia_index, viscous_criterion
class TestChestDeflection:
def test_basic(self, chest_deflection_channel):
result = chest_deflection(channel=chest_deflection_channel)
assert result.criterion == "Chest Deflection"
assert result.unit == "mm"
assert 30.0 < result.value < 40.0 # ~35mm in fixture
def test_peak_time(self, chest_deflection_channel):
result = chest_deflection(channel=chest_deflection_channel)
assert result.time_of_peak is not None
assert 0.0 < result.time_of_peak < 0.1
class TestClip3ms:
def test_basic(self, head_accel_x):
result = clip_3ms(head_accel_x)
assert result.criterion == "3ms Clip"
assert result.value > 0
def test_body_region(self, head_accel_x):
result = clip_3ms(head_accel_x)
assert result.body_region == "Chest"
class TestFemurLoad:
def test_single_channel(self, femur_left_channel):
result = femur_load(channel=femur_left_channel, side="left")
assert result.criterion == "Femur Load Left"
assert result.unit == "kN"
assert result.value > 0
def test_peak_time(self, femur_left_channel):
result = femur_load(channel=femur_left_channel, side="left")
assert result.time_of_peak is not None
class TestTibiaIndex:
def test_with_all_components(self, time_array, sample_rate):
from impakt.channel.code import ChannelCode
from impakt.channel.model import Channel
t = time_array
fz = np.zeros_like(t)
mx = np.zeros_like(t)
my = np.zeros_like(t)
mask = (t >= 0.02) & (t <= 0.08)
fz[mask] = -5000 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
mx[mask] = 50 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
my[mask] = 80 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
fz_ch = Channel(
name="TIBFZ",
code=ChannelCode.parse("TIBFZ"),
data=fz,
time=t,
unit="N",
sample_rate=sample_rate,
)
mx_ch = Channel(
name="TIBMX",
code=ChannelCode.parse("TIBMX"),
data=mx,
time=t,
unit="N·m",
sample_rate=sample_rate,
)
my_ch = Channel(
name="TIBMY",
code=ChannelCode.parse("TIBMY"),
data=my,
time=t,
unit="N·m",
sample_rate=sample_rate,
)
result = tibia_index(fz_channel=fz_ch, mx_channel=mx_ch, my_channel=my_ch)
assert result.value > 0
class TestViscousCriterion:
def test_basic(self, chest_deflection_channel):
result = viscous_criterion(channel=chest_deflection_channel)
assert result.criterion == "Viscous Criterion"
assert result.unit == "m/s"
assert result.value >= 0

View File

@@ -0,0 +1,81 @@
"""Tests for plot engine."""
import numpy as np
import pytest
from impakt.plot.engine import PlotEngine, DEFAULT_COLORS
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
class TestPlotEngine:
def test_render_empty(self):
engine = PlotEngine()
spec = PlotSpec()
fig = engine.render(spec)
assert fig is not None
def test_render_with_channels(self, head_accel_x, head_accel_y):
engine = PlotEngine()
spec = PlotSpec(
channels=[
ChannelRef(channel=head_accel_x, style=PlotStyle(label="Head X")),
ChannelRef(channel=head_accel_y, style=PlotStyle(label="Head Y")),
],
y_label="g",
)
fig = engine.render(spec)
assert len(fig.data) == 2
def test_render_compact_mode(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
compact=True,
)
fig = engine.render(spec)
assert fig.layout.hovermode is False
assert fig.layout.showlegend is False
def test_render_with_corridors(self, head_accel_x):
engine = PlotEngine()
corridor = Corridor(
name="Test Corridor",
time=np.linspace(0, 0.1, 100),
lower=np.full(100, -50.0),
upper=np.full(100, 50.0),
)
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
corridors=[corridor],
)
fig = engine.render(spec)
# 1 data trace + 2 corridor traces (upper + lower)
assert len(fig.data) == 3
def test_render_with_cursors(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
x_cursors=(0.02, 0.06),
)
fig = engine.render(spec)
# Should have vertical lines (as shapes)
assert len(fig.layout.shapes) >= 2
def test_compact_cursor_annotations(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
x_cursors=(0.02, 0.06),
compact=True,
)
fig = engine.render(spec)
# Should have X1/X2 annotations
annotations = [
a for a in (fig.layout.annotations or []) if "X1" in str(a.text) or "X2" in str(a.text)
]
assert len(annotations) == 2
def test_default_colors(self):
assert len(DEFAULT_COLORS) == 10
assert all(c.startswith("#") for c in DEFAULT_COLORS)

View File

@@ -0,0 +1,68 @@
"""Tests for IIHS and US NCAP scoring."""
import pytest
from impakt.criteria.base import CriterionResult
from impakt.protocol.iihs import IIHS, evaluate as iihs_evaluate
from impakt.protocol.us_ncap import USNCAP, evaluate as us_evaluate
@pytest.fixture
def sample_criteria():
return {
"HIC15": CriterionResult(criterion="HIC15", value=400, body_region="Head"),
"Nij": CriterionResult(criterion="Nij", value=0.4, body_region="Neck"),
"Chest Deflection": CriterionResult(
criterion="Chest Deflection", value=35, unit="mm", body_region="Chest"
),
"Femur Load Left": CriterionResult(
criterion="Femur Load Left", value=4.0, unit="kN", body_region="Femur Left"
),
"Femur Load Right": CriterionResult(
criterion="Femur Load Right", value=4.5, unit="kN", body_region="Femur Right"
),
}
class TestIIHS:
def test_good_results(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
assert result.protocol == "IIHS"
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
def test_region_scores(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
assert len(result.region_scores) > 0
for rs in result.region_scores:
assert rs.rating is not None
def test_poor_hic(self):
result = iihs_evaluate(
{
"HIC15": CriterionResult(criterion="HIC15", value=1500, body_region="Head"),
}
)
assert result.overall_rating == "POOR"
def test_summary(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
summary = result.summary()
assert "IIHS" in summary
class TestUSNCAP:
def test_basic(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert result.protocol == "US NCAP"
assert result.stars is not None
assert 1 <= result.stars <= 5
def test_injury_probabilities(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert "combined_injury_probability" in result.details
p = result.details["combined_injury_probability"]
assert 0.0 <= p <= 1.0
def test_region_scores(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert len(result.region_scores) > 0

126
tests/test_scripting_api.py Normal file
View File

@@ -0,0 +1,126 @@
"""Tests for the scripting API (Session, ChannelHandle, TransformProxy)."""
from pathlib import Path
import pytest
from impakt import Session, Template
FIXTURE_DATA = Path(__file__).parent / "fixtures" / "sample_mme"
MME_DATA = Path(__file__).parent / "mme_data"
class TestSession:
def test_open(self):
s = Session.open(FIXTURE_DATA)
assert s.test_id == "IMPAKT_SYNTH_001"
assert len(s) == 26
def test_channel_access(self):
s = Session.open(FIXTURE_DATA)
ch = s.channel("11HEAD0000ACXA")
assert ch.name == "11HEAD0000ACXA"
assert ch.peak > 0
def test_find(self):
s = Session.open(FIXTURE_DATA)
channels = s.find("*HEAD*AC*")
assert len(channels) == 3
def test_group(self):
s = Session.open(FIXTURE_DATA)
group = s.group("HEAD0000AC")
assert group.x is not None
def test_compute_criteria(self):
s = Session.open(FIXTURE_DATA)
criteria = s.compute_criteria()
assert len(criteria) > 0
def test_evaluate(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("euro_ncap")
assert result.stars is not None
assert result.protocol == "Euro NCAP"
def test_evaluate_us_ncap(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("us_ncap")
assert result.stars is not None
def test_evaluate_iihs(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("iihs")
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
def test_evaluate_invalid_protocol(self):
s = Session.open(FIXTURE_DATA)
with pytest.raises(ValueError, match="Unknown protocol"):
s.evaluate("invalid")
def test_contains(self):
s = Session.open(FIXTURE_DATA)
assert "11HEAD0000ACXA" in s
assert "NONEXISTENT" not in s
class TestChannelHandleChaining:
"""The fluent API must support chaining — each transform returns ChannelHandle."""
def test_single_transform(self):
s = Session.open(FIXTURE_DATA)
ch = s.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(600)
assert type(filtered).__name__ == "ChannelHandle"
assert filtered.raw.cfc_class == 600
def test_double_chain(self):
s = Session.open(FIXTURE_DATA)
result = s.channel("11HEAD0000ACXA").transform.cfc(600).transform.y_align()
assert type(result).__name__ == "ChannelHandle"
assert len(result.raw.transform_history) == 2
def test_triple_chain(self):
s = Session.open(FIXTURE_DATA)
result = (
s.channel("11HEAD0000ACXA")
.transform.cfc(1000)
.transform.y_align()
.transform.trim(t_start=0.0, t_end=0.1)
)
assert type(result).__name__ == "ChannelHandle"
assert len(result.raw.transform_history) == 3
def test_chain_preserves_data(self):
s = Session.open(FIXTURE_DATA)
original = s.channel("11HEAD0000ACXA")
original_peak = original.peak
filtered = original.transform.cfc(600)
# Original should be unchanged — peak should be the same
assert original.peak == original_peak
# Filtered should have different CFC and lower peak (smoothed)
assert filtered.raw.cfc_class == 600
assert filtered.peak <= original_peak
@pytest.mark.skipif(not (MME_DATA / "3239").exists(), reason="Real data not available")
class TestSessionRealData:
def test_open_real(self):
s = Session.open(MME_DATA / "3239")
assert s.test_id == "3239"
assert len(s) == 133
def test_full_pipeline(self):
s = Session.open(MME_DATA / "3239")
# Chain: get channel -> filter -> check
ch = s.channel("11HEAD0000H3ACXP").transform.cfc(1000)
assert ch.peak > 100 # Significant head acceleration
# Compute criteria
criteria = s.compute_criteria()
assert "HIC15" in criteria
# Evaluate
result = s.evaluate("euro_ncap")
assert result.stars is not None
assert result.stars >= 0

132
tests/test_template.py Normal file
View File

@@ -0,0 +1,132 @@
"""Tests for template model and session persistence."""
from pathlib import Path
import pytest
from impakt.template.library import TemplateLibrary
from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
from impakt.template.session import SessionManager
class TestTemplateSpec:
def test_yaml_round_trip(self):
template = TemplateSpec(
name="Test Template",
version=2,
description="A test template",
plots=[
PlotDefinition(
title="Head Acceleration",
channel_patterns=["*HEAD*AC*"],
x_cursors=(0.0, 0.1),
)
],
default_cfc=1000,
criteria=["hic15", "nij"],
protocol="euro_ncap",
)
yaml_str = template.to_yaml()
restored = TemplateSpec.from_yaml(yaml_str)
assert restored.name == "Test Template"
assert restored.version == 2
assert restored.default_cfc == 1000
assert len(restored.plots) == 1
assert restored.plots[0].channel_patterns == ["*HEAD*AC*"]
assert restored.criteria == ["hic15", "nij"]
def test_save_and_load(self, tmp_path):
template = TemplateSpec(name="Saved Test", version=1)
path = tmp_path / "test.yaml"
template.save(path)
assert path.exists()
loaded = TemplateSpec.load(path)
assert loaded.name == "Saved Test"
class TestSessionState:
def test_yaml_round_trip(self):
state = SessionState(
template_name="my_template",
template_version=3,
test_path="/data/test_001",
overrides={"cfc": "600", "selected": ["ch1", "ch2"]},
)
yaml_str = state.to_yaml()
restored = SessionState.from_yaml(yaml_str)
assert restored.template_name == "my_template"
assert restored.template_version == 3
assert restored.overrides["cfc"] == "600"
def test_save_and_load(self, tmp_path):
state = SessionState(template_name="test")
path = tmp_path / "session.yaml"
state.save(path)
assert path.exists()
loaded = SessionState.load(path)
assert loaded.template_name == "test"
class TestTemplateLibrary:
def test_empty_library(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
assert lib.list() == []
assert len(lib) == 0
def test_save_and_list(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
template = TemplateSpec(name="My Template")
lib.save(template)
assert "my_template" in lib.list()
assert len(lib) == 1
def test_get(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
lib.save(TemplateSpec(name="Getter Test", version=5))
loaded = lib.get("getter_test")
assert loaded.name == "Getter Test"
assert loaded.version == 5
def test_delete(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
lib.save(TemplateSpec(name="To Delete"))
assert lib.delete("to_delete")
assert "to_delete" not in lib.list()
def test_get_missing_raises(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
with pytest.raises(FileNotFoundError):
lib.get("nonexistent")
class TestSessionManager:
def test_create_and_save(self, tmp_path):
mgr = SessionManager(tmp_path)
mgr.state.template_name = "test_tmpl"
mgr.save()
assert mgr.has_session
assert (tmp_path / ".impakt" / "session.yaml").exists()
def test_load_existing(self, tmp_path):
# Save
mgr1 = SessionManager(tmp_path)
mgr1.state.template_name = "saved_tmpl"
mgr1.state.overrides = {"key": "value"}
mgr1.save()
# Load
mgr2 = SessionManager(tmp_path)
assert mgr2.state.template_name == "saved_tmpl"
assert mgr2.state.overrides["key"] == "value"
def test_clear(self, tmp_path):
mgr = SessionManager(tmp_path)
mgr.save()
assert mgr.has_session
mgr.clear()
assert not mgr.has_session

View File

@@ -0,0 +1,75 @@
"""Tests for math expressions and resultant computation."""
import numpy as np
import pytest
from impakt.transform.math_expr import math_expr
from impakt.transform.resultant import resultant_from_channels
from impakt.transform.resample import trim, resample
class TestMathExpr:
def test_simple_expression(self, head_accel_x, head_accel_z):
result = math_expr(
expression="sqrt(a**2 + b**2)",
channels={"a": head_accel_x, "b": head_accel_z},
name="resultant_xz",
unit="g",
)
assert result.name == "resultant_xz"
assert result.unit == "g"
assert result.peak > 0
assert len(result.data) == len(head_accel_x.data)
def test_constant_expression(self, head_accel_x):
result = math_expr(
expression="a * 0 + 42.0",
channels={"a": head_accel_x},
name="constant",
)
assert np.allclose(result.data, 42.0)
def test_invalid_expression(self, head_accel_x):
with pytest.raises(ValueError, match="Error evaluating"):
math_expr(
expression="invalid_func(a)",
channels={"a": head_accel_x},
)
def test_forbidden_expression(self, head_accel_x):
with pytest.raises(ValueError, match="Forbidden"):
math_expr(
expression="__import__('os')",
channels={"a": head_accel_x},
)
class TestResultant:
def test_from_channels(self, head_accel_x, head_accel_y, head_accel_z):
result = resultant_from_channels(head_accel_x, head_accel_y, head_accel_z)
assert result.code.direction == "R"
# Resultant >= any component
assert result.peak >= head_accel_x.peak
assert result.peak >= head_accel_y.peak
def test_from_two_channels(self, head_accel_x, head_accel_z):
result = resultant_from_channels(head_accel_x, head_accel_z)
assert result.peak > 0
def test_single_channel_raises(self):
with pytest.raises(ValueError, match="At least one"):
resultant_from_channels()
class TestTrimResample:
def test_trim(self, head_accel_x):
trimmed = trim(head_accel_x, t_start=0.0, t_end=0.05)
assert trimmed.time[0] >= 0.0
assert trimmed.time[-1] <= 0.05
assert len(trimmed.data) < len(head_accel_x.data)
def test_resample(self, head_accel_x):
resampled = resample(head_accel_x, target_rate=5000.0)
expected_samples = int(head_accel_x.duration * 5000.0)
assert abs(len(resampled.data) - expected_samples) <= 2
assert resampled.sample_rate == 5000.0

View File

@@ -19,11 +19,11 @@ class TestAppState:
def test_load_test(self):
state = AppState()
loaded = state.load_test(FIXTURE_DATA)
session = state.load_test(FIXTURE_DATA)
assert not state.is_empty
assert loaded.test_id == "IMPAKT_SYNTH_001"
assert loaded.channel_count == 26
assert state.primary_test is loaded
assert session.test_id == "IMPAKT_SYNTH_001"
assert len(session) == 26
assert state.primary_test is session
def test_load_multiple_tests(self):
state = AppState()
@@ -32,13 +32,13 @@ class TestAppState:
if (MME_DATA / "VW1FGS15").exists():
t2 = state.load_test(MME_DATA / "VW1FGS15")
assert len(state.tests) == 2
assert state.primary_test is t1 # First loaded is primary
assert state.primary_test is t1
assert state.total_channels == 26 + 10
def test_remove_test(self):
state = AppState()
loaded = state.load_test(FIXTURE_DATA)
state.remove_test(loaded.test_id)
session = state.load_test(FIXTURE_DATA)
state.remove_test(session.test_id)
assert state.is_empty
def test_get_channel(self):
@@ -54,38 +54,38 @@ class TestAppState:
ch = state.get_channel("IMPAKT_SYNTH_001", "NONEXISTENT")
assert ch is None
def test_resolve_channel_with_key(self):
def test_get_channel_via_session(self):
"""Channels can be accessed through the Session scripting API."""
state = AppState()
state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("IMPAKT_SYNTH_001::11HEAD0000ACXA")
assert ch is not None
session = state.primary_test
assert session is not None
ch_handle = session.channel("11HEAD0000ACXA")
assert ch_handle.name == "11HEAD0000ACXA"
def test_resolve_channel_primary_default(self):
def test_session_fluent_transforms(self):
"""Fluent transform chaining works through the Session API."""
state = AppState()
state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("11HEAD0000ACXA")
assert ch is not None
session = state.primary_test
ch = session.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(600).transform.y_align()
assert filtered.raw.cfc_class == 600
assert len(filtered.raw.transform_history) == 2
def test_resolve_channel_with_cfc(self):
def test_session_compute_criteria(self):
"""Session.compute_criteria() auto-detects channels."""
state = AppState()
state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("11HEAD0000ACXA", cfc_class=600)
assert ch is not None
assert ch.cfc_class == 600
def test_flat_channel_list(self):
state = AppState()
state.load_test(FIXTURE_DATA)
items = state.flat_channel_list()
assert len(items) == 26
assert all("value" in item and "label" in item for item in items)
criteria = state.primary_test.compute_criteria()
assert len(criteria) > 0
assert "HIC15" in criteria or "Chest Deflection" in criteria
def test_build_channel_tree(self):
state = AppState()
state.load_test(FIXTURE_DATA)
tree = state.build_channel_tree()
assert "IMPAKT_SYNTH_001" in tree
# Should have hierarchical structure
test_tree = tree["IMPAKT_SYNTH_001"]
assert len(test_tree) > 0
@@ -94,15 +94,23 @@ class TestAppState:
class TestAppStateRealData:
def test_load_real_mme(self):
state = AppState()
loaded = state.load_test(MME_DATA / "3239")
assert loaded.test_id == "3239"
assert loaded.channel_count == 133
session = state.load_test(MME_DATA / "3239")
assert session.test_id == "3239"
assert len(session) == 133
def test_channel_tree_real_data(self):
state = AppState()
state.load_test(MME_DATA / "3239")
tree = state.build_channel_tree()
assert "3239" in tree
# Should contain "Driver" in some key
test_tree = tree["3239"]
assert any("Driver" in k for k in test_tree)
def test_session_evaluate_real_data(self):
"""Full pipeline through Session API on real data."""
state = AppState()
state.load_test(MME_DATA / "3239")
result = state.primary_test.evaluate("euro_ncap")
assert result.stars is not None
assert result.stars >= 0
assert len(result.region_scores) > 0