mirror of https://github.com/Nonannet/copapy.git
trig-function fixed and tests added
This commit is contained in:
parent
73d32a07b1
commit
d17aa809e1
|
|
@ -67,7 +67,7 @@ __attribute__((noinline)) float aux_cos(float x) {
|
|||
// Select function and sign based on quadrant
|
||||
int qm = q & 3;
|
||||
int use_sin = (qm == 1 || qm == 3);
|
||||
int sign = (qm == 0 || qm == 1) ? +1 : -1;
|
||||
int sign = (qm == 0 || qm == 3) ? +1 : -1;
|
||||
|
||||
float r2 = r * r;
|
||||
|
||||
|
|
@ -101,15 +101,15 @@ __attribute__((noinline)) float aux_tan(float x) {
|
|||
int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); // nearest integer
|
||||
|
||||
// Range reduce: r = x - q*(pi/2)
|
||||
const double PIO2_HI = 1.57079625129699707031; // π/2 high part
|
||||
const double PIO2_LO = 7.54978941586159635335e-08; // π/2 low part
|
||||
const double PIO2_HI = 1.57079625129699707031; // pi/2 high part
|
||||
const double PIO2_LO = 7.54978941586159635335e-08; // pi/2 low part
|
||||
double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO;
|
||||
float r = (float)r_d;
|
||||
|
||||
// For tan: period is π, so q mod 2 determines sign
|
||||
// For tan: period is pi, so q mod 2 determines sign
|
||||
int qm = q & 3;
|
||||
int use_cot = (qm == 1 || qm == 3); // tan(x) = ±cot(r) in odd quadrants
|
||||
int sign = (qm == 1 || qm == 2) ? -1 : +1;
|
||||
int sign = (qm == 0 || qm == 2) ? +1 : -1;
|
||||
|
||||
// Polynomial approximations
|
||||
// sin(r) ≈ r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ def test_corse():
|
|||
assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="sqrt must be fixed")
|
||||
def test_fine():
|
||||
a_i = 9
|
||||
a_f = 2.5
|
||||
|
|
@ -37,8 +36,21 @@ def test_fine():
|
|||
c_f = variable(a_f)
|
||||
# c_b = variable(True)
|
||||
|
||||
ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f), cp.sin(c_f), cp.cos(c_f), cp.tan(c_f)) # , c_i & 3)
|
||||
ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f), cp.sin(a_f), cp.cos(a_f), cp.tan(a_f)) # , a_i & 3)
|
||||
ret_test = (c_f ** 2,
|
||||
c_i ** -1,
|
||||
cp.sqrt(c_i),
|
||||
cp.sqrt(c_f),
|
||||
cp.sin(c_f),
|
||||
cp.cos(c_f),
|
||||
cp.tan(c_f)) # , c_i & 3)
|
||||
|
||||
ret_refe = (a_f ** 2,
|
||||
a_i ** -1,
|
||||
cp.sqrt(a_i),
|
||||
cp.sqrt(a_f),
|
||||
cp.sin(a_f),
|
||||
cp.cos(a_f),
|
||||
cp.tan(a_f)) # , a_i & 3)
|
||||
|
||||
tg = Target()
|
||||
print('* compile and copy ...')
|
||||
|
|
@ -47,13 +59,33 @@ def test_fine():
|
|||
tg.run()
|
||||
print('* finished')
|
||||
|
||||
for test, ref in zip(ret_test, ret_refe):
|
||||
for test, ref, name in zip(ret_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')):
|
||||
assert isinstance(test, cp.variable)
|
||||
val = tg.read_value(test)
|
||||
print('+', val, ref, type(val), test.dtype)
|
||||
#for t in (int, float, bool):
|
||||
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
|
||||
assert val == pytest.approx(ref, 0.001), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
||||
assert val == pytest.approx(ref, 1e-5), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def test_trig_precision():
|
||||
|
||||
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] # up to 2pi
|
||||
|
||||
ret_test = [r for v in test_vals for r in (cp.sin(variable(v)), cp.cos(variable(v)), cp.tan(variable(v)))]
|
||||
ret_refe = [r for v in test_vals for r in (cp.sin(v), cp.cos(v), cp.tan(v))]
|
||||
|
||||
|
||||
tg = Target()
|
||||
tg.compile(ret_test)
|
||||
tg.run()
|
||||
|
||||
for i, (test, ref) in enumerate(zip(ret_test, ret_refe)):
|
||||
func_name = ['sin', 'cos', 'tan'][i % 3]
|
||||
assert isinstance(test, cp.variable)
|
||||
val = tg.read_value(test)
|
||||
print(f"+ Result of {func_name}: {val}; reference: {ref}")
|
||||
assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ def test_vectors_init():
|
|||
|
||||
print(tt1, tt2, tt3, tt4, tt5)
|
||||
|
||||
@pytest.mark.skip(reason="sqrt must be fixed")
|
||||
|
||||
def test_compiled_vectors():
|
||||
t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3))
|
||||
t2 = t1.sum()
|
||||
|
|
|
|||
Loading…
Reference in New Issue