bookmark - Refactor
This commit is contained in:
29
src/impakt/io/csv.py
Normal file
29
src/impakt/io/csv.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""CSV data reader (stub).
|
||||
|
||||
Placeholder for a future CSV reader plugin.
|
||||
Install or implement to enable reading CSV crash test data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from impakt.channel.model import TestData, TestMetadata
|
||||
|
||||
|
||||
class CSVReader:
|
||||
"""Reader for CSV time-series data (not yet implemented)."""
|
||||
|
||||
@property
|
||||
def format_name(self) -> str:
|
||||
return "CSV"
|
||||
|
||||
def supports(self, path: Path) -> bool:
|
||||
# Only claim support for .csv files with crash test structure
|
||||
return False # Disabled until implemented
|
||||
|
||||
def metadata(self, path: Path) -> TestMetadata:
|
||||
raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.")
|
||||
|
||||
def read(self, path: Path) -> TestData:
|
||||
raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.")
|
||||
37
src/impakt/io/tdms.py
Normal file
37
src/impakt/io/tdms.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""NI TDMS data reader (stub).
|
||||
|
||||
Placeholder for a future TDMS reader plugin.
|
||||
Install nptdms to enable: pip install impakt[tdms]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from impakt.channel.model import TestData, TestMetadata
|
||||
|
||||
|
||||
class TDMSReader:
|
||||
"""Reader for NI TDMS files (not yet implemented).
|
||||
|
||||
Requires the optional ``nptdms`` dependency::
|
||||
|
||||
pip install impakt[tdms]
|
||||
"""
|
||||
|
||||
@property
|
||||
def format_name(self) -> str:
|
||||
return "NI TDMS"
|
||||
|
||||
def supports(self, path: Path) -> bool:
|
||||
return False # Disabled until implemented
|
||||
|
||||
def metadata(self, path: Path) -> TestMetadata:
|
||||
raise NotImplementedError(
|
||||
"TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]"
|
||||
)
|
||||
|
||||
def read(self, path: Path) -> TestData:
|
||||
raise NotImplementedError(
|
||||
"TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]"
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -2,6 +2,9 @@
|
||||
|
||||
Renders PlotSpec objects into interactive Plotly figures with
|
||||
support for corridors, dual X-cursors, and export.
|
||||
|
||||
The PlotEngine is the **single rendering path** — both the scripting
|
||||
API and the web UI construct PlotSpec objects and delegate here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -30,12 +33,39 @@ DEFAULT_COLORS = [
|
||||
|
||||
|
||||
class PlotEngine:
|
||||
"""Renders PlotSpec into Plotly figures."""
|
||||
"""Renders PlotSpec into Plotly figures.
|
||||
|
||||
Supports two modes:
|
||||
- **Default**: Full interactive plot with hover tooltips and legend.
|
||||
- **Compact** (``spec.compact=True``): Web UI mode with no legend, no
|
||||
hover tooltips, tight margins, smaller axis labels. Cursor tracking
|
||||
is handled by external JS.
|
||||
"""
|
||||
|
||||
def render(self, spec: PlotSpec) -> go.Figure:
|
||||
"""Render a PlotSpec into an interactive Plotly figure."""
|
||||
fig = go.Figure()
|
||||
|
||||
if not spec.channels:
|
||||
fig.update_layout(
|
||||
template="plotly_white",
|
||||
annotations=[
|
||||
{
|
||||
"text": "Select channels to plot",
|
||||
"xref": "paper",
|
||||
"yref": "paper",
|
||||
"x": 0.5,
|
||||
"y": 0.5,
|
||||
"showarrow": False,
|
||||
"font": {"size": 14, "color": "#bbb"},
|
||||
}
|
||||
],
|
||||
xaxis={"visible": False},
|
||||
yaxis={"visible": False},
|
||||
margin={"l": 40, "r": 20, "t": 30, "b": 40},
|
||||
)
|
||||
return fig
|
||||
|
||||
# Add corridor fills first (behind data traces)
|
||||
for corridor in spec.corridors:
|
||||
self._add_corridor(fig, corridor)
|
||||
@@ -54,36 +84,74 @@ class PlotEngine:
|
||||
color = style.color or DEFAULT_COLORS[i % len(DEFAULT_COLORS)]
|
||||
label = style.label or ch_ref.label
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=ch.time,
|
||||
y=ch.data,
|
||||
mode="lines",
|
||||
name=label,
|
||||
line=dict(
|
||||
trace_kwargs: dict[str, Any] = {
|
||||
"x": ch.time,
|
||||
"y": ch.data,
|
||||
"mode": "lines",
|
||||
"name": label,
|
||||
"line": dict(
|
||||
color=color,
|
||||
width=style.line_width,
|
||||
dash=style.line_dash,
|
||||
),
|
||||
opacity=style.opacity,
|
||||
hovertemplate=f"{label}<br>t=%{{x:.6f}}s<br>%{{y:.4f}} {ch.unit}<extra></extra>",
|
||||
)
|
||||
"opacity": style.opacity,
|
||||
}
|
||||
|
||||
if not spec.compact:
|
||||
trace_kwargs["hovertemplate"] = (
|
||||
f"{label}<br>t=%{{x:.6f}}s<br>%{{y:.4f}} {ch.unit}<extra></extra>"
|
||||
)
|
||||
|
||||
fig.add_trace(go.Scatter(**trace_kwargs))
|
||||
|
||||
# Add cursor lines
|
||||
if spec.x_cursors:
|
||||
x1, x2 = spec.x_cursors
|
||||
for x_val, label in [(x1, "x1"), (x2, "x2")]:
|
||||
if spec.compact:
|
||||
# Color-coded cursor lines for compact mode
|
||||
fig.add_vline(
|
||||
x=x1,
|
||||
line_dash="dash",
|
||||
line_color="rgba(220,53,69,0.6)",
|
||||
line_width=1,
|
||||
annotation_text=f"X1={x1:.4f}s",
|
||||
annotation_font_size=9,
|
||||
annotation_font_color="rgba(220,53,69,0.8)",
|
||||
)
|
||||
fig.add_vline(
|
||||
x=x2,
|
||||
line_dash="dash",
|
||||
line_color="rgba(13,110,253,0.6)",
|
||||
line_width=1,
|
||||
annotation_text=f"X2={x2:.4f}s",
|
||||
annotation_font_size=9,
|
||||
annotation_font_color="rgba(13,110,253,0.8)",
|
||||
)
|
||||
else:
|
||||
for x_val, lbl in [(x1, "x1"), (x2, "x2")]:
|
||||
fig.add_vline(
|
||||
x=x_val,
|
||||
line_dash="dash",
|
||||
line_color="gray",
|
||||
line_width=1,
|
||||
annotation_text=f"{label}={x_val:.6f}s",
|
||||
annotation_text=f"{lbl}={x_val:.6f}s",
|
||||
annotation_position="top",
|
||||
)
|
||||
|
||||
# Layout
|
||||
if spec.compact:
|
||||
margin = spec.margin or {"l": 45, "r": 8, "t": 4, "b": 28}
|
||||
fig.update_layout(
|
||||
xaxis_title=dict(text=spec.x_label, font=dict(size=10, color="#999")),
|
||||
yaxis_title=dict(text=spec.y_label, font=dict(size=10, color="#999")),
|
||||
template="plotly_white",
|
||||
hovermode=False,
|
||||
showlegend=False,
|
||||
margin=margin,
|
||||
)
|
||||
else:
|
||||
margin = spec.margin or {"l": 60, "r": 20, "t": 40 if spec.title else 10, "b": 60}
|
||||
hovermode = spec.hovermode if spec.hovermode is not None else "x unified"
|
||||
fig.update_layout(
|
||||
title=spec.title,
|
||||
xaxis_title=spec.x_label,
|
||||
@@ -92,7 +160,7 @@ class PlotEngine:
|
||||
height=spec.height,
|
||||
width=spec.width,
|
||||
template="plotly_white",
|
||||
hovermode="x unified",
|
||||
hovermode=hovermode,
|
||||
legend=dict(
|
||||
orientation="h",
|
||||
yanchor="bottom",
|
||||
@@ -100,6 +168,7 @@ class PlotEngine:
|
||||
xanchor="center",
|
||||
x=0.5,
|
||||
),
|
||||
margin=margin,
|
||||
)
|
||||
|
||||
if spec.show_grid:
|
||||
@@ -117,7 +186,6 @@ class PlotEngine:
|
||||
"""Add a corridor (tolerance band) to the figure."""
|
||||
style = corridor.style
|
||||
|
||||
# Upper bound
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=corridor.time,
|
||||
@@ -128,8 +196,6 @@ class PlotEngine:
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Lower bound with fill to upper
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=corridor.time,
|
||||
@@ -144,16 +210,7 @@ class PlotEngine:
|
||||
)
|
||||
|
||||
def to_image(self, spec: PlotSpec, format: str = "png", scale: float = 2.0) -> bytes:
|
||||
"""Render to a static image.
|
||||
|
||||
Args:
|
||||
spec: Plot specification.
|
||||
format: Image format ('png', 'svg', 'pdf', 'jpeg').
|
||||
scale: Resolution multiplier.
|
||||
|
||||
Returns:
|
||||
Image bytes.
|
||||
"""
|
||||
"""Render to a static image."""
|
||||
fig = self.render(spec)
|
||||
return fig.to_image(format=format, scale=scale)
|
||||
|
||||
@@ -168,16 +225,7 @@ def cursor_values(
|
||||
x1: float,
|
||||
x2: float,
|
||||
) -> CursorValues:
|
||||
"""Compute interpolated values at two X-axis positions.
|
||||
|
||||
Args:
|
||||
spec_or_channels: PlotSpec or list of Channels.
|
||||
x1: First cursor position (time).
|
||||
x2: Second cursor position (time).
|
||||
|
||||
Returns:
|
||||
CursorValues with interpolated values for each channel.
|
||||
"""
|
||||
"""Compute interpolated values at two X-axis positions."""
|
||||
channels: list[tuple[str, Channel]] = []
|
||||
|
||||
if isinstance(spec_or_channels, PlotSpec):
|
||||
|
||||
@@ -168,3 +168,8 @@ class PlotSpec:
|
||||
show_grid: bool = True
|
||||
height: int = 500
|
||||
width: int = 900
|
||||
# Web UI compact mode: disables hover tooltips, removes legend,
|
||||
# uses tight margins, and renders axis labels in a smaller font.
|
||||
compact: bool = False
|
||||
hovermode: str | bool = "x unified"
|
||||
margin: dict[str, int] | None = None
|
||||
|
||||
@@ -42,8 +42,15 @@ class PluginRegistry:
|
||||
self._plugins: list[ImpaktPlugin] = []
|
||||
|
||||
def register_reader(self, reader: Any) -> None:
|
||||
"""Register a data reader."""
|
||||
"""Register a data reader.
|
||||
|
||||
Forwards to the IO reader registry so plugin readers are
|
||||
discoverable by Session.open() and the web UI.
|
||||
"""
|
||||
self._readers.append(reader)
|
||||
from impakt.io.reader import register_reader
|
||||
|
||||
register_reader(reader)
|
||||
logger.info("Plugin reader registered: %s", getattr(reader, "format_name", reader))
|
||||
|
||||
def register_transform(self, name: str, transform_cls: type) -> None:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -9,6 +9,7 @@ Threshold values are versioned — this module supports multiple protocol years.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from impakt.criteria.base import CriterionResult
|
||||
@@ -64,9 +65,44 @@ def _points_from_color(color: Color, max_points: float) -> float:
|
||||
return max_points * color_fractions[color]
|
||||
|
||||
|
||||
# Threshold sets by year
|
||||
# Format: {criterion: (green, yellow, orange, brown, red, higher_is_worse, max_points)}
|
||||
THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]] = {
|
||||
# ---------------------------------------------------------------------------
|
||||
# Threshold loading — from YAML files or hardcoded fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_THRESHOLDS_DIR = Path(__file__).parent / "thresholds"
|
||||
|
||||
|
||||
def _load_yaml_thresholds(
|
||||
version: str,
|
||||
) -> dict[str, tuple[float, float, float, float, float, bool, float]]:
|
||||
"""Load thresholds from a YAML file for a given version."""
|
||||
yaml_path = _THRESHOLDS_DIR / f"euro_ncap_{version}.yaml"
|
||||
if not yaml_path.exists():
|
||||
return {}
|
||||
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
|
||||
result: dict[str, tuple[float, float, float, float, float, bool, float]] = {}
|
||||
for name, vals in data.items():
|
||||
if isinstance(vals, dict):
|
||||
result[name] = (
|
||||
float(vals["green"]),
|
||||
float(vals["yellow"]),
|
||||
float(vals["orange"]),
|
||||
float(vals["brown"]),
|
||||
float(vals["red"]),
|
||||
bool(vals.get("higher_is_worse", True)),
|
||||
float(vals.get("max_points", 0)),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# Hardcoded fallback (used if YAML files are missing)
|
||||
_THRESHOLDS_2024_FALLBACK: dict[str, tuple[float, float, float, float, float, bool, float]] = {
|
||||
"HIC15": (500.0, 620.0, 700.0, 850.0, 1000.0, True, 4.0),
|
||||
"3ms Clip": (42.0, 48.0, 54.0, 57.0, 60.0, True, 4.0),
|
||||
"Chest Deflection": (22.0, 34.0, 42.0, 50.0, 63.0, True, 4.0),
|
||||
@@ -77,9 +113,17 @@ THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]
|
||||
"Viscous Criterion": (0.32, 0.56, 0.8, 0.9, 1.0, True, 2.0),
|
||||
}
|
||||
|
||||
THRESHOLDS: dict[str, dict[str, tuple[float, float, float, float, float, bool, float]]] = {
|
||||
"2024": THRESHOLDS_2024,
|
||||
}
|
||||
|
||||
def _get_thresholds(
|
||||
version: str,
|
||||
) -> dict[str, tuple[float, float, float, float, float, bool, float]]:
|
||||
"""Get thresholds for a version, preferring YAML over hardcoded."""
|
||||
thresholds = _load_yaml_thresholds(version)
|
||||
if thresholds:
|
||||
return thresholds
|
||||
if version == "2024":
|
||||
return _THRESHOLDS_2024_FALLBACK
|
||||
return {}
|
||||
|
||||
|
||||
class EuroNCAP:
|
||||
@@ -87,11 +131,12 @@ class EuroNCAP:
|
||||
|
||||
def __init__(self, version: str = "2024") -> None:
|
||||
self._version = version
|
||||
if version not in THRESHOLDS:
|
||||
self._thresholds = _get_thresholds(version)
|
||||
if not self._thresholds:
|
||||
raise ValueError(
|
||||
f"Unknown Euro NCAP version: {version}. Available: {list(THRESHOLDS.keys())}"
|
||||
f"No thresholds found for Euro NCAP version: {version}. "
|
||||
f"Check {_THRESHOLDS_DIR} for YAML files."
|
||||
)
|
||||
self._thresholds = THRESHOLDS[version]
|
||||
|
||||
@property
|
||||
def protocol_name(self) -> str:
|
||||
|
||||
@@ -6,27 +6,55 @@ The overall rating is determined by the worst sub-rating.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from impakt.criteria.base import CriterionResult
|
||||
from impakt.protocol.base import BodyRegionScore, ProtocolResult, Rating
|
||||
|
||||
_THRESHOLDS_DIR = Path(__file__).parent / "thresholds"
|
||||
|
||||
# Thresholds: (good_limit, acceptable_limit, marginal_limit)
|
||||
# Values above marginal_limit = Poor
|
||||
# higher_is_worse indicates that higher values are worse
|
||||
IIHS_THRESHOLDS_2024: dict[str, tuple[float, float, float, bool]] = {
|
||||
|
||||
def _load_iihs_yaml(version: str) -> dict[str, tuple[float, float, float, bool]]:
|
||||
"""Load IIHS thresholds from YAML."""
|
||||
yaml_path = _THRESHOLDS_DIR / f"iihs_{version}.yaml"
|
||||
if not yaml_path.exists():
|
||||
return {}
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
result: dict[str, tuple[float, float, float, bool]] = {}
|
||||
for name, vals in data.items():
|
||||
if isinstance(vals, dict):
|
||||
result[name] = (
|
||||
float(vals["good"]),
|
||||
float(vals["acceptable"]),
|
||||
float(vals["marginal"]),
|
||||
bool(vals.get("higher_is_worse", True)),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# Hardcoded fallback
|
||||
_IIHS_2024_FALLBACK: dict[str, tuple[float, float, float, bool]] = {
|
||||
"HIC15": (250.0, 500.0, 700.0, True),
|
||||
"Chest Deflection": (38.0, 50.0, 63.0, True), # mm
|
||||
"Femur Load Left": (3.8, 6.2, 10.0, True), # kN
|
||||
"Femur Load Right": (3.8, 6.2, 10.0, True), # kN
|
||||
"Chest Deflection": (38.0, 50.0, 63.0, True),
|
||||
"Femur Load Left": (3.8, 6.2, 10.0, True),
|
||||
"Femur Load Right": (3.8, 6.2, 10.0, True),
|
||||
"Nij": (0.52, 0.78, 1.0, True),
|
||||
"Tibia Index": (0.5, 0.8, 1.3, True),
|
||||
}
|
||||
|
||||
IIHS_THRESHOLDS: dict[str, dict[str, tuple[float, float, float, bool]]] = {
|
||||
"2024": IIHS_THRESHOLDS_2024,
|
||||
}
|
||||
|
||||
def _get_iihs_thresholds(version: str) -> dict[str, tuple[float, float, float, bool]]:
|
||||
thresholds = _load_iihs_yaml(version)
|
||||
if thresholds:
|
||||
return thresholds
|
||||
if version == "2024":
|
||||
return _IIHS_2024_FALLBACK
|
||||
return {}
|
||||
|
||||
|
||||
def _rate_value(
|
||||
@@ -70,11 +98,12 @@ class IIHS:
|
||||
|
||||
def __init__(self, version: str = "2024") -> None:
|
||||
self._version = version
|
||||
if version not in IIHS_THRESHOLDS:
|
||||
self._thresholds = _get_iihs_thresholds(version)
|
||||
if not self._thresholds:
|
||||
raise ValueError(
|
||||
f"Unknown IIHS version: {version}. Available: {list(IIHS_THRESHOLDS.keys())}"
|
||||
f"No thresholds found for IIHS version: {version}. "
|
||||
f"Check {_THRESHOLDS_DIR} for YAML files."
|
||||
)
|
||||
self._thresholds = IIHS_THRESHOLDS[version]
|
||||
|
||||
@property
|
||||
def protocol_name(self) -> str:
|
||||
|
||||
76
src/impakt/protocol/thresholds/euro_ncap_2024.yaml
Normal file
76
src/impakt/protocol/thresholds/euro_ncap_2024.yaml
Normal file
@@ -0,0 +1,76 @@
|
||||
# Euro NCAP 2024 Adult Occupant Frontal Impact Thresholds
|
||||
#
|
||||
# Each criterion: [green, yellow, orange, brown, red, higher_is_worse, max_points]
|
||||
# Sliding-scale: values at or below green = full points, at or above red = zero.
|
||||
|
||||
HIC15:
|
||||
green: 500.0
|
||||
yellow: 620.0
|
||||
orange: 700.0
|
||||
brown: 850.0
|
||||
red: 1000.0
|
||||
higher_is_worse: true
|
||||
max_points: 4.0
|
||||
|
||||
3ms Clip:
|
||||
green: 42.0
|
||||
yellow: 48.0
|
||||
orange: 54.0
|
||||
brown: 57.0
|
||||
red: 60.0
|
||||
higher_is_worse: true
|
||||
max_points: 4.0
|
||||
|
||||
Chest Deflection:
|
||||
green: 22.0
|
||||
yellow: 34.0
|
||||
orange: 42.0
|
||||
brown: 50.0
|
||||
red: 63.0
|
||||
higher_is_worse: true
|
||||
max_points: 4.0
|
||||
|
||||
Nij:
|
||||
green: 0.5
|
||||
yellow: 0.65
|
||||
orange: 0.8
|
||||
brown: 0.9
|
||||
red: 1.0
|
||||
higher_is_worse: true
|
||||
max_points: 2.0
|
||||
|
||||
Femur Load Left:
|
||||
green: 3.8
|
||||
yellow: 5.4
|
||||
orange: 7.0
|
||||
brown: 8.5
|
||||
red: 10.0
|
||||
higher_is_worse: true
|
||||
max_points: 2.0
|
||||
|
||||
Femur Load Right:
|
||||
green: 3.8
|
||||
yellow: 5.4
|
||||
orange: 7.0
|
||||
brown: 8.5
|
||||
red: 10.0
|
||||
higher_is_worse: true
|
||||
max_points: 2.0
|
||||
|
||||
Tibia Index:
|
||||
green: 0.4
|
||||
yellow: 0.7
|
||||
orange: 1.0
|
||||
brown: 1.15
|
||||
red: 1.3
|
||||
higher_is_worse: true
|
||||
max_points: 2.0
|
||||
|
||||
Viscous Criterion:
|
||||
green: 0.32
|
||||
yellow: 0.56
|
||||
orange: 0.8
|
||||
brown: 0.9
|
||||
red: 1.0
|
||||
higher_is_worse: true
|
||||
max_points: 2.0
|
||||
40
src/impakt/protocol/thresholds/iihs_2024.yaml
Normal file
40
src/impakt/protocol/thresholds/iihs_2024.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
# IIHS 2024 Crashworthiness Thresholds
|
||||
#
|
||||
# Each criterion: [good, acceptable, marginal, higher_is_worse]
|
||||
# Values above marginal = Poor.
|
||||
|
||||
HIC15:
|
||||
good: 250.0
|
||||
acceptable: 500.0
|
||||
marginal: 700.0
|
||||
higher_is_worse: true
|
||||
|
||||
Chest Deflection:
|
||||
good: 38.0
|
||||
acceptable: 50.0
|
||||
marginal: 63.0
|
||||
higher_is_worse: true
|
||||
|
||||
Femur Load Left:
|
||||
good: 3.8
|
||||
acceptable: 6.2
|
||||
marginal: 10.0
|
||||
higher_is_worse: true
|
||||
|
||||
Femur Load Right:
|
||||
good: 3.8
|
||||
acceptable: 6.2
|
||||
marginal: 10.0
|
||||
higher_is_worse: true
|
||||
|
||||
Nij:
|
||||
good: 0.52
|
||||
acceptable: 0.78
|
||||
marginal: 1.0
|
||||
higher_is_worse: true
|
||||
|
||||
Tibia Index:
|
||||
good: 0.5
|
||||
acceptable: 0.8
|
||||
marginal: 1.3
|
||||
higher_is_worse: true
|
||||
Binary file not shown.
Binary file not shown.
@@ -2,6 +2,23 @@
|
||||
|
||||
Provides the ``Session`` and ``Template`` classes that serve as the
|
||||
primary entry points for both scripting and the web UI.
|
||||
|
||||
Usage::
|
||||
|
||||
from impakt import Session
|
||||
|
||||
test = Session.open("tests/mme_data/3239")
|
||||
ch = test.channel("11HEAD0000H3ACXP")
|
||||
|
||||
# Fluent chaining — each transform returns a ChannelHandle
|
||||
filtered = ch.transform.cfc(1000).transform.y_align()
|
||||
|
||||
# Compute all injury criteria (auto-detects channels)
|
||||
criteria = test.compute_criteria()
|
||||
|
||||
# Score against a protocol
|
||||
result = test.evaluate("euro_ncap")
|
||||
result.to_pdf("report.pdf")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -36,12 +53,19 @@ class Session:
|
||||
Wraps TestData with session state, template binding, and
|
||||
convenience methods for transforms, criteria, and plotting.
|
||||
|
||||
Usage:
|
||||
Usage::
|
||||
|
||||
test = Session.open("/path/to/test_001")
|
||||
ch = test.channel("11HEAD0000ACXA")
|
||||
filtered = ch.transform.cfc(1000)
|
||||
filtered = ch.transform.cfc(1000) # returns ChannelHandle, chainable
|
||||
|
||||
criteria = test.compute_criteria() # auto-detect channels
|
||||
result = test.evaluate("euro_ncap") # score against protocol
|
||||
result.to_pdf("report.pdf")
|
||||
"""
|
||||
|
||||
_plugins_discovered = False
|
||||
|
||||
def __init__(self, test_data: TestData, session_mgr: SessionManager | None = None) -> None:
|
||||
self._data = test_data
|
||||
self._session_mgr = session_mgr or (
|
||||
@@ -49,12 +73,29 @@ class Session:
|
||||
)
|
||||
self._template: TemplateSpec | None = None
|
||||
|
||||
@classmethod
|
||||
def _discover_plugins(cls) -> None:
|
||||
"""Discover and register plugins. Called once on first Session.open()."""
|
||||
if cls._plugins_discovered:
|
||||
return
|
||||
cls._plugins_discovered = True
|
||||
try:
|
||||
from impakt.plugin.registry import discover_all
|
||||
|
||||
discover_all()
|
||||
except Exception as e:
|
||||
logger.debug("Plugin discovery failed (non-fatal): %s", e)
|
||||
|
||||
@classmethod
|
||||
def open(cls, path: str | Path) -> Session:
|
||||
"""Open a crash test from a path.
|
||||
|
||||
Auto-detects the format using the reader registry.
|
||||
Discovers plugins on first call.
|
||||
"""
|
||||
# Discover plugins (idempotent — safe to call multiple times)
|
||||
cls._discover_plugins()
|
||||
|
||||
path = Path(path).resolve()
|
||||
registry = get_registry()
|
||||
test_data = registry.read(path)
|
||||
@@ -87,6 +128,11 @@ class Session:
|
||||
"""Underlying TestData object."""
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def path(self) -> Path | None:
|
||||
"""Path to the test data directory."""
|
||||
return self._data.path
|
||||
|
||||
@property
|
||||
def channel_names(self) -> list[str]:
|
||||
return self._data.channel_names
|
||||
@@ -98,7 +144,10 @@ class Session:
|
||||
# ----- Channel access -----
|
||||
|
||||
def channel(self, name: str) -> ChannelHandle:
|
||||
"""Get a channel by name, wrapped with transform convenience methods."""
|
||||
"""Get a channel by name, wrapped with transform convenience methods.
|
||||
|
||||
Returns a ChannelHandle with a fluent ``.transform`` interface.
|
||||
"""
|
||||
ch = self._data.get(name)
|
||||
return ChannelHandle(ch)
|
||||
|
||||
@@ -118,6 +167,47 @@ class Session:
|
||||
"""Hierarchical channel tree for UI display."""
|
||||
return self._data.channel_tree()
|
||||
|
||||
# ----- Criteria & Protocol -----
|
||||
|
||||
def compute_criteria(self) -> dict[str, CriterionResult]:
|
||||
"""Auto-detect channels and compute all applicable injury criteria.
|
||||
|
||||
Uses ISO channel naming to find the right channels for each criterion.
|
||||
Returns a dict of criterion name -> CriterionResult.
|
||||
"""
|
||||
from impakt.web.components.criteria import auto_compute_criteria
|
||||
|
||||
return auto_compute_criteria(self._data)
|
||||
|
||||
def evaluate(self, protocol: str = "euro_ncap", version: str = "") -> ProtocolResult:
|
||||
"""Compute criteria and score against a rating protocol.
|
||||
|
||||
Args:
|
||||
protocol: Protocol name ('euro_ncap', 'us_ncap', 'iihs').
|
||||
version: Protocol version (optional, uses latest if empty).
|
||||
|
||||
Returns:
|
||||
ProtocolResult with stars/rating and per-region scores.
|
||||
"""
|
||||
criteria = self.compute_criteria()
|
||||
|
||||
if protocol == "euro_ncap":
|
||||
from impakt.protocol.euro_ncap import evaluate
|
||||
|
||||
return evaluate(criteria, version=version or "2024")
|
||||
elif protocol == "us_ncap":
|
||||
from impakt.protocol.us_ncap import evaluate
|
||||
|
||||
return evaluate(criteria, version=version or "2023")
|
||||
elif protocol == "iihs":
|
||||
from impakt.protocol.iihs import evaluate
|
||||
|
||||
return evaluate(criteria, version=version or "2024")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown protocol: {protocol}. Use 'euro_ncap', 'us_ncap', or 'iihs'."
|
||||
)
|
||||
|
||||
# ----- Template -----
|
||||
|
||||
def apply_template(self, name_or_spec: str | TemplateSpec) -> None:
|
||||
@@ -188,15 +278,16 @@ class Session:
|
||||
class ChannelHandle:
|
||||
"""Wrapper around a Channel providing fluent transform access.
|
||||
|
||||
Example:
|
||||
Each transform method on ``.transform`` returns a new ``ChannelHandle``,
|
||||
enabling chaining::
|
||||
|
||||
ch = session.channel("11HEAD0000ACXA")
|
||||
filtered = ch.transform.cfc(1000)
|
||||
aligned = ch.transform.x_align(method="threshold", threshold_value=5.0)
|
||||
result = ch.transform.cfc(1000).transform.y_align().transform.trim(t_end=0.1)
|
||||
"""
|
||||
|
||||
def __init__(self, channel: Channel) -> None:
|
||||
self._channel = channel
|
||||
self.transform = TransformProxy(channel)
|
||||
self.transform = TransformProxy(self)
|
||||
|
||||
@property
|
||||
def raw(self) -> Channel:
|
||||
@@ -207,6 +298,10 @@ class ChannelHandle:
|
||||
def name(self) -> str:
|
||||
return self._channel.name
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
return self._channel.code
|
||||
|
||||
@property
|
||||
def data(self) -> np.ndarray:
|
||||
return self._channel.data
|
||||
@@ -219,6 +314,14 @@ class ChannelHandle:
|
||||
def unit(self) -> str:
|
||||
return self._channel.unit
|
||||
|
||||
@property
|
||||
def peak(self) -> float:
|
||||
return self._channel.peak
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> float:
|
||||
return self._channel.sample_rate
|
||||
|
||||
def value_at(self, t: float) -> float:
|
||||
return self._channel.value_at(t)
|
||||
|
||||
@@ -238,44 +341,51 @@ class ChannelHandle:
|
||||
class TransformProxy:
|
||||
"""Fluent transform interface for a channel.
|
||||
|
||||
Each method returns a new Channel (non-destructive).
|
||||
Each method returns a new ``ChannelHandle`` wrapping the transformed
|
||||
channel, enabling chaining.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: Channel) -> None:
|
||||
self._channel = channel
|
||||
def __init__(self, handle: ChannelHandle) -> None:
|
||||
self._handle = handle
|
||||
|
||||
def cfc(self, cfc_class: int) -> Channel:
|
||||
@property
|
||||
def _channel(self) -> Channel:
|
||||
return self._handle._channel
|
||||
|
||||
def cfc(self, cfc_class: int) -> ChannelHandle:
|
||||
"""Apply CFC filter."""
|
||||
from impakt.transform.cfc import CFCFilter
|
||||
|
||||
return CFCFilter(cfc_class=cfc_class).apply(self._channel)
|
||||
return ChannelHandle(CFCFilter(cfc_class=cfc_class).apply(self._channel))
|
||||
|
||||
def x_align(
|
||||
self, method: str = "manual", reference_time: float = 0.0, **kwargs: Any
|
||||
) -> Channel:
|
||||
) -> ChannelHandle:
|
||||
"""Apply time-zero alignment."""
|
||||
from impakt.transform.align import XAlign
|
||||
|
||||
return XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel)
|
||||
return ChannelHandle(
|
||||
XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel)
|
||||
)
|
||||
|
||||
def y_align(self, window: tuple[float, float] | None = None) -> Channel:
|
||||
def y_align(self, window: tuple[float, float] | None = None) -> ChannelHandle:
|
||||
"""Apply Y-axis zero correction."""
|
||||
from impakt.transform.align import YAlign
|
||||
|
||||
start, end = window if window else (None, None)
|
||||
return YAlign(window_start=start, window_end=end).apply(self._channel)
|
||||
return ChannelHandle(YAlign(window_start=start, window_end=end).apply(self._channel))
|
||||
|
||||
def trim(self, t_start: float | None = None, t_end: float | None = None) -> Channel:
|
||||
def trim(self, t_start: float | None = None, t_end: float | None = None) -> ChannelHandle:
|
||||
"""Trim to a time range."""
|
||||
from impakt.transform.resample import Trim
|
||||
|
||||
return Trim(t_start=t_start, t_end=t_end).apply(self._channel)
|
||||
return ChannelHandle(Trim(t_start=t_start, t_end=t_end).apply(self._channel))
|
||||
|
||||
def resample(self, target_rate: float) -> Channel:
|
||||
def resample(self, target_rate: float) -> ChannelHandle:
|
||||
"""Resample to a new rate."""
|
||||
from impakt.transform.resample import Resample
|
||||
|
||||
return Resample(target_rate=target_rate).apply(self._channel)
|
||||
return ChannelHandle(Resample(target_rate=target_rate).apply(self._channel))
|
||||
|
||||
|
||||
class Template:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Dash web application factory.
|
||||
|
||||
Creates the Dash app with all layout components and callbacks registered.
|
||||
The AppState holds server-side data; Dash stores hold lightweight UI state.
|
||||
The AppState holds Session objects server-side; Dash stores hold lightweight UI state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -13,14 +13,12 @@ import dash
|
||||
import dash_bootstrap_components as dbc
|
||||
|
||||
from impakt.channel.model import TestData
|
||||
from impakt.script.api import Session
|
||||
from impakt.template.library import TemplateLibrary
|
||||
from impakt.web.callbacks import register_callbacks
|
||||
from impakt.web.layout import build_layout
|
||||
from impakt.web.state import AppState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from impakt.script.api import Session
|
||||
|
||||
|
||||
def create_app(
|
||||
session_or_data: Session | TestData | None = None,
|
||||
@@ -37,21 +35,15 @@ def create_app(
|
||||
Returns:
|
||||
Configured Dash app ready to run.
|
||||
"""
|
||||
# Build or use provided AppState
|
||||
if app_state is None:
|
||||
app_state = AppState()
|
||||
if session_or_data is not None:
|
||||
if hasattr(session_or_data, "data"):
|
||||
test_data: TestData = session_or_data.data # type: ignore[union-attr]
|
||||
else:
|
||||
test_data = session_or_data # type: ignore[assignment]
|
||||
|
||||
# Create a LoadedTest from TestData directly
|
||||
from impakt.web.state import LoadedTest
|
||||
|
||||
loaded = LoadedTest(test_data)
|
||||
app_state._tests[test_data.test_id] = loaded
|
||||
app_state._test_order.append(test_data.test_id)
|
||||
if isinstance(session_or_data, Session):
|
||||
app_state.add_session(session_or_data)
|
||||
elif isinstance(session_or_data, TestData):
|
||||
# Wrap raw TestData in a Session
|
||||
session = Session(session_or_data)
|
||||
app_state.add_session(session)
|
||||
|
||||
# Discover templates
|
||||
if template_names is None:
|
||||
@@ -64,14 +56,13 @@ def create_app(
|
||||
# Title
|
||||
title = "Impakt"
|
||||
if app_state.primary_test:
|
||||
title = f"Impakt — {app_state.primary_test.test_id}"
|
||||
title = f"Impakt \u2014 {app_state.primary_test.test_id}"
|
||||
|
||||
app = dash.Dash(
|
||||
__name__,
|
||||
external_stylesheets=[dbc.themes.FLATLY],
|
||||
title=title,
|
||||
suppress_callback_exceptions=True,
|
||||
# Prevent browser from caching old layouts
|
||||
serve_locally=True,
|
||||
meta_tags=[
|
||||
{"http-equiv": "Cache-Control", "content": "no-cache, no-store, must-revalidate"},
|
||||
@@ -92,16 +83,9 @@ def serve(
|
||||
port: int = 8050,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
"""Convenience function to create and run the web UI.
|
||||
|
||||
Args:
|
||||
session_or_data: Session or TestData to visualize, or None for empty app.
|
||||
template: Template name to pre-apply.
|
||||
port: Server port.
|
||||
debug: Enable Dash debug mode.
|
||||
"""
|
||||
"""Convenience function to create and run the web UI."""
|
||||
if template and session_or_data and hasattr(session_or_data, "apply_template"):
|
||||
session_or_data.apply_template(template) # type: ignore[union-attr]
|
||||
session_or_data.apply_template(template)
|
||||
|
||||
app = create_app(session_or_data)
|
||||
print(f"Impakt running at http://localhost:{port}")
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""Criteria computation callbacks.
|
||||
|
||||
Handles:
|
||||
- Compute All button -> runs auto_compute_criteria
|
||||
- Protocol scoring
|
||||
- Results display
|
||||
Uses Session.compute_criteria() and Session.evaluate() from the
|
||||
scripting API — the same path used by programmatic scripts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,11 +12,7 @@ import dash
|
||||
from dash import Input, Output, State, html
|
||||
from dash.exceptions import PreventUpdate
|
||||
|
||||
from impakt.web.components.criteria import (
|
||||
auto_compute_criteria,
|
||||
build_criteria_results_display,
|
||||
score_protocol,
|
||||
)
|
||||
from impakt.web.components.criteria import build_criteria_results_display
|
||||
from impakt.web.state import AppState
|
||||
|
||||
|
||||
@@ -41,8 +35,8 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None:
|
||||
if primary is None:
|
||||
return html.Div("No test loaded.", className="text-danger small")
|
||||
|
||||
# Compute criteria
|
||||
criteria = auto_compute_criteria(primary.data)
|
||||
# Use Session.compute_criteria() — single code path for script + web
|
||||
criteria = primary.compute_criteria()
|
||||
|
||||
if not criteria:
|
||||
return html.Div(
|
||||
@@ -51,7 +45,10 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None:
|
||||
className="text-warning small",
|
||||
)
|
||||
|
||||
# Score against protocol
|
||||
protocol_result = score_protocol(criteria, protocol)
|
||||
# Score against protocol — use Session.evaluate()
|
||||
try:
|
||||
protocol_result = primary.evaluate(protocol)
|
||||
except Exception:
|
||||
protocol_result = None
|
||||
|
||||
return build_criteria_results_display(criteria, protocol_result)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Plot rendering callbacks.
|
||||
|
||||
Handles:
|
||||
- Updating plot figures when channels/transforms change
|
||||
- Cursor line rendering (X1/X2 vertical lines)
|
||||
- Hover data is NOT shown as a Plotly tooltip — instead the cursor grid
|
||||
picks it up via the hoverData callback
|
||||
Builds a PlotSpec from the UI state and delegates to PlotEngine.render().
|
||||
This is the single rendering path — the same PlotEngine used by the
|
||||
scripting API renders the web UI plots.
|
||||
|
||||
Transform application uses TransformChain, making the pipeline
|
||||
serializable and reproducible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,16 +16,55 @@ from typing import Any
|
||||
import dash
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
from dash import ALL, Input, Output, State, ctx
|
||||
from dash import Input, Output, State
|
||||
from dash.exceptions import PreventUpdate
|
||||
|
||||
from impakt.channel.model import Channel
|
||||
from impakt.plot.engine import DEFAULT_COLORS
|
||||
from impakt.plot.engine import DEFAULT_COLORS, PlotEngine
|
||||
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
|
||||
from impakt.transform.align import XAlign, YAlign
|
||||
from impakt.transform.base import TransformChain
|
||||
from impakt.transform.cfc import CFCFilter
|
||||
from impakt.transform.resultant import resultant_from_channels
|
||||
from impakt.web.state import AppState
|
||||
|
||||
# Module-level engine instance (stateless, safe to reuse)
|
||||
_engine = PlotEngine()
|
||||
|
||||
|
||||
def _build_transform_chain(
|
||||
cfc_value: str,
|
||||
y_align: bool,
|
||||
x_align_method: str,
|
||||
x_align_value: float | None,
|
||||
per_channel_cfc: str | None = None,
|
||||
) -> TransformChain:
|
||||
"""Build a TransformChain from the current UI state.
|
||||
|
||||
Per-channel CFC override takes precedence over the global CFC setting.
|
||||
"""
|
||||
chain = TransformChain()
|
||||
|
||||
# CFC filter
|
||||
effective_cfc = per_channel_cfc if per_channel_cfc else cfc_value
|
||||
if effective_cfc and effective_cfc != "none":
|
||||
try:
|
||||
chain = chain.append(CFCFilter(cfc_class=int(effective_cfc)))
|
||||
except (ValueError, Exception):
|
||||
pass
|
||||
|
||||
# Y-align
|
||||
if y_align:
|
||||
chain = chain.append(YAlign())
|
||||
|
||||
# X-align
|
||||
if x_align_method == "manual" and x_align_value is not None:
|
||||
chain = chain.append(XAlign(method="manual", reference_time=x_align_value))
|
||||
elif x_align_method == "threshold" and x_align_value is not None:
|
||||
chain = chain.append(XAlign(method="threshold", threshold_value=x_align_value))
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
def _resolve_channels(
|
||||
selected_keys: list[str],
|
||||
@@ -35,12 +75,15 @@ def _resolve_channels(
|
||||
x_align_value: float | None,
|
||||
show_resultant: bool,
|
||||
) -> list[tuple[str, Channel]]:
|
||||
"""Resolve selected channel keys to Channel objects with transforms applied.
|
||||
"""Resolve selected channel keys to transformed Channel objects.
|
||||
|
||||
Per-channel overrides (from app_state.channel_overrides) take precedence
|
||||
over the global CFC setting.
|
||||
Uses TransformChain for each channel. Per-channel overrides from
|
||||
app_state.channel_overrides take precedence over the global CFC.
|
||||
|
||||
Returns list of (label, transformed_channel) tuples.
|
||||
"""
|
||||
channels: list[tuple[str, Channel]] = []
|
||||
multi_test = len(app_state.tests) > 1
|
||||
|
||||
for key in selected_keys:
|
||||
if "::" in key:
|
||||
@@ -55,29 +98,22 @@ def _resolve_channels(
|
||||
if ch is None:
|
||||
continue
|
||||
|
||||
# Determine CFC: per-channel override takes precedence over global
|
||||
# Build per-channel transform chain
|
||||
override = app_state.channel_overrides.get(key, {})
|
||||
ch_cfc = override.get("cfc", "")
|
||||
effective_cfc = ch_cfc if ch_cfc else cfc_value
|
||||
per_ch_cfc = override.get("cfc", "")
|
||||
chain = _build_transform_chain(
|
||||
cfc_value,
|
||||
y_align,
|
||||
x_align_method,
|
||||
x_align_value,
|
||||
per_channel_cfc=per_ch_cfc,
|
||||
)
|
||||
|
||||
if effective_cfc and effective_cfc != "none":
|
||||
try:
|
||||
ch = CFCFilter(cfc_class=int(effective_cfc)).apply(ch)
|
||||
except (ValueError, Exception):
|
||||
pass
|
||||
|
||||
# Apply Y-align
|
||||
if y_align:
|
||||
ch = YAlign().apply(ch)
|
||||
|
||||
# Apply X-align
|
||||
if x_align_method == "manual" and x_align_value is not None:
|
||||
ch = XAlign(method="manual", reference_time=x_align_value).apply(ch)
|
||||
elif x_align_method == "threshold" and x_align_value is not None:
|
||||
ch = XAlign(method="threshold", threshold_value=x_align_value).apply(ch)
|
||||
# Apply the chain
|
||||
if len(chain) > 0:
|
||||
ch = chain.apply(ch)
|
||||
|
||||
# Build label
|
||||
multi_test = len(app_state.tests) > 1
|
||||
label = ch.code.short_label if ch.code.is_valid else ch.name
|
||||
if multi_test:
|
||||
label = f"[{test_id}] {label}"
|
||||
@@ -99,7 +135,7 @@ def _resolve_channels(
|
||||
res_label = (
|
||||
f"{res.code.location_label} Resultant" if res.code.is_valid else "Resultant"
|
||||
)
|
||||
if len(app_state.tests) > 1:
|
||||
if multi_test:
|
||||
res_label = f"[{comps[0].source_test_id}] {res_label}"
|
||||
channels.append((res_label, res))
|
||||
except Exception:
|
||||
@@ -108,123 +144,61 @@ def _resolve_channels(
|
||||
return channels
|
||||
|
||||
|
||||
def _build_figure(
|
||||
def _build_plot_spec(
|
||||
channels: list[tuple[str, Channel]],
|
||||
cursor_x1: float | None,
|
||||
cursor_x2: float | None,
|
||||
cfc_value: str,
|
||||
corridors: list[dict] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Build a Plotly figure from resolved channels."""
|
||||
fig = go.Figure()
|
||||
) -> PlotSpec:
|
||||
"""Build a PlotSpec from resolved channels and UI state.
|
||||
|
||||
if not channels:
|
||||
fig.update_layout(
|
||||
template="plotly_white",
|
||||
annotations=[
|
||||
{
|
||||
"text": "Select channels from the grid to plot",
|
||||
"xref": "paper",
|
||||
"yref": "paper",
|
||||
"x": 0.5,
|
||||
"y": 0.5,
|
||||
"showarrow": False,
|
||||
"font": {"size": 14, "color": "#bbb"},
|
||||
}
|
||||
],
|
||||
xaxis={"visible": False},
|
||||
yaxis={"visible": False},
|
||||
margin={"l": 40, "r": 20, "t": 30, "b": 40},
|
||||
)
|
||||
return fig
|
||||
|
||||
# Add traces — hovermode is disabled at the layout level; cursor tracking
|
||||
# is handled by our own JS mousemove handler (cursor_tracker.js).
|
||||
Channels are already transformed — they are wrapped in ChannelRef
|
||||
objects with no additional transform chain (transforms were applied
|
||||
during resolution).
|
||||
"""
|
||||
# Build ChannelRef objects
|
||||
refs: list[ChannelRef] = []
|
||||
for i, (label, ch) in enumerate(channels):
|
||||
color = DEFAULT_COLORS[i % len(DEFAULT_COLORS)]
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=ch.time.tolist(),
|
||||
y=ch.data.tolist(),
|
||||
mode="lines",
|
||||
name=label,
|
||||
line=dict(color=color, width=1.5),
|
||||
refs.append(
|
||||
ChannelRef(
|
||||
channel=ch,
|
||||
style=PlotStyle(label=label, color=color),
|
||||
)
|
||||
)
|
||||
|
||||
# Add corridor fills
|
||||
# Build Corridor objects from raw dicts
|
||||
corridor_objs: list[Corridor] = []
|
||||
if corridors:
|
||||
for corridor in corridors:
|
||||
if not corridor.get("visible", True):
|
||||
for c in corridors:
|
||||
if not c.get("visible", True):
|
||||
continue
|
||||
c_time = corridor["time"]
|
||||
c_upper = corridor["upper"]
|
||||
c_lower = corridor["lower"]
|
||||
c_name = corridor.get("name", "Corridor")
|
||||
|
||||
# Upper bound
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=c_time,
|
||||
y=c_upper,
|
||||
mode="lines",
|
||||
line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"),
|
||||
showlegend=False,
|
||||
name=f"{c_name} upper",
|
||||
)
|
||||
)
|
||||
# Lower bound with fill to upper
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=c_time,
|
||||
y=c_lower,
|
||||
mode="lines",
|
||||
line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"),
|
||||
fill="tonexty",
|
||||
fillcolor="rgba(100,100,255,0.1)",
|
||||
showlegend=True,
|
||||
name=c_name,
|
||||
corridor_objs.append(
|
||||
Corridor(
|
||||
name=c.get("name", "Corridor"),
|
||||
time=np.array(c["time"]),
|
||||
lower=np.array(c["lower"]),
|
||||
upper=np.array(c["upper"]),
|
||||
style=CorridorStyle(),
|
||||
)
|
||||
)
|
||||
|
||||
# Add X1/X2 cursor lines
|
||||
if cursor_x1 is not None:
|
||||
fig.add_vline(
|
||||
x=cursor_x1,
|
||||
line_dash="dash",
|
||||
line_color="rgba(220,53,69,0.6)",
|
||||
line_width=1,
|
||||
annotation_text=f"X1={cursor_x1:.4f}s",
|
||||
annotation_font_size=9,
|
||||
annotation_font_color="rgba(220,53,69,0.8)",
|
||||
)
|
||||
if cursor_x2 is not None:
|
||||
fig.add_vline(
|
||||
x=cursor_x2,
|
||||
line_dash="dash",
|
||||
line_color="rgba(13,110,253,0.6)",
|
||||
line_width=1,
|
||||
annotation_text=f"X2={cursor_x2:.4f}s",
|
||||
annotation_font_size=9,
|
||||
annotation_font_color="rgba(13,110,253,0.8)",
|
||||
)
|
||||
# Cursor positions
|
||||
x_cursors = None
|
||||
if cursor_x1 is not None and cursor_x2 is not None:
|
||||
x_cursors = (cursor_x1, cursor_x2)
|
||||
|
||||
# Layout — hovermode is disabled; cursor tracking is handled entirely
|
||||
# by our JS (cursor_tracker.js) which reads pixel positions from mousemove
|
||||
# events and converts to data coordinates via Plotly's axis internals.
|
||||
# Y-axis label from first channel
|
||||
y_label = channels[0][1].unit if channels else ""
|
||||
|
||||
fig.update_layout(
|
||||
xaxis_title=dict(text="Time (s)", font=dict(size=10, color="#999")),
|
||||
yaxis_title=dict(text=y_label, font=dict(size=10, color="#999")),
|
||||
template="plotly_white",
|
||||
hovermode=False,
|
||||
showlegend=False,
|
||||
margin=dict(l=45, r=8, t=4, b=28),
|
||||
return PlotSpec(
|
||||
channels=refs,
|
||||
corridors=corridor_objs,
|
||||
x_cursors=x_cursors,
|
||||
y_label=y_label,
|
||||
compact=True, # Web UI always uses compact mode
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None:
|
||||
"""Register all plot-related callbacks."""
|
||||
@@ -271,10 +245,5 @@ def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None:
|
||||
show_resultant,
|
||||
)
|
||||
|
||||
return _build_figure(
|
||||
channels,
|
||||
cursor_x1,
|
||||
cursor_x2,
|
||||
cfc_value,
|
||||
corridors=corridors_data,
|
||||
)
|
||||
spec = _build_plot_spec(channels, cursor_x1, cursor_x2, corridors_data)
|
||||
return _engine.render(spec)
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
"""Collapsible channel tree component.
|
||||
|
||||
Renders channels in a hierarchical accordion:
|
||||
Test Object > Body Region > Measurement Type > [channels]
|
||||
|
||||
Supports search filtering, select-all per group, and channel preview
|
||||
(peak, unit, sample rate) on hover/expand.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import dash_bootstrap_components as dbc
|
||||
from dash import dcc, html
|
||||
|
||||
from impakt.web.state import AppState
|
||||
|
||||
|
||||
def _make_channel_item(ch_info: dict[str, str], show_test_prefix: bool = False) -> html.Div:
|
||||
"""Build a single channel item with checkbox and preview info."""
|
||||
label = ch_info["label"]
|
||||
key = ch_info["key"]
|
||||
peak = ch_info.get("peak", "")
|
||||
unit = ch_info.get("unit", "")
|
||||
|
||||
preview = f" ({peak} {unit})" if peak else ""
|
||||
|
||||
return html.Div(
|
||||
[
|
||||
dbc.Checkbox(
|
||||
id={"type": "channel-check", "index": key},
|
||||
label=html.Span(
|
||||
[
|
||||
html.Span(label, style={"fontSize": "12px"}),
|
||||
html.Span(
|
||||
preview,
|
||||
style={"fontSize": "10px", "color": "#999", "marginLeft": "4px"},
|
||||
),
|
||||
]
|
||||
),
|
||||
value=False,
|
||||
style={"marginBottom": "1px"},
|
||||
),
|
||||
],
|
||||
className="channel-item",
|
||||
)
|
||||
|
||||
|
||||
def _make_group_header(
|
||||
group_label: str,
|
||||
group_id: str,
|
||||
channel_keys: list[str],
|
||||
) -> html.Div:
|
||||
"""Build a measurement group header with select-all button."""
|
||||
return html.Div(
|
||||
[
|
||||
html.Span(group_label, style={"fontSize": "12px", "fontWeight": "500"}),
|
||||
dbc.Button(
|
||||
"All",
|
||||
id={"type": "select-group-btn", "index": group_id},
|
||||
color="link",
|
||||
size="sm",
|
||||
style={"fontSize": "10px", "padding": "0 4px", "textDecoration": "none"},
|
||||
),
|
||||
# Hidden store for this group's channel keys
|
||||
dcc.Store(
|
||||
id={"type": "group-channels", "index": group_id},
|
||||
data=channel_keys,
|
||||
),
|
||||
],
|
||||
className="d-flex align-items-center justify-content-between",
|
||||
)
|
||||
|
||||
|
||||
def build_channel_tree(app_state: AppState) -> dbc.Card:
|
||||
"""Build the full channel tree panel with search and accordion."""
|
||||
if app_state.is_empty:
|
||||
return dbc.Card(
|
||||
[
|
||||
dbc.CardHeader("Channels", className="fw-bold py-2"),
|
||||
dbc.CardBody(
|
||||
html.Div("No test loaded", className="text-muted small text-center py-3"),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
multi_test = len(app_state.tests) > 1
|
||||
full_tree = app_state.build_channel_tree()
|
||||
|
||||
# Build accordion items
|
||||
accordion_items = []
|
||||
group_counter = 0
|
||||
|
||||
for test_id, test_tree in full_tree.items():
|
||||
test_label = ""
|
||||
loaded = app_state.get_test(test_id)
|
||||
if loaded and multi_test:
|
||||
test_label = f"[{test_id}] "
|
||||
|
||||
for obj_label, locations in test_tree.items():
|
||||
# Top-level accordion: test object (Driver, Vehicle Structure, etc.)
|
||||
obj_content = []
|
||||
|
||||
for loc_label, measurements in locations.items():
|
||||
for meas_label, channels in measurements.items():
|
||||
group_id = f"grp_{group_counter}"
|
||||
group_counter += 1
|
||||
|
||||
channel_keys = [ch["key"] for ch in channels]
|
||||
header_label = (
|
||||
f"{loc_label} — {meas_label}" if loc_label != meas_label else meas_label
|
||||
)
|
||||
|
||||
obj_content.append(_make_group_header(header_label, group_id, channel_keys))
|
||||
for ch in channels:
|
||||
obj_content.append(_make_channel_item(ch, show_test_prefix=multi_test))
|
||||
|
||||
obj_content.append(html.Hr(style={"margin": "4px 0"}))
|
||||
|
||||
if obj_content:
|
||||
# Remove trailing hr
|
||||
if obj_content and isinstance(obj_content[-1], html.Hr):
|
||||
obj_content = obj_content[:-1]
|
||||
|
||||
display_label = f"{test_label}{obj_label}" if test_label else obj_label
|
||||
accordion_items.append(
|
||||
dbc.AccordionItem(
|
||||
html.Div(obj_content),
|
||||
title=display_label,
|
||||
style={"fontSize": "12px"},
|
||||
)
|
||||
)
|
||||
|
||||
return dbc.Card(
|
||||
[
|
||||
dbc.CardHeader(
|
||||
[
|
||||
html.Span("Channels", className="fw-bold"),
|
||||
html.Span(
|
||||
f" ({app_state.total_channels})",
|
||||
style={"fontSize": "11px", "color": "#999"},
|
||||
),
|
||||
],
|
||||
className="py-2",
|
||||
),
|
||||
dbc.CardBody(
|
||||
[
|
||||
# Search
|
||||
dbc.Input(
|
||||
id="channel-search",
|
||||
placeholder="Search channels...",
|
||||
type="text",
|
||||
size="sm",
|
||||
className="mb-2",
|
||||
debounce=True,
|
||||
),
|
||||
# Selected channels display
|
||||
html.Div(id="selected-channels-badges", className="mb-2"),
|
||||
# Tree
|
||||
html.Div(
|
||||
dbc.Accordion(
|
||||
accordion_items,
|
||||
flush=True,
|
||||
always_open=True,
|
||||
start_collapsed=True,
|
||||
id="channel-accordion",
|
||||
),
|
||||
style={"maxHeight": "450px", "overflowY": "auto"},
|
||||
id="channel-tree-container",
|
||||
),
|
||||
],
|
||||
className="py-2",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_selected_channels_badges(selected_keys: list[str], app_state: AppState) -> list:
|
||||
"""Build badge pills showing currently selected channels."""
|
||||
if not selected_keys:
|
||||
return [
|
||||
html.Span("No channels selected", className="text-muted", style={"fontSize": "11px"})
|
||||
]
|
||||
|
||||
badges = []
|
||||
for key in selected_keys[:20]: # Limit display to 20
|
||||
# Extract label
|
||||
if "::" in key:
|
||||
test_id, ch_name = key.split("::", 1)
|
||||
ch = app_state.get_channel(test_id, ch_name)
|
||||
if ch and ch.code.is_valid:
|
||||
label = ch.code.short_label
|
||||
else:
|
||||
label = ch_name
|
||||
if len(app_state.tests) > 1:
|
||||
label = f"{test_id}: {label}"
|
||||
else:
|
||||
label = key
|
||||
|
||||
badges.append(
|
||||
dbc.Badge(
|
||||
label,
|
||||
id={"type": "channel-badge", "index": key},
|
||||
color="primary",
|
||||
pill=True,
|
||||
className="me-1 mb-1",
|
||||
style={"fontSize": "10px", "cursor": "pointer"},
|
||||
)
|
||||
)
|
||||
|
||||
if len(selected_keys) > 20:
|
||||
badges.append(
|
||||
html.Span(
|
||||
f"+{len(selected_keys) - 20} more",
|
||||
className="text-muted",
|
||||
style={"fontSize": "10px"},
|
||||
)
|
||||
)
|
||||
|
||||
return badges
|
||||
@@ -129,7 +129,7 @@ def build_test_info_panel(app_state: AppState) -> dbc.Card:
|
||||
f"{meta.impact.speed_kmh:.1f}" if meta.impact.speed_kmh else "—",
|
||||
style={"fontSize": "12px"},
|
||||
),
|
||||
html.Td(str(loaded.channel_count), style={"fontSize": "12px"}),
|
||||
html.Td(str(len(loaded)), style={"fontSize": "12px"}),
|
||||
html.Td(
|
||||
dbc.Button(
|
||||
"x",
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Application state management.
|
||||
|
||||
AppState is the central data store for the web UI. It holds all loaded tests,
|
||||
manages channel transforms, and provides the data that callbacks need to
|
||||
render plots, compute criteria, and generate reports.
|
||||
AppState is the central data store for the web UI. It holds Session
|
||||
objects (from the scripting API) and delegates to them for channel
|
||||
access, transforms, criteria, and protocol scoring.
|
||||
|
||||
AppState is NOT a Dash Store — it lives server-side in Python memory. The Dash
|
||||
stores hold lightweight references (test IDs, selected channels, plot config)
|
||||
AppState is NOT a Dash Store — it lives server-side in Python memory.
|
||||
Dash stores hold lightweight references (test IDs, selected channels)
|
||||
that callbacks use to look up data from AppState.
|
||||
|
||||
This design keeps large numpy arrays out of the browser and avoids serialization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -17,220 +15,95 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from impakt.channel.model import Channel, ChannelGroup, TestData
|
||||
from impakt.io.reader import get_registry
|
||||
from impakt.channel.model import Channel, TestData
|
||||
from impakt.script.api import Session
|
||||
from impakt.template.library import TemplateLibrary
|
||||
from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
|
||||
from impakt.template.session import SessionManager
|
||||
from impakt.template.model import PlotDefinition, TemplateSpec
|
||||
from impakt.transform.align import YAlign
|
||||
from impakt.transform.cfc import CFCFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadedTest:
|
||||
"""A single loaded test with its metadata and channel access."""
|
||||
|
||||
def __init__(self, test_data: TestData, color_offset: int = 0) -> None:
|
||||
self.data = test_data
|
||||
self.color_offset = color_offset
|
||||
|
||||
@property
|
||||
def test_id(self) -> str:
|
||||
return self.data.test_id
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
"""Short label for UI display."""
|
||||
meta = self.data.metadata
|
||||
parts = [meta.test_number]
|
||||
if meta.vehicle.make:
|
||||
parts.append(f"{meta.vehicle.make} {meta.vehicle.model}".strip())
|
||||
return " — ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
@property
|
||||
def channel_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class AppState:
|
||||
"""Server-side application state.
|
||||
|
||||
Manages multiple loaded tests and provides channel resolution
|
||||
for the UI callbacks.
|
||||
Manages multiple loaded test sessions and provides channel resolution
|
||||
for the UI callbacks. All test data access goes through Session objects.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tests: dict[str, LoadedTest] = {}
|
||||
self._test_order: list[str] = []
|
||||
self._color_counter: int = 0
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._session_order: list[str] = []
|
||||
self._template_library = TemplateLibrary()
|
||||
self._active_template: TemplateSpec | None = None
|
||||
self._session_managers: dict[str, SessionManager] = {}
|
||||
# Per-channel transform overrides: {channel_key: {"cfc": "600", "y_align": True}}
|
||||
# Per-channel transform overrides: {channel_key: {"cfc": "600"}}
|
||||
self.channel_overrides: dict[str, dict[str, Any]] = {}
|
||||
# Active corridors: list of {name, time, lower, upper, visible}
|
||||
# Active corridors
|
||||
self.corridors: list[dict[str, Any]] = []
|
||||
|
||||
def load_test(self, path: str | Path) -> LoadedTest:
|
||||
def load_test(self, path: str | Path) -> Session:
|
||||
"""Load a test from a path and add it to the state.
|
||||
|
||||
If a test with the same ID is already loaded, replaces it.
|
||||
|
||||
Returns:
|
||||
The loaded test.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is not a valid test directory.
|
||||
Returns the Session object.
|
||||
"""
|
||||
path = Path(path).resolve()
|
||||
registry = get_registry()
|
||||
test_data = registry.read(path)
|
||||
session = Session.open(path)
|
||||
test_id = session.test_id
|
||||
|
||||
loaded = LoadedTest(test_data, color_offset=self._color_counter)
|
||||
self._color_counter += len(test_data)
|
||||
self._sessions[test_id] = session
|
||||
if test_id not in self._session_order:
|
||||
self._session_order.append(test_id)
|
||||
|
||||
test_id = test_data.test_id
|
||||
if test_id in self._tests:
|
||||
self._tests[test_id] = loaded
|
||||
else:
|
||||
self._tests[test_id] = loaded
|
||||
self._test_order.append(test_id)
|
||||
logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(session))
|
||||
return session
|
||||
|
||||
logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(test_data))
|
||||
return loaded
|
||||
def add_session(self, session: Session) -> None:
|
||||
"""Add an already-created Session to the state."""
|
||||
test_id = session.test_id
|
||||
self._sessions[test_id] = session
|
||||
if test_id not in self._session_order:
|
||||
self._session_order.append(test_id)
|
||||
|
||||
def remove_test(self, test_id: str) -> None:
|
||||
"""Remove a loaded test."""
|
||||
if test_id in self._tests:
|
||||
del self._tests[test_id]
|
||||
self._test_order.remove(test_id)
|
||||
self._sessions.pop(test_id, None)
|
||||
if test_id in self._session_order:
|
||||
self._session_order.remove(test_id)
|
||||
|
||||
@property
|
||||
def tests(self) -> list[LoadedTest]:
|
||||
"""All loaded tests in load order."""
|
||||
return [self._tests[tid] for tid in self._test_order if tid in self._tests]
|
||||
def sessions(self) -> list[Session]:
|
||||
"""All loaded sessions in load order."""
|
||||
return [self._sessions[tid] for tid in self._session_order if tid in self._sessions]
|
||||
|
||||
@property
|
||||
def tests(self) -> list[Session]:
|
||||
"""Alias for sessions (backward compat for components)."""
|
||||
return self.sessions
|
||||
|
||||
@property
|
||||
def test_ids(self) -> list[str]:
|
||||
return list(self._test_order)
|
||||
return list(self._session_order)
|
||||
|
||||
@property
|
||||
def primary_test(self) -> LoadedTest | None:
|
||||
"""The first loaded test (primary for criteria, etc.)."""
|
||||
if self._test_order:
|
||||
return self._tests.get(self._test_order[0])
|
||||
def primary_test(self) -> Session | None:
|
||||
"""The first loaded session (primary for criteria, etc.)."""
|
||||
if self._session_order:
|
||||
return self._sessions.get(self._session_order[0])
|
||||
return None
|
||||
|
||||
def get_test(self, test_id: str) -> LoadedTest | None:
|
||||
return self._tests.get(test_id)
|
||||
def get_test(self, test_id: str) -> Session | None:
|
||||
return self._sessions.get(test_id)
|
||||
|
||||
def get_channel(self, test_id: str, channel_name: str) -> Channel | None:
|
||||
"""Resolve a channel from a test ID and channel name."""
|
||||
test = self._tests.get(test_id)
|
||||
if test is None:
|
||||
session = self._sessions.get(test_id)
|
||||
if session is None:
|
||||
return None
|
||||
try:
|
||||
return test.data.get(channel_name)
|
||||
return session.data.get(channel_name)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def resolve_channel(
|
||||
self,
|
||||
channel_key: str,
|
||||
cfc_class: int | None = None,
|
||||
y_align: bool = False,
|
||||
) -> Channel | None:
|
||||
"""Resolve a channel key (test_id::channel_name) and apply transforms.
|
||||
|
||||
Channel keys use the format 'test_id::channel_name'. If no '::'
|
||||
separator, assumes the primary test.
|
||||
"""
|
||||
if "::" in channel_key:
|
||||
test_id, ch_name = channel_key.split("::", 1)
|
||||
else:
|
||||
if self.primary_test is None:
|
||||
return None
|
||||
test_id = self.primary_test.test_id
|
||||
ch_name = channel_key
|
||||
|
||||
ch = self.get_channel(test_id, ch_name)
|
||||
if ch is None:
|
||||
return None
|
||||
|
||||
# Apply transforms
|
||||
if cfc_class is not None:
|
||||
try:
|
||||
ch = CFCFilter(cfc_class=cfc_class).apply(ch)
|
||||
except (ValueError, Exception):
|
||||
pass
|
||||
|
||||
if y_align:
|
||||
ch = YAlign().apply(ch)
|
||||
|
||||
return ch
|
||||
|
||||
def build_channel_tree(
|
||||
self,
|
||||
) -> dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]]:
|
||||
"""Build a hierarchical channel tree across all loaded tests.
|
||||
|
||||
Returns:
|
||||
{test_id: {object_label: {location_label: {measurement_label: [{name, label, key}]}}}}
|
||||
"""
|
||||
result: dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]] = {}
|
||||
|
||||
for loaded in self.tests:
|
||||
tree = loaded.data.channel_tree()
|
||||
test_tree: dict[str, dict[str, dict[str, list[dict[str, str]]]]] = {}
|
||||
|
||||
for obj, locations in sorted(tree.items()):
|
||||
obj_tree: dict[str, dict[str, list[dict[str, str]]]] = {}
|
||||
for loc, measurements in sorted(locations.items()):
|
||||
loc_tree: dict[str, list[dict[str, str]]] = {}
|
||||
for meas, channels in sorted(measurements.items()):
|
||||
ch_list = []
|
||||
for ch in channels:
|
||||
label = ch.code.short_label if ch.code.is_valid else ch.name
|
||||
ch_list.append(
|
||||
{
|
||||
"name": ch.name,
|
||||
"label": label,
|
||||
"key": f"{loaded.test_id}::{ch.name}",
|
||||
"unit": ch.unit,
|
||||
"peak": f"{ch.peak:.2f}",
|
||||
"samples": str(ch.n_samples),
|
||||
"rate": f"{ch.sample_rate:.0f}",
|
||||
}
|
||||
)
|
||||
loc_tree[meas] = ch_list
|
||||
obj_tree[loc] = loc_tree
|
||||
test_tree[obj] = obj_tree
|
||||
result[loaded.test_id] = test_tree
|
||||
|
||||
return result
|
||||
|
||||
def flat_channel_list(self) -> list[dict[str, str]]:
|
||||
"""Flat list of all channels across all tests, for dropdown/checklist options."""
|
||||
items: list[dict[str, str]] = []
|
||||
multi = len(self._tests) > 1
|
||||
|
||||
for loaded in self.tests:
|
||||
prefix = f"[{loaded.test_id}] " if multi else ""
|
||||
for ch in sorted(loaded.data, key=lambda c: c.name):
|
||||
label = ch.code.short_label if ch.code.is_valid else ch.name
|
||||
items.append(
|
||||
{
|
||||
"label": f"{prefix}{label}",
|
||||
"value": f"{loaded.test_id}::{ch.name}",
|
||||
}
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
# ----- Template & Session -----
|
||||
|
||||
@property
|
||||
@@ -245,52 +118,37 @@ class AppState:
|
||||
def template_names(self) -> list[str]:
|
||||
return self._template_library.list()
|
||||
|
||||
def apply_template(
|
||||
self,
|
||||
name: str,
|
||||
selected_keys: list[str] | None = None,
|
||||
) -> tuple[list[str], dict[str, str]]:
|
||||
def apply_template(self, name: str) -> tuple[list[str], dict[str, str]]:
|
||||
"""Apply a template by name.
|
||||
|
||||
Resolves the template's channel patterns against the primary test,
|
||||
sets the active template, and returns the resolved channel keys and
|
||||
transform settings.
|
||||
|
||||
Returns:
|
||||
(selected_channel_keys, transform_settings)
|
||||
Resolves channel patterns against the primary test.
|
||||
Returns (selected_channel_keys, transform_settings).
|
||||
"""
|
||||
template = self._template_library.get(name)
|
||||
self._active_template = template
|
||||
|
||||
# Persist to session if primary test has a path
|
||||
primary = self.primary_test
|
||||
if primary and primary.data.path:
|
||||
mgr = self._get_session_manager(primary.data.path)
|
||||
mgr.apply_template(template)
|
||||
if primary:
|
||||
primary.apply_template(template)
|
||||
|
||||
# Resolve channel patterns from the template
|
||||
# Resolve channel patterns
|
||||
resolved_keys: list[str] = []
|
||||
if primary:
|
||||
for plot_def in template.plots:
|
||||
for pattern in plot_def.channel_patterns:
|
||||
matches = primary.data.find(pattern)
|
||||
matches = primary.find(pattern)
|
||||
for ch in matches:
|
||||
key = f"{primary.test_id}::{ch.name}"
|
||||
if key not in resolved_keys:
|
||||
resolved_keys.append(key)
|
||||
|
||||
# Transform settings
|
||||
transforms: dict[str, str] = {}
|
||||
if template.default_cfc is not None:
|
||||
transforms["cfc"] = str(template.default_cfc)
|
||||
else:
|
||||
transforms["cfc"] = "none"
|
||||
|
||||
logger.info(
|
||||
"Applied template '%s' — resolved %d channels",
|
||||
name,
|
||||
len(resolved_keys),
|
||||
)
|
||||
logger.info("Applied template '%s' — %d channels", name, len(resolved_keys))
|
||||
return resolved_keys, transforms
|
||||
|
||||
def save_as_template(
|
||||
@@ -303,30 +161,19 @@ class AppState:
|
||||
x2: float | None,
|
||||
protocol: str = "",
|
||||
) -> TemplateSpec:
|
||||
"""Capture the current UI state as a new template.
|
||||
|
||||
Converts selected channels back to patterns and stores the
|
||||
current filter/cursor/protocol settings.
|
||||
"""
|
||||
# Convert selected keys to channel patterns
|
||||
"""Capture current UI state as a new template."""
|
||||
patterns: list[str] = []
|
||||
for key in selected_keys:
|
||||
if "::" in key:
|
||||
_, ch_name = key.split("::", 1)
|
||||
else:
|
||||
ch_name = key
|
||||
# Use the raw channel name as a pattern (exact match)
|
||||
ch_name = key.split("::", 1)[-1] if "::" in key else key
|
||||
if ch_name not in patterns:
|
||||
patterns.append(ch_name)
|
||||
|
||||
# Build plot definition
|
||||
plot = PlotDefinition(
|
||||
title=name,
|
||||
channel_patterns=patterns,
|
||||
x_cursors=(x1, x2) if x1 is not None and x2 is not None else None,
|
||||
)
|
||||
|
||||
# Build transforms list
|
||||
transforms: list[dict[str, Any]] = []
|
||||
if cfc_value and cfc_value != "none":
|
||||
transforms.append({"type": "cfc_filter", "cfc_class": int(cfc_value)})
|
||||
@@ -343,17 +190,18 @@ class AppState:
|
||||
|
||||
self._template_library.save(template)
|
||||
self._active_template = template
|
||||
|
||||
logger.info("Saved template '%s' with %d channel patterns", name, len(patterns))
|
||||
logger.info("Saved template '%s' with %d patterns", name, len(patterns))
|
||||
return template
|
||||
|
||||
def save_session(self, selected_keys: list[str], cfc_value: str, **overrides: Any) -> None:
|
||||
"""Auto-save current state to the session for the primary test."""
|
||||
primary = self.primary_test
|
||||
if not primary or not primary.data.path:
|
||||
if not primary or not primary.path:
|
||||
return
|
||||
|
||||
mgr = self._get_session_manager(primary.data.path)
|
||||
from impakt.template.session import SessionManager
|
||||
|
||||
mgr = SessionManager(primary.path)
|
||||
mgr.state.overrides = {
|
||||
"selected_channels": selected_keys,
|
||||
"cfc": cfc_value,
|
||||
@@ -362,12 +210,14 @@ class AppState:
|
||||
mgr.save()
|
||||
|
||||
def load_session_state(self) -> dict[str, Any] | None:
|
||||
"""Load saved session state for the primary test, if any."""
|
||||
"""Load saved session state for the primary test."""
|
||||
primary = self.primary_test
|
||||
if not primary or not primary.data.path:
|
||||
if not primary or not primary.path:
|
||||
return None
|
||||
|
||||
mgr = self._get_session_manager(primary.data.path)
|
||||
from impakt.template.session import SessionManager
|
||||
|
||||
mgr = SessionManager(primary.path)
|
||||
if not mgr.has_session:
|
||||
return None
|
||||
|
||||
@@ -377,21 +227,48 @@ class AppState:
|
||||
"template": mgr.state.template_name,
|
||||
}
|
||||
|
||||
def _get_session_manager(self, path: Path) -> SessionManager:
|
||||
key = str(path)
|
||||
if key not in self._session_managers:
|
||||
self._session_managers[key] = SessionManager(path)
|
||||
return self._session_managers[key]
|
||||
# ----- Channel tree / grid helpers -----
|
||||
|
||||
def build_channel_tree(self) -> dict[str, Any]:
|
||||
"""Build hierarchical channel tree across all loaded tests."""
|
||||
result: dict[str, Any] = {}
|
||||
for session in self.sessions:
|
||||
tree = session.channel_tree()
|
||||
test_tree: dict[str, Any] = {}
|
||||
for obj, locations in sorted(tree.items()):
|
||||
obj_tree: dict[str, Any] = {}
|
||||
for loc, measurements in sorted(locations.items()):
|
||||
loc_tree: dict[str, Any] = {}
|
||||
for meas, channels in sorted(measurements.items()):
|
||||
ch_list = []
|
||||
for ch in channels:
|
||||
label = ch.code.short_label if ch.code.is_valid else ch.name
|
||||
ch_list.append(
|
||||
{
|
||||
"name": ch.name,
|
||||
"label": label,
|
||||
"key": f"{session.test_id}::{ch.name}",
|
||||
"unit": ch.unit,
|
||||
"peak": f"{ch.peak:.2f}",
|
||||
"samples": str(ch.n_samples),
|
||||
"rate": f"{ch.sample_rate:.0f}",
|
||||
}
|
||||
)
|
||||
loc_tree[meas] = ch_list
|
||||
obj_tree[loc] = loc_tree
|
||||
test_tree[obj] = obj_tree
|
||||
result[session.test_id] = test_tree
|
||||
return result
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return len(self._tests) == 0
|
||||
return len(self._sessions) == 0
|
||||
|
||||
@property
|
||||
def total_channels(self) -> int:
|
||||
return sum(t.channel_count for t in self.tests)
|
||||
return sum(len(s) for s in self.sessions)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
test_info = ", ".join(f"{t.test_id}({t.channel_count}ch)" for t in self.tests)
|
||||
test_info = ", ".join(f"{s.test_id}({len(s)}ch)" for s in self.sessions)
|
||||
tmpl = f", template={self._active_template.name}" if self._active_template else ""
|
||||
return f"AppState([{test_info}]{tmpl})"
|
||||
|
||||
93
tests/test_criteria/test_chest_femur_tibia.py
Normal file
93
tests/test_criteria/test_chest_femur_tibia.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Tests for chest deflection, femur load, tibia index, 3ms clip, viscous criterion."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from impakt.criteria import chest_deflection, clip_3ms, femur_load, tibia_index, viscous_criterion
|
||||
|
||||
|
||||
class TestChestDeflection:
|
||||
def test_basic(self, chest_deflection_channel):
|
||||
result = chest_deflection(channel=chest_deflection_channel)
|
||||
assert result.criterion == "Chest Deflection"
|
||||
assert result.unit == "mm"
|
||||
assert 30.0 < result.value < 40.0 # ~35mm in fixture
|
||||
|
||||
def test_peak_time(self, chest_deflection_channel):
|
||||
result = chest_deflection(channel=chest_deflection_channel)
|
||||
assert result.time_of_peak is not None
|
||||
assert 0.0 < result.time_of_peak < 0.1
|
||||
|
||||
|
||||
class TestClip3ms:
|
||||
def test_basic(self, head_accel_x):
|
||||
result = clip_3ms(head_accel_x)
|
||||
assert result.criterion == "3ms Clip"
|
||||
assert result.value > 0
|
||||
|
||||
def test_body_region(self, head_accel_x):
|
||||
result = clip_3ms(head_accel_x)
|
||||
assert result.body_region == "Chest"
|
||||
|
||||
|
||||
class TestFemurLoad:
|
||||
def test_single_channel(self, femur_left_channel):
|
||||
result = femur_load(channel=femur_left_channel, side="left")
|
||||
assert result.criterion == "Femur Load Left"
|
||||
assert result.unit == "kN"
|
||||
assert result.value > 0
|
||||
|
||||
def test_peak_time(self, femur_left_channel):
|
||||
result = femur_load(channel=femur_left_channel, side="left")
|
||||
assert result.time_of_peak is not None
|
||||
|
||||
|
||||
class TestTibiaIndex:
|
||||
def test_with_all_components(self, time_array, sample_rate):
|
||||
from impakt.channel.code import ChannelCode
|
||||
from impakt.channel.model import Channel
|
||||
|
||||
t = time_array
|
||||
fz = np.zeros_like(t)
|
||||
mx = np.zeros_like(t)
|
||||
my = np.zeros_like(t)
|
||||
mask = (t >= 0.02) & (t <= 0.08)
|
||||
fz[mask] = -5000 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
|
||||
mx[mask] = 50 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
|
||||
my[mask] = 80 * np.sin(np.pi * (t[mask] - 0.02) / 0.06)
|
||||
|
||||
fz_ch = Channel(
|
||||
name="TIBFZ",
|
||||
code=ChannelCode.parse("TIBFZ"),
|
||||
data=fz,
|
||||
time=t,
|
||||
unit="N",
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
mx_ch = Channel(
|
||||
name="TIBMX",
|
||||
code=ChannelCode.parse("TIBMX"),
|
||||
data=mx,
|
||||
time=t,
|
||||
unit="N·m",
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
my_ch = Channel(
|
||||
name="TIBMY",
|
||||
code=ChannelCode.parse("TIBMY"),
|
||||
data=my,
|
||||
time=t,
|
||||
unit="N·m",
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
|
||||
result = tibia_index(fz_channel=fz_ch, mx_channel=mx_ch, my_channel=my_ch)
|
||||
assert result.value > 0
|
||||
|
||||
|
||||
class TestViscousCriterion:
|
||||
def test_basic(self, chest_deflection_channel):
|
||||
result = viscous_criterion(channel=chest_deflection_channel)
|
||||
assert result.criterion == "Viscous Criterion"
|
||||
assert result.unit == "m/s"
|
||||
assert result.value >= 0
|
||||
81
tests/test_plot/test_engine.py
Normal file
81
tests/test_plot/test_engine.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for plot engine."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from impakt.plot.engine import PlotEngine, DEFAULT_COLORS
|
||||
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
|
||||
|
||||
|
||||
class TestPlotEngine:
|
||||
def test_render_empty(self):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec()
|
||||
fig = engine.render(spec)
|
||||
assert fig is not None
|
||||
|
||||
def test_render_with_channels(self, head_accel_x, head_accel_y):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec(
|
||||
channels=[
|
||||
ChannelRef(channel=head_accel_x, style=PlotStyle(label="Head X")),
|
||||
ChannelRef(channel=head_accel_y, style=PlotStyle(label="Head Y")),
|
||||
],
|
||||
y_label="g",
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
assert len(fig.data) == 2
|
||||
|
||||
def test_render_compact_mode(self, head_accel_x):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec(
|
||||
channels=[ChannelRef(channel=head_accel_x)],
|
||||
compact=True,
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
assert fig.layout.hovermode is False
|
||||
assert fig.layout.showlegend is False
|
||||
|
||||
def test_render_with_corridors(self, head_accel_x):
|
||||
engine = PlotEngine()
|
||||
corridor = Corridor(
|
||||
name="Test Corridor",
|
||||
time=np.linspace(0, 0.1, 100),
|
||||
lower=np.full(100, -50.0),
|
||||
upper=np.full(100, 50.0),
|
||||
)
|
||||
spec = PlotSpec(
|
||||
channels=[ChannelRef(channel=head_accel_x)],
|
||||
corridors=[corridor],
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
# 1 data trace + 2 corridor traces (upper + lower)
|
||||
assert len(fig.data) == 3
|
||||
|
||||
def test_render_with_cursors(self, head_accel_x):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec(
|
||||
channels=[ChannelRef(channel=head_accel_x)],
|
||||
x_cursors=(0.02, 0.06),
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
# Should have vertical lines (as shapes)
|
||||
assert len(fig.layout.shapes) >= 2
|
||||
|
||||
def test_compact_cursor_annotations(self, head_accel_x):
|
||||
engine = PlotEngine()
|
||||
spec = PlotSpec(
|
||||
channels=[ChannelRef(channel=head_accel_x)],
|
||||
x_cursors=(0.02, 0.06),
|
||||
compact=True,
|
||||
)
|
||||
fig = engine.render(spec)
|
||||
# Should have X1/X2 annotations
|
||||
annotations = [
|
||||
a for a in (fig.layout.annotations or []) if "X1" in str(a.text) or "X2" in str(a.text)
|
||||
]
|
||||
assert len(annotations) == 2
|
||||
|
||||
def test_default_colors(self):
|
||||
assert len(DEFAULT_COLORS) == 10
|
||||
assert all(c.startswith("#") for c in DEFAULT_COLORS)
|
||||
68
tests/test_protocol/test_iihs_usncap.py
Normal file
68
tests/test_protocol/test_iihs_usncap.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for IIHS and US NCAP scoring."""
|
||||
|
||||
import pytest
|
||||
|
||||
from impakt.criteria.base import CriterionResult
|
||||
from impakt.protocol.iihs import IIHS, evaluate as iihs_evaluate
|
||||
from impakt.protocol.us_ncap import USNCAP, evaluate as us_evaluate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_criteria():
|
||||
return {
|
||||
"HIC15": CriterionResult(criterion="HIC15", value=400, body_region="Head"),
|
||||
"Nij": CriterionResult(criterion="Nij", value=0.4, body_region="Neck"),
|
||||
"Chest Deflection": CriterionResult(
|
||||
criterion="Chest Deflection", value=35, unit="mm", body_region="Chest"
|
||||
),
|
||||
"Femur Load Left": CriterionResult(
|
||||
criterion="Femur Load Left", value=4.0, unit="kN", body_region="Femur Left"
|
||||
),
|
||||
"Femur Load Right": CriterionResult(
|
||||
criterion="Femur Load Right", value=4.5, unit="kN", body_region="Femur Right"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestIIHS:
|
||||
def test_good_results(self, sample_criteria):
|
||||
result = iihs_evaluate(sample_criteria)
|
||||
assert result.protocol == "IIHS"
|
||||
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
|
||||
|
||||
def test_region_scores(self, sample_criteria):
|
||||
result = iihs_evaluate(sample_criteria)
|
||||
assert len(result.region_scores) > 0
|
||||
for rs in result.region_scores:
|
||||
assert rs.rating is not None
|
||||
|
||||
def test_poor_hic(self):
|
||||
result = iihs_evaluate(
|
||||
{
|
||||
"HIC15": CriterionResult(criterion="HIC15", value=1500, body_region="Head"),
|
||||
}
|
||||
)
|
||||
assert result.overall_rating == "POOR"
|
||||
|
||||
def test_summary(self, sample_criteria):
|
||||
result = iihs_evaluate(sample_criteria)
|
||||
summary = result.summary()
|
||||
assert "IIHS" in summary
|
||||
|
||||
|
||||
class TestUSNCAP:
|
||||
def test_basic(self, sample_criteria):
|
||||
result = us_evaluate(sample_criteria)
|
||||
assert result.protocol == "US NCAP"
|
||||
assert result.stars is not None
|
||||
assert 1 <= result.stars <= 5
|
||||
|
||||
def test_injury_probabilities(self, sample_criteria):
|
||||
result = us_evaluate(sample_criteria)
|
||||
assert "combined_injury_probability" in result.details
|
||||
p = result.details["combined_injury_probability"]
|
||||
assert 0.0 <= p <= 1.0
|
||||
|
||||
def test_region_scores(self, sample_criteria):
|
||||
result = us_evaluate(sample_criteria)
|
||||
assert len(result.region_scores) > 0
|
||||
126
tests/test_scripting_api.py
Normal file
126
tests/test_scripting_api.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Tests for the scripting API (Session, ChannelHandle, TransformProxy)."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from impakt import Session, Template
|
||||
|
||||
FIXTURE_DATA = Path(__file__).parent / "fixtures" / "sample_mme"
|
||||
MME_DATA = Path(__file__).parent / "mme_data"
|
||||
|
||||
|
||||
class TestSession:
|
||||
def test_open(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
assert s.test_id == "IMPAKT_SYNTH_001"
|
||||
assert len(s) == 26
|
||||
|
||||
def test_channel_access(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
ch = s.channel("11HEAD0000ACXA")
|
||||
assert ch.name == "11HEAD0000ACXA"
|
||||
assert ch.peak > 0
|
||||
|
||||
def test_find(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
channels = s.find("*HEAD*AC*")
|
||||
assert len(channels) == 3
|
||||
|
||||
def test_group(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
group = s.group("HEAD0000AC")
|
||||
assert group.x is not None
|
||||
|
||||
def test_compute_criteria(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
criteria = s.compute_criteria()
|
||||
assert len(criteria) > 0
|
||||
|
||||
def test_evaluate(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = s.evaluate("euro_ncap")
|
||||
assert result.stars is not None
|
||||
assert result.protocol == "Euro NCAP"
|
||||
|
||||
def test_evaluate_us_ncap(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = s.evaluate("us_ncap")
|
||||
assert result.stars is not None
|
||||
|
||||
def test_evaluate_iihs(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = s.evaluate("iihs")
|
||||
assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR")
|
||||
|
||||
def test_evaluate_invalid_protocol(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
with pytest.raises(ValueError, match="Unknown protocol"):
|
||||
s.evaluate("invalid")
|
||||
|
||||
def test_contains(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
assert "11HEAD0000ACXA" in s
|
||||
assert "NONEXISTENT" not in s
|
||||
|
||||
|
||||
class TestChannelHandleChaining:
|
||||
"""The fluent API must support chaining — each transform returns ChannelHandle."""
|
||||
|
||||
def test_single_transform(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
ch = s.channel("11HEAD0000ACXA")
|
||||
filtered = ch.transform.cfc(600)
|
||||
assert type(filtered).__name__ == "ChannelHandle"
|
||||
assert filtered.raw.cfc_class == 600
|
||||
|
||||
def test_double_chain(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = s.channel("11HEAD0000ACXA").transform.cfc(600).transform.y_align()
|
||||
assert type(result).__name__ == "ChannelHandle"
|
||||
assert len(result.raw.transform_history) == 2
|
||||
|
||||
def test_triple_chain(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
result = (
|
||||
s.channel("11HEAD0000ACXA")
|
||||
.transform.cfc(1000)
|
||||
.transform.y_align()
|
||||
.transform.trim(t_start=0.0, t_end=0.1)
|
||||
)
|
||||
assert type(result).__name__ == "ChannelHandle"
|
||||
assert len(result.raw.transform_history) == 3
|
||||
|
||||
def test_chain_preserves_data(self):
|
||||
s = Session.open(FIXTURE_DATA)
|
||||
original = s.channel("11HEAD0000ACXA")
|
||||
original_peak = original.peak
|
||||
filtered = original.transform.cfc(600)
|
||||
# Original should be unchanged — peak should be the same
|
||||
assert original.peak == original_peak
|
||||
# Filtered should have different CFC and lower peak (smoothed)
|
||||
assert filtered.raw.cfc_class == 600
|
||||
assert filtered.peak <= original_peak
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (MME_DATA / "3239").exists(), reason="Real data not available")
|
||||
class TestSessionRealData:
|
||||
def test_open_real(self):
|
||||
s = Session.open(MME_DATA / "3239")
|
||||
assert s.test_id == "3239"
|
||||
assert len(s) == 133
|
||||
|
||||
def test_full_pipeline(self):
|
||||
s = Session.open(MME_DATA / "3239")
|
||||
# Chain: get channel -> filter -> check
|
||||
ch = s.channel("11HEAD0000H3ACXP").transform.cfc(1000)
|
||||
assert ch.peak > 100 # Significant head acceleration
|
||||
|
||||
# Compute criteria
|
||||
criteria = s.compute_criteria()
|
||||
assert "HIC15" in criteria
|
||||
|
||||
# Evaluate
|
||||
result = s.evaluate("euro_ncap")
|
||||
assert result.stars is not None
|
||||
assert result.stars >= 0
|
||||
132
tests/test_template.py
Normal file
132
tests/test_template.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for template model and session persistence."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from impakt.template.library import TemplateLibrary
|
||||
from impakt.template.model import PlotDefinition, SessionState, TemplateSpec
|
||||
from impakt.template.session import SessionManager
|
||||
|
||||
|
||||
class TestTemplateSpec:
|
||||
def test_yaml_round_trip(self):
|
||||
template = TemplateSpec(
|
||||
name="Test Template",
|
||||
version=2,
|
||||
description="A test template",
|
||||
plots=[
|
||||
PlotDefinition(
|
||||
title="Head Acceleration",
|
||||
channel_patterns=["*HEAD*AC*"],
|
||||
x_cursors=(0.0, 0.1),
|
||||
)
|
||||
],
|
||||
default_cfc=1000,
|
||||
criteria=["hic15", "nij"],
|
||||
protocol="euro_ncap",
|
||||
)
|
||||
|
||||
yaml_str = template.to_yaml()
|
||||
restored = TemplateSpec.from_yaml(yaml_str)
|
||||
|
||||
assert restored.name == "Test Template"
|
||||
assert restored.version == 2
|
||||
assert restored.default_cfc == 1000
|
||||
assert len(restored.plots) == 1
|
||||
assert restored.plots[0].channel_patterns == ["*HEAD*AC*"]
|
||||
assert restored.criteria == ["hic15", "nij"]
|
||||
|
||||
def test_save_and_load(self, tmp_path):
|
||||
template = TemplateSpec(name="Saved Test", version=1)
|
||||
path = tmp_path / "test.yaml"
|
||||
template.save(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded = TemplateSpec.load(path)
|
||||
assert loaded.name == "Saved Test"
|
||||
|
||||
|
||||
class TestSessionState:
|
||||
def test_yaml_round_trip(self):
|
||||
state = SessionState(
|
||||
template_name="my_template",
|
||||
template_version=3,
|
||||
test_path="/data/test_001",
|
||||
overrides={"cfc": "600", "selected": ["ch1", "ch2"]},
|
||||
)
|
||||
yaml_str = state.to_yaml()
|
||||
restored = SessionState.from_yaml(yaml_str)
|
||||
|
||||
assert restored.template_name == "my_template"
|
||||
assert restored.template_version == 3
|
||||
assert restored.overrides["cfc"] == "600"
|
||||
|
||||
def test_save_and_load(self, tmp_path):
|
||||
state = SessionState(template_name="test")
|
||||
path = tmp_path / "session.yaml"
|
||||
state.save(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded = SessionState.load(path)
|
||||
assert loaded.template_name == "test"
|
||||
|
||||
|
||||
class TestTemplateLibrary:
|
||||
def test_empty_library(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
assert lib.list() == []
|
||||
assert len(lib) == 0
|
||||
|
||||
def test_save_and_list(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
template = TemplateSpec(name="My Template")
|
||||
lib.save(template)
|
||||
assert "my_template" in lib.list()
|
||||
assert len(lib) == 1
|
||||
|
||||
def test_get(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
lib.save(TemplateSpec(name="Getter Test", version=5))
|
||||
loaded = lib.get("getter_test")
|
||||
assert loaded.name == "Getter Test"
|
||||
assert loaded.version == 5
|
||||
|
||||
def test_delete(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
lib.save(TemplateSpec(name="To Delete"))
|
||||
assert lib.delete("to_delete")
|
||||
assert "to_delete" not in lib.list()
|
||||
|
||||
def test_get_missing_raises(self, tmp_path):
|
||||
lib = TemplateLibrary(tmp_path / "templates")
|
||||
with pytest.raises(FileNotFoundError):
|
||||
lib.get("nonexistent")
|
||||
|
||||
|
||||
class TestSessionManager:
|
||||
def test_create_and_save(self, tmp_path):
|
||||
mgr = SessionManager(tmp_path)
|
||||
mgr.state.template_name = "test_tmpl"
|
||||
mgr.save()
|
||||
assert mgr.has_session
|
||||
assert (tmp_path / ".impakt" / "session.yaml").exists()
|
||||
|
||||
def test_load_existing(self, tmp_path):
|
||||
# Save
|
||||
mgr1 = SessionManager(tmp_path)
|
||||
mgr1.state.template_name = "saved_tmpl"
|
||||
mgr1.state.overrides = {"key": "value"}
|
||||
mgr1.save()
|
||||
|
||||
# Load
|
||||
mgr2 = SessionManager(tmp_path)
|
||||
assert mgr2.state.template_name == "saved_tmpl"
|
||||
assert mgr2.state.overrides["key"] == "value"
|
||||
|
||||
def test_clear(self, tmp_path):
|
||||
mgr = SessionManager(tmp_path)
|
||||
mgr.save()
|
||||
assert mgr.has_session
|
||||
mgr.clear()
|
||||
assert not mgr.has_session
|
||||
75
tests/test_transform/test_math_resultant.py
Normal file
75
tests/test_transform/test_math_resultant.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Tests for math expressions and resultant computation."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from impakt.transform.math_expr import math_expr
|
||||
from impakt.transform.resultant import resultant_from_channels
|
||||
from impakt.transform.resample import trim, resample
|
||||
|
||||
|
||||
class TestMathExpr:
|
||||
def test_simple_expression(self, head_accel_x, head_accel_z):
|
||||
result = math_expr(
|
||||
expression="sqrt(a**2 + b**2)",
|
||||
channels={"a": head_accel_x, "b": head_accel_z},
|
||||
name="resultant_xz",
|
||||
unit="g",
|
||||
)
|
||||
assert result.name == "resultant_xz"
|
||||
assert result.unit == "g"
|
||||
assert result.peak > 0
|
||||
assert len(result.data) == len(head_accel_x.data)
|
||||
|
||||
def test_constant_expression(self, head_accel_x):
|
||||
result = math_expr(
|
||||
expression="a * 0 + 42.0",
|
||||
channels={"a": head_accel_x},
|
||||
name="constant",
|
||||
)
|
||||
assert np.allclose(result.data, 42.0)
|
||||
|
||||
def test_invalid_expression(self, head_accel_x):
|
||||
with pytest.raises(ValueError, match="Error evaluating"):
|
||||
math_expr(
|
||||
expression="invalid_func(a)",
|
||||
channels={"a": head_accel_x},
|
||||
)
|
||||
|
||||
def test_forbidden_expression(self, head_accel_x):
|
||||
with pytest.raises(ValueError, match="Forbidden"):
|
||||
math_expr(
|
||||
expression="__import__('os')",
|
||||
channels={"a": head_accel_x},
|
||||
)
|
||||
|
||||
|
||||
class TestResultant:
|
||||
def test_from_channels(self, head_accel_x, head_accel_y, head_accel_z):
|
||||
result = resultant_from_channels(head_accel_x, head_accel_y, head_accel_z)
|
||||
assert result.code.direction == "R"
|
||||
# Resultant >= any component
|
||||
assert result.peak >= head_accel_x.peak
|
||||
assert result.peak >= head_accel_y.peak
|
||||
|
||||
def test_from_two_channels(self, head_accel_x, head_accel_z):
|
||||
result = resultant_from_channels(head_accel_x, head_accel_z)
|
||||
assert result.peak > 0
|
||||
|
||||
def test_single_channel_raises(self):
|
||||
with pytest.raises(ValueError, match="At least one"):
|
||||
resultant_from_channels()
|
||||
|
||||
|
||||
class TestTrimResample:
|
||||
def test_trim(self, head_accel_x):
|
||||
trimmed = trim(head_accel_x, t_start=0.0, t_end=0.05)
|
||||
assert trimmed.time[0] >= 0.0
|
||||
assert trimmed.time[-1] <= 0.05
|
||||
assert len(trimmed.data) < len(head_accel_x.data)
|
||||
|
||||
def test_resample(self, head_accel_x):
|
||||
resampled = resample(head_accel_x, target_rate=5000.0)
|
||||
expected_samples = int(head_accel_x.duration * 5000.0)
|
||||
assert abs(len(resampled.data) - expected_samples) <= 2
|
||||
assert resampled.sample_rate == 5000.0
|
||||
@@ -19,11 +19,11 @@ class TestAppState:
|
||||
|
||||
def test_load_test(self):
|
||||
state = AppState()
|
||||
loaded = state.load_test(FIXTURE_DATA)
|
||||
session = state.load_test(FIXTURE_DATA)
|
||||
assert not state.is_empty
|
||||
assert loaded.test_id == "IMPAKT_SYNTH_001"
|
||||
assert loaded.channel_count == 26
|
||||
assert state.primary_test is loaded
|
||||
assert session.test_id == "IMPAKT_SYNTH_001"
|
||||
assert len(session) == 26
|
||||
assert state.primary_test is session
|
||||
|
||||
def test_load_multiple_tests(self):
|
||||
state = AppState()
|
||||
@@ -32,13 +32,13 @@ class TestAppState:
|
||||
if (MME_DATA / "VW1FGS15").exists():
|
||||
t2 = state.load_test(MME_DATA / "VW1FGS15")
|
||||
assert len(state.tests) == 2
|
||||
assert state.primary_test is t1 # First loaded is primary
|
||||
assert state.primary_test is t1
|
||||
assert state.total_channels == 26 + 10
|
||||
|
||||
def test_remove_test(self):
|
||||
state = AppState()
|
||||
loaded = state.load_test(FIXTURE_DATA)
|
||||
state.remove_test(loaded.test_id)
|
||||
session = state.load_test(FIXTURE_DATA)
|
||||
state.remove_test(session.test_id)
|
||||
assert state.is_empty
|
||||
|
||||
def test_get_channel(self):
|
||||
@@ -54,38 +54,38 @@ class TestAppState:
|
||||
ch = state.get_channel("IMPAKT_SYNTH_001", "NONEXISTENT")
|
||||
assert ch is None
|
||||
|
||||
def test_resolve_channel_with_key(self):
|
||||
def test_get_channel_via_session(self):
|
||||
"""Channels can be accessed through the Session scripting API."""
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
ch = state.resolve_channel("IMPAKT_SYNTH_001::11HEAD0000ACXA")
|
||||
assert ch is not None
|
||||
session = state.primary_test
|
||||
assert session is not None
|
||||
ch_handle = session.channel("11HEAD0000ACXA")
|
||||
assert ch_handle.name == "11HEAD0000ACXA"
|
||||
|
||||
def test_resolve_channel_primary_default(self):
|
||||
def test_session_fluent_transforms(self):
|
||||
"""Fluent transform chaining works through the Session API."""
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
ch = state.resolve_channel("11HEAD0000ACXA")
|
||||
assert ch is not None
|
||||
session = state.primary_test
|
||||
ch = session.channel("11HEAD0000ACXA")
|
||||
filtered = ch.transform.cfc(600).transform.y_align()
|
||||
assert filtered.raw.cfc_class == 600
|
||||
assert len(filtered.raw.transform_history) == 2
|
||||
|
||||
def test_resolve_channel_with_cfc(self):
|
||||
def test_session_compute_criteria(self):
|
||||
"""Session.compute_criteria() auto-detects channels."""
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
ch = state.resolve_channel("11HEAD0000ACXA", cfc_class=600)
|
||||
assert ch is not None
|
||||
assert ch.cfc_class == 600
|
||||
|
||||
def test_flat_channel_list(self):
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
items = state.flat_channel_list()
|
||||
assert len(items) == 26
|
||||
assert all("value" in item and "label" in item for item in items)
|
||||
criteria = state.primary_test.compute_criteria()
|
||||
assert len(criteria) > 0
|
||||
assert "HIC15" in criteria or "Chest Deflection" in criteria
|
||||
|
||||
def test_build_channel_tree(self):
|
||||
state = AppState()
|
||||
state.load_test(FIXTURE_DATA)
|
||||
tree = state.build_channel_tree()
|
||||
assert "IMPAKT_SYNTH_001" in tree
|
||||
# Should have hierarchical structure
|
||||
test_tree = tree["IMPAKT_SYNTH_001"]
|
||||
assert len(test_tree) > 0
|
||||
|
||||
@@ -94,15 +94,23 @@ class TestAppState:
|
||||
class TestAppStateRealData:
|
||||
def test_load_real_mme(self):
|
||||
state = AppState()
|
||||
loaded = state.load_test(MME_DATA / "3239")
|
||||
assert loaded.test_id == "3239"
|
||||
assert loaded.channel_count == 133
|
||||
session = state.load_test(MME_DATA / "3239")
|
||||
assert session.test_id == "3239"
|
||||
assert len(session) == 133
|
||||
|
||||
def test_channel_tree_real_data(self):
|
||||
state = AppState()
|
||||
state.load_test(MME_DATA / "3239")
|
||||
tree = state.build_channel_tree()
|
||||
assert "3239" in tree
|
||||
# Should contain "Driver" in some key
|
||||
test_tree = tree["3239"]
|
||||
assert any("Driver" in k for k in test_tree)
|
||||
|
||||
def test_session_evaluate_real_data(self):
|
||||
"""Full pipeline through Session API on real data."""
|
||||
state = AppState()
|
||||
state.load_test(MME_DATA / "3239")
|
||||
result = state.primary_test.evaluate("euro_ncap")
|
||||
assert result.stars is not None
|
||||
assert result.stars >= 0
|
||||
assert len(result.region_scores) > 0
|
||||
|
||||
Reference in New Issue
Block a user