cp.sign and cp.relu added to _math.py

This commit is contained in:
Nicolas Kruse 2025-12-02 16:57:06 +01:00
parent 5bdd77db91
commit d2df1dd3fb
2 changed files with 36 additions and 2 deletions

View File

@ -2,7 +2,7 @@ from ._target import Target
from ._basic_types import NumLike, variable, generic_sdb, iif
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
from ._matrices import matrix, identity, zeros, ones, diagonal
from ._math import sqrt, abs, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
__all__ = [
"Target",
@ -19,6 +19,7 @@ __all__ = [
"sqrt",
"abs",
"sin",
"sign",
"cos",
"tan",
"asin",
@ -32,6 +33,7 @@ __all__ = [
"clamp",
"min",
"max",
"relu",
"distance",
"scalar_projection",
"angle_between",

View File

@ -278,6 +278,8 @@ def get_42(x: NumLike) -> variable[float] | float:
return add_op('get_42', [x, x])
return float((int(x) * 3.0 + 42.0) * 5.0 + 21.0)
#TODO: Add vector support
@overload
def abs(x: U) -> U: ...
@overload
@ -292,7 +294,26 @@ def abs(x: U | variable[U]) -> Any:
Absolute value of x
"""
ret = (x < 0) * -x + (x >= 0) * x
return ret # pyright: ignore[reportReturnType]
return ret # REMpyright: ignore[reportReturnType]
#TODO: Add vector support
@overload
def sign(x: U) -> U: ...
@overload
def sign(x: variable[U]) -> variable[U]: ...
def sign(x: U | variable[U]) -> Any:
"""Return 1 for positive numbers and -1 for negative numbers.
For an input of 0 the return value is 0.
Arguments:
x: Input value
Returns:
-1, 0 or 1
"""
ret = (x > 0) - (x < 0)
return ret
@overload
@ -381,6 +402,17 @@ def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U
return v1 * (1 - t) + v2 * t
#TODO: Add vector support
@overload
def relu(x: U) -> U: ...
@overload
def relu(x: variable[U]) -> variable[U]: ...
def relu(x: U | variable[U]) -> Any:
"""Returns x for x > 0 and otherwise 0."""
ret = (x > 0) * x
return ret
def _map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], variable[U] | U]) -> vector[U]:
"""Applies a function to each element of the vector and a second vector or scalar."""
if isinstance(self, vector) and isinstance(other, vector):