bookmark - Refactor

This commit is contained in:
2026-04-10 17:28:29 -04:00
parent 48ab8c28b9
commit 68422dd304
29 changed files with 1371 additions and 756 deletions

29
src/impakt/io/csv.py Normal file
View File

@@ -0,0 +1,29 @@
"""CSV data reader (stub).
Placeholder for a future CSV reader plugin.
Install or implement to enable reading CSV crash test data.
"""
from __future__ import annotations
from pathlib import Path
from impakt.channel.model import TestData, TestMetadata
class CSVReader:
"""Reader for CSV time-series data (not yet implemented)."""
@property
def format_name(self) -> str:
return "CSV"
def supports(self, path: Path) -> bool:
# Only claim support for .csv files with crash test structure
return False # Disabled until implemented
def metadata(self, path: Path) -> TestMetadata:
raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.")
def read(self, path: Path) -> TestData:
raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.")

37
src/impakt/io/tdms.py Normal file
View File

@@ -0,0 +1,37 @@
"""NI TDMS data reader (stub).
Placeholder for a future TDMS reader plugin.
Install nptdms to enable: pip install impakt[tdms]
"""
from __future__ import annotations
from pathlib import Path
from impakt.channel.model import TestData, TestMetadata
class TDMSReader:
"""Reader for NI TDMS files (not yet implemented).
Requires the optional ``nptdms`` dependency::
pip install impakt[tdms]
"""
@property
def format_name(self) -> str:
return "NI TDMS"
def supports(self, path: Path) -> bool:
return False # Disabled until implemented
def metadata(self, path: Path) -> TestMetadata:
raise NotImplementedError(
"TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]"
)
def read(self, path: Path) -> TestData:
raise NotImplementedError(
"TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]"
)

View File

