diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index c5ced93..baca9cf 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -2,6 +2,7 @@ import pkgutil from typing import Any, TypeVar, overload, TypeAlias, Generic, cast from ._stencils import stencil_database import platform +import copapy as cp NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool' unifloat: TypeAlias = 'variable[float] | float' @@ -228,12 +229,7 @@ class variable(Generic[TNum], Net): @overload def __pow__(self, other: NumLike) -> 'variable[float] | variable[int]': ... def __pow__(self, other: NumLike) -> Any: - if not isinstance(other, variable): - if other == 2: - return self * self - if other == -1: - return 1 / self - return add_op('pow', [self, other]) + return cp.pow(self, other) @overload def __rpow__(self: TCPNum, other: int) -> TCPNum: ... diff --git a/src/copapy/_math.py b/src/copapy/_math.py index ab64250..31ba2dd 100644 --- a/src/copapy/_math.py +++ b/src/copapy/_math.py @@ -45,8 +45,10 @@ def log(x: NumLike) -> variable[float] | float: @overload def pow(x: float | int, y: float | int) -> float: ... @overload -def pow(x: variable[Any], y: variable[Any]) -> variable[float]: ... -def pow(x: NumLike, y: NumLike) -> variable[float] | float: +def pow(x: variable[Any], y: NumLike) -> variable[float]: ... +@overload +def pow(x: NumLike, y: variable[Any]) -> variable[float]: ... +def pow(x: NumLike, y: NumLike) -> NumLike: """x to the power of y Arguments: @@ -55,6 +57,15 @@ def pow(x: NumLike, y: NumLike) -> variable[float] | float: Returns: result of x**y """ + if isinstance(y, int) and 0 <= y < 16: + if y == 0: + return 1 + m = x + for _ in range(y - 1): + m *= x + return m + if y == -1: + return 1 / x return exp(y * log(x)) @@ -198,7 +209,7 @@ def acos(x: NumLike) -> variable[float] | float: Returns: Inverse cosine of x """ - return 2 * math.pi - asin(x) + return math.pi / 2 - asin(x) def get_42() -> variable[float]: diff --git a/stencils/aux_functions.c b/stencils/aux_functions.c index dd81a17..3c73674 100644 --- a/stencils/aux_functions.c +++ b/stencils/aux_functions.c @@ -33,15 +33,35 @@ NOINLINE float aux_get_42(float n) { return n + 42.0; } -float fast_pow_float(float base, float exponent) { - union { - float f; - uint32_t i; - } u; +NOINLINE float aux_log(float x) +{ + union { float f; uint32_t i; } vx = { x }; + float e = (float)((vx.i >> 23) & 0xFF) - 127.0f; + vx.i = (vx.i & 0x007FFFFF) | 0x3F800000; // normalized mantissa in [1,2) + float m = vx.f; - u.f = base; - int32_t x = u.i; - int32_t y = (int32_t)(exponent * (x - 1072632447) + 1072632447); - u.i = (uint32_t)y; - return u.f; + // 3rd-degree minimax polynomial approximation of log2(m) + // over [1, 2): log2(m) ≈ p(m) = a*m^3 + b*m^2 + c*m + d + float p = -0.34484843f * m * m * m + 2.02466578f * m * m - 2.67487759f * m + 1.65149613f; + + float log2x = e + p; + return log2x * 0.69314718f; // convert log2 → ln +} + +NOINLINE float aux_exp(float x) +{ + // Scale by 1/ln(2) + x = x * 1.44269504089f; + float xi = (float)((int)x); + float f = x - xi; + + // Polynomial approximation of 2^f for f ∈ [0,1) + float p = 1.0f + f * (0.69314718f + f * (0.24022651f + f * (0.05550411f))); + + // Reconstruct exponent + int ei = (int)xi + 127; + if (ei <= 0) ei = 0; else if (ei >= 255) ei = 255; + union { uint32_t i; float f; } v = { (uint32_t)(ei << 23) }; + + return v.f * p; } diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index 9d5b2a9..0fb0e71 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -87,7 +87,7 @@ def get_cast(type1: str, type2: str, type_out: str) -> str: @norm_indent -def get_func2(func_name: str, type1: str, type2: str) -> str: +def get_func1(func_name: str, type1: str, type2: str) -> str: return f""" {stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ STENCIL_START({func_name}_{type1}_{type2}); @@ -96,6 +96,16 @@ def get_func2(func_name: str, type1: str, type2: str) -> str: """ +@norm_indent +def get_func2(func_name: str, type1: str, type2: str) -> str: + return f""" + {stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ + STENCIL_START({func_name}_{type1}_{type2}); + result_float_{type2}(aux_{func_name}((float)arg1, (float)arg2), arg2); + }} + """ + + @norm_indent def get_conv_code(type1: str, type2: str, type_out: str) -> str: return f""" @@ -205,7 +215,7 @@ if __name__ == "__main__": # Scalar arithmetic: types = ['int', 'float'] - ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'pow', 'atan2'] + ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'atan2'] int_ops = ['bwand', 'bwor', 'bwxor', 'lshift', 'rshift'] for t1 in types: @@ -220,9 +230,9 @@ if __name__ == "__main__": t_out = 'int' if t1 == 'float' else 'float' code += get_cast(t1, t2, t_out) - fnames = ['sqrt', 'exp', 'sin', 'cos', 'tan', 'asin', 'atan', 'get_42'] + fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan', 'get_42'] for fn, t1 in permutate(fnames, types): - code += get_func2(fn, t1, t1) + code += get_func1(fn, t1, t1) for op, t1, t2 in permutate(ops, types, types): t_out = t1 if t1 == t2 else 'float' @@ -230,7 +240,7 @@ if __name__ == "__main__": code += get_floordiv('floordiv', t1, t2) elif op == 'div': code += get_op_code_float(op, t1, t2) - elif op in {'pow', 'atan2'}: + elif op in {'atan2'}: code += get_func2(op, t1, t2) elif op in {'gt', 'eq', 'ge', 'ne'}: code += get_op_code(op, t1, t2, 'int') diff --git a/stencils/test.c b/stencils/test.c index 2d1253a..ec90a20 100644 --- a/stencils/test.c +++ b/stencils/test.c @@ -5,11 +5,15 @@ int main() { // Test aux functions float a = 16.0f; float sqrt_a = aux_sqrt(100000.0f); - float pow_a = fast_pow_float(a, 0.5f); float div_result = (float)floor_div(-7.0f, 3.0f); float sin_30 = aux_sin(30.0f); float cos_60 = aux_cos(60.0f); float tan_45 = aux_tan(45.0f); + float atan_15 = aux_atan(1.5f); + float asin_15 = aux_asin(1.5f); + float atan2_15 = aux_atan2(1.5f, 1.5f); + float exp_5 = aux_exp(5.0); + float log_5 = aux_log(5.0); float g42 = aux_get_42(0.0f); return 0; } diff --git a/stencils/trigonometry.c b/stencils/trigonometry.c index fe027a3..546154d 100644 --- a/stencils/trigonometry.c +++ b/stencils/trigonometry.c @@ -182,9 +182,9 @@ NOINLINE float aux_atan2(float y, float x) { float angle; if (abs_x > abs_y) - angle = fast_atanf_fp32_nolib(y / x); + angle = aux_atan(y / x); else - angle = PI_2 - fast_atanf_fp32_nolib(x / y); + angle = PI_2 - aux_atan(x / y); // Quadrant correction if (x < 0) angle = (y >= 0) ? angle + PI : angle - PI; diff --git a/tests/test_math.py b/tests/test_math.py index 8c3fa82..1dc2e3f 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,6 +1,8 @@ from copapy import variable, Target import pytest import copapy as cp +import math as ma +import warnings def test_corse(): @@ -44,7 +46,7 @@ def test_fine(): cp.cos(c_f), cp.tan(c_f)) # , c_i & 3) - ret_refe = (a_f ** 2, + re2_test = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f), @@ -52,6 +54,14 @@ def test_fine(): cp.cos(a_f), cp.tan(a_f)) # , a_i & 3) + ret_refe = (a_f ** 2, + a_i ** -1, + ma.sqrt(a_i), + ma.sqrt(a_f), + ma.sin(a_f), + ma.cos(a_f), + ma.tan(a_f)) # , a_i & 3) + tg = Target() print('* compile and copy ...') tg.compile(ret_test) @@ -59,13 +69,14 @@ def test_fine(): tg.run() print('* finished') - for test, ref, name in zip(ret_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')): + for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')): assert isinstance(test, cp.variable) val = tg.read_value(test) print('+', val, ref, type(val), test.dtype) #for t in (int, float, bool): # assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" - assert val == pytest.approx(ref, 1e-5), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + assert val == pytest.approx(ref, 1e-5), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + assert val2 == pytest.approx(ref, 1e-5), f"Local result for {name} does not match: {val2} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] def test_trig_precision(): @@ -74,8 +85,7 @@ def test_trig_precision(): -0.0001, -0.1, -0.5, -1.0, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0, -5.5, -6.0, -6.28318530718, -100.0, -1000.0, -100000.0] ret_test = [r for v in test_vals for r in (cp.sin(variable(v)), cp.cos(variable(v)), cp.tan(variable(v)))] - ret_refe = [r for v in test_vals for r in (cp.sin(v), cp.cos(v), cp.tan(v))] - + ret_refe = [r for v in test_vals for r in (ma.sin(v), ma.cos(v), ma.tan(v))] tg = Target() tg.compile(ret_test) @@ -89,6 +99,36 @@ def test_trig_precision(): assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] +def test_arcus_trig_precision(): + + test_vals = [0.0, 0.01, 0.1, 0.5, 0.7, 0.9, 0.95, + -0.01, -0.1, -0.5, -0.7, -0.9, 0.95] + + ret_test = [r for v in test_vals for r in (cp.asin(variable(v)), + cp.acos(variable(v)), + cp.atan(variable(v)), + cp.atan2(variable(v), variable(0.7)), + cp.atan2(variable(v), variable(-0.2)))] + ret_refe = [r for v in test_vals for r in (ma.asin(v), + ma.acos(v), + ma.atan(v), + ma.atan2(v, 0.7), + ma.atan2(v, -0.2))] + + tg = Target() + tg.compile(ret_test) + tg.run() + + for i, (test, ref) in enumerate(zip(ret_test, ret_refe)): + func_name = ['asin', 'acos', 'atan', 'atan2[1]', 'atan2[2]'][i % 5] + assert isinstance(test, cp.variable) + val = tg.read_value(test) + print(f"+ Result of {func_name}: {val}; reference: {ref}") + #assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 5]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + if not val == pytest.approx(ref, abs=1e-5): # pyright: ignore[reportUnknownMemberType] + warnings.warn(f"Result of {func_name} for input {test_vals[i // 5]} does not match: {val} and reference: {ref}", UserWarning) + + def test_sqrt_precision(): test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] @@ -105,7 +145,31 @@ def test_sqrt_precision(): assert isinstance(test, cp.variable) val = tg.read_value(test) print(f"+ Result of {func_name}: {val}; reference: {ref}") - assert val == pytest.approx(ref, 1e-5), f"Result of {func_name} for input {test_vals[i // 3]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + assert val == pytest.approx(ref, 1e-5), f"Result of {func_name} for input {test_vals[i]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + + +def test_log_exp_precision(): + + test_vals = [0.1, 0.5, 0.9, 0.999, 1.0, 8.8, 12.0 + -0.1, -0.5, -0.9, -0.999, -1.0, 8.8, 12.0] + + ret_test = [r for v in test_vals for r in (cp.log(variable(abs(v))), + cp.exp(variable(v)))] + ret_refe = [r for v in test_vals for r in (ma.log(abs(v)), + ma.exp(v))] + + tg = Target() + tg.compile(ret_test) + tg.run() + + for i, (test, ref) in enumerate(zip(ret_test, ret_refe)): + func_name = ['log', 'exp'][i % 2] + assert isinstance(test, cp.variable) + val = tg.read_value(test) + print(f"+ Result of {func_name}: {val}; reference: {ref}") + #assert val == pytest.approx(ref, rel=4e-2), f"Result of {func_name} for input {test_vals[i // 2]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + if not val == pytest.approx(ref, rel=4e-2): # pyright: ignore[reportUnknownMemberType] + warnings.warn(f"Result of {func_name} for input {test_vals[i // 2]} does not match: {val} and reference: {ref}", UserWarning) if __name__ == "__main__":