bookmark - Refactor
This commit is contained in:
93
tests/test_criteria/test_chest_femur_tibia.py
Normal file
93
tests/test_criteria/test_chest_femur_tibia.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Tests for chest deflection, femur load, tibia index, 3ms clip, viscous criterion."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from impakt.criteria import chest_deflection, clip_3ms, femur_load, tibia_index, viscous_criterion
|
||||
|
||||
|
||||
class TestChestDeflection:
|
||||
def test_basic(self, chest_deflection_channel):
|
||||
result = chest_deflection(channel=chest_deflection_channel)
|
||||
assert result.criterion == "Chest Deflection"
|
||||
assert result.unit == "mm"
|
||||
assert 30.0 < result.value < 40.0 # ~35mm in fixture
|
||||
|
||||
def test_peak_time(self, chest_deflection_channel):
|
||||
result = chest_deflection(channel=chest_deflection_channel)
|
||||
assert result.time_of_peak is not None
|
||||
assert 0.0 < result.time_of_peak < 0.1
|
||||
|
||||
|
||||
class TestClip3ms:
|
||||
def test_basic(self, head_accel_x):
|
||||
result = clip_3ms(head_accel_x)
|
||||
assert result.criterion == "3ms Clip"
|
||||
assert result.value > 0
|
||||
|
||||
def test_body_region(self, head_accel_x):
|
||||
result = clip_3ms(head_accel_x)
|
||||
assert result.body_region == "Chest"
|
||||
|
||||
|
||||
class TestFemurLoad:
|
||||
def test_single_channel(self, femur_left_channel):
|
||||
result = femur_load(channel=femur_left_channel, side="left")
|
||||
assert result.criterion == "Femur Load Left"
|
||||
assert result.unit == "kN"
|
||||
assert result.value > 0
|
||||
|
||||
def test_peak_time(self, femur_left_channel):
|
||||
result = femur_load(channel=femur_left_channel, side="left")
|
||||
assert result.time_of_peak is not None
|
||||
|
||||
|
||||
class TestTibiaIndex:
|
||||
def test_with_all_components(self, time_array, sample_rate):
|
||||
from impakt.channel.code import ChannelCode
|
||||
from impakt.channel.model import Channel
|
||||
|
||||
t = time_array
|
||||
fz = np.zeros_like(t)
|
||||
mx = np.zeros_like(t)
|
||||
my = np.zeros_like(t)
|
||||
mask = (t >= 0.02) & (t <= 0.08)
|
||||
fz[mask] = -5000 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
|
||||
mx[mask] = 50 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
|
||||
my[mask] = 80 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
|
||||
|
||||
fz_ch = Channel(
|
||||
name="TIBFZ",
|
||||
code=ChannelCode.parse("TIBFZ"),
|
||||
data=fz,
|
||||
time=t,
|
||||
unit="N",
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
mx_ch = Channel(
|
||||
name="TIBMX",
|
||||
code=ChannelCode.parse("TIBMX"),
|
||||
data=mx,
|
||||
time=t,
|
||||
unit="N·m",
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
my_ch = Channel(
|
||||
name="TIBMY",
|
||||
code=ChannelCode.parse("TIBMY"),
|
||||
data=my,
|
||||
time=t,
|
||||
unit="N·m",
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
|
||||
result = tibia_index(fz_channel=fz_ch, mx_channel=mx_ch, my_channel=my_ch)
|
||||
assert result.value > 0
|
||||
|
||||
|
||||
class TestViscousCriterion:
|
||||
def test_basic(self, chest_deflection_channel):
|
||||
result = viscous_criterion(channel=chest_deflection_channel)
|
||||
assert result.criterion == "Viscous Criterion"
|
||||
assert result.unit == "m/s"
|
||||
assert result.value >= 0
|
||||
81
tests/test_plot/test_engine.py
Normal file
81
tests/test_plot/test_engine.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for plot engine."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from impakt.plot.engine import PlotEngine, DEFAULT_COLORS
|
||||
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
|
||||
|
||||
|
||||
class TestPlotEngine:
|
||||
def test_render_empty(self):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec()
|
||||
fig = engine.render(spec)
|
||||
assert fig is not None
|
||||
|
||||
def test_render_with_channels(self, head_accel_x, head_accel_y):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec(
|
||||
channels=[
|
||||
ChannelRef(channel=head_accel_x, style=PlotStyle(label="Head X")),
|
||||
ChannelRef(channel=head_accel_y, style=PlotStyle(label="Head Y")),
|
||||
],
|
||||
y_label="g",
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
assert len(fig.data) == 2
|
||||
|
||||
def test_render_compact_mode(self, head_accel_x):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec(
|
||||
channels=[ChannelRef(channel=head_accel_x)],
|
||||
compact=True,
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
assert fig.layout.hovermode is False
|
||||
assert fig.layout.showlegend is False
|
||||
|
||||
def test_render_with_corridors(self, head_accel_x):
|
||||
engine = PlotEngine()
|
||||
corridor = Corridor(
|
||||
name="Test Corridor",
|
||||
time=np.linspace(0, 0.1, 100),
|
||||
lower=np.full(100, -50.0),
|
||||
upper=np.full(100, 50.0),
|
||||
)
|
||||
spec = PlotSpec(
|
||||
channels=[ChannelRef(channel=head_accel_x)],
|
||||
corridors=[corridor],
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
# 1 data trace + 2 corridor traces (upper + lower)
|
||||
assert len(fig.data) == 3
|
||||
|
||||
def test_render_with_cursors(self, head_accel_x):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec(
|
||||
channels=[ChannelRef(channel=head_accel_x)],
|
||||
x_cursors=(0.02, 0.06),
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
# Should have vertical lines (as shapes)
|
||||
assert len(fig.layout.shapes) >= 2
|
||||
|
||||
def test_compact_cursor_annotations(self, head_accel_x):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec(
|
||||
channels=[ChannelRef(channel=head_accel_x)],
|
||||
x_cursors=(0.02, 0.06),
|
||||
compact=True,
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
# Should have X1/X2 annotations
|
||||
annotations = [
|
||||
a for a in (fig.layout.annotations or []) if "X1" in str(a.text) or "X2" in str(a.text)
|
||||
]
|
||||
assert len(annotations) == 2
|
||||
|
||||
def test_default_colors(self):
|
||||
assert len(DEFAULT_COLORS) == 10
|
||||
assert all(c.startswith("#") for c in DEFAULT_COLORS)
|
||||
68
tests/test_protocol/test_iihs_usncap.py
Normal file
68
tests/test_protocol/test_iihs_usncap.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for IIHS and US NCAP scoring."""
|
||||
|
||||
import pytest
|
||||
|
||||
from impakt.criteria.base import CriterionResult
|
||||
from impakt.protocol.iihs import IIHS, evaluate as iihs_evaluate
|
||||
from impakt.protocol.us_ncap import USNCAP, evaluate as us_evaluate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_criteria():
|
||||
return {
|
||||
"HIC15": CriterionResult(criterion="HIC15", value=400, body_region="Head"),
|
||||
"Nij": CriterionResult(criterion="Nij", value=0.4, body_region="Neck"),
|
||||
"Chest Deflection": CriterionResult(
|
||||
criterion="Chest Deflection", value=35, unit="mm", body_region="Chest"
|
||||
),
|
||||
"Femur Load Left": CriterionResult(
|
||||
criterion="Femur Load Left", value=4.0, unit="kN", body_region="Femur Left"
|
||||
),
|
||||
"Femur Load Right": CriterionResult(
|
||||
criterion="Femur Load Right", value=4.5, unit="kN", body_region="Femur Right"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestIIHS:
|
||||
def test_good_results(self, sample_criteria):
|
||||
result = iihs_evaluate(sample_criteria)
|
||||
assert result.protocol == "IIHS"
|
||||
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
|
||||
|
||||
def test_region_scores(self, sample_criteria):
|
||||
result = iihs_evaluate(sample_criteria)
|
||||
assert len(result.region_scores) > 0
|
||||
for rs in result.region_scores:
|
||||
assert rs.rating is not None
|
||||
|
||||
def test_poor_hic(self):
|
||||
result = iihs_evaluate(
|
||||
{
|
||||
"HIC15": CriterionResult(criterion="HIC15", value=1500, body_region="Head"),
|
||||
}
|
||||
)
|
||||
assert result.overall_rating == "POOR"
|
||||
|
||||
def test_summary(self, sample_criteria):
|
||||
result = iihs_evaluate(sample_criteria)
|
||||
summary = result.summary()
|
||||
assert "IIHS" in summary
|
||||
|
||||
|
||||
class TestUSNCAP:
|
||||
def test_basic(self, sample_criteria):
|
||||
result = us_evaluate(sample_criteria)
|
||||
assert result.protocol == "US NCAP"
|
||||
assert result.stars is not None
|
||||
assert 1 <= result.stars <= 5
|
||||
|
||||
def test_injury_probabilities(self, sample_criteria):
|
||||
result = us_evaluate(sample_criteria)
|
||||
assert "combined_injury_probability" in result.details
|
||||
p = result.details["combined_injury_probability"]
|
||||
assert 0.0 <= p <= 1.0
|
||||
|
||||
def test_region_scores(self, sample_criteria):
|
||||
result = us_evaluate(sample_criteria)
|
||||
assert len(result.region_scores) > 0
|
||||
126
tests/test_scripting_api.py
Normal file
126
tests/test_scripting_api.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Tests for the scripting API (Session, ChannelHandle, TransformProxy)."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from impakt import Session, Template
|
||||
|
||||
FIXTURE_DATA = Path(__file__).parent / "fixtures" / "sample_mme"
|
||||
MME_DATA = Path(__file__).parent / "mme_data"
|
||||
|
||||
|
||||
class TestSession:
|
||||
def test_open(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
assert s.test_id == "IMPAKT_SYNTH_001"
|
||||
assert len(s) == 26
|
||||
|
||||
def test_channel_access(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
ch = s.channel("11HEAD0000ACXA")
|
||||
assert ch.name == "11HEAD0000ACXA"
|
||||
assert ch.peak > 0
|
||||
|
||||
def test_find(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
channels = s.find("*HEAD*AC*")
|
||||
assert len(channels) == 3
|
||||
|
||||
def test_group(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
group = s.group("HEAD0000AC")
|
||||
assert group.x is not None
|
||||
|
||||
def test_compute_criteria(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
criteria = s.compute_criteria()
|
||||
assert len(criteria) > 0
|
||||
|
||||
def test_evaluate(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = s.evaluate("euro_ncap")
|
||||
assert result.stars is not None
|
||||
assert result.protocol == "Euro NCAP"
|
||||
|
||||
def test_evaluate_us_ncap(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = s.evaluate("us_ncap")
|
||||
assert result.stars is not None
|
||||
|
||||
def test_evaluate_iihs(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = s.evaluate("iihs")
|
||||
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
|
||||
|
||||
def test_evaluate_invalid_protocol(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
with pytest.raises(ValueError, match="Unknown protocol"):
|
||||
s.evaluate("invalid")
|
||||
|
||||
def test_contains(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
assert "11HEAD0000ACXA" in s
|
||||
assert "NONEXISTENT" not in s
|
||||
|
||||
|
||||
class TestChannelHandleChaining:
|
||||
"""The fluent API must support chaining — each transform returns ChannelHandle."""
|
||||
|
||||
def test_single_transform(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
ch = s.channel("11HEAD0000ACXA")
|
||||
filtered = ch.transform.cfc(600)
|
||||
assert type(filtered).__name__ == "ChannelHandle"
|
||||
assert filtered.raw.cfc_class == 600
|
||||
|
||||
def test_double_chain(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = s.channel("11HEAD0000ACXA").transform.cfc(600).transform.y_align()
|
||||
assert type(result).__name__ == "ChannelHandle"
|
||||
assert len(result.raw.transform_history) == 2
|
||||
|
||||
def test_triple_chain(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = (
|
||||
s.channel("11HEAD0000ACXA")
|
||||
.transform.cfc(1000)
|
||||
.transform.y_align()
|
||||
.transform.trim(t_start=0.0, t_end=0.1)
|
||||
)
|
||||
assert type(result).__name__ == "ChannelHandle"
|
||||
assert len(result.raw.transform_history) == 3
|
||||
|
||||
def test_chain_preserves_data(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
original = s.channel("11HEAD0000ACXA")
|
||||
original_peak = original.peak
|
||||
filtered = original.transform.cfc(600)
|
||||
# Original should be unchanged — peak should be the same
|
||||
assert original.peak == original_peak
|
||||
# Filtered should have different CFC and lower peak (smoothed)
|
||||
assert filtered.raw.cfc_class == 600
|
||||
assert filtered.peak <= original_peak
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (MME_DATA / "3239").exists(), reason="Real data not available")
|
||||
class TestSessionRealData:
|
||||
def test_open_real(self):
|
||||
s = Session.open(MME_DATA / "3239")
|
||||
assert s.test_id == "3239"
|
||||
assert len(s) == 133
|
||||
|
||||
def test_full_pipeline(self):
|
||||
s = Session.open(MME_DATA / "3239")
|
||||
# Chain: get channel -> filter -> check
|
||||
ch = s.channel("11HEAD0000H3ACXP").transform.cfc(1000)
|
||||
assert ch.peak > 100 # Significant head acceleration
|
||||
|
||||
# Compute criteria
|
||||
criteria = s.compute_criteria()
|
||||
assert "HIC15" in criteria
|
||||
|
||||
# Evaluate
|
||||
result = s.evaluate("euro_ncap")
|
||||
assert result.stars is not None
|
||||
assert result.stars >= 0
|
||||
132
tests/test_template.py
Normal file
132
tests/test_template.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for template model and session persistence."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from impakt.template.library import TemplateLibrary
|
||||
from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
|
||||
from impakt.template.session import SessionManager
|
||||
|
||||
|
||||
class TestTemplateSpec:
|
||||
def test_yaml_round_trip(self):
|
||||
template = TemplateSpec(
|
||||
name="Test Template",
|
||||
version=2,
|
||||
description="A test template",
|
||||
plots=[
|
||||
PlotDefinition(
|
||||
title="Head Acceleration",
|
||||
channel_patterns=["*HEAD*AC*"],
|
||||
x_cursors=(0.0, 0.1),
|
||||
)
|
||||
],
|
||||
default_cfc=1000,
|
||||
criteria=["hic15", "nij"],
|
||||
protocol="euro_ncap",
|
||||
)
|
||||
|
||||
yaml_str = template.to_yaml()
|
||||
restored = TemplateSpec.from_yaml(yaml_str)
|
||||
|
||||
assert restored.name == "Test Template"
|
||||
assert restored.version == 2
|
||||
assert restored.default_cfc == 1000
|
||||
assert len(restored.plots) == 1
|
||||
assert restored.plots[0].channel_patterns == ["*HEAD*AC*"]
|
||||
assert restored.criteria == ["hic15", "nij"]
|
||||
|
||||
def test_save_and_load(self, tmp_path):
|
||||
template = TemplateSpec(name="Saved Test", version=1)
|
||||
path = tmp_path / "test.yaml"
|
||||
template.save(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded = TemplateSpec.load(path)
|
||||
assert loaded.name == "Saved Test"
|
||||
|
||||
|
||||
class TestSessionState:
|
||||
def test_yaml_round_trip(self):
|
||||
state = SessionState(
|
||||
template_name="my_template",
|
||||
template_version=3,
|
||||
test_path="/data/test_001",
|
||||
overrides={"cfc": "600", "selected": ["ch1", "ch2"]},
|
||||
)
|
||||
yaml_str = state.to_yaml()
|
||||
restored = SessionState.from_yaml(yaml_str)
|
||||
|
||||
assert restored.template_name == "my_template"
|
||||
assert restored.template_version == 3
|
||||
assert restored.overrides["cfc"] == "600"
|
||||
|
||||
def test_save_and_load(self, tmp_path):
|
||||
state = SessionState(template_name="test")
|
||||
path = tmp_path / "session.yaml"
|
||||
state.save(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded = SessionState.load(path)
|
||||
assert loaded.template_name == "test"
|
||||
|
||||
|
||||
class TestTemplateLibrary:
|
||||
def test_empty_library(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
assert lib.list() == []
|
||||
assert len(lib) == 0
|
||||
|
||||
def test_save_and_list(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
template = TemplateSpec(name="My Template")
|
||||
lib.save(template)
|
||||
assert "my_template" in lib.list()
|
||||
assert len(lib) == 1
|
||||
|
||||
def test_get(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
lib.save(TemplateSpec(name="Getter Test", version=5))
|
||||
loaded = lib.get("getter_test")
|
||||
assert loaded.name == "Getter Test"
|
||||
assert loaded.version == 5
|
||||
|
||||
def test_delete(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
lib.save(TemplateSpec(name="To Delete"))
|
||||
assert lib.delete("to_delete")
|
||||
assert "to_delete" not in lib.list()
|
||||
|
||||
def test_get_missing_raises(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
with pytest.raises(FileNotFoundError):
|
||||
lib.get("nonexistent")
|
||||
|
||||
|
||||
class TestSessionManager:
|
||||
def test_create_and_save(self, tmp_path):
|
||||
mgr = SessionManager(tmp_path)
|
||||
mgr.state.template_name = "test_tmpl"
|
||||
mgr.save()
|
||||
assert mgr.has_session
|
||||
assert (tmp_path / ".impakt" / "session.yaml").exists()
|
||||
|
||||
def test_load_existing(self, tmp_path):
|
||||
# Save
|
||||
mgr1 = SessionManager(tmp_path)
|
||||
mgr1.state.template_name = "saved_tmpl"
|
||||
mgr1.state.overrides = {"key": "value"}
|
||||
mgr1.save()
|
||||
|
||||
# Load
|
||||
mgr2 = SessionManager(tmp_path)
|
||||
assert mgr2.state.template_name == "saved_tmpl"
|
||||
assert mgr2.state.overrides["key"] == "value"
|
||||
|
||||
def test_clear(self, tmp_path):
|
||||
mgr = SessionManager(tmp_path)
|
||||
mgr.save()
|
||||
assert mgr.has_session
|
||||
mgr.clear()
|
||||
assert not mgr.has_session
|
||||
75
tests/test_transform/test_math_resultant.py
Normal file
75
tests/test_transform/test_math_resultant.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Tests for math expressions and resultant computation."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from impakt.transform.math_expr import math_expr
|
||||
from impakt.transform.resultant import resultant_from_channels
|
||||
from impakt.transform.resample import trim, resample
|
||||
|
||||
|
||||
class TestMathExpr:
|
||||
def test_simple_expression(self, head_accel_x, head_accel_z):
|
||||
result = math_expr(
|
||||
expression="sqrt(a**2 + b**2)",
|
||||
channels={"a": head_accel_x, "b": head_accel_z},
|
||||
name="resultant_xz",
|
||||
unit="g",
|
||||
)
|
||||
assert result.name == "resultant_xz"
|
||||
assert result.unit == "g"
|
||||
assert result.peak > 0
|
||||
assert len(result.data) == len(head_accel_x.data)
|
||||
|
||||
def test_constant_expression(self, head_accel_x):
|
||||
result = math_expr(
|
||||
expression="a * 0 + 42.0",
|
||||
channels={"a": head_accel_x},
|
||||
name="constant",
|
||||
)
|
||||
assert np.allclose(result.data, 42.0)
|
||||
|
||||
def test_invalid_expression(self, head_accel_x):
|
||||
with pytest.raises(ValueError, match="Error evaluating"):
|
||||
math_expr(
|
||||
expression="invalid_func(a)",
|
||||
channels={"a": head_accel_x},
|
||||
)
|
||||
|
||||
def test_forbidden_expression(self, head_accel_x):
|
||||
with pytest.raises(ValueError, match="Forbidden"):
|
||||
math_expr(
|
||||
expression="__import__('os')",
|
||||
channels={"a": head_accel_x},
|
||||
)
|
||||
|
||||
|
||||
class TestResultant:
|
||||
def test_from_channels(self, head_accel_x, head_accel_y, head_accel_z):
|
||||
result = resultant_from_channels(head_accel_x, head_accel_y, head_accel_z)
|
||||
assert result.code.direction == "R"
|
||||
# Resultant >= any component
|
||||
assert result.peak >= head_accel_x.peak
|
||||
assert result.peak >= head_accel_y.peak
|
||||
|
||||
def test_from_two_channels(self, head_accel_x, head_accel_z):
|
||||
result = resultant_from_channels(head_accel_x, head_accel_z)
|
||||
assert result.peak > 0
|
||||
|
||||
def test_single_channel_raises(self):
|
||||
with pytest.raises(ValueError, match="At least one"):
|
||||
resultant_from_channels()
|
||||
|
||||
|
||||
class TestTrimResample:
|
||||
def test_trim(self, head_accel_x):
|
||||
trimmed = trim(head_accel_x, t_start=0.0, t_end=0.05)
|
||||
assert trimmed.time[0] >= 0.0
|
||||
assert trimmed.time[-1] <= 0.05
|
||||
assert len(trimmed.data) < len(head_accel_x.data)
|
||||
|
||||
def test_resample(self, head_accel_x):
|
||||
resampled = resample(head_accel_x, target_rate=5000.0)
|
||||
expected_samples = int(head_accel_x.duration * 5000.0)
|
||||
assert abs(len(resampled.data) - expected_samples) <= 2
|
||||
assert resampled.sample_rate == 5000.0
|
||||
@@ -19,11 +19,11 @@ class TestAppState:
|
||||
|
||||
def test_load_test(self):
|
||||
state = AppState()
|
||||
loaded = state.load_test(FIXTURE_DATA)
|
||||
session = state.load_test(FIXTURE_DATA)
|
||||
assert not state.is_empty
|
||||
assert loaded.test_id == "IMPAKT_SYNTH_001"
|
||||
assert loaded.channel_count == 26
|
||||
assert state.primary_test is loaded
|
||||
assert session.test_id == "IMPAKT_SYNTH_001"
|
||||
assert len(session) == 26
|
||||
assert state.primary_test is session
|
||||
|
||||
def test_load_multiple_tests(self):
|
||||
state = AppState()
|
||||
@@ -32,13 +32,13 @@ class TestAppState:
|
||||
if (MME_DATA / "VW1FGS15").exists():
|
||||
t2 = state.load_test(MME_DATA / "VW1FGS15")
|
||||
assert len(state.tests) == 2
|
||||
assert state.primary_test is t1 # First loaded is primary
|
||||
assert state.primary_test is t1
|
||||
assert state.total_channels == 26 + 10
|
||||
|
||||
def test_remove_test(self):
|
||||
state = AppState()
|
||||
loaded = state.load_test(FIXTURE_DATA)
|
||||
state.remove_test(loaded.test_id)
|
||||
session = state.load_test(FIXTURE_DATA)
|
||||
state.remove_test(session.test_id)
|
||||
assert state.is_empty
|
||||
|
||||
def test_get_channel(self):
|
||||
@@ -54,38 +54,38 @@ class TestAppState:
|
||||
ch = state.get_channel("IMPAKT_SYNTH_001", "NONEXISTENT")
|
||||
assert ch is None
|
||||
|
||||
def test_resolve_channel_with_key(self):
|
||||
def test_get_channel_via_session(self):
|
||||
"""Channels can be accessed through the Session scripting API."""
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
ch = state.resolve_channel("IMPAKT_SYNTH_001::11HEAD0000ACXA")
|
||||
assert ch is not None
|
||||
session = state.primary_test
|
||||
assert session is not None
|
||||
ch_handle = session.channel("11HEAD0000ACXA")
|
||||
assert ch_handle.name == "11HEAD0000ACXA"
|
||||
|
||||
def test_resolve_channel_primary_default(self):
|
||||
def test_session_fluent_transforms(self):
|
||||
"""Fluent transform chaining works through the Session API."""
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
ch = state.resolve_channel("11HEAD0000ACXA")
|
||||
assert ch is not None
|
||||
session = state.primary_test
|
||||
ch = session.channel("11HEAD0000ACXA")
|
||||
filtered = ch.transform.cfc(600).transform.y_align()
|
||||
assert filtered.raw.cfc_class == 600
|
||||
assert len(filtered.raw.transform_history) == 2
|
||||
|
||||
def test_resolve_channel_with_cfc(self):
|
||||
def test_session_compute_criteria(self):
|
||||
"""Session.compute_criteria() auto-detects channels."""
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
ch = state.resolve_channel("11HEAD0000ACXA", cfc_class=600)
|
||||
assert ch is not None
|
||||
assert ch.cfc_class == 600
|
||||
|
||||
def test_flat_channel_list(self):
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
items = state.flat_channel_list()
|
||||
assert len(items) == 26
|
||||
assert all("value" in item and "label" in item for item in items)
|
||||
criteria = state.primary_test.compute_criteria()
|
||||
assert len(criteria) > 0
|
||||
assert "HIC15" in criteria or "Chest Deflection" in criteria
|
||||
|
||||
def test_build_channel_tree(self):
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
tree = state.build_channel_tree()
|
||||
assert "IMPAKT_SYNTH_001" in tree
|
||||
# Should have hierarchical structure
|
||||
test_tree = tree["IMPAKT_SYNTH_001"]
|
||||
assert len(test_tree) > 0
|
||||
|
||||
@@ -94,15 +94,23 @@ class TestAppState:
|
||||
class TestAppStateRealData:
|
||||
def test_load_real_mme(self):
|
||||
state = AppState()
|
||||
loaded = state.load_test(MME_DATA / "3239")
|
||||
assert loaded.test_id == "3239"
|
||||
assert loaded.channel_count == 133
|
||||
session = state.load_test(MME_DATA / "3239")
|
||||
assert session.test_id == "3239"
|
||||
assert len(session) == 133
|
||||
|
||||
def test_channel_tree_real_data(self):
|
||||
state = AppState()
|
||||
state.load_test(MME_DATA / "3239")
|
||||
tree = state.build_channel_tree()
|
||||
assert "3239" in tree
|
||||
# Should contain "Driver" in some key
|
||||
test_tree = tree["3239"]
|
||||
assert any("Driver" in k for k in test_tree)
|
||||
|
||||
def test_session_evaluate_real_data(self):
|
||||
"""Full pipeline through Session API on real data."""
|
||||
state = AppState()
|
||||
state.load_test(MME_DATA / "3239")
|
||||
result = state.primary_test.evaluate("euro_ncap")
|
||||
assert result.stars is not None
|
||||
assert result.stars >= 0
|
||||
assert len(result.region_scores) > 0
|
||||
|
||||
Reference in New Issue
Block a user