bookmark - Refactor

This commit is contained in:
2026-04-10 17:28:29 -04:00
parent 48ab8c28b9
commit 68422dd304
29 changed files with 1371 additions and 756 deletions

29
src/impakt/io/csv.py Normal file
View File

@@ -0,0 +1,29 @@
"""CSV data reader (stub).
Placeholder for a future CSV reader plugin.
Install or implement to enable reading CSV crash test data.
"""
from __future__ import annotations
from pathlib import Path
from impakt.channel.model import TestData, TestMetadata
class CSVReader:
"""Reader for CSV time-series data (not yet implemented)."""
@property
def format_name(self) -> str:
return "CSV"
def supports(self, path: Path) -> bool:
# Only claim support for .csv files with crash test structure
return False # Disabled until implemented
def metadata(self, path: Path) -> TestMetadata:
raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.")
def read(self, path: Path) -> TestData:
raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.")

37
src/impakt/io/tdms.py Normal file
View File

@@ -0,0 +1,37 @@
"""NI TDMS data reader (stub).
Placeholder for a future TDMS reader plugin.
Install nptdms to enable: pip install impakt[tdms]
"""
from __future__ import annotations
from pathlib import Path
from impakt.channel.model import TestData, TestMetadata
class TDMSReader:
"""Reader for NI TDMS files (not yet implemented).
Requires the optional ``nptdms`` dependency::
pip install impakt[tdms]
"""
@property
def format_name(self) -> str:
return "NI TDMS"
def supports(self, path: Path) -> bool:
return False # Disabled until implemented
def metadata(self, path: Path) -> TestMetadata:
raise NotImplementedError(
"TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]"
)
def read(self, path: Path) -> TestData:
raise NotImplementedError(
"TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]"
)

View File

@@ -2,6 +2,9 @@
Renders PlotSpec objects into interactive Plotly figures with Renders PlotSpec objects into interactive Plotly figures with
support for corridors, dual X-cursors, and export. support for corridors, dual X-cursors, and export.
The PlotEngine is the **single rendering path** — both the scripting
API and the web UI construct PlotSpec objects and delegate here.
""" """
from __future__ import annotations from __future__ import annotations
@@ -30,12 +33,39 @@ DEFAULT_COLORS = [
class PlotEngine: class PlotEngine:
"""Renders PlotSpec into Plotly figures.""" """Renders PlotSpec into Plotly figures.
Supports two modes:
- **Default**: Full interactive plot with hover tooltips and legend.
- **Compact** (``spec.compact=True``): Web UI mode with no legend, no
hover tooltips, tight margins, smaller axis labels. Cursor tracking
is handled by external JS.
"""
def render(self, spec: PlotSpec) -> go.Figure: def render(self, spec: PlotSpec) -> go.Figure:
"""Render a PlotSpec into an interactive Plotly figure.""" """Render a PlotSpec into an interactive Plotly figure."""
fig = go.Figure() fig = go.Figure()
if not spec.channels:
fig.update_layout(
template="plotly_white",
annotations=[
{
"text": "Select channels to plot",
"xref": "paper",
"yref": "paper",
"x": 0.5,
"y": 0.5,
"showarrow": False,
"font": {"size": 14, "color": "#bbb"},
}
],
xaxis={"visible": False},
yaxis={"visible": False},
margin={"l": 40, "r": 20, "t": 30, "b": 40},
)
return fig
# Add corridor fills first (behind data traces) # Add corridor fills first (behind data traces)
for corridor in spec.corridors: for corridor in spec.corridors:
self._add_corridor(fig, corridor) self._add_corridor(fig, corridor)
@@ -54,53 +84,92 @@ class PlotEngine:
color = style.color or DEFAULT_COLORS[i % len(DEFAULT_COLORS)] color = style.color or DEFAULT_COLORS[i % len(DEFAULT_COLORS)]
label = style.label or ch_ref.label label = style.label or ch_ref.label
fig.add_trace( trace_kwargs: dict[str, Any] = {
go.Scatter( "x": ch.time,
x=ch.time, "y": ch.data,
y=ch.data, "mode": "lines",
mode="lines", "name": label,
name=label, "line": dict(
line=dict( color=color,
color=color, width=style.line_width,
width=style.line_width, dash=style.line_dash,
dash=style.line_dash, ),
), "opacity": style.opacity,
opacity=style.opacity, }
hovertemplate=f"{label}<br>t=%{{x:.6f}}s<br>%{{y:.4f}} {ch.unit}<extra></extra>",
if not spec.compact:
trace_kwargs["hovertemplate"] = (
f"{label}<br>t=%{{x:.6f}}s<br>%{{y:.4f}} {ch.unit}<extra></extra>"
) )
)
fig.add_trace(go.Scatter(**trace_kwargs))
# Add cursor lines # Add cursor lines
if spec.x_cursors: if spec.x_cursors:
x1, x2 = spec.x_cursors x1, x2 = spec.x_cursors
for x_val, label in [(x1, "x1"), (x2, "x2")]: if spec.compact:
# Color-coded cursor lines for compact mode
fig.add_vline( fig.add_vline(
x=x_val, x=x1,
line_dash="dash", line_dash="dash",
line_color="gray", line_color="rgba(220,53,69,0.6)",
line_width=1, line_width=1,
annotation_text=f"{label}={x_val:.6f}s", annotation_text=f"X1={x1:.4f}s",
annotation_position="top", annotation_font_size=9,
annotation_font_color="rgba(220,53,69,0.8)",
) )
fig.add_vline(
x=x2,
line_dash="dash",
line_color="rgba(13,110,253,0.6)",
line_width=1,
annotation_text=f"X2={x2:.4f}s",
annotation_font_size=9,
annotation_font_color="rgba(13,110,253,0.8)",
)
else:
for x_val, lbl in [(x1, "x1"), (x2, "x2")]:
fig.add_vline(
x=x_val,
line_dash="dash",
line_color="gray",
line_width=1,
annotation_text=f"{lbl}={x_val:.6f}s",
annotation_position="top",
)
# Layout # Layout
fig.update_layout( if spec.compact:
title=spec.title, margin = spec.margin or {"l": 45, "r": 8, "t": 4, "b": 28}
xaxis_title=spec.x_label, fig.update_layout(
yaxis_title=spec.y_label, xaxis_title=dict(text=spec.x_label, font=dict(size=10, color="#999")),
showlegend=spec.show_legend, yaxis_title=dict(text=spec.y_label, font=dict(size=10, color="#999")),
height=spec.height, template="plotly_white",
width=spec.width, hovermode=False,
template="plotly_white", showlegend=False,
hovermode="x unified", margin=margin,
legend=dict( )
orientation="h", else:
yanchor="bottom", margin = spec.margin or {"l": 60, "r": 20, "t": 40 if spec.title else 10, "b": 60}
y=-0.3, hovermode = spec.hovermode if spec.hovermode is not None else "x unified"
xanchor="center", fig.update_layout(
x=0.5, title=spec.title,
), xaxis_title=spec.x_label,
) yaxis_title=spec.y_label,
showlegend=spec.show_legend,
height=spec.height,
width=spec.width,
template="plotly_white",
hovermode=hovermode,
legend=dict(
orientation="h",
yanchor="bottom",
y=-0.3,
xanchor="center",
x=0.5,
),
margin=margin,
)
if spec.show_grid: if spec.show_grid:
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128,128,128,0.2)") fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128,128,128,0.2)")
@@ -117,7 +186,6 @@ class PlotEngine:
"""Add a corridor (tolerance band) to the figure.""" """Add a corridor (tolerance band) to the figure."""
style = corridor.style style = corridor.style
# Upper bound
fig.add_trace( fig.add_trace(
go.Scatter( go.Scatter(
x=corridor.time, x=corridor.time,
@@ -128,8 +196,6 @@ class PlotEngine:
showlegend=False, showlegend=False,
) )
) )
# Lower bound with fill to upper
fig.add_trace( fig.add_trace(
go.Scatter( go.Scatter(
x=corridor.time, x=corridor.time,
@@ -144,16 +210,7 @@ class PlotEngine:
) )
def to_image(self, spec: PlotSpec, format: str = "png", scale: float = 2.0) -> bytes: def to_image(self, spec: PlotSpec, format: str = "png", scale: float = 2.0) -> bytes:
"""Render to a static image. """Render to a static image."""
Args:
spec: Plot specification.
format: Image format ('png', 'svg', 'pdf', 'jpeg').
scale: Resolution multiplier.
Returns:
Image bytes.
"""
fig = self.render(spec) fig = self.render(spec)
return fig.to_image(format=format, scale=scale) return fig.to_image(format=format, scale=scale)
@@ -168,16 +225,7 @@ def cursor_values(
x1: float, x1: float,
x2: float, x2: float,
) -> CursorValues: ) -> CursorValues:
"""Compute interpolated values at two X-axis positions. """Compute interpolated values at two X-axis positions."""
Args:
spec_or_channels: PlotSpec or list of Channels.
x1: First cursor position (time).
x2: Second cursor position (time).
Returns:
CursorValues with interpolated values for each channel.
"""
channels: list[tuple[str, Channel]] = [] channels: list[tuple[str, Channel]] = []
if isinstance(spec_or_channels, PlotSpec): if isinstance(spec_or_channels, PlotSpec):

View File

@@ -168,3 +168,8 @@ class PlotSpec:
show_grid: bool = True show_grid: bool = True
height: int = 500 height: int = 500
width: int = 900 width: int = 900
# Web UI compact mode: disables hover tooltips, removes legend,
# uses tight margins, and renders axis labels in a smaller font.
compact: bool = False
hovermode: str | bool = "x unified"
margin: dict[str, int] | None = None

View File

@@ -42,8 +42,15 @@ class PluginRegistry:
self._plugins: list[ImpaktPlugin] = [] self._plugins: list[ImpaktPlugin] = []
def register_reader(self, reader: Any) -> None: def register_reader(self, reader: Any) -> None:
"""Register a data reader.""" """Register a data reader.
Forwards to the IO reader registry so plugin readers are
discoverable by Session.open() and the web UI.
"""
self._readers.append(reader) self._readers.append(reader)
from impakt.io.reader import register_reader
register_reader(reader)
logger.info("Plugin reader registered: %s", getattr(reader, "format_name", reader)) logger.info("Plugin reader registered: %s", getattr(reader, "format_name", reader))
def register_transform(self, name: str, transform_cls: type) -> None: def register_transform(self, name: str, transform_cls: type) -> None:

View File

@@ -9,6 +9,7 @@ Threshold values are versioned — this module supports multiple protocol years.
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import Any from typing import Any
from impakt.criteria.base import CriterionResult from impakt.criteria.base import CriterionResult
@@ -64,9 +65,44 @@ def _points_from_color(color: Color, max_points: float) -> float:
return max_points * color_fractions[color] return max_points * color_fractions[color]
# Threshold sets by year # ---------------------------------------------------------------------------
# Format: {criterion: (green, yellow, orange, brown, red, higher_is_worse, max_points)} # Threshold loading — from YAML files or hardcoded fallback
THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]] = { # ---------------------------------------------------------------------------
_THRESHOLDS_DIR = Path(__file__).parent / "thresholds"
def _load_yaml_thresholds(
version: str,
) -> dict[str, tuple[float, float, float, float, float, bool, float]]:
"""Load thresholds from a YAML file for a given version."""
yaml_path = _THRESHOLDS_DIR / f"euro_ncap_{version}.yaml"
if not yaml_path.exists():
return {}
import yaml
data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {}
result: dict[str, tuple[float, float, float, float, float, bool, float]] = {}
for name, vals in data.items():
if isinstance(vals, dict):
result[name] = (
float(vals["green"]),
float(vals["yellow"]),
float(vals["orange"]),
float(vals["brown"]),
float(vals["red"]),
bool(vals.get("higher_is_worse", True)),
float(vals.get("max_points", 0)),
)
return result
# Hardcoded fallback (used if YAML files are missing)
_THRESHOLDS_2024_FALLBACK: dict[str, tuple[float, float, float, float, float, bool, float]] = {
"HIC15": (500.0, 620.0, 700.0, 850.0, 1000.0, True, 4.0), "HIC15": (500.0, 620.0, 700.0, 850.0, 1000.0, True, 4.0),
"3ms Clip": (42.0, 48.0, 54.0, 57.0, 60.0, True, 4.0), "3ms Clip": (42.0, 48.0, 54.0, 57.0, 60.0, True, 4.0),
"Chest Deflection": (22.0, 34.0, 42.0, 50.0, 63.0, True, 4.0), "Chest Deflection": (22.0, 34.0, 42.0, 50.0, 63.0, True, 4.0),
@@ -77,9 +113,17 @@ THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]
"Viscous Criterion": (0.32, 0.56, 0.8, 0.9, 1.0, True, 2.0), "Viscous Criterion": (0.32, 0.56, 0.8, 0.9, 1.0, True, 2.0),
} }
THRESHOLDS: dict[str, dict[str, tuple[float, float, float, float, float, bool, float]]] = {
"2024": THRESHOLDS_2024, def _get_thresholds(
} version: str,
) -> dict[str, tuple[float, float, float, float, float, bool, float]]:
"""Get thresholds for a version, preferring YAML over hardcoded."""
thresholds = _load_yaml_thresholds(version)
if thresholds:
return thresholds
if version == "2024":
return _THRESHOLDS_2024_FALLBACK
return {}
class EuroNCAP: class EuroNCAP:
@@ -87,11 +131,12 @@ class EuroNCAP:
def __init__(self, version: str = "2024") -> None: def __init__(self, version: str = "2024") -> None:
self._version = version self._version = version
if version not in THRESHOLDS: self._thresholds = _get_thresholds(version)
if not self._thresholds:
raise ValueError( raise ValueError(
f"Unknown Euro NCAP version: {version}. Available: {list(THRESHOLDS.keys())}" f"No thresholds found for Euro NCAP version: {version}. "
f"Check {_THRESHOLDS_DIR} for YAML files."
) )
self._thresholds = THRESHOLDS[version]
@property @property
def protocol_name(self) -> str: def protocol_name(self) -> str:

