sqrt function apdated, test for it added

This commit is contained in:
Nicolas 2025-11-01 21:51:29 +01:00
parent abf19ea92b
commit 58038cef8b
2 changed files with 34 additions and 9 deletions

View File

@ -12,17 +12,20 @@ __attribute__((noinline)) int floor_div(float arg1, float arg2) {
return i;
}
__attribute__((noinline)) float aux_sqrt(float n) {
if (n < 0) return -1;
__attribute__((noinline)) float aux_sqrt(float x) {
if (x <= 0.0f) return 0.0f;
float x = n; // initial guess
float epsilon = 0.00001; // desired accuracy
// --- Improved initial guess using bit-level trick ---
union { float f; uint32_t i; } conv = { x };
conv.i = (conv.i >> 1) + 0x1fc00000; // better bias constant
float y = conv.f;
while ((x - n / x) > epsilon || (x - n / x) < -epsilon) {
x = 0.5 * (x + n / x);
}
// --- Fixed number of Newton-Raphson iterations ---
y = 0.5f * (y + x / y);
y = 0.5f * (y + x / y);
y = 0.5f * (y + x / y); // 3 fixed iterations
return x;
return y;
}
__attribute__((noinline)) float aux_get_42(float n) {

View File

@ -70,7 +70,8 @@ def test_fine():
def test_trig_precision():
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] # up to 2pi
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0,
-0.0001, -0.1, -0.5, -1.0, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0, -5.5, -6.0, -6.28318530718, -100.0, -1000.0, -100000.0]
ret_test = [r for v in test_vals for r in (cp.sin(variable(v)), cp.cos(variable(v)), cp.tan(variable(v)))]
ret_refe = [r for v in test_vals for r in (cp.sin(v), cp.cos(v), cp.tan(v))]
@ -88,6 +89,27 @@ def test_trig_precision():
assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
def test_sqrt_precision():
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0]
ret_test = [r for v in test_vals for r in (cp.sqrt(variable(v)),)]
ret_refe = [r for v in test_vals for r in (cp.sqrt(v),)]
tg = Target()
tg.compile(ret_test)
tg.run()
for i, (test, ref) in enumerate(zip(ret_test, ret_refe)):
func_name = 'sqrt'
assert isinstance(test, cp.variable)
val = tg.read_value(test)
print(f"+ Result of {func_name}: {val}; reference: {ref}")
assert val == pytest.approx(ref, 1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
test_corse()
test_fine()
test_sqrt_precision()
test_trig_precision()