copapy/tests/test_math.py

167 lines
7.0 KiB
Python
Raw Permalink Normal View History

2025-12-06 17:09:25 +00:00
from copapy import value, Target
import pytest
import copapy as cp
import math as ma
import warnings
def test_fine():
a_i = 9
a_f = 2.5
2025-12-06 17:09:25 +00:00
c_i = value(a_i)
c_f = value(a_f)
2025-11-01 12:43:22 +00:00
ret_test = (c_f ** 2,
c_i ** -1,
c_i ** 2.111,
c_f ** 2.111,
2025-11-01 12:43:22 +00:00
cp.sqrt(c_i),
cp.sqrt(c_f),
cp.sin(c_f),
cp.cos(c_f),
cp.tan(c_f),
cp.abs(-c_i),
cp.abs(-c_f))
re2_test = (a_f ** 2,
2025-11-01 12:43:22 +00:00
a_i ** -1,
a_i ** 2.111,
a_f ** 2.111,
2025-11-01 12:43:22 +00:00
cp.sqrt(a_i),
cp.sqrt(a_f),
cp.sin(a_f),
cp.cos(a_f),
cp.tan(a_f),
cp.abs(-a_i),
cp.abs(-a_f))
ret_refe = (a_f ** 2,
a_i ** -1,
a_i ** 2.111,
a_f ** 2.111,
ma.sqrt(a_i),
ma.sqrt(a_f),
ma.sin(a_f),
ma.cos(a_f),
ma.tan(a_f),
cp.abs(-a_i),
cp.abs(-a_f))
tg = Target()
print('* compile and copy ...')
tg.compile(ret_test)
print('* run and copy ...')
tg.run()
print('* finished')
for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')):
2025-12-06 17:09:25 +00:00
assert isinstance(test, cp.value)
val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, abs=1e-3), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
assert val2 == pytest.approx(ref, abs=1e-3), f"Local result for {name} does not match: {val2} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
2025-11-01 12:43:22 +00:00
def test_trig_precision():
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0,
-0.0001, -0.1, -0.5, -1.0, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0, -5.5, -6.0, -6.28318530718, -100.0, -1000.0, -100000.0]
2025-11-01 12:43:22 +00:00
2025-12-06 17:09:25 +00:00
ret_test = [r for v in test_vals for r in (cp.sin(value(v)), cp.cos(value(v)), cp.tan(value(v)))]
ret_refe = [r for v in test_vals for r in (ma.sin(v), ma.cos(v), ma.tan(v))]
2025-11-01 12:43:22 +00:00
tg = Target()
tg.compile(ret_test)
tg.run()
2025-11-14 07:56:43 +00:00
for i, (v, test, ref) in enumerate(zip(test_vals, ret_test, ret_refe)):
2025-11-01 12:43:22 +00:00
func_name = ['sin', 'cos', 'tan'][i % 3]
2025-12-06 17:09:25 +00:00
assert isinstance(test, cp.value)
2025-11-01 12:43:22 +00:00
val = tg.read_value(test)
print(f"+ Result of {func_name}: {val}; reference: {ref}")
2025-11-14 07:56:43 +00:00
assert val == pytest.approx(ref, abs=1e-3), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref} (value: {v})" # pyright: ignore[reportUnknownMemberType]
#if not val == pytest.approx(ref, abs=1e-5): # pyright: ignore[reportUnknownMemberType]
# warnings.warn(f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}", UserWarning)
def test_arcus_trig_precision():
2025-11-09 23:08:26 +00:00
test_vals = [0.0, 0.01, 0.1, 0.5, 0.7, 0.9, 0.95,
-0.01, -0.1, -0.5, -0.7, -0.9, 0.95]
2025-12-06 17:09:25 +00:00
ret_test = [r for v in test_vals for r in (cp.asin(value(v)),
cp.acos(value(v)),
cp.atan(value(v)),
cp.atan2(value(v), value(3)),
cp.atan2(value(v), value(-3)),)]
ret_refe = [r for v in test_vals for r in (ma.asin(v),
ma.acos(v),
ma.atan(v),
2025-11-14 07:56:43 +00:00
ma.atan2(v, 3),
ma.atan2(v, -3),)]
tg = Target()
tg.compile(ret_test)
tg.run()
for i, (test, ref) in enumerate(zip(ret_test, ret_refe)):
func_name = ['asin', 'acos', 'atan', 'atan2[1]', 'atan2[2]'][i % 5]
2025-12-06 17:09:25 +00:00
assert isinstance(test, cp.value)
2025-11-14 07:56:43 +00:00
val = tg.read_value(test)
print(f"+ Result of {func_name}: {val}; reference: {ref}")
#assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 5]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if not val == pytest.approx(ref, abs=1e-5): # pyright: ignore[reportUnknownMemberType]
warnings.warn(f"Result of {func_name} for input {test_vals[i // 5]} does not match: {val} and reference: {ref}", UserWarning)
def test_sqrt_precision():
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0]
2025-12-06 17:09:25 +00:00
ret_test = [r for v in test_vals for r in (cp.sqrt(value(v)),)]
ret_refe = [r for v in test_vals for r in (cp.sqrt(v),)]
tg = Target()
tg.compile(ret_test)
tg.run()
for i, (test, ref) in enumerate(zip(ret_test, ret_refe)):
func_name = 'sqrt'
2025-12-06 17:09:25 +00:00
assert isinstance(test, cp.value)
val = tg.read_value(test)
print(f"+ Result of {func_name}: {val}; reference: {ref}")
assert val == pytest.approx(ref, rel=1e-5), f"Result of {func_name} for input {test_vals[i]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
#if not val == pytest.approx(ref, abs=1e-5): # pyright: ignore[reportUnknownMemberType]
# warnings.warn(f"Result of {func_name} for input {test_vals[i // 2]} does not match: {val} and reference: {ref}", UserWarning)
def test_log_exp_precision():
test_vals = [0.1, 0.5, 0.9, 0.999, 1.0, 2.5,
-0.1, -0.5, -0.9, -0.999, -1.0, 2.5]
2025-12-06 17:09:25 +00:00
ret_test = [r for v in test_vals for r in (cp.log(value(abs(v))),
cp.exp(value(v)))]
2025-11-09 23:08:26 +00:00
ret_refe = [r for v in test_vals for r in (ma.log(abs(v)),
ma.exp(v))]
tg = Target()
tg.compile(ret_test)
tg.run()
for i, (test, ref) in enumerate(zip(ret_test, ret_refe)):
func_name = ['log', 'exp'][i % 2]
2025-12-06 17:09:25 +00:00
assert isinstance(test, cp.value)
val = tg.read_value(test)
print(f"+ Result of {func_name}: {val}; reference: {ref}")
assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 2]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
#if not val == pytest.approx(ref, abs=1e-5): # pyright: ignore[reportUnknownMemberType]
# warnings.warn(f"Result of {func_name} for input {test_vals[i // 2]} does not match: {val} and reference: {ref}", UserWarning)
if __name__ == "__main__":
test_fine()
test_sqrt_precision()
test_trig_precision()
test_log_exp_precision()