diff --git a/stencils/trigonometry.c b/stencils/trigonometry.c index da995bd..927ec36 100644 --- a/stencils/trigonometry.c +++ b/stencils/trigonometry.c @@ -67,7 +67,7 @@ __attribute__((noinline)) float aux_cos(float x) { // Select function and sign based on quadrant int qm = q & 3; int use_sin = (qm == 1 || qm == 3); - int sign = (qm == 0 || qm == 1) ? +1 : -1; + int sign = (qm == 0 || qm == 3) ? +1 : -1; float r2 = r * r; @@ -101,15 +101,15 @@ __attribute__((noinline)) float aux_tan(float x) { int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); // nearest integer // Range reduce: r = x - q*(pi/2) - const double PIO2_HI = 1.57079625129699707031; // π/2 high part - const double PIO2_LO = 7.54978941586159635335e-08; // π/2 low part + const double PIO2_HI = 1.57079625129699707031; // pi/2 high part + const double PIO2_LO = 7.54978941586159635335e-08; // pi/2 low part double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO; float r = (float)r_d; - // For tan: period is π, so q mod 2 determines sign + // For tan: period is pi, so q mod 2 determines sign int qm = q & 3; int use_cot = (qm == 1 || qm == 3); // tan(x) = ±cot(r) in odd quadrants - int sign = (qm == 1 || qm == 2) ? -1 : +1; + int sign = (qm == 0 || qm == 2) ? +1 : -1; // Polynomial approximations // sin(r) ≈ r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9 diff --git a/tests/test_math.py b/tests/test_math.py index acc0d1d..3deaaa7 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -29,7 +29,6 @@ def test_corse(): assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] -@pytest.mark.skip(reason="sqrt must be fixed") def test_fine(): a_i = 9 a_f = 2.5 @@ -37,8 +36,21 @@ def test_fine(): c_f = variable(a_f) # c_b = variable(True) - ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f), cp.sin(c_f), cp.cos(c_f), cp.tan(c_f)) # , c_i & 3) - ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f), cp.sin(a_f), cp.cos(a_f), cp.tan(a_f)) # , a_i & 3) + ret_test = (c_f ** 2, + c_i ** -1, + cp.sqrt(c_i), + cp.sqrt(c_f), + cp.sin(c_f), + cp.cos(c_f), + cp.tan(c_f)) # , c_i & 3) + + ret_refe = (a_f ** 2, + a_i ** -1, + cp.sqrt(a_i), + cp.sqrt(a_f), + cp.sin(a_f), + cp.cos(a_f), + cp.tan(a_f)) # , a_i & 3) tg = Target() print('* compile and copy ...') @@ -47,13 +59,33 @@ def test_fine(): tg.run() print('* finished') - for test, ref in zip(ret_test, ret_refe): + for test, ref, name in zip(ret_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')): assert isinstance(test, cp.variable) val = tg.read_value(test) print('+', val, ref, type(val), test.dtype) #for t in (int, float, bool): # assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" - assert val == pytest.approx(ref, 0.001), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + assert val == pytest.approx(ref, 1e-5), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + + +def test_trig_precision(): + + test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] # up to 2pi + + ret_test = [r for v in test_vals for r in (cp.sin(variable(v)), cp.cos(variable(v)), cp.tan(variable(v)))] + ret_refe = [r for v in test_vals for r in (cp.sin(v), cp.cos(v), cp.tan(v))] + + + tg = Target() + tg.compile(ret_test) + tg.run() + + for i, (test, ref) in enumerate(zip(ret_test, ret_refe)): + func_name = ['sin', 'cos', 'tan'][i % 3] + assert isinstance(test, cp.variable) + val = tg.read_value(test) + print(f"+ Result of {func_name}: {val}; reference: {ref}") + assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] if __name__ == "__main__": diff --git a/tests/test_vector.py b/tests/test_vector.py index 297551b..b448417 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -11,7 +11,7 @@ def test_vectors_init(): print(tt1, tt2, tt3, tt4, tt5) -@pytest.mark.skip(reason="sqrt must be fixed") + def test_compiled_vectors(): t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3)) t2 = t1.sum()