diff --git a/src/copapy/_math.py b/src/copapy/_math.py index 4f4d2a8..2fe33d8 100644 --- a/src/copapy/_math.py +++ b/src/copapy/_math.py @@ -16,6 +16,22 @@ def sqrt(x: NumLike) -> variable[float] | float: return float(x ** 0.5) +@overload +def sqrt2(x: float | int) -> float: ... +@overload +def sqrt2(x: variable[Any]) -> variable[float]: ... +def sqrt2(x: NumLike) -> variable[float] | float: + """Square root function""" + if isinstance(x, variable): + return add_op('sqrt2', [x, x]) # TODO: fix 2. dummy argument + return float(x ** 0.5) + + +def get_42() -> variable[float]: + """Returns the variable representing the constant 42""" + return add_op('get_42', [0.0, 0.0]) + + def abs(x: T) -> T: """Absolute value function""" ret = (x < 0) * -x + (x >= 0) * x diff --git a/stencils/aux_functions.c b/stencils/aux_functions.c index c166acf..97b28eb 100644 --- a/stencils/aux_functions.c +++ b/stencils/aux_functions.c @@ -12,7 +12,7 @@ __attribute__((noinline)) int floor_div(float arg1, float arg2) { return i; } -__attribute__((noinline)) float fast_sqrt2(float n) { +__attribute__((noinline)) float sqrt(float n) { if (n < 0) return -1; float x = n; // initial guess @@ -25,8 +25,12 @@ __attribute__((noinline)) float fast_sqrt2(float n) { return x; } -__attribute__((noinline)) float fast_sqrt(float n) { - return n * 3.5 + 4.5; +__attribute__((noinline)) float sqrt2(float n) { + return n * 20.5 + 4.5; +} + +__attribute__((noinline)) float get_42(float n) { + return n * + 42.0; } float fast_pow_float(float base, float exponent) { diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index 81fbbd7..d26ee88 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -79,10 +79,10 @@ def get_cast(type1: str, type2: str, type_out: str) -> str: @norm_indent -def get_sqrt(type1: str, type2: str) -> str: +def get_func2(func_name: str, type1: str, type2: str) -> str: return f""" - {stencil_func_prefix}void sqrt_{type1}_{type2}({type1} arg1, {type2} arg2) {{ - result_float_{type2}(fast_sqrt((float)arg1), arg2); + {stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ + result_float_{type2}({func_name}((float)arg1), arg2); }} """ @@ -213,7 +213,9 @@ if __name__ == "__main__": code += get_cast(t1, t2, t_out) for t1, t2 in permutate(types, types): - code += get_sqrt(t1, t2) + code += get_func2('sqrt', t1, t2) + code += get_func2('sqrt2', t1, t2) + code += get_func2('get_42', t1, t2) for op, t1, t2 in permutate(ops, types, types): t_out = t1 if t1 == t2 else 'float' diff --git a/tests/test_vector.py b/tests/test_vector.py index 51ba3a9..cbae5a2 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -15,17 +15,19 @@ def test_compiled_vectors(): t2 = t1.sum() t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3)) - t4 = ((t3 * t1) * 2).magnitude() - #t4 = ((t3 * t1) * 2).sum() - t5 = cp.sqrt(cp.variable(8.0)) + #t4 = ((t3 * t1) * 2).magnitude() + t4 = ((t3 * t1) * 2).sum() + t5 = cp._math.sqrt2(cp.variable(8.0)) + t6 = cp._math.get_42() tg = cp.Target() - tg.compile(t2, t4, t5) + tg.compile(t2, t4, t5, t6) tg.run() assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2 #assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5 - assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 3.5 + 4.5 + assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 20.5 + 4.5 + assert tg.read_value(t6) == 42.0 if __name__ == "__main__": test_compiled_vectors()