From 3db85352143ec9afc5b72d6664b8d29ef1d6b58b Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Wed, 5 Nov 2025 21:46:53 +0100 Subject: [PATCH] inverse trig functions and log, exp and pow function added --- src/copapy/__init__.py | 11 ++- src/copapy/_math.py | 128 +++++++++++++++++++++++++++++++++- stencils/generate_stencils.py | 18 ++--- stencils/trigonometry.c | 58 +++++++++++++++ 4 files changed, 197 insertions(+), 18 deletions(-) diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index e6478f7..941ea4c 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -1,7 +1,7 @@ from ._target import Target from ._basic_types import NumLike, variable, generic_sdb, iif from ._vectors import vector -from ._math import sqrt, abs, sin, cos, tan +from ._math import sqrt, abs, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow __all__ = [ "Target", @@ -14,5 +14,12 @@ __all__ = [ "abs", "sin", "cos", - "tan" + "tan", + "asin", + "acos", + "atan", + "atan2", + "log", + "exp", + "pow" ] diff --git a/src/copapy/_math.py b/src/copapy/_math.py index 0446693..ab64250 100644 --- a/src/copapy/_math.py +++ b/src/copapy/_math.py @@ -6,6 +6,58 @@ import math T = TypeVar("T", int, float, variable[int], variable[float]) +@overload +def exp(x: float | int) -> float: ... +@overload +def exp(x: variable[Any]) -> variable[float]: ... +def exp(x: NumLike) -> variable[float] | float: + """Exponential function to basis e + + Arguments: + x: Input value + + Returns: + result of e**x + """ + if isinstance(x, variable): + return add_op('exp', [x, x]) # TODO: fix 2. dummy argument + return float(math.exp(x)) + + +@overload +def log(x: float | int) -> float: ... +@overload +def log(x: variable[Any]) -> variable[float]: ... +def log(x: NumLike) -> variable[float] | float: + """Logarithm to basis e + + Arguments: + x: Input value + + Returns: + result of ln(x) + """ + if isinstance(x, variable): + return add_op('log', [x, x]) # TODO: fix 2. dummy argument + return float(math.log(x)) + + +@overload +def pow(x: float | int, y: float | int) -> float: ... +@overload +def pow(x: variable[Any], y: variable[Any]) -> variable[float]: ... +def pow(x: NumLike, y: NumLike) -> variable[float] | float: + """x to the power of y + + Arguments: + x: Input value + + Returns: + result of x**y + """ + return exp(y * log(x)) + + @overload def sqrt(x: float | int) -> float: ... @overload @@ -21,7 +73,7 @@ def sqrt(x: NumLike) -> variable[float] | float: """ if isinstance(x, variable): return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument - return float(x ** 0.5) + return float(math.sqrt(x)) @overload @@ -41,6 +93,7 @@ def sin(x: NumLike) -> variable[float] | float: return add_op('sin', [x, x]) # TODO: fix 2. dummy argument return math.sin(x) + @overload def cos(x: float | int) -> float: ... @overload @@ -58,6 +111,7 @@ def cos(x: NumLike) -> variable[float] | float: return add_op('cos', [x, x]) # TODO: fix 2. dummy argument return math.cos(x) + @overload def tan(x: float | int) -> float: ... @overload @@ -76,6 +130,77 @@ def tan(x: NumLike) -> variable[float] | float: return math.tan(x) +@overload +def atan(x: float | int) -> float: ... +@overload +def atan(x: variable[Any]) -> variable[float]: ... +def atan(x: NumLike) -> variable[float] | float: + """Inverse tangent function + + Arguments: + x: Input value + + Returns: + Inverse tangent of x + """ + if isinstance(x, variable): + return add_op('atan', [x, x]) # TODO: fix 2. dummy argument + return math.atan(x) + + +@overload +def atan2(x: float | int, y: float | int) -> float: ... +@overload +def atan2(x: variable[Any], y: variable[Any]) -> variable[float]: ... +def atan2(x: NumLike, y: NumLike) -> variable[float] | float: + """2-argument arctangent + + Arguments: + x: Input value + y: Input value + + Returns: + Result in radian + """ + if isinstance(x, variable) or isinstance(y, variable): + return add_op('atan', [x, x]) # TODO: fix 2. dummy argument + return math.atan2(x, y) + + +@overload +def asin(x: float | int) -> float: ... +@overload +def asin(x: variable[Any]) -> variable[float]: ... +def asin(x: NumLike) -> variable[float] | float: + """Inverse sine function + + Arguments: + x: Input value + + Returns: + Inverse sine of x + """ + if isinstance(x, variable): + return add_op('asin', [x, x]) # TODO: fix 2. dummy argument + return math.asin(x) + + +@overload +def acos(x: float | int) -> float: ... +@overload +def acos(x: variable[Any]) -> variable[float]: ... +def acos(x: NumLike) -> variable[float] | float: + """Inverse cosine function + + Arguments: + x: Input value + + Returns: + Inverse cosine of x + """ + return 2 * math.pi - asin(x) + + def get_42() -> variable[float]: """Returns the variable representing the constant 42""" return add_op('get_42', [0.0, 0.0]) @@ -92,4 +217,3 @@ def abs(x: T) -> T: """ ret = (x < 0) * -x + (x >= 0) * x return ret # pyright: ignore[reportReturnType] - diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index 98b1cd4..9d5b2a9 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -116,16 +116,6 @@ def get_op_code_float(op: str, type1: str, type2: str) -> str: """ -@norm_indent -def get_pow(type1: str, type2: str) -> str: - return f""" - {stencil_func_prefix}void pow_{type1}_{type2}({type1} arg1, {type2} arg2) {{ - STENCIL_START(pow_{type1}_{type2}); - result_float_{type2}(fast_pow_float((float)arg1, (float)arg2), arg2); - }} - """ - - @norm_indent def get_floordiv(op: str, type1: str, type2: str) -> str: if type1 == 'int' and type2 == 'int': @@ -215,7 +205,7 @@ if __name__ == "__main__": # Scalar arithmetic: types = ['int', 'float'] - ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'pow'] + ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'pow', 'atan2'] int_ops = ['bwand', 'bwor', 'bwxor', 'lshift', 'rshift'] for t1 in types: @@ -230,7 +220,7 @@ if __name__ == "__main__": t_out = 'int' if t1 == 'float' else 'float' code += get_cast(t1, t2, t_out) - fnames = ['sqrt', 'sin', 'cos', 'tan', 'get_42'] + fnames = ['sqrt', 'exp', 'sin', 'cos', 'tan', 'asin', 'atan', 'get_42'] for fn, t1 in permutate(fnames, types): code += get_func2(fn, t1, t1) @@ -240,8 +230,8 @@ if __name__ == "__main__": code += get_floordiv('floordiv', t1, t2) elif op == 'div': code += get_op_code_float(op, t1, t2) - elif op == 'pow': - code += get_pow(t1, t2) + elif op in {'pow', 'atan2'}: + code += get_func2(op, t1, t2) elif op in {'gt', 'eq', 'ge', 'ne'}: code += get_op_code(op, t1, t2, 'int') else: diff --git a/stencils/trigonometry.c b/stencils/trigonometry.c index 4d390f4..fe027a3 100644 --- a/stencils/trigonometry.c +++ b/stencils/trigonometry.c @@ -145,4 +145,62 @@ NOINLINE float aux_tan(float x) { if (t < -1e8f) t = -1e8f; return sign * t; +} + +NOINLINE float aux_atan(float x) { + const float absx = x < 0 ? -x : x; + + // Coefficients for a rational minimax fit on [0,1] + const float a0 = 0.9998660f; + const float a1 = -0.3302995f; + const float b1 = 0.1801410f; + const float b2 = -0.0126492f; + + float y; + if (absx <= 1.0f) { + float x2 = x * x; + y = x * (a0 + a1 * x2) / (1.0f + b1 * x2 + b2 * x2 * x2); + } else { + float inv = 1.0f / absx; + float x2 = inv * inv; + float core = inv * (a0 + a1 * x2) / (1.0f + b1 * x2 + b2 * x2 * x2); + y = PI_2 - core; + } + + return x < 0 ? -y : y; +} + +NOINLINE float aux_atan2(float y, float x) { + if (x == 0.0f) { + if (y > 0.0f) return PI_2; + if (y < 0.0f) return -PI_2; + return 0.0f; // TODO: undefined + } + + float abs_y = y < 0 ? -y : y; + float abs_x = x < 0 ? -x : x; + float angle; + + if (abs_x > abs_y) + angle = fast_atanf_fp32_nolib(y / x); + else + angle = PI_2 - fast_atanf_fp32_nolib(x / y); + + // Quadrant correction + if (x < 0) angle = (y >= 0) ? angle + PI : angle - PI; + return angle; +} + +NOINLINE float aux_asin(float x) { + const float PI_2 = 1.57079632679489661923f; + if (x > 1.0f) x = 1.0f; + if (x < -1.0f) x = -1.0f; + + const float c3 = 0.16666667f; // ≈ 1/6 + const float c5 = 0.07500000f; // ≈ 3/40 + const float c7 = 0.04464286f; // ≈ 5/112 + + float x2 = x * x; + float p = x + x * x2 * (c3 + x2 * (c5 + c7 * x2)); + return p; } \ No newline at end of file