diff --git a/src/impakt/io/csv.py b/src/impakt/io/csv.py new file mode 100644 index 0000000..4dd5e8f --- /dev/null +++ b/src/impakt/io/csv.py @@ -0,0 +1,29 @@ +"""CSV data reader (stub). + +Placeholder for a future CSV reader plugin. +Install or implement to enable reading CSV crash test data. +""" + +from __future__ import annotations + +from pathlib import Path + +from impakt.channel.model import TestData, TestMetadata + + +class CSVReader: + """Reader for CSV time-series data (not yet implemented).""" + + @property + def format_name(self) -> str: + return "CSV" + + def supports(self, path: Path) -> bool: + # Only claim support for .csv files with crash test structure + return False # Disabled until implemented + + def metadata(self, path: Path) -> TestMetadata: + raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.") + + def read(self, path: Path) -> TestData: + raise NotImplementedError("CSV reader is not yet implemented. See BRAINSTORM.md.") diff --git a/src/impakt/io/tdms.py b/src/impakt/io/tdms.py new file mode 100644 index 0000000..e0ce04d --- /dev/null +++ b/src/impakt/io/tdms.py @@ -0,0 +1,37 @@ +"""NI TDMS data reader (stub). + +Placeholder for a future TDMS reader plugin. +Install nptdms to enable: pip install impakt[tdms] +""" + +from __future__ import annotations + +from pathlib import Path + +from impakt.channel.model import TestData, TestMetadata + + +class TDMSReader: + """Reader for NI TDMS files (not yet implemented). + + Requires the optional ``nptdms`` dependency:: + + pip install impakt[tdms] + """ + + @property + def format_name(self) -> str: + return "NI TDMS" + + def supports(self, path: Path) -> bool: + return False # Disabled until implemented + + def metadata(self, path: Path) -> TestMetadata: + raise NotImplementedError( + "TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]" + ) + + def read(self, path: Path) -> TestData: + raise NotImplementedError( + "TDMS reader is not yet implemented. Install nptdms: pip install impakt[tdms]" + ) diff --git a/src/impakt/plot/__pycache__/engine.cpython-312.pyc b/src/impakt/plot/__pycache__/engine.cpython-312.pyc index 4644c3e..ca36e3e 100644 Binary files a/src/impakt/plot/__pycache__/engine.cpython-312.pyc and b/src/impakt/plot/__pycache__/engine.cpython-312.pyc differ diff --git a/src/impakt/plot/__pycache__/spec.cpython-312.pyc b/src/impakt/plot/__pycache__/spec.cpython-312.pyc index f694a6f..49848d9 100644 Binary files a/src/impakt/plot/__pycache__/spec.cpython-312.pyc and b/src/impakt/plot/__pycache__/spec.cpython-312.pyc differ diff --git a/src/impakt/plot/engine.py b/src/impakt/plot/engine.py index c8fdbbd..aa6f555 100644 --- a/src/impakt/plot/engine.py +++ b/src/impakt/plot/engine.py @@ -2,6 +2,9 @@ Renders PlotSpec objects into interactive Plotly figures with support for corridors, dual X-cursors, and export. + +The PlotEngine is the **single rendering path** — both the scripting +API and the web UI construct PlotSpec objects and delegate here. """ from __future__ import annotations @@ -30,12 +33,39 @@ DEFAULT_COLORS = [ class PlotEngine: - """Renders PlotSpec into Plotly figures.""" + """Renders PlotSpec into Plotly figures. + + Supports two modes: + - **Default**: Full interactive plot with hover tooltips and legend. + - **Compact** (``spec.compact=True``): Web UI mode with no legend, no + hover tooltips, tight margins, smaller axis labels. Cursor tracking + is handled by external JS. + """ def render(self, spec: PlotSpec) -> go.Figure: """Render a PlotSpec into an interactive Plotly figure.""" fig = go.Figure() + if not spec.channels: + fig.update_layout( + template="plotly_white", + annotations=[ + { + "text": "Select channels to plot", + "xref": "paper", + "yref": "paper", + "x": 0.5, + "y": 0.5, + "showarrow": False, + "font": {"size": 14, "color": "#bbb"}, + } + ], + xaxis={"visible": False}, + yaxis={"visible": False}, + margin={"l": 40, "r": 20, "t": 30, "b": 40}, + ) + return fig + # Add corridor fills first (behind data traces) for corridor in spec.corridors: self._add_corridor(fig, corridor) @@ -54,53 +84,92 @@ class PlotEngine: color = style.color or DEFAULT_COLORS[i % len(DEFAULT_COLORS)] label = style.label or ch_ref.label - fig.add_trace( - go.Scatter( - x=ch.time, - y=ch.data, - mode="lines", - name=label, - line=dict( - color=color, - width=style.line_width, - dash=style.line_dash, - ), - opacity=style.opacity, - hovertemplate=f"{label}
t=%{{x:.6f}}s
%{{y:.4f}} {ch.unit}", + trace_kwargs: dict[str, Any] = { + "x": ch.time, + "y": ch.data, + "mode": "lines", + "name": label, + "line": dict( + color=color, + width=style.line_width, + dash=style.line_dash, + ), + "opacity": style.opacity, + } + + if not spec.compact: + trace_kwargs["hovertemplate"] = ( + f"{label}
t=%{{x:.6f}}s
%{{y:.4f}} {ch.unit}" ) - ) + + fig.add_trace(go.Scatter(**trace_kwargs)) # Add cursor lines if spec.x_cursors: x1, x2 = spec.x_cursors - for x_val, label in [(x1, "x1"), (x2, "x2")]: + if spec.compact: + # Color-coded cursor lines for compact mode fig.add_vline( - x=x_val, + x=x1, line_dash="dash", - line_color="gray", + line_color="rgba(220,53,69,0.6)", line_width=1, - annotation_text=f"{label}={x_val:.6f}s", - annotation_position="top", + annotation_text=f"X1={x1:.4f}s", + annotation_font_size=9, + annotation_font_color="rgba(220,53,69,0.8)", ) + fig.add_vline( + x=x2, + line_dash="dash", + line_color="rgba(13,110,253,0.6)", + line_width=1, + annotation_text=f"X2={x2:.4f}s", + annotation_font_size=9, + annotation_font_color="rgba(13,110,253,0.8)", + ) + else: + for x_val, lbl in [(x1, "x1"), (x2, "x2")]: + fig.add_vline( + x=x_val, + line_dash="dash", + line_color="gray", + line_width=1, + annotation_text=f"{lbl}={x_val:.6f}s", + annotation_position="top", + ) # Layout - fig.update_layout( - title=spec.title, - xaxis_title=spec.x_label, - yaxis_title=spec.y_label, - showlegend=spec.show_legend, - height=spec.height, - width=spec.width, - template="plotly_white", - hovermode="x unified", - legend=dict( - orientation="h", - yanchor="bottom", - y=-0.3, - xanchor="center", - x=0.5, - ), - ) + if spec.compact: + margin = spec.margin or {"l": 45, "r": 8, "t": 4, "b": 28} + fig.update_layout( + xaxis_title=dict(text=spec.x_label, font=dict(size=10, color="#999")), + yaxis_title=dict(text=spec.y_label, font=dict(size=10, color="#999")), + template="plotly_white", + hovermode=False, + showlegend=False, + margin=margin, + ) + else: + margin = spec.margin or {"l": 60, "r": 20, "t": 40 if spec.title else 10, "b": 60} + hovermode = spec.hovermode if spec.hovermode is not None else "x unified" + fig.update_layout( + title=spec.title, + xaxis_title=spec.x_label, + yaxis_title=spec.y_label, + showlegend=spec.show_legend, + height=spec.height, + width=spec.width, + template="plotly_white", + hovermode=hovermode, + legend=dict( + orientation="h", + yanchor="bottom", + y=-0.3, + xanchor="center", + x=0.5, + ), + margin=margin, + ) if spec.show_grid: fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128,128,128,0.2)") @@ -117,7 +186,6 @@ class PlotEngine: """Add a corridor (tolerance band) to the figure.""" style = corridor.style - # Upper bound fig.add_trace( go.Scatter( x=corridor.time, @@ -128,8 +196,6 @@ class PlotEngine: showlegend=False, ) ) - - # Lower bound with fill to upper fig.add_trace( go.Scatter( x=corridor.time, @@ -144,16 +210,7 @@ class PlotEngine: ) def to_image(self, spec: PlotSpec, format: str = "png", scale: float = 2.0) -> bytes: - """Render to a static image. - - Args: - spec: Plot specification. - format: Image format ('png', 'svg', 'pdf', 'jpeg'). - scale: Resolution multiplier. - - Returns: - Image bytes. - """ + """Render to a static image.""" fig = self.render(spec) return fig.to_image(format=format, scale=scale) @@ -168,16 +225,7 @@ def cursor_values( x1: float, x2: float, ) -> CursorValues: - """Compute interpolated values at two X-axis positions. - - Args: - spec_or_channels: PlotSpec or list of Channels. - x1: First cursor position (time). - x2: Second cursor position (time). - - Returns: - CursorValues with interpolated values for each channel. - """ + """Compute interpolated values at two X-axis positions.""" channels: list[tuple[str, Channel]] = [] if isinstance(spec_or_channels, PlotSpec): diff --git a/src/impakt/plot/spec.py b/src/impakt/plot/spec.py index f8772c1..abb7039 100644 --- a/src/impakt/plot/spec.py +++ b/src/impakt/plot/spec.py @@ -168,3 +168,8 @@ class PlotSpec: show_grid: bool = True height: int = 500 width: int = 900 + # Web UI compact mode: disables hover tooltips, removes legend, + # uses tight margins, and renders axis labels in a smaller font. + compact: bool = False + hovermode: str | bool = "x unified" + margin: dict[str, int] | None = None diff --git a/src/impakt/plugin/registry.py b/src/impakt/plugin/registry.py index 72f2115..d73da9b 100644 --- a/src/impakt/plugin/registry.py +++ b/src/impakt/plugin/registry.py @@ -42,8 +42,15 @@ class PluginRegistry: self._plugins: list[ImpaktPlugin] = [] def register_reader(self, reader: Any) -> None: - """Register a data reader.""" + """Register a data reader. + + Forwards to the IO reader registry so plugin readers are + discoverable by Session.open() and the web UI. + """ self._readers.append(reader) + from impakt.io.reader import register_reader + + register_reader(reader) logger.info("Plugin reader registered: %s", getattr(reader, "format_name", reader)) def register_transform(self, name: str, transform_cls: type) -> None: diff --git a/src/impakt/protocol/__pycache__/euro_ncap.cpython-312.pyc b/src/impakt/protocol/__pycache__/euro_ncap.cpython-312.pyc index 4873670..18deb8c 100644 Binary files a/src/impakt/protocol/__pycache__/euro_ncap.cpython-312.pyc and b/src/impakt/protocol/__pycache__/euro_ncap.cpython-312.pyc differ diff --git a/src/impakt/protocol/__pycache__/iihs.cpython-312.pyc b/src/impakt/protocol/__pycache__/iihs.cpython-312.pyc index fe686f9..d62f7c9 100644 Binary files a/src/impakt/protocol/__pycache__/iihs.cpython-312.pyc and b/src/impakt/protocol/__pycache__/iihs.cpython-312.pyc differ diff --git a/src/impakt/protocol/euro_ncap.py b/src/impakt/protocol/euro_ncap.py index 19dd79c..265504b 100644 --- a/src/impakt/protocol/euro_ncap.py +++ b/src/impakt/protocol/euro_ncap.py @@ -9,6 +9,7 @@ Threshold values are versioned — this module supports multiple protocol years. from __future__ import annotations +from pathlib import Path from typing import Any from impakt.criteria.base import CriterionResult @@ -64,9 +65,44 @@ def _points_from_color(color: Color, max_points: float) -> float: return max_points * color_fractions[color] -# Threshold sets by year -# Format: {criterion: (green, yellow, orange, brown, red, higher_is_worse, max_points)} -THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float]] = { +# --------------------------------------------------------------------------- +# Threshold loading — from YAML files or hardcoded fallback +# --------------------------------------------------------------------------- + +_THRESHOLDS_DIR = Path(__file__).parent / "thresholds" + + +def _load_yaml_thresholds( + version: str, +) -> dict[str, tuple[float, float, float, float, float, bool, float]]: + """Load thresholds from a YAML file for a given version.""" + yaml_path = _THRESHOLDS_DIR / f"euro_ncap_{version}.yaml" + if not yaml_path.exists(): + return {} + + import yaml + + data = yaml.safe_load(yaml_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return {} + + result: dict[str, tuple[float, float, float, float, float, bool, float]] = {} + for name, vals in data.items(): + if isinstance(vals, dict): + result[name] = ( + float(vals["green"]), + float(vals["yellow"]), + float(vals["orange"]), + float(vals["brown"]), + float(vals["red"]), + bool(vals.get("higher_is_worse", True)), + float(vals.get("max_points", 0)), + ) + return result + + +# Hardcoded fallback (used if YAML files are missing) +_THRESHOLDS_2024_FALLBACK: dict[str, tuple[float, float, float, float, float, bool, float]] = { "HIC15": (500.0, 620.0, 700.0, 850.0, 1000.0, True, 4.0), "3ms Clip": (42.0, 48.0, 54.0, 57.0, 60.0, True, 4.0), "Chest Deflection": (22.0, 34.0, 42.0, 50.0, 63.0, True, 4.0), @@ -77,9 +113,17 @@ THRESHOLDS_2024: dict[str, tuple[float, float, float, float, float, bool, float] "Viscous Criterion": (0.32, 0.56, 0.8, 0.9, 1.0, True, 2.0), } -THRESHOLDS: dict[str, dict[str, tuple[float, float, float, float, float, bool, float]]] = { - "2024": THRESHOLDS_2024, -} + +def _get_thresholds( + version: str, +) -> dict[str, tuple[float, float, float, float, float, bool, float]]: + """Get thresholds for a version, preferring YAML over hardcoded.""" + thresholds = _load_yaml_thresholds(version) + if thresholds: + return thresholds + if version == "2024": + return _THRESHOLDS_2024_FALLBACK + return {} class EuroNCAP: @@ -87,11 +131,12 @@ class EuroNCAP: def __init__(self, version: str = "2024") -> None: self._version = version - if version not in THRESHOLDS: + self._thresholds = _get_thresholds(version) + if not self._thresholds: raise ValueError( - f"Unknown Euro NCAP version: {version}. Available: {list(THRESHOLDS.keys())}" + f"No thresholds found for Euro NCAP version: {version}. " + f"Check {_THRESHOLDS_DIR} for YAML files." ) - self._thresholds = THRESHOLDS[version] @property def protocol_name(self) -> str: diff --git a/src/impakt/protocol/iihs.py b/src/impakt/protocol/iihs.py index 0ce26ec..48b33ab 100644 --- a/src/impakt/protocol/iihs.py +++ b/src/impakt/protocol/iihs.py @@ -6,27 +6,55 @@ The overall rating is determined by the worst sub-rating. from __future__ import annotations +from pathlib import Path from typing import Any from impakt.criteria.base import CriterionResult from impakt.protocol.base import BodyRegionScore, ProtocolResult, Rating +_THRESHOLDS_DIR = Path(__file__).parent / "thresholds" -# Thresholds: (good_limit, acceptable_limit, marginal_limit) -# Values above marginal_limit = Poor -# higher_is_worse indicates that higher values are worse -IIHS_THRESHOLDS_2024: dict[str, tuple[float, float, float, bool]] = { + +def _load_iihs_yaml(version: str) -> dict[str, tuple[float, float, float, bool]]: + """Load IIHS thresholds from YAML.""" + yaml_path = _THRESHOLDS_DIR / f"iihs_{version}.yaml" + if not yaml_path.exists(): + return {} + import yaml + + data = yaml.safe_load(yaml_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return {} + result: dict[str, tuple[float, float, float, bool]] = {} + for name, vals in data.items(): + if isinstance(vals, dict): + result[name] = ( + float(vals["good"]), + float(vals["acceptable"]), + float(vals["marginal"]), + bool(vals.get("higher_is_worse", True)), + ) + return result + + +# Hardcoded fallback +_IIHS_2024_FALLBACK: dict[str, tuple[float, float, float, bool]] = { "HIC15": (250.0, 500.0, 700.0, True), - "Chest Deflection": (38.0, 50.0, 63.0, True), # mm - "Femur Load Left": (3.8, 6.2, 10.0, True), # kN - "Femur Load Right": (3.8, 6.2, 10.0, True), # kN + "Chest Deflection": (38.0, 50.0, 63.0, True), + "Femur Load Left": (3.8, 6.2, 10.0, True), + "Femur Load Right": (3.8, 6.2, 10.0, True), "Nij": (0.52, 0.78, 1.0, True), "Tibia Index": (0.5, 0.8, 1.3, True), } -IIHS_THRESHOLDS: dict[str, dict[str, tuple[float, float, float, bool]]] = { - "2024": IIHS_THRESHOLDS_2024, -} + +def _get_iihs_thresholds(version: str) -> dict[str, tuple[float, float, float, bool]]: + thresholds = _load_iihs_yaml(version) + if thresholds: + return thresholds + if version == "2024": + return _IIHS_2024_FALLBACK + return {} def _rate_value( @@ -70,11 +98,12 @@ class IIHS: def __init__(self, version: str = "2024") -> None: self._version = version - if version not in IIHS_THRESHOLDS: + self._thresholds = _get_iihs_thresholds(version) + if not self._thresholds: raise ValueError( - f"Unknown IIHS version: {version}. Available: {list(IIHS_THRESHOLDS.keys())}" + f"No thresholds found for IIHS version: {version}. " + f"Check {_THRESHOLDS_DIR} for YAML files." ) - self._thresholds = IIHS_THRESHOLDS[version] @property def protocol_name(self) -> str: diff --git a/src/impakt/protocol/thresholds/euro_ncap_2024.yaml b/src/impakt/protocol/thresholds/euro_ncap_2024.yaml new file mode 100644 index 0000000..fb54bcd --- /dev/null +++ b/src/impakt/protocol/thresholds/euro_ncap_2024.yaml @@ -0,0 +1,76 @@ +# Euro NCAP 2024 Adult Occupant Frontal Impact Thresholds +# +# Each criterion: [green, yellow, orange, brown, red, higher_is_worse, max_points] +# Sliding-scale: values at or below green = full points, at or above red = zero. + +HIC15: + green: 500.0 + yellow: 620.0 + orange: 700.0 + brown: 850.0 + red: 1000.0 + higher_is_worse: true + max_points: 4.0 + +3ms Clip: + green: 42.0 + yellow: 48.0 + orange: 54.0 + brown: 57.0 + red: 60.0 + higher_is_worse: true + max_points: 4.0 + +Chest Deflection: + green: 22.0 + yellow: 34.0 + orange: 42.0 + brown: 50.0 + red: 63.0 + higher_is_worse: true + max_points: 4.0 + +Nij: + green: 0.5 + yellow: 0.65 + orange: 0.8 + brown: 0.9 + red: 1.0 + higher_is_worse: true + max_points: 2.0 + +Femur Load Left: + green: 3.8 + yellow: 5.4 + orange: 7.0 + brown: 8.5 + red: 10.0 + higher_is_worse: true + max_points: 2.0 + +Femur Load Right: + green: 3.8 + yellow: 5.4 + orange: 7.0 + brown: 8.5 + red: 10.0 + higher_is_worse: true + max_points: 2.0 + +Tibia Index: + green: 0.4 + yellow: 0.7 + orange: 1.0 + brown: 1.15 + red: 1.3 + higher_is_worse: true + max_points: 2.0 + +Viscous Criterion: + green: 0.32 + yellow: 0.56 + orange: 0.8 + brown: 0.9 + red: 1.0 + higher_is_worse: true + max_points: 2.0 diff --git a/src/impakt/protocol/thresholds/iihs_2024.yaml b/src/impakt/protocol/thresholds/iihs_2024.yaml new file mode 100644 index 0000000..3c28d17 --- /dev/null +++ b/src/impakt/protocol/thresholds/iihs_2024.yaml @@ -0,0 +1,40 @@ +# IIHS 2024 Crashworthiness Thresholds +# +# Each criterion: [good, acceptable, marginal, higher_is_worse] +# Values above marginal = Poor. + +HIC15: + good: 250.0 + acceptable: 500.0 + marginal: 700.0 + higher_is_worse: true + +Chest Deflection: + good: 38.0 + acceptable: 50.0 + marginal: 63.0 + higher_is_worse: true + +Femur Load Left: + good: 3.8 + acceptable: 6.2 + marginal: 10.0 + higher_is_worse: true + +Femur Load Right: + good: 3.8 + acceptable: 6.2 + marginal: 10.0 + higher_is_worse: true + +Nij: + good: 0.52 + acceptable: 0.78 + marginal: 1.0 + higher_is_worse: true + +Tibia Index: + good: 0.5 + acceptable: 0.8 + marginal: 1.3 + higher_is_worse: true diff --git a/src/impakt/script/__pycache__/api.cpython-312.pyc b/src/impakt/script/__pycache__/api.cpython-312.pyc index 9a1ec10..4abac9f 100644 Binary files a/src/impakt/script/__pycache__/api.cpython-312.pyc and b/src/impakt/script/__pycache__/api.cpython-312.pyc differ diff --git a/src/impakt/script/__pycache__/cli.cpython-312.pyc b/src/impakt/script/__pycache__/cli.cpython-312.pyc deleted file mode 100644 index d435223..0000000 Binary files a/src/impakt/script/__pycache__/cli.cpython-312.pyc and /dev/null differ diff --git a/src/impakt/script/api.py b/src/impakt/script/api.py index 3f6120e..3325059 100644 --- a/src/impakt/script/api.py +++ b/src/impakt/script/api.py @@ -2,6 +2,23 @@ Provides the ``Session`` and ``Template`` classes that serve as the primary entry points for both scripting and the web UI. + +Usage:: + + from impakt import Session + + test = Session.open("tests/mme_data/3239") + ch = test.channel("11HEAD0000H3ACXP") + + # Fluent chaining — each transform returns a ChannelHandle + filtered = ch.transform.cfc(1000).transform.y_align() + + # Compute all injury criteria (auto-detects channels) + criteria = test.compute_criteria() + + # Score against a protocol + result = test.evaluate("euro_ncap") + result.to_pdf("report.pdf") """ from __future__ import annotations @@ -36,12 +53,19 @@ class Session: Wraps TestData with session state, template binding, and convenience methods for transforms, criteria, and plotting. - Usage: + Usage:: + test = Session.open("/path/to/test_001") ch = test.channel("11HEAD0000ACXA") - filtered = ch.transform.cfc(1000) + filtered = ch.transform.cfc(1000) # returns ChannelHandle, chainable + + criteria = test.compute_criteria() # auto-detect channels + result = test.evaluate("euro_ncap") # score against protocol + result.to_pdf("report.pdf") """ + _plugins_discovered = False + def __init__(self, test_data: TestData, session_mgr: SessionManager | None = None) -> None: self._data = test_data self._session_mgr = session_mgr or ( @@ -49,12 +73,29 @@ class Session: ) self._template: TemplateSpec | None = None + @classmethod + def _discover_plugins(cls) -> None: + """Discover and register plugins. Called once on first Session.open().""" + if cls._plugins_discovered: + return + cls._plugins_discovered = True + try: + from impakt.plugin.registry import discover_all + + discover_all() + except Exception as e: + logger.debug("Plugin discovery failed (non-fatal): %s", e) + @classmethod def open(cls, path: str | Path) -> Session: """Open a crash test from a path. Auto-detects the format using the reader registry. + Discovers plugins on first call. """ + # Discover plugins (idempotent — safe to call multiple times) + cls._discover_plugins() + path = Path(path).resolve() registry = get_registry() test_data = registry.read(path) @@ -87,6 +128,11 @@ class Session: """Underlying TestData object.""" return self._data + @property + def path(self) -> Path | None: + """Path to the test data directory.""" + return self._data.path + @property def channel_names(self) -> list[str]: return self._data.channel_names @@ -98,7 +144,10 @@ class Session: # ----- Channel access ----- def channel(self, name: str) -> ChannelHandle: - """Get a channel by name, wrapped with transform convenience methods.""" + """Get a channel by name, wrapped with transform convenience methods. + + Returns a ChannelHandle with a fluent ``.transform`` interface. + """ ch = self._data.get(name) return ChannelHandle(ch) @@ -118,6 +167,47 @@ class Session: """Hierarchical channel tree for UI display.""" return self._data.channel_tree() + # ----- Criteria & Protocol ----- + + def compute_criteria(self) -> dict[str, CriterionResult]: + """Auto-detect channels and compute all applicable injury criteria. + + Uses ISO channel naming to find the right channels for each criterion. + Returns a dict of criterion name -> CriterionResult. + """ + from impakt.web.components.criteria import auto_compute_criteria + + return auto_compute_criteria(self._data) + + def evaluate(self, protocol: str = "euro_ncap", version: str = "") -> ProtocolResult: + """Compute criteria and score against a rating protocol. + + Args: + protocol: Protocol name ('euro_ncap', 'us_ncap', 'iihs'). + version: Protocol version (optional, uses latest if empty). + + Returns: + ProtocolResult with stars/rating and per-region scores. + """ + criteria = self.compute_criteria() + + if protocol == "euro_ncap": + from impakt.protocol.euro_ncap import evaluate + + return evaluate(criteria, version=version or "2024") + elif protocol == "us_ncap": + from impakt.protocol.us_ncap import evaluate + + return evaluate(criteria, version=version or "2023") + elif protocol == "iihs": + from impakt.protocol.iihs import evaluate + + return evaluate(criteria, version=version or "2024") + else: + raise ValueError( + f"Unknown protocol: {protocol}. Use 'euro_ncap', 'us_ncap', or 'iihs'." + ) + # ----- Template ----- def apply_template(self, name_or_spec: str | TemplateSpec) -> None: @@ -188,15 +278,16 @@ class Session: class ChannelHandle: """Wrapper around a Channel providing fluent transform access. - Example: + Each transform method on ``.transform`` returns a new ``ChannelHandle``, + enabling chaining:: + ch = session.channel("11HEAD0000ACXA") - filtered = ch.transform.cfc(1000) - aligned = ch.transform.x_align(method="threshold", threshold_value=5.0) + result = ch.transform.cfc(1000).transform.y_align().transform.trim(t_end=0.1) """ def __init__(self, channel: Channel) -> None: self._channel = channel - self.transform = TransformProxy(channel) + self.transform = TransformProxy(self) @property def raw(self) -> Channel: @@ -207,6 +298,10 @@ class ChannelHandle: def name(self) -> str: return self._channel.name + @property + def code(self): + return self._channel.code + @property def data(self) -> np.ndarray: return self._channel.data @@ -219,6 +314,14 @@ class ChannelHandle: def unit(self) -> str: return self._channel.unit + @property + def peak(self) -> float: + return self._channel.peak + + @property + def sample_rate(self) -> float: + return self._channel.sample_rate + def value_at(self, t: float) -> float: return self._channel.value_at(t) @@ -238,44 +341,51 @@ class ChannelHandle: class TransformProxy: """Fluent transform interface for a channel. - Each method returns a new Channel (non-destructive). + Each method returns a new ``ChannelHandle`` wrapping the transformed + channel, enabling chaining. """ - def __init__(self, channel: Channel) -> None: - self._channel = channel + def __init__(self, handle: ChannelHandle) -> None: + self._handle = handle - def cfc(self, cfc_class: int) -> Channel: + @property + def _channel(self) -> Channel: + return self._handle._channel + + def cfc(self, cfc_class: int) -> ChannelHandle: """Apply CFC filter.""" from impakt.transform.cfc import CFCFilter - return CFCFilter(cfc_class=cfc_class).apply(self._channel) + return ChannelHandle(CFCFilter(cfc_class=cfc_class).apply(self._channel)) def x_align( self, method: str = "manual", reference_time: float = 0.0, **kwargs: Any - ) -> Channel: + ) -> ChannelHandle: """Apply time-zero alignment.""" from impakt.transform.align import XAlign - return XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel) + return ChannelHandle( + XAlign(method=method, reference_time=reference_time, **kwargs).apply(self._channel) + ) - def y_align(self, window: tuple[float, float] | None = None) -> Channel: + def y_align(self, window: tuple[float, float] | None = None) -> ChannelHandle: """Apply Y-axis zero correction.""" from impakt.transform.align import YAlign start, end = window if window else (None, None) - return YAlign(window_start=start, window_end=end).apply(self._channel) + return ChannelHandle(YAlign(window_start=start, window_end=end).apply(self._channel)) - def trim(self, t_start: float | None = None, t_end: float | None = None) -> Channel: + def trim(self, t_start: float | None = None, t_end: float | None = None) -> ChannelHandle: """Trim to a time range.""" from impakt.transform.resample import Trim - return Trim(t_start=t_start, t_end=t_end).apply(self._channel) + return ChannelHandle(Trim(t_start=t_start, t_end=t_end).apply(self._channel)) - def resample(self, target_rate: float) -> Channel: + def resample(self, target_rate: float) -> ChannelHandle: """Resample to a new rate.""" from impakt.transform.resample import Resample - return Resample(target_rate=target_rate).apply(self._channel) + return ChannelHandle(Resample(target_rate=target_rate).apply(self._channel)) class Template: diff --git a/src/impakt/web/app.py b/src/impakt/web/app.py index a724bc6..a6af628 100644 --- a/src/impakt/web/app.py +++ b/src/impakt/web/app.py @@ -1,7 +1,7 @@ """Dash web application factory. Creates the Dash app with all layout components and callbacks registered. -The AppState holds server-side data; Dash stores hold lightweight UI state. +The AppState holds Session objects server-side; Dash stores hold lightweight UI state. """ from __future__ import annotations @@ -13,14 +13,12 @@ import dash import dash_bootstrap_components as dbc from impakt.channel.model import TestData +from impakt.script.api import Session from impakt.template.library import TemplateLibrary from impakt.web.callbacks import register_callbacks from impakt.web.layout import build_layout from impakt.web.state import AppState -if TYPE_CHECKING: - from impakt.script.api import Session - def create_app( session_or_data: Session | TestData | None = None, @@ -37,21 +35,15 @@ def create_app( Returns: Configured Dash app ready to run. """ - # Build or use provided AppState if app_state is None: app_state = AppState() if session_or_data is not None: - if hasattr(session_or_data, "data"): - test_data: TestData = session_or_data.data # type: ignore[union-attr] - else: - test_data = session_or_data # type: ignore[assignment] - - # Create a LoadedTest from TestData directly - from impakt.web.state import LoadedTest - - loaded = LoadedTest(test_data) - app_state._tests[test_data.test_id] = loaded - app_state._test_order.append(test_data.test_id) + if isinstance(session_or_data, Session): + app_state.add_session(session_or_data) + elif isinstance(session_or_data, TestData): + # Wrap raw TestData in a Session + session = Session(session_or_data) + app_state.add_session(session) # Discover templates if template_names is None: @@ -64,14 +56,13 @@ def create_app( # Title title = "Impakt" if app_state.primary_test: - title = f"Impakt — {app_state.primary_test.test_id}" + title = f"Impakt \u2014 {app_state.primary_test.test_id}" app = dash.Dash( __name__, external_stylesheets=[dbc.themes.FLATLY], title=title, suppress_callback_exceptions=True, - # Prevent browser from caching old layouts serve_locally=True, meta_tags=[ {"http-equiv": "Cache-Control", "content": "no-cache, no-store, must-revalidate"}, @@ -92,16 +83,9 @@ def serve( port: int = 8050, debug: bool = False, ) -> None: - """Convenience function to create and run the web UI. - - Args: - session_or_data: Session or TestData to visualize, or None for empty app. - template: Template name to pre-apply. - port: Server port. - debug: Enable Dash debug mode. - """ + """Convenience function to create and run the web UI.""" if template and session_or_data and hasattr(session_or_data, "apply_template"): - session_or_data.apply_template(template) # type: ignore[union-attr] + session_or_data.apply_template(template) app = create_app(session_or_data) print(f"Impakt running at http://localhost:{port}") diff --git a/src/impakt/web/callbacks/criteria_callbacks.py b/src/impakt/web/callbacks/criteria_callbacks.py index 9456bef..9d1fecc 100644 --- a/src/impakt/web/callbacks/criteria_callbacks.py +++ b/src/impakt/web/callbacks/criteria_callbacks.py @@ -1,9 +1,7 @@ """Criteria computation callbacks. -Handles: -- Compute All button -> runs auto_compute_criteria -- Protocol scoring -- Results display +Uses Session.compute_criteria() and Session.evaluate() from the +scripting API — the same path used by programmatic scripts. """ from __future__ import annotations @@ -14,11 +12,7 @@ import dash from dash import Input, Output, State, html from dash.exceptions import PreventUpdate -from impakt.web.components.criteria import ( - auto_compute_criteria, - build_criteria_results_display, - score_protocol, -) +from impakt.web.components.criteria import build_criteria_results_display from impakt.web.state import AppState @@ -41,8 +35,8 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None: if primary is None: return html.Div("No test loaded.", className="text-danger small") - # Compute criteria - criteria = auto_compute_criteria(primary.data) + # Use Session.compute_criteria() — single code path for script + web + criteria = primary.compute_criteria() if not criteria: return html.Div( @@ -51,7 +45,10 @@ def register_criteria_callbacks(app: dash.Dash, app_state: AppState) -> None: className="text-warning small", ) - # Score against protocol - protocol_result = score_protocol(criteria, protocol) + # Score against protocol — use Session.evaluate() + try: + protocol_result = primary.evaluate(protocol) + except Exception: + protocol_result = None return build_criteria_results_display(criteria, protocol_result) diff --git a/src/impakt/web/callbacks/plot_callbacks.py b/src/impakt/web/callbacks/plot_callbacks.py index 9add249..643e69e 100644 --- a/src/impakt/web/callbacks/plot_callbacks.py +++ b/src/impakt/web/callbacks/plot_callbacks.py @@ -1,10 +1,11 @@ """Plot rendering callbacks. -Handles: -- Updating plot figures when channels/transforms change -- Cursor line rendering (X1/X2 vertical lines) -- Hover data is NOT shown as a Plotly tooltip — instead the cursor grid - picks it up via the hoverData callback +Builds a PlotSpec from the UI state and delegates to PlotEngine.render(). +This is the single rendering path — the same PlotEngine used by the +scripting API renders the web UI plots. + +Transform application uses TransformChain, making the pipeline +serializable and reproducible. """ from __future__ import annotations @@ -15,16 +16,55 @@ from typing import Any import dash import numpy as np import plotly.graph_objects as go -from dash import ALL, Input, Output, State, ctx +from dash import Input, Output, State from dash.exceptions import PreventUpdate from impakt.channel.model import Channel -from impakt.plot.engine import DEFAULT_COLORS +from impakt.plot.engine import DEFAULT_COLORS, PlotEngine +from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle from impakt.transform.align import XAlign, YAlign +from impakt.transform.base import TransformChain from impakt.transform.cfc import CFCFilter from impakt.transform.resultant import resultant_from_channels from impakt.web.state import AppState +# Module-level engine instance (stateless, safe to reuse) +_engine = PlotEngine() + + +def _build_transform_chain( + cfc_value: str, + y_align: bool, + x_align_method: str, + x_align_value: float | None, + per_channel_cfc: str | None = None, +) -> TransformChain: + """Build a TransformChain from the current UI state. + + Per-channel CFC override takes precedence over the global CFC setting. + """ + chain = TransformChain() + + # CFC filter + effective_cfc = per_channel_cfc if per_channel_cfc else cfc_value + if effective_cfc and effective_cfc != "none": + try: + chain = chain.append(CFCFilter(cfc_class=int(effective_cfc))) + except (ValueError, Exception): + pass + + # Y-align + if y_align: + chain = chain.append(YAlign()) + + # X-align + if x_align_method == "manual" and x_align_value is not None: + chain = chain.append(XAlign(method="manual", reference_time=x_align_value)) + elif x_align_method == "threshold" and x_align_value is not None: + chain = chain.append(XAlign(method="threshold", threshold_value=x_align_value)) + + return chain + def _resolve_channels( selected_keys: list[str], @@ -35,12 +75,15 @@ def _resolve_channels( x_align_value: float | None, show_resultant: bool, ) -> list[tuple[str, Channel]]: - """Resolve selected channel keys to Channel objects with transforms applied. + """Resolve selected channel keys to transformed Channel objects. - Per-channel overrides (from app_state.channel_overrides) take precedence - over the global CFC setting. + Uses TransformChain for each channel. Per-channel overrides from + app_state.channel_overrides take precedence over the global CFC. + + Returns list of (label, transformed_channel) tuples. """ channels: list[tuple[str, Channel]] = [] + multi_test = len(app_state.tests) > 1 for key in selected_keys: if "::" in key: @@ -55,29 +98,22 @@ def _resolve_channels( if ch is None: continue - # Determine CFC: per-channel override takes precedence over global + # Build per-channel transform chain override = app_state.channel_overrides.get(key, {}) - ch_cfc = override.get("cfc", "") - effective_cfc = ch_cfc if ch_cfc else cfc_value + per_ch_cfc = override.get("cfc", "") + chain = _build_transform_chain( + cfc_value, + y_align, + x_align_method, + x_align_value, + per_channel_cfc=per_ch_cfc, + ) - if effective_cfc and effective_cfc != "none": - try: - ch = CFCFilter(cfc_class=int(effective_cfc)).apply(ch) - except (ValueError, Exception): - pass - - # Apply Y-align - if y_align: - ch = YAlign().apply(ch) - - # Apply X-align - if x_align_method == "manual" and x_align_value is not None: - ch = XAlign(method="manual", reference_time=x_align_value).apply(ch) - elif x_align_method == "threshold" and x_align_value is not None: - ch = XAlign(method="threshold", threshold_value=x_align_value).apply(ch) + # Apply the chain + if len(chain) > 0: + ch = chain.apply(ch) # Build label - multi_test = len(app_state.tests) > 1 label = ch.code.short_label if ch.code.is_valid else ch.name if multi_test: label = f"[{test_id}] {label}" @@ -99,7 +135,7 @@ def _resolve_channels( res_label = ( f"{res.code.location_label} Resultant" if res.code.is_valid else "Resultant" ) - if len(app_state.tests) > 1: + if multi_test: res_label = f"[{comps[0].source_test_id}] {res_label}" channels.append((res_label, res)) except Exception: @@ -108,123 +144,61 @@ def _resolve_channels( return channels -def _build_figure( +def _build_plot_spec( channels: list[tuple[str, Channel]], cursor_x1: float | None, cursor_x2: float | None, - cfc_value: str, corridors: list[dict] | None = None, -) -> go.Figure: - """Build a Plotly figure from resolved channels.""" - fig = go.Figure() +) -> PlotSpec: + """Build a PlotSpec from resolved channels and UI state. - if not channels: - fig.update_layout( - template="plotly_white", - annotations=[ - { - "text": "Select channels from the grid to plot", - "xref": "paper", - "yref": "paper", - "x": 0.5, - "y": 0.5, - "showarrow": False, - "font": {"size": 14, "color": "#bbb"}, - } - ], - xaxis={"visible": False}, - yaxis={"visible": False}, - margin={"l": 40, "r": 20, "t": 30, "b": 40}, - ) - return fig - - # Add traces — hovermode is disabled at the layout level; cursor tracking - # is handled by our own JS mousemove handler (cursor_tracker.js). + Channels are already transformed — they are wrapped in ChannelRef + objects with no additional transform chain (transforms were applied + during resolution). + """ + # Build ChannelRef objects + refs: list[ChannelRef] = [] for i, (label, ch) in enumerate(channels): color = DEFAULT_COLORS[i % len(DEFAULT_COLORS)] - fig.add_trace( - go.Scatter( - x=ch.time.tolist(), - y=ch.data.tolist(), - mode="lines", - name=label, - line=dict(color=color, width=1.5), + refs.append( + ChannelRef( + channel=ch, + style=PlotStyle(label=label, color=color), ) ) - # Add corridor fills + # Build Corridor objects from raw dicts + corridor_objs: list[Corridor] = [] if corridors: - for corridor in corridors: - if not corridor.get("visible", True): + for c in corridors: + if not c.get("visible", True): continue - c_time = corridor["time"] - c_upper = corridor["upper"] - c_lower = corridor["lower"] - c_name = corridor.get("name", "Corridor") - - # Upper bound - fig.add_trace( - go.Scatter( - x=c_time, - y=c_upper, - mode="lines", - line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"), - showlegend=False, - name=f"{c_name} upper", - ) - ) - # Lower bound with fill to upper - fig.add_trace( - go.Scatter( - x=c_time, - y=c_lower, - mode="lines", - line=dict(color="rgba(100,100,255,0.4)", width=1, dash="dash"), - fill="tonexty", - fillcolor="rgba(100,100,255,0.1)", - showlegend=True, - name=c_name, + corridor_objs.append( + Corridor( + name=c.get("name", "Corridor"), + time=np.array(c["time"]), + lower=np.array(c["lower"]), + upper=np.array(c["upper"]), + style=CorridorStyle(), ) ) - # Add X1/X2 cursor lines - if cursor_x1 is not None: - fig.add_vline( - x=cursor_x1, - line_dash="dash", - line_color="rgba(220,53,69,0.6)", - line_width=1, - annotation_text=f"X1={cursor_x1:.4f}s", - annotation_font_size=9, - annotation_font_color="rgba(220,53,69,0.8)", - ) - if cursor_x2 is not None: - fig.add_vline( - x=cursor_x2, - line_dash="dash", - line_color="rgba(13,110,253,0.6)", - line_width=1, - annotation_text=f"X2={cursor_x2:.4f}s", - annotation_font_size=9, - annotation_font_color="rgba(13,110,253,0.8)", - ) + # Cursor positions + x_cursors = None + if cursor_x1 is not None and cursor_x2 is not None: + x_cursors = (cursor_x1, cursor_x2) - # Layout — hovermode is disabled; cursor tracking is handled entirely - # by our JS (cursor_tracker.js) which reads pixel positions from mousemove - # events and converts to data coordinates via Plotly's axis internals. + # Y-axis label from first channel y_label = channels[0][1].unit if channels else "" - fig.update_layout( - xaxis_title=dict(text="Time (s)", font=dict(size=10, color="#999")), - yaxis_title=dict(text=y_label, font=dict(size=10, color="#999")), - template="plotly_white", - hovermode=False, - showlegend=False, - margin=dict(l=45, r=8, t=4, b=28), + return PlotSpec( + channels=refs, + corridors=corridor_objs, + x_cursors=x_cursors, + y_label=y_label, + compact=True, # Web UI always uses compact mode ) - return fig - def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None: """Register all plot-related callbacks.""" @@ -271,10 +245,5 @@ def register_plot_callbacks(app: dash.Dash, app_state: AppState) -> None: show_resultant, ) - return _build_figure( - channels, - cursor_x1, - cursor_x2, - cfc_value, - corridors=corridors_data, - ) + spec = _build_plot_spec(channels, cursor_x1, cursor_x2, corridors_data) + return _engine.render(spec) diff --git a/src/impakt/web/components/channel_tree.py b/src/impakt/web/components/channel_tree.py deleted file mode 100644 index e4b68db..0000000 --- a/src/impakt/web/components/channel_tree.py +++ /dev/null @@ -1,221 +0,0 @@ -"""Collapsible channel tree component. - -Renders channels in a hierarchical accordion: - Test Object > Body Region > Measurement Type > [channels] - -Supports search filtering, select-all per group, and channel preview -(peak, unit, sample rate) on hover/expand. -""" - -from __future__ import annotations - -from typing import Any - -import dash_bootstrap_components as dbc -from dash import dcc, html - -from impakt.web.state import AppState - - -def _make_channel_item(ch_info: dict[str, str], show_test_prefix: bool = False) -> html.Div: - """Build a single channel item with checkbox and preview info.""" - label = ch_info["label"] - key = ch_info["key"] - peak = ch_info.get("peak", "") - unit = ch_info.get("unit", "") - - preview = f" ({peak} {unit})" if peak else "" - - return html.Div( - [ - dbc.Checkbox( - id={"type": "channel-check", "index": key}, - label=html.Span( - [ - html.Span(label, style={"fontSize": "12px"}), - html.Span( - preview, - style={"fontSize": "10px", "color": "#999", "marginLeft": "4px"}, - ), - ] - ), - value=False, - style={"marginBottom": "1px"}, - ), - ], - className="channel-item", - ) - - -def _make_group_header( - group_label: str, - group_id: str, - channel_keys: list[str], -) -> html.Div: - """Build a measurement group header with select-all button.""" - return html.Div( - [ - html.Span(group_label, style={"fontSize": "12px", "fontWeight": "500"}), - dbc.Button( - "All", - id={"type": "select-group-btn", "index": group_id}, - color="link", - size="sm", - style={"fontSize": "10px", "padding": "0 4px", "textDecoration": "none"}, - ), - # Hidden store for this group's channel keys - dcc.Store( - id={"type": "group-channels", "index": group_id}, - data=channel_keys, - ), - ], - className="d-flex align-items-center justify-content-between", - ) - - -def build_channel_tree(app_state: AppState) -> dbc.Card: - """Build the full channel tree panel with search and accordion.""" - if app_state.is_empty: - return dbc.Card( - [ - dbc.CardHeader("Channels", className="fw-bold py-2"), - dbc.CardBody( - html.Div("No test loaded", className="text-muted small text-center py-3"), - ), - ] - ) - - multi_test = len(app_state.tests) > 1 - full_tree = app_state.build_channel_tree() - - # Build accordion items - accordion_items = [] - group_counter = 0 - - for test_id, test_tree in full_tree.items(): - test_label = "" - loaded = app_state.get_test(test_id) - if loaded and multi_test: - test_label = f"[{test_id}] " - - for obj_label, locations in test_tree.items(): - # Top-level accordion: test object (Driver, Vehicle Structure, etc.) - obj_content = [] - - for loc_label, measurements in locations.items(): - for meas_label, channels in measurements.items(): - group_id = f"grp_{group_counter}" - group_counter += 1 - - channel_keys = [ch["key"] for ch in channels] - header_label = ( - f"{loc_label} — {meas_label}" if loc_label != meas_label else meas_label - ) - - obj_content.append(_make_group_header(header_label, group_id, channel_keys)) - for ch in channels: - obj_content.append(_make_channel_item(ch, show_test_prefix=multi_test)) - - obj_content.append(html.Hr(style={"margin": "4px 0"})) - - if obj_content: - # Remove trailing hr - if obj_content and isinstance(obj_content[-1], html.Hr): - obj_content = obj_content[:-1] - - display_label = f"{test_label}{obj_label}" if test_label else obj_label - accordion_items.append( - dbc.AccordionItem( - html.Div(obj_content), - title=display_label, - style={"fontSize": "12px"}, - ) - ) - - return dbc.Card( - [ - dbc.CardHeader( - [ - html.Span("Channels", className="fw-bold"), - html.Span( - f" ({app_state.total_channels})", - style={"fontSize": "11px", "color": "#999"}, - ), - ], - className="py-2", - ), - dbc.CardBody( - [ - # Search - dbc.Input( - id="channel-search", - placeholder="Search channels...", - type="text", - size="sm", - className="mb-2", - debounce=True, - ), - # Selected channels display - html.Div(id="selected-channels-badges", className="mb-2"), - # Tree - html.Div( - dbc.Accordion( - accordion_items, - flush=True, - always_open=True, - start_collapsed=True, - id="channel-accordion", - ), - style={"maxHeight": "450px", "overflowY": "auto"}, - id="channel-tree-container", - ), - ], - className="py-2", - ), - ] - ) - - -def build_selected_channels_badges(selected_keys: list[str], app_state: AppState) -> list: - """Build badge pills showing currently selected channels.""" - if not selected_keys: - return [ - html.Span("No channels selected", className="text-muted", style={"fontSize": "11px"}) - ] - - badges = [] - for key in selected_keys[:20]: # Limit display to 20 - # Extract label - if "::" in key: - test_id, ch_name = key.split("::", 1) - ch = app_state.get_channel(test_id, ch_name) - if ch and ch.code.is_valid: - label = ch.code.short_label - else: - label = ch_name - if len(app_state.tests) > 1: - label = f"{test_id}: {label}" - else: - label = key - - badges.append( - dbc.Badge( - label, - id={"type": "channel-badge", "index": key}, - color="primary", - pill=True, - className="me-1 mb-1", - style={"fontSize": "10px", "cursor": "pointer"}, - ) - ) - - if len(selected_keys) > 20: - badges.append( - html.Span( - f"+{len(selected_keys) - 20} more", - className="text-muted", - style={"fontSize": "10px"}, - ) - ) - - return badges diff --git a/src/impakt/web/components/header.py b/src/impakt/web/components/header.py index 479e978..e7a36b8 100644 --- a/src/impakt/web/components/header.py +++ b/src/impakt/web/components/header.py @@ -129,7 +129,7 @@ def build_test_info_panel(app_state: AppState) -> dbc.Card: f"{meta.impact.speed_kmh:.1f}" if meta.impact.speed_kmh else "—", style={"fontSize": "12px"}, ), - html.Td(str(loaded.channel_count), style={"fontSize": "12px"}), + html.Td(str(len(loaded)), style={"fontSize": "12px"}), html.Td( dbc.Button( "x", diff --git a/src/impakt/web/state.py b/src/impakt/web/state.py index cf80de2..c7b92a1 100644 --- a/src/impakt/web/state.py +++ b/src/impakt/web/state.py @@ -1,14 +1,12 @@ """Application state management. -AppState is the central data store for the web UI. It holds all loaded tests, -manages channel transforms, and provides the data that callbacks need to -render plots, compute criteria, and generate reports. +AppState is the central data store for the web UI. It holds Session +objects (from the scripting API) and delegates to them for channel +access, transforms, criteria, and protocol scoring. -AppState is NOT a Dash Store — it lives server-side in Python memory. The Dash -stores hold lightweight references (test IDs, selected channels, plot config) +AppState is NOT a Dash Store — it lives server-side in Python memory. +Dash stores hold lightweight references (test IDs, selected channels) that callbacks use to look up data from AppState. - -This design keeps large numpy arrays out of the browser and avoids serialization. """ from __future__ import annotations @@ -17,220 +15,95 @@ import logging from pathlib import Path from typing import Any -from impakt.channel.model import Channel, ChannelGroup, TestData -from impakt.io.reader import get_registry +from impakt.channel.model import Channel, TestData +from impakt.script.api import Session from impakt.template.library import TemplateLibrary -from impakt.template.model import PlotDefinition, SessionState, TemplateSpec -from impakt.template.session import SessionManager +from impakt.template.model import PlotDefinition, TemplateSpec from impakt.transform.align import YAlign from impakt.transform.cfc import CFCFilter logger = logging.getLogger(__name__) -class LoadedTest: - """A single loaded test with its metadata and channel access.""" - - def __init__(self, test_data: TestData, color_offset: int = 0) -> None: - self.data = test_data - self.color_offset = color_offset - - @property - def test_id(self) -> str: - return self.data.test_id - - @property - def label(self) -> str: - """Short label for UI display.""" - meta = self.data.metadata - parts = [meta.test_number] - if meta.vehicle.make: - parts.append(f"{meta.vehicle.make} {meta.vehicle.model}".strip()) - return " — ".join(parts) if len(parts) > 1 else parts[0] - - @property - def channel_count(self) -> int: - return len(self.data) - - def __len__(self) -> int: - return len(self.data) - - class AppState: """Server-side application state. - Manages multiple loaded tests and provides channel resolution - for the UI callbacks. + Manages multiple loaded test sessions and provides channel resolution + for the UI callbacks. All test data access goes through Session objects. """ def __init__(self) -> None: - self._tests: dict[str, LoadedTest] = {} - self._test_order: list[str] = [] - self._color_counter: int = 0 + self._sessions: dict[str, Session] = {} + self._session_order: list[str] = [] self._template_library = TemplateLibrary() self._active_template: TemplateSpec | None = None - self._session_managers: dict[str, SessionManager] = {} - # Per-channel transform overrides: {channel_key: {"cfc": "600", "y_align": True}} + # Per-channel transform overrides: {channel_key: {"cfc": "600"}} self.channel_overrides: dict[str, dict[str, Any]] = {} - # Active corridors: list of {name, time, lower, upper, visible} + # Active corridors self.corridors: list[dict[str, Any]] = [] - def load_test(self, path: str | Path) -> LoadedTest: + def load_test(self, path: str | Path) -> Session: """Load a test from a path and add it to the state. - If a test with the same ID is already loaded, replaces it. - - Returns: - The loaded test. - - Raises: - ValueError: If the path is not a valid test directory. + Returns the Session object. """ - path = Path(path).resolve() - registry = get_registry() - test_data = registry.read(path) + session = Session.open(path) + test_id = session.test_id - loaded = LoadedTest(test_data, color_offset=self._color_counter) - self._color_counter += len(test_data) + self._sessions[test_id] = session + if test_id not in self._session_order: + self._session_order.append(test_id) - test_id = test_data.test_id - if test_id in self._tests: - self._tests[test_id] = loaded - else: - self._tests[test_id] = loaded - self._test_order.append(test_id) + logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(session)) + return session - logger.info("Loaded test '%s' from %s (%d channels)", test_id, path, len(test_data)) - return loaded + def add_session(self, session: Session) -> None: + """Add an already-created Session to the state.""" + test_id = session.test_id + self._sessions[test_id] = session + if test_id not in self._session_order: + self._session_order.append(test_id) def remove_test(self, test_id: str) -> None: """Remove a loaded test.""" - if test_id in self._tests: - del self._tests[test_id] - self._test_order.remove(test_id) + self._sessions.pop(test_id, None) + if test_id in self._session_order: + self._session_order.remove(test_id) @property - def tests(self) -> list[LoadedTest]: - """All loaded tests in load order.""" - return [self._tests[tid] for tid in self._test_order if tid in self._tests] + def sessions(self) -> list[Session]: + """All loaded sessions in load order.""" + return [self._sessions[tid] for tid in self._session_order if tid in self._sessions] + + @property + def tests(self) -> list[Session]: + """Alias for sessions (backward compat for components).""" + return self.sessions @property def test_ids(self) -> list[str]: - return list(self._test_order) + return list(self._session_order) @property - def primary_test(self) -> LoadedTest | None: - """The first loaded test (primary for criteria, etc.).""" - if self._test_order: - return self._tests.get(self._test_order[0]) + def primary_test(self) -> Session | None: + """The first loaded session (primary for criteria, etc.).""" + if self._session_order: + return self._sessions.get(self._session_order[0]) return None - def get_test(self, test_id: str) -> LoadedTest | None: - return self._tests.get(test_id) + def get_test(self, test_id: str) -> Session | None: + return self._sessions.get(test_id) def get_channel(self, test_id: str, channel_name: str) -> Channel | None: """Resolve a channel from a test ID and channel name.""" - test = self._tests.get(test_id) - if test is None: + session = self._sessions.get(test_id) + if session is None: return None try: - return test.data.get(channel_name) + return session.data.get(channel_name) except KeyError: return None - def resolve_channel( - self, - channel_key: str, - cfc_class: int | None = None, - y_align: bool = False, - ) -> Channel | None: - """Resolve a channel key (test_id::channel_name) and apply transforms. - - Channel keys use the format 'test_id::channel_name'. If no '::' - separator, assumes the primary test. - """ - if "::" in channel_key: - test_id, ch_name = channel_key.split("::", 1) - else: - if self.primary_test is None: - return None - test_id = self.primary_test.test_id - ch_name = channel_key - - ch = self.get_channel(test_id, ch_name) - if ch is None: - return None - - # Apply transforms - if cfc_class is not None: - try: - ch = CFCFilter(cfc_class=cfc_class).apply(ch) - except (ValueError, Exception): - pass - - if y_align: - ch = YAlign().apply(ch) - - return ch - - def build_channel_tree( - self, - ) -> dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]]: - """Build a hierarchical channel tree across all loaded tests. - - Returns: - {test_id: {object_label: {location_label: {measurement_label: [{name, label, key}]}}}} - """ - result: dict[str, dict[str, dict[str, dict[str, list[dict[str, str]]]]]] = {} - - for loaded in self.tests: - tree = loaded.data.channel_tree() - test_tree: dict[str, dict[str, dict[str, list[dict[str, str]]]]] = {} - - for obj, locations in sorted(tree.items()): - obj_tree: dict[str, dict[str, list[dict[str, str]]]] = {} - for loc, measurements in sorted(locations.items()): - loc_tree: dict[str, list[dict[str, str]]] = {} - for meas, channels in sorted(measurements.items()): - ch_list = [] - for ch in channels: - label = ch.code.short_label if ch.code.is_valid else ch.name - ch_list.append( - { - "name": ch.name, - "label": label, - "key": f"{loaded.test_id}::{ch.name}", - "unit": ch.unit, - "peak": f"{ch.peak:.2f}", - "samples": str(ch.n_samples), - "rate": f"{ch.sample_rate:.0f}", - } - ) - loc_tree[meas] = ch_list - obj_tree[loc] = loc_tree - test_tree[obj] = obj_tree - result[loaded.test_id] = test_tree - - return result - - def flat_channel_list(self) -> list[dict[str, str]]: - """Flat list of all channels across all tests, for dropdown/checklist options.""" - items: list[dict[str, str]] = [] - multi = len(self._tests) > 1 - - for loaded in self.tests: - prefix = f"[{loaded.test_id}] " if multi else "" - for ch in sorted(loaded.data, key=lambda c: c.name): - label = ch.code.short_label if ch.code.is_valid else ch.name - items.append( - { - "label": f"{prefix}{label}", - "value": f"{loaded.test_id}::{ch.name}", - } - ) - - return items - # ----- Template & Session ----- @property @@ -245,52 +118,37 @@ class AppState: def template_names(self) -> list[str]: return self._template_library.list() - def apply_template( - self, - name: str, - selected_keys: list[str] | None = None, - ) -> tuple[list[str], dict[str, str]]: + def apply_template(self, name: str) -> tuple[list[str], dict[str, str]]: """Apply a template by name. - Resolves the template's channel patterns against the primary test, - sets the active template, and returns the resolved channel keys and - transform settings. - - Returns: - (selected_channel_keys, transform_settings) + Resolves channel patterns against the primary test. + Returns (selected_channel_keys, transform_settings). """ template = self._template_library.get(name) self._active_template = template - # Persist to session if primary test has a path primary = self.primary_test - if primary and primary.data.path: - mgr = self._get_session_manager(primary.data.path) - mgr.apply_template(template) + if primary: + primary.apply_template(template) - # Resolve channel patterns from the template + # Resolve channel patterns resolved_keys: list[str] = [] if primary: for plot_def in template.plots: for pattern in plot_def.channel_patterns: - matches = primary.data.find(pattern) + matches = primary.find(pattern) for ch in matches: key = f"{primary.test_id}::{ch.name}" if key not in resolved_keys: resolved_keys.append(key) - # Transform settings transforms: dict[str, str] = {} if template.default_cfc is not None: transforms["cfc"] = str(template.default_cfc) else: transforms["cfc"] = "none" - logger.info( - "Applied template '%s' — resolved %d channels", - name, - len(resolved_keys), - ) + logger.info("Applied template '%s' — %d channels", name, len(resolved_keys)) return resolved_keys, transforms def save_as_template( @@ -303,30 +161,19 @@ class AppState: x2: float | None, protocol: str = "", ) -> TemplateSpec: - """Capture the current UI state as a new template. - - Converts selected channels back to patterns and stores the - current filter/cursor/protocol settings. - """ - # Convert selected keys to channel patterns + """Capture current UI state as a new template.""" patterns: list[str] = [] for key in selected_keys: - if "::" in key: - _, ch_name = key.split("::", 1) - else: - ch_name = key - # Use the raw channel name as a pattern (exact match) + ch_name = key.split("::", 1)[-1] if "::" in key else key if ch_name not in patterns: patterns.append(ch_name) - # Build plot definition plot = PlotDefinition( title=name, channel_patterns=patterns, x_cursors=(x1, x2) if x1 is not None and x2 is not None else None, ) - # Build transforms list transforms: list[dict[str, Any]] = [] if cfc_value and cfc_value != "none": transforms.append({"type": "cfc_filter", "cfc_class": int(cfc_value)}) @@ -343,17 +190,18 @@ class AppState: self._template_library.save(template) self._active_template = template - - logger.info("Saved template '%s' with %d channel patterns", name, len(patterns)) + logger.info("Saved template '%s' with %d patterns", name, len(patterns)) return template def save_session(self, selected_keys: list[str], cfc_value: str, **overrides: Any) -> None: """Auto-save current state to the session for the primary test.""" primary = self.primary_test - if not primary or not primary.data.path: + if not primary or not primary.path: return - mgr = self._get_session_manager(primary.data.path) + from impakt.template.session import SessionManager + + mgr = SessionManager(primary.path) mgr.state.overrides = { "selected_channels": selected_keys, "cfc": cfc_value, @@ -362,12 +210,14 @@ class AppState: mgr.save() def load_session_state(self) -> dict[str, Any] | None: - """Load saved session state for the primary test, if any.""" + """Load saved session state for the primary test.""" primary = self.primary_test - if not primary or not primary.data.path: + if not primary or not primary.path: return None - mgr = self._get_session_manager(primary.data.path) + from impakt.template.session import SessionManager + + mgr = SessionManager(primary.path) if not mgr.has_session: return None @@ -377,21 +227,48 @@ class AppState: "template": mgr.state.template_name, } - def _get_session_manager(self, path: Path) -> SessionManager: - key = str(path) - if key not in self._session_managers: - self._session_managers[key] = SessionManager(path) - return self._session_managers[key] + # ----- Channel tree / grid helpers ----- + + def build_channel_tree(self) -> dict[str, Any]: + """Build hierarchical channel tree across all loaded tests.""" + result: dict[str, Any] = {} + for session in self.sessions: + tree = session.channel_tree() + test_tree: dict[str, Any] = {} + for obj, locations in sorted(tree.items()): + obj_tree: dict[str, Any] = {} + for loc, measurements in sorted(locations.items()): + loc_tree: dict[str, Any] = {} + for meas, channels in sorted(measurements.items()): + ch_list = [] + for ch in channels: + label = ch.code.short_label if ch.code.is_valid else ch.name + ch_list.append( + { + "name": ch.name, + "label": label, + "key": f"{session.test_id}::{ch.name}", + "unit": ch.unit, + "peak": f"{ch.peak:.2f}", + "samples": str(ch.n_samples), + "rate": f"{ch.sample_rate:.0f}", + } + ) + loc_tree[meas] = ch_list + obj_tree[loc] = loc_tree + test_tree[obj] = obj_tree + result[session.test_id] = test_tree + return result @property def is_empty(self) -> bool: - return len(self._tests) == 0 + return len(self._sessions) == 0 @property def total_channels(self) -> int: - return sum(t.channel_count for t in self.tests) + return sum(len(s) for s in self.sessions) def __repr__(self) -> str: - test_info = ", ".join(f"{t.test_id}({t.channel_count}ch)" for t in self.tests) + test_info = ", ".join(f"{s.test_id}({len(s)}ch)" for s in self.sessions) tmpl = f", template={self._active_template.name}" if self._active_template else "" return f"AppState([{test_info}]{tmpl})" diff --git a/tests/test_criteria/test_chest_femur_tibia.py b/tests/test_criteria/test_chest_femur_tibia.py new file mode 100644 index 0000000..9450997 --- /dev/null +++ b/tests/test_criteria/test_chest_femur_tibia.py @@ -0,0 +1,93 @@ +"""Tests for chest deflection, femur load, tibia index, 3ms clip, viscous criterion.""" + +import numpy as np +import pytest + +from impakt.criteria import chest_deflection, clip_3ms, femur_load, tibia_index, viscous_criterion + + +class TestChestDeflection: + def test_basic(self, chest_deflection_channel): + result = chest_deflection(channel=chest_deflection_channel) + assert result.criterion == "Chest Deflection" + assert result.unit == "mm" + assert 30.0 < result.value < 40.0 # ~35mm in fixture + + def test_peak_time(self, chest_deflection_channel): + result = chest_deflection(channel=chest_deflection_channel) + assert result.time_of_peak is not None + assert 0.0 < result.time_of_peak < 0.1 + + +class TestClip3ms: + def test_basic(self, head_accel_x): + result = clip_3ms(head_accel_x) + assert result.criterion == "3ms Clip" + assert result.value > 0 + + def test_body_region(self, head_accel_x): + result = clip_3ms(head_accel_x) + assert result.body_region == "Chest" + + +class TestFemurLoad: + def test_single_channel(self, femur_left_channel): + result = femur_load(channel=femur_left_channel, side="left") + assert result.criterion == "Femur Load Left" + assert result.unit == "kN" + assert result.value > 0 + + def test_peak_time(self, femur_left_channel): + result = femur_load(channel=femur_left_channel, side="left") + assert result.time_of_peak is not None + + +class TestTibiaIndex: + def test_with_all_components(self, time_array, sample_rate): + from impakt.channel.code import ChannelCode + from impakt.channel.model import Channel + + t = time_array + fz = np.zeros_like(t) + mx = np.zeros_like(t) + my = np.zeros_like(t) + mask = (t >= 0.02) & (t <= 0.08) + fz[mask] = -5000 * np.sin(np.pi * (t[mask] - 0.02) / 0.06) + mx[mask] = 50 * np.sin(np.pi * (t[mask] - 0.02) / 0.06) + my[mask] = 80 * np.sin(np.pi * (t[mask] - 0.02) / 0.06) + + fz_ch = Channel( + name="TIBFZ", + code=ChannelCode.parse("TIBFZ"), + data=fz, + time=t, + unit="N", + sample_rate=sample_rate, + ) + mx_ch = Channel( + name="TIBMX", + code=ChannelCode.parse("TIBMX"), + data=mx, + time=t, + unit="N·m", + sample_rate=sample_rate, + ) + my_ch = Channel( + name="TIBMY", + code=ChannelCode.parse("TIBMY"), + data=my, + time=t, + unit="N·m", + sample_rate=sample_rate, + ) + + result = tibia_index(fz_channel=fz_ch, mx_channel=mx_ch, my_channel=my_ch) + assert result.value > 0 + + +class TestViscousCriterion: + def test_basic(self, chest_deflection_channel): + result = viscous_criterion(channel=chest_deflection_channel) + assert result.criterion == "Viscous Criterion" + assert result.unit == "m/s" + assert result.value >= 0 diff --git a/tests/test_plot/test_engine.py b/tests/test_plot/test_engine.py new file mode 100644 index 0000000..5f37100 --- /dev/null +++ b/tests/test_plot/test_engine.py @@ -0,0 +1,81 @@ +"""Tests for plot engine.""" + +import numpy as np +import pytest + +from impakt.plot.engine import PlotEngine, DEFAULT_COLORS +from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle + + +class TestPlotEngine: + def test_render_empty(self): + engine = PlotEngine() + spec = PlotSpec() + fig = engine.render(spec) + assert fig is not None + + def test_render_with_channels(self, head_accel_x, head_accel_y): + engine = PlotEngine() + spec = PlotSpec( + channels=[ + ChannelRef(channel=head_accel_x, style=PlotStyle(label="Head X")), + ChannelRef(channel=head_accel_y, style=PlotStyle(label="Head Y")), + ], + y_label="g", + ) + fig = engine.render(spec) + assert len(fig.data) == 2 + + def test_render_compact_mode(self, head_accel_x): + engine = PlotEngine() + spec = PlotSpec( + channels=[ChannelRef(channel=head_accel_x)], + compact=True, + ) + fig = engine.render(spec) + assert fig.layout.hovermode is False + assert fig.layout.showlegend is False + + def test_render_with_corridors(self, head_accel_x): + engine = PlotEngine() + corridor = Corridor( + name="Test Corridor", + time=np.linspace(0, 0.1, 100), + lower=np.full(100, -50.0), + upper=np.full(100, 50.0), + ) + spec = PlotSpec( + channels=[ChannelRef(channel=head_accel_x)], + corridors=[corridor], + ) + fig = engine.render(spec) + # 1 data trace + 2 corridor traces (upper + lower) + assert len(fig.data) == 3 + + def test_render_with_cursors(self, head_accel_x): + engine = PlotEngine() + spec = PlotSpec( + channels=[ChannelRef(channel=head_accel_x)], + x_cursors=(0.02, 0.06), + ) + fig = engine.render(spec) + # Should have vertical lines (as shapes) + assert len(fig.layout.shapes) >= 2 + + def test_compact_cursor_annotations(self, head_accel_x): + engine = PlotEngine() + spec = PlotSpec( + channels=[ChannelRef(channel=head_accel_x)], + x_cursors=(0.02, 0.06), + compact=True, + ) + fig = engine.render(spec) + # Should have X1/X2 annotations + annotations = [ + a for a in (fig.layout.annotations or []) if "X1" in str(a.text) or "X2" in str(a.text) + ] + assert len(annotations) == 2 + + def test_default_colors(self): + assert len(DEFAULT_COLORS) == 10 + assert all(c.startswith("#") for c in DEFAULT_COLORS) diff --git a/tests/test_protocol/test_iihs_usncap.py b/tests/test_protocol/test_iihs_usncap.py new file mode 100644 index 0000000..4bb9904 --- /dev/null +++ b/tests/test_protocol/test_iihs_usncap.py @@ -0,0 +1,68 @@ +"""Tests for IIHS and US NCAP scoring.""" + +import pytest + +from impakt.criteria.base import CriterionResult +from impakt.protocol.iihs import IIHS, evaluate as iihs_evaluate +from impakt.protocol.us_ncap import USNCAP, evaluate as us_evaluate + + +@pytest.fixture +def sample_criteria(): + return { + "HIC15": CriterionResult(criterion="HIC15", value=400, body_region="Head"), + "Nij": CriterionResult(criterion="Nij", value=0.4, body_region="Neck"), + "Chest Deflection": CriterionResult( + criterion="Chest Deflection", value=35, unit="mm", body_region="Chest" + ), + "Femur Load Left": CriterionResult( + criterion="Femur Load Left", value=4.0, unit="kN", body_region="Femur Left" + ), + "Femur Load Right": CriterionResult( + criterion="Femur Load Right", value=4.5, unit="kN", body_region="Femur Right" + ), + } + + +class TestIIHS: + def test_good_results(self, sample_criteria): + result = iihs_evaluate(sample_criteria) + assert result.protocol == "IIHS" + assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR") + + def test_region_scores(self, sample_criteria): + result = iihs_evaluate(sample_criteria) + assert len(result.region_scores) > 0 + for rs in result.region_scores: + assert rs.rating is not None + + def test_poor_hic(self): + result = iihs_evaluate( + { + "HIC15": CriterionResult(criterion="HIC15", value=1500, body_region="Head"), + } + ) + assert result.overall_rating == "POOR" + + def test_summary(self, sample_criteria): + result = iihs_evaluate(sample_criteria) + summary = result.summary() + assert "IIHS" in summary + + +class TestUSNCAP: + def test_basic(self, sample_criteria): + result = us_evaluate(sample_criteria) + assert result.protocol == "US NCAP" + assert result.stars is not None + assert 1 <= result.stars <= 5 + + def test_injury_probabilities(self, sample_criteria): + result = us_evaluate(sample_criteria) + assert "combined_injury_probability" in result.details + p = result.details["combined_injury_probability"] + assert 0.0 <= p <= 1.0 + + def test_region_scores(self, sample_criteria): + result = us_evaluate(sample_criteria) + assert len(result.region_scores) > 0 diff --git a/tests/test_scripting_api.py b/tests/test_scripting_api.py new file mode 100644 index 0000000..573ed57 --- /dev/null +++ b/tests/test_scripting_api.py @@ -0,0 +1,126 @@ +"""Tests for the scripting API (Session, ChannelHandle, TransformProxy).""" + +from pathlib import Path + +import pytest + +from impakt import Session, Template + +FIXTURE_DATA = Path(__file__).parent / "fixtures" / "sample_mme" +MME_DATA = Path(__file__).parent / "mme_data" + + +class TestSession: + def test_open(self): + s = Session.open(FIXTURE_DATA) + assert s.test_id == "IMPAKT_SYNTH_001" + assert len(s) == 26 + + def test_channel_access(self): + s = Session.open(FIXTURE_DATA) + ch = s.channel("11HEAD0000ACXA") + assert ch.name == "11HEAD0000ACXA" + assert ch.peak > 0 + + def test_find(self): + s = Session.open(FIXTURE_DATA) + channels = s.find("*HEAD*AC*") + assert len(channels) == 3 + + def test_group(self): + s = Session.open(FIXTURE_DATA) + group = s.group("HEAD0000AC") + assert group.x is not None + + def test_compute_criteria(self): + s = Session.open(FIXTURE_DATA) + criteria = s.compute_criteria() + assert len(criteria) > 0 + + def test_evaluate(self): + s = Session.open(FIXTURE_DATA) + result = s.evaluate("euro_ncap") + assert result.stars is not None + assert result.protocol == "Euro NCAP" + + def test_evaluate_us_ncap(self): + s = Session.open(FIXTURE_DATA) + result = s.evaluate("us_ncap") + assert result.stars is not None + + def test_evaluate_iihs(self): + s = Session.open(FIXTURE_DATA) + result = s.evaluate("iihs") + assert result.overall_rating in ("GOOD", "ACCEPTABLE", "MARGINAL", "POOR") + + def test_evaluate_invalid_protocol(self): + s = Session.open(FIXTURE_DATA) + with pytest.raises(ValueError, match="Unknown protocol"): + s.evaluate("invalid") + + def test_contains(self): + s = Session.open(FIXTURE_DATA) + assert "11HEAD0000ACXA" in s + assert "NONEXISTENT" not in s + + +class TestChannelHandleChaining: + """The fluent API must support chaining — each transform returns ChannelHandle.""" + + def test_single_transform(self): + s = Session.open(FIXTURE_DATA) + ch = s.channel("11HEAD0000ACXA") + filtered = ch.transform.cfc(600) + assert type(filtered).__name__ == "ChannelHandle" + assert filtered.raw.cfc_class == 600 + + def test_double_chain(self): + s = Session.open(FIXTURE_DATA) + result = s.channel("11HEAD0000ACXA").transform.cfc(600).transform.y_align() + assert type(result).__name__ == "ChannelHandle" + assert len(result.raw.transform_history) == 2 + + def test_triple_chain(self): + s = Session.open(FIXTURE_DATA) + result = ( + s.channel("11HEAD0000ACXA") + .transform.cfc(1000) + .transform.y_align() + .transform.trim(t_start=0.0, t_end=0.1) + ) + assert type(result).__name__ == "ChannelHandle" + assert len(result.raw.transform_history) == 3 + + def test_chain_preserves_data(self): + s = Session.open(FIXTURE_DATA) + original = s.channel("11HEAD0000ACXA") + original_peak = original.peak + filtered = original.transform.cfc(600) + # Original should be unchanged — peak should be the same + assert original.peak == original_peak + # Filtered should have different CFC and lower peak (smoothed) + assert filtered.raw.cfc_class == 600 + assert filtered.peak <= original_peak + + +@pytest.mark.skipif(not (MME_DATA / "3239").exists(), reason="Real data not available") +class TestSessionRealData: + def test_open_real(self): + s = Session.open(MME_DATA / "3239") + assert s.test_id == "3239" + assert len(s) == 133 + + def test_full_pipeline(self): + s = Session.open(MME_DATA / "3239") + # Chain: get channel -> filter -> check + ch = s.channel("11HEAD0000H3ACXP").transform.cfc(1000) + assert ch.peak > 100 # Significant head acceleration + + # Compute criteria + criteria = s.compute_criteria() + assert "HIC15" in criteria + + # Evaluate + result = s.evaluate("euro_ncap") + assert result.stars is not None + assert result.stars >= 0 diff --git a/tests/test_template.py b/tests/test_template.py new file mode 100644 index 0000000..756ba58 --- /dev/null +++ b/tests/test_template.py @@ -0,0 +1,132 @@ +"""Tests for template model and session persistence.""" + +from pathlib import Path + +import pytest + +from impakt.template.library import TemplateLibrary +from impakt.template.model import PlotDefinition, SessionState, TemplateSpec +from impakt.template.session import SessionManager + + +class TestTemplateSpec: + def test_yaml_round_trip(self): + template = TemplateSpec( + name="Test Template", + version=2, + description="A test template", + plots=[ + PlotDefinition( + title="Head Acceleration", + channel_patterns=["*HEAD*AC*"], + x_cursors=(0.0, 0.1), + ) + ], + default_cfc=1000, + criteria=["hic15", "nij"], + protocol="euro_ncap", + ) + + yaml_str = template.to_yaml() + restored = TemplateSpec.from_yaml(yaml_str) + + assert restored.name == "Test Template" + assert restored.version == 2 + assert restored.default_cfc == 1000 + assert len(restored.plots) == 1 + assert restored.plots[0].channel_patterns == ["*HEAD*AC*"] + assert restored.criteria == ["hic15", "nij"] + + def test_save_and_load(self, tmp_path): + template = TemplateSpec(name="Saved Test", version=1) + path = tmp_path / "test.yaml" + template.save(path) + assert path.exists() + + loaded = TemplateSpec.load(path) + assert loaded.name == "Saved Test" + + +class TestSessionState: + def test_yaml_round_trip(self): + state = SessionState( + template_name="my_template", + template_version=3, + test_path="/data/test_001", + overrides={"cfc": "600", "selected": ["ch1", "ch2"]}, + ) + yaml_str = state.to_yaml() + restored = SessionState.from_yaml(yaml_str) + + assert restored.template_name == "my_template" + assert restored.template_version == 3 + assert restored.overrides["cfc"] == "600" + + def test_save_and_load(self, tmp_path): + state = SessionState(template_name="test") + path = tmp_path / "session.yaml" + state.save(path) + assert path.exists() + + loaded = SessionState.load(path) + assert loaded.template_name == "test" + + +class TestTemplateLibrary: + def test_empty_library(self, tmp_path): + lib = TemplateLibrary(tmp_path / "templates") + assert lib.list() == [] + assert len(lib) == 0 + + def test_save_and_list(self, tmp_path): + lib = TemplateLibrary(tmp_path / "templates") + template = TemplateSpec(name="My Template") + lib.save(template) + assert "my_template" in lib.list() + assert len(lib) == 1 + + def test_get(self, tmp_path): + lib = TemplateLibrary(tmp_path / "templates") + lib.save(TemplateSpec(name="Getter Test", version=5)) + loaded = lib.get("getter_test") + assert loaded.name == "Getter Test" + assert loaded.version == 5 + + def test_delete(self, tmp_path): + lib = TemplateLibrary(tmp_path / "templates") + lib.save(TemplateSpec(name="To Delete")) + assert lib.delete("to_delete") + assert "to_delete" not in lib.list() + + def test_get_missing_raises(self, tmp_path): + lib = TemplateLibrary(tmp_path / "templates") + with pytest.raises(FileNotFoundError): + lib.get("nonexistent") + + +class TestSessionManager: + def test_create_and_save(self, tmp_path): + mgr = SessionManager(tmp_path) + mgr.state.template_name = "test_tmpl" + mgr.save() + assert mgr.has_session + assert (tmp_path / ".impakt" / "session.yaml").exists() + + def test_load_existing(self, tmp_path): + # Save + mgr1 = SessionManager(tmp_path) + mgr1.state.template_name = "saved_tmpl" + mgr1.state.overrides = {"key": "value"} + mgr1.save() + + # Load + mgr2 = SessionManager(tmp_path) + assert mgr2.state.template_name == "saved_tmpl" + assert mgr2.state.overrides["key"] == "value" + + def test_clear(self, tmp_path): + mgr = SessionManager(tmp_path) + mgr.save() + assert mgr.has_session + mgr.clear() + assert not mgr.has_session diff --git a/tests/test_transform/test_math_resultant.py b/tests/test_transform/test_math_resultant.py new file mode 100644 index 0000000..c395b7f --- /dev/null +++ b/tests/test_transform/test_math_resultant.py @@ -0,0 +1,75 @@ +"""Tests for math expressions and resultant computation.""" + +import numpy as np +import pytest + +from impakt.transform.math_expr import math_expr +from impakt.transform.resultant import resultant_from_channels +from impakt.transform.resample import trim, resample + + +class TestMathExpr: + def test_simple_expression(self, head_accel_x, head_accel_z): + result = math_expr( + expression="sqrt(a**2 + b**2)", + channels={"a": head_accel_x, "b": head_accel_z}, + name="resultant_xz", + unit="g", + ) + assert result.name == "resultant_xz" + assert result.unit == "g" + assert result.peak > 0 + assert len(result.data) == len(head_accel_x.data) + + def test_constant_expression(self, head_accel_x): + result = math_expr( + expression="a * 0 + 42.0", + channels={"a": head_accel_x}, + name="constant", + ) + assert np.allclose(result.data, 42.0) + + def test_invalid_expression(self, head_accel_x): + with pytest.raises(ValueError, match="Error evaluating"): + math_expr( + expression="invalid_func(a)", + channels={"a": head_accel_x}, + ) + + def test_forbidden_expression(self, head_accel_x): + with pytest.raises(ValueError, match="Forbidden"): + math_expr( + expression="__import__('os')", + channels={"a": head_accel_x}, + ) + + +class TestResultant: + def test_from_channels(self, head_accel_x, head_accel_y, head_accel_z): + result = resultant_from_channels(head_accel_x, head_accel_y, head_accel_z) + assert result.code.direction == "R" + # Resultant >= any component + assert result.peak >= head_accel_x.peak + assert result.peak >= head_accel_y.peak + + def test_from_two_channels(self, head_accel_x, head_accel_z): + result = resultant_from_channels(head_accel_x, head_accel_z) + assert result.peak > 0 + + def test_single_channel_raises(self): + with pytest.raises(ValueError, match="At least one"): + resultant_from_channels() + + +class TestTrimResample: + def test_trim(self, head_accel_x): + trimmed = trim(head_accel_x, t_start=0.0, t_end=0.05) + assert trimmed.time[0] >= 0.0 + assert trimmed.time[-1] <= 0.05 + assert len(trimmed.data) < len(head_accel_x.data) + + def test_resample(self, head_accel_x): + resampled = resample(head_accel_x, target_rate=5000.0) + expected_samples = int(head_accel_x.duration * 5000.0) + assert abs(len(resampled.data) - expected_samples) <= 2 + assert resampled.sample_rate == 5000.0 diff --git a/tests/test_web/test_state.py b/tests/test_web/test_state.py index 56d24c7..93c950b 100644 --- a/tests/test_web/test_state.py +++ b/tests/test_web/test_state.py @@ -19,11 +19,11 @@ class TestAppState: def test_load_test(self): state = AppState() - loaded = state.load_test(FIXTURE_DATA) + session = state.load_test(FIXTURE_DATA) assert not state.is_empty - assert loaded.test_id == "IMPAKT_SYNTH_001" - assert loaded.channel_count == 26 - assert state.primary_test is loaded + assert session.test_id == "IMPAKT_SYNTH_001" + assert len(session) == 26 + assert state.primary_test is session def test_load_multiple_tests(self): state = AppState() @@ -32,13 +32,13 @@ class TestAppState: if (MME_DATA / "VW1FGS15").exists(): t2 = state.load_test(MME_DATA / "VW1FGS15") assert len(state.tests) == 2 - assert state.primary_test is t1 # First loaded is primary + assert state.primary_test is t1 assert state.total_channels == 26 + 10 def test_remove_test(self): state = AppState() - loaded = state.load_test(FIXTURE_DATA) - state.remove_test(loaded.test_id) + session = state.load_test(FIXTURE_DATA) + state.remove_test(session.test_id) assert state.is_empty def test_get_channel(self): @@ -54,38 +54,38 @@ class TestAppState: ch = state.get_channel("IMPAKT_SYNTH_001", "NONEXISTENT") assert ch is None - def test_resolve_channel_with_key(self): + def test_get_channel_via_session(self): + """Channels can be accessed through the Session scripting API.""" state = AppState() state.load_test(FIXTURE_DATA) - ch = state.resolve_channel("IMPAKT_SYNTH_001::11HEAD0000ACXA") - assert ch is not None + session = state.primary_test + assert session is not None + ch_handle = session.channel("11HEAD0000ACXA") + assert ch_handle.name == "11HEAD0000ACXA" - def test_resolve_channel_primary_default(self): + def test_session_fluent_transforms(self): + """Fluent transform chaining works through the Session API.""" state = AppState() state.load_test(FIXTURE_DATA) - ch = state.resolve_channel("11HEAD0000ACXA") - assert ch is not None + session = state.primary_test + ch = session.channel("11HEAD0000ACXA") + filtered = ch.transform.cfc(600).transform.y_align() + assert filtered.raw.cfc_class == 600 + assert len(filtered.raw.transform_history) == 2 - def test_resolve_channel_with_cfc(self): + def test_session_compute_criteria(self): + """Session.compute_criteria() auto-detects channels.""" state = AppState() state.load_test(FIXTURE_DATA) - ch = state.resolve_channel("11HEAD0000ACXA", cfc_class=600) - assert ch is not None - assert ch.cfc_class == 600 - - def test_flat_channel_list(self): - state = AppState() - state.load_test(FIXTURE_DATA) - items = state.flat_channel_list() - assert len(items) == 26 - assert all("value" in item and "label" in item for item in items) + criteria = state.primary_test.compute_criteria() + assert len(criteria) > 0 + assert "HIC15" in criteria or "Chest Deflection" in criteria def test_build_channel_tree(self): state = AppState() state.load_test(FIXTURE_DATA) tree = state.build_channel_tree() assert "IMPAKT_SYNTH_001" in tree - # Should have hierarchical structure test_tree = tree["IMPAKT_SYNTH_001"] assert len(test_tree) > 0 @@ -94,15 +94,23 @@ class TestAppState: class TestAppStateRealData: def test_load_real_mme(self): state = AppState() - loaded = state.load_test(MME_DATA / "3239") - assert loaded.test_id == "3239" - assert loaded.channel_count == 133 + session = state.load_test(MME_DATA / "3239") + assert session.test_id == "3239" + assert len(session) == 133 def test_channel_tree_real_data(self): state = AppState() state.load_test(MME_DATA / "3239") tree = state.build_channel_tree() assert "3239" in tree - # Should contain "Driver" in some key test_tree = tree["3239"] assert any("Driver" in k for k in test_tree) + + def test_session_evaluate_real_data(self): + """Full pipeline through Session API on real data.""" + state = AppState() + state.load_test(MME_DATA / "3239") + result = state.primary_test.evaluate("euro_ncap") + assert result.stars is not None + assert result.stars >= 0 + assert len(result.region_scores) > 0