diff --git a/src/impakt/io/csv.py b/src/impakt/io/csv.py
new file mode 100644
index 0000000..4dd5e8f
--- /dev/null
+++ b/src/impakt/io/csv.py
@@ -0,0 +1,29 @@
+"""CSV data reader (stub).
+
+Placeholder for a future CSV reader plugin.
+Install or implement to enable reading CSV crash test data.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from impakt.channel.model import TestData, TestMetadata
+
+
+class CSVReader:
+ """Reader for CSV time-series data (not yet implemented)."""
+
+ @property
+ def format_name(self) -> str:
+ return "CSV"
+
+ def supports(self, path: Path) -> bool:
+ # Only claim support for .csv files with crash test structure
+ return False # Disabled until implemented
+
+ def metadata(self, path: Path) -> TestMetadata:
+ raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.")
+
+ def read(self, path: Path) -> TestData:
+ raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.")
diff --git a/src/impakt/io/tdms.py b/src/impakt/io/tdms.py
new file mode 100644
index 0000000..e0ce04d
--- /dev/null
+++ b/src/impakt/io/tdms.py
@@ -0,0 +1,37 @@
+"""NI TDMS data reader (stub).
+
+Placeholder for a future TDMS reader plugin.
+Install nptdms to enable: pip install impakt[tdms]
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from impakt.channel.model import TestData, TestMetadata
+
+
+class TDMSReader:
+ """Reader for NI TDMS files (not yet implemented).
+
+ Requires the optional ``nptdms`` dependency::
+
+ pip install impakt[tdms]
+ """
+
+ @property
+ def format_name(self) -> str:
+ return "NI TDMS"
+
+ def supports(self, path: Path) -> bool:
+ return False # Disabled until implemented
+
+ def metadata(self, path: Path) -> TestMetadata:
+ raise NotImplementedError(
+ "TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]"
+ )
+
+ def read(self, path: Path) -> TestData:
+ raise NotImplementedError(
+ "TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]"
+ )
diff --git a/src/impakt/plot/__pycache__/engine.cpython-312.pyc b/src/impakt/plot/__pycache__/engine.cpython-312.pyc
index 4644c3e..ca36e3e 100644
Binary files a/src/impakt/plot/__pycache__/engine.cpython-312.pyc and b/src/impakt/plot/__pycache__/engine.cpython-312.pyc differ
diff --git a/src/impakt/plot/__pycache__/spec.cpython-312.pyc b/src/impakt/plot/__pycache__/spec.cpython-312.pyc
index f694a6f..49848d9 100644
Binary files a/src/impakt/plot/__pycache__/spec.cpython-312.pyc and b/src/impakt/plot/__pycache__/spec.cpython-312.pyc differ
diff --git a/src/impakt/plot/engine.py b/src/impakt/plot/engine.py
index c8fdbbd..aa6f555 100644
--- a/src/impakt/plot/engine.py
+++ b/src/impakt/plot/engine.py
@@ -2,6 +2,9 @@
Renders PlotSpec objects into interactive Plotly figures with
support for corridors, dual X-cursors, and export.
+
+The PlotEngine is the **single rendering path** — both the scripting
+API and the web UI construct PlotSpec objects and delegate here.
"""
from __future__ import annotations
@@ -30,12 +33,39 @@ DEFAULT_COLORS = [
class PlotEngine:
- """Renders PlotSpec into Plotly figures."""
+ """Renders PlotSpec into Plotly figures.
+
+ Supports two modes:
+ - **Default**: Full interactive plot with hover tooltips and legend.
+ - **Compact** (``spec.compact=True``): Web UI mode with no legend, no
+ hover tooltips, tight margins, smaller axis labels. Cursor tracking
+ is handled by external JS.
+ """
def render(self, spec: PlotSpec) -> go.Figure:
"""Render a PlotSpec into an interactive Plotly figure."""
fig = go.Figure()
+ if not spec.channels:
+ fig.update_layout(
+ template="plotly_white",
+ annotations=[
+ {
+ "text": "Select channels to plot",
+ "xref": "paper",
+ "yref": "paper",
+ "x": 0.5,
+ "y": 0.5,
+ "showarrow": False,
+ "font": {"size": 14, "color": "#bbb"},
+ }
+ ],
+ xaxis={"visible": False},
+ yaxis={"visible": False},
+ margin={"l": 40, "r": 20, "t": 30, "b": 40},
+ )
+ return fig
+
# Add corridor fills first (behind data traces)
for corridor in spec.corridors:
self._add_corridor(fig, corridor)
@@ -54,53 +84,92 @@ class PlotEngine:
color = style.color or DEFAULT_COLORS[i % len(DEFAULT_COLORS)]
label = style.label or ch_ref.label
- fig.add_trace(
- go.Scatter(
- x=ch.time,
- y=ch.data,
- mode="lines",
- name=label,
- line=dict(
- color=color,
- width=style.line_width,
- dash=style.line_dash,
- ),
- opacity=style.opacity,
- hovertemplate=f"{label}
t=%{{x:.6f}}s
%{{y:.4f}} {ch.unit}",
+ trace_kwargs: dict[str, Any] = {
+ "x": ch.time,
+ "y": ch.data,
+ "mode": "lines",
+ "name": label,
+ "line": dict(
+ color=color,
+ width=style.line_width,
+ dash=style.line_dash,
+ ),
+ "opacity": style.opacity,
+ }
+
+ if not spec.compact:
+ trace_kwargs["hovertemplate"] = (
+ f"{label}
t=%{{x:.6f}}s
%{{y:.4f}} {ch.unit}"
)
- )
+
+ fig.add_trace(go.Scatter(**trace_kwargs))
# Add cursor lines
if spec.x_cursors:
x1, x2 = spec.x_cursors
- for x_val, label in [(x1, "x1"), (x2, "x2")]:
+ if spec.compact:
+ # Color-coded cursor lines for compact mode
fig.add_vline(
- x=x_val,
+ x=x1,
line_dash="dash",
- line_color="gray",
+ line_color="rgba(220,53,69,0.6)",
line_width=1,
- annotation_text=f"{label}={x_val:.6f}s",
- annotation_position="top",
+ annotation_text=f"X1={x1:.4f}s",
+ annotation_font_size=9,
+ annotation_font_color="rgba(220,53,69,0.8)",
)
+ fig.add_vline(
+ x=x2,
+ line_dash="dash",
+ line_color="rgba(13,110,253,0.6)",
+ line_width=1,
+ annotation_text=f"X2={x2:.4f}s",
+ annotation_font_size=9,
+ annotation_font_color="rgba(13,110,253,0.8)",
+ )
+ else:
+ for x_val, lbl in [(x1, "x1"), (x2, "x2")]:
+ fig.add_vline(
+ x=x_val,
+ line_dash="dash",
+ line_color="gray",
+ line_width=1,
+ annotation_text=f"{lbl}={x_val:.6f}s",
+ annotation_position="top",
+ )
# Layout
- fig.update_layout(
- title=spec.title,
- xaxis_title=spec.x_label,
- yaxis_title=spec.y_label,
- showlegend=spec.show_legend,
- height=spec.height,
- width=spec.width,
- template="plotly_white",
- hovermode="x unified",
- legend=dict(
- orientation="h",
- yanchor="bottom",
- y=-0.3,
- xanchor="center",
- x=0.5,
- ),
- )
+ if spec.compact:
+ margin = spec.margin or {"l": 45, "r": 8, "t": 4, "b": 28}
+ fig.update_layout(
+ xaxis_title=dict(text=spec.x_label, font=dict(size=10, color="#999")),
+ yaxis_title=dict(text=spec.y_label, font=dict(size=10, color="#999")),
+ template="plotly_white",
+ hovermode=False,
+ showlegend=False,
+ margin=margin,
+ )
+ else:
+ margin = spec.margin or {"l": 60, "r": 20, "t": 40 if spec.title else 10, "b": 60}
+ hovermode = spec.hovermode if spec.hovermode is not None else "x unified"
+ fig.update_layout(
+ title=spec.title,
+ xaxis_title=spec.x_label,
+ yaxis_title=spec.y_label,
+ showlegend=spec.show_legend,
+ height=spec.height,
+ width=spec.width,
+ template="plotly_white",
+ hovermode=hovermode,
+ legend=dict(
+ orientation="h",
+ yanchor="bottom",
+ y=-0.3,
+ xanchor="center",
+ x=0.5,
+ ),
+ margin=margin,
+ )
if spec.show_grid:
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128,128,128,0.2)")
@@ -117,7 +186,6 @@ class PlotEngine:
"""Add a corridor (tolerance band) to the figure."""
style = corridor.style
- # Upper bound
fig.add_trace(
go.Scatter(
x=corridor.time,
@@ -128,8 +196,6 @@ class PlotEngine:
showlegend=False,
)
)
-
- # Lower bound with fill to upper
fig.add_trace(
go.Scatter(
x=corridor.time,
@@ -144,16 +210,7 @@ class PlotEngine:
)
def to_image(self, spec: PlotSpec, format: str = "png", scale: float = 2.0) -> bytes:
- """Render to a static image.
-
- Args:
- spec: Plot specification.
- format: Image format ('png', 'svg', 'pdf', 'jpeg').
- scale: Resolution multiplier.
-
- Returns:
- Image bytes.
- """
+ """Render to a static image."""
fig = self.render(spec)
return fig.to_image(format=format, scale=scale)
@@ -168,16 +225,7 @@ def cursor_values(
x1: float,
x2: float,
) -> CursorValues:
- """Compute interpolated values at two X-axis positions.
-
- Args:
- spec_or_channels: PlotSpec or list of Channels.
- x1: First cursor position (time).
- x2: Second cursor position (time).
-
- Returns:
- CursorValues with interpolated values for each channel.
- """
+ """Compute interpolated values at two X-axis positions."""
channels: list[tuple[str, Channel]] = []
if isinstance(spec_or_channels, PlotSpec):
diff --git a/src/impakt/plot/spec.py b/src/impakt/plot/spec.py
index f8772c1..abb7039 100644
--- a/src/impakt/plot/spec.py
+++ b/src/impakt/plot/spec.py
@@ -168,3 +168,8 @@ class PlotSpec:
show_grid: bool = True
height: int = 500
width: int = 900
+ # Web UI compact mode: disables hover tooltips, removes legend,
+ # uses tight margins, and renders axis labels in a smaller font.
+ compact: bool = False
+ hovermode: str | bool = "x unified"
+ margin: dict[str, int] | None = None
diff --git a/src/impakt/plugin/registry.py b/src/impakt/plugin/registry.py
index 72f2115..d73da9b 100644
--- a/src/impakt/plugin/registry.py
+++ b/src/impakt/plugin/registry.py
@@ -42,8 +42,15 @@ class PluginRegistry:
self._plugins: list[ImpaktPlugin] = []
def register_reader(self, reader: Any) -> None:
- """Register a data reader."""
+ """Register a data reader.
+
+ Forwards to the IO reader registry so plugin readers are
+ discoverable by Session.open() and the web UI.
+ """
self._readers.append(reader)
+ from impakt.io.reader import register_reader
+
+ register_reader(reader)
logger.info("Plugin reader registered: %s", getattr(reader, "format_name", reader))
def register_transform(self, name: str, transform_cls: type) -> None:
diff --git a/src/impakt/protocol/__pycache__/euro_ncap.cpython-312.pyc b/src/impakt/protocol/__pycache__/euro_ncap.cpython-312.pyc
index 4873670..18deb8c 100644
Binary files a/src/impakt/protocol/__pycache__/euro_ncap.cpython-312.pyc and b/src/impakt/protocol/__pycache__/euro_ncap.cpython-312.pyc differ
diff --git a/src/impakt/protocol/__pycache__/iihs.cpython-312.pyc b/src/impakt/protocol/__pycache__/iihs.cpython-312.pyc
index fe686f9..d62f7c9 100644
Binary files a/src/impakt/protocol/__pycache__/iihs.cpython-312.pyc and b/src/impakt/protocol/__pycache__/iihs.cpython-312.pyc differ
diff --git a/src/impakt/protocol/euro_ncap.py b/src/impakt/protocol/euro_ncap.py
index 19dd79c..265504b 100644
--- a/src/impakt/protocol/euro_ncap.py
+++ b/src/impakt/protocol/euro_ncap.py
@@ -9,6 +9,7 @@ Threshold values are versioned — this module supports multiple protocol years.
from __future__ import annotations
+from pathlib import Path
from typing import Any
from impakt.criteria.base import CriterionResult
@@ -64,9 +65,44 @@ def _points_from_color(color: Color, max_points: float) -> float:
return max_points * color_fractions[color]
-# Threshold sets by year
-# Format: {criterion: (green, yellow, orange, brown, red, higher_is_worse, max_points)}
-THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]] = {
+# ---------------------------------------------------------------------------
+# Threshold loading — from YAML files or hardcoded fallback
+# ---------------------------------------------------------------------------
+
+_THRESHOLDS_DIR = Path(__file__).parent / "thresholds"
+
+
+def _load_yaml_thresholds(
+ version: str,
+) -> dict[str, tuple[float, float, float, float, float, bool, float]]:
+ """Load thresholds from a YAML file for a given version."""
+ yaml_path = _THRESHOLDS_DIR / f"euro_ncap_{version}.yaml"
+ if not yaml_path.exists():
+ return {}
+
+ import yaml
+
+ data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
+ if not isinstance(data, dict):
+ return {}
+
+ result: dict[str, tuple[float, float, float, float, float, bool, float]] = {}
+ for name, vals in data.items():
+ if isinstance(vals, dict):
+ result[name] = (
+ float(vals["green"]),
+ float(vals["yellow"]),
+ float(vals["orange"]),
+ float(vals["brown"]),
+ float(vals["red"]),
+ bool(vals.get("higher_is_worse", True)),
+ float(vals.get("max_points", 0)),
+ )
+ return result
+
+
+# Hardcoded fallback (used if YAML files are missing)
+_THRESHOLDS_2024_FALLBACK: dict[str, tuple[float, float, float, float, float, bool, float]] = {
"HIC15": (500.0, 620.0, 700.0, 850.0, 1000.0, True, 4.0),
"3ms Clip": (42.0, 48.0, 54.0, 57.0, 60.0, True, 4.0),
"Chest Deflection": (22.0, 34.0, 42.0, 50.0, 63.0, True, 4.0),
@@ -77,9 +113,17 @@ THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]
"Viscous Criterion": (0.32, 0.56, 0.8, 0.9, 1.0, True, 2.0),
}
-THRESHOLDS: dict[str, dict[str, tuple[float, float, float, float, float, bool, float]]] = {
- "2024": THRESHOLDS_2024,
-}
+
+def _get_thresholds(
+ version: str,
+) -> dict[str, tuple[float, float, float, float, float, bool, float]]:
+ """Get thresholds for a version, preferring YAML over hardcoded."""
+ thresholds = _load_yaml_thresholds(version)
+ if thresholds:
+ return thresholds
+ if version == "2024":
+ return _THRESHOLDS_2024_FALLBACK
+ return {}
class EuroNCAP:
@@ -87,11 +131,12 @@ class EuroNCAP:
def __init__(self, version: str = "2024") -> None:
self._version = version
- if version not in THRESHOLDS:
+ self._thresholds = _get_thresholds(version)
+ if not self._thresholds:
raise ValueError(
- f"Unknown Euro NCAP version: {version}. Available: {list(THRESHOLDS.keys())}"
+ f"No thresholds found for Euro NCAP version: {version}. "
+ f"Check {_THRESHOLDS_DIR} for YAML files."
)
- self._thresholds = THRESHOLDS[version]
@property
def protocol_name(self) -> str:
diff --git a/src/impakt/protocol/iihs.py b/src/impakt/protocol/iihs.py
index 0ce26ec..48b33ab 100644
--- a/src/impakt/protocol/iihs.py
+++ b/src/impakt/protocol/iihs.py
@@ -6,27 +6,55 @@ The overall rating is determined by the worst sub-rating.
from __future__ import annotations
+from pathlib import Path
from typing import Any
from impakt.criteria.base import CriterionResult
from impakt.protocol.base import BodyRegionScore, ProtocolResult, Rating
+_THRESHOLDS_DIR = Path(__file__).parent / "thresholds"
-# Thresholds: (good_limit, acceptable_limit, marginal_limit)
-# Values above marginal_limit = Poor
-# higher_is_worse indicates that higher values are worse
-IIHS_THRESHOLDS_2024: dict[str, tuple[float, float, float, bool]] = {
+
+def _load_iihs_yaml(version: str) -> dict[str, tuple[float, float, float, bool]]:
+ """Load IIHS thresholds from YAML."""
+ yaml_path = _THRESHOLDS_DIR / f"iihs_{version}.yaml"
+ if not yaml_path.exists():
+ return {}
+ import yaml
+
+ data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
+ if not isinstance(data, dict):
+ return {}
+ result: dict[str, tuple[float, float, float, bool]] = {}
+ for name, vals in data.items():
+ if isinstance(vals, dict):
+ result[name] = (
+ float(vals["good"]),
+ float(vals["acceptable"]),
+ float(vals["marginal"]),
+ bool(vals.get("higher_is_worse", True)),
+ )
+ return result
+
+
+# Hardcoded fallback
+_IIHS_2024_FALLBACK: dict[str, tuple[float, float, float, bool]] = {
"HIC15": (250.0, 500.0, 700.0, True),
- "Chest Deflection": (38.0, 50.0, 63.0, True), # mm
- "Femur Load Left": (3.8, 6.2, 10.0, True), # kN
- "Femur Load Right": (3.8, 6.2, 10.0, True), # kN
+ "Chest Deflection": (38.0, 50.0, 63.0, True),
+ "Femur Load Left": (3.8, 6.2, 10.0, True),
+ "Femur Load Right": (3.8, 6.2, 10.0, True),
"Nij": (0.52, 0.78, 1.0, True),
"Tibia Index": (0.5, 0.8, 1.3, True),
}
-IIHS_THRESHOLDS: dict[str, dict[str, tuple[float, float, float, bool]]] = {
- "2024": IIHS_THRESHOLDS_2024,
-}
+
+def _get_iihs_thresholds(version: str) -> dict[str, tuple[float, float, float, bool]]:
+ thresholds = _load_iihs_yaml(version)
+ if thresholds:
+ return thresholds
+ if version == "2024":
+ return _IIHS_2024_FALLBACK
+ return {}
def _rate_value(
@@ -70,11 +98,12 @@ class IIHS:
def __init__(self, version: str = "2024") -> None:
self._version = version
- if version not in IIHS_THRESHOLDS:
+ self._thresholds = _get_iihs_thresholds(version)
+ if not self._thresholds:
raise ValueError(
- f"Unknown IIHS version: {version}. Available: {list(IIHS_THRESHOLDS.keys())}"
+ f"No thresholds found for IIHS version: {version}. "
+ f"Check {_THRESHOLDS_DIR} for YAML files."
)
- self._thresholds = IIHS_THRESHOLDS[version]
@property
def protocol_name(self) -> str:
diff --git a/src/impakt/protocol/thresholds/euro_ncap_2024.yaml b/src/impakt/protocol/thresholds/euro_ncap_2024.yaml
new file mode 100644
index 0000000..fb54bcd
--- /dev/null
+++ b/src/impakt/protocol/thresholds/euro_ncap_2024.yaml
@@ -0,0 +1,76 @@
+# Euro NCAP 2024 Adult Occupant Frontal Impact Thresholds
+#
+# Each criterion: [green, yellow, orange, brown, red, higher_is_worse, max_points]
+# Sliding-scale: values at or below green = full points, at or above red = zero.
+
+HIC15:
+ green: 500.0
+ yellow: 620.0
+ orange: 700.0
+ brown: 850.0
+ red: 1000.0
+ higher_is_worse: true
+ max_points: 4.0
+
+3ms Clip:
+ green: 42.0
+ yellow: 48.0
+ orange: 54.0
+ brown: 57.0
+ red: 60.0
+ higher_is_worse: true
+ max_points: 4.0
+
+Chest Deflection:
+ green: 22.0
+ yellow: 34.0
+ orange: 42.0
+ brown: 50.0
+ red: 63.0
+ higher_is_worse: true
+ max_points: 4.0
+
+Nij:
+ green: 0.5
+ yellow: 0.65
+ orange: 0.8
+ brown: 0.9
+ red: 1.0
+ higher_is_worse: true
+ max_points: 2.0
+
+Femur Load Left:
+ green: 3.8
+ yellow: 5.4
+ orange: 7.0
+ brown: 8.5
+ red: 10.0
+ higher_is_worse: true
+ max_points: 2.0
+
+Femur Load Right:
+ green: 3.8
+ yellow: 5.4
+ orange: 7.0
+ brown: 8.5
+ red: 10.0
+ higher_is_worse: true
+ max_points: 2.0
+
+Tibia Index:
+ green: 0.4
+ yellow: 0.7
+ orange: 1.0
+ brown: 1.15
+ red: 1.3
+ higher_is_worse: true
+ max_points: 2.0
+
+Viscous Criterion:
+ green: 0.32
+ yellow: 0.56
+ orange: 0.8
+ brown: 0.9
+ red: 1.0
+ higher_is_worse: true
+ max_points: 2.0
diff --git a/src/impakt/protocol/thresholds/iihs_2024.yaml b/src/impakt/protocol/thresholds/iihs_2024.yaml
new file mode 100644
index 0000000..3c28d17
--- /dev/null
+++ b/src/impakt/protocol/thresholds/iihs_2024.yaml
@@ -0,0 +1,40 @@
+# IIHS 2024 Crashworthiness Thresholds
+#
+# Each criterion: [good, acceptable, marginal, higher_is_worse]
+# Values above marginal = Poor.
+
+HIC15:
+ good: 250.0
+ acceptable: 500.0
+ marginal: 700.0
+ higher_is_worse: true
+
+Chest Deflection:
+ good: 38.0
+ acceptable: 50.0
+ marginal: 63.0
+ higher_is_worse: true
+
+Femur Load Left:
+ good: 3.8
+ acceptable: 6.2
+ marginal: 10.0
+ higher_is_worse: true
+
+Femur Load Right:
+ good: 3.8
+ acceptable: 6.2
+ marginal: 10.0
+ higher_is_worse: true
+
+Nij:
+ good: 0.52
+ acceptable: 0.78
+ marginal: 1.0
+ higher_is_worse: true
+
+Tibia Index:
+ good: 0.5
+ acceptable: 0.8
+ marginal: 1.3
+ higher_is_worse: true
diff --git a/src/impakt/script/__pycache__/api.cpython-312.pyc b/src/impakt/script/__pycache__/api.cpython-312.pyc
index 9a1ec10..4abac9f 100644
Binary files a/src/impakt/script/__pycache__/api.cpython-312.pyc and b/src/impakt/script/__pycache__/api.cpython-312.pyc differ
diff --git a/src/impakt/script/__pycache__/cli.cpython-312.pyc b/src/impakt/script/__pycache__/cli.cpython-312.pyc
deleted file mode 100644
index d435223..0000000
Binary files a/src/impakt/script/__pycache__/cli.cpython-312.pyc and /dev/null differ
diff --git a/src/impakt/script/api.py b/src/impakt/script/api.py
index 3f6120e..3325059 100644
--- a/src/impakt/script/api.py
+++ b/src/impakt/script/api.py
@@ -2,6 +2,23 @@
Provides the ``Session`` and ``Template`` classes that serve as the
primary entry points for both scripting and the web UI.
+
+Usage::
+
+ from impakt import Session
+
+ test = Session.open("tests/mme_data/3239")
+ ch = test.channel("11HEAD0000H3ACXP")
+
+ # Fluent chaining — each transform returns a ChannelHandle
+ filtered = ch.transform.cfc(1000).transform.y_align()
+
+ # Compute all injury criteria (auto-detects channels)
+ criteria = test.compute_criteria()
+
+ # Score against a protocol
+ result = test.evaluate("euro_ncap")
+ result.to_pdf("report.pdf")
"""
from __future__ import annotations
@@ -36,12 +53,19 @@ class Session:
Wraps TestData with session state, template binding, and
convenience methods for transforms, criteria, and plotting.
- Usage:
+ Usage::
+
test = Session.open("/path/to/test_001")
ch = test.channel("11HEAD0000ACXA")
- filtered = ch.transform.cfc(1000)
+ filtered = ch.transform.cfc(1000) # returns ChannelHandle, chainable
+
+ criteria = test.compute_criteria() # auto-detect channels
+ result = test.evaluate("euro_ncap") # score against protocol
+ result.to_pdf("report.pdf")
"""
+ _plugins_discovered = False
+
def __init__(self, test_data: TestData, session_mgr: SessionManager | None = None) -> None:
self._data = test_data
self._session_mgr = session_mgr or (
@@ -49,12 +73,29 @@ class Session:
)
self._template: TemplateSpec | None = None
+ @classmethod
+ def _discover_plugins(cls) -> None:
+ """Discover and register plugins. Called once on first Session.open()."""
+ if cls._plugins_discovered:
+ return
+ cls._plugins_discovered = True
+ try:
+ from impakt.plugin.registry import discover_all
+
+ discover_all()
+ except Exception as e:
+ logger.debug("Plugin discovery failed (non-fatal): %s", e)
+
@classmethod
def open(cls, path: str | Path) -> Session:
"""Open a crash test from a path.
Auto-detects the format using the reader registry.
+ Discovers plugins on first call.
"""
+ # Discover plugins (idempotent — safe to call multiple times)
+ cls._discover_plugins()
+
path = Path(path).resolve()
registry = get_registry()
test_data = registry.read(path)
@@ -87,6 +128,11 @@ class Session:
"""Underlying TestData object."""
return self._data
+ @property
+ def path(self) -> Path | None:
+ """Path to the test data directory."""
+ return self._data.path
+
@property
def channel_names(self) -> list[str]:
return self._data.channel_names
@@ -98,7 +144,10 @@ class Session:
# ----- Channel access -----
def channel(self, name: str) -> ChannelHandle:
- """Get a channel by name, wrapped with transform convenience methods."""
+ """Get a channel by name, wrapped with transform convenience methods.
+
+ Returns a ChannelHandle with a fluent ``.transform`` interface.
+ """
ch = self._data.get(name)
return ChannelHandle(ch)
@@ -118,6 +167,47 @@ class Session:
"""Hierarchical channel tree for UI display."""
return self._data.channel_tree()
+ # ----- Criteria & Protocol -----
+
+ def compute_criteria(self) -> dict[str, CriterionResult]:
+ """Auto-detect channels and compute all applicable injury criteria.
+
+ Uses ISO channel naming to find the right channels for each criterion.
+ Returns a dict of criterion name -> CriterionResult.
+ """
+ from impakt.web.components.criteria import auto_compute_criteria
+
+ return auto_compute_criteria(self._data)
+
+ def evaluate(self, protocol: str = "euro_ncap", version: str = "") -> ProtocolResult:
+ """Compute criteria and score against a rating protocol.
+
+ Args:
+ protocol: Protocol name ('euro_ncap', 'us_ncap', 'iihs').
+ version: Protocol version (optional, uses latest if empty).
+
+ Returns:
+ ProtocolResult with stars/rating and per-region scores.
+ """
+ criteria = self.compute_criteria()
+
+ if protocol == "euro_ncap":
+ from impakt.protocol.euro_ncap import evaluate
+
+ return evaluate(criteria, version=version or "2024")
+ elif protocol == "us_ncap":
+ from impakt.protocol.us_ncap import evaluate
+
+ return evaluate(criteria, version=version or "2023")
+ elif protocol == "iihs":
+ from impakt.protocol.iihs import evaluate
+
+ return evaluate(criteria, version=version or "2024")
+ else:
+ raise ValueError(
+ f"Unknown protocol: {protocol}. Use 'euro_ncap', 'us_ncap', or 'iihs'."
+ )
+
# ----- Template -----
def apply_template(self, name_or_spec: str | TemplateSpec) -> None:
@@ -188,15 +278,16 @@ class Session:
class ChannelHandle:
"""Wrapper around a Channel providing fluent transform access.
- Example:
+ Each transform method on ``.transform`` returns a new ``ChannelHandle``,
+ enabling chaining::
+
ch = session.channel("11HEAD0000ACXA")
- filtered = ch.transform.cfc(1000)
- aligned = ch.transform.x_align(method="threshold", threshold_value=5.0)
+ result = ch.transform.cfc(1000).transform.y_align().transform.trim(t_end=0.1)
"""
def __init__(self, channel: Channel) -> None:
self._channel = channel
- self.transform = TransformProxy(channel)
+ self.transform = TransformProxy(self)
@property
def raw(self) -> Channel:
@@ -207,6 +298,10 @@ class ChannelHandle:
def name(self) -> str:
return self._channel.name
+ @property
+ def code(self):
+ return self._channel.code
+
@property
def data(self) -> np.ndarray:
return self._channel.data
@@ -219,6 +314,14 @@ class ChannelHandle:
def unit(self) -> str:
return self._channel.unit
+ @property
+ def peak(self) -> float:
+ return self._channel.peak
+
+ @property
+ def sample_rate(self) -> float:
+ return self._channel.sample_rate
+
def value_at(self, t: float) -> float:
return self._channel.value_at(t)
@@ -238,44 +341,51 @@ class ChannelHandle:
class TransformProxy:
"""Fluent transform interface for a channel.
- Each method returns a new Channel (non-destructive).
+ Each method returns a new ``ChannelHandle`` wrapping the transformed
+ channel, enabling chaining.
"""
- def __init__(self, channel: Channel) -> None:
- self._channel = channel
+ def __init__(self, handle: ChannelHandle) -> None:
+ self._handle = handle
- def cfc(self, cfc_class: int) -> Channel:
+ @property
+ def _channel(self) -> Channel:
+ return self._handle._channel
+
+ def cfc(self, cfc_class: int) -> ChannelHandle:
"""Apply CFC filter."""
from impakt.transform.cfc import CFCFilter
- return CFCFilter(cfc_class=cfc_class).apply(self._channel)
+ return ChannelHandle(CFCFilter(cfc_class=cfc_class).apply(self._channel))
def x_align(
self, method: str = "manual", reference_time: float = 0.0, **kwargs: Any
- ) -> Channel:
+ ) -> ChannelHandle:
"""Apply time-zero alignment."""
from impakt.transform.align import XAlign
- return XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel)
+ return ChannelHandle(
+ XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel)
+ )
- def y_align(self, window: tuple[float, float] | None = None) -> Channel:
+ def y_align(self, window: tuple[float, float] | None = None) -> ChannelHandle:
"""Apply Y-axis zero correction."""
from impakt.transform.align import YAlign
start, end = window if window else (None, None)
- return YAlign(window_start=start, window_end=end).apply(self._channel)
+ return ChannelHandle(YAlign(window_start=start, window_end=end).apply(self._channel))
- def trim(self, t_start: float | None = None, t_end: float | None = None) -> Channel:
+ def trim(self, t_start: float | None = None, t_end: float | None = None) -> ChannelHandle:
"""Trim to a time range."""
from impakt.transform.resample import Trim
- return Trim(t_start=t_start, t_end=t_end).apply(self._channel)
+ return ChannelHandle(Trim(t_start=t_start, t_end=t_end).apply(self._channel))
- def resample(self, target_rate: float) -> Channel:
+ def resample(self, target_rate: float) -> ChannelHandle:
"""Resample to a new rate."""
from impakt.transform.resample import Resample
- return Resample(target_rate=target_rate).apply(self._channel)
+ return ChannelHandle(Resample(target_rate=target_rate).apply(self._channel))
class Template:
diff --git a/src/impakt/web/app.py b/src/impakt/web/app.py
index a724bc6..a6af628 100644
--- a/src/impakt/web/app.py
+++ b/src/impakt/web/app.py
@@ -1,7 +1,7 @@
"""Dash web application factory.
Creates the Dash app with all layout components and callbacks registered.
-The AppState holds server-side data; Dash stores hold lightweight UI state.
+The AppState holds Session objects server-side; Dash stores hold lightweight UI state.
"""
from __future__ import annotations
@@ -13,14 +13,12 @@ import dash
import dash_bootstrap_components as dbc
from impakt.channel.model import TestData
+from impakt.script.api import Session
from impakt.template.library import TemplateLibrary
from impakt.web.callbacks import register_callbacks
from impakt.web.layout import build_layout
from impakt.web.state import AppState
-if TYPE_CHECKING:
- from impakt.script.api import Session
-
def create_app(
session_or_data: Session | TestData | None = None,
@@ -37,21 +35,15 @@ def create_app(
Returns:
Configured Dash app ready to run.
"""
- # Build or use provided AppState
if app_state is None:
app_state = AppState()
if session_or_data is not None:
- if hasattr(session_or_data, "data"):
- test_data: TestData = session_or_data.data # type: ignore[union-attr]
- else:
- test_data = session_or_data # type: ignore[assignment]
-
- # Create a LoadedTest from TestData directly
- from impakt.web.state import LoadedTest
-
- loaded = LoadedTest(test_data)
- app_state._tests[test_data.test_id] = loaded
- app_state._test_order.append(test_data.test_id)
+ if isinstance(session_or_data, Session):
+ app_state.add_session(session_or_data)
+ elif isinstance(session_or_data, TestData):
+ # Wrap raw TestData in a Session
+ session = Session(session_or_data)
+ app_state.add_session(session)
# Discover templates
if template_names is None:
@@ -64,14 +56,13 @@ def create_app(
# Title
title = "Impakt"
if app_state.primary_test:
- title = f"Impakt — {app_state.primary_test.test_id}"
+ title = f"Impakt \u2014 {app_state.primary_test.test_id}"
app = dash.Dash(
__name__,
external_stylesheets=[dbc.themes.FLATLY],
title=title,
suppress_callback_exceptions=True,
- # Prevent browser from caching old layouts
serve_locally=True,
meta_tags=[
{"http-equiv": "Cache-Control", "content": "no-cache, no-store, must-revalidate"},
@@ -92,16 +83,9 @@ def serve(
port: int = 8050,
debug: bool = False,
) -> None:
- """Convenience function to create and run the web UI.
-
- Args:
- session_or_data: Session or TestData to visualize, or None for empty app.
- template: Template name to pre-apply.
- port: Server port.
- debug: Enable Dash debug mode.
- """
+ """Convenience function to create and run the web UI."""
if template and session_or_data and hasattr(session_or_data, "apply_template"):
- session_or_data.apply_template(template) # type: ignore[union-attr]
+ session_or_data.apply_template(template)
app = create_app(session_or_data)
print(f"Impakt running at http://localhost:{port}")
diff --git a/src/impakt/web/callbacks/criteria_callbacks.py b/src/impakt/web/callbacks/criteria_callbacks.py
index 9456bef..9d1fecc 100644
--- a/src/impakt/web/callbacks/criteria_callbacks.py
+++ b/src/impakt/web/callbacks/criteria_callbacks.py
@@ -1,9 +1,7 @@
"""Criteria computation callbacks.
-Handles:
-- Compute All button -> runs auto_compute_criteria
-- Protocol scoring
-- Results display
+Uses Session.compute_criteria() and Session.evaluate() from the
+scripting API — the same path used by programmatic scripts.
"""
from __future__ import annotations
@@ -14,11 +12,7 @@ import dash
from dash import Input, Output, State, html
from dash.exceptions import PreventUpdate
-from impakt.web.components.criteria import (
- auto_compute_criteria,
- build_criteria_results_display,
- score_protocol,
-)
+from impakt.web.components.criteria import build_criteria_results_display
from impakt.web.state import AppState
@@ -41,8 +35,8 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None:
if primary is None:
return html.Div("No test loaded.", className="text-danger small")
- # Compute criteria
- criteria = auto_compute_criteria(primary.data)
+ # Use Session.compute_criteria() — single code path for script + web
+ criteria = primary.compute_criteria()
if not criteria:
return html.Div(
@@ -51,7 +45,10 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None:
className="text-warning small",
)
- # Score against protocol
- protocol_result = score_protocol(criteria, protocol)
+ # Score against protocol — use Session.evaluate()
+ try:
+ protocol_result = primary.evaluate(protocol)
+ except Exception:
+ protocol_result = None
return build_criteria_results_display(criteria, protocol_result)
diff --git a/src/impakt/web/callbacks/plot_callbacks.py b/src/impakt/web/callbacks/plot_callbacks.py
index 9add249..643e69e 100644
--- a/src/impakt/web/callbacks/plot_callbacks.py
+++ b/src/impakt/web/callbacks/plot_callbacks.py
@@ -1,10 +1,11 @@
"""Plot rendering callbacks.
-Handles:
-- Updating plot figures when channels/transforms change
-- Cursor line rendering (X1/X2 vertical lines)
-- Hover data is NOT shown as a Plotly tooltip — instead the cursor grid
- picks it up via the hoverData callback
+Builds a PlotSpec from the UI state and delegates to PlotEngine.render().
+This is the single rendering path — the same PlotEngine used by the
+scripting API renders the web UI plots.
+
+Transform application uses TransformChain, making the pipeline
+serializable and reproducible.
"""
from __future__ import annotations
@@ -15,16 +16,55 @@ from typing import Any
import dash
import numpy as np
import plotly.graph_objects as go
-from dash import ALL, Input, Output, State, ctx
+from dash import Input, Output, State
from dash.exceptions import PreventUpdate
from impakt.channel.model import Channel
-from impakt.plot.engine import DEFAULT_COLORS
+from impakt.plot.engine import DEFAULT_COLORS, PlotEngine
+from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
from impakt.transform.align import XAlign, YAlign
+from impakt.transform.base import TransformChain
from impakt.transform.cfc import CFCFilter
from impakt.transform.resultant import resultant_from_channels
from impakt.web.state import AppState
+# Module-level engine instance (stateless, safe to reuse)
+_engine = PlotEngine()
+
+
+def _build_transform_chain(
+ cfc_value: str,
+ y_align: bool,
+ x_align_method: str,
+ x_align_value: float | None,
+ per_channel_cfc: str | None = None,
+) -> TransformChain:
+ """Build a TransformChain from the current UI state.
+
+ Per-channel CFC override takes precedence over the global CFC setting.
+ """
+ chain = TransformChain()
+
+ # CFC filter
+ effective_cfc = per_channel_cfc if per_channel_cfc else cfc_value
+ if effective_cfc and effective_cfc != "none":
+ try:
+ chain = chain.append(CFCFilter(cfc_class=int(effective_cfc)))
+ except (ValueError, Exception):
+ pass
+
+ # Y-align
+ if y_align:
+ chain = chain.append(YAlign())
+
+ # X-align
+ if x_align_method == "manual" and x_align_value is not None:
+ chain = chain.append(XAlign(method="manual", reference_time=x_align_value))
+ elif x_align_method == "threshold" and x_align_value is not None:
+ chain = chain.append(XAlign(method="threshold", threshold_value=x_align_value))
+
+ return chain
+
def _resolve_channels(
selected_keys: list[str],
@@ -35,12 +75,15 @@ def _resolve_channels(
x_align_value: float | None,
show_resultant: bool,
) -> list[tuple[str, Channel]]:
- """Resolve selected channel keys to Channel objects with transforms applied.
+ """Resolve selected channel keys to transformed Channel objects.
- Per-channel overrides (from app_state.channel_overrides) take precedence
- over the global CFC setting.
+ Uses TransformChain for each channel. Per-channel overrides from
+ app_state.channel_overrides take precedence over the global CFC.
+
+ Returns list of (label, transformed_channel) tuples.
"""
channels: list[tuple[str, Channel]] = []
+ multi_test = len(app_state.tests) > 1
for key in selected_keys:
if "::" in key:
@@ -55,29 +98,22 @@ def _resolve_channels(
if ch is None:
continue
- # Determine CFC: per-channel override takes precedence over global
+ # Build per-channel transform chain
override = app_state.channel_overrides.get(key, {})
- ch_cfc = override.get("cfc", "")
- effective_cfc = ch_cfc if ch_cfc else cfc_value
+ per_ch_cfc = override.get("cfc", "")
+ chain = _build_transform_chain(
+ cfc_value,
+ y_align,
+ x_align_method,
+ x_align_value,
+ per_channel_cfc=per_ch_cfc,
+ )
- if effective_cfc and effective_cfc != "none":
- try:
- ch = CFCFilter(cfc_class=int(effective_cfc)).apply(ch)
- except (ValueError, Exception):
- pass
-
- # Apply Y-align
- if y_align:
- ch = YAlign().apply(ch)
-
- # Apply X-align
- if x_align_method == "manual" and x_align_value is not None:
- ch = XAlign(method="manual", reference_time=x_align_value).apply(ch)
- elif x_align_method == "threshold" and x_align_value is not None:
- ch = XAlign(method="threshold", threshold_value=x_align_value).apply(ch)
+ # Apply the chain
+ if len(chain) > 0:
+ ch = chain.apply(ch)
# Build label
- multi_test = len(app_state.tests) > 1
label = ch.code.short_label if ch.code.is_valid else ch.name
if multi_test:
label = f"[{test_id}] {label}"
@@ -99,7 +135,7 @@ def _resolve_channels(
res_label = (
f"{res.code.location_label} Resultant" if res.code.is_valid else "Resultant"
)
- if len(app_state.tests) > 1:
+ if multi_test:
res_label = f"[{comps[0].source_test_id}] {res_label}"
channels.append((res_label, res))
except Exception:
@@ -108,123 +144,61 @@ def _resolve_channels(
return channels
-def _build_figure(
+def _build_plot_spec(
channels: list[tuple[str, Channel]],
cursor_x1: float | None,
cursor_x2: float | None,
- cfc_value: str,
corridors: list[dict] | None = None,
-) -> go.Figure:
- """Build a Plotly figure from resolved channels."""
- fig = go.Figure()
+) -> PlotSpec:
+ """Build a PlotSpec from resolved channels and UI state.
- if not channels:
- fig.update_layout(
- template="plotly_white",
- annotations=[
- {
- "text": "Select channels from the grid to plot",
- "xref": "paper",
- "yref": "paper",
- "x": 0.5,
- "y": 0.5,
- "showarrow": False,
- "font": {"size": 14, "color": "#bbb"},
- }
- ],
- xaxis={"visible": False},
- yaxis={"visible": False},
- margin={"l": 40, "r": 20, "t": 30, "b": 40},
- )
- return fig
-
- # Add traces — hovermode is disabled at the layout level; cursor tracking
- # is handled by our own JS mousemove handler (cursor_tracker.js).
+ Channels are already transformed — they are wrapped in ChannelRef
+ objects with no additional transform chain (transforms were applied
+ during resolution).
+ """
+ # Build ChannelRef objects
+ refs: list[ChannelRef] = []
for i, (label, ch) in enumerate(channels):
color = DEFAULT_COLORS[i % len(DEFAULT_COLORS)]
- fig.add_trace(
- go.Scatter(
- x=ch.time.tolist(),
- y=ch.data.tolist(),
- mode="lines",
- name=label,
- line=dict(color=color, width=1.5),
+ refs.append(
+ ChannelRef(
+ channel=ch,
+ style=PlotStyle(label=label, color=color),
)
)
- # Add corridor fills
+ # Build Corridor objects from raw dicts
+ corridor_objs: list[Corridor] = []
if corridors:
- for corridor in corridors:
- if not corridor.get("visible", True):
+ for c in corridors:
+ if not c.get("visible", True):
continue
- c_time = corridor["time"]
- c_upper = corridor["upper"]
- c_lower = corridor["lower"]
- c_name = corridor.get("name", "Corridor")
-
- # Upper bound
- fig.add_trace(
- go.Scatter(
- x=c_time,
- y=c_upper,
- mode="lines",
- line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"),
- showlegend=False,
- name=f"{c_name} upper",
- )
- )
- # Lower bound with fill to upper
- fig.add_trace(
- go.Scatter(
- x=c_time,
- y=c_lower,
- mode="lines",
- line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"),
- fill="tonexty",
- fillcolor="rgba(100,100,255,0.1)",
- showlegend=True,
- name=c_name,
+ corridor_objs.append(
+ Corridor(
+ name=c.get("name", "Corridor"),
+ time=np.array(c["time"]),
+ lower=np.array(c["lower"]),
+ upper=np.array(c["upper"]),
+ style=CorridorStyle(),
)
)
- # Add X1/X2 cursor lines
- if cursor_x1 is not None:
- fig.add_vline(
- x=cursor_x1,
- line_dash="dash",
- line_color="rgba(220,53,69,0.6)",
- line_width=1,
- annotation_text=f"X1={cursor_x1:.4f}s",
- annotation_font_size=9,
- annotation_font_color="rgba(220,53,69,0.8)",
- )
- if cursor_x2 is not None:
- fig.add_vline(
- x=cursor_x2,
- line_dash="dash",
- line_color="rgba(13,110,253,0.6)",
- line_width=1,
- annotation_text=f"X2={cursor_x2:.4f}s",
- annotation_font_size=9,
- annotation_font_color="rgba(13,110,253,0.8)",
- )
+ # Cursor positions
+ x_cursors = None
+ if cursor_x1 is not None and cursor_x2 is not None:
+ x_cursors = (cursor_x1, cursor_x2)
- # Layout — hovermode is disabled; cursor tracking is handled entirely
- # by our JS (cursor_tracker.js) which reads pixel positions from mousemove
- # events and converts to data coordinates via Plotly's axis internals.
+ # Y-axis label from first channel
y_label = channels[0][1].unit if channels else ""
- fig.update_layout(
- xaxis_title=dict(text="Time (s)", font=dict(size=10, color="#999")),
- yaxis_title=dict(text=y_label, font=dict(size=10, color="#999")),
- template="plotly_white",
- hovermode=False,
- showlegend=False,
- margin=dict(l=45, r=8, t=4, b=28),
+ return PlotSpec(
+ channels=refs,
+ corridors=corridor_objs,
+ x_cursors=x_cursors,
+ y_label=y_label,
+ compact=True, # Web UI always uses compact mode
)
- return fig
-
def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None:
"""Register all plot-related callbacks."""
@@ -271,10 +245,5 @@ def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None:
show_resultant,
)
- return _build_figure(
- channels,
- cursor_x1,
- cursor_x2,
- cfc_value,
- corridors=corridors_data,
- )
+ spec = _build_plot_spec(channels, cursor_x1, cursor_x2, corridors_data)
+ return _engine.render(spec)
diff --git a/src/impakt/web/components/channel_tree.py b/src/impakt/web/components/channel_tree.py
deleted file mode 100644
index e4b68db..0000000
--- a/src/impakt/web/components/channel_tree.py
+++ /dev/null
@@ -1,221 +0,0 @@
-"""Collapsible channel tree component.
-
-Renders channels in a hierarchical accordion:
- Test Object > Body Region > Measurement Type > [channels]
-
-Supports search filtering, select-all per group, and channel preview
-(peak, unit, sample rate) on hover/expand.
-"""
-
-from __future__ import annotations
-
-from typing import Any
-
-import dash_bootstrap_components as dbc
-from dash import dcc, html
-
-from impakt.web.state import AppState
-
-
-def _make_channel_item(ch_info: dict[str, str], show_test_prefix: bool = False) -> html.Div:
- """Build a single channel item with checkbox and preview info."""
- label = ch_info["label"]
- key = ch_info["key"]
- peak = ch_info.get("peak", "")
- unit = ch_info.get("unit", "")
-
- preview = f" ({peak} {unit})" if peak else ""
-
- return html.Div(
- [
- dbc.Checkbox(
- id={"type": "channel-check", "index": key},
- label=html.Span(
- [
- html.Span(label, style={"fontSize": "12px"}),
- html.Span(
- preview,
- style={"fontSize": "10px", "color": "#999", "marginLeft": "4px"},
- ),
- ]
- ),
- value=False,
- style={"marginBottom": "1px"},
- ),
- ],
- className="channel-item",
- )
-
-
-def _make_group_header(
- group_label: str,
- group_id: str,
- channel_keys: list[str],
-) -> html.Div:
- """Build a measurement group header with select-all button."""
- return html.Div(
- [
- html.Span(group_label, style={"fontSize": "12px", "fontWeight": "500"}),
- dbc.Button(
- "All",
- id={"type": "select-group-btn", "index": group_id},
- color="link",
- size="sm",
- style={"fontSize": "10px", "padding": "0 4px", "textDecoration": "none"},
- ),
- # Hidden store for this group's channel keys
- dcc.Store(
- id={"type": "group-channels", "index": group_id},
- data=channel_keys,
- ),
- ],
- className="d-flex align-items-center justify-content-between",
- )
-
-
-def build_channel_tree(app_state: AppState) -> dbc.Card:
- """Build the full channel tree panel with search and accordion."""
- if app_state.is_empty:
- return dbc.Card(
- [
- dbc.CardHeader("Channels", className="fw-bold py-2"),
- dbc.CardBody(
- html.Div("No test loaded", className="text-muted small text-center py-3"),
- ),
- ]
- )
-
- multi_test = len(app_state.tests) > 1
- full_tree = app_state.build_channel_tree()
-
- # Build accordion items
- accordion_items = []
- group_counter = 0
-
- for test_id, test_tree in full_tree.items():
- test_label = ""
- loaded = app_state.get_test(test_id)
- if loaded and multi_test:
- test_label = f"[{test_id}] "
-
- for obj_label, locations in test_tree.items():
- # Top-level accordion: test object (Driver, Vehicle Structure, etc.)
- obj_content = []
-
- for loc_label, measurements in locations.items():
- for meas_label, channels in measurements.items():
- group_id = f"grp_{group_counter}"
- group_counter += 1
-
- channel_keys = [ch["key"] for ch in channels]
- header_label = (
- f"{loc_label} — {meas_label}" if loc_label != meas_label else meas_label
- )
-
- obj_content.append(_make_group_header(header_label, group_id, channel_keys))
- for ch in channels:
- obj_content.append(_make_channel_item(ch, show_test_prefix=multi_test))
-
- obj_content.append(html.Hr(style={"margin": "4px 0"}))
-
- if obj_content:
- # Remove trailing hr
- if obj_content and isinstance(obj_content[-1], html.Hr):
- obj_content = obj_content[:-1]
-
- display_label = f"{test_label}{obj_label}" if test_label else obj_label
- accordion_items.append(
- dbc.AccordionItem(
- html.Div(obj_content),
- title=display_label,
- style={"fontSize": "12px"},
- )
- )
-
- return dbc.Card(
- [
- dbc.CardHeader(
- [
- html.Span("Channels", className="fw-bold"),
- html.Span(
- f" ({app_state.total_channels})",
- style={"fontSize": "11px", "color": "#999"},
- ),
- ],
- className="py-2",
- ),
- dbc.CardBody(
- [
- # Search
- dbc.Input(
- id="channel-search",
- placeholder="Search channels...",
- type="text",
- size="sm",
- className="mb-2",
- debounce=True,
- ),
- # Selected channels display
- html.Div(id="selected-channels-badges", className="mb-2"),
- # Tree
- html.Div(
- dbc.Accordion(
- accordion_items,
- flush=True,
- always_open=True,
- start_collapsed=True,
- id="channel-accordion",
- ),
- style={"maxHeight": "450px", "overflowY": "auto"},
- id="channel-tree-container",
- ),
- ],
- className="py-2",
- ),
- ]
- )
-
-
-def build_selected_channels_badges(selected_keys: list[str], app_state: AppState) -> list:
- """Build badge pills showing currently selected channels."""
- if not selected_keys:
- return [
- html.Span("No channels selected", className="text-muted", style={"fontSize": "11px"})
- ]
-
- badges = []
- for key in selected_keys[:20]: # Limit display to 20
- # Extract label
- if "::" in key:
- test_id, ch_name = key.split("::", 1)
- ch = app_state.get_channel(test_id, ch_name)
- if ch and ch.code.is_valid:
- label = ch.code.short_label
- else:
- label = ch_name
- if len(app_state.tests) > 1:
- label = f"{test_id}: {label}"
- else:
- label = key
-
- badges.append(
- dbc.Badge(
- label,
- id={"type": "channel-badge", "index": key},
- color="primary",
- pill=True,
- className="me-1 mb-1",
- style={"fontSize": "10px", "cursor": "pointer"},
- )
- )
-
- if len(selected_keys) > 20:
- badges.append(
- html.Span(
- f"+{len(selected_keys) - 20} more",
- className="text-muted",
- style={"fontSize": "10px"},
- )
- )
-
- return badges
diff --git a/src/impakt/web/components/header.py b/src/impakt/web/components/header.py
index 479e978..e7a36b8 100644
--- a/src/impakt/web/components/header.py
+++ b/src/impakt/web/components/header.py
@@ -129,7 +129,7 @@ def build_test_info_panel(app_state: AppState) -> dbc.Card:
f"{meta.impact.speed_kmh:.1f}" if meta.impact.speed_kmh else "—",
style={"fontSize": "12px"},
),
- html.Td(str(loaded.channel_count), style={"fontSize": "12px"}),
+ html.Td(str(len(loaded)), style={"fontSize": "12px"}),
html.Td(
dbc.Button(
"x",
diff --git a/src/impakt/web/state.py b/src/impakt/web/state.py
index cf80de2..c7b92a1 100644
--- a/src/impakt/web/state.py
+++ b/src/impakt/web/state.py
@@ -1,14 +1,12 @@
"""Application state management.
-AppState is the central data store for the web UI. It holds all loaded tests,
-manages channel transforms, and provides the data that callbacks need to
-render plots, compute criteria, and generate reports.
+AppState is the central data store for the web UI. It holds Session
+objects (from the scripting API) and delegates to them for channel
+access, transforms, criteria, and protocol scoring.
-AppState is NOT a Dash Store — it lives server-side in Python memory. The Dash
-stores hold lightweight references (test IDs, selected channels, plot config)
+AppState is NOT a Dash Store — it lives server-side in Python memory.
+Dash stores hold lightweight references (test IDs, selected channels)
that callbacks use to look up data from AppState.
-
-This design keeps large numpy arrays out of the browser and avoids serialization.
"""
from __future__ import annotations
@@ -17,220 +15,95 @@ import logging
from pathlib import Path
from typing import Any
-from impakt.channel.model import Channel, ChannelGroup, TestData
-from impakt.io.reader import get_registry
+from impakt.channel.model import Channel, TestData
+from impakt.script.api import Session
from impakt.template.library import TemplateLibrary
-from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
-from impakt.template.session import SessionManager
+from impakt.template.model import PlotDefinition, TemplateSpec
from impakt.transform.align import YAlign
from impakt.transform.cfc import CFCFilter
logger = logging.getLogger(__name__)
-class LoadedTest:
- """A single loaded test with its metadata and channel access."""
-
- def __init__(self, test_data: TestData, color_offset: int = 0) -> None:
- self.data = test_data
- self.color_offset = color_offset
-
- @property
- def test_id(self) -> str:
- return self.data.test_id
-
- @property
- def label(self) -> str:
- """Short label for UI display."""
- meta = self.data.metadata
- parts = [meta.test_number]
- if meta.vehicle.make:
- parts.append(f"{meta.vehicle.make} {meta.vehicle.model}".strip())
- return " — ".join(parts) if len(parts) > 1 else parts[0]
-
- @property
- def channel_count(self) -> int:
- return len(self.data)
-
- def __len__(self) -> int:
- return len(self.data)
-
-
class AppState:
"""Server-side application state.
- Manages multiple loaded tests and provides channel resolution
- for the UI callbacks.
+ Manages multiple loaded test sessions and provides channel resolution
+ for the UI callbacks. All test data access goes through Session objects.
"""
def __init__(self) -> None:
- self._tests: dict[str, LoadedTest] = {}
- self._test_order: list[str] = []
- self._color_counter: int = 0
+ self._sessions: dict[str, Session] = {}
+ self._session_order: list[str] = []
self._template_library = TemplateLibrary()
self._active_template: TemplateSpec | None = None
- self._session_managers: dict[str, SessionManager] = {}
- # Per-channel transform overrides: {channel_key: {"cfc": "600", "y_align": True}}
+ # Per-channel transform overrides: {channel_key: {"cfc": "600"}}
self.channel_overrides: dict[str, dict[str, Any]] = {}
- # Active corridors: list of {name, time, lower, upper, visible}
+ # Active corridors
self.corridors: list[dict[str, Any]] = []
- def load_test(self, path: str | Path) -> LoadedTest:
+ def load_test(self, path: str | Path) -> Session:
"""Load a test from a path and add it to the state.
- If a test with the same ID is already loaded, replaces it.
-
- Returns:
- The loaded test.
-
- Raises:
- ValueError: If the path is not a valid test directory.
+ Returns the Session object.
"""
- path = Path(path).resolve()
- registry = get_registry()
- test_data = registry.read(path)
+ session = Session.open(path)
+ test_id = session.test_id
- loaded = LoadedTest(test_data, color_offset=self._color_counter)
- self._color_counter += len(test_data)
+ self._sessions[test_id] = session
+ if test_id not in self._session_order:
+ self._session_order.append(test_id)
- test_id = test_data.test_id
- if test_id in self._tests:
- self._tests[test_id] = loaded
- else:
- self._tests[test_id] = loaded
- self._test_order.append(test_id)
+ logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(session))
+ return session
- logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(test_data))
- return loaded
+ def add_session(self, session: Session) -> None:
+ """Add an already-created Session to the state."""
+ test_id = session.test_id
+ self._sessions[test_id] = session
+ if test_id not in self._session_order:
+ self._session_order.append(test_id)
def remove_test(self, test_id: str) -> None:
"""Remove a loaded test."""
- if test_id in self._tests:
- del self._tests[test_id]
- self._test_order.remove(test_id)
+ self._sessions.pop(test_id, None)
+ if test_id in self._session_order:
+ self._session_order.remove(test_id)
@property
- def tests(self) -> list[LoadedTest]:
- """All loaded tests in load order."""
- return [self._tests[tid] for tid in self._test_order if tid in self._tests]
+ def sessions(self) -> list[Session]:
+ """All loaded sessions in load order."""
+ return [self._sessions[tid] for tid in self._session_order if tid in self._sessions]
+
+ @property
+ def tests(self) -> list[Session]:
+ """Alias for sessions (backward compat for components)."""
+ return self.sessions
@property
def test_ids(self) -> list[str]:
- return list(self._test_order)
+ return list(self._session_order)
@property
- def primary_test(self) -> LoadedTest | None:
- """The first loaded test (primary for criteria, etc.)."""
- if self._test_order:
- return self._tests.get(self._test_order[0])
+ def primary_test(self) -> Session | None:
+ """The first loaded session (primary for criteria, etc.)."""
+ if self._session_order:
+ return self._sessions.get(self._session_order[0])
return None
- def get_test(self, test_id: str) -> LoadedTest | None:
- return self._tests.get(test_id)
+ def get_test(self, test_id: str) -> Session | None:
+ return self._sessions.get(test_id)
def get_channel(self, test_id: str, channel_name: str) -> Channel | None:
"""Resolve a channel from a test ID and channel name."""
- test = self._tests.get(test_id)
- if test is None:
+ session = self._sessions.get(test_id)
+ if session is None:
return None
try:
- return test.data.get(channel_name)
+ return session.data.get(channel_name)
except KeyError:
return None
- def resolve_channel(
- self,
- channel_key: str,
- cfc_class: int | None = None,
- y_align: bool = False,
- ) -> Channel | None:
- """Resolve a channel key (test_id::channel_name) and apply transforms.
-
- Channel keys use the format 'test_id::channel_name'. If no '::'
- separator, assumes the primary test.
- """
- if "::" in channel_key:
- test_id, ch_name = channel_key.split("::", 1)
- else:
- if self.primary_test is None:
- return None
- test_id = self.primary_test.test_id
- ch_name = channel_key
-
- ch = self.get_channel(test_id, ch_name)
- if ch is None:
- return None
-
- # Apply transforms
- if cfc_class is not None:
- try:
- ch = CFCFilter(cfc_class=cfc_class).apply(ch)
- except (ValueError, Exception):
- pass
-
- if y_align:
- ch = YAlign().apply(ch)
-
- return ch
-
- def build_channel_tree(
- self,
- ) -> dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]]:
- """Build a hierarchical channel tree across all loaded tests.
-
- Returns:
- {test_id: {object_label: {location_label: {measurement_label: [{name, label, key}]}}}}
- """
- result: dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]] = {}
-
- for loaded in self.tests:
- tree = loaded.data.channel_tree()
- test_tree: dict[str, dict[str, dict[str, list[dict[str, str]]]]] = {}
-
- for obj, locations in sorted(tree.items()):
- obj_tree: dict[str, dict[str, list[dict[str, str]]]] = {}
- for loc, measurements in sorted(locations.items()):
- loc_tree: dict[str, list[dict[str, str]]] = {}
- for meas, channels in sorted(measurements.items()):
- ch_list = []
- for ch in channels:
- label = ch.code.short_label if ch.code.is_valid else ch.name
- ch_list.append(
- {
- "name": ch.name,
- "label": label,
- "key": f"{loaded.test_id}::{ch.name}",
- "unit": ch.unit,
- "peak": f"{ch.peak:.2f}",
- "samples": str(ch.n_samples),
- "rate": f"{ch.sample_rate:.0f}",
- }
- )
- loc_tree[meas] = ch_list
- obj_tree[loc] = loc_tree
- test_tree[obj] = obj_tree
- result[loaded.test_id] = test_tree
-
- return result
-
- def flat_channel_list(self) -> list[dict[str, str]]:
- """Flat list of all channels across all tests, for dropdown/checklist options."""
- items: list[dict[str, str]] = []
- multi = len(self._tests) > 1
-
- for loaded in self.tests:
- prefix = f"[{loaded.test_id}] " if multi else ""
- for ch in sorted(loaded.data, key=lambda c: c.name):
- label = ch.code.short_label if ch.code.is_valid else ch.name
- items.append(
- {
- "label": f"{prefix}{label}",
- "value": f"{loaded.test_id}::{ch.name}",
- }
- )
-
- return items
-
# ----- Template & Session -----
@property
@@ -245,52 +118,37 @@ class AppState:
def template_names(self) -> list[str]:
return self._template_library.list()
- def apply_template(
- self,
- name: str,
- selected_keys: list[str] | None = None,
- ) -> tuple[list[str], dict[str, str]]:
+ def apply_template(self, name: str) -> tuple[list[str], dict[str, str]]:
"""Apply a template by name.
- Resolves the template's channel patterns against the primary test,
- sets the active template, and returns the resolved channel keys and
- transform settings.
-
- Returns:
- (selected_channel_keys, transform_settings)
+ Resolves channel patterns against the primary test.
+ Returns (selected_channel_keys, transform_settings).
"""
template = self._template_library.get(name)
self._active_template = template
- # Persist to session if primary test has a path
primary = self.primary_test
- if primary and primary.data.path:
- mgr = self._get_session_manager(primary.data.path)
- mgr.apply_template(template)
+ if primary:
+ primary.apply_template(template)
- # Resolve channel patterns from the template
+ # Resolve channel patterns
resolved_keys: list[str] = []
if primary:
for plot_def in template.plots:
for pattern in plot_def.channel_patterns:
- matches = primary.data.find(pattern)
+ matches = primary.find(pattern)
for ch in matches:
key = f"{primary.test_id}::{ch.name}"
if key not in resolved_keys:
resolved_keys.append(key)
- # Transform settings
transforms: dict[str, str] = {}
if template.default_cfc is not None:
transforms["cfc"] = str(template.default_cfc)
else:
transforms["cfc"] = "none"
- logger.info(
- "Applied template '%s' — resolved %d channels",
- name,
- len(resolved_keys),
- )
+ logger.info("Applied template '%s' — %d channels", name, len(resolved_keys))
return resolved_keys, transforms
def save_as_template(
@@ -303,30 +161,19 @@ class AppState:
x2: float | None,
protocol: str = "",
) -> TemplateSpec:
- """Capture the current UI state as a new template.
-
- Converts selected channels back to patterns and stores the
- current filter/cursor/protocol settings.
- """
- # Convert selected keys to channel patterns
+ """Capture current UI state as a new template."""
patterns: list[str] = []
for key in selected_keys:
- if "::" in key:
- _, ch_name = key.split("::", 1)
- else:
- ch_name = key
- # Use the raw channel name as a pattern (exact match)
+ ch_name = key.split("::", 1)[-1] if "::" in key else key
if ch_name not in patterns:
patterns.append(ch_name)
- # Build plot definition
plot = PlotDefinition(
title=name,
channel_patterns=patterns,
x_cursors=(x1, x2) if x1 is not None and x2 is not None else None,
)
- # Build transforms list
transforms: list[dict[str, Any]] = []
if cfc_value and cfc_value != "none":
transforms.append({"type": "cfc_filter", "cfc_class": int(cfc_value)})
@@ -343,17 +190,18 @@ class AppState:
self._template_library.save(template)
self._active_template = template
-
- logger.info("Saved template '%s' with %d channel patterns", name, len(patterns))
+ logger.info("Saved template '%s' with %d patterns", name, len(patterns))
return template
def save_session(self, selected_keys: list[str], cfc_value: str, **overrides: Any) -> None:
"""Auto-save current state to the session for the primary test."""
primary = self.primary_test
- if not primary or not primary.data.path:
+ if not primary or not primary.path:
return
- mgr = self._get_session_manager(primary.data.path)
+ from impakt.template.session import SessionManager
+
+ mgr = SessionManager(primary.path)
mgr.state.overrides = {
"selected_channels": selected_keys,
"cfc": cfc_value,
@@ -362,12 +210,14 @@ class AppState:
mgr.save()
def load_session_state(self) -> dict[str, Any] | None:
- """Load saved session state for the primary test, if any."""
+ """Load saved session state for the primary test."""
primary = self.primary_test
- if not primary or not primary.data.path:
+ if not primary or not primary.path:
return None
- mgr = self._get_session_manager(primary.data.path)
+ from impakt.template.session import SessionManager
+
+ mgr = SessionManager(primary.path)
if not mgr.has_session:
return None
@@ -377,21 +227,48 @@ class AppState:
"template": mgr.state.template_name,
}
- def _get_session_manager(self, path: Path) -> SessionManager:
- key = str(path)
- if key not in self._session_managers:
- self._session_managers[key] = SessionManager(path)
- return self._session_managers[key]
+ # ----- Channel tree / grid helpers -----
+
+ def build_channel_tree(self) -> dict[str, Any]:
+ """Build hierarchical channel tree across all loaded tests."""
+ result: dict[str, Any] = {}
+ for session in self.sessions:
+ tree = session.channel_tree()
+ test_tree: dict[str, Any] = {}
+ for obj, locations in sorted(tree.items()):
+ obj_tree: dict[str, Any] = {}
+ for loc, measurements in sorted(locations.items()):
+ loc_tree: dict[str, Any] = {}
+ for meas, channels in sorted(measurements.items()):
+ ch_list = []
+ for ch in channels:
+ label = ch.code.short_label if ch.code.is_valid else ch.name
+ ch_list.append(
+ {
+ "name": ch.name,
+ "label": label,
+ "key": f"{session.test_id}::{ch.name}",
+ "unit": ch.unit,
+ "peak": f"{ch.peak:.2f}",
+ "samples": str(ch.n_samples),
+ "rate": f"{ch.sample_rate:.0f}",
+ }
+ )
+ loc_tree[meas] = ch_list
+ obj_tree[loc] = loc_tree
+ test_tree[obj] = obj_tree
+ result[session.test_id] = test_tree
+ return result
@property
def is_empty(self) -> bool:
- return len(self._tests) == 0
+ return len(self._sessions) == 0
@property
def total_channels(self) -> int:
- return sum(t.channel_count for t in self.tests)
+ return sum(len(s) for s in self.sessions)
def __repr__(self) -> str:
- test_info = ", ".join(f"{t.test_id}({t.channel_count}ch)" for t in self.tests)
+ test_info = ", ".join(f"{s.test_id}({len(s)}ch)" for s in self.sessions)
tmpl = f", template={self._active_template.name}" if self._active_template else ""
return f"AppState([{test_info}]{tmpl})"
diff --git a/tests/test_criteria/test_chest_femur_tibia.py b/tests/test_criteria/test_chest_femur_tibia.py
new file mode 100644
index 0000000..9450997
--- /dev/null
+++ b/tests/test_criteria/test_chest_femur_tibia.py
@@ -0,0 +1,93 @@
+"""Tests for chest deflection, femur load, tibia index, 3ms clip, viscous criterion."""
+
+import numpy as np
+import pytest
+
+from impakt.criteria import chest_deflection, clip_3ms, femur_load, tibia_index, viscous_criterion
+
+
+class TestChestDeflection:
+ def test_basic(self, chest_deflection_channel):
+ result = chest_deflection(channel=chest_deflection_channel)
+ assert result.criterion == "Chest Deflection"
+ assert result.unit == "mm"
+ assert 30.0 < result.value < 40.0 # ~35mm in fixture
+
+ def test_peak_time(self, chest_deflection_channel):
+ result = chest_deflection(channel=chest_deflection_channel)
+ assert result.time_of_peak is not None
+ assert 0.0 < result.time_of_peak < 0.1
+
+
+class TestClip3ms:
+ def test_basic(self, head_accel_x):
+ result = clip_3ms(head_accel_x)
+ assert result.criterion == "3ms Clip"
+ assert result.value > 0
+
+ def test_body_region(self, head_accel_x):
+ result = clip_3ms(head_accel_x)
+ assert result.body_region == "Chest"
+
+
+class TestFemurLoad:
+ def test_single_channel(self, femur_left_channel):
+ result = femur_load(channel=femur_left_channel, side="left")
+ assert result.criterion == "Femur Load Left"
+ assert result.unit == "kN"
+ assert result.value > 0
+
+ def test_peak_time(self, femur_left_channel):
+ result = femur_load(channel=femur_left_channel, side="left")
+ assert result.time_of_peak is not None
+
+
+class TestTibiaIndex:
+ def test_with_all_components(self, time_array, sample_rate):
+ from impakt.channel.code import ChannelCode
+ from impakt.channel.model import Channel
+
+ t = time_array
+ fz = np.zeros_like(t)
+ mx = np.zeros_like(t)
+ my = np.zeros_like(t)
+ mask = (t >= 0.02) & (t <= 0.08)
+ fz[mask] = -5000 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
+ mx[mask] = 50 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
+ my[mask] = 80 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
+
+ fz_ch = Channel(
+ name="TIBFZ",
+ code=ChannelCode.parse("TIBFZ"),
+ data=fz,
+ time=t,
+ unit="N",
+ sample_rate=sample_rate,
+ )
+ mx_ch = Channel(
+ name="TIBMX",
+ code=ChannelCode.parse("TIBMX"),
+ data=mx,
+ time=t,
+ unit="N·m",
+ sample_rate=sample_rate,
+ )
+ my_ch = Channel(
+ name="TIBMY",
+ code=ChannelCode.parse("TIBMY"),
+ data=my,
+ time=t,
+ unit="N·m",
+ sample_rate=sample_rate,
+ )
+
+ result = tibia_index(fz_channel=fz_ch, mx_channel=mx_ch, my_channel=my_ch)
+ assert result.value > 0
+
+
+class TestViscousCriterion:
+ def test_basic(self, chest_deflection_channel):
+ result = viscous_criterion(channel=chest_deflection_channel)
+ assert result.criterion == "Viscous Criterion"
+ assert result.unit == "m/s"
+ assert result.value >= 0
diff --git a/tests/test_plot/test_engine.py b/tests/test_plot/test_engine.py
new file mode 100644
index 0000000..5f37100
--- /dev/null
+++ b/tests/test_plot/test_engine.py
@@ -0,0 +1,81 @@
+"""Tests for plot engine."""
+
+import numpy as np
+import pytest
+
+from impakt.plot.engine import PlotEngine, DEFAULT_COLORS
+from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
+
+
+class TestPlotEngine:
+ def test_render_empty(self):
+ engine = PlotEngine()
+ spec = PlotSpec()
+ fig = engine.render(spec)
+ assert fig is not None
+
+ def test_render_with_channels(self, head_accel_x, head_accel_y):
+ engine = PlotEngine()
+ spec = PlotSpec(
+ channels=[
+ ChannelRef(channel=head_accel_x, style=PlotStyle(label="Head X")),
+ ChannelRef(channel=head_accel_y, style=PlotStyle(label="Head Y")),
+ ],
+ y_label="g",
+ )
+ fig = engine.render(spec)
+ assert len(fig.data) == 2
+
+ def test_render_compact_mode(self, head_accel_x):
+ engine = PlotEngine()
+ spec = PlotSpec(
+ channels=[ChannelRef(channel=head_accel_x)],
+ compact=True,
+ )
+ fig = engine.render(spec)
+ assert fig.layout.hovermode is False
+ assert fig.layout.showlegend is False
+
+ def test_render_with_corridors(self, head_accel_x):
+ engine = PlotEngine()
+ corridor = Corridor(
+ name="Test Corridor",
+ time=np.linspace(0, 0.1, 100),
+ lower=np.full(100, -50.0),
+ upper=np.full(100, 50.0),
+ )
+ spec = PlotSpec(
+ channels=[ChannelRef(channel=head_accel_x)],
+ corridors=[corridor],
+ )
+ fig = engine.render(spec)
+ # 1 data trace + 2 corridor traces (upper + lower)
+ assert len(fig.data) == 3
+
+ def test_render_with_cursors(self, head_accel_x):
+ engine = PlotEngine()
+ spec = PlotSpec(
+ channels=[ChannelRef(channel=head_accel_x)],
+ x_cursors=(0.02, 0.06),
+ )
+ fig = engine.render(spec)
+ # Should have vertical lines (as shapes)
+ assert len(fig.layout.shapes) >= 2
+
+ def test_compact_cursor_annotations(self, head_accel_x):
+ engine = PlotEngine()
+ spec = PlotSpec(
+ channels=[ChannelRef(channel=head_accel_x)],
+ x_cursors=(0.02, 0.06),
+ compact=True,
+ )
+ fig = engine.render(spec)
+ # Should have X1/X2 annotations
+ annotations = [
+ a for a in (fig.layout.annotations or []) if "X1" in str(a.text) or "X2" in str(a.text)
+ ]
+ assert len(annotations) == 2
+
+ def test_default_colors(self):
+ assert len(DEFAULT_COLORS) == 10
+ assert all(c.startswith("#") for c in DEFAULT_COLORS)
diff --git a/tests/test_protocol/test_iihs_usncap.py b/tests/test_protocol/test_iihs_usncap.py
new file mode 100644
index 0000000..4bb9904
--- /dev/null
+++ b/tests/test_protocol/test_iihs_usncap.py
@@ -0,0 +1,68 @@
+"""Tests for IIHS and US NCAP scoring."""
+
+import pytest
+
+from impakt.criteria.base import CriterionResult
+from impakt.protocol.iihs import IIHS, evaluate as iihs_evaluate
+from impakt.protocol.us_ncap import USNCAP, evaluate as us_evaluate
+
+
+@pytest.fixture
+def sample_criteria():
+ return {
+ "HIC15": CriterionResult(criterion="HIC15", value=400, body_region="Head"),
+ "Nij": CriterionResult(criterion="Nij", value=0.4, body_region="Neck"),
+ "Chest Deflection": CriterionResult(
+ criterion="Chest Deflection", value=35, unit="mm", body_region="Chest"
+ ),
+ "Femur Load Left": CriterionResult(
+ criterion="Femur Load Left", value=4.0, unit="kN", body_region="Femur Left"
+ ),
+ "Femur Load Right": CriterionResult(
+ criterion="Femur Load Right", value=4.5, unit="kN", body_region="Femur Right"
+ ),
+ }
+
+
+class TestIIHS:
+ def test_good_results(self, sample_criteria):
+ result = iihs_evaluate(sample_criteria)
+ assert result.protocol == "IIHS"
+ assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
+
+ def test_region_scores(self, sample_criteria):
+ result = iihs_evaluate(sample_criteria)
+ assert len(result.region_scores) > 0
+ for rs in result.region_scores:
+ assert rs.rating is not None
+
+ def test_poor_hic(self):
+ result = iihs_evaluate(
+ {
+ "HIC15": CriterionResult(criterion="HIC15", value=1500, body_region="Head"),
+ }
+ )
+ assert result.overall_rating == "POOR"
+
+ def test_summary(self, sample_criteria):
+ result = iihs_evaluate(sample_criteria)
+ summary = result.summary()
+ assert "IIHS" in summary
+
+
+class TestUSNCAP:
+ def test_basic(self, sample_criteria):
+ result = us_evaluate(sample_criteria)
+ assert result.protocol == "US NCAP"
+ assert result.stars is not None
+ assert 1 <= result.stars <= 5
+
+ def test_injury_probabilities(self, sample_criteria):
+ result = us_evaluate(sample_criteria)
+ assert "combined_injury_probability" in result.details
+ p = result.details["combined_injury_probability"]
+ assert 0.0 <= p <= 1.0
+
+ def test_region_scores(self, sample_criteria):
+ result = us_evaluate(sample_criteria)
+ assert len(result.region_scores) > 0
diff --git a/tests/test_scripting_api.py b/tests/test_scripting_api.py
new file mode 100644
index 0000000..573ed57
--- /dev/null
+++ b/tests/test_scripting_api.py
@@ -0,0 +1,126 @@
+"""Tests for the scripting API (Session, ChannelHandle, TransformProxy)."""
+
+from pathlib import Path
+
+import pytest
+
+from impakt import Session, Template
+
+FIXTURE_DATA = Path(__file__).parent / "fixtures" / "sample_mme"
+MME_DATA = Path(__file__).parent / "mme_data"
+
+
+class TestSession:
+ def test_open(self):
+ s = Session.open(FIXTURE_DATA)
+ assert s.test_id == "IMPAKT_SYNTH_001"
+ assert len(s) == 26
+
+ def test_channel_access(self):
+ s = Session.open(FIXTURE_DATA)
+ ch = s.channel("11HEAD0000ACXA")
+ assert ch.name == "11HEAD0000ACXA"
+ assert ch.peak > 0
+
+ def test_find(self):
+ s = Session.open(FIXTURE_DATA)
+ channels = s.find("*HEAD*AC*")
+ assert len(channels) == 3
+
+ def test_group(self):
+ s = Session.open(FIXTURE_DATA)
+ group = s.group("HEAD0000AC")
+ assert group.x is not None
+
+ def test_compute_criteria(self):
+ s = Session.open(FIXTURE_DATA)
+ criteria = s.compute_criteria()
+ assert len(criteria) > 0
+
+ def test_evaluate(self):
+ s = Session.open(FIXTURE_DATA)
+ result = s.evaluate("euro_ncap")
+ assert result.stars is not None
+ assert result.protocol == "Euro NCAP"
+
+ def test_evaluate_us_ncap(self):
+ s = Session.open(FIXTURE_DATA)
+ result = s.evaluate("us_ncap")
+ assert result.stars is not None
+
+ def test_evaluate_iihs(self):
+ s = Session.open(FIXTURE_DATA)
+ result = s.evaluate("iihs")
+ assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
+
+ def test_evaluate_invalid_protocol(self):
+ s = Session.open(FIXTURE_DATA)
+ with pytest.raises(ValueError, match="Unknown protocol"):
+ s.evaluate("invalid")
+
+ def test_contains(self):
+ s = Session.open(FIXTURE_DATA)
+ assert "11HEAD0000ACXA" in s
+ assert "NONEXISTENT" not in s
+
+
+class TestChannelHandleChaining:
+ """The fluent API must support chaining — each transform returns ChannelHandle."""
+
+ def test_single_transform(self):
+ s = Session.open(FIXTURE_DATA)
+ ch = s.channel("11HEAD0000ACXA")
+ filtered = ch.transform.cfc(600)
+ assert type(filtered).__name__ == "ChannelHandle"
+ assert filtered.raw.cfc_class == 600
+
+ def test_double_chain(self):
+ s = Session.open(FIXTURE_DATA)
+ result = s.channel("11HEAD0000ACXA").transform.cfc(600).transform.y_align()
+ assert type(result).__name__ == "ChannelHandle"
+ assert len(result.raw.transform_history) == 2
+
+ def test_triple_chain(self):
+ s = Session.open(FIXTURE_DATA)
+ result = (
+ s.channel("11HEAD0000ACXA")
+ .transform.cfc(1000)
+ .transform.y_align()
+ .transform.trim(t_start=0.0, t_end=0.1)
+ )
+ assert type(result).__name__ == "ChannelHandle"
+ assert len(result.raw.transform_history) == 3
+
+ def test_chain_preserves_data(self):
+ s = Session.open(FIXTURE_DATA)
+ original = s.channel("11HEAD0000ACXA")
+ original_peak = original.peak
+ filtered = original.transform.cfc(600)
+ # Original should be unchanged — peak should be the same
+ assert original.peak == original_peak
+ # Filtered should have different CFC and lower peak (smoothed)
+ assert filtered.raw.cfc_class == 600
+ assert filtered.peak <= original_peak
+
+
+@pytest.mark.skipif(not (MME_DATA / "3239").exists(), reason="Real data not available")
+class TestSessionRealData:
+ def test_open_real(self):
+ s = Session.open(MME_DATA / "3239")
+ assert s.test_id == "3239"
+ assert len(s) == 133
+
+ def test_full_pipeline(self):
+ s = Session.open(MME_DATA / "3239")
+ # Chain: get channel -> filter -> check
+ ch = s.channel("11HEAD0000H3ACXP").transform.cfc(1000)
+ assert ch.peak > 100 # Significant head acceleration
+
+ # Compute criteria
+ criteria = s.compute_criteria()
+ assert "HIC15" in criteria
+
+ # Evaluate
+ result = s.evaluate("euro_ncap")
+ assert result.stars is not None
+ assert result.stars >= 0
diff --git a/tests/test_template.py b/tests/test_template.py
new file mode 100644
index 0000000..756ba58
--- /dev/null
+++ b/tests/test_template.py
@@ -0,0 +1,132 @@
+"""Tests for template model and session persistence."""
+
+from pathlib import Path
+
+import pytest
+
+from impakt.template.library import TemplateLibrary
+from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
+from impakt.template.session import SessionManager
+
+
+class TestTemplateSpec:
+ def test_yaml_round_trip(self):
+ template = TemplateSpec(
+ name="Test Template",
+ version=2,
+ description="A test template",
+ plots=[
+ PlotDefinition(
+ title="Head Acceleration",
+ channel_patterns=["*HEAD*AC*"],
+ x_cursors=(0.0, 0.1),
+ )
+ ],
+ default_cfc=1000,
+ criteria=["hic15", "nij"],
+ protocol="euro_ncap",
+ )
+
+ yaml_str = template.to_yaml()
+ restored = TemplateSpec.from_yaml(yaml_str)
+
+ assert restored.name == "Test Template"
+ assert restored.version == 2
+ assert restored.default_cfc == 1000
+ assert len(restored.plots) == 1
+ assert restored.plots[0].channel_patterns == ["*HEAD*AC*"]
+ assert restored.criteria == ["hic15", "nij"]
+
+ def test_save_and_load(self, tmp_path):
+ template = TemplateSpec(name="Saved Test", version=1)
+ path = tmp_path / "test.yaml"
+ template.save(path)
+ assert path.exists()
+
+ loaded = TemplateSpec.load(path)
+ assert loaded.name == "Saved Test"
+
+
+class TestSessionState:
+ def test_yaml_round_trip(self):
+ state = SessionState(
+ template_name="my_template",
+ template_version=3,
+ test_path="/data/test_001",
+ overrides={"cfc": "600", "selected": ["ch1", "ch2"]},
+ )
+ yaml_str = state.to_yaml()
+ restored = SessionState.from_yaml(yaml_str)
+
+ assert restored.template_name == "my_template"
+ assert restored.template_version == 3
+ assert restored.overrides["cfc"] == "600"
+
+ def test_save_and_load(self, tmp_path):
+ state = SessionState(template_name="test")
+ path = tmp_path / "session.yaml"
+ state.save(path)
+ assert path.exists()
+
+ loaded = SessionState.load(path)
+ assert loaded.template_name == "test"
+
+
+class TestTemplateLibrary:
+ def test_empty_library(self, tmp_path):
+ lib = TemplateLibrary(tmp_path / "templates")
+ assert lib.list() == []
+ assert len(lib) == 0
+
+ def test_save_and_list(self, tmp_path):
+ lib = TemplateLibrary(tmp_path / "templates")
+ template = TemplateSpec(name="My Template")
+ lib.save(template)
+ assert "my_template" in lib.list()
+ assert len(lib) == 1
+
+ def test_get(self, tmp_path):
+ lib = TemplateLibrary(tmp_path / "templates")
+ lib.save(TemplateSpec(name="Getter Test", version=5))
+ loaded = lib.get("getter_test")
+ assert loaded.name == "Getter Test"
+ assert loaded.version == 5
+
+ def test_delete(self, tmp_path):
+ lib = TemplateLibrary(tmp_path / "templates")
+ lib.save(TemplateSpec(name="To Delete"))
+ assert lib.delete("to_delete")
+ assert "to_delete" not in lib.list()
+
+ def test_get_missing_raises(self, tmp_path):
+ lib = TemplateLibrary(tmp_path / "templates")
+ with pytest.raises(FileNotFoundError):
+ lib.get("nonexistent")
+
+
+class TestSessionManager:
+ def test_create_and_save(self, tmp_path):
+ mgr = SessionManager(tmp_path)
+ mgr.state.template_name = "test_tmpl"
+ mgr.save()
+ assert mgr.has_session
+ assert (tmp_path / ".impakt" / "session.yaml").exists()
+
+ def test_load_existing(self, tmp_path):
+ # Save
+ mgr1 = SessionManager(tmp_path)
+ mgr1.state.template_name = "saved_tmpl"
+ mgr1.state.overrides = {"key": "value"}
+ mgr1.save()
+
+ # Load
+ mgr2 = SessionManager(tmp_path)
+ assert mgr2.state.template_name == "saved_tmpl"
+ assert mgr2.state.overrides["key"] == "value"
+
+ def test_clear(self, tmp_path):
+ mgr = SessionManager(tmp_path)
+ mgr.save()
+ assert mgr.has_session
+ mgr.clear()
+ assert not mgr.has_session
diff --git a/tests/test_transform/test_math_resultant.py b/tests/test_transform/test_math_resultant.py
new file mode 100644
index 0000000..c395b7f
--- /dev/null
+++ b/tests/test_transform/test_math_resultant.py
@@ -0,0 +1,75 @@
+"""Tests for math expressions and resultant computation."""
+
+import numpy as np
+import pytest
+
+from impakt.transform.math_expr import math_expr
+from impakt.transform.resultant import resultant_from_channels
+from impakt.transform.resample import trim, resample
+
+
+class TestMathExpr:
+ def test_simple_expression(self, head_accel_x, head_accel_z):
+ result = math_expr(
+ expression="sqrt(a**2 + b**2)",
+ channels={"a": head_accel_x, "b": head_accel_z},
+ name="resultant_xz",
+ unit="g",
+ )
+ assert result.name == "resultant_xz"
+ assert result.unit == "g"
+ assert result.peak > 0
+ assert len(result.data) == len(head_accel_x.data)
+
+ def test_constant_expression(self, head_accel_x):
+ result = math_expr(
+ expression="a * 0 + 42.0",
+ channels={"a": head_accel_x},
+ name="constant",
+ )
+ assert np.allclose(result.data, 42.0)
+
+ def test_invalid_expression(self, head_accel_x):
+ with pytest.raises(ValueError, match="Error evaluating"):
+ math_expr(
+ expression="invalid_func(a)",
+ channels={"a": head_accel_x},
+ )
+
+ def test_forbidden_expression(self, head_accel_x):
+ with pytest.raises(ValueError, match="Forbidden"):
+ math_expr(
+ expression="__import__('os')",
+ channels={"a": head_accel_x},
+ )
+
+
+class TestResultant:
+ def test_from_channels(self, head_accel_x, head_accel_y, head_accel_z):
+ result = resultant_from_channels(head_accel_x, head_accel_y, head_accel_z)
+ assert result.code.direction == "R"
+ # Resultant >= any component
+ assert result.peak >= head_accel_x.peak
+ assert result.peak >= head_accel_y.peak
+
+ def test_from_two_channels(self, head_accel_x, head_accel_z):
+ result = resultant_from_channels(head_accel_x, head_accel_z)
+ assert result.peak > 0
+
+ def test_single_channel_raises(self):
+ with pytest.raises(ValueError, match="At least one"):
+ resultant_from_channels()
+
+
+class TestTrimResample:
+ def test_trim(self, head_accel_x):
+ trimmed = trim(head_accel_x, t_start=0.0, t_end=0.05)
+ assert trimmed.time[0] >= 0.0
+ assert trimmed.time[-1] <= 0.05
+ assert len(trimmed.data) < len(head_accel_x.data)
+
+ def test_resample(self, head_accel_x):
+ resampled = resample(head_accel_x, target_rate=5000.0)
+ expected_samples = int(head_accel_x.duration * 5000.0)
+ assert abs(len(resampled.data) - expected_samples) <= 2
+ assert resampled.sample_rate == 5000.0
diff --git a/tests/test_web/test_state.py b/tests/test_web/test_state.py
index 56d24c7..93c950b 100644
--- a/tests/test_web/test_state.py
+++ b/tests/test_web/test_state.py
@@ -19,11 +19,11 @@ class TestAppState:
def test_load_test(self):
state = AppState()
- loaded = state.load_test(FIXTURE_DATA)
+ session = state.load_test(FIXTURE_DATA)
assert not state.is_empty
- assert loaded.test_id == "IMPAKT_SYNTH_001"
- assert loaded.channel_count == 26
- assert state.primary_test is loaded
+ assert session.test_id == "IMPAKT_SYNTH_001"
+ assert len(session) == 26
+ assert state.primary_test is session
def test_load_multiple_tests(self):
state = AppState()
@@ -32,13 +32,13 @@ class TestAppState:
if (MME_DATA / "VW1FGS15").exists():
t2 = state.load_test(MME_DATA / "VW1FGS15")
assert len(state.tests) == 2
- assert state.primary_test is t1 # First loaded is primary
+ assert state.primary_test is t1
assert state.total_channels == 26 + 10
def test_remove_test(self):
state = AppState()
- loaded = state.load_test(FIXTURE_DATA)
- state.remove_test(loaded.test_id)
+ session = state.load_test(FIXTURE_DATA)
+ state.remove_test(session.test_id)
assert state.is_empty
def test_get_channel(self):
@@ -54,38 +54,38 @@ class TestAppState:
ch = state.get_channel("IMPAKT_SYNTH_001", "NONEXISTENT")
assert ch is None
- def test_resolve_channel_with_key(self):
+ def test_get_channel_via_session(self):
+ """Channels can be accessed through the Session scripting API."""
state = AppState()
state.load_test(FIXTURE_DATA)
- ch = state.resolve_channel("IMPAKT_SYNTH_001::11HEAD0000ACXA")
- assert ch is not None
+ session = state.primary_test
+ assert session is not None
+ ch_handle = session.channel("11HEAD0000ACXA")
+ assert ch_handle.name == "11HEAD0000ACXA"
- def test_resolve_channel_primary_default(self):
+ def test_session_fluent_transforms(self):
+ """Fluent transform chaining works through the Session API."""
state = AppState()
state.load_test(FIXTURE_DATA)
- ch = state.resolve_channel("11HEAD0000ACXA")
- assert ch is not None
+ session = state.primary_test
+ ch = session.channel("11HEAD0000ACXA")
+ filtered = ch.transform.cfc(600).transform.y_align()
+ assert filtered.raw.cfc_class == 600
+ assert len(filtered.raw.transform_history) == 2
- def test_resolve_channel_with_cfc(self):
+ def test_session_compute_criteria(self):
+ """Session.compute_criteria() auto-detects channels."""
state = AppState()
state.load_test(FIXTURE_DATA)
- ch = state.resolve_channel("11HEAD0000ACXA", cfc_class=600)
- assert ch is not None
- assert ch.cfc_class == 600
-
- def test_flat_channel_list(self):
- state = AppState()
- state.load_test(FIXTURE_DATA)
- items = state.flat_channel_list()
- assert len(items) == 26
- assert all("value" in item and "label" in item for item in items)
+ criteria = state.primary_test.compute_criteria()
+ assert len(criteria) > 0
+ assert "HIC15" in criteria or "Chest Deflection" in criteria
def test_build_channel_tree(self):
state = AppState()
state.load_test(FIXTURE_DATA)
tree = state.build_channel_tree()
assert "IMPAKT_SYNTH_001" in tree
- # Should have hierarchical structure
test_tree = tree["IMPAKT_SYNTH_001"]
assert len(test_tree) > 0
@@ -94,15 +94,23 @@ class TestAppState:
class TestAppStateRealData:
def test_load_real_mme(self):
state = AppState()
- loaded = state.load_test(MME_DATA / "3239")
- assert loaded.test_id == "3239"
- assert loaded.channel_count == 133
+ session = state.load_test(MME_DATA / "3239")
+ assert session.test_id == "3239"
+ assert len(session) == 133
def test_channel_tree_real_data(self):
state = AppState()
state.load_test(MME_DATA / "3239")
tree = state.build_channel_tree()
assert "3239" in tree
- # Should contain "Driver" in some key
test_tree = tree["3239"]
assert any("Driver" in k for k in test_tree)
+
+ def test_session_evaluate_real_data(self):
+ """Full pipeline through Session API on real data."""
+ state = AppState()
+ state.load_test(MME_DATA / "3239")
+ result = state.primary_test.evaluate("euro_ncap")
+ assert result.stars is not None
+ assert result.stars >= 0
+ assert len(result.region_scores) > 0