trig-function fixed and tests added

This commit is contained in:
Nicolas 2025-11-01 13:43:22 +01:00
parent 73d32a07b1
commit d17aa809e1
3 changed files with 43 additions and 11 deletions

View File

@ -67,7 +67,7 @@ __attribute__((noinline)) float aux_cos(float x) {
// Select function and sign based on quadrant // Select function and sign based on quadrant
int qm = q & 3; int qm = q & 3;
int use_sin = (qm == 1 || qm == 3); int use_sin = (qm == 1 || qm == 3);
int sign = (qm == 0 || qm == 1) ? +1 : -1; int sign = (qm == 0 || qm == 3) ? +1 : -1;
float r2 = r * r; float r2 = r * r;
@ -101,15 +101,15 @@ __attribute__((noinline)) float aux_tan(float x) {
int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); // nearest integer int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); // nearest integer
// Range reduce: r = x - q*(pi/2) // Range reduce: r = x - q*(pi/2)
const double PIO2_HI = 1.57079625129699707031; // π/2 high part const double PIO2_HI = 1.57079625129699707031; // pi/2 high part
const double PIO2_LO = 7.54978941586159635335e-08; // π/2 low part const double PIO2_LO = 7.54978941586159635335e-08; // pi/2 low part
double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO; double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO;
float r = (float)r_d; float r = (float)r_d;
// For tan: period is π, so q mod 2 determines sign // For tan: period is pi, so q mod 2 determines sign
int qm = q & 3; int qm = q & 3;
int use_cot = (qm == 1 || qm == 3); // tan(x) = ±cot(r) in odd quadrants int use_cot = (qm == 1 || qm == 3); // tan(x) = ±cot(r) in odd quadrants
int sign = (qm == 1 || qm == 2) ? -1 : +1; int sign = (qm == 0 || qm == 2) ? +1 : -1;
// Polynomial approximations // Polynomial approximations
// sin(r) ≈ r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9 // sin(r) ≈ r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9

View File

@ -29,7 +29,6 @@ def test_corse():
assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
@pytest.mark.skip(reason="sqrt must be fixed")
def test_fine(): def test_fine():
a_i = 9 a_i = 9
a_f = 2.5 a_f = 2.5
@ -37,8 +36,21 @@ def test_fine():
c_f = variable(a_f) c_f = variable(a_f)
# c_b = variable(True) # c_b = variable(True)
ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f), cp.sin(c_f), cp.cos(c_f), cp.tan(c_f)) # , c_i & 3) ret_test = (c_f ** 2,
ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f), cp.sin(a_f), cp.cos(a_f), cp.tan(a_f)) # , a_i & 3) c_i ** -1,
cp.sqrt(c_i),
cp.sqrt(c_f),
cp.sin(c_f),
cp.cos(c_f),
cp.tan(c_f)) # , c_i & 3)
ret_refe = (a_f ** 2,
a_i ** -1,
cp.sqrt(a_i),
cp.sqrt(a_f),
cp.sin(a_f),
cp.cos(a_f),
cp.tan(a_f)) # , a_i & 3)
tg = Target() tg = Target()
print('* compile and copy ...') print('* compile and copy ...')
@ -47,13 +59,33 @@ def test_fine():
tg.run() tg.run()
print('* finished') print('* finished')
for test, ref in zip(ret_test, ret_refe): for test, ref, name in zip(ret_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')):
assert isinstance(test, cp.variable) assert isinstance(test, cp.variable)
val = tg.read_value(test) val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype) print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool): #for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" # assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, 0.001), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] assert val == pytest.approx(ref, 1e-5), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
def test_trig_precision():
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] # up to 2pi
ret_test = [r for v in test_vals for r in (cp.sin(variable(v)), cp.cos(variable(v)), cp.tan(variable(v)))]
ret_refe = [r for v in test_vals for r in (cp.sin(v), cp.cos(v), cp.tan(v))]
tg = Target()
tg.compile(ret_test)
tg.run()
for i, (test, ref) in enumerate(zip(ret_test, ret_refe)):
func_name = ['sin', 'cos', 'tan'][i % 3]
assert isinstance(test, cp.variable)
val = tg.read_value(test)
print(f"+ Result of {func_name}: {val}; reference: {ref}")
assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -11,7 +11,7 @@ def test_vectors_init():
print(tt1, tt2, tt3, tt4, tt5) print(tt1, tt2, tt3, tt4, tt5)
@pytest.mark.skip(reason="sqrt must be fixed")
def test_compiled_vectors(): def test_compiled_vectors():
t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3)) t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3))
t2 = t1.sum() t2 = t1.sum()