Merge pull request #9 from Nonannet/dev

Element wise math operations for vectors
This commit is contained in:
Nicolas Kruse 2025-11-17 21:21:21 +01:00 committed by GitHub
commit b279da800a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 42 deletions

View File

@ -9,7 +9,7 @@ uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool' unibool: TypeAlias = 'variable[bool] | bool'
TCPNum = TypeVar("TCPNum", bound='variable[Any]') TCPNum = TypeVar("TCPNum", bound='variable[Any]')
TNum = TypeVar("TNum", int, bool, float) TNum = TypeVar("TNum", int, float, bool)
stencil_cache: dict[tuple[str, str], stencil_database] = {} stencil_cache: dict[tuple[str, str], stencil_database] = {}

View File

@ -1,16 +1,20 @@
from . import vector
from ._vectors import VecNumLike
from . import variable, NumLike from . import variable, NumLike
from typing import TypeVar, Any, overload from typing import TypeVar, Any, overload, Callable
from ._basic_types import add_op from ._basic_types import add_op
import math import math
T = TypeVar("T", int, float, variable[int], variable[float]) T = TypeVar("T", int, float, variable[int], variable[float])
U = TypeVar("U", int, float)
@overload @overload
def exp(x: float | int) -> float: ... def exp(x: float | int) -> float: ...
@overload @overload
def exp(x: variable[Any]) -> variable[float]: ... def exp(x: variable[Any]) -> variable[float]: ...
def exp(x: NumLike) -> variable[float] | float: @overload
def exp(x: vector[Any]) -> vector[float]: ...
def exp(x: Any) -> Any:
"""Exponential function to basis e """Exponential function to basis e
Arguments: Arguments:
@ -21,6 +25,8 @@ def exp(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('exp', [x]) return add_op('exp', [x])
if isinstance(x, vector):
return x.map(exp)
return float(math.exp(x)) return float(math.exp(x))
@ -28,7 +34,9 @@ def exp(x: NumLike) -> variable[float] | float:
def log(x: float | int) -> float: ... def log(x: float | int) -> float: ...
@overload @overload
def log(x: variable[Any]) -> variable[float]: ... def log(x: variable[Any]) -> variable[float]: ...
def log(x: NumLike) -> variable[float] | float: @overload
def log(x: vector[Any]) -> vector[float]: ...
def log(x: Any) -> Any:
"""Logarithm to basis e """Logarithm to basis e
Arguments: Arguments:
@ -39,6 +47,8 @@ def log(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('log', [x]) return add_op('log', [x])
if isinstance(x, vector):
return x.map(log)
return float(math.log(x)) return float(math.log(x))
@ -48,7 +58,9 @@ def pow(x: float | int, y: float | int) -> float: ...
def pow(x: variable[Any], y: NumLike) -> variable[float]: ... def pow(x: variable[Any], y: NumLike) -> variable[float]: ...
@overload @overload
def pow(x: NumLike, y: variable[Any]) -> variable[float]: ... def pow(x: NumLike, y: variable[Any]) -> variable[float]: ...
def pow(x: NumLike, y: NumLike) -> NumLike: @overload
def pow(x: vector[Any], y: Any) -> vector[float]: ...
def pow(x: VecNumLike, y: VecNumLike) -> Any:
"""x to the power of y """x to the power of y
Arguments: Arguments:
@ -57,6 +69,8 @@ def pow(x: NumLike, y: NumLike) -> NumLike:
Returns: Returns:
result of x**y result of x**y
""" """
if isinstance(x, vector) or isinstance(y, vector):
return map2(x, y, pow)
if isinstance(y, int) and 0 <= y < 8: if isinstance(y, int) and 0 <= y < 8:
if y == 0: if y == 0:
return 1 return 1
@ -76,7 +90,9 @@ def pow(x: NumLike, y: NumLike) -> NumLike:
def sqrt(x: float | int) -> float: ... def sqrt(x: float | int) -> float: ...
@overload @overload
def sqrt(x: variable[Any]) -> variable[float]: ... def sqrt(x: variable[Any]) -> variable[float]: ...
def sqrt(x: NumLike) -> variable[float] | float: @overload
def sqrt(x: vector[Any]) -> vector[float]: ...
def sqrt(x: Any) -> Any:
"""Square root function """Square root function
Arguments: Arguments:
@ -87,6 +103,8 @@ def sqrt(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('sqrt', [x]) return add_op('sqrt', [x])
if isinstance(x, vector):
return x.map(sqrt)
return float(math.sqrt(x)) return float(math.sqrt(x))
@ -94,7 +112,9 @@ def sqrt(x: NumLike) -> variable[float] | float:
def sin(x: float | int) -> float: ... def sin(x: float | int) -> float: ...
@overload @overload
def sin(x: variable[Any]) -> variable[float]: ... def sin(x: variable[Any]) -> variable[float]: ...
def sin(x: NumLike) -> variable[float] | float: @overload
def sin(x: vector[Any]) -> vector[float]: ...
def sin(x: Any) -> Any:
"""Sine function """Sine function
Arguments: Arguments:
@ -105,6 +125,8 @@ def sin(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('sin', [x]) return add_op('sin', [x])
if isinstance(x, vector):
return x.map(sin)
return math.sin(x) return math.sin(x)
@ -112,7 +134,9 @@ def sin(x: NumLike) -> variable[float] | float:
def cos(x: float | int) -> float: ... def cos(x: float | int) -> float: ...
@overload @overload
def cos(x: variable[Any]) -> variable[float]: ... def cos(x: variable[Any]) -> variable[float]: ...
def cos(x: NumLike) -> variable[float] | float: @overload
def cos(x: vector[Any]) -> vector[float]: ...
def cos(x: Any) -> Any:
"""Cosine function """Cosine function
Arguments: Arguments:
@ -123,6 +147,8 @@ def cos(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('cos', [x]) return add_op('cos', [x])
if isinstance(x, vector):
return x.map(cos)
return math.cos(x) return math.cos(x)
@ -130,7 +156,9 @@ def cos(x: NumLike) -> variable[float] | float:
def tan(x: float | int) -> float: ... def tan(x: float | int) -> float: ...
@overload @overload
def tan(x: variable[Any]) -> variable[float]: ... def tan(x: variable[Any]) -> variable[float]: ...
def tan(x: NumLike) -> variable[float] | float: @overload
def tan(x: vector[Any]) -> vector[float]: ...
def tan(x: Any) -> Any:
"""Tangent function """Tangent function
Arguments: Arguments:
@ -141,6 +169,9 @@ def tan(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('tan', [x]) return add_op('tan', [x])
if isinstance(x, vector):
#return x.map(tan)
return x.map(tan)
return math.tan(x) return math.tan(x)
@ -148,7 +179,9 @@ def tan(x: NumLike) -> variable[float] | float:
def atan(x: float | int) -> float: ... def atan(x: float | int) -> float: ...
@overload @overload
def atan(x: variable[Any]) -> variable[float]: ... def atan(x: variable[Any]) -> variable[float]: ...
def atan(x: NumLike) -> variable[float] | float: @overload
def atan(x: vector[Any]) -> vector[float]: ...
def atan(x: Any) -> Any:
"""Inverse tangent function """Inverse tangent function
Arguments: Arguments:
@ -159,14 +192,22 @@ def atan(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('atan', [x]) return add_op('atan', [x])
if isinstance(x, vector):
return x.map(atan)
return math.atan(x) return math.atan(x)
@overload @overload
def atan2(x: float | int, y: float | int) -> float: ... def atan2(x: float | int, y: float | int) -> float: ...
@overload @overload
def atan2(x: variable[Any], y: variable[Any]) -> variable[float]: ... def atan2(x: variable[Any], y: NumLike) -> variable[float]: ...
def atan2(x: NumLike, y: NumLike) -> variable[float] | float: @overload
def atan2(x: NumLike, y: variable[Any]) -> variable[float]: ...
@overload
def atan2(x: vector[float], y: VecNumLike) -> vector[float]: ...
@overload
def atan2(x: VecNumLike, y: vector[float]) -> vector[float]: ...
def atan2(x: VecNumLike, y: VecNumLike) -> Any:
"""2-argument arctangent """2-argument arctangent
Arguments: Arguments:
@ -176,6 +217,8 @@ def atan2(x: NumLike, y: NumLike) -> variable[float] | float:
Returns: Returns:
Result in radian Result in radian
""" """
if isinstance(x, vector) or isinstance(y, vector):
return map2(x, y, atan2)
if isinstance(x, variable) or isinstance(y, variable): if isinstance(x, variable) or isinstance(y, variable):
return add_op('atan2', [x, y]) return add_op('atan2', [x, y])
return math.atan2(x, y) return math.atan2(x, y)
@ -185,7 +228,9 @@ def atan2(x: NumLike, y: NumLike) -> variable[float] | float:
def asin(x: float | int) -> float: ... def asin(x: float | int) -> float: ...
@overload @overload
def asin(x: variable[Any]) -> variable[float]: ... def asin(x: variable[Any]) -> variable[float]: ...
def asin(x: NumLike) -> variable[float] | float: @overload
def asin(x: vector[Any]) -> vector[float]: ...
def asin(x: Any) -> Any:
"""Inverse sine function """Inverse sine function
Arguments: Arguments:
@ -196,6 +241,8 @@ def asin(x: NumLike) -> variable[float] | float:
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('asin', [x]) return add_op('asin', [x])
if isinstance(x, vector):
return x.map(asin)
return math.asin(x) return math.asin(x)
@ -203,7 +250,9 @@ def asin(x: NumLike) -> variable[float] | float:
def acos(x: float | int) -> float: ... def acos(x: float | int) -> float: ...
@overload @overload
def acos(x: variable[Any]) -> variable[float]: ... def acos(x: variable[Any]) -> variable[float]: ...
def acos(x: NumLike) -> variable[float] | float: @overload
def acos(x: vector[Any]) -> vector[float]: ...
def acos(x: Any) -> Any:
"""Inverse cosine function """Inverse cosine function
Arguments: Arguments:
@ -212,7 +261,11 @@ def acos(x: NumLike) -> variable[float] | float:
Returns: Returns:
Inverse cosine of x Inverse cosine of x
""" """
return math.pi / 2 - asin(x) if isinstance(x, variable):
return add_op('acos', [x])
if isinstance(x, vector):
return x.map(acos)
return math.asin(x)
@overload @overload
@ -237,3 +290,15 @@ def abs(x: T) -> T:
""" """
ret = (x < 0) * -x + (x >= 0) * x ret = (x < 0) * -x + (x >= 0) * x
return ret # pyright: ignore[reportReturnType] return ret # pyright: ignore[reportReturnType]
def map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], variable[U] | U]) -> vector[U]:
"""Applies a function to each element of the vector and a second vector or scalar."""
if isinstance(self, vector) and isinstance(other, vector):
return vector(func(x, y) for x, y in zip(self.values, other.values))
elif isinstance(self, vector):
return vector(func(x, other) for x in self.values)
elif isinstance(other, vector):
return vector(func(self, x) for x in other.values)
else:
return vector([func(self, other)])

View File

@ -1,11 +1,12 @@
from . import variable from . import variable
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable
from ._math import sqrt import copapy as cp
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float' VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | variable[bool] | int | float | bool'
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int' VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float' VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
T = TypeVar("T", int, float) T = TypeVar("T", int, float)
U = TypeVar("U", int, float)
epsilon = 1e-20 epsilon = 1e-20
@ -155,7 +156,7 @@ class vector(Generic[T]):
def magnitude(self) -> 'float | variable[float]': def magnitude(self) -> 'float | variable[float]':
"""Magnitude (length) of the vector.""" """Magnitude (length) of the vector."""
s = sum(a * a for a in self.values) s = sum(a * a for a in self.values)
return sqrt(s) if isinstance(s, variable) else sqrt(s) return cp.sqrt(s) if isinstance(s, variable) else cp.sqrt(s)
def normalize(self) -> 'vector[float]': def normalize(self) -> 'vector[float]':
"""Returns a normalized (unit length) version of the vector.""" """Returns a normalized (unit length) version of the vector."""
@ -164,3 +165,7 @@ class vector(Generic[T]):
def __iter__(self) -> Iterable[variable[T] | T]: def __iter__(self) -> Iterable[variable[T] | T]:
return iter(self.values) return iter(self.values)
def map(self, func: Callable[[Any], variable[U] | U]) -> 'vector[U]':
"""Applies a function to each element of the vector and returns a new vector."""
return vector(func(x) for x in self.values)

View File

@ -240,7 +240,7 @@ if __name__ == "__main__":
for fn, t1 in permutate(fnames, types): for fn, t1 in permutate(fnames, types):
code += get_func1(fn, t1, t1) code += get_func1(fn, t1, t1)
fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan'] fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan']
for fn, t1 in permutate(fnames, types): for fn, t1 in permutate(fnames, types):
code += get_math_func1(fn, t1) code += get_math_func1(fn, t1)

View File

@ -110,27 +110,6 @@ def test_arcus_trig_precision():
warnings.warn(f"Result of {func_name} for input {test_vals[i // 5]} does not match: {val} and reference: {ref}", UserWarning) warnings.warn(f"Result of {func_name} for input {test_vals[i // 5]} does not match: {val} and reference: {ref}", UserWarning)
def test_arcus_trig_crash():
v = 0.0
ret_test = [cp.asin(variable(v))]
ret_refe = [ma.asin(v)]
tg = Target()
tg.compile(ret_test)
tg.run()
for i, (test, ref) in enumerate(zip(ret_test, ret_refe)):
func_name = ['asin', 'acos', 'atan', 'atan2[1]', 'atan2[2]'][i % 5]
assert isinstance(test, cp.variable)
val = tg.read_value(test)
print(f"+ Result of {func_name}: {val}; reference: {ref}")
#assert val == pytest.approx(ref, abs=1e-5), f"Result of {func_name} for input {test_vals[i // 5]} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if not val == pytest.approx(ref, abs=1e-5): # pyright: ignore[reportUnknownMemberType]
warnings.warn(f"Result of {func_name} for input {test_vals[i // 5]} does not match: {val} and reference: {ref}", UserWarning)
def test_sqrt_precision(): def test_sqrt_precision():
test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0] test_vals = [0.0, 0.0001, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.28318530718, 100.0, 1000.0, 100000.0]