Added tensor support and type hints for math functions

This commit is contained in:
Nicolas 2026-01-05 13:39:53 +01:00
parent 32aad5cafd
commit 0f5bb86bd4
3 changed files with 142 additions and 44 deletions

View File

@ -40,7 +40,7 @@ from ._tensors import tensor, zeros, ones, arange, eye, identity, diagonal
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
from ._autograd import grad
from ._tensors import tensor as matrix
from ._version import __version__
from ._version import __version__ # Run "pip install -e ." to generate _version.py
__all__ = [

View File

@ -1,5 +1,7 @@
from . import vector
from . import tensor
from ._vectors import VecNumLike
from ._tensors import TensorNumLike
from . import value, NumLike
from typing import TypeVar, Any, overload, Callable
from ._basic_types import add_op, unifloat
@ -15,6 +17,8 @@ def exp(x: float | int) -> float: ...
def exp(x: value[Any]) -> value[float]: ...
@overload
def exp(x: vector[Any]) -> vector[float]: ...
@overload
def exp(x: tensor[Any]) -> tensor[float]: ...
def exp(x: Any) -> Any:
"""Exponential function to basis e
@ -26,7 +30,7 @@ def exp(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('exp', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(exp)
return float(math.exp(x))
@ -37,6 +41,8 @@ def log(x: float | int) -> float: ...
def log(x: value[Any]) -> value[float]: ...
@overload
def log(x: vector[Any]) -> vector[float]: ...
@overload
def log(x: tensor[Any]) -> tensor[float]: ...
def log(x: Any) -> Any:
"""Logarithm to basis e
@ -48,7 +54,7 @@ def log(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('log', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(log)
return float(math.log(x))
@ -61,7 +67,13 @@ def pow(x: value[Any], y: NumLike) -> value[float]: ...
def pow(x: NumLike, y: value[Any]) -> value[float]: ...
@overload
def pow(x: vector[Any], y: Any) -> vector[float]: ...
def pow(x: VecNumLike, y: VecNumLike) -> Any:
@overload
def pow(x: Any, y: vector[Any]) -> vector[float]: ...
@overload
def pow(x: tensor[Any], y: Any) -> tensor[float]: ...
@overload
def pow(x: Any, y: tensor[Any]) -> tensor[float]: ...
def pow(x: TensorNumLike, y: TensorNumLike) -> Any:
"""x to the power of y
Arguments:
@ -70,8 +82,10 @@ def pow(x: VecNumLike, y: VecNumLike) -> Any:
Returns:
result of x**y
"""
if isinstance(x, tensor) or isinstance(y, tensor):
return _map2_tensor(x, y, pow)
if isinstance(x, vector) or isinstance(y, vector):
return _map2(x, y, pow)
return _map2_vector(x, y, pow)
if isinstance(y, int) and 0 <= y < 8:
if y == 0:
return 1
@ -93,6 +107,8 @@ def sqrt(x: float | int) -> float: ...
def sqrt(x: value[Any]) -> value[float]: ...
@overload
def sqrt(x: vector[Any]) -> vector[float]: ...
@overload
def sqrt(x: tensor[Any]) -> tensor[float]: ...
def sqrt(x: Any) -> Any:
"""Square root function
@ -104,7 +120,7 @@ def sqrt(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('sqrt', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(sqrt)
return float(math.sqrt(x))
@ -115,6 +131,8 @@ def sin(x: float | int) -> float: ...
def sin(x: value[Any]) -> value[float]: ...
@overload
def sin(x: vector[Any]) -> vector[float]: ...
@overload
def sin(x: tensor[Any]) -> tensor[float]: ...
def sin(x: Any) -> Any:
"""Sine function
@ -126,7 +144,7 @@ def sin(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('sin', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(sin)
return math.sin(x)
@ -137,6 +155,8 @@ def cos(x: float | int) -> float: ...
def cos(x: value[Any]) -> value[float]: ...
@overload
def cos(x: vector[Any]) -> vector[float]: ...
@overload
def cos(x: tensor[Any]) -> tensor[float]: ...
def cos(x: Any) -> Any:
"""Cosine function
@ -148,7 +168,7 @@ def cos(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('cos', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(cos)
return math.cos(x)
@ -159,6 +179,8 @@ def tan(x: float | int) -> float: ...
def tan(x: value[Any]) -> value[float]: ...
@overload
def tan(x: vector[Any]) -> vector[float]: ...
@overload
def tan(x: tensor[Any]) -> tensor[float]: ...
def tan(x: Any) -> Any:
"""Tangent function
@ -170,8 +192,7 @@ def tan(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('tan', [x])
if isinstance(x, vector):
#return x.map(tan)
if isinstance(x, vector | tensor):
return x.map(tan)
return math.tan(x)
@ -182,6 +203,8 @@ def atan(x: float | int) -> float: ...
def atan(x: value[Any]) -> value[float]: ...
@overload
def atan(x: vector[Any]) -> vector[float]: ...
@overload
def atan(x: tensor[Any]) -> tensor[float]: ...
def atan(x: Any) -> Any:
"""Inverse tangent function
@ -193,7 +216,7 @@ def atan(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('atan', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(atan)
return math.atan(x)
@ -208,7 +231,11 @@ def atan2(x: NumLike, y: value[Any]) -> value[float]: ...
def atan2(x: vector[float], y: VecNumLike) -> vector[float]: ...
@overload
def atan2(x: VecNumLike, y: vector[float]) -> vector[float]: ...
def atan2(x: VecNumLike, y: VecNumLike) -> Any:
@overload
def atan2(x: tensor[float], y: TensorNumLike) -> tensor[float]: ...
@overload
def atan2(x: TensorNumLike, y: tensor[float]) -> tensor[float]: ...
def atan2(x: TensorNumLike, y: TensorNumLike) -> Any:
"""2-argument arctangent
Arguments:
@ -218,8 +245,10 @@ def atan2(x: VecNumLike, y: VecNumLike) -> Any:
Returns:
Result in radian
"""
if isinstance(x, tensor) or isinstance(y, tensor):
return _map2_tensor(x, y, atan2)
if isinstance(x, vector) or isinstance(y, vector):
return _map2(x, y, atan2)
return _map2_vector(x, y, atan2)
if isinstance(x, value) or isinstance(y, value):
return add_op('atan2', [x, y])
return math.atan2(x, y)
@ -231,6 +260,8 @@ def asin(x: float | int) -> float: ...
def asin(x: value[Any]) -> value[float]: ...
@overload
def asin(x: vector[Any]) -> vector[float]: ...
@overload
def asin(x: tensor[Any]) -> tensor[float]: ...
def asin(x: Any) -> Any:
"""Inverse sine function
@ -242,7 +273,7 @@ def asin(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('asin', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(asin)
return math.asin(x)
@ -253,6 +284,8 @@ def acos(x: float | int) -> float: ...
def acos(x: value[Any]) -> value[float]: ...
@overload
def acos(x: vector[Any]) -> vector[float]: ...
@overload
def acos(x: tensor[Any]) -> tensor[float]: ...
def acos(x: Any) -> Any:
"""Inverse cosine function
@ -264,11 +297,12 @@ def acos(x: Any) -> Any:
"""
if isinstance(x, value):
return add_op('acos', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(acos)
return math.asin(x)
# Debug test function
@overload
def get_42(x: float | int) -> float: ...
@overload
@ -286,7 +320,9 @@ def abs(x: U) -> U: ...
def abs(x: value[U]) -> value[U]: ...
@overload
def abs(x: vector[U]) -> vector[U]: ...
def abs(x: U | value[U] | vector[U]) -> Any:
@overload
def abs(x: tensor[U]) -> tensor[U]: ...
def abs(x: U | value[U] | vector[U] | tensor[U]) -> Any:
"""Absolute value function
Arguments:
@ -297,18 +333,20 @@ def abs(x: U | value[U] | vector[U]) -> Any:
"""
if isinstance(x, value):
return add_op('abs', [x])
if isinstance(x, vector):
if isinstance(x, vector | tensor):
return x.map(abs)
return (x < 0) * -x + (x >= 0) * x
@overload
def sign(x: U) -> U: ...
def sign(x: U) -> int: ...
@overload
def sign(x: value[U]) -> value[U]: ...
def sign(x: value[U]) -> value[int]: ...
@overload
def sign(x: vector[U]) -> vector[U]: ...
def sign(x: U | value[U] | vector[U]) -> Any:
def sign(x: vector[U]) -> vector[int]: ...
@overload
def sign(x: tensor[U]) -> tensor[int]: ...
def sign(x: U | value[U] | vector[U] | tensor[U]) -> Any:
"""Return 1 for positive numbers and -1 for negative numbers.
For an input of 0 the return value is 0.
@ -318,8 +356,11 @@ def sign(x: U | value[U] | vector[U]) -> Any:
Returns:
-1, 0 or 1
"""
ret = (x > 0) - (x < 0)
return ret
if isinstance(x, value):
return add_op('sign', [x])
if isinstance(x, vector | tensor):
return x.map(sign)
return (x > 0) - (x < 0)
@overload
@ -367,7 +408,13 @@ def min(x: U | value[U], y: U | value[U]) -> Any:
Returns:
Minimum of x and y
"""
return (x < y) * x + (x >= y) * y
if isinstance(x, value):
return add_op('min', [x, y])
if isinstance(x, tensor):
return _map2_tensor(x, y, min)
if isinstance(x, vector):
return _map2_vector(x, y, min)
return x if x < y else y
@overload
@ -386,7 +433,13 @@ def max(x: U | value[U], y: U | value[U]) -> Any:
Returns:
Maximum of x and y
"""
return (x > y) * x + (x <= y) * y
if isinstance(x, value):
return add_op('max', [x, y])
if isinstance(x, tensor):
return _map2_tensor(x, y, max)
if isinstance(x, vector):
return _map2_vector(x, y, max)
return x if x > y else y
@overload
@ -400,7 +453,16 @@ def lerp(v1: U, v2: U, t: float) -> U: ...
@overload
def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ...
def lerp(v1: U | value[U] | vector[U], v2: U | value[U] | vector[U], t: unifloat) -> Any:
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t."""
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t.
Arguments:
v1: First value or vector
v2: Second value or vector
t: Interpolation factor (0.0 to 1.0)
Returns:
Interpolated value or vector
"""
if isinstance(v1, vector) or isinstance(v2, vector):
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
assert len(v1.values) == len(v2.values), "Vectors must be of the same length."
@ -414,13 +476,15 @@ def relu(x: U) -> U: ...
def relu(x: value[U]) -> value[U]: ...
@overload
def relu(x: vector[U]) -> vector[U]: ...
def relu(x: U | value[U] | vector[U]) -> Any:
@overload
def relu(x: tensor[U]) -> tensor[U]: ...
def relu(x: U | value[U] | vector[U] | tensor[U]) -> Any:
"""Returns x for x > 0 and otherwise 0."""
ret = (x > 0) * x
ret = x * (x > 0)
return ret
def _map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[U] | U]) -> vector[U]:
def _map2_vector(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[U] | U]) -> vector[U]:
"""Applies a function to each element of the vector and a second vector or scalar."""
if isinstance(self, vector) and isinstance(other, vector):
return vector(func(x, y) for x, y in zip(self.values, other.values))
@ -430,3 +494,20 @@ def _map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[
return vector(func(self, x) for x in other.values)
else:
return vector([func(self, other)])
def _map2_tensor(self: TensorNumLike, other: TensorNumLike, func: Callable[[Any, Any], value[U] | U]) -> tensor[U]:
"""Applies a function to each element of the vector and a second vector or scalar."""
if isinstance(self, vector):
self = tensor(self.values, (len(self.values),))
if isinstance(other, vector):
other = tensor(other.values, (len(other.values),))
if isinstance(self, tensor) and isinstance(other, tensor):
assert self.shape == other.shape, "Tensors must have the same shape"
return tensor([func(x, y) for x, y in zip(self.values, other.values)], self.shape)
elif isinstance(self, tensor):
return tensor([func(x, other) for x in self.values], self.shape)
elif isinstance(other, tensor):
return tensor([func(self, x) for x in other.values], other.shape)
else:
return tensor(func(self, other))

View File

@ -1,13 +1,13 @@
from copapy._basic_types import NumLike, ArrayType
from . import value
from ._vectors import vector
from ._vectors import vector, VecFloatLike, VecIntLike, VecNumLike
from ._mixed import mixed_sum
from typing import TypeVar, Any, overload, TypeAlias, Callable, Iterator, Sequence
from ._helper_types import TNum
TensorNumLike: TypeAlias = 'tensor[Any] | vector[Any] | value[Any] | int | float | bool'
TensorIntLike: TypeAlias = 'tensor[int] | value[int] | int'
TensorFloatLike: TypeAlias = 'tensor[float] | value[float] | float'
TensorIntLike: TypeAlias = 'tensor[int] | vector[int] | value[int] | int | bool'
TensorFloatLike: TypeAlias = 'tensor[float] | vector[float] | value[float] | float'
TensorSequence: TypeAlias = 'Sequence[TNum | value[TNum]] | Sequence[Sequence[TNum | value[TNum]]] | Sequence[Sequence[Sequence[TNum | value[TNum]]]]'
U = TypeVar("U", int, float)
@ -26,6 +26,7 @@ class tensor(ArrayType[TNum]):
values: Nested iterables of constant values or copapy values.
Can be a scalar, 1D iterable (vector),
or n-dimensional nested structure.
shape: Optional shape of the tensor. If not provided, inferred from values.
"""
if shape:
self.shape: tuple[int, ...] = tuple(shape)
@ -264,15 +265,19 @@ class tensor(ArrayType[TNum]):
@overload
def __add__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
@overload
def __add__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ...
def __add__(self, other: TensorNumLike) -> 'tensor[Any]': ...
def __add__(self, other: TensorNumLike) -> Any:
"""Element-wise addition."""
return self._binary_op(other, lambda a, b: a + b)
@overload
def __radd__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
def __radd__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
@overload
def __radd__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ...
def __radd__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
@overload
def __radd__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
@overload
def __radd__(self, other: VecNumLike) -> 'tensor[Any]': ...
def __radd__(self, other: Any) -> Any:
return self + other
@ -283,15 +288,19 @@ class tensor(ArrayType[TNum]):
@overload
def __sub__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
@overload
def __sub__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ...
def __sub__(self, other: TensorNumLike) -> 'tensor[Any]': ...
def __sub__(self, other: TensorNumLike) -> Any:
"""Element-wise subtraction."""
return self._binary_op(other, lambda a, b: a - b, commutative=False)
@overload
def __rsub__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
def __rsub__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
@overload
def __rsub__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ...
def __rsub__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
@overload
def __rsub__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
@overload
def __rsub__(self, other: VecNumLike) -> 'tensor[Any]': ...
def __rsub__(self, other: TensorNumLike) -> Any:
return self._binary_op(other, lambda a, b: b - a, commutative=False, reversed=True)
@ -302,15 +311,19 @@ class tensor(ArrayType[TNum]):
@overload
def __mul__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
@overload
def __mul__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ...
def __mul__(self, other: TensorNumLike) -> 'tensor[Any]': ...
def __mul__(self, other: TensorNumLike) -> Any:
"""Element-wise multiplication."""
return self._binary_op(other, lambda a, b: a * b)
@overload
def __rmul__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
def __rmul__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
@overload
def __rmul__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ...
def __rmul__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
@overload
def __rmul__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
@overload
def __rmul__(self, other: VecNumLike) -> 'tensor[Any]': ...
def __rmul__(self, other: TensorNumLike) -> Any:
return self * other
@ -329,15 +342,19 @@ class tensor(ArrayType[TNum]):
@overload
def __pow__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
@overload
def __pow__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ...
def __pow__(self, other: TensorNumLike) -> 'tensor[Any]': ...
def __pow__(self, other: TensorNumLike) -> Any:
"""Element-wise power."""
return self._binary_op(other, lambda a, b: a ** b, commutative=False)
@overload
def __rpow__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
def __rpow__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
@overload
def __rpow__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ...
def __rpow__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
@overload
def __rpow__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
@overload
def __rpow__(self, other: VecNumLike) -> 'tensor[Any]': ...
def __rpow__(self, other: TensorNumLike) -> Any:
return self._binary_op(other, lambda a, b: b ** a, commutative=False, reversed=True)