Files

110 lines
3.6 KiB
Python
Raw Permalink Normal View History

2026-04-10 14:37:34 -04:00
"""Tests for channel and test data models."""
import numpy as np
import pytest
from impakt.channel.model import Channel, ChannelGroup, TestData
class TestChannel:
def test_channel_properties(self, head_accel_x):
ch = head_accel_x
assert ch.name == "11HEAD0000ACXA"
assert ch.unit == "g"
assert ch.sample_rate == 20000.0
assert ch.n_samples == len(ch.data)
assert ch.duration > 0
def test_channel_peak(self, head_accel_x):
ch = head_accel_x
assert ch.peak > 35.0 # ~40g peak
assert 0.0 <= ch.peak_time <= 0.1
def test_value_at_interpolation(self, head_accel_x):
ch = head_accel_x
# At t=0 (pre-trigger boundary), value should be near zero
v = ch.value_at(-0.005)
assert abs(v) < 2.0 # Small noise
# At peak (~0.05s), should be near 40g
v_peak = ch.value_at(0.05)
assert v_peak > 30.0
def test_with_data_creates_new_channel(self, head_accel_x):
ch = head_accel_x
new_data = ch.data * 2.0
ch2 = ch.with_data(data=new_data, transform_note="doubled")
assert ch2 is not ch
assert np.allclose(ch2.data, new_data)
assert ch2.name == ch.name
assert len(ch2.transform_history) == 1
assert "doubled" in ch2.transform_history[0]
def test_data_time_length_mismatch_raises(self):
with pytest.raises(ValueError, match="data length"):
Channel(
name="test",
code=__import__("impakt.channel.code", fromlist=["ChannelCode"]).ChannelCode.parse(
"test"
),
data=np.zeros(100),
time=np.zeros(50),
)
class TestChannelGroup:
def test_group_components(self, head_group):
comps = head_group.components()
assert len(comps) == 3
def test_group_resultant(self, head_group):
resultant = head_group.resultant()
assert resultant.code.direction == "R"
assert resultant.n_samples == head_group.x.n_samples
# Resultant should be >= any single component
assert resultant.peak >= head_group.x.peak
def test_group_description(self, head_group):
desc = head_group.description
assert "Head" in desc
class TestTestDataContainer:
def test_get_channel(self, sample_test_data):
ch = sample_test_data.get("11HEAD0000ACXA")
assert ch.name == "11HEAD0000ACXA"
def test_get_channel_case_insensitive(self, sample_test_data):
ch = sample_test_data.get("11head0000acxa")
assert ch.name == "11HEAD0000ACXA"
def test_find_channels(self, sample_test_data):
channels = sample_test_data.find("11HEAD0000AC*")
assert len(channels) == 3 # X, Y, Z
def test_groups(self, sample_test_data):
groups = sample_test_data.groups()
assert len(groups) > 0
# Head acceleration group should exist
head_keys = [k for k in groups if "HEAD" in k and "AC" in k]
assert len(head_keys) == 1
def test_group_lookup(self, sample_test_data):
group = sample_test_data.group("HEAD0000AC")
assert group.x is not None
assert group.y is not None
assert group.z is not None
def test_channel_tree(self, sample_test_data):
tree = sample_test_data.channel_tree()
assert len(tree) > 0
# Should have a "Driver" entry
assert any("Driver" in key for key in tree)
def test_len(self, sample_test_data):
assert len(sample_test_data) == 7
def test_contains(self, sample_test_data):
assert "11HEAD0000ACXA" in sample_test_data
assert "NONEXISTENT" not in sample_test_data