@@ -2,6 +2,9 @@
Renders PlotSpec objects into interactive Plotly figures with
support for corridors, dual X-cursors, and export.
The PlotEngine is the **single rendering path** — both the scripting
API and the web UI construct PlotSpec objects and delegate here.
"""
from __future__ import annotations
@@ -30,12 +33,39 @@ DEFAULT_COLORS = [
class PlotEngine:
"""Renders PlotSpec into Plotly figures."""
"""Renders PlotSpec into Plotly figures.
Supports two modes:
- **Default**: Full interactive plot with hover tooltips and legend.
- **Compact** (``spec.compact=True``): Web UI mode with no legend, no
hover tooltips, tight margins, smaller axis labels. Cursor tracking
is handled by external JS.
"""
def render(self, spec: PlotSpec) -> go.Figure:
"""Render a PlotSpec into an interactive Plotly figure."""
fig = go.Figure()
if not spec.channels:
fig.update_layout(
template="plotly_white",
annotations=[
{
"text": "Select channels to plot",
"xref": "paper",
"yref": "paper",
"x": 0.5,
"y": 0.5,
"showarrow": False,
"font": {"size": 14, "color": "#bbb"},
}
],
xaxis={"visible": False},
yaxis={"visible": False},
margin={"l": 40, "r": 20, "t": 30, "b": 40},
)
return fig
# Add corridor fills first (behind data traces)
for corridor in spec.corridors:
self._add_corridor(fig, corridor)
@@ -54,36 +84,74 @@ class PlotEngine:
color = style.color or DEFAULT_COLORS[i % len(DEFAULT_COLORS)]
label = style.label or ch_ref.label
fig.add_trace(
go.Scatter(
x=ch.time,
y=ch.data,
mode="lines",
name=label,
line=dict(
trace_kwargs: dict[str, Any] = {
"x": ch.time,
"y": ch.data,
"mode": "lines",
"name": label,
"line": dict(
color=color,
width=style.line_width,
dash=style.line_dash,
),
opacity=style.opacity,
hovertemplate=f"{label}<br>t=%{{x:.6f}}s<br>%{{y:.4f}} {ch.unit}<extra></extra>",
)
"opacity": style.opacity,
}
if not spec.compact:
trace_kwargs["hovertemplate"] = (
f"{label}<br>t=%{{x:.6f}}s<br>%{{y:.4f}} {ch.unit}<extra></extra>"
)
fig.add_trace(go.Scatter(**trace_kwargs))
# Add cursor lines
if spec.x_cursors:
x1, x2 = spec.x_cursors
for x_val, label in [(x1, "x1"), (x2, "x2")]:
if spec.compact:
# Color-coded cursor lines for compact mode
fig.add_vline(
x=x1,
line_dash="dash",
line_color="rgba(220,53,69,0.6)",
line_width=1,
annotation_text=f"X1={x1:.4f}s",
annotation_font_size=9,
annotation_font_color="rgba(220,53,69,0.8)",
)
fig.add_vline(
x=x2,
line_dash="dash",
line_color="rgba(13,110,253,0.6)",
line_width=1,
annotation_text=f"X2={x2:.4f}s",
annotation_font_size=9,
annotation_font_color="rgba(13,110,253,0.8)",
)
else:
for x_val, lbl in [(x1, "x1"), (x2, "x2")]:
fig.add_vline(
x=x_val,
line_dash="dash",
line_color="gray",
line_width=1,
annotation_text=f"{label}={x_val:.6f}s",
annotation_text=f"{lbl}={x_val:.6f}s",
annotation_position="top",
)
# Layout
if spec.compact:
margin = spec.margin or {"l": 45, "r": 8, "t": 4, "b": 28}
fig.update_layout(
xaxis_title=dict(text=spec.x_label, font=dict(size=10, color="#999")),
yaxis_title=dict(text=spec.y_label, font=dict(size=10, color="#999")),
template="plotly_white",
hovermode=False,
showlegend=False,
margin=margin,
)
else:
margin = spec.margin or {"l": 60, "r": 20, "t": 40 if spec.title else 10, "b": 60}
hovermode = spec.hovermode if spec.hovermode is not None else "x unified"
fig.update_layout(
title=spec.title,
xaxis_title=spec.x_label,
@@ -92,7 +160,7 @@ class PlotEngine:
height=spec.height,
width=spec.width,
template="plotly_white",
hovermode="x unified",
hovermode=hovermode,
legend=dict(
orientation="h",
yanchor="bottom",
@@ -100,6 +168,7 @@ class PlotEngine:
xanchor="center",
x=0.5,
),
margin=margin,
)
if spec.show_grid:
@@ -117,7 +186,6 @@ class PlotEngine:
"""Add a corridor (tolerance band) to the figure."""
style = corridor.style
# Upper bound
fig.add_trace(
go.Scatter(
x=corridor.time,
@@ -128,8 +196,6 @@ class PlotEngine:
showlegend=False,
)
)
# Lower bound with fill to upper
fig.add_trace(
go.Scatter(
x=corridor.time,
@@ -144,16 +210,7 @@ class PlotEngine:
)
def to_image(self, spec: PlotSpec, format: str = "png", scale: float = 2.0) -> bytes:
"""Render to a static image.
Args:
spec: Plot specification.
format: Image format ('png', 'svg', 'pdf', 'jpeg').
scale: Resolution multiplier.
Returns:
Image bytes.
"""
"""Render to a static image."""
fig = self.render(spec)
return fig.to_image(format=format, scale=scale)
@@ -168,16 +225,7 @@ def cursor_values(
x1: float,
x2: float,
) -> CursorValues:
"""Compute interpolated values at two X-axis positions.
Args:
spec_or_channels: PlotSpec or list of Channels.
x1: First cursor position (time).
x2: Second cursor position (time).
Returns:
CursorValues with interpolated values for each channel.
"""
"""Compute interpolated values at two X-axis positions."""
channels: list[tuple[str, Channel]] = []
if isinstance(spec_or_channels, PlotSpec):

View File

@@ -168,3 +168,8 @@ class PlotSpec:
show_grid: bool = True
height: int = 500
width: int = 900
# Web UI compact mode: disables hover tooltips, removes legend,
# uses tight margins, and renders axis labels in a smaller font.
compact: bool = False
hovermode: str | bool = "x unified"
margin: dict[str, int] | None = None

View File

@@ -42,8 +42,15 @@ class PluginRegistry:
self._plugins: list[ImpaktPlugin] = []
def register_reader(self, reader: Any) -> None:
"""Register a data reader."""
"""Register a data reader.
Forwards to the IO reader registry so plugin readers are
discoverable by Session.open() and the web UI.
"""
self._readers.append(reader)
from impakt.io.reader import register_reader
register_reader(reader)
logger.info("Plugin reader registered: %s", getattr(reader, "format_name", reader))
def register_transform(self, name: str, transform_cls: type) -> None:

View File

@@ -9,6 +9,7 @@ Threshold values are versioned — this module supports multiple protocol years.
from __future__ import annotations
from pathlib import Path
from typing import Any
from impakt.criteria.base import CriterionResult
@@ -64,9 +65,44 @@ def _points_from_color(color: Color, max_points: float) -> float:
return max_points * color_fractions[color]
# Threshold sets by year
# Format: {criterion: (green, yellow, orange, brown, red, higher_is_worse, max_points)}
THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]] = {
# ---------------------------------------------------------------------------
# Threshold loading — from YAML files or hardcoded fallback
# ---------------------------------------------------------------------------
_THRESHOLDS_DIR = Path(__file__).parent / "thresholds"
def _load_yaml_thresholds(
version: str,
) -> dict[str, tuple[float, float, float, float, float, bool, float]]:
"""Load thresholds from a YAML file for a given version."""
yaml_path = _THRESHOLDS_DIR / f"euro_ncap_{version}.yaml"
if not yaml_path.exists():
return {}
import yaml
data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {}
result: dict[str, tuple[float, float, float, float, float, bool, float]] = {}
for name, vals in data.items():
if isinstance(vals, dict):
result[name] = (
float(vals["green"]),
float(vals["yellow"]),
float(vals["orange"]),
float(vals["brown"]),
float(vals["red"]),
bool(vals.get("higher_is_worse", True)),
float(vals.get("max_points", 0)),
)
return result
# Hardcoded fallback (used if YAML files are missing)
_THRESHOLDS_2024_FALLBACK: dict[str, tuple[float, float, float, float, float, bool, float]] = {
"HIC15": (500.0, 620.0, 700.0, 850.0, 1000.0, True, 4.0),
"3ms Clip": (42.0, 48.0, 54.0, 57.0, 60.0, True, 4.0),
"Chest Deflection": (22.0, 34.0, 42.0, 50.0, 63.0, True, 4.0),
@@ -77,9 +113,17 @@ THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]
"Viscous Criterion": (0.32, 0.56, 0.8, 0.9, 1.0, True, 2.0),
}
THRESHOLDS: dict[str, dict[str, tuple[float, float, float, float, float, bool, float]]] = {
"2024": THRESHOLDS_2024,
}
def _get_thresholds(
version: str,
) -> dict[str, tuple[float, float, float, float, float, bool, float]]:
"""Get thresholds for a version, preferring YAML over hardcoded."""
thresholds = _load_yaml_thresholds(version)
if thresholds:
return thresholds
if version == "2024":
return _THRESHOLDS_2024_FALLBACK
return {}
class EuroNCAP:
@@ -87,11 +131,12 @@ class EuroNCAP:
def __init__(self, version: str = "2024") -> None:
self._version = version
if version not in THRESHOLDS:
self._thresholds = _get_thresholds(version)
if not self._thresholds:
raise ValueError(
f"Unknown Euro NCAP version: {version}. Available: {list(THRESHOLDS.keys())}"
f"No thresholds found for Euro NCAP version: {version}. "
f"Check {_THRESHOLDS_DIR} for YAML files."
)
self._thresholds = THRESHOLDS[version]
@property
def protocol_name(self) -> str:

View File

@@ -6,27 +6,55 @@ The overall rating is determined by the worst sub-rating.
from __future__ import annotations
from pathlib import Path
from typing import Any
from impakt.criteria.base import CriterionResult
from impakt.protocol.base import BodyRegionScore, ProtocolResult, Rating
_THRESHOLDS_DIR = Path(__file__).parent / "thresholds"
# Thresholds: (good_limit, acceptable_limit, marginal_limit)
# Values above marginal_limit = Poor
# higher_is_worse indicates that higher values are worse
IIHS_THRESHOLDS_2024: dict[str, tuple[float, float, float, bool]] = {
def _load_iihs_yaml(version: str) -> dict[str, tuple[float, float, float, bool]]:
"""Load IIHS thresholds from YAML."""
yaml_path = _THRESHOLDS_DIR / f"iihs_{version}.yaml"
if not yaml_path.exists():
return {}
import yaml
data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {}
result: dict[str, tuple[float, float, float, bool]] = {}
for name, vals in data.items():
if isinstance(vals, dict):
result[name] = (
float(vals["good"]),
float(vals["acceptable"]),
float(vals["marginal"]),
bool(vals.get("higher_is_worse", True)),
)
return result
# Hardcoded fallback
_IIHS_2024_FALLBACK: dict[str, tuple[float, float, float, bool]] = {
"HIC15": (250.0, 500.0, 700.0, True),
"Chest Deflection": (38.0, 50.0, 63.0, True), # mm
"Femur Load Left": (3.8, 6.2, 10.0, True), # kN
"Femur Load Right": (3.8, 6.2, 10.0, True), # kN
"Chest Deflection": (38.0, 50.0, 63.0, True),
"Femur Load Left": (3.8, 6.2, 10.0, True),
"Femur Load Right": (3.8, 6.2, 10.0, True),
"Nij": (0.52, 0.78, 1.0, True),
"Tibia Index": (0.5, 0.8, 1.3, True),
}
IIHS_THRESHOLDS: dict[str, dict[str, tuple[float, float, float, bool]]] = {
"2024": IIHS_THRESHOLDS_2024,
}
def _get_iihs_thresholds(version: str) -> dict[str, tuple[float, float, float, bool]]:
thresholds = _load_iihs_yaml(version)
if thresholds:
return thresholds
if version == "2024":
return _IIHS_2024_FALLBACK
return {}
def _rate_value(
@@ -70,11 +98,12 @@ class IIHS:
def __init__(self, version: str = "2024") -> None:
self._version = version
if version not in IIHS_THRESHOLDS:
self._thresholds = _get_iihs_thresholds(version)
if not self._thresholds:
raise ValueError(
f"Unknown IIHS version: {version}. Available: {list(IIHS_THRESHOLDS.keys())}"
f"No thresholds found for IIHS version: {version}. "
f"Check {_THRESHOLDS_DIR} for YAML files."
)
self._thresholds = IIHS_THRESHOLDS[version]
@property
def protocol_name(self) -> str:

View File

@@ -0,0 +1,76 @@
# Euro NCAP 2024 Adult Occupant Frontal Impact Thresholds
#
# Each criterion: [green, yellow, orange, brown, red, higher_is_worse, max_points]
# Sliding-scale: values at or below green = full points, at or above red = zero.
HIC15:
green: 500.0
yellow: 620.0
orange: 700.0
brown: 850.0
red: 1000.0
higher_is_worse: true
max_points: 4.0
3ms Clip:
green: 42.0
yellow: 48.0
orange: 54.0
brown: 57.0
red: 60.0
higher_is_worse: true
max_points: 4.0
Chest Deflection:
green: 22.0
yellow: 34.0
orange: 42.0
brown: 50.0
red: 63.0
higher_is_worse: true
max_points: 4.0
Nij:
green: 0.5
yellow: 0.65
orange: 0.8
brown: 0.9
red: 1.0
higher_is_worse: true
max_points: 2.0
Femur Load Left:
green: 3.8
yellow: 5.4
orange: 7.0
brown: 8.5
red: 10.0
higher_is_worse: true
max_points: 2.0
Femur Load Right:
green: 3.8
yellow: 5.4
orange: 7.0
brown: 8.5
red: 10.0
higher_is_worse: true
max_points: 2.0
Tibia Index:
green: 0.4
yellow: 0.7
orange: 1.0
brown: 1.15
red: 1.3
higher_is_worse: true
max_points: 2.0
Viscous Criterion:
green: 0.32
yellow: 0.56
orange: 0.8
brown: 0.9
red: 1.0
higher_is_worse: true
max_points: 2.0

View File

@@ -0,0 +1,40 @@
# IIHS 2024 Crashworthiness Thresholds
#
# Each criterion: [good, acceptable, marginal, higher_is_worse]
# Values above marginal = Poor.
HIC15:
good: 250.0
acceptable: 500.0
marginal: 700.0
higher_is_worse: true
Chest Deflection:
good: 38.0
acceptable: 50.0
marginal: 63.0
higher_is_worse: true
Femur Load Left:
good: 3.8
acceptable: 6.2
marginal: 10.0
higher_is_worse: true
Femur Load Right:
good: 3.8
acceptable: 6.2
marginal: 10.0
higher_is_worse: true
Nij:
good: 0.52
acceptable: 0.78
marginal: 1.0
higher_is_worse: true
Tibia Index:
good: 0.5
acceptable: 0.8
marginal: 1.3
higher_is_worse: true

View File

@@ -2,6 +2,23 @@
Provides the ``Session`` and ``Template`` classes that serve as the
primary entry points for both scripting and the web UI.
Usage::
from impakt import Session
test = Session.open("tests/mme_data/3239")
ch = test.channel("11HEAD0000H3ACXP")
# Fluent chaining — each transform returns a ChannelHandle
filtered = ch.transform.cfc(1000).transform.y_align()
# Compute all injury criteria (auto-detects channels)
criteria = test.compute_criteria()
# Score against a protocol
result = test.evaluate("euro_ncap")
result.to_pdf("report.pdf")
"""
from __future__ import annotations
@@ -36,12 +53,19 @@ class Session:
Wraps TestData with session state, template binding, and
convenience methods for transforms, criteria, and plotting.
Usage:
Usage::
test = Session.open("/path/to/test_001")
ch = test.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(1000)
filtered = ch.transform.cfc(1000) # returns ChannelHandle, chainable
criteria = test.compute_criteria() # auto-detect channels
result = test.evaluate("euro_ncap") # score against protocol
result.to_pdf("report.pdf")
"""
_plugins_discovered = False
def __init__(self, test_data: TestData, session_mgr: SessionManager | None = None) -> None:
self._data = test_data
self._session_mgr = session_mgr or (
@@ -49,12 +73,29 @@ class Session:
)
self._template: TemplateSpec | None = None
@classmethod
def _discover_plugins(cls) -> None:
"""Discover and register plugins. Called once on first Session.open()."""
if cls._plugins_discovered:
return
cls._plugins_discovered = True
try:
from impakt.plugin.registry import discover_all
discover_all()
except Exception as e:
logger.debug("Plugin discovery failed (non-fatal): %s", e)
@classmethod
def open(cls, path: str | Path) -> Session:
"""Open a crash test from a path.
Auto-detects the format using the reader registry.
Discovers plugins on first call.
"""
# Discover plugins (idempotent — safe to call multiple times)
cls._discover_plugins()
path = Path(path).resolve()
registry = get_registry()
test_data = registry.read(path)
@@ -87,6 +128,11 @@ class Session:
"""Underlying TestData object."""
return self._data
@property
def path(self) -> Path | None:
"""Path to the test data directory."""
return self._data.path
@property
def channel_names(self) -> list[str]:
return self._data.channel_names
@@ -98,7 +144,10 @@ class Session:
# ----- Channel access -----
def channel(self, name: str) -> ChannelHandle:
"""Get a channel by name, wrapped with transform convenience methods."""
"""Get a channel by name, wrapped with transform convenience methods.
Returns a ChannelHandle with a fluent ``.transform`` interface.
"""
ch = self._data.get(name)
return ChannelHandle(ch)
@@ -118,6 +167,47 @@ class Session:
"""Hierarchical channel tree for UI display."""
return self._data.channel_tree()
# ----- Criteria & Protocol -----
def compute_criteria(self) -> dict[str, CriterionResult]:
"""Auto-detect channels and compute all applicable injury criteria.
Uses ISO channel naming to find the right channels for each criterion.
Returns a dict of criterion name -> CriterionResult.
"""
from impakt.web.components.criteria import auto_compute_criteria
return auto_compute_criteria(self._data)
def evaluate(self, protocol: str = "euro_ncap", version: str = "") -> ProtocolResult:
"""Compute criteria and score against a rating protocol.
Args:
protocol: Protocol name ('euro_ncap', 'us_ncap', 'iihs').
version: Protocol version (optional, uses latest if empty).
Returns:
ProtocolResult with stars/rating and per-region scores.
"""
criteria = self.compute_criteria()
if protocol == "euro_ncap":
from impakt.protocol.euro_ncap import evaluate
return evaluate(criteria, version=version or "2024")
elif protocol == "us_ncap":
from impakt.protocol.us_ncap import evaluate
return evaluate(criteria, version=version or "2023")
elif protocol == "iihs":
from impakt.protocol.iihs import evaluate
return evaluate(criteria, version=version or "2024")
else:
raise ValueError(
f"Unknown protocol: {protocol}. Use 'euro_ncap', 'us_ncap', or 'iihs'."
)
# ----- Template -----
def apply_template(self, name_or_spec: str | TemplateSpec) -> None:
@@ -188,15 +278,16 @@ class Session:
class ChannelHandle:
"""Wrapper around a Channel providing fluent transform access.
Example:
Each transform method on ``.transform`` returns a new ``ChannelHandle``,
enabling chaining::
ch = session.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(1000)
aligned = ch.transform.x_align(method="threshold", threshold_value=5.0)
result = ch.transform.cfc(1000).transform.y_align().transform.trim(t_end=0.1)
"""
def __init__(self, channel: Channel) -> None:
self._channel = channel
self.transform = TransformProxy(channel)
self.transform = TransformProxy(self)
@property
def raw(self) -> Channel:
@@ -207,6 +298,10 @@ class ChannelHandle:
def name(self) -> str:
return self._channel.name
@property
def code(self):
return self._channel.code
@property
def data(self) -> np.ndarray:
return self._channel.data
@@ -219,6 +314,14 @@ class ChannelHandle:
def unit(self) -> str:
return self._channel.unit
@property
def peak(self) -> float:
return self._channel.peak
@property
def sample_rate(self) -> float:
return self._channel.sample_rate
def value_at(self, t: float) -> float:
return self._channel.value_at(t)
@@ -238,44 +341,51 @@ class ChannelHandle:
class TransformProxy:
"""Fluent transform interface for a channel.
Each method returns a new Channel (non-destructive).
Each method returns a new ``ChannelHandle`` wrapping the transformed
channel, enabling chaining.
"""
def __init__(self, channel: Channel) -> None:
self._channel = channel
def __init__(self, handle: ChannelHandle) -> None:
self._handle = handle
def cfc(self, cfc_class: int) -> Channel:
@property
def _channel(self) -> Channel:
return self._handle._channel
def cfc(self, cfc_class: int) -> ChannelHandle:
"""Apply CFC filter."""
from impakt.transform.cfc import CFCFilter
return CFCFilter(cfc_class=cfc_class).apply(self._channel)
return ChannelHandle(CFCFilter(cfc_class=cfc_class).apply(self._channel))
def x_align(
self, method: str = "manual", reference_time: float = 0.0, **kwargs: Any
) -> Channel:
) -> ChannelHandle:
"""Apply time-zero alignment."""
from impakt.transform.align import XAlign
return XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel)
return ChannelHandle(
XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel)
)
def y_align(self, window: tuple[float, float] | None = None) -> Channel:
def y_align(self, window: tuple[float, float] | None = None) -> ChannelHandle:
"""Apply Y-axis zero correction."""
from impakt.transform.align import YAlign
start, end = window if window else (None, None)
return YAlign(window_start=start, window_end=end).apply(self._channel)
return ChannelHandle(YAlign(window_start=start, window_end=end).apply(self._channel))
def trim(self, t_start: float | None = None, t_end: float | None = None) -> Channel:
def trim(self, t_start: float | None = None, t_end: float | None = None) -> ChannelHandle:
"""Trim to a time range."""
from impakt.transform.resample import Trim
return Trim(t_start=t_start, t_end=t_end).apply(self._channel)
return ChannelHandle(Trim(t_start=t_start, t_end=t_end).apply(self._channel))
def resample(self, target_rate: float) -> Channel:
def resample(self, target_rate: float) -> ChannelHandle:
"""Resample to a new rate."""
from impakt.transform.resample import Resample
return Resample(target_rate=target_rate).apply(self._channel)
return ChannelHandle(Resample(target_rate=target_rate).apply(self._channel))
class Template:

View File

@@ -1,7 +1,7 @@
"""Dash web application factory.
Creates the Dash app with all layout components and callbacks registered.
The AppState holds server-side data; Dash stores hold lightweight UI state.
The AppState holds Session objects server-side; Dash stores hold lightweight UI state.
"""
from __future__ import annotations
@@ -13,14 +13,12 @@ import dash
import dash_bootstrap_components as dbc
from impakt.channel.model import TestData
from impakt.script.api import Session
from impakt.template.library import TemplateLibrary
from impakt.web.callbacks import register_callbacks
from impakt.web.layout import build_layout
from impakt.web.state import AppState
if TYPE_CHECKING:
from impakt.script.api import Session
def create_app(
session_or_data: Session | TestData | None = None,
@@ -37,21 +35,15 @@ def create_app(
Returns:
Configured Dash app ready to run.
"""
# Build or use provided AppState
if app_state is None:
app_state = AppState()
if session_or_data is not None:
if hasattr(session_or_data, "data"):
test_data: TestData = session_or_data.data # type: ignore[union-attr]
else:
test_data = session_or_data # type: ignore[assignment]
# Create a LoadedTest from TestData directly
from impakt.web.state import LoadedTest
loaded = LoadedTest(test_data)
app_state._tests[test_data.test_id] = loaded
app_state._test_order.append(test_data.test_id)
if isinstance(session_or_data, Session):
app_state.add_session(session_or_data)
elif isinstance(session_or_data, TestData):
# Wrap raw TestData in a Session
session = Session(session_or_data)
app_state.add_session(session)
# Discover templates
if template_names is None:
@@ -64,14 +56,13 @@ def create_app(
# Title
title = "Impakt"
if app_state.primary_test:
title = f"Impakt {app_state.primary_test.test_id}"
title = f"Impakt \u2014 {app_state.primary_test.test_id}"
app = dash.Dash(
__name__,
external_stylesheets=[dbc.themes.FLATLY],
title=title,
suppress_callback_exceptions=True,
# Prevent browser from caching old layouts
serve_locally=True,
meta_tags=[
{"http-equiv": "Cache-Control", "content": "no-cache, no-store, must-revalidate"},
@@ -92,16 +83,9 @@ def serve(
port: int = 8050,
debug: bool = False,
) -> None:
"""Convenience function to create and run the web UI.
Args:
session_or_data: Session or TestData to visualize, or None for empty app.
template: Template name to pre-apply.
port: Server port.
debug: Enable Dash debug mode.
"""
"""Convenience function to create and run the web UI."""
if template and session_or_data and hasattr(session_or_data, "apply_template"):
session_or_data.apply_template(template) # type: ignore[union-attr]
session_or_data.apply_template(template)
app = create_app(session_or_data)
print(f"Impakt running at http://localhost:{port}")

View File

@@ -1,9 +1,7 @@
"""Criteria computation callbacks.
Handles:
- Compute All button -> runs auto_compute_criteria
- Protocol scoring
- Results display
Uses Session.compute_criteria() and Session.evaluate() from the
scripting API — the same path used by programmatic scripts.
"""
from __future__ import annotations
@@ -14,11 +12,7 @@ import dash
from dash import Input, Output, State, html
from dash.exceptions import PreventUpdate
from impakt.web.components.criteria import (
auto_compute_criteria,
build_criteria_results_display,
score_protocol,
)
from impakt.web.components.criteria import build_criteria_results_display
from impakt.web.state import AppState
@@ -41,8 +35,8 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None:
if primary is None:
return html.Div("No test loaded.", className="text-danger small")
# Compute criteria
criteria = auto_compute_criteria(primary.data)
# Use Session.compute_criteria() — single code path for script + web
criteria = primary.compute_criteria()
if not criteria:
return html.Div(
@@ -51,7 +45,10 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None:
className="text-warning small",
)
# Score against protocol
protocol_result = score_protocol(criteria, protocol)
# Score against protocol — use Session.evaluate()
try:
protocol_result = primary.evaluate(protocol)
except Exception:
protocol_result = None
return build_criteria_results_display(criteria, protocol_result)

View File

@@ -1,10 +1,11 @@
"""Plot rendering callbacks.
Handles:
- Updating plot figures when channels/transforms change
- Cursor line rendering (X1/X2 vertical lines)
- Hover data is NOT shown as a Plotly tooltip — instead the cursor grid
picks it up via the hoverData callback
Builds a PlotSpec from the UI state and delegates to PlotEngine.render().
This is the single rendering path — the same PlotEngine used by the
scripting API renders the web UI plots.
Transform application uses TransformChain, making the pipeline
serializable and reproducible.
"""
from __future__ import annotations
@@ -15,16 +16,55 @@ from typing import Any
import dash
import numpy as np
import plotly.graph_objects as go
from dash import ALL, Input, Output, State, ctx
from dash import Input, Output, State
from dash.exceptions import PreventUpdate
from impakt.channel.model import Channel
from impakt.plot.engine import DEFAULT_COLORS
from impakt.plot.engine import DEFAULT_COLORS, PlotEngine
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
from impakt.transform.align import XAlign, YAlign
from impakt.transform.base import TransformChain
from impakt.transform.cfc import CFCFilter
from impakt.transform.resultant import resultant_from_channels
from impakt.web.state import AppState
# Module-level engine instance (stateless, safe to reuse)
_engine = PlotEngine()
def _build_transform_chain(
cfc_value: str,
y_align: bool,
x_align_method: str,
x_align_value: float | None,
per_channel_cfc: str | None = None,
) -> TransformChain:
"""Build a TransformChain from the current UI state.
Per-channel CFC override takes precedence over the global CFC setting.
"""
chain = TransformChain()
# CFC filter
effective_cfc = per_channel_cfc if per_channel_cfc else cfc_value
if effective_cfc and effective_cfc != "none":
try:
chain = chain.append(CFCFilter(cfc_class=int(effective_cfc)))
except (ValueError, Exception):
pass
# Y-align
if y_align:
chain = chain.append(YAlign())
# X-align
if x_align_method == "manual" and x_align_value is not None:
chain = chain.append(XAlign(method="manual", reference_time=x_align_value))
elif x_align_method == "threshold" and x_align_value is not None:
chain = chain.append(XAlign(method="threshold", threshold_value=x_align_value))
return chain
def _resolve_channels(
selected_keys: list[str],
@@ -35,12 +75,15 @@ def _resolve_channels(
x_align_value: float | None,
show_resultant: bool,
) -> list[tuple[str, Channel]]:
"""Resolve selected channel keys to Channel objects with transforms applied.
"""Resolve selected channel keys to transformed Channel objects.
Per-channel overrides (from app_state.channel_overrides) take precedence
over the global CFC setting.
Uses TransformChain for each channel. Per-channel overrides from
app_state.channel_overrides take precedence over the global CFC.
Returns list of (label, transformed_channel) tuples.
"""
channels: list[tuple[str, Channel]] = []
multi_test = len(app_state.tests) > 1
for key in selected_keys:
if "::" in key:
@@ -55,29 +98,22 @@ def _resolve_channels(
if ch is None:
continue
# Determine CFC: per-channel override takes precedence over global
# Build per-channel transform chain
override = app_state.channel_overrides.get(key, {})
ch_cfc = override.get("cfc", "")
effective_cfc = ch_cfc if ch_cfc else cfc_value
per_ch_cfc = override.get("cfc", "")
chain = _build_transform_chain(
cfc_value,
y_align,
x_align_method,
x_align_value,
per_channel_cfc=per_ch_cfc,
)
if effective_cfc and effective_cfc != "none":
try:
ch = CFCFilter(cfc_class=int(effective_cfc)).apply(ch)
except (ValueError, Exception):
pass
# Apply Y-align
if y_align:
ch = YAlign().apply(ch)
# Apply X-align
if x_align_method == "manual" and x_align_value is not None:
ch = XAlign(method="manual", reference_time=x_align_value).apply(ch)
elif x_align_method == "threshold" and x_align_value is not None:
ch = XAlign(method="threshold", threshold_value=x_align_value).apply(ch)
# Apply the chain
if len(chain) > 0:
ch = chain.apply(ch)
# Build label
multi_test = len(app_state.tests) > 1
label = ch.code.short_label if ch.code.is_valid else ch.name
if multi_test:
label = f"[{test_id}] {label}"
@@ -99,7 +135,7 @@ def _resolve_channels(
res_label = (
f"{res.code.location_label} Resultant" if res.code.is_valid else "Resultant"
)
if len(app_state.tests) > 1:
if multi_test:
res_label = f"[{comps[0].source_test_id}] {res_label}"
channels.append((res_label, res))
except Exception:
@@ -108,123 +144,61 @@ def _resolve_channels(
return channels
def _build_figure(
def _build_plot_spec(
channels: list[tuple[str, Channel]],
cursor_x1: float | None,
cursor_x2: float | None,
cfc_value: str,
corridors: list[dict] | None = None,
) -> go.Figure:
"""Build a Plotly figure from resolved channels."""
fig = go.Figure()
) -> PlotSpec:
"""Build a PlotSpec from resolved channels and UI state.
if not channels:
fig.update_layout(
template="plotly_white",
annotations=[
{
"text": "Select channels from the grid to plot",
"xref": "paper",
"yref": "paper",
"x": 0.5,
"y": 0.5,
"showarrow": False,
"font": {"size": 14, "color": "#bbb"},
}
],
xaxis={"visible": False},
yaxis={"visible": False},
margin={"l": 40, "r": 20, "t": 30, "b": 40},
)
return fig
# Add traces — hovermode is disabled at the layout level; cursor tracking
# is handled by our own JS mousemove handler (cursor_tracker.js).
Channels are already transformed — they are wrapped in ChannelRef
objects with no additional transform chain (transforms were applied
during resolution).
"""
# Build ChannelRef objects
refs: list[ChannelRef] = []
for i, (label, ch) in enumerate(channels):
color = DEFAULT_COLORS[i % len(DEFAULT_COLORS)]
fig.add_trace(
go.Scatter(
x=ch.time.tolist(),
y=ch.data.tolist(),
mode="lines",
name=label,
line=dict(color=color, width=1.5),
refs.append(
ChannelRef(
channel=ch,
style=PlotStyle(label=label, color=color),
)
)
# Add corridor fills
# Build Corridor objects from raw dicts
corridor_objs: list[Corridor] = []
if corridors:
for corridor in corridors:
if not corridor.get("visible", True):
for c in corridors:
if not c.get("visible", True):
continue
c_time = corridor["time"]
c_upper = corridor["upper"]
c_lower = corridor["lower"]
c_name = corridor.get("name", "Corridor")
# Upper bound
fig.add_trace(
go.Scatter(
x=c_time,
y=c_upper,
mode="lines",
line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"),
showlegend=False,
name=f"{c_name} upper",
)
)
# Lower bound with fill to upper
fig.add_trace(
go.Scatter(
x=c_time,
y=c_lower,
mode="lines",
line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"),
fill="tonexty",
fillcolor="rgba(100,100,255,0.1)",
showlegend=True,
name=c_name,
corridor_objs.append(
Corridor(
name=c.get("name", "Corridor"),
time=np.array(c["time"]),
lower=np.array(c["lower"]),
upper=np.array(c["upper"]),
style=CorridorStyle(),
)
)
# Add X1/X2 cursor lines
if cursor_x1 is not None:
fig.add_vline(
x=cursor_x1,
line_dash="dash",
line_color="rgba(220,53,69,0.6)",
line_width=1,
annotation_text=f"X1={cursor_x1:.4f}s",
annotation_font_size=9,
annotation_font_color="rgba(220,53,69,0.8)",
)
if cursor_x2 is not None:
fig.add_vline(
x=cursor_x2,
line_dash="dash",
line_color="rgba(13,110,253,0.6)",
line_width=1,
annotation_text=f"X2={cursor_x2:.4f}s",
annotation_font_size=9,
annotation_font_color="rgba(13,110,253,0.8)",
)
# Cursor positions
x_cursors = None
if cursor_x1 is not None and cursor_x2 is not None:
x_cursors = (cursor_x1, cursor_x2)
# Layout — hovermode is disabled; cursor tracking is handled entirely
# by our JS (cursor_tracker.js) which reads pixel positions from mousemove
# events and converts to data coordinates via Plotly's axis internals.
# Y-axis label from first channel
y_label = channels[0][1].unit if channels else ""
fig.update_layout(
xaxis_title=dict(text="Time (s)", font=dict(size=10, color="#999")),
yaxis_title=dict(text=y_label, font=dict(size=10, color="#999")),
template="plotly_white",
hovermode=False,
showlegend=False,
margin=dict(l=45, r=8, t=4, b=28),
return PlotSpec(
channels=refs,
corridors=corridor_objs,
x_cursors=x_cursors,
y_label=y_label,
compact=True, # Web UI always uses compact mode
)
return fig
def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None:
"""Register all plot-related callbacks."""
@@ -271,10 +245,5 @@ def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None:
show_resultant,
)
return _build_figure(
channels,
cursor_x1,
cursor_x2,
cfc_value,
corridors=corridors_data,
)
spec = _build_plot_spec(channels, cursor_x1, cursor_x2, corridors_data)
return _engine.render(spec)

View File

@@ -1,221 +0,0 @@
"""Collapsible channel tree component.
Renders channels in a hierarchical accordion:
Test Object > Body Region > Measurement Type > [channels]
Supports search filtering, select-all per group, and channel preview
(peak, unit, sample rate) on hover/expand.
"""
from __future__ import annotations
from typing import Any
import dash_bootstrap_components as dbc
from dash import dcc, html
from impakt.web.state import AppState
def _make_channel_item(ch_info: dict[str, str], show_test_prefix: bool = False) -> html.Div:
"""Build a single channel item with checkbox and preview info."""
label = ch_info["label"]
key = ch_info["key"]
peak = ch_info.get("peak", "")
unit = ch_info.get("unit", "")
preview = f" ({peak} {unit})" if peak else ""
return html.Div(
[
dbc.Checkbox(
id={"type": "channel-check", "index": key},
label=html.Span(
[
html.Span(label, style={"fontSize": "12px"}),
html.Span(
preview,
style={"fontSize": "10px", "color": "#999", "marginLeft": "4px"},
),
]
),
value=False,
style={"marginBottom": "1px"},
),
],
className="channel-item",
)
def _make_group_header(
group_label: str,
group_id: str,
channel_keys: list[str],
) -> html.Div:
"""Build a measurement group header with select-all button."""
return html.Div(
[
html.Span(group_label, style={"fontSize": "12px", "fontWeight": "500"}),
dbc.Button(
"All",
id={"type": "select-group-btn", "index": group_id},
color="link",
size="sm",
style={"fontSize": "10px", "padding": "0 4px", "textDecoration": "none"},
),
# Hidden store for this group's channel keys
dcc.Store(
id={"type": "group-channels", "index": group_id},
data=channel_keys,
),
],
className="d-flex align-items-center justify-content-between",
)
def build_channel_tree(app_state: AppState) -> dbc.Card:
"""Build the full channel tree panel with search and accordion."""
if app_state.is_empty:
return dbc.Card(
[
dbc.CardHeader("Channels", className="fw-bold py-2"),
dbc.CardBody(
html.Div("No test loaded", className="text-muted small text-center py-3"),
),
]
)
multi_test = len(app_state.tests) > 1
full_tree = app_state.build_channel_tree()
# Build accordion items
accordion_items = []
group_counter = 0
for test_id, test_tree in full_tree.items():
test_label = ""
loaded = app_state.get_test(test_id)
if loaded and multi_test:
test_label = f"[{test_id}] "
for obj_label, locations in test_tree.items():
# Top-level accordion: test object (Driver, Vehicle Structure, etc.)
obj_content = []
for loc_label, measurements in locations.items():
for meas_label, channels in measurements.items():
group_id = f"grp_{group_counter}"
group_counter += 1
channel_keys = [ch["key"] for ch in channels]
header_label = (
f"{loc_label}{meas_label}" if loc_label != meas_label else meas_label
)
obj_content.append(_make_group_header(header_label, group_id, channel_keys))
for ch in channels:
obj_content.append(_make_channel_item(ch, show_test_prefix=multi_test))
obj_content.append(html.Hr(style={"margin": "4px 0"}))
if obj_content:
# Remove trailing hr
if obj_content and isinstance(obj_content[-1], html.Hr):
obj_content = obj_content[:-1]
display_label = f"{test_label}{obj_label}" if test_label else obj_label
accordion_items.append(
dbc.AccordionItem(
html.Div(obj_content),
title=display_label,
style={"fontSize": "12px"},
)
)
return dbc.Card(
[
dbc.CardHeader(
[
html.Span("Channels", className="fw-bold"),
html.Span(
f" ({app_state.total_channels})",
style={"fontSize": "11px", "color": "#999"},
),
],
className="py-2",
),
dbc.CardBody(
[
# Search
dbc.Input(
id="channel-search",
placeholder="Search channels...",
type="text",
size="sm",
className="mb-2",
debounce=True,
),
# Selected channels display
html.Div(id="selected-channels-badges", className="mb-2"),
# Tree
html.Div(
dbc.Accordion(
accordion_items,
flush=True,
always_open=True,
start_collapsed=True,
id="channel-accordion",
),
style={"maxHeight": "450px", "overflowY": "auto"},
id="channel-tree-container",
),
],
className="py-2",
),
]
)
def build_selected_channels_badges(selected_keys: list[str], app_state: AppState) -> list:
"""Build badge pills showing currently selected channels."""
if not selected_keys:
return [
html.Span("No channels selected", className="text-muted", style={"fontSize": "11px"})
]
badges = []
for key in selected_keys[:20]: # Limit display to 20
# Extract label
if "::" in key:
test_id, ch_name = key.split("::", 1)
ch = app_state.get_channel(test_id, ch_name)
if ch and ch.code.is_valid:
label = ch.code.short_label
else:
label = ch_name
if len(app_state.tests) > 1:
label = f"{test_id}: {label}"
else:
label = key
badges.append(
dbc.Badge(
label,
id={"type": "channel-badge", "index": key},
color="primary",
pill=True,
className="me-1 mb-1",
style={"fontSize": "10px", "cursor": "pointer"},
)
)
if len(selected_keys) > 20:
badges.append(
html.Span(
f"+{len(selected_keys) - 20} more",
className="text-muted",
style={"fontSize": "10px"},
)
)
return badges

View File

@@ -129,7 +129,7 @@ def build_test_info_panel(app_state: AppState) -> dbc.Card:
f"{meta.impact.speed_kmh:.1f}" if meta.impact.speed_kmh else "",
style={"fontSize": "12px"},
),
html.Td(str(loaded.channel_count), style={"fontSize": "12px"}),
html.Td(str(len(loaded)), style={"fontSize": "12px"}),
html.Td(
dbc.Button(
"x",

View File

@@ -1,14 +1,12 @@
"""Application state management.
AppState is the central data store for the web UI. It holds all loaded tests,
manages channel transforms, and provides the data that callbacks need to
render plots, compute criteria, and generate reports.
AppState is the central data store for the web UI. It holds Session
objects (from the scripting API) and delegates to them for channel
access, transforms, criteria, and protocol scoring.
AppState is NOT a Dash Store — it lives server-side in Python memory. The Dash
stores hold lightweight references (test IDs, selected channels, plot config)
AppState is NOT a Dash Store — it lives server-side in Python memory.
Dash stores hold lightweight references (test IDs, selected channels)
that callbacks use to look up data from AppState.
This design keeps large numpy arrays out of the browser and avoids serialization.
"""
from __future__ import annotations
@@ -17,220 +15,95 @@ import logging
from pathlib import Path
from typing import Any
from impakt.channel.model import Channel, ChannelGroup, TestData
from impakt.io.reader import get_registry
from impakt.channel.model import Channel, TestData
from impakt.script.api import Session
from impakt.template.library import TemplateLibrary
from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
from impakt.template.session import SessionManager
from impakt.template.model import PlotDefinition, TemplateSpec
from impakt.transform.align import YAlign
from impakt.transform.cfc import CFCFilter
logger = logging.getLogger(__name__)
class LoadedTest:
"""A single loaded test with its metadata and channel access."""
def __init__(self, test_data: TestData, color_offset: int = 0) -> None:
self.data = test_data
self.color_offset = color_offset
@property
def test_id(self) -> str:
return self.data.test_id
@property
def label(self) -> str:
"""Short label for UI display."""
meta = self.data.metadata
parts = [meta.test_number]
if meta.vehicle.make:
parts.append(f"{meta.vehicle.make} {meta.vehicle.model}".strip())
return "".join(parts) if len(parts) > 1 else parts[0]
@property
def channel_count(self) -> int:
return len(self.data)
def __len__(self) -> int:
return len(self.data)
class AppState:
"""Server-side application state.
Manages multiple loaded tests and provides channel resolution
for the UI callbacks.
Manages multiple loaded test sessions and provides channel resolution
for the UI callbacks. All test data access goes through Session objects.
"""
def __init__(self) -> None:
self._tests: dict[str, LoadedTest] = {}
self._test_order: list[str] = []
self._color_counter: int = 0
self._sessions: dict[str, Session] = {}
self._session_order: list[str] = []
self._template_library = TemplateLibrary()
self._active_template: TemplateSpec | None = None
self._session_managers: dict[str, SessionManager] = {}
# Per-channel transform overrides: {channel_key: {"cfc": "600", "y_align": True}}
# Per-channel transform overrides: {channel_key: {"cfc": "600"}}
self.channel_overrides: dict[str, dict[str, Any]] = {}
# Active corridors: list of {name, time, lower, upper, visible}
# Active corridors
self.corridors: list[dict[str, Any]] = []
def load_test(self, path: str | Path) -> LoadedTest:
def load_test(self, path: str | Path) -> Session:
"""Load a test from a path and add it to the state.
If a test with the same ID is already loaded, replaces it.
Returns:
The loaded test.
Raises:
ValueError: If the path is not a valid test directory.
Returns the Session object.
"""
path = Path(path).resolve()
registry = get_registry()
test_data = registry.read(path)
session = Session.open(path)
test_id = session.test_id
loaded = LoadedTest(test_data, color_offset=self._color_counter)
self._color_counter += len(test_data)
self._sessions[test_id] = session
if test_id not in self._session_order:
self._session_order.append(test_id)
test_id = test_data.test_id
if test_id in self._tests:
self._tests[test_id] = loaded
else:
self._tests[test_id] = loaded
self._test_order.append(test_id)
logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(session))
return session
logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(test_data))
return loaded
def add_session(self, session: Session) -> None:
"""Add an already-created Session to the state."""
test_id = session.test_id
self._sessions[test_id] = session
if test_id not in self._session_order:
self._session_order.append(test_id)
def remove_test(self, test_id: str) -> None:
"""Remove a loaded test."""
if test_id in self._tests:
del self._tests[test_id]
self._test_order.remove(test_id)
self._sessions.pop(test_id, None)
if test_id in self._session_order:
self._session_order.remove(test_id)
@property
def tests(self) -> list[LoadedTest]:
"""All loaded tests in load order."""
return [self._tests[tid] for tid in self._test_order if tid in self._tests]
def sessions(self) -> list[Session]:
"""All loaded sessions in load order."""
return [self._sessions[tid] for tid in self._session_order if tid in self._sessions]
@property
def tests(self) -> list[Session]:
"""Alias for sessions (backward compat for components)."""
return self.sessions
@property
def test_ids(self) -> list[str]:
return list(self._test_order)
return list(self._session_order)
@property
def primary_test(self) -> LoadedTest | None:
"""The first loaded test (primary for criteria, etc.)."""
if self._test_order:
return self._tests.get(self._test_order[0])
def primary_test(self) -> Session | None:
"""The first loaded session (primary for criteria, etc.)."""
if self._session_order:
return self._sessions.get(self._session_order[0])
return None
def get_test(self, test_id: str) -> LoadedTest | None:
return self._tests.get(test_id)
def get_test(self, test_id: str) -> Session | None:
return self._sessions.get(test_id)
def get_channel(self, test_id: str, channel_name: str) -> Channel | None:
"""Resolve a channel from a test ID and channel name."""
test = self._tests.get(test_id)
if test is None:
session = self._sessions.get(test_id)
if session is None:
return None
try:
return test.data.get(channel_name)
return session.data.get(channel_name)
except KeyError:
return None
def resolve_channel(
self,
channel_key: str,
cfc_class: int | None = None,
y_align: bool = False,
) -> Channel | None:
"""Resolve a channel key (test_id::channel_name) and apply transforms.
Channel keys use the format 'test_id::channel_name'. If no '::'
separator, assumes the primary test.
"""
if "::" in channel_key:
test_id, ch_name = channel_key.split("::", 1)
else:
if self.primary_test is None:
return None
test_id = self.primary_test.test_id
ch_name = channel_key
ch = self.get_channel(test_id, ch_name)
if ch is None:
return None
# Apply transforms
if cfc_class is not None:
try:
ch = CFCFilter(cfc_class=cfc_class).apply(ch)
except (ValueError, Exception):
pass
if y_align:
ch = YAlign().apply(ch)
return ch
def build_channel_tree(
self,
) -> dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]]:
"""Build a hierarchical channel tree across all loaded tests.
Returns:
{test_id: {object_label: {location_label: {measurement_label: [{name, label, key}]}}}}
"""
result: dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]] = {}
for loaded in self.tests:
tree = loaded.data.channel_tree()
test_tree: dict[str, dict[str, dict[str, list[dict[str, str]]]]] = {}
for obj, locations in sorted(tree.items()):
obj_tree: dict[str, dict[str, list[dict[str, str]]]] = {}
for loc, measurements in sorted(locations.items()):
loc_tree: dict[str, list[dict[str, str]]] = {}
for meas, channels in sorted(measurements.items()):
ch_list = []
for ch in channels:
label = ch.code.short_label if ch.code.is_valid else ch.name
ch_list.append(
{
"name": ch.name,
"label": label,
"key": f"{loaded.test_id}::{ch.name}",
"unit": ch.unit,
"peak": f"{ch.peak:.2f}",
"samples": str(ch.n_samples),
"rate": f"{ch.sample_rate:.0f}",
}
)
loc_tree[meas] = ch_list
obj_tree[loc] = loc_tree
test_tree[obj] = obj_tree
result[loaded.test_id] = test_tree
return result
def flat_channel_list(self) -> list[dict[str, str]]:
"""Flat list of all channels across all tests, for dropdown/checklist options."""
items: list[dict[str, str]] = []
multi = len(self._tests) > 1
for loaded in self.tests:
prefix = f"[{loaded.test_id}] " if multi else ""
for ch in sorted(loaded.data, key=lambda c: c.name):
label = ch.code.short_label if ch.code.is_valid else ch.name
items.append(
{
"label": f"{prefix}{label}",
"value": f"{loaded.test_id}::{ch.name}",
}
)
return items
# ----- Template & Session -----
@property
@@ -245,52 +118,37 @@ class AppState:
def template_names(self) -> list[str]:
return self._template_library.list()
def apply_template(
self,
name: str,
selected_keys: list[str] | None = None,
) -> tuple[list[str], dict[str, str]]:
def apply_template(self, name: str) -> tuple[list[str], dict[str, str]]:
"""Apply a template by name.
Resolves the template's channel patterns against the primary test,
sets the active template, and returns the resolved channel keys and
transform settings.
Returns:
(selected_channel_keys, transform_settings)
Resolves channel patterns against the primary test.
Returns (selected_channel_keys, transform_settings).
"""
template = self._template_library.get(name)
self._active_template = template
# Persist to session if primary test has a path
primary = self.primary_test
if primary and primary.data.path:
mgr = self._get_session_manager(primary.data.path)
mgr.apply_template(template)
if primary:
primary.apply_template(template)
# Resolve channel patterns from the template
# Resolve channel patterns
resolved_keys: list[str] = []
if primary:
for plot_def in template.plots:
for pattern in plot_def.channel_patterns:
matches = primary.data.find(pattern)
matches = primary.find(pattern)
for ch in matches:
key = f"{primary.test_id}::{ch.name}"
if key not in resolved_keys:
resolved_keys.append(key)
# Transform settings
transforms: dict[str, str] = {}
if template.default_cfc is not None:
transforms["cfc"] = str(template.default_cfc)
else:
transforms["cfc"] = "none"
logger.info(
"Applied template '%s' — resolved %d channels",
name,
len(resolved_keys),
)
logger.info("Applied template '%s'%d channels", name, len(resolved_keys))
return resolved_keys, transforms
def save_as_template(
@@ -303,30 +161,19 @@ class AppState:
x2: float | None,
protocol: str = "",
) -> TemplateSpec:
"""Capture the current UI state as a new template.
Converts selected channels back to patterns and stores the
current filter/cursor/protocol settings.
"""
# Convert selected keys to channel patterns
"""Capture current UI state as a new template."""
patterns: list[str] = []
for key in selected_keys:
if "::" in key:
_, ch_name = key.split("::", 1)
else:
ch_name = key
# Use the raw channel name as a pattern (exact match)
ch_name = key.split("::", 1)[-1] if "::" in key else key
if ch_name not in patterns:
patterns.append(ch_name)
# Build plot definition
plot = PlotDefinition(
title=name,
channel_patterns=patterns,
x_cursors=(x1, x2) if x1 is not None and x2 is not None else None,
)
# Build transforms list
transforms: list[dict[str, Any]] = []
if cfc_value and cfc_value != "none":
transforms.append({"type": "cfc_filter", "cfc_class": int(cfc_value)})
@@ -343,17 +190,18 @@ class AppState:
self._template_library.save(template)
self._active_template = template
logger.info("Saved template '%s' with %d channel patterns", name, len(patterns))
logger.info("Saved template '%s' with %d patterns", name, len(patterns))
return template
def save_session(self, selected_keys: list[str], cfc_value: str, **overrides: Any) -> None:
"""Auto-save current state to the session for the primary test."""
primary = self.primary_test
if not primary or not primary.data.path:
if not primary or not primary.path:
return
mgr = self._get_session_manager(primary.data.path)
from impakt.template.session import SessionManager
mgr = SessionManager(primary.path)
mgr.state.overrides = {
"selected_channels": selected_keys,
"cfc": cfc_value,
@@ -362,12 +210,14 @@ class AppState:
mgr.save()
def load_session_state(self) -> dict[str, Any] | None:
"""Load saved session state for the primary test, if any."""
"""Load saved session state for the primary test."""
primary = self.primary_test
if not primary or not primary.data.path:
if not primary or not primary.path:
return None
mgr = self._get_session_manager(primary.data.path)
from impakt.template.session import SessionManager
mgr = SessionManager(primary.path)
if not mgr.has_session:
return None
@@ -377,21 +227,48 @@ class AppState:
"template": mgr.state.template_name,
}
def _get_session_manager(self, path: Path) -> SessionManager:
key = str(path)
if key not in self._session_managers:
self._session_managers[key] = SessionManager(path)
return self._session_managers[key]
# ----- Channel tree / grid helpers -----
def build_channel_tree(self) -> dict[str, Any]:
"""Build hierarchical channel tree across all loaded tests."""
result: dict[str, Any] = {}
for session in self.sessions:
tree = session.channel_tree()
test_tree: dict[str, Any] = {}
for obj, locations in sorted(tree.items()):
obj_tree: dict[str, Any] = {}
for loc, measurements in sorted(locations.items()):
loc_tree: dict[str, Any] = {}
for meas, channels in sorted(measurements.items()):
ch_list = []
for ch in channels:
label = ch.code.short_label if ch.code.is_valid else ch.name
ch_list.append(
{
"name": ch.name,
"label": label,
"key": f"{session.test_id}::{ch.name}",
"unit": ch.unit,
"peak": f"{ch.peak:.2f}",
"samples": str(ch.n_samples),
"rate": f"{ch.sample_rate:.0f}",
}
)
loc_tree[meas] = ch_list
obj_tree[loc] = loc_tree
test_tree[obj] = obj_tree
result[session.test_id] = test_tree
return result
@property
def is_empty(self) -> bool:
return len(self._tests) == 0
return len(self._sessions) == 0
@property
def total_channels(self) -> int:
return sum(t.channel_count for t in self.tests)
return sum(len(s) for s in self.sessions)
def __repr__(self) -> str:
test_info = ", ".join(f"{t.test_id}({t.channel_count}ch)" for t in self.tests)
test_info = ", ".join(f"{s.test_id}({len(s)}ch)" for s in self.sessions)
tmpl = f", template={self._active_template.name}" if self._active_template else ""
return f"AppState([{test_info}]{tmpl})"

View File

@@ -0,0 +1,93 @@
"""Tests for chest deflection, femur load, tibia index, 3ms clip, viscous criterion."""
import numpy as np
import pytest
from impakt.criteria import chest_deflection, clip_3ms, femur_load, tibia_index, viscous_criterion
class TestChestDeflection:
def test_basic(self, chest_deflection_channel):
result = chest_deflection(channel=chest_deflection_channel)
assert result.criterion == "Chest Deflection"
assert result.unit == "mm"
assert 30.0 < result.value < 40.0 # ~35mm in fixture
def test_peak_time(self, chest_deflection_channel):
result = chest_deflection(channel=chest_deflection_channel)
assert result.time_of_peak is not None
assert 0.0 < result.time_of_peak < 0.1
class TestClip3ms:
def test_basic(self, head_accel_x):
result = clip_3ms(head_accel_x)
assert result.criterion == "3ms Clip"
assert result.value > 0
def test_body_region(self, head_accel_x):
result = clip_3ms(head_accel_x)
assert result.body_region == "Chest"
class TestFemurLoad:
def test_single_channel(self, femur_left_channel):
result = femur_load(channel=femur_left_channel, side="left")
assert result.criterion == "Femur Load Left"
assert result.unit == "kN"
assert result.value > 0
def test_peak_time(self, femur_left_channel):
result = femur_load(channel=femur_left_channel, side="left")
assert result.time_of_peak is not None
class TestTibiaIndex:
def test_with_all_components(self, time_array, sample_rate):
from impakt.channel.code import ChannelCode
from impakt.channel.model import Channel
t = time_array
fz = np.zeros_like(t)
mx = np.zeros_like(t)
my = np.zeros_like(t)
mask = (t >= 0.02) & (t <= 0.08)
fz[mask] = -5000 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
mx[mask] = 50 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
my[mask] = 80 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
fz_ch = Channel(
name="TIBFZ",
code=ChannelCode.parse("TIBFZ"),
data=fz,
time=t,
unit="N",
sample_rate=sample_rate,
)
mx_ch = Channel(
name="TIBMX",
code=ChannelCode.parse("TIBMX"),
data=mx,
time=t,
unit="N·m",
sample_rate=sample_rate,
)
my_ch = Channel(
name="TIBMY",
code=ChannelCode.parse("TIBMY"),
data=my,
time=t,
unit="N·m",
sample_rate=sample_rate,
)
result = tibia_index(fz_channel=fz_ch, mx_channel=mx_ch, my_channel=my_ch)
assert result.value > 0
class TestViscousCriterion:
def test_basic(self, chest_deflection_channel):
result = viscous_criterion(channel=chest_deflection_channel)
assert result.criterion == "Viscous Criterion"
assert result.unit == "m/s"
assert result.value >= 0

View File

@@ -0,0 +1,81 @@
"""Tests for plot engine."""
import numpy as np
import pytest
from impakt.plot.engine import PlotEngine, DEFAULT_COLORS
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
class TestPlotEngine:
def test_render_empty(self):
engine = PlotEngine()
spec = PlotSpec()
fig = engine.render(spec)
assert fig is not None
def test_render_with_channels(self, head_accel_x, head_accel_y):
engine = PlotEngine()
spec = PlotSpec(
channels=[
ChannelRef(channel=head_accel_x, style=PlotStyle(label="Head X")),
ChannelRef(channel=head_accel_y, style=PlotStyle(label="Head Y")),
],
y_label="g",
)
fig = engine.render(spec)
assert len(fig.data) == 2
def test_render_compact_mode(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
compact=True,
)
fig = engine.render(spec)
assert fig.layout.hovermode is False
assert fig.layout.showlegend is False
def test_render_with_corridors(self, head_accel_x):
engine = PlotEngine()
corridor = Corridor(
name="Test Corridor",
time=np.linspace(0, 0.1, 100),
lower=np.full(100, -50.0),
upper=np.full(100, 50.0),
)
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
corridors=[corridor],
)
fig = engine.render(spec)
# 1 data trace + 2 corridor traces (upper + lower)
assert len(fig.data) == 3
def test_render_with_cursors(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
x_cursors=(0.02, 0.06),
)
fig = engine.render(spec)
# Should have vertical lines (as shapes)
assert len(fig.layout.shapes) >= 2
def test_compact_cursor_annotations(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
x_cursors=(0.02, 0.06),
compact=True,
)
fig = engine.render(spec)
# Should have X1/X2 annotations
annotations = [
a for a in (fig.layout.annotations or []) if "X1" in str(a.text) or "X2" in str(a.text)
]
assert len(annotations) == 2
def test_default_colors(self):
assert len(DEFAULT_COLORS) == 10
assert all(c.startswith("#") for c in DEFAULT_COLORS)

View File

@@ -0,0 +1,68 @@
"""Tests for IIHS and US NCAP scoring."""
import pytest
from impakt.criteria.base import CriterionResult
from impakt.protocol.iihs import IIHS, evaluate as iihs_evaluate
from impakt.protocol.us_ncap import USNCAP, evaluate as us_evaluate
@pytest.fixture
def sample_criteria():
return {
"HIC15": CriterionResult(criterion="HIC15", value=400, body_region="Head"),
"Nij": CriterionResult(criterion="Nij", value=0.4, body_region="Neck"),
"Chest Deflection": CriterionResult(
criterion="Chest Deflection", value=35, unit="mm", body_region="Chest"
),
"Femur Load Left": CriterionResult(
criterion="Femur Load Left", value=4.0, unit="kN", body_region="Femur Left"
),
"Femur Load Right": CriterionResult(
criterion="Femur Load Right", value=4.5, unit="kN", body_region="Femur Right"
),
}
class TestIIHS:
def test_good_results(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
assert result.protocol == "IIHS"
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
def test_region_scores(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
assert len(result.region_scores) > 0
for rs in result.region_scores:
assert rs.rating is not None
def test_poor_hic(self):
result = iihs_evaluate(
{
"HIC15": CriterionResult(criterion="HIC15", value=1500, body_region="Head"),
}
)
assert result.overall_rating == "POOR"
def test_summary(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
summary = result.summary()
assert "IIHS" in summary
class TestUSNCAP:
def test_basic(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert result.protocol == "US NCAP"
assert result.stars is not None
assert 1 <= result.stars <= 5
def test_injury_probabilities(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert "combined_injury_probability" in result.details
p = result.details["combined_injury_probability"]
assert 0.0 <= p <= 1.0
def test_region_scores(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert len(result.region_scores) > 0

126
tests/test_scripting_api.py Normal file
View File

@@ -0,0 +1,126 @@
"""Tests for the scripting API (Session, ChannelHandle, TransformProxy)."""
from pathlib import Path
import pytest
from impakt import Session, Template
FIXTURE_DATA = Path(__file__).parent / "fixtures" / "sample_mme"
MME_DATA = Path(__file__).parent / "mme_data"
class TestSession:
def test_open(self):
s = Session.open(FIXTURE_DATA)
assert s.test_id == "IMPAKT_SYNTH_001"
assert len(s) == 26
def test_channel_access(self):
s = Session.open(FIXTURE_DATA)
ch = s.channel("11HEAD0000ACXA")
assert ch.name == "11HEAD0000ACXA"
assert ch.peak > 0
def test_find(self):
s = Session.open(FIXTURE_DATA)
channels = s.find("*HEAD*AC*")
assert len(channels) == 3
def test_group(self):
s = Session.open(FIXTURE_DATA)
group = s.group("HEAD0000AC")
assert group.x is not None
def test_compute_criteria(self):
s = Session.open(FIXTURE_DATA)
criteria = s.compute_criteria()
assert len(criteria) > 0
def test_evaluate(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("euro_ncap")
assert result.stars is not None
assert result.protocol == "Euro NCAP"
def test_evaluate_us_ncap(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("us_ncap")
assert result.stars is not None
def test_evaluate_iihs(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("iihs")
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
def test_evaluate_invalid_protocol(self):
s = Session.open(FIXTURE_DATA)
with pytest.raises(ValueError, match="Unknown protocol"):
s.evaluate("invalid")
def test_contains(self):
s = Session.open(FIXTURE_DATA)
assert "11HEAD0000ACXA" in s
assert "NONEXISTENT" not in s
class TestChannelHandleChaining:
"""The fluent API must support chaining — each transform returns ChannelHandle."""
def test_single_transform(self):
s = Session.open(FIXTURE_DATA)
ch = s.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(600)
assert type(filtered).__name__ == "ChannelHandle"
assert filtered.raw.cfc_class == 600
def test_double_chain(self):
s = Session.open(FIXTURE_DATA)
result = s.channel("11HEAD0000ACXA").transform.cfc(600).transform.y_align()
assert type(result).__name__ == "ChannelHandle"
assert len(result.raw.transform_history) == 2
def test_triple_chain(self):
s = Session.open(FIXTURE_DATA)
result = (
s.channel("11HEAD0000ACXA")
.transform.cfc(1000)
.transform.y_align()
.transform.trim(t_start=0.0, t_end=0.1)
)
assert type(result).__name__ == "ChannelHandle"
assert len(result.raw.transform_history) == 3
def test_chain_preserves_data(self):
s = Session.open(FIXTURE_DATA)
original = s.channel("11HEAD0000ACXA")
original_peak = original.peak
filtered = original.transform.cfc(600)
# Original should be unchanged — peak should be the same
assert original.peak == original_peak
# Filtered should have different CFC and lower peak (smoothed)
assert filtered.raw.cfc_class == 600
assert filtered.peak <= original_peak
@pytest.mark.skipif(not (MME_DATA / "3239").exists(), reason="Real data not available")
class TestSessionRealData:
def test_open_real(self):
s = Session.open(MME_DATA / "3239")
assert s.test_id == "3239"
assert len(s) == 133
def test_full_pipeline(self):
s = Session.open(MME_DATA / "3239")
# Chain: get channel -> filter -> check
ch = s.channel("11HEAD0000H3ACXP").transform.cfc(1000)
assert ch.peak > 100 # Significant head acceleration
# Compute criteria
criteria = s.compute_criteria()
assert "HIC15" in criteria
# Evaluate
result = s.evaluate("euro_ncap")
assert result.stars is not None
assert result.stars >= 0

132
tests/test_template.py Normal file
View File

@@ -0,0 +1,132 @@
"""Tests for template model and session persistence."""
from pathlib import Path
import pytest
from impakt.template.library import TemplateLibrary
from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
from impakt.template.session import SessionManager
class TestTemplateSpec:
def test_yaml_round_trip(self):
template = TemplateSpec(
name="Test Template",
version=2,
description="A test template",
plots=[
PlotDefinition(
title="Head Acceleration",
channel_patterns=["*HEAD*AC*"],
x_cursors=(0.0, 0.1),
)
],
default_cfc=1000,
criteria=["hic15", "nij"],
protocol="euro_ncap",
)
yaml_str = template.to_yaml()
restored = TemplateSpec.from_yaml(yaml_str)
assert restored.name == "Test Template"
assert restored.version == 2
assert restored.default_cfc == 1000
assert len(restored.plots) == 1
assert restored.plots[0].channel_patterns == ["*HEAD*AC*"]
assert restored.criteria == ["hic15", "nij"]
def test_save_and_load(self, tmp_path):
template = TemplateSpec(name="Saved Test", version=1)
path = tmp_path / "test.yaml"
template.save(path)
assert path.exists()
loaded = TemplateSpec.load(path)
assert loaded.name == "Saved Test"
class TestSessionState:
def test_yaml_round_trip(self):
state = SessionState(
template_name="my_template",
template_version=3,
test_path="/data/test_001",
overrides={"cfc": "600", "selected": ["ch1", "ch2"]},
)
yaml_str = state.to_yaml()
restored = SessionState.from_yaml(yaml_str)
assert restored.template_name == "my_template"
assert restored.template_version == 3
assert restored.overrides["cfc"] == "600"
def test_save_and_load(self, tmp_path):
state = SessionState(template_name="test")
path = tmp_path / "session.yaml"
state.save(path)
assert path.exists()
loaded = SessionState.load(path)
assert loaded.template_name == "test"
class TestTemplateLibrary:
def test_empty_library(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
assert lib.list() == []
assert len(lib) == 0
def test_save_and_list(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
template = TemplateSpec(name="My Template")
lib.save(template)
assert "my_template" in lib.list()
assert len(lib) == 1
def test_get(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
lib.save(TemplateSpec(name="Getter Test", version=5))
loaded = lib.get("getter_test")
assert loaded.name == "Getter Test"
assert loaded.version == 5
def test_delete(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
lib.save(TemplateSpec(name="To Delete"))
assert lib.delete("to_delete")
assert "to_delete" not in lib.list()
def test_get_missing_raises(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
with pytest.raises(FileNotFoundError):
lib.get("nonexistent")
class TestSessionManager:
def test_create_and_save(self, tmp_path):
mgr = SessionManager(tmp_path)
mgr.state.template_name = "test_tmpl"
mgr.save()
assert mgr.has_session
assert (tmp_path / ".impakt" / "session.yaml").exists()
def test_load_existing(self, tmp_path):
# Save
mgr1 = SessionManager(tmp_path)
mgr1.state.template_name = "saved_tmpl"
mgr1.state.overrides = {"key": "value"}
mgr1.save()
# Load
mgr2 = SessionManager(tmp_path)
assert mgr2.state.template_name == "saved_tmpl"
assert mgr2.state.overrides["key"] == "value"
def test_clear(self, tmp_path):
mgr = SessionManager(tmp_path)
mgr.save()
assert mgr.has_session
mgr.clear()
assert not mgr.has_session

View File

@@ -0,0 +1,75 @@
"""Tests for math expressions and resultant computation."""
import numpy as np
import pytest
from impakt.transform.math_expr import math_expr
from impakt.transform.resultant import resultant_from_channels
from impakt.transform.resample import trim, resample
class TestMathExpr:
def test_simple_expression(self, head_accel_x, head_accel_z):
result = math_expr(
expression="sqrt(a**2 + b**2)",
channels={"a": head_accel_x, "b": head_accel_z},
name="resultant_xz",
unit="g",
)
assert result.name == "resultant_xz"
assert result.unit == "g"
assert result.peak > 0
assert len(result.data) == len(head_accel_x.data)
def test_constant_expression(self, head_accel_x):
result = math_expr(
expression="a * 0 + 42.0",
channels={"a": head_accel_x},
name="constant",
)
assert np.allclose(result.data, 42.0)
def test_invalid_expression(self, head_accel_x):
with pytest.raises(ValueError, match="Error evaluating"):
math_expr(
expression="invalid_func(a)",
channels={"a": head_accel_x},
)
def test_forbidden_expression(self, head_accel_x):
with pytest.raises(ValueError, match="Forbidden"):
math_expr(
expression="__import__('os')",
channels={"a": head_accel_x},
)
class TestResultant:
def test_from_channels(self, head_accel_x, head_accel_y, head_accel_z):
result = resultant_from_channels(head_accel_x, head_accel_y, head_accel_z)
assert result.code.direction == "R"
# Resultant >= any component
assert result.peak >= head_accel_x.peak
assert result.peak >= head_accel_y.peak
def test_from_two_channels(self, head_accel_x, head_accel_z):
result = resultant_from_channels(head_accel_x, head_accel_z)
assert result.peak > 0
def test_single_channel_raises(self):
with pytest.raises(ValueError, match="At least one"):
resultant_from_channels()
class TestTrimResample:
def test_trim(self, head_accel_x):
trimmed = trim(head_accel_x, t_start=0.0, t_end=0.05)
assert trimmed.time[0] >= 0.0
assert trimmed.time[-1] <= 0.05
assert len(trimmed.data) < len(head_accel_x.data)
def test_resample(self, head_accel_x):
resampled = resample(head_accel_x, target_rate=5000.0)
expected_samples = int(head_accel_x.duration * 5000.0)
assert abs(len(resampled.data) - expected_samples) <= 2
assert resampled.sample_rate == 5000.0

View File

@@ -19,11 +19,11 @@ class TestAppState:
def test_load_test(self):
state = AppState()
loaded = state.load_test(FIXTURE_DATA)
session = state.load_test(FIXTURE_DATA)
assert not state.is_empty
assert loaded.test_id == "IMPAKT_SYNTH_001"
assert loaded.channel_count == 26
assert state.primary_test is loaded
assert session.test_id == "IMPAKT_SYNTH_001"
assert len(session) == 26
assert state.primary_test is session
def test_load_multiple_tests(self):
state = AppState()
@@ -32,13 +32,13 @@ class TestAppState:
if (MME_DATA / "VW1FGS15").exists():
t2 = state.load_test(MME_DATA / "VW1FGS15")
assert len(state.tests) == 2
assert state.primary_test is t1 # First loaded is primary
assert state.primary_test is t1
assert state.total_channels == 26 + 10
def test_remove_test(self):
state = AppState()
loaded = state.load_test(FIXTURE_DATA)
state.remove_test(loaded.test_id)
session = state.load_test(FIXTURE_DATA)
state.remove_test(session.test_id)
assert state.is_empty
def test_get_channel(self):
@@ -54,38 +54,38 @@ class TestAppState:
ch = state.get_channel("IMPAKT_SYNTH_001", "NONEXISTENT")
assert ch is None
def test_resolve_channel_with_key(self):
def test_get_channel_via_session(self):
"""Channels can be accessed through the Session scripting API."""
state = AppState()
state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("IMPAKT_SYNTH_001::11HEAD0000ACXA")
assert ch is not None
session = state.primary_test
assert session is not None
ch_handle = session.channel("11HEAD0000ACXA")
assert ch_handle.name == "11HEAD0000ACXA"
def test_resolve_channel_primary_default(self):
def test_session_fluent_transforms(self):
"""Fluent transform chaining works through the Session API."""
state = AppState()
state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("11HEAD0000ACXA")
assert ch is not None
session = state.primary_test
ch = session.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(600).transform.y_align()
assert filtered.raw.cfc_class == 600
assert len(filtered.raw.transform_history) == 2
def test_resolve_channel_with_cfc(self):
def test_session_compute_criteria(self):
"""Session.compute_criteria() auto-detects channels."""
state = AppState()
state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("11HEAD0000ACXA", cfc_class=600)
assert ch is not None
assert ch.cfc_class == 600
def test_flat_channel_list(self):
state = AppState()
state.load_test(FIXTURE_DATA)
items = state.flat_channel_list()
assert len(items) == 26
assert all("value" in item and "label" in item for item in items)
criteria = state.primary_test.compute_criteria()
assert len(criteria) > 0
assert "HIC15" in criteria or "Chest Deflection" in criteria
def test_build_channel_tree(self):
state = AppState()
state.load_test(FIXTURE_DATA)
tree = state.build_channel_tree()
assert "IMPAKT_SYNTH_001" in tree
# Should have hierarchical structure
test_tree = tree["IMPAKT_SYNTH_001"]
assert len(test_tree) > 0
@@ -94,15 +94,23 @@ class TestAppState:
class TestAppStateRealData:
def test_load_real_mme(self):
state = AppState()
loaded = state.load_test(MME_DATA / "3239")
assert loaded.test_id == "3239"
assert loaded.channel_count == 133
session = state.load_test(MME_DATA / "3239")
assert session.test_id == "3239"
assert len(session) == 133
def test_channel_tree_real_data(self):
state = AppState()
state.load_test(MME_DATA / "3239")
tree = state.build_channel_tree()
assert "3239" in tree
# Should contain "Driver" in some key
test_tree = tree["3239"]
assert any("Driver" in k for k in test_tree)
def test_session_evaluate_real_data(self):
"""Full pipeline through Session API on real data."""
state = AppState()
state.load_test(MME_DATA / "3239")
result = state.primary_test.evaluate("euro_ncap")
assert result.stars is not None
assert result.stars >= 0
assert len(result.region_scores) > 0