inverse trig functions and log, exp and pow function added

This commit is contained in:
Nicolas Kruse 2025-11-05 21:46:53 +01:00 committed by Nicolas Kruse
parent e0c4bd5280
commit d1935a34f8
4 changed files with 197 additions and 18 deletions

View File

@ -1,7 +1,7 @@
from ._target import Target from ._target import Target
from ._basic_types import NumLike, variable, generic_sdb, iif from ._basic_types import NumLike, variable, generic_sdb, iif
from ._vectors import vector from ._vectors import vector
from ._math import sqrt, abs, sin, cos, tan from ._math import sqrt, abs, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow
__all__ = [ __all__ = [
"Target", "Target",
@ -14,5 +14,12 @@ __all__ = [
"abs", "abs",
"sin", "sin",
"cos", "cos",
"tan" "tan",
"asin",
"acos",
"atan",
"atan2",
"log",
"exp",
"pow"
] ]

View File

@ -6,6 +6,58 @@ import math
T = TypeVar("T", int, float, variable[int], variable[float]) T = TypeVar("T", int, float, variable[int], variable[float])
@overload
def exp(x: float | int) -> float: ...
@overload
def exp(x: variable[Any]) -> variable[float]: ...
def exp(x: NumLike) -> variable[float] | float:
"""Exponential function to basis e
Arguments:
x: Input value
Returns:
result of e**x
"""
if isinstance(x, variable):
return add_op('exp', [x, x]) # TODO: fix 2. dummy argument
return float(math.exp(x))
@overload
def log(x: float | int) -> float: ...
@overload
def log(x: variable[Any]) -> variable[float]: ...
def log(x: NumLike) -> variable[float] | float:
"""Logarithm to basis e
Arguments:
x: Input value
Returns:
result of ln(x)
"""
if isinstance(x, variable):
return add_op('log', [x, x]) # TODO: fix 2. dummy argument
return float(math.log(x))
@overload
def pow(x: float | int, y: float | int) -> float: ...
@overload
def pow(x: variable[Any], y: variable[Any]) -> variable[float]: ...
def pow(x: NumLike, y: NumLike) -> variable[float] | float:
"""x to the power of y
Arguments:
x: Input value
Returns:
result of x**y
"""
return exp(y * log(x))
@overload @overload
def sqrt(x: float | int) -> float: ... def sqrt(x: float | int) -> float: ...
@overload @overload
@ -21,7 +73,7 @@ def sqrt(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
return float(x ** 0.5) return float(math.sqrt(x))
@overload @overload
@ -41,6 +93,7 @@ def sin(x: NumLike) -> variable[float] | float:
return add_op('sin', [x, x]) # TODO: fix 2. dummy argument return add_op('sin', [x, x]) # TODO: fix 2. dummy argument
return math.sin(x) return math.sin(x)
@overload @overload
def cos(x: float | int) -> float: ... def cos(x: float | int) -> float: ...
@overload @overload
@ -58,6 +111,7 @@ def cos(x: NumLike) -> variable[float] | float:
return add_op('cos', [x, x]) # TODO: fix 2. dummy argument return add_op('cos', [x, x]) # TODO: fix 2. dummy argument
return math.cos(x) return math.cos(x)
@overload @overload
def tan(x: float | int) -> float: ... def tan(x: float | int) -> float: ...
@overload @overload
@ -76,6 +130,77 @@ def tan(x: NumLike) -> variable[float] | float:
return math.tan(x) return math.tan(x)
@overload
def atan(x: float | int) -> float: ...
@overload
def atan(x: variable[Any]) -> variable[float]: ...
def atan(x: NumLike) -> variable[float] | float:
"""Inverse tangent function
Arguments:
x: Input value
Returns:
Inverse tangent of x
"""
if isinstance(x, variable):
return add_op('atan', [x, x]) # TODO: fix 2. dummy argument
return math.atan(x)
@overload
def atan2(x: float | int, y: float | int) -> float: ...
@overload
def atan2(x: variable[Any], y: variable[Any]) -> variable[float]: ...
def atan2(x: NumLike, y: NumLike) -> variable[float] | float:
"""2-argument arctangent
Arguments:
x: Input value
y: Input value
Returns:
Result in radian
"""
if isinstance(x, variable) or isinstance(y, variable):
return add_op('atan', [x, x]) # TODO: fix 2. dummy argument
return math.atan2(x, y)
@overload
def asin(x: float | int) -> float: ...
@overload
def asin(x: variable[Any]) -> variable[float]: ...
def asin(x: NumLike) -> variable[float] | float:
"""Inverse sine function
Arguments:
x: Input value
Returns:
Inverse sine of x
"""
if isinstance(x, variable):
return add_op('asin', [x, x]) # TODO: fix 2. dummy argument
return math.asin(x)
@overload
def acos(x: float | int) -> float: ...
@overload
def acos(x: variable[Any]) -> variable[float]: ...
def acos(x: NumLike) -> variable[float] | float:
"""Inverse cosine function
Arguments:
x: Input value
Returns:
Inverse cosine of x
"""
return 2 * math.pi - asin(x)
def get_42() -> variable[float]: def get_42() -> variable[float]:
"""Returns the variable representing the constant 42""" """Returns the variable representing the constant 42"""
return add_op('get_42', [0.0, 0.0]) return add_op('get_42', [0.0, 0.0])
@ -92,4 +217,3 @@ def abs(x: T) -> T:
""" """
ret = (x < 0) * -x + (x >= 0) * x ret = (x < 0) * -x + (x >= 0) * x
return ret # pyright: ignore[reportReturnType] return ret # pyright: ignore[reportReturnType]

View File

@ -116,16 +116,6 @@ def get_op_code_float(op: str, type1: str, type2: str) -> str:
""" """
@norm_indent
def get_pow(type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void pow_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START(pow_{type1}_{type2});
result_float_{type2}(fast_pow_float((float)arg1, (float)arg2), arg2);
}}
"""
@norm_indent @norm_indent
def get_floordiv(op: str, type1: str, type2: str) -> str: def get_floordiv(op: str, type1: str, type2: str) -> str:
if type1 == 'int' and type2 == 'int': if type1 == 'int' and type2 == 'int':
@ -215,7 +205,7 @@ if __name__ == "__main__":
# Scalar arithmetic: # Scalar arithmetic:
types = ['int', 'float'] types = ['int', 'float']
ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'pow'] ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'pow', 'atan2']
int_ops = ['bwand', 'bwor', 'bwxor', 'lshift', 'rshift'] int_ops = ['bwand', 'bwor', 'bwxor', 'lshift', 'rshift']
for t1 in types: for t1 in types:
@ -230,7 +220,7 @@ if __name__ == "__main__":
t_out = 'int' if t1 == 'float' else 'float' t_out = 'int' if t1 == 'float' else 'float'
code += get_cast(t1, t2, t_out) code += get_cast(t1, t2, t_out)
fnames = ['sqrt', 'sin', 'cos', 'tan', 'get_42'] fnames = ['sqrt', 'exp', 'sin', 'cos', 'tan', 'asin', 'atan', 'get_42']
for fn, t1 in permutate(fnames, types): for fn, t1 in permutate(fnames, types):
code += get_func2(fn, t1, t1) code += get_func2(fn, t1, t1)
@ -240,8 +230,8 @@ if __name__ == "__main__":
code += get_floordiv('floordiv', t1, t2) code += get_floordiv('floordiv', t1, t2)
elif op == 'div': elif op == 'div':
code += get_op_code_float(op, t1, t2) code += get_op_code_float(op, t1, t2)
elif op == 'pow': elif op in {'pow', 'atan2'}:
code += get_pow(t1, t2) code += get_func2(op, t1, t2)
elif op in {'gt', 'eq', 'ge', 'ne'}: elif op in {'gt', 'eq', 'ge', 'ne'}:
code += get_op_code(op, t1, t2, 'int') code += get_op_code(op, t1, t2, 'int')
else: else:

View File

@ -146,3 +146,61 @@ NOINLINE float aux_tan(float x) {
return sign * t; return sign * t;
} }
NOINLINE float aux_atan(float x) {
const float absx = x < 0 ? -x : x;
// Coefficients for a rational minimax fit on [0,1]
const float a0 = 0.9998660f;
const float a1 = -0.3302995f;
const float b1 = 0.1801410f;
const float b2 = -0.0126492f;
float y;
if (absx <= 1.0f) {
float x2 = x * x;
y = x * (a0 + a1 * x2) / (1.0f + b1 * x2 + b2 * x2 * x2);
} else {
float inv = 1.0f / absx;
float x2 = inv * inv;
float core = inv * (a0 + a1 * x2) / (1.0f + b1 * x2 + b2 * x2 * x2);
y = PI_2 - core;
}
return x < 0 ? -y : y;
}
NOINLINE float aux_atan2(float y, float x) {
if (x == 0.0f) {
if (y > 0.0f) return PI_2;
if (y < 0.0f) return -PI_2;
return 0.0f; // TODO: undefined
}
float abs_y = y < 0 ? -y : y;
float abs_x = x < 0 ? -x : x;
float angle;
if (abs_x > abs_y)
angle = fast_atanf_fp32_nolib(y / x);
else
angle = PI_2 - fast_atanf_fp32_nolib(x / y);
// Quadrant correction
if (x < 0) angle = (y >= 0) ? angle + PI : angle - PI;
return angle;
}
NOINLINE float aux_asin(float x) {
const float PI_2 = 1.57079632679489661923f;
if (x > 1.0f) x = 1.0f;
if (x < -1.0f) x = -1.0f;
const float c3 = 0.16666667f; // ≈ 1/6
const float c5 = 0.07500000f; // ≈ 3/40
const float c7 = 0.04464286f; // ≈ 5/112
float x2 = x * x;
float p = x + x * x2 * (c3 + x2 * (c5 + c7 * x2));
return p;
}