Files
impakt/tests/test_transform/test_math_resultant.py
2026-04-10 17:28:29 -04:00

76 lines
2.6 KiB
Python

"""Tests for math expressions and resultant computation."""
import numpy as np
import pytest
from impakt.transform.math_expr import math_expr
from impakt.transform.resultant import resultant_from_channels
from impakt.transform.resample import trim, resample
class TestMathExpr:
def test_simple_expression(self, head_accel_x, head_accel_z):
result = math_expr(
expression="sqrt(a**2 + b**2)",
channels={"a": head_accel_x, "b": head_accel_z},
name="resultant_xz",
unit="g",
)
assert result.name == "resultant_xz"
assert result.unit == "g"
assert result.peak > 0
assert len(result.data) == len(head_accel_x.data)
def test_constant_expression(self, head_accel_x):
result = math_expr(
expression="a * 0 + 42.0",
channels={"a": head_accel_x},
name="constant",
)
assert np.allclose(result.data, 42.0)
def test_invalid_expression(self, head_accel_x):
with pytest.raises(ValueError, match="Error evaluating"):
math_expr(
expression="invalid_func(a)",
channels={"a": head_accel_x},
)
def test_forbidden_expression(self, head_accel_x):
with pytest.raises(ValueError, match="Forbidden"):
math_expr(
expression="__import__('os')",
channels={"a": head_accel_x},
)
class TestResultant:
def test_from_channels(self, head_accel_x, head_accel_y, head_accel_z):
result = resultant_from_channels(head_accel_x, head_accel_y, head_accel_z)
assert result.code.direction == "R"
# Resultant >= any component
assert result.peak >= head_accel_x.peak
assert result.peak >= head_accel_y.peak
def test_from_two_channels(self, head_accel_x, head_accel_z):
result = resultant_from_channels(head_accel_x, head_accel_z)
assert result.peak > 0
def test_single_channel_raises(self):
with pytest.raises(ValueError, match="At least one"):
resultant_from_channels()
class TestTrimResample:
def test_trim(self, head_accel_x):
trimmed = trim(head_accel_x, t_start=0.0, t_end=0.05)
assert trimmed.time[0] >= 0.0
assert trimmed.time[-1] <= 0.05
assert len(trimmed.data) < len(head_accel_x.data)
def test_resample(self, head_accel_x):
resampled = resample(head_accel_x, target_rate=5000.0)
expected_samples = int(head_accel_x.duration * 5000.0)
assert abs(len(resampled.data) - expected_samples) <= 2
assert resampled.sample_rate == 5000.0