Merge pull request #27 from Nonannet/dev

Dev
This commit is contained in:
Nicolas Kruse 2026-01-05 14:13:04 +01:00 committed by GitHub
commit 8d77ee3a25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 203 additions and 52 deletions

View File

@ -40,7 +40,7 @@ from ._tensors import tensor, zeros, ones, arange, eye, identity, diagonal
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
from ._autograd import grad from ._autograd import grad
from ._tensors import tensor as matrix from ._tensors import tensor as matrix
from ._version import __version__ from ._version import __version__ # Run "pip install -e ." to generate _version.py
__all__ = [ __all__ = [

View File

@ -1,5 +1,7 @@
from . import vector from . import vector
from . import tensor
from ._vectors import VecNumLike from ._vectors import VecNumLike
from ._tensors import TensorNumLike
from . import value, NumLike from . import value, NumLike
from typing import TypeVar, Any, overload, Callable from typing import TypeVar, Any, overload, Callable
from ._basic_types import add_op, unifloat from ._basic_types import add_op, unifloat
@ -15,6 +17,8 @@ def exp(x: float | int) -> float: ...
def exp(x: value[Any]) -> value[float]: ... def exp(x: value[Any]) -> value[float]: ...
@overload @overload
def exp(x: vector[Any]) -> vector[float]: ... def exp(x: vector[Any]) -> vector[float]: ...
@overload
def exp(x: tensor[Any]) -> tensor[float]: ...
def exp(x: Any) -> Any: def exp(x: Any) -> Any:
"""Exponential function to basis e """Exponential function to basis e
@ -26,7 +30,7 @@ def exp(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('exp', [x]) return add_op('exp', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(exp) return x.map(exp)
return float(math.exp(x)) return float(math.exp(x))
@ -37,6 +41,8 @@ def log(x: float | int) -> float: ...
def log(x: value[Any]) -> value[float]: ... def log(x: value[Any]) -> value[float]: ...
@overload @overload
def log(x: vector[Any]) -> vector[float]: ... def log(x: vector[Any]) -> vector[float]: ...
@overload
def log(x: tensor[Any]) -> tensor[float]: ...
def log(x: Any) -> Any: def log(x: Any) -> Any:
"""Logarithm to basis e """Logarithm to basis e
@ -48,7 +54,7 @@ def log(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('log', [x]) return add_op('log', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(log) return x.map(log)
return float(math.log(x)) return float(math.log(x))
@ -61,7 +67,13 @@ def pow(x: value[Any], y: NumLike) -> value[float]: ...
def pow(x: NumLike, y: value[Any]) -> value[float]: ... def pow(x: NumLike, y: value[Any]) -> value[float]: ...
@overload @overload
def pow(x: vector[Any], y: Any) -> vector[float]: ... def pow(x: vector[Any], y: Any) -> vector[float]: ...
def pow(x: VecNumLike, y: VecNumLike) -> Any: @overload
def pow(x: Any, y: vector[Any]) -> vector[float]: ...
@overload
def pow(x: tensor[Any], y: Any) -> tensor[float]: ...
@overload
def pow(x: Any, y: tensor[Any]) -> tensor[float]: ...
def pow(x: TensorNumLike, y: TensorNumLike) -> Any:
"""x to the power of y """x to the power of y
Arguments: Arguments:
@ -70,8 +82,10 @@ def pow(x: VecNumLike, y: VecNumLike) -> Any:
Returns: Returns:
result of x**y result of x**y
""" """
if isinstance(x, tensor) or isinstance(y, tensor):
return _map2_tensor(x, y, pow)
if isinstance(x, vector) or isinstance(y, vector): if isinstance(x, vector) or isinstance(y, vector):
return _map2(x, y, pow) return _map2_vector(x, y, pow)
if isinstance(y, int) and 0 <= y < 8: if isinstance(y, int) and 0 <= y < 8:
if y == 0: if y == 0:
return 1 return 1
@ -93,6 +107,8 @@ def sqrt(x: float | int) -> float: ...
def sqrt(x: value[Any]) -> value[float]: ... def sqrt(x: value[Any]) -> value[float]: ...
@overload @overload
def sqrt(x: vector[Any]) -> vector[float]: ... def sqrt(x: vector[Any]) -> vector[float]: ...
@overload
def sqrt(x: tensor[Any]) -> tensor[float]: ...
def sqrt(x: Any) -> Any: def sqrt(x: Any) -> Any:
"""Square root function """Square root function
@ -104,7 +120,7 @@ def sqrt(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('sqrt', [x]) return add_op('sqrt', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(sqrt) return x.map(sqrt)
return float(math.sqrt(x)) return float(math.sqrt(x))
@ -115,6 +131,8 @@ def sin(x: float | int) -> float: ...
def sin(x: value[Any]) -> value[float]: ... def sin(x: value[Any]) -> value[float]: ...
@overload @overload
def sin(x: vector[Any]) -> vector[float]: ... def sin(x: vector[Any]) -> vector[float]: ...
@overload
def sin(x: tensor[Any]) -> tensor[float]: ...
def sin(x: Any) -> Any: def sin(x: Any) -> Any:
"""Sine function """Sine function
@ -126,7 +144,7 @@ def sin(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('sin', [x]) return add_op('sin', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(sin) return x.map(sin)
return math.sin(x) return math.sin(x)
@ -137,6 +155,8 @@ def cos(x: float | int) -> float: ...
def cos(x: value[Any]) -> value[float]: ... def cos(x: value[Any]) -> value[float]: ...
@overload @overload
def cos(x: vector[Any]) -> vector[float]: ... def cos(x: vector[Any]) -> vector[float]: ...
@overload
def cos(x: tensor[Any]) -> tensor[float]: ...
def cos(x: Any) -> Any: def cos(x: Any) -> Any:
"""Cosine function """Cosine function
@ -148,7 +168,7 @@ def cos(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('cos', [x]) return add_op('cos', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(cos) return x.map(cos)
return math.cos(x) return math.cos(x)
@ -159,6 +179,8 @@ def tan(x: float | int) -> float: ...
def tan(x: value[Any]) -> value[float]: ... def tan(x: value[Any]) -> value[float]: ...
@overload @overload
def tan(x: vector[Any]) -> vector[float]: ... def tan(x: vector[Any]) -> vector[float]: ...
@overload
def tan(x: tensor[Any]) -> tensor[float]: ...
def tan(x: Any) -> Any: def tan(x: Any) -> Any:
"""Tangent function """Tangent function
@ -170,8 +192,7 @@ def tan(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('tan', [x]) return add_op('tan', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
#return x.map(tan)
return x.map(tan) return x.map(tan)
return math.tan(x) return math.tan(x)
@ -182,6 +203,8 @@ def atan(x: float | int) -> float: ...
def atan(x: value[Any]) -> value[float]: ... def atan(x: value[Any]) -> value[float]: ...
@overload @overload
def atan(x: vector[Any]) -> vector[float]: ... def atan(x: vector[Any]) -> vector[float]: ...
@overload
def atan(x: tensor[Any]) -> tensor[float]: ...
def atan(x: Any) -> Any: def atan(x: Any) -> Any:
"""Inverse tangent function """Inverse tangent function
@ -193,7 +216,7 @@ def atan(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('atan', [x]) return add_op('atan', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(atan) return x.map(atan)
return math.atan(x) return math.atan(x)
@ -208,7 +231,11 @@ def atan2(x: NumLike, y: value[Any]) -> value[float]: ...
def atan2(x: vector[float], y: VecNumLike) -> vector[float]: ... def atan2(x: vector[float], y: VecNumLike) -> vector[float]: ...
@overload @overload
def atan2(x: VecNumLike, y: vector[float]) -> vector[float]: ... def atan2(x: VecNumLike, y: vector[float]) -> vector[float]: ...
def atan2(x: VecNumLike, y: VecNumLike) -> Any: @overload
def atan2(x: tensor[float], y: TensorNumLike) -> tensor[float]: ...
@overload
def atan2(x: TensorNumLike, y: tensor[float]) -> tensor[float]: ...
def atan2(x: TensorNumLike, y: TensorNumLike) -> Any:
"""2-argument arctangent """2-argument arctangent
Arguments: Arguments:
@ -218,8 +245,10 @@ def atan2(x: VecNumLike, y: VecNumLike) -> Any:
Returns: Returns:
Result in radian Result in radian
""" """
if isinstance(x, tensor) or isinstance(y, tensor):
return _map2_tensor(x, y, atan2)
if isinstance(x, vector) or isinstance(y, vector): if isinstance(x, vector) or isinstance(y, vector):
return _map2(x, y, atan2) return _map2_vector(x, y, atan2)
if isinstance(x, value) or isinstance(y, value): if isinstance(x, value) or isinstance(y, value):
return add_op('atan2', [x, y]) return add_op('atan2', [x, y])
return math.atan2(x, y) return math.atan2(x, y)
@ -231,6 +260,8 @@ def asin(x: float | int) -> float: ...
def asin(x: value[Any]) -> value[float]: ... def asin(x: value[Any]) -> value[float]: ...
@overload @overload
def asin(x: vector[Any]) -> vector[float]: ... def asin(x: vector[Any]) -> vector[float]: ...
@overload
def asin(x: tensor[Any]) -> tensor[float]: ...
def asin(x: Any) -> Any: def asin(x: Any) -> Any:
"""Inverse sine function """Inverse sine function
@ -242,7 +273,7 @@ def asin(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('asin', [x]) return add_op('asin', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(asin) return x.map(asin)
return math.asin(x) return math.asin(x)
@ -253,6 +284,8 @@ def acos(x: float | int) -> float: ...
def acos(x: value[Any]) -> value[float]: ... def acos(x: value[Any]) -> value[float]: ...
@overload @overload
def acos(x: vector[Any]) -> vector[float]: ... def acos(x: vector[Any]) -> vector[float]: ...
@overload
def acos(x: tensor[Any]) -> tensor[float]: ...
def acos(x: Any) -> Any: def acos(x: Any) -> Any:
"""Inverse cosine function """Inverse cosine function
@ -264,11 +297,12 @@ def acos(x: Any) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('acos', [x]) return add_op('acos', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(acos) return x.map(acos)
return math.asin(x) return math.asin(x)
# Debug test function
@overload @overload
def get_42(x: float | int) -> float: ... def get_42(x: float | int) -> float: ...
@overload @overload
@ -286,7 +320,9 @@ def abs(x: U) -> U: ...
def abs(x: value[U]) -> value[U]: ... def abs(x: value[U]) -> value[U]: ...
@overload @overload
def abs(x: vector[U]) -> vector[U]: ... def abs(x: vector[U]) -> vector[U]: ...
def abs(x: U | value[U] | vector[U]) -> Any: @overload
def abs(x: tensor[U]) -> tensor[U]: ...
def abs(x: U | value[U] | vector[U] | tensor[U]) -> Any:
"""Absolute value function """Absolute value function
Arguments: Arguments:
@ -297,18 +333,20 @@ def abs(x: U | value[U] | vector[U]) -> Any:
""" """
if isinstance(x, value): if isinstance(x, value):
return add_op('abs', [x]) return add_op('abs', [x])
if isinstance(x, vector): if isinstance(x, vector | tensor):
return x.map(abs) return x.map(abs)
return (x < 0) * -x + (x >= 0) * x return (x < 0) * -x + (x >= 0) * x
@overload @overload
def sign(x: U) -> U: ... def sign(x: U) -> int: ...
@overload @overload
def sign(x: value[U]) -> value[U]: ... def sign(x: value[U]) -> value[int]: ...
@overload @overload
def sign(x: vector[U]) -> vector[U]: ... def sign(x: vector[U]) -> vector[int]: ...
def sign(x: U | value[U] | vector[U]) -> Any: @overload
def sign(x: tensor[U]) -> tensor[int]: ...
def sign(x: U | value[U] | vector[U] | tensor[U]) -> Any:
"""Return 1 for positive numbers and -1 for negative numbers. """Return 1 for positive numbers and -1 for negative numbers.
For an input of 0 the return value is 0. For an input of 0 the return value is 0.
@ -318,8 +356,11 @@ def sign(x: U | value[U] | vector[U]) -> Any:
Returns: Returns:
-1, 0 or 1 -1, 0 or 1
""" """
ret = (x > 0) - (x < 0) if isinstance(x, value):
return ret return add_op('sign', [x])
if isinstance(x, vector | tensor):
return x.map(sign)
return (x > 0) - (x < 0)
@overload @overload
@ -367,7 +408,13 @@ def min(x: U | value[U], y: U | value[U]) -> Any:
Returns: Returns:
Minimum of x and y Minimum of x and y
""" """
return (x < y) * x + (x >= y) * y if isinstance(x, value):
return add_op('min', [x, y])
if isinstance(x, tensor):
return _map2_tensor(x, y, min)
if isinstance(x, vector):
return _map2_vector(x, y, min)
return x if x < y else y
@overload @overload
@ -386,7 +433,13 @@ def max(x: U | value[U], y: U | value[U]) -> Any:
Returns: Returns:
Maximum of x and y Maximum of x and y
""" """
return (x > y) * x + (x <= y) * y if isinstance(x, value):
return add_op('max', [x, y])
if isinstance(x, tensor):
return _map2_tensor(x, y, max)
if isinstance(x, vector):
return _map2_vector(x, y, max)
return x if x > y else y
@overload @overload
@ -400,7 +453,16 @@ def lerp(v1: U, v2: U, t: float) -> U: ...
@overload @overload
def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ... def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ...
def lerp(v1: U | value[U] | vector[U], v2: U | value[U] | vector[U], t: unifloat) -> Any: def lerp(v1: U | value[U] | vector[U], v2: U | value[U] | vector[U], t: unifloat) -> Any:
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t.""" """Linearly interpolate between two values or vectors v1 and v2 by a factor t.
Arguments:
v1: First value or vector
v2: Second value or vector
t: Interpolation factor (0.0 to 1.0)
Returns:
Interpolated value or vector
"""
if isinstance(v1, vector) or isinstance(v2, vector): if isinstance(v1, vector) or isinstance(v2, vector):
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors." assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
assert len(v1.values) == len(v2.values), "Vectors must be of the same length." assert len(v1.values) == len(v2.values), "Vectors must be of the same length."
@ -414,13 +476,15 @@ def relu(x: U) -> U: ...
def relu(x: value[U]) -> value[U]: ... def relu(x: value[U]) -> value[U]: ...
@overload @overload
def relu(x: vector[U]) -> vector[U]: ... def relu(x: vector[U]) -> vector[U]: ...
def relu(x: U | value[U] | vector[U]) -> Any: @overload
def relu(x: tensor[U]) -> tensor[U]: ...
def relu(x: U | value[U] | vector[U] | tensor[U]) -> Any:
"""Returns x for x > 0 and otherwise 0.""" """Returns x for x > 0 and otherwise 0."""
ret = (x > 0) * x ret = x * (x > 0)
return ret return ret
def _map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[U] | U]) -> vector[U]: def _map2_vector(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[U] | U]) -> vector[U]:
"""Applies a function to each element of the vector and a second vector or scalar.""" """Applies a function to each element of the vector and a second vector or scalar."""
if isinstance(self, vector) and isinstance(other, vector): if isinstance(self, vector) and isinstance(other, vector):
return vector(func(x, y) for x, y in zip(self.values, other.values)) return vector(func(x, y) for x, y in zip(self.values, other.values))
@ -430,3 +494,20 @@ def _map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[
return vector(func(self, x) for x in other.values) return vector(func(self, x) for x in other.values)
else: else:
return vector([func(self, other)]) return vector([func(self, other)])
def _map2_tensor(self: TensorNumLike, other: TensorNumLike, func: Callable[[Any, Any], value[U] | U]) -> tensor[U]:
"""Applies a function to each element of the vector and a second vector or scalar."""
if isinstance(self, vector):
self = tensor(self.values, (len(self.values),))
if isinstance(other, vector):
other = tensor(other.values, (len(other.values),))
if isinstance(self, tensor) and isinstance(other, tensor):
assert self.shape == other.shape, "Tensors must have the same shape"
return tensor([func(x, y) for x, y in zip(self.values, other.values)], self.shape)
elif isinstance(self, tensor):
return tensor([func(x, other) for x in self.values], self.shape)
elif isinstance(other, tensor):
return tensor([func(self, x) for x in other.values], other.shape)
else:
return tensor(func(self, other))

View File

@ -1,13 +1,13 @@
from copapy._basic_types import NumLike, ArrayType from copapy._basic_types import NumLike, ArrayType
from . import value from . import value
from ._vectors import vector from ._vectors import vector, VecFloatLike, VecIntLike, VecNumLike
from ._mixed import mixed_sum from ._mixed import mixed_sum
from typing import TypeVar, Any, overload, TypeAlias, Callable, Iterator, Sequence from typing import TypeVar, Any, overload, TypeAlias, Callable, Iterator, Sequence
from ._helper_types import TNum from ._helper_types import TNum
TensorNumLike: TypeAlias = 'tensor[Any] | vector[Any] | value[Any] | int | float | bool' TensorNumLike: TypeAlias = 'tensor[Any] | vector[Any] | value[Any] | int | float | bool'
TensorIntLike: TypeAlias = 'tensor[int] | value[int] | int' TensorIntLike: TypeAlias = 'tensor[int] | vector[int] | value[int] | int | bool'
TensorFloatLike: TypeAlias = 'tensor[float] | value[float] | float' TensorFloatLike: TypeAlias = 'tensor[float] | vector[float] | value[float] | float'
TensorSequence: TypeAlias = 'Sequence[TNum | value[TNum]] | Sequence[Sequence[TNum | value[TNum]]] | Sequence[Sequence[Sequence[TNum | value[TNum]]]]' TensorSequence: TypeAlias = 'Sequence[TNum | value[TNum]] | Sequence[Sequence[TNum | value[TNum]]] | Sequence[Sequence[Sequence[TNum | value[TNum]]]]'
U = TypeVar("U", int, float) U = TypeVar("U", int, float)
@ -26,6 +26,7 @@ class tensor(ArrayType[TNum]):
values: Nested iterables of constant values or copapy values. values: Nested iterables of constant values or copapy values.
Can be a scalar, 1D iterable (vector), Can be a scalar, 1D iterable (vector),
or n-dimensional nested structure. or n-dimensional nested structure.
shape: Optional shape of the tensor. If not provided, inferred from values.
""" """
if shape: if shape:
self.shape: tuple[int, ...] = tuple(shape) self.shape: tuple[int, ...] = tuple(shape)
@ -264,15 +265,19 @@ class tensor(ArrayType[TNum]):
@overload @overload
def __add__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ... def __add__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
@overload @overload
def __add__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ... def __add__(self, other: TensorNumLike) -> 'tensor[Any]': ...
def __add__(self, other: TensorNumLike) -> Any: def __add__(self, other: TensorNumLike) -> Any:
"""Element-wise addition.""" """Element-wise addition."""
return self._binary_op(other, lambda a, b: a + b) return self._binary_op(other, lambda a, b: a + b)
@overload @overload
def __radd__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ... def __radd__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
@overload @overload
def __radd__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ... def __radd__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
@overload
def __radd__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
@overload
def __radd__(self, other: VecNumLike) -> 'tensor[Any]': ...
def __radd__(self, other: Any) -> Any: def __radd__(self, other: Any) -> Any:
return self + other return self + other
@ -283,15 +288,19 @@ class tensor(ArrayType[TNum]):
@overload @overload
def __sub__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ... def __sub__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
@overload @overload
def __sub__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ... def __sub__(self, other: TensorNumLike) -> 'tensor[Any]': ...
def __sub__(self, other: TensorNumLike) -> Any: def __sub__(self, other: TensorNumLike) -> Any:
"""Element-wise subtraction.""" """Element-wise subtraction."""
return self._binary_op(other, lambda a, b: a - b, commutative=False) return self._binary_op(other, lambda a, b: a - b, commutative=False)
@overload @overload
def __rsub__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ... def __rsub__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
@overload @overload
def __rsub__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ... def __rsub__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
@overload
def __rsub__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
@overload
def __rsub__(self, other: VecNumLike) -> 'tensor[Any]': ...
def __rsub__(self, other: TensorNumLike) -> Any: def __rsub__(self, other: TensorNumLike) -> Any:
return self._binary_op(other, lambda a, b: b - a, commutative=False, reversed=True) return self._binary_op(other, lambda a, b: b - a, commutative=False, reversed=True)
@ -302,15 +311,19 @@ class tensor(ArrayType[TNum]):
@overload @overload
def __mul__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ... def __mul__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
@overload @overload
def __mul__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ... def __mul__(self, other: TensorNumLike) -> 'tensor[Any]': ...
def __mul__(self, other: TensorNumLike) -> Any: def __mul__(self, other: TensorNumLike) -> Any:
"""Element-wise multiplication.""" """Element-wise multiplication."""
return self._binary_op(other, lambda a, b: a * b) return self._binary_op(other, lambda a, b: a * b)
@overload @overload
def __rmul__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ... def __rmul__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
@overload @overload
def __rmul__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ... def __rmul__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
@overload
def __rmul__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
@overload
def __rmul__(self, other: VecNumLike) -> 'tensor[Any]': ...
def __rmul__(self, other: TensorNumLike) -> Any: def __rmul__(self, other: TensorNumLike) -> Any:
return self * other return self * other
@ -329,15 +342,19 @@ class tensor(ArrayType[TNum]):
@overload @overload
def __pow__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ... def __pow__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
@overload @overload
def __pow__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ... def __pow__(self, other: TensorNumLike) -> 'tensor[Any]': ...
def __pow__(self, other: TensorNumLike) -> Any: def __pow__(self, other: TensorNumLike) -> Any:
"""Element-wise power.""" """Element-wise power."""
return self._binary_op(other, lambda a, b: a ** b, commutative=False) return self._binary_op(other, lambda a, b: a ** b, commutative=False)
@overload @overload
def __rpow__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ... def __rpow__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
@overload @overload
def __rpow__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ... def __rpow__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
@overload
def __rpow__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
@overload
def __rpow__(self, other: VecNumLike) -> 'tensor[Any]': ...
def __rpow__(self, other: TensorNumLike) -> Any: def __rpow__(self, other: TensorNumLike) -> Any:
return self._binary_op(other, lambda a, b: b ** a, commutative=False, reversed=True) return self._binary_op(other, lambda a, b: b ** a, commutative=False, reversed=True)

View File

@ -166,8 +166,42 @@ def get_floordiv(op: str, type1: str, type2: str) -> str:
""" """
else: else:
return f""" return f"""
STENCIL void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ STENCIL void {op}_{type1}_{type2}({type1} a, {type2} b) {{
result_float_{type2}(floorf((float)arg1 / (float)arg2), arg2); result_float_{type2}(floorf((float)a / (float)b), b);
}}
"""
@norm_indent
def get_min(type1: str, type2: str) -> str:
if type1 == 'int' and type2 == 'int':
return f"""
STENCIL void min_{type1}_{type2}({type1} a, {type2} b) {{
result_int_{type2}(a < b ? a : b, b);
}}
"""
else:
return f"""
STENCIL void min_{type1}_{type2}({type1} a, {type2} b) {{
float _a = (float)a; float _b = (float)b;
result_float_{type2}(_a < _b ? _a : _b, b);
}}
"""
@norm_indent
def get_max(type1: str, type2: str) -> str:
if type1 == 'int' and type2 == 'int':
return f"""
STENCIL void max_{type1}_{type2}({type1} a, {type2} b) {{
result_int_{type2}(a > b ? a : b, b);
}}
"""
else:
return f"""
STENCIL void max_{type1}_{type2}({type1} a, {type2} b) {{
float _a = (float)a; float _b = (float)b;
result_float_{type2}(_a > _b ? _a : _b, b);
}} }}
""" """
@ -268,10 +302,17 @@ if __name__ == "__main__":
code += get_math_func1('fabsf', 'float', 'abs') code += get_math_func1('fabsf', 'float', 'abs')
code += get_custom_stencil('abs_int(int arg1)', 'result_int(__builtin_abs(arg1));') code += get_custom_stencil('abs_int(int arg1)', 'result_int(__builtin_abs(arg1));')
for t in types:
code += get_custom_stencil(f"sign_{t}({t} arg1)", f"result_int((arg1 > 0) - (arg1 < 0));")
fnames = ['atan2', 'pow'] fnames = ['atan2', 'pow']
for fn, t1, t2 in permutate(fnames, types, types): for fn, t1, t2 in permutate(fnames, types, types):
code += get_math_func2(fn, t1, t2) code += get_math_func2(fn, t1, t2)
for t1, t2 in permutate(types, types):
code += get_min(t1, t2)
code += get_max(t1, t2)
for op, t1, t2 in permutate(ops, types, types): for op, t1, t2 in permutate(ops, types, types):
t_out = t1 if t1 == t2 else 'float' t_out = t1 if t1 == t2 else 'float'
if op == 'floordiv': if op == 'floordiv':

View File

@ -20,7 +20,11 @@ def test_fine():
cp.cos(c_f), cp.cos(c_f),
cp.tan(c_f), cp.tan(c_f),
cp.abs(-c_i), cp.abs(-c_i),
cp.abs(-c_f)) cp.abs(-c_f),
cp.sign(c_i),
cp.sign(-c_f),
cp.min(c_i, 5),
cp.max(c_f, 5))
re2_test = (a_f ** 2, re2_test = (a_f ** 2,
a_i ** -1, a_i ** -1,
@ -32,7 +36,11 @@ def test_fine():
cp.cos(a_f), cp.cos(a_f),
cp.tan(a_f), cp.tan(a_f),
cp.abs(-a_i), cp.abs(-a_i),
cp.abs(-a_f)) cp.abs(-a_f),
cp.sign(a_i),
cp.sign(-a_f),
cp.min(a_i, 5),
cp.max(a_f, 5))
ret_refe = (a_f ** 2, ret_refe = (a_f ** 2,
a_i ** -1, a_i ** -1,
@ -43,8 +51,12 @@ def test_fine():
ma.sin(a_f), ma.sin(a_f),
ma.cos(a_f), ma.cos(a_f),
ma.tan(a_f), ma.tan(a_f),
cp.abs(-a_i), abs(-a_i),
cp.abs(-a_f)) abs(-a_f),
(a_i > 0) - (a_i < 0),
(-a_f > 0) - (-a_f < 0),
min(a_i, 5),
max(a_f, 5))
tg = Target() tg = Target()
print('* compile and copy ...') print('* compile and copy ...')
@ -53,10 +65,10 @@ def test_fine():
tg.run() tg.run()
print('* finished') print('* finished')
for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')): for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ['^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan'] + ['other']*10):
assert isinstance(test, cp.value) assert isinstance(test, cp.value)
val = tg.read_value(test) val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype) print('+', name, val, ref, type(val), test.dtype)
#for t in (int, float, bool): #for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" # assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, abs=1e-3), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] assert val == pytest.approx(ref, abs=1e-3), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]