From 1277369f06388846f99d22d1ecc6b117d3f18dfa Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Mon, 20 Oct 2025 22:23:31 +0200 Subject: [PATCH] fast c pow function integrated including unit test --- src/copapy/_target.py | 6 +++--- tests/test_ext_ops.py | 31 +++++++++++++++++++++++++++++++ tools/crosscompile.sh | 2 +- tools/generate_stencils.py | 18 +++++++++++++++++- 4 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 tests/test_ext_ops.py diff --git a/src/copapy/_target.py b/src/copapy/_target.py index 631cb70..9f77903 100644 --- a/src/copapy/_target.py +++ b/src/copapy/_target.py @@ -1,4 +1,4 @@ -from typing import overload +from typing import Iterable, overload from . import _binwrite as binw from coparun_module import coparun, read_data_mem import struct @@ -20,10 +20,10 @@ class Target(): self.sdb = stencil_db_from_package(arch, optimization) self._variables: dict[Net, tuple[int, int, str]] = dict() - def compile(self, *variables: int | float | cpint | cpfloat | cpbool | list[int | float | cpint | cpfloat | cpbool]) -> None: + def compile(self, *variables: int | float | cpint | cpfloat | cpbool | Iterable[int | float | cpint | cpfloat | cpbool]) -> None: nodes: list[Node] = [] for s in variables: - if isinstance(s, list): + if isinstance(s, Iterable): for net in s: assert isinstance(net, Net), f"The folowing element is not a Net: {net}" nodes.append(Write(net)) diff --git a/tests/test_ext_ops.py b/tests/test_ext_ops.py new file mode 100644 index 0000000..8123694 --- /dev/null +++ b/tests/test_ext_ops.py @@ -0,0 +1,31 @@ +from copapy import cpvalue, Target +import pytest +import copapy + + +def test_compile(): + c_i = cpvalue(9) + c_f = cpvalue(1.111) + # c_b = cpvalue(True) + + ret_test = (c_f ** c_f, c_i ** c_i) + ret_ref = (1.111 ** 1.111, 9 ** 9) + + tg = Target() + print('* compile and copy ...') + tg.compile(ret_test) + print('* run and copy ...') + tg.run() + print('* finished') + + for test, ref in zip(ret_test, ret_ref): + assert isinstance(test, copapy.CPNumber) + val = tg.read_value(test) + print('+', val, ref, type(val), test.dtype) + #for t in (int, float, bool): + # assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" + assert val == pytest.approx(ref, 1e-3), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + + +if __name__ == "__main__": + test_compile() diff --git a/tools/crosscompile.sh b/tools/crosscompile.sh index 684045d..d81b3e8 100644 --- a/tools/crosscompile.sh +++ b/tools/crosscompile.sh @@ -47,7 +47,7 @@ mips-linux-gnu-gcc-13 -$OPT -c $SRC -o $DEST/stencils_mips_$OPT.o mipsel-linux-gnu-gcc-13 -$OPT -c $SRC -o $DEST/stencils_mipsel_$OPT.o # RISCV 32 Bit -riscv64-linux-gnu-gcc-13 -$OPT -march=rv32imac -mabi=ilp32 -c $SRC -o $DEST/stencils_riscv32_$OPT.o +# riscv64-linux-gnu-gcc-13 -$OPT -march=rv32imac -mabi=ilp32 -c $SRC -o $DEST/stencils_riscv32_$OPT.o # RISCV 64 Bit riscv64-linux-gnu-gcc-13 -$OPT -c $SRC -o $DEST/stencils_riscv64_$OPT.o diff --git a/tools/generate_stencils.py b/tools/generate_stencils.py index 93bc081..f81eb32 100644 --- a/tools/generate_stencils.py +++ b/tools/generate_stencils.py @@ -26,6 +26,19 @@ def get_aux_funcs() -> str: if (x < 0 && x != (float)i) i -= 1; return i; } + + float fast_pow_float(float base, float exponent) { + union { + float f; + uint32_t i; + } u; + + u.f = base; + int32_t x = u.i; + int32_t y = (int32_t)(exponent * (x - 1072632447) + 1072632447); + u.i = (uint32_t)y; + return u.f; + } """ @@ -64,7 +77,8 @@ def get_op_code_float(op: str, type1: str, type2: str) -> str: def get_pow(type1: str, type2: str) -> str: return f""" {stencil_func_prefix}void pow_{type1}_{type2}({type1} arg1, {type2} arg2) {{ - result_float_{type2}((float)math_pow((double)arg1, (double)arg2), arg2); + //result_float_{type2}((float)math_pow((double)arg1, (double)arg2), arg2); + result_float_{type2}(fast_pow_float((float)arg1, (float)arg2), arg2); }} """ @@ -144,6 +158,8 @@ if __name__ == "__main__": // Auto-generated stencils for copapy // Do not edit manually + #include + double (*math_pow)(double, double); volatile int dummy_int = 1337;