View File

@@ -6,27 +6,55 @@ The overall rating is determined by the worst sub-rating.
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import Any from typing import Any
from impakt.criteria.base import CriterionResult from impakt.criteria.base import CriterionResult
from impakt.protocol.base import BodyRegionScore, ProtocolResult, Rating from impakt.protocol.base import BodyRegionScore, ProtocolResult, Rating
_THRESHOLDS_DIR = Path(__file__).parent / "thresholds"
# Thresholds: (good_limit, acceptable_limit, marginal_limit)
# Values above marginal_limit = Poor def _load_iihs_yaml(version: str) -> dict[str, tuple[float, float, float, bool]]:
# higher_is_worse indicates that higher values are worse """Load IIHS thresholds from YAML."""
IIHS_THRESHOLDS_2024: dict[str, tuple[float, float, float, bool]] = { yaml_path = _THRESHOLDS_DIR / f"iihs_{version}.yaml"
if not yaml_path.exists():
return {}
import yaml
data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {}
result: dict[str, tuple[float, float, float, bool]] = {}
for name, vals in data.items():
if isinstance(vals, dict):
result[name] = (
float(vals["good"]),
float(vals["acceptable"]),
float(vals["marginal"]),
bool(vals.get("higher_is_worse", True)),
)
return result
# Hardcoded fallback
_IIHS_2024_FALLBACK: dict[str, tuple[float, float, float, bool]] = {
"HIC15": (250.0, 500.0, 700.0, True), "HIC15": (250.0, 500.0, 700.0, True),
"Chest Deflection": (38.0, 50.0, 63.0, True), # mm "Chest Deflection": (38.0, 50.0, 63.0, True),
"Femur Load Left": (3.8, 6.2, 10.0, True), # kN "Femur Load Left": (3.8, 6.2, 10.0, True),
"Femur Load Right": (3.8, 6.2, 10.0, True), # kN "Femur Load Right": (3.8, 6.2, 10.0, True),
"Nij": (0.52, 0.78, 1.0, True), "Nij": (0.52, 0.78, 1.0, True),
"Tibia Index": (0.5, 0.8, 1.3, True), "Tibia Index": (0.5, 0.8, 1.3, True),
} }
IIHS_THRESHOLDS: dict[str, dict[str, tuple[float, float, float, bool]]] = {
"2024": IIHS_THRESHOLDS_2024, def _get_iihs_thresholds(version: str) -> dict[str, tuple[float, float, float, bool]]:
} thresholds = _load_iihs_yaml(version)
if thresholds:
return thresholds
if version == "2024":
return _IIHS_2024_FALLBACK
return {}
def _rate_value( def _rate_value(
@@ -70,11 +98,12 @@ class IIHS:
def __init__(self, version: str = "2024") -> None: def __init__(self, version: str = "2024") -> None:
self._version = version self._version = version
if version not in IIHS_THRESHOLDS: self._thresholds = _get_iihs_thresholds(version)
if not self._thresholds:
raise ValueError( raise ValueError(
f"Unknown IIHS version: {version}. Available: {list(IIHS_THRESHOLDS.keys())}" f"No thresholds found for IIHS version: {version}. "
f"Check {_THRESHOLDS_DIR} for YAML files."
) )
self._thresholds = IIHS_THRESHOLDS[version]
@property @property
def protocol_name(self) -> str: def protocol_name(self) -> str:

View File

@@ -0,0 +1,76 @@
# Euro NCAP 2024 Adult Occupant Frontal Impact Thresholds
#
# Each criterion: [green, yellow, orange, brown, red, higher_is_worse, max_points]
# Sliding-scale: values at or below green = full points, at or above red = zero.
HIC15:
green: 500.0
yellow: 620.0
orange: 700.0
brown: 850.0
red: 1000.0
higher_is_worse: true
max_points: 4.0
3ms Clip:
green: 42.0
yellow: 48.0
orange: 54.0
brown: 57.0
red: 60.0
higher_is_worse: true
max_points: 4.0
Chest Deflection:
green: 22.0
yellow: 34.0
orange: 42.0
brown: 50.0
red: 63.0
higher_is_worse: true
max_points: 4.0
Nij:
green: 0.5
yellow: 0.65
orange: 0.8
brown: 0.9
red: 1.0
higher_is_worse: true
max_points: 2.0
Femur Load Left:
green: 3.8
yellow: 5.4
orange: 7.0
brown: 8.5
red: 10.0
higher_is_worse: true
max_points: 2.0
Femur Load Right:
green: 3.8
yellow: 5.4
orange: 7.0
brown: 8.5
red: 10.0
higher_is_worse: true
max_points: 2.0
Tibia Index:
green: 0.4
yellow: 0.7
orange: 1.0
brown: 1.15
red: 1.3
higher_is_worse: true
max_points: 2.0
Viscous Criterion:
green: 0.32
yellow: 0.56
orange: 0.8
brown: 0.9
red: 1.0
higher_is_worse: true
max_points: 2.0

View File

@@ -0,0 +1,40 @@
# IIHS 2024 Crashworthiness Thresholds
#
# Each criterion: [good, acceptable, marginal, higher_is_worse]
# Values above marginal = Poor.
HIC15:
good: 250.0
acceptable: 500.0
marginal: 700.0
higher_is_worse: true
Chest Deflection:
good: 38.0
acceptable: 50.0
marginal: 63.0
higher_is_worse: true
Femur Load Left:
good: 3.8
acceptable: 6.2
marginal: 10.0
higher_is_worse: true
Femur Load Right:
good: 3.8
acceptable: 6.2
marginal: 10.0
higher_is_worse: true
Nij:
good: 0.52
acceptable: 0.78
marginal: 1.0
higher_is_worse: true
Tibia Index:
good: 0.5
acceptable: 0.8
marginal: 1.3
higher_is_worse: true

View File

@@ -2,6 +2,23 @@
Provides the ``Session`` and ``Template`` classes that serve as the Provides the ``Session`` and ``Template`` classes that serve as the
primary entry points for both scripting and the web UI. primary entry points for both scripting and the web UI.
Usage::
from impakt import Session
test = Session.open("tests/mme_data/3239")
ch = test.channel("11HEAD0000H3ACXP")
# Fluent chaining — each transform returns a ChannelHandle
filtered = ch.transform.cfc(1000).transform.y_align()
# Compute all injury criteria (auto-detects channels)
criteria = test.compute_criteria()
# Score against a protocol
result = test.evaluate("euro_ncap")
result.to_pdf("report.pdf")
""" """
from __future__ import annotations from __future__ import annotations
@@ -36,12 +53,19 @@ class Session:
Wraps TestData with session state, template binding, and Wraps TestData with session state, template binding, and
convenience methods for transforms, criteria, and plotting. convenience methods for transforms, criteria, and plotting.
Usage: Usage::
test = Session.open("/path/to/test_001") test = Session.open("/path/to/test_001")
ch = test.channel("11HEAD0000ACXA") ch = test.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(1000) filtered = ch.transform.cfc(1000) # returns ChannelHandle, chainable
criteria = test.compute_criteria() # auto-detect channels
result = test.evaluate("euro_ncap") # score against protocol
result.to_pdf("report.pdf")
""" """
_plugins_discovered = False
def __init__(self, test_data: TestData, session_mgr: SessionManager | None = None) -> None: def __init__(self, test_data: TestData, session_mgr: SessionManager | None = None) -> None:
self._data = test_data self._data = test_data
self._session_mgr = session_mgr or ( self._session_mgr = session_mgr or (
@@ -49,12 +73,29 @@ class Session:
) )
self._template: TemplateSpec | None = None self._template: TemplateSpec | None = None
@classmethod
def _discover_plugins(cls) -> None:
"""Discover and register plugins. Called once on first Session.open()."""
if cls._plugins_discovered:
return
cls._plugins_discovered = True
try:
from impakt.plugin.registry import discover_all
discover_all()
except Exception as e:
logger.debug("Plugin discovery failed (non-fatal): %s", e)
@classmethod @classmethod
def open(cls, path: str | Path) -> Session: def open(cls, path: str | Path) -> Session:
"""Open a crash test from a path. """Open a crash test from a path.
Auto-detects the format using the reader registry. Auto-detects the format using the reader registry.
Discovers plugins on first call.
""" """
# Discover plugins (idempotent — safe to call multiple times)
cls._discover_plugins()
path = Path(path).resolve() path = Path(path).resolve()
registry = get_registry() registry = get_registry()
test_data = registry.read(path) test_data = registry.read(path)
@@ -87,6 +128,11 @@ class Session:
"""Underlying TestData object.""" """Underlying TestData object."""
return self._data return self._data
@property
def path(self) -> Path | None:
"""Path to the test data directory."""
return self._data.path
@property @property
def channel_names(self) -> list[str]: def channel_names(self) -> list[str]:
return self._data.channel_names return self._data.channel_names
@@ -98,7 +144,10 @@ class Session:
# ----- Channel access ----- # ----- Channel access -----
def channel(self, name: str) -> ChannelHandle: def channel(self, name: str) -> ChannelHandle:
"""Get a channel by name, wrapped with transform convenience methods.""" """Get a channel by name, wrapped with transform convenience methods.
Returns a ChannelHandle with a fluent ``.transform`` interface.
"""
ch = self._data.get(name) ch = self._data.get(name)
return ChannelHandle(ch) return ChannelHandle(ch)
@@ -118,6 +167,47 @@ class Session:
"""Hierarchical channel tree for UI display.""" """Hierarchical channel tree for UI display."""
return self._data.channel_tree() return self._data.channel_tree()
# ----- Criteria & Protocol -----
def compute_criteria(self) -> dict[str, CriterionResult]:
"""Auto-detect channels and compute all applicable injury criteria.
Uses ISO channel naming to find the right channels for each criterion.
Returns a dict of criterion name -> CriterionResult.
"""
from impakt.web.components.criteria import auto_compute_criteria
return auto_compute_criteria(self._data)
def evaluate(self, protocol: str = "euro_ncap", version: str = "") -> ProtocolResult:
"""Compute criteria and score against a rating protocol.
Args:
protocol: Protocol name ('euro_ncap', 'us_ncap', 'iihs').
version: Protocol version (optional, uses latest if empty).
Returns:
ProtocolResult with stars/rating and per-region scores.
"""
criteria = self.compute_criteria()
if protocol == "euro_ncap":
from impakt.protocol.euro_ncap import evaluate
return evaluate(criteria, version=version or "2024")
elif protocol == "us_ncap":
from impakt.protocol.us_ncap import evaluate
return evaluate(criteria, version=version or "2023")
elif protocol == "iihs":
from impakt.protocol.iihs import evaluate
return evaluate(criteria, version=version or "2024")
else:
raise ValueError(
f"Unknown protocol: {protocol}. Use 'euro_ncap', 'us_ncap', or 'iihs'."
)
# ----- Template ----- # ----- Template -----
def apply_template(self, name_or_spec: str | TemplateSpec) -> None: def apply_template(self, name_or_spec: str | TemplateSpec) -> None:
@@ -188,15 +278,16 @@ class Session:
class ChannelHandle: class ChannelHandle:
"""Wrapper around a Channel providing fluent transform access. """Wrapper around a Channel providing fluent transform access.
Example: Each transform method on ``.transform`` returns a new ``ChannelHandle``,
enabling chaining::
ch = session.channel("11HEAD0000ACXA") ch = session.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(1000) result = ch.transform.cfc(1000).transform.y_align().transform.trim(t_end=0.1)
aligned = ch.transform.x_align(method="threshold", threshold_value=5.0)
""" """
def __init__(self, channel: Channel) -> None: def __init__(self, channel: Channel) -> None:
self._channel = channel self._channel = channel
self.transform = TransformProxy(channel) self.transform = TransformProxy(self)
@property @property
def raw(self) -> Channel: def raw(self) -> Channel:
@@ -207,6 +298,10 @@ class ChannelHandle:
def name(self) -> str: def name(self) -> str:
return self._channel.name return self._channel.name
@property
def code(self):
return self._channel.code
@property @property
def data(self) -> np.ndarray: def data(self) -> np.ndarray:
return self._channel.data return self._channel.data
@@ -219,6 +314,14 @@ class ChannelHandle:
def unit(self) -> str: def unit(self) -> str:
return self._channel.unit return self._channel.unit
@property
def peak(self) -> float:
return self._channel.peak
@property
def sample_rate(self) -> float:
return self._channel.sample_rate
def value_at(self, t: float) -> float: def value_at(self, t: float) -> float:
return self._channel.value_at(t) return self._channel.value_at(t)
@@ -238,44 +341,51 @@ class ChannelHandle:
class TransformProxy: class TransformProxy:
"""Fluent transform interface for a channel. """Fluent transform interface for a channel.
Each method returns a new Channel (non-destructive). Each method returns a new ``ChannelHandle`` wrapping the transformed
channel, enabling chaining.
""" """
def __init__(self, channel: Channel) -> None: def __init__(self, handle: ChannelHandle) -> None:
self._channel = channel self._handle = handle
def cfc(self, cfc_class: int) -> Channel: @property
def _channel(self) -> Channel:
return self._handle._channel
def cfc(self, cfc_class: int) -> ChannelHandle:
"""Apply CFC filter.""" """Apply CFC filter."""
from impakt.transform.cfc import CFCFilter from impakt.transform.cfc import CFCFilter
return CFCFilter(cfc_class=cfc_class).apply(self._channel) return ChannelHandle(CFCFilter(cfc_class=cfc_class).apply(self._channel))
def x_align( def x_align(
self, method: str = "manual", reference_time: float = 0.0, **kwargs: Any self, method: str = "manual", reference_time: float = 0.0, **kwargs: Any
) -> Channel: ) -> ChannelHandle:
"""Apply time-zero alignment.""" """Apply time-zero alignment."""
from impakt.transform.align import XAlign from impakt.transform.align import XAlign
return XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel) return ChannelHandle(
XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel)
)
def y_align(self, window: tuple[float, float] | None = None) -> Channel: def y_align(self, window: tuple[float, float] | None = None) -> ChannelHandle:
"""Apply Y-axis zero correction.""" """Apply Y-axis zero correction."""
from impakt.transform.align import YAlign from impakt.transform.align import YAlign
start, end = window if window else (None, None) start, end = window if window else (None, None)
return YAlign(window_start=start, window_end=end).apply(self._channel) return ChannelHandle(YAlign(window_start=start, window_end=end).apply(self._channel))
def trim(self, t_start: float | None = None, t_end: float | None = None) -> Channel: def trim(self, t_start: float | None = None, t_end: float | None = None) -> ChannelHandle:
"""Trim to a time range.""" """Trim to a time range."""
from impakt.transform.resample import Trim from impakt.transform.resample import Trim
return Trim(t_start=t_start, t_end=t_end).apply(self._channel) return ChannelHandle(Trim(t_start=t_start, t_end=t_end).apply(self._channel))
def resample(self, target_rate: float) -> Channel: def resample(self, target_rate: float) -> ChannelHandle:
"""Resample to a new rate.""" """Resample to a new rate."""
from impakt.transform.resample import Resample from impakt.transform.resample import Resample
return Resample(target_rate=target_rate).apply(self._channel) return ChannelHandle(Resample(target_rate=target_rate).apply(self._channel))
class Template: class Template:

View File

@@ -1,7 +1,7 @@
"""Dash web application factory. """Dash web application factory.
Creates the Dash app with all layout components and callbacks registered. Creates the Dash app with all layout components and callbacks registered.
The AppState holds server-side data; Dash stores hold lightweight UI state. The AppState holds Session objects server-side; Dash stores hold lightweight UI state.
""" """
from __future__ import annotations from __future__ import annotations
@@ -13,14 +13,12 @@ import dash
import dash_bootstrap_components as dbc import dash_bootstrap_components as dbc
from impakt.channel.model import TestData from impakt.channel.model import TestData
from impakt.script.api import Session
from impakt.template.library import TemplateLibrary from impakt.template.library import TemplateLibrary
from impakt.web.callbacks import register_callbacks from impakt.web.callbacks import register_callbacks
from impakt.web.layout import build_layout from impakt.web.layout import build_layout
from impakt.web.state import AppState from impakt.web.state import AppState
if TYPE_CHECKING:
from impakt.script.api import Session
def create_app( def create_app(
session_or_data: Session | TestData | None = None, session_or_data: Session | TestData | None = None,
@@ -37,21 +35,15 @@ def create_app(
Returns: Returns:
Configured Dash app ready to run. Configured Dash app ready to run.
""" """
# Build or use provided AppState
if app_state is None: if app_state is None:
app_state = AppState() app_state = AppState()
if session_or_data is not None: if session_or_data is not None:
if hasattr(session_or_data, "data"): if isinstance(session_or_data, Session):
test_data: TestData = session_or_data.data # type: ignore[union-attr] app_state.add_session(session_or_data)
else: elif isinstance(session_or_data, TestData):
test_data = session_or_data # type: ignore[assignment] # Wrap raw TestData in a Session
session = Session(session_or_data)
# Create a LoadedTest from TestData directly app_state.add_session(session)
from impakt.web.state import LoadedTest
loaded = LoadedTest(test_data)
app_state._tests[test_data.test_id] = loaded
app_state._test_order.append(test_data.test_id)
# Discover templates # Discover templates
if template_names is None: if template_names is None:
@@ -64,14 +56,13 @@ def create_app(
# Title # Title
title = "Impakt" title = "Impakt"
if app_state.primary_test: if app_state.primary_test:
title = f"Impakt {app_state.primary_test.test_id}" title = f"Impakt \u2014 {app_state.primary_test.test_id}"
app = dash.Dash( app = dash.Dash(
__name__, __name__,
external_stylesheets=[dbc.themes.FLATLY], external_stylesheets=[dbc.themes.FLATLY],
title=title, title=title,
suppress_callback_exceptions=True, suppress_callback_exceptions=True,
# Prevent browser from caching old layouts
serve_locally=True, serve_locally=True,
meta_tags=[ meta_tags=[
{"http-equiv": "Cache-Control", "content": "no-cache, no-store, must-revalidate"}, {"http-equiv": "Cache-Control", "content": "no-cache, no-store, must-revalidate"},
@@ -92,16 +83,9 @@ def serve(
port: int = 8050, port: int = 8050,
debug: bool = False, debug: bool = False,
) -> None: ) -> None:
"""Convenience function to create and run the web UI. """Convenience function to create and run the web UI."""
Args:
session_or_data: Session or TestData to visualize, or None for empty app.
template: Template name to pre-apply.
port: Server port.
debug: Enable Dash debug mode.
"""
if template and session_or_data and hasattr(session_or_data, "apply_template"): if template and session_or_data and hasattr(session_or_data, "apply_template"):
session_or_data.apply_template(template) # type: ignore[union-attr] session_or_data.apply_template(template)
app = create_app(session_or_data) app = create_app(session_or_data)
print(f"Impakt running at http://localhost:{port}") print(f"Impakt running at http://localhost:{port}")

View File

@@ -1,9 +1,7 @@
"""Criteria computation callbacks. """Criteria computation callbacks.
Handles: Uses Session.compute_criteria() and Session.evaluate() from the
- Compute All button -> runs auto_compute_criteria scripting API — the same path used by programmatic scripts.
- Protocol scoring
- Results display
""" """
from __future__ import annotations from __future__ import annotations
@@ -14,11 +12,7 @@ import dash
from dash import Input, Output, State, html from dash import Input, Output, State, html
from dash.exceptions import PreventUpdate from dash.exceptions import PreventUpdate
from impakt.web.components.criteria import ( from impakt.web.components.criteria import build_criteria_results_display
auto_compute_criteria,
build_criteria_results_display,
score_protocol,
)
from impakt.web.state import AppState from impakt.web.state import AppState
@@ -41,8 +35,8 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None:
if primary is None: if primary is None:
return html.Div("No test loaded.", className="text-danger small") return html.Div("No test loaded.", className="text-danger small")
# Compute criteria # Use Session.compute_criteria() — single code path for script + web
criteria = auto_compute_criteria(primary.data) criteria = primary.compute_criteria()
if not criteria: if not criteria:
return html.Div( return html.Div(
@@ -51,7 +45,10 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None:
className="text-warning small", className="text-warning small",
) )
# Score against protocol # Score against protocol — use Session.evaluate()
protocol_result = score_protocol(criteria, protocol) try:
protocol_result = primary.evaluate(protocol)
except Exception:
protocol_result = None
return build_criteria_results_display(criteria, protocol_result) return build_criteria_results_display(criteria, protocol_result)

View File

@@ -1,10 +1,11 @@
"""Plot rendering callbacks. """Plot rendering callbacks.
Handles: Builds a PlotSpec from the UI state and delegates to PlotEngine.render().
- Updating plot figures when channels/transforms change This is the single rendering path — the same PlotEngine used by the
- Cursor line rendering (X1/X2 vertical lines) scripting API renders the web UI plots.
- Hover data is NOT shown as a Plotly tooltip — instead the cursor grid
picks it up via the hoverData callback Transform application uses TransformChain, making the pipeline
serializable and reproducible.
""" """
from __future__ import annotations from __future__ import annotations
@@ -15,16 +16,55 @@ from typing import Any
import dash import dash
import numpy as np import numpy as np
import plotly.graph_objects as go import plotly.graph_objects as go
from dash import ALL, Input, Output, State, ctx from dash import Input, Output, State
from dash.exceptions import PreventUpdate from dash.exceptions import PreventUpdate
from impakt.channel.model import Channel from impakt.channel.model import Channel
from impakt.plot.engine import DEFAULT_COLORS from impakt.plot.engine import DEFAULT_COLORS, PlotEngine
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
from impakt.transform.align import XAlign, YAlign from impakt.transform.align import XAlign, YAlign
from impakt.transform.base import TransformChain
from impakt.transform.cfc import CFCFilter from impakt.transform.cfc import CFCFilter
from impakt.transform.resultant import resultant_from_channels from impakt.transform.resultant import resultant_from_channels
from impakt.web.state import AppState from impakt.web.state import AppState
# Module-level engine instance (stateless, safe to reuse)
_engine = PlotEngine()
def _build_transform_chain(
cfc_value: str,
y_align: bool,
x_align_method: str,
x_align_value: float | None,
per_channel_cfc: str | None = None,
) -> TransformChain:
"""Build a TransformChain from the current UI state.
Per-channel CFC override takes precedence over the global CFC setting.
"""
chain = TransformChain()
# CFC filter
effective_cfc = per_channel_cfc if per_channel_cfc else cfc_value
if effective_cfc and effective_cfc != "none":
try:
chain = chain.append(CFCFilter(cfc_class=int(effective_cfc)))
except (ValueError, Exception):
pass
# Y-align
if y_align:
chain = chain.append(YAlign())
# X-align
if x_align_method == "manual" and x_align_value is not None:
chain = chain.append(XAlign(method="manual", reference_time=x_align_value))
elif x_align_method == "threshold" and x_align_value is not None:
chain = chain.append(XAlign(method="threshold", threshold_value=x_align_value))
return chain
def _resolve_channels( def _resolve_channels(
selected_keys: list[str], selected_keys: list[str],
@@ -35,12 +75,15 @@ def _resolve_channels(
x_align_value: float | None, x_align_value: float | None,
show_resultant: bool, show_resultant: bool,
) -> list[tuple[str, Channel]]: ) -> list[tuple[str, Channel]]:
"""Resolve selected channel keys to Channel objects with transforms applied. """Resolve selected channel keys to transformed Channel objects.
Per-channel overrides (from app_state.channel_overrides) take precedence Uses TransformChain for each channel. Per-channel overrides from
over the global CFC setting. app_state.channel_overrides take precedence over the global CFC.
Returns list of (label, transformed_channel) tuples.
""" """
channels: list[tuple[str, Channel]] = [] channels: list[tuple[str, Channel]] = []
multi_test = len(app_state.tests) > 1
for key in selected_keys: for key in selected_keys:
if "::" in key: if "::" in key:
@@ -55,29 +98,22 @@ def _resolve_channels(
if ch is None: if ch is None:
continue continue
# Determine CFC: per-channel override takes precedence over global # Build per-channel transform chain
override = app_state.channel_overrides.get(key, {}) override = app_state.channel_overrides.get(key, {})
ch_cfc = override.get("cfc", "") per_ch_cfc = override.get("cfc", "")
effective_cfc = ch_cfc if ch_cfc else cfc_value chain = _build_transform_chain(
cfc_value,
y_align,
x_align_method,
x_align_value,
per_channel_cfc=per_ch_cfc,
)
if effective_cfc and effective_cfc != "none": # Apply the chain
try: if len(chain) > 0:
ch = CFCFilter(cfc_class=int(effective_cfc)).apply(ch) ch = chain.apply(ch)
except (ValueError, Exception):
pass
# Apply Y-align
if y_align:
ch = YAlign().apply(ch)
# Apply X-align
if x_align_method == "manual" and x_align_value is not None:
ch = XAlign(method="manual", reference_time=x_align_value).apply(ch)
elif x_align_method == "threshold" and x_align_value is not None:
ch = XAlign(method="threshold", threshold_value=x_align_value).apply(ch)
# Build label # Build label
multi_test = len(app_state.tests) > 1
label = ch.code.short_label if ch.code.is_valid else ch.name label = ch.code.short_label if ch.code.is_valid else ch.name
if multi_test: if multi_test:
label = f"[{test_id}] {label}" label = f"[{test_id}] {label}"
@@ -99,7 +135,7 @@ def _resolve_channels(
res_label = ( res_label = (
f"{res.code.location_label} Resultant" if res.code.is_valid else "Resultant" f"{res.code.location_label} Resultant" if res.code.is_valid else "Resultant"
) )
if len(app_state.tests) > 1: if multi_test:
res_label = f"[{comps[0].source_test_id}] {res_label}" res_label = f"[{comps[0].source_test_id}] {res_label}"
channels.append((res_label, res)) channels.append((res_label, res))
except Exception: except Exception:
@@ -108,123 +144,61 @@ def _resolve_channels(
return channels return channels
def _build_figure( def _build_plot_spec(
channels: list[tuple[str, Channel]], channels: list[tuple[str, Channel]],
cursor_x1: float | None, cursor_x1: float | None,
cursor_x2: float | None, cursor_x2: float | None,
cfc_value: str,
corridors: list[dict] | None = None, corridors: list[dict] | None = None,
) -> go.Figure: ) -> PlotSpec:
"""Build a Plotly figure from resolved channels.""" """Build a PlotSpec from resolved channels and UI state.
fig = go.Figure()
if not channels: Channels are already transformed — they are wrapped in ChannelRef
fig.update_layout( objects with no additional transform chain (transforms were applied
template="plotly_white", during resolution).
annotations=[ """
{ # Build ChannelRef objects
"text": "Select channels from the grid to plot", refs: list[ChannelRef] = []
"xref": "paper",
"yref": "paper",
"x": 0.5,
"y": 0.5,
"showarrow": False,
"font": {"size": 14, "color": "#bbb"},
}
],
xaxis={"visible": False},
yaxis={"visible": False},
margin={"l": 40, "r": 20, "t": 30, "b": 40},
)
return fig
# Add traces — hovermode is disabled at the layout level; cursor tracking
# is handled by our own JS mousemove handler (cursor_tracker.js).
for i, (label, ch) in enumerate(channels): for i, (label, ch) in enumerate(channels):
color = DEFAULT_COLORS[i % len(DEFAULT_COLORS)] color = DEFAULT_COLORS[i % len(DEFAULT_COLORS)]
fig.add_trace( refs.append(
go.Scatter( ChannelRef(
x=ch.time.tolist(), channel=ch,
y=ch.data.tolist(), style=PlotStyle(label=label, color=color),
mode="lines",
name=label,
line=dict(color=color, width=1.5),
) )
) )
# Add corridor fills # Build Corridor objects from raw dicts
corridor_objs: list[Corridor] = []
if corridors: if corridors:
for corridor in corridors: for c in corridors:
if not corridor.get("visible", True): if not c.get("visible", True):
continue continue
c_time = corridor["time"] corridor_objs.append(
c_upper = corridor["upper"] Corridor(
c_lower = corridor["lower"] name=c.get("name", "Corridor"),
c_name = corridor.get("name", "Corridor") time=np.array(c["time"]),
lower=np.array(c["lower"]),
# Upper bound upper=np.array(c["upper"]),
fig.add_trace( style=CorridorStyle(),
go.Scatter(
x=c_time,
y=c_upper,
mode="lines",
line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"),
showlegend=False,
name=f"{c_name} upper",
)
)
# Lower bound with fill to upper
fig.add_trace(
go.Scatter(
x=c_time,
y=c_lower,
mode="lines",
line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"),
fill="tonexty",
fillcolor="rgba(100,100,255,0.1)",
showlegend=True,
name=c_name,
) )
) )
# Add X1/X2 cursor lines # Cursor positions
if cursor_x1 is not None: x_cursors = None
fig.add_vline( if cursor_x1 is not None and cursor_x2 is not None:
x=cursor_x1, x_cursors = (cursor_x1, cursor_x2)
line_dash="dash",
line_color="rgba(220,53,69,0.6)",
line_width=1,
annotation_text=f"X1={cursor_x1:.4f}s",
annotation_font_size=9,
annotation_font_color="rgba(220,53,69,0.8)",
)
if cursor_x2 is not None:
fig.add_vline(
x=cursor_x2,
line_dash="dash",
line_color="rgba(13,110,253,0.6)",
line_width=1,
annotation_text=f"X2={cursor_x2:.4f}s",
annotation_font_size=9,
annotation_font_color="rgba(13,110,253,0.8)",
)
# Layout — hovermode is disabled; cursor tracking is handled entirely # Y-axis label from first channel
# by our JS (cursor_tracker.js) which reads pixel positions from mousemove
# events and converts to data coordinates via Plotly's axis internals.
y_label = channels[0][1].unit if channels else "" y_label = channels[0][1].unit if channels else ""
fig.update_layout( return PlotSpec(
xaxis_title=dict(text="Time (s)", font=dict(size=10, color="#999")), channels=refs,
yaxis_title=dict(text=y_label, font=dict(size=10, color="#999")), corridors=corridor_objs,
template="plotly_white", x_cursors=x_cursors,
hovermode=False, y_label=y_label,
showlegend=False, compact=True, # Web UI always uses compact mode
margin=dict(l=45, r=8, t=4, b=28),
) )
return fig
def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None: def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None:
"""Register all plot-related callbacks.""" """Register all plot-related callbacks."""
@@ -271,10 +245,5 @@ def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None:
show_resultant, show_resultant,
) )
return _build_figure( spec = _build_plot_spec(channels, cursor_x1, cursor_x2, corridors_data)
channels, return _engine.render(spec)
cursor_x1,
cursor_x2,
cfc_value,
corridors=corridors_data,
)

View File

@@ -1,221 +0,0 @@
"""Collapsible channel tree component.
Renders channels in a hierarchical accordion:
Test Object > Body Region > Measurement Type > [channels]
Supports search filtering, select-all per group, and channel preview
(peak, unit, sample rate) on hover/expand.
"""
from __future__ import annotations
from typing import Any
import dash_bootstrap_components as dbc
from dash import dcc, html
from impakt.web.state import AppState
def _make_channel_item(ch_info: dict[str, str], show_test_prefix: bool = False) -> html.Div:
"""Build a single channel item with checkbox and preview info."""
label = ch_info["label"]
key = ch_info["key"]
peak = ch_info.get("peak", "")
unit = ch_info.get("unit", "")
preview = f" ({peak} {unit})" if peak else ""
return html.Div(
[
dbc.Checkbox(
id={"type": "channel-check", "index": key},
label=html.Span(
[
html.Span(label, style={"fontSize": "12px"}),
html.Span(
preview,
style={"fontSize": "10px", "color": "#999", "marginLeft": "4px"},
),
]
),
value=False,
style={"marginBottom": "1px"},
),
],
className="channel-item",
)
def _make_group_header(
group_label: str,
group_id: str,
channel_keys: list[str],
) -> html.Div:
"""Build a measurement group header with select-all button."""
return html.Div(
[
html.Span(group_label, style={"fontSize": "12px", "fontWeight": "500"}),
dbc.Button(
"All",
id={"type": "select-group-btn", "index": group_id},
color="link",
size="sm",
style={"fontSize": "10px", "padding": "0 4px", "textDecoration": "none"},
),
# Hidden store for this group's channel keys
dcc.Store(
id={"type": "group-channels", "index": group_id},
data=channel_keys,
),
],
className="d-flex align-items-center justify-content-between",
)
def build_channel_tree(app_state: AppState) -> dbc.Card:
"""Build the full channel tree panel with search and accordion."""
if app_state.is_empty:
return dbc.Card(
[
dbc.CardHeader("Channels", className="fw-bold py-2"),
dbc.CardBody(
html.Div("No test loaded", className="text-muted small text-center py-3"),
),
]
)
multi_test = len(app_state.tests) > 1
full_tree = app_state.build_channel_tree()
# Build accordion items
accordion_items = []
group_counter = 0
for test_id, test_tree in full_tree.items():
test_label = ""
loaded = app_state.get_test(test_id)
if loaded and multi_test:
test_label = f"[{test_id}] "
for obj_label, locations in test_tree.items():
# Top-level accordion: test object (Driver, Vehicle Structure, etc.)
obj_content = []
for loc_label, measurements in locations.items():
for meas_label, channels in measurements.items():
group_id = f"grp_{group_counter}"
group_counter += 1
channel_keys = [ch["key"] for ch in channels]
header_label = (
f"{loc_label}{meas_label}" if loc_label != meas_label else meas_label
)
obj_content.append(_make_group_header(header_label, group_id, channel_keys))
for ch in channels:
obj_content.append(_make_channel_item(ch, show_test_prefix=multi_test))
obj_content.append(html.Hr(style={"margin": "4px 0"}))
if obj_content:
# Remove trailing hr
if obj_content and isinstance(obj_content[-1], html.Hr):
obj_content = obj_content[:-1]
display_label = f"{test_label}{obj_label}" if test_label else obj_label
accordion_items.append(
dbc.AccordionItem(
html.Div(obj_content),
title=display_label,
style={"fontSize": "12px"},
)
)
return dbc.Card(
[
dbc.CardHeader(
[
html.Span("Channels", className="fw-bold"),
html.Span(
f" ({app_state.total_channels})",
style={"fontSize": "11px", "color": "#999"},
),
],
className="py-2",
),
dbc.CardBody(
[
# Search
dbc.Input(
id="channel-search",
placeholder="Search channels...",
type="text",
size="sm",
className="mb-2",
debounce=True,
),
# Selected channels display
html.Div(id="selected-channels-badges", className="mb-2"),
# Tree
html.Div(
dbc.Accordion(
accordion_items,
flush=True,
always_open=True,
start_collapsed=True,
id="channel-accordion",
),
style={"maxHeight": "450px", "overflowY": "auto"},
id="channel-tree-container",
),
],
className="py-2",
),
]
)
def build_selected_channels_badges(selected_keys: list[str], app_state: AppState) -> list:
"""Build badge pills showing currently selected channels."""
if not selected_keys:
return [
html.Span("No channels selected", className="text-muted", style={"fontSize": "11px"})
]
badges = []
for key in selected_keys[:20]: # Limit display to 20
# Extract label
if "::" in key:
test_id, ch_name = key.split("::", 1)
ch = app_state.get_channel(test_id, ch_name)
if ch and ch.code.is_valid:
label = ch.code.short_label
else:
label = ch_name
if len(app_state.tests) > 1:
label = f"{test_id}: {label}"
else:
label = key
badges.append(
dbc.Badge(
label,
id={"type": "channel-badge", "index": key},
color="primary",
pill=True,
className="me-1 mb-1",
style={"fontSize": "10px", "cursor": "pointer"},
)
)
if len(selected_keys) > 20:
badges.append(
html.Span(
f"+{len(selected_keys) - 20} more",
className="text-muted",
style={"fontSize": "10px"},
)
)
return badges

View File

@@ -129,7 +129,7 @@ def build_test_info_panel(app_state: AppState) -> dbc.Card:
f"{meta.impact.speed_kmh:.1f}" if meta.impact.speed_kmh else "", f"{meta.impact.speed_kmh:.1f}" if meta.impact.speed_kmh else "",
style={"fontSize": "12px"}, style={"fontSize": "12px"},
), ),
html.Td(str(loaded.channel_count), style={"fontSize": "12px"}), html.Td(str(len(loaded)), style={"fontSize": "12px"}),
html.Td( html.Td(
dbc.Button( dbc.Button(
"x", "x",

View File

@@ -1,14 +1,12 @@
"""Application state management. """Application state management.
AppState is the central data store for the web UI. It holds all loaded tests, AppState is the central data store for the web UI. It holds Session
manages channel transforms, and provides the data that callbacks need to objects (from the scripting API) and delegates to them for channel
render plots, compute criteria, and generate reports. access, transforms, criteria, and protocol scoring.
AppState is NOT a Dash Store — it lives server-side in Python memory. The Dash AppState is NOT a Dash Store — it lives server-side in Python memory.
stores hold lightweight references (test IDs, selected channels, plot config) Dash stores hold lightweight references (test IDs, selected channels)
that callbacks use to look up data from AppState. that callbacks use to look up data from AppState.
This design keeps large numpy arrays out of the browser and avoids serialization.
""" """
from __future__ import annotations from __future__ import annotations
@@ -17,220 +15,95 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from impakt.channel.model import Channel, ChannelGroup, TestData from impakt.channel.model import Channel, TestData
from impakt.io.reader import get_registry from impakt.script.api import Session
from impakt.template.library import TemplateLibrary from impakt.template.library import TemplateLibrary
from impakt.template.model import PlotDefinition, SessionState, TemplateSpec from impakt.template.model import PlotDefinition, TemplateSpec
from impakt.template.session import SessionManager
from impakt.transform.align import YAlign from impakt.transform.align import YAlign
from impakt.transform.cfc import CFCFilter from impakt.transform.cfc import CFCFilter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoadedTest:
"""A single loaded test with its metadata and channel access."""
def __init__(self, test_data: TestData, color_offset: int = 0) -> None:
self.data = test_data
self.color_offset = color_offset
@property
def test_id(self) -> str:
return self.data.test_id
@property
def label(self) -> str:
"""Short label for UI display."""
meta = self.data.metadata
parts = [meta.test_number]
if meta.vehicle.make:
parts.append(f"{meta.vehicle.make} {meta.vehicle.model}".strip())
return "".join(parts) if len(parts) > 1 else parts[0]
@property
def channel_count(self) -> int:
return len(self.data)
def __len__(self) -> int:
return len(self.data)
class AppState: class AppState:
"""Server-side application state. """Server-side application state.
Manages multiple loaded tests and provides channel resolution Manages multiple loaded test sessions and provides channel resolution
for the UI callbacks. for the UI callbacks. All test data access goes through Session objects.
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._tests: dict[str, LoadedTest] = {} self._sessions: dict[str, Session] = {}
self._test_order: list[str] = [] self._session_order: list[str] = []
self._color_counter: int = 0
self._template_library = TemplateLibrary() self._template_library = TemplateLibrary()
self._active_template: TemplateSpec | None = None self._active_template: TemplateSpec | None = None
self._session_managers: dict[str, SessionManager] = {} # Per-channel transform overrides: {channel_key: {"cfc": "600"}}
# Per-channel transform overrides: {channel_key: {"cfc": "600", "y_align": True}}
self.channel_overrides: dict[str, dict[str, Any]] = {} self.channel_overrides: dict[str, dict[str, Any]] = {}
# Active corridors: list of {name, time, lower, upper, visible} # Active corridors
self.corridors: list[dict[str, Any]] = [] self.corridors: list[dict[str, Any]] = []
def load_test(self, path: str | Path) -> LoadedTest: def load_test(self, path: str | Path) -> Session:
"""Load a test from a path and add it to the state. """Load a test from a path and add it to the state.
If a test with the same ID is already loaded, replaces it. Returns the Session object.
Returns:
The loaded test.
Raises:
ValueError: If the path is not a valid test directory.
""" """
path = Path(path).resolve() session = Session.open(path)
registry = get_registry() test_id = session.test_id
test_data = registry.read(path)
loaded = LoadedTest(test_data, color_offset=self._color_counter) self._sessions[test_id] = session
self._color_counter += len(test_data) if test_id not in self._session_order:
self._session_order.append(test_id)
test_id = test_data.test_id logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(session))
if test_id in self._tests: return session
self._tests[test_id] = loaded
else:
self._tests[test_id] = loaded
self._test_order.append(test_id)
logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(test_data)) def add_session(self, session: Session) -> None:
return loaded """Add an already-created Session to the state."""
test_id = session.test_id
self._sessions[test_id] = session
if test_id not in self._session_order:
self._session_order.append(test_id)
def remove_test(self, test_id: str) -> None: def remove_test(self, test_id: str) -> None:
"""Remove a loaded test.""" """Remove a loaded test."""
if test_id in self._tests: self._sessions.pop(test_id, None)
del self._tests[test_id] if test_id in self._session_order:
self._test_order.remove(test_id) self._session_order.remove(test_id)
@property @property
def tests(self) -> list[LoadedTest]: def sessions(self) -> list[Session]:
"""All loaded tests in load order.""" """All loaded sessions in load order."""
return [self._tests[tid] for tid in self._test_order if tid in self._tests] return [self._sessions[tid] for tid in self._session_order if tid in self._sessions]
@property
def tests(self) -> list[Session]:
"""Alias for sessions (backward compat for components)."""
return self.sessions
@property @property
def test_ids(self) -> list[str]: def test_ids(self) -> list[str]:
return list(self._test_order) return list(self._session_order)
@property @property
def primary_test(self) -> LoadedTest | None: def primary_test(self) -> Session | None:
"""The first loaded test (primary for criteria, etc.).""" """The first loaded session (primary for criteria, etc.)."""
if self._test_order: if self._session_order:
return self._tests.get(self._test_order[0]) return self._sessions.get(self._session_order[0])
return None return None
def get_test(self, test_id: str) -> LoadedTest | None: def get_test(self, test_id: str) -> Session | None:
return self._tests.get(test_id) return self._sessions.get(test_id)
def get_channel(self, test_id: str, channel_name: str) -> Channel | None: def get_channel(self, test_id: str, channel_name: str) -> Channel | None:
"""Resolve a channel from a test ID and channel name.""" """Resolve a channel from a test ID and channel name."""
test = self._tests.get(test_id) session = self._sessions.get(test_id)
if test is None: if session is None:
return None return None
try: try:
return test.data.get(channel_name) return session.data.get(channel_name)
except KeyError: except KeyError:
return None return None
def resolve_channel(
self,
channel_key: str,
cfc_class: int | None = None,
y_align: bool = False,
) -> Channel | None:
"""Resolve a channel key (test_id::channel_name) and apply transforms.
Channel keys use the format 'test_id::channel_name'. If no '::'
separator, assumes the primary test.
"""
if "::" in channel_key:
test_id, ch_name = channel_key.split("::", 1)
else:
if self.primary_test is None:
return None
test_id = self.primary_test.test_id
ch_name = channel_key
ch = self.get_channel(test_id, ch_name)
if ch is None:
return None
# Apply transforms
if cfc_class is not None:
try:
ch = CFCFilter(cfc_class=cfc_class).apply(ch)
except (ValueError, Exception):
pass
if y_align:
ch = YAlign().apply(ch)
return ch
def build_channel_tree(
self,
) -> dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]]:
"""Build a hierarchical channel tree across all loaded tests.
Returns:
{test_id: {object_label: {location_label: {measurement_label: [{name, label, key}]}}}}
"""
result: dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]] = {}
for loaded in self.tests:
tree = loaded.data.channel_tree()
test_tree: dict[str, dict[str, dict[str, list[dict[str, str]]]]] = {}
for obj, locations in sorted(tree.items()):
obj_tree: dict[str, dict[str, list[dict[str, str]]]] = {}
for loc, measurements in sorted(locations.items()):
loc_tree: dict[str, list[dict[str, str]]] = {}
for meas, channels in sorted(measurements.items()):
ch_list = []
for ch in channels:
label = ch.code.short_label if ch.code.is_valid else ch.name
ch_list.append(
{
"name": ch.name,
"label": label,
"key": f"{loaded.test_id}::{ch.name}",
"unit": ch.unit,
"peak": f"{ch.peak:.2f}",
"samples": str(ch.n_samples),
"rate": f"{ch.sample_rate:.0f}",
}
)
loc_tree[meas] = ch_list
obj_tree[loc] = loc_tree
test_tree[obj] = obj_tree
result[loaded.test_id] = test_tree
return result
def flat_channel_list(self) -> list[dict[str, str]]:
"""Flat list of all channels across all tests, for dropdown/checklist options."""
items: list[dict[str, str]] = []
multi = len(self._tests) > 1
for loaded in self.tests:
prefix = f"[{loaded.test_id}] " if multi else ""
for ch in sorted(loaded.data, key=lambda c: c.name):
label = ch.code.short_label if ch.code.is_valid else ch.name
items.append(
{
"label": f"{prefix}{label}",
"value": f"{loaded.test_id}::{ch.name}",
}
)
return items
# ----- Template & Session ----- # ----- Template & Session -----
@property @property
@@ -245,52 +118,37 @@ class AppState:
def template_names(self) -> list[str]: def template_names(self) -> list[str]:
return self._template_library.list() return self._template_library.list()
def apply_template( def apply_template(self, name: str) -> tuple[list[str], dict[str, str]]:
self,
name: str,
selected_keys: list[str] | None = None,
) -> tuple[list[str], dict[str, str]]:
"""Apply a template by name. """Apply a template by name.
Resolves the template's channel patterns against the primary test, Resolves channel patterns against the primary test.
sets the active template, and returns the resolved channel keys and Returns (selected_channel_keys, transform_settings).
transform settings.
Returns:
(selected_channel_keys, transform_settings)
""" """
template = self._template_library.get(name) template = self._template_library.get(name)
self._active_template = template self._active_template = template
# Persist to session if primary test has a path
primary = self.primary_test primary = self.primary_test
if primary and primary.data.path: if primary:
mgr = self._get_session_manager(primary.data.path) primary.apply_template(template)
mgr.apply_template(template)
# Resolve channel patterns from the template # Resolve channel patterns
resolved_keys: list[str] = [] resolved_keys: list[str] = []
if primary: if primary:
for plot_def in template.plots: for plot_def in template.plots:
for pattern in plot_def.channel_patterns: for pattern in plot_def.channel_patterns:
matches = primary.data.find(pattern) matches = primary.find(pattern)
for ch in matches: for ch in matches:
key = f"{primary.test_id}::{ch.name}" key = f"{primary.test_id}::{ch.name}"
if key not in resolved_keys: if key not in resolved_keys:
resolved_keys.append(key) resolved_keys.append(key)
# Transform settings
transforms: dict[str, str] = {} transforms: dict[str, str] = {}
if template.default_cfc is not None: if template.default_cfc is not None:
transforms["cfc"] = str(template.default_cfc) transforms["cfc"] = str(template.default_cfc)
else: else:
transforms["cfc"] = "none" transforms["cfc"] = "none"
logger.info( logger.info("Applied template '%s'%d channels", name, len(resolved_keys))
"Applied template '%s' — resolved %d channels",
name,
len(resolved_keys),
)
return resolved_keys, transforms return resolved_keys, transforms
def save_as_template( def save_as_template(
@@ -303,30 +161,19 @@ class AppState:
x2: float | None, x2: float | None,
protocol: str = "", protocol: str = "",
) -> TemplateSpec: ) -> TemplateSpec:
"""Capture the current UI state as a new template. """Capture current UI state as a new template."""
Converts selected channels back to patterns and stores the
current filter/cursor/protocol settings.
"""
# Convert selected keys to channel patterns
patterns: list[str] = [] patterns: list[str] = []
for key in selected_keys: for key in selected_keys:
if "::" in key: ch_name = key.split("::", 1)[-1] if "::" in key else key
_, ch_name = key.split("::", 1)
else:
ch_name = key
# Use the raw channel name as a pattern (exact match)
if ch_name not in patterns: if ch_name not in patterns:
patterns.append(ch_name) patterns.append(ch_name)
# Build plot definition
plot = PlotDefinition( plot = PlotDefinition(
title=name, title=name,
channel_patterns=patterns, channel_patterns=patterns,
x_cursors=(x1, x2) if x1 is not None and x2 is not None else None, x_cursors=(x1, x2) if x1 is not None and x2 is not None else None,
) )
# Build transforms list
transforms: list[dict[str, Any]] = [] transforms: list[dict[str, Any]] = []
if cfc_value and cfc_value != "none": if cfc_value and cfc_value != "none":
transforms.append({"type": "cfc_filter", "cfc_class": int(cfc_value)}) transforms.append({"type": "cfc_filter", "cfc_class": int(cfc_value)})
@@ -343,17 +190,18 @@ class AppState:
self._template_library.save(template) self._template_library.save(template)
self._active_template = template self._active_template = template
logger.info("Saved template '%s' with %d patterns", name, len(patterns))
logger.info("Saved template '%s' with %d channel patterns", name, len(patterns))
return template return template
def save_session(self, selected_keys: list[str], cfc_value: str, **overrides: Any) -> None: def save_session(self, selected_keys: list[str], cfc_value: str, **overrides: Any) -> None:
"""Auto-save current state to the session for the primary test.""" """Auto-save current state to the session for the primary test."""
primary = self.primary_test primary = self.primary_test
if not primary or not primary.data.path: if not primary or not primary.path:
return return
mgr = self._get_session_manager(primary.data.path) from impakt.template.session import SessionManager
mgr = SessionManager(primary.path)
mgr.state.overrides = { mgr.state.overrides = {
"selected_channels": selected_keys, "selected_channels": selected_keys,
"cfc": cfc_value, "cfc": cfc_value,
@@ -362,12 +210,14 @@ class AppState:
mgr.save() mgr.save()
def load_session_state(self) -> dict[str, Any] | None: def load_session_state(self) -> dict[str, Any] | None:
"""Load saved session state for the primary test, if any.""" """Load saved session state for the primary test."""
primary = self.primary_test primary = self.primary_test
if not primary or not primary.data.path: if not primary or not primary.path:
return None return None
mgr = self._get_session_manager(primary.data.path) from impakt.template.session import SessionManager
mgr = SessionManager(primary.path)
if not mgr.has_session: if not mgr.has_session:
return None return None
@@ -377,21 +227,48 @@ class AppState:
"template": mgr.state.template_name, "template": mgr.state.template_name,
} }
def _get_session_manager(self, path: Path) -> SessionManager: # ----- Channel tree / grid helpers -----
key = str(path)
if key not in self._session_managers: def build_channel_tree(self) -> dict[str, Any]:
self._session_managers[key] = SessionManager(path) """Build hierarchical channel tree across all loaded tests."""
return self._session_managers[key] result: dict[str, Any] = {}
for session in self.sessions:
tree = session.channel_tree()
test_tree: dict[str, Any] = {}
for obj, locations in sorted(tree.items()):
obj_tree: dict[str, Any] = {}
for loc, measurements in sorted(locations.items()):
loc_tree: dict[str, Any] = {}
for meas, channels in sorted(measurements.items()):
ch_list = []
for ch in channels:
label = ch.code.short_label if ch.code.is_valid else ch.name
ch_list.append(
{
"name": ch.name,
"label": label,
"key": f"{session.test_id}::{ch.name}",
"unit": ch.unit,
"peak": f"{ch.peak:.2f}",
"samples": str(ch.n_samples),
"rate": f"{ch.sample_rate:.0f}",
}
)
loc_tree[meas] = ch_list
obj_tree[loc] = loc_tree
test_tree[obj] = obj_tree
result[session.test_id] = test_tree
return result
@property @property
def is_empty(self) -> bool: def is_empty(self) -> bool:
return len(self._tests) == 0 return len(self._sessions) == 0
@property @property
def total_channels(self) -> int: def total_channels(self) -> int:
return sum(t.channel_count for t in self.tests) return sum(len(s) for s in self.sessions)
def __repr__(self) -> str: def __repr__(self) -> str:
test_info = ", ".join(f"{t.test_id}({t.channel_count}ch)" for t in self.tests) test_info = ", ".join(f"{s.test_id}({len(s)}ch)" for s in self.sessions)
tmpl = f", template={self._active_template.name}" if self._active_template else "" tmpl = f", template={self._active_template.name}" if self._active_template else ""
return f"AppState([{test_info}]{tmpl})" return f"AppState([{test_info}]{tmpl})"

View File

@@ -0,0 +1,93 @@
"""Tests for chest deflection, femur load, tibia index, 3ms clip, viscous criterion."""
import numpy as np
import pytest
from impakt.criteria import chest_deflection, clip_3ms, femur_load, tibia_index, viscous_criterion
class TestChestDeflection:
def test_basic(self, chest_deflection_channel):
result = chest_deflection(channel=chest_deflection_channel)
assert result.criterion == "Chest Deflection"
assert result.unit == "mm"
assert 30.0 < result.value < 40.0 # ~35mm in fixture
def test_peak_time(self, chest_deflection_channel):
result = chest_deflection(channel=chest_deflection_channel)
assert result.time_of_peak is not None
assert 0.0 < result.time_of_peak < 0.1
class TestClip3ms:
def test_basic(self, head_accel_x):
result = clip_3ms(head_accel_x)
assert result.criterion == "3ms Clip"
assert result.value > 0
def test_body_region(self, head_accel_x):
result = clip_3ms(head_accel_x)
assert result.body_region == "Chest"
class TestFemurLoad:
def test_single_channel(self, femur_left_channel):
result = femur_load(channel=femur_left_channel, side="left")
assert result.criterion == "Femur Load Left"
assert result.unit == "kN"
assert result.value > 0
def test_peak_time(self, femur_left_channel):
result = femur_load(channel=femur_left_channel, side="left")
assert result.time_of_peak is not None
class TestTibiaIndex:
def test_with_all_components(self, time_array, sample_rate):
from impakt.channel.code import ChannelCode
from impakt.channel.model import Channel
t = time_array
fz = np.zeros_like(t)
mx = np.zeros_like(t)
my = np.zeros_like(t)
mask = (t >= 0.02) & (t <= 0.08)
fz[mask] = -5000 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
mx[mask] = 50 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
my[mask] = 80 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
fz_ch = Channel(
name="TIBFZ",
code=ChannelCode.parse("TIBFZ"),
data=fz,
time=t,
unit="N",
sample_rate=sample_rate,
)
mx_ch = Channel(
name="TIBMX",
code=ChannelCode.parse("TIBMX"),
data=mx,
time=t,
unit="N·m",
sample_rate=sample_rate,
)
my_ch = Channel(
name="TIBMY",
code=ChannelCode.parse("TIBMY"),
data=my,
time=t,
unit="N·m",
sample_rate=sample_rate,
)
result = tibia_index(fz_channel=fz_ch, mx_channel=mx_ch, my_channel=my_ch)
assert result.value > 0
class TestViscousCriterion:
def test_basic(self, chest_deflection_channel):
result = viscous_criterion(channel=chest_deflection_channel)
assert result.criterion == "Viscous Criterion"
assert result.unit == "m/s"
assert result.value >= 0

View File

@@ -0,0 +1,81 @@
"""Tests for plot engine."""
import numpy as np
import pytest
from impakt.plot.engine import PlotEngine, DEFAULT_COLORS
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
class TestPlotEngine:
def test_render_empty(self):
engine = PlotEngine()
spec = PlotSpec()
fig = engine.render(spec)
assert fig is not None
def test_render_with_channels(self, head_accel_x, head_accel_y):
engine = PlotEngine()
spec = PlotSpec(
channels=[
ChannelRef(channel=head_accel_x, style=PlotStyle(label="Head X")),
ChannelRef(channel=head_accel_y, style=PlotStyle(label="Head Y")),
],
y_label="g",
)
fig = engine.render(spec)
assert len(fig.data) == 2
def test_render_compact_mode(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
compact=True,
)
fig = engine.render(spec)
assert fig.layout.hovermode is False
assert fig.layout.showlegend is False
def test_render_with_corridors(self, head_accel_x):
engine = PlotEngine()
corridor = Corridor(
name="Test Corridor",
time=np.linspace(0, 0.1, 100),
lower=np.full(100, -50.0),
upper=np.full(100, 50.0),
)
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
corridors=[corridor],
)
fig = engine.render(spec)
# 1 data trace + 2 corridor traces (upper + lower)
assert len(fig.data) == 3
def test_render_with_cursors(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
x_cursors=(0.02, 0.06),
)
fig = engine.render(spec)
# Should have vertical lines (as shapes)
assert len(fig.layout.shapes) >= 2
def test_compact_cursor_annotations(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
x_cursors=(0.02, 0.06),
compact=True,
)
fig = engine.render(spec)
# Should have X1/X2 annotations
annotations = [
a for a in (fig.layout.annotations or []) if "X1" in str(a.text) or "X2" in str(a.text)
]
assert len(annotations) == 2
def test_default_colors(self):
assert len(DEFAULT_COLORS) == 10
assert all(c.startswith("#") for c in DEFAULT_COLORS)

View File

@@ -0,0 +1,68 @@
"""Tests for IIHS and US NCAP scoring."""
import pytest
from impakt.criteria.base import CriterionResult
from impakt.protocol.iihs import IIHS, evaluate as iihs_evaluate
from impakt.protocol.us_ncap import USNCAP, evaluate as us_evaluate
@pytest.fixture
def sample_criteria():
return {
"HIC15": CriterionResult(criterion="HIC15", value=400, body_region="Head"),
"Nij": CriterionResult(criterion="Nij", value=0.4, body_region="Neck"),
"Chest Deflection": CriterionResult(
criterion="Chest Deflection", value=35, unit="mm", body_region="Chest"
),
"Femur Load Left": CriterionResult(
criterion="Femur Load Left", value=4.0, unit="kN", body_region="Femur Left"
),
"Femur Load Right": CriterionResult(
criterion="Femur Load Right", value=4.5, unit="kN", body_region="Femur Right"
),
}
class TestIIHS:
def test_good_results(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
assert result.protocol == "IIHS"
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
def test_region_scores(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
assert len(result.region_scores) > 0
for rs in result.region_scores:
assert rs.rating is not None
def test_poor_hic(self):
result = iihs_evaluate(
{
"HIC15": CriterionResult(criterion="HIC15", value=1500, body_region="Head"),
}
)
assert result.overall_rating == "POOR"
def test_summary(self, sample_criteria):
result = iihs_evaluate(sample_criteria)
summary = result.summary()
assert "IIHS" in summary
class TestUSNCAP:
def test_basic(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert result.protocol == "US NCAP"
assert result.stars is not None
assert 1 <= result.stars <= 5
def test_injury_probabilities(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert "combined_injury_probability" in result.details
p = result.details["combined_injury_probability"]
assert 0.0 <= p <= 1.0
def test_region_scores(self, sample_criteria):
result = us_evaluate(sample_criteria)
assert len(result.region_scores) > 0

126
tests/test_scripting_api.py Normal file
View File

@@ -0,0 +1,126 @@
"""Tests for the scripting API (Session, ChannelHandle, TransformProxy)."""
from pathlib import Path
import pytest
from impakt import Session, Template
FIXTURE_DATA = Path(__file__).parent / "fixtures" / "sample_mme"
MME_DATA = Path(__file__).parent / "mme_data"
class TestSession:
def test_open(self):
s = Session.open(FIXTURE_DATA)
assert s.test_id == "IMPAKT_SYNTH_001"
assert len(s) == 26
def test_channel_access(self):
s = Session.open(FIXTURE_DATA)
ch = s.channel("11HEAD0000ACXA")
assert ch.name == "11HEAD0000ACXA"
assert ch.peak > 0
def test_find(self):
s = Session.open(FIXTURE_DATA)
channels = s.find("*HEAD*AC*")
assert len(channels) == 3
def test_group(self):
s = Session.open(FIXTURE_DATA)
group = s.group("HEAD0000AC")
assert group.x is not None
def test_compute_criteria(self):
s = Session.open(FIXTURE_DATA)
criteria = s.compute_criteria()
assert len(criteria) > 0
def test_evaluate(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("euro_ncap")
assert result.stars is not None
assert result.protocol == "Euro NCAP"
def test_evaluate_us_ncap(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("us_ncap")
assert result.stars is not None
def test_evaluate_iihs(self):
s = Session.open(FIXTURE_DATA)
result = s.evaluate("iihs")
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
def test_evaluate_invalid_protocol(self):
s = Session.open(FIXTURE_DATA)
with pytest.raises(ValueError, match="Unknown protocol"):
s.evaluate("invalid")
def test_contains(self):
s = Session.open(FIXTURE_DATA)
assert "11HEAD0000ACXA" in s
assert "NONEXISTENT" not in s
class TestChannelHandleChaining:
"""The fluent API must support chaining — each transform returns ChannelHandle."""
def test_single_transform(self):
s = Session.open(FIXTURE_DATA)
ch = s.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(600)
assert type(filtered).__name__ == "ChannelHandle"
assert filtered.raw.cfc_class == 600
def test_double_chain(self):
s = Session.open(FIXTURE_DATA)
result = s.channel("11HEAD0000ACXA").transform.cfc(600).transform.y_align()
assert type(result).__name__ == "ChannelHandle"
assert len(result.raw.transform_history) == 2
def test_triple_chain(self):
s = Session.open(FIXTURE_DATA)
result = (
s.channel("11HEAD0000ACXA")
.transform.cfc(1000)
.transform.y_align()
.transform.trim(t_start=0.0, t_end=0.1)
)
assert type(result).__name__ == "ChannelHandle"
assert len(result.raw.transform_history) == 3
def test_chain_preserves_data(self):
s = Session.open(FIXTURE_DATA)
original = s.channel("11HEAD0000ACXA")
original_peak = original.peak
filtered = original.transform.cfc(600)
# Original should be unchanged — peak should be the same
assert original.peak == original_peak
# Filtered should have different CFC and lower peak (smoothed)
assert filtered.raw.cfc_class == 600
assert filtered.peak <= original_peak
@pytest.mark.skipif(not (MME_DATA / "3239").exists(), reason="Real data not available")
class TestSessionRealData:
def test_open_real(self):
s = Session.open(MME_DATA / "3239")
assert s.test_id == "3239"
assert len(s) == 133
def test_full_pipeline(self):
s = Session.open(MME_DATA / "3239")
# Chain: get channel -> filter -> check
ch = s.channel("11HEAD0000H3ACXP").transform.cfc(1000)
assert ch.peak > 100 # Significant head acceleration
# Compute criteria
criteria = s.compute_criteria()
assert "HIC15" in criteria
# Evaluate
result = s.evaluate("euro_ncap")
assert result.stars is not None
assert result.stars >= 0

132
tests/test_template.py Normal file
View File

@@ -0,0 +1,132 @@
"""Tests for template model and session persistence."""
from pathlib import Path
import pytest
from impakt.template.library import TemplateLibrary
from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
from impakt.template.session import SessionManager
class TestTemplateSpec:
def test_yaml_round_trip(self):
template = TemplateSpec(
name="Test Template",
version=2,
description="A test template",
plots=[
PlotDefinition(
title="Head Acceleration",
channel_patterns=["*HEAD*AC*"],
x_cursors=(0.0, 0.1),
)
],
default_cfc=1000,
criteria=["hic15", "nij"],
protocol="euro_ncap",
)
yaml_str = template.to_yaml()
restored = TemplateSpec.from_yaml(yaml_str)
assert restored.name == "Test Template"
assert restored.version == 2
assert restored.default_cfc == 1000
assert len(restored.plots) == 1
assert restored.plots[0].channel_patterns == ["*HEAD*AC*"]
assert restored.criteria == ["hic15", "nij"]
def test_save_and_load(self, tmp_path):
template = TemplateSpec(name="Saved Test", version=1)
path = tmp_path / "test.yaml"
template.save(path)
assert path.exists()
loaded = TemplateSpec.load(path)
assert loaded.name == "Saved Test"
class TestSessionState:
def test_yaml_round_trip(self):
state = SessionState(
template_name="my_template",
template_version=3,
test_path="/data/test_001",
overrides={"cfc": "600", "selected": ["ch1", "ch2"]},
)
yaml_str = state.to_yaml()
restored = SessionState.from_yaml(yaml_str)
assert restored.template_name == "my_template"
assert restored.template_version == 3
assert restored.overrides["cfc"] == "600"
def test_save_and_load(self, tmp_path):
state = SessionState(template_name="test")
path = tmp_path / "session.yaml"
state.save(path)
assert path.exists()
loaded = SessionState.load(path)
assert loaded.template_name == "test"
class TestTemplateLibrary:
def test_empty_library(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
assert lib.list() == []
assert len(lib) == 0
def test_save_and_list(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
template = TemplateSpec(name="My Template")
lib.save(template)
assert "my_template" in lib.list()
assert len(lib) == 1
def test_get(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
lib.save(TemplateSpec(name="Getter Test", version=5))
loaded = lib.get("getter_test")
assert loaded.name == "Getter Test"
assert loaded.version == 5
def test_delete(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
lib.save(TemplateSpec(name="To Delete"))
assert lib.delete("to_delete")
assert "to_delete" not in lib.list()
def test_get_missing_raises(self, tmp_path):
lib = TemplateLibrary(tmp_path / "templates")
with pytest.raises(FileNotFoundError):
lib.get("nonexistent")
class TestSessionManager:
def test_create_and_save(self, tmp_path):
mgr = SessionManager(tmp_path)
mgr.state.template_name = "test_tmpl"
mgr.save()
assert mgr.has_session
assert (tmp_path / ".impakt" / "session.yaml").exists()
def test_load_existing(self, tmp_path):
# Save
mgr1 = SessionManager(tmp_path)
mgr1.state.template_name = "saved_tmpl"
mgr1.state.overrides = {"key": "value"}
mgr1.save()
# Load
mgr2 = SessionManager(tmp_path)
assert mgr2.state.template_name == "saved_tmpl"
assert mgr2.state.overrides["key"] == "value"
def test_clear(self, tmp_path):
mgr = SessionManager(tmp_path)
mgr.save()
assert mgr.has_session
mgr.clear()
assert not mgr.has_session

View File

@@ -0,0 +1,75 @@
"""Tests for math expressions and resultant computation."""
import numpy as np
import pytest
from impakt.transform.math_expr import math_expr
from impakt.transform.resultant import resultant_from_channels
from impakt.transform.resample import trim, resample
class TestMathExpr:
def test_simple_expression(self, head_accel_x, head_accel_z):
result = math_expr(
expression="sqrt(a**2 + b**2)",
channels={"a": head_accel_x, "b": head_accel_z},
name="resultant_xz",
unit="g",
)
assert result.name == "resultant_xz"
assert result.unit == "g"
assert result.peak > 0
assert len(result.data) == len(head_accel_x.data)
def test_constant_expression(self, head_accel_x):
result = math_expr(
expression="a * 0 + 42.0",
channels={"a": head_accel_x},
name="constant",
)
assert np.allclose(result.data, 42.0)
def test_invalid_expression(self, head_accel_x):
with pytest.raises(ValueError, match="Error evaluating"):
math_expr(
expression="invalid_func(a)",
channels={"a": head_accel_x},
)
def test_forbidden_expression(self, head_accel_x):
with pytest.raises(ValueError, match="Forbidden"):
math_expr(
expression="__import__('os')",
channels={"a": head_accel_x},
)
class TestResultant:
def test_from_channels(self, head_accel_x, head_accel_y, head_accel_z):
result = resultant_from_channels(head_accel_x, head_accel_y, head_accel_z)
assert result.code.direction == "R"
# Resultant >= any component
assert result.peak >= head_accel_x.peak
assert result.peak >= head_accel_y.peak
def test_from_two_channels(self, head_accel_x, head_accel_z):
result = resultant_from_channels(head_accel_x, head_accel_z)
assert result.peak > 0
def test_single_channel_raises(self):
with pytest.raises(ValueError, match="At least one"):
resultant_from_channels()
class TestTrimResample:
def test_trim(self, head_accel_x):
trimmed = trim(head_accel_x, t_start=0.0, t_end=0.05)
assert trimmed.time[0] >= 0.0
assert trimmed.time[-1] <= 0.05
assert len(trimmed.data) < len(head_accel_x.data)
def test_resample(self, head_accel_x):
resampled = resample(head_accel_x, target_rate=5000.0)
expected_samples = int(head_accel_x.duration * 5000.0)
assert abs(len(resampled.data) - expected_samples) <= 2
assert resampled.sample_rate == 5000.0

View File

@@ -19,11 +19,11 @@ class TestAppState:
def test_load_test(self): def test_load_test(self):
state = AppState() state = AppState()
loaded = state.load_test(FIXTURE_DATA) session = state.load_test(FIXTURE_DATA)
assert not state.is_empty assert not state.is_empty
assert loaded.test_id == "IMPAKT_SYNTH_001" assert session.test_id == "IMPAKT_SYNTH_001"
assert loaded.channel_count == 26 assert len(session) == 26
assert state.primary_test is loaded assert state.primary_test is session
def test_load_multiple_tests(self): def test_load_multiple_tests(self):
state = AppState() state = AppState()
@@ -32,13 +32,13 @@ class TestAppState:
if (MME_DATA / "VW1FGS15").exists(): if (MME_DATA / "VW1FGS15").exists():
t2 = state.load_test(MME_DATA / "VW1FGS15") t2 = state.load_test(MME_DATA / "VW1FGS15")
assert len(state.tests) == 2 assert len(state.tests) == 2
assert state.primary_test is t1 # First loaded is primary assert state.primary_test is t1
assert state.total_channels == 26 + 10 assert state.total_channels == 26 + 10
def test_remove_test(self): def test_remove_test(self):
state = AppState() state = AppState()
loaded = state.load_test(FIXTURE_DATA) session = state.load_test(FIXTURE_DATA)
state.remove_test(loaded.test_id) state.remove_test(session.test_id)
assert state.is_empty assert state.is_empty
def test_get_channel(self): def test_get_channel(self):
@@ -54,38 +54,38 @@ class TestAppState:
ch = state.get_channel("IMPAKT_SYNTH_001", "NONEXISTENT") ch = state.get_channel("IMPAKT_SYNTH_001", "NONEXISTENT")
assert ch is None assert ch is None
def test_resolve_channel_with_key(self): def test_get_channel_via_session(self):
"""Channels can be accessed through the Session scripting API."""
state = AppState() state = AppState()
state.load_test(FIXTURE_DATA) state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("IMPAKT_SYNTH_001::11HEAD0000ACXA") session = state.primary_test
assert ch is not None assert session is not None
ch_handle = session.channel("11HEAD0000ACXA")
assert ch_handle.name == "11HEAD0000ACXA"
def test_resolve_channel_primary_default(self): def test_session_fluent_transforms(self):
"""Fluent transform chaining works through the Session API."""
state = AppState() state = AppState()
state.load_test(FIXTURE_DATA) state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("11HEAD0000ACXA") session = state.primary_test
assert ch is not None ch = session.channel("11HEAD0000ACXA")
filtered = ch.transform.cfc(600).transform.y_align()
assert filtered.raw.cfc_class == 600
assert len(filtered.raw.transform_history) == 2
def test_resolve_channel_with_cfc(self): def test_session_compute_criteria(self):
"""Session.compute_criteria() auto-detects channels."""
state = AppState() state = AppState()
state.load_test(FIXTURE_DATA) state.load_test(FIXTURE_DATA)
ch = state.resolve_channel("11HEAD0000ACXA", cfc_class=600) criteria = state.primary_test.compute_criteria()
assert ch is not None assert len(criteria) > 0
assert ch.cfc_class == 600 assert "HIC15" in criteria or "Chest Deflection" in criteria
def test_flat_channel_list(self):
state = AppState()
state.load_test(FIXTURE_DATA)
items = state.flat_channel_list()
assert len(items) == 26
assert all("value" in item and "label" in item for item in items)
def test_build_channel_tree(self): def test_build_channel_tree(self):
state = AppState() state = AppState()
state.load_test(FIXTURE_DATA) state.load_test(FIXTURE_DATA)
tree = state.build_channel_tree() tree = state.build_channel_tree()
assert "IMPAKT_SYNTH_001" in tree assert "IMPAKT_SYNTH_001" in tree
# Should have hierarchical structure
test_tree = tree["IMPAKT_SYNTH_001"] test_tree = tree["IMPAKT_SYNTH_001"]
assert len(test_tree) > 0 assert len(test_tree) > 0
@@ -94,15 +94,23 @@ class TestAppState:
class TestAppStateRealData: class TestAppStateRealData:
def test_load_real_mme(self): def test_load_real_mme(self):
state = AppState() state = AppState()
loaded = state.load_test(MME_DATA / "3239") session = state.load_test(MME_DATA / "3239")
assert loaded.test_id == "3239" assert session.test_id == "3239"
assert loaded.channel_count == 133 assert len(session) == 133
def test_channel_tree_real_data(self): def test_channel_tree_real_data(self):
state = AppState() state = AppState()
state.load_test(MME_DATA / "3239") state.load_test(MME_DATA / "3239")
tree = state.build_channel_tree() tree = state.build_channel_tree()
assert "3239" in tree assert "3239" in tree
# Should contain "Driver" in some key
test_tree = tree["3239"] test_tree = tree["3239"]
assert any("Driver" in k for k in test_tree) assert any("Driver" in k for k in test_tree)
def test_session_evaluate_real_data(self):
"""Full pipeline through Session API on real data."""
state = AppState()
state.load_test(MME_DATA / "3239")
result = state.primary_test.evaluate("euro_ncap")
assert result.stars is not None
assert result.stars >= 0
assert len(result.region_scores) > 0