Files
impakt/tests/test_plot/test_engine.py
2026-04-10 17:28:29 -04:00

82 lines
2.6 KiB
Python

"""Tests for plot engine."""
import numpy as np
import pytest
from impakt.plot.engine import PlotEngine, DEFAULT_COLORS
from impakt.plot.spec import ChannelRef, Corridor, CorridorStyle, PlotSpec, PlotStyle
class TestPlotEngine:
def test_render_empty(self):
engine = PlotEngine()
spec = PlotSpec()
fig = engine.render(spec)
assert fig is not None
def test_render_with_channels(self, head_accel_x, head_accel_y):
engine = PlotEngine()
spec = PlotSpec(
channels=[
ChannelRef(channel=head_accel_x, style=PlotStyle(label="Head X")),
ChannelRef(channel=head_accel_y, style=PlotStyle(label="Head Y")),
],
y_label="g",
)
fig = engine.render(spec)
assert len(fig.data) == 2
def test_render_compact_mode(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
compact=True,
)
fig = engine.render(spec)
assert fig.layout.hovermode is False
assert fig.layout.showlegend is False
def test_render_with_corridors(self, head_accel_x):
engine = PlotEngine()
corridor = Corridor(
name="Test Corridor",
time=np.linspace(0, 0.1, 100),
lower=np.full(100, -50.0),
upper=np.full(100, 50.0),
)
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
corridors=[corridor],
)
fig = engine.render(spec)
# 1 data trace + 2 corridor traces (upper + lower)
assert len(fig.data) == 3
def test_render_with_cursors(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
x_cursors=(0.02, 0.06),
)
fig = engine.render(spec)
# Should have vertical lines (as shapes)
assert len(fig.layout.shapes) >= 2
def test_compact_cursor_annotations(self, head_accel_x):
engine = PlotEngine()
spec = PlotSpec(
channels=[ChannelRef(channel=head_accel_x)],
x_cursors=(0.02, 0.06),
compact=True,
)
fig = engine.render(spec)
# Should have X1/X2 annotations
annotations = [
a for a in (fig.layout.annotations or []) if "X1" in str(a.text) or "X2" in str(a.text)
]
assert len(annotations) == 2
def test_default_colors(self):
assert len(DEFAULT_COLORS) == 10
assert all(c.startswith("#") for c in DEFAULT_COLORS)