"""Tests for channel and test data models.""" import numpy as np import pytest from impakt.channel.model import Channel, ChannelGroup, TestData class TestChannel: def test_channel_properties(self, head_accel_x): ch = head_accel_x assert ch.name == "11HEAD0000ACXA" assert ch.unit == "g" assert ch.sample_rate == 20000.0 assert ch.n_samples == len(ch.data) assert ch.duration > 0 def test_channel_peak(self, head_accel_x): ch = head_accel_x assert ch.peak > 35.0 # ~40g peak assert 0.0 <= ch.peak_time <= 0.1 def test_value_at_interpolation(self, head_accel_x): ch = head_accel_x # At t=0 (pre-trigger boundary), value should be near zero v = ch.value_at(-0.005) assert abs(v) < 2.0 # Small noise # At peak (~0.05s), should be near 40g v_peak = ch.value_at(0.05) assert v_peak > 30.0 def test_with_data_creates_new_channel(self, head_accel_x): ch = head_accel_x new_data = ch.data * 2.0 ch2 = ch.with_data(data=new_data, transform_note="doubled") assert ch2 is not ch assert np.allclose(ch2.data, new_data) assert ch2.name == ch.name assert len(ch2.transform_history) == 1 assert "doubled" in ch2.transform_history[0] def test_data_time_length_mismatch_raises(self): with pytest.raises(ValueError, match="data length"): Channel( name="test", code=__import__("impakt.channel.code", fromlist=["ChannelCode"]).ChannelCode.parse( "test" ), data=np.zeros(100), time=np.zeros(50), ) class TestChannelGroup: def test_group_components(self, head_group): comps = head_group.components() assert len(comps) == 3 def test_group_resultant(self, head_group): resultant = head_group.resultant() assert resultant.code.direction == "R" assert resultant.n_samples == head_group.x.n_samples # Resultant should be >= any single component assert resultant.peak >= head_group.x.peak def test_group_description(self, head_group): desc = head_group.description assert "Head" in desc class TestTestDataContainer: def test_get_channel(self, sample_test_data): ch = sample_test_data.get("11HEAD0000ACXA") assert ch.name == "11HEAD0000ACXA" def test_get_channel_case_insensitive(self, sample_test_data): ch = sample_test_data.get("11head0000acxa") assert ch.name == "11HEAD0000ACXA" def test_find_channels(self, sample_test_data): channels = sample_test_data.find("11HEAD0000AC*") assert len(channels) == 3 # X, Y, Z def test_groups(self, sample_test_data): groups = sample_test_data.groups() assert len(groups) > 0 # Head acceleration group should exist head_keys = [k for k in groups if "HEAD" in k and "AC" in k] assert len(head_keys) == 1 def test_group_lookup(self, sample_test_data): group = sample_test_data.group("HEAD0000AC") assert group.x is not None assert group.y is not None assert group.z is not None def test_channel_tree(self, sample_test_data): tree = sample_test_data.channel_tree() assert len(tree) > 0 # Should have a "Driver" entry assert any("Driver" in key for key in tree) def test_len(self, sample_test_data): assert len(sample_test_data) == 7 def test_contains(self, sample_test_data): assert "11HEAD0000ACXA" in sample_test_data assert "NONEXISTENT" not in sample_test_data