diff --git a/stencils/aux_functions.c b/stencils/aux_functions.c index 18e950e..3a74420 100644 --- a/stencils/aux_functions.c +++ b/stencils/aux_functions.c @@ -12,17 +12,20 @@ __attribute__((noinline)) int floor_div(float arg1, float arg2) { return i; } -__attribute__((noinline)) float aux_sqrt(float n) { - if (n < 0) return -1; +__attribute__((noinline)) float aux_sqrt(float x) { + if (x <= 0.0f) return 0.0f; - float x = n; // initial guess - float epsilon = 0.00001; // desired accuracy + // --- Improved initial guess using bit-level trick --- + union { float f; uint32_t i; } conv = { x }; + conv.i = (conv.i >> 1) + 0x1fc00000; // better bias constant + float y = conv.f; - while ((x - n / x) > epsilon || (x - n / x) < -epsilon) { - x = 0.5 * (x + n / x); - } + // --- Fixed number of Newton-Raphson iterations --- + y = 0.5f * (y + x / y); + y = 0.5f * (y + x / y); + y = 0.5f * (y + x / y); // 3 fixed iterations - return x; + return y; } __attribute__((noinline)) float aux_get_42(float n) { diff --git a/tests/test_math.py b/tests/test_math.py index 3deaaa7..8241b3e 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -70,7 +70,8 @@ def test_fine(): def test_trig_precision(): - test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] # up to 2pi + test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0, + -0.0001, -0.1, -0.5, -1.0, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0, -5.5, -6.0, -6.28318530718, -100.0, -1000.0, -100000.0] ret_test = [r for v in test_vals for r in (cp.sin(variable(v)), cp.cos(variable(v)), cp.tan(variable(v)))] ret_refe = [r for v in test_vals for r in (cp.sin(v), cp.cos(v), cp.tan(v))] @@ -88,6 +89,27 @@ def test_trig_precision(): assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] +def test_sqrt_precision(): + + test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] + + ret_test = [r for v in test_vals for r in (cp.sqrt(variable(v)),)] + ret_refe = [r for v in test_vals for r in (cp.sqrt(v),)] + + tg = Target() + tg.compile(ret_test) + tg.run() + + for i, (test, ref) in enumerate(zip(ret_test, ret_refe)): + func_name = 'sqrt' + assert isinstance(test, cp.variable) + val = tg.read_value(test) + print(f"+ Result of {func_name}: {val}; reference: {ref}") + assert val == pytest.approx(ref, 1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + + if __name__ == "__main__": test_corse() test_fine() + test_sqrt_precision() + test_trig_precision()