mirror of https://github.com/Nonannet/copapy.git
trig-function fixed and tests added
This commit is contained in:
parent
73d32a07b1
commit
d17aa809e1
|
|
@ -67,7 +67,7 @@ __attribute__((noinline)) float aux_cos(float x) {
|
||||||
// Select function and sign based on quadrant
|
// Select function and sign based on quadrant
|
||||||
int qm = q & 3;
|
int qm = q & 3;
|
||||||
int use_sin = (qm == 1 || qm == 3);
|
int use_sin = (qm == 1 || qm == 3);
|
||||||
int sign = (qm == 0 || qm == 1) ? +1 : -1;
|
int sign = (qm == 0 || qm == 3) ? +1 : -1;
|
||||||
|
|
||||||
float r2 = r * r;
|
float r2 = r * r;
|
||||||
|
|
||||||
|
|
@ -101,15 +101,15 @@ __attribute__((noinline)) float aux_tan(float x) {
|
||||||
int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); // nearest integer
|
int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); // nearest integer
|
||||||
|
|
||||||
// Range reduce: r = x - q*(pi/2)
|
// Range reduce: r = x - q*(pi/2)
|
||||||
const double PIO2_HI = 1.57079625129699707031; // π/2 high part
|
const double PIO2_HI = 1.57079625129699707031; // pi/2 high part
|
||||||
const double PIO2_LO = 7.54978941586159635335e-08; // π/2 low part
|
const double PIO2_LO = 7.54978941586159635335e-08; // pi/2 low part
|
||||||
double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO;
|
double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO;
|
||||||
float r = (float)r_d;
|
float r = (float)r_d;
|
||||||
|
|
||||||
// For tan: period is π, so q mod 2 determines sign
|
// For tan: period is pi, so q mod 2 determines sign
|
||||||
int qm = q & 3;
|
int qm = q & 3;
|
||||||
int use_cot = (qm == 1 || qm == 3); // tan(x) = ±cot(r) in odd quadrants
|
int use_cot = (qm == 1 || qm == 3); // tan(x) = ±cot(r) in odd quadrants
|
||||||
int sign = (qm == 1 || qm == 2) ? -1 : +1;
|
int sign = (qm == 0 || qm == 2) ? +1 : -1;
|
||||||
|
|
||||||
// Polynomial approximations
|
// Polynomial approximations
|
||||||
// sin(r) ≈ r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9
|
// sin(r) ≈ r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ def test_corse():
|
||||||
assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="sqrt must be fixed")
|
|
||||||
def test_fine():
|
def test_fine():
|
||||||
a_i = 9
|
a_i = 9
|
||||||
a_f = 2.5
|
a_f = 2.5
|
||||||
|
|
@ -37,8 +36,21 @@ def test_fine():
|
||||||
c_f = variable(a_f)
|
c_f = variable(a_f)
|
||||||
# c_b = variable(True)
|
# c_b = variable(True)
|
||||||
|
|
||||||
ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f), cp.sin(c_f), cp.cos(c_f), cp.tan(c_f)) # , c_i & 3)
|
ret_test = (c_f ** 2,
|
||||||
ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f), cp.sin(a_f), cp.cos(a_f), cp.tan(a_f)) # , a_i & 3)
|
c_i ** -1,
|
||||||
|
cp.sqrt(c_i),
|
||||||
|
cp.sqrt(c_f),
|
||||||
|
cp.sin(c_f),
|
||||||
|
cp.cos(c_f),
|
||||||
|
cp.tan(c_f)) # , c_i & 3)
|
||||||
|
|
||||||
|
ret_refe = (a_f ** 2,
|
||||||
|
a_i ** -1,
|
||||||
|
cp.sqrt(a_i),
|
||||||
|
cp.sqrt(a_f),
|
||||||
|
cp.sin(a_f),
|
||||||
|
cp.cos(a_f),
|
||||||
|
cp.tan(a_f)) # , a_i & 3)
|
||||||
|
|
||||||
tg = Target()
|
tg = Target()
|
||||||
print('* compile and copy ...')
|
print('* compile and copy ...')
|
||||||
|
|
@ -47,13 +59,33 @@ def test_fine():
|
||||||
tg.run()
|
tg.run()
|
||||||
print('* finished')
|
print('* finished')
|
||||||
|
|
||||||
for test, ref in zip(ret_test, ret_refe):
|
for test, ref, name in zip(ret_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')):
|
||||||
assert isinstance(test, cp.variable)
|
assert isinstance(test, cp.variable)
|
||||||
val = tg.read_value(test)
|
val = tg.read_value(test)
|
||||||
print('+', val, ref, type(val), test.dtype)
|
print('+', val, ref, type(val), test.dtype)
|
||||||
#for t in (int, float, bool):
|
#for t in (int, float, bool):
|
||||||
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
|
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
|
||||||
assert val == pytest.approx(ref, 0.001), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
assert val == pytest.approx(ref, 1e-5), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
|
||||||
|
def test_trig_precision():
|
||||||
|
|
||||||
|
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] # up to 2pi
|
||||||
|
|
||||||
|
ret_test = [r for v in test_vals for r in (cp.sin(variable(v)), cp.cos(variable(v)), cp.tan(variable(v)))]
|
||||||
|
ret_refe = [r for v in test_vals for r in (cp.sin(v), cp.cos(v), cp.tan(v))]
|
||||||
|
|
||||||
|
|
||||||
|
tg = Target()
|
||||||
|
tg.compile(ret_test)
|
||||||
|
tg.run()
|
||||||
|
|
||||||
|
for i, (test, ref) in enumerate(zip(ret_test, ret_refe)):
|
||||||
|
func_name = ['sin', 'cos', 'tan'][i % 3]
|
||||||
|
assert isinstance(test, cp.variable)
|
||||||
|
val = tg.read_value(test)
|
||||||
|
print(f"+ Result of {func_name}: {val}; reference: {ref}")
|
||||||
|
assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ def test_vectors_init():
|
||||||
|
|
||||||
print(tt1, tt2, tt3, tt4, tt5)
|
print(tt1, tt2, tt3, tt4, tt5)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="sqrt must be fixed")
|
|
||||||
def test_compiled_vectors():
|
def test_compiled_vectors():
|
||||||
t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3))
|
t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3))
|
||||||
t2 = t1.sum()
|
t2 = t1.sum()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue