mirror of https://github.com/Nonannet/copapy.git
Added tensor support and type hints for math functions
This commit is contained in:
parent
32aad5cafd
commit
0f5bb86bd4
|
|
@ -40,7 +40,7 @@ from ._tensors import tensor, zeros, ones, arange, eye, identity, diagonal
|
||||||
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
|
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
|
||||||
from ._autograd import grad
|
from ._autograd import grad
|
||||||
from ._tensors import tensor as matrix
|
from ._tensors import tensor as matrix
|
||||||
from ._version import __version__
|
from ._version import __version__ # Run "pip install -e ." to generate _version.py
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from . import vector
|
from . import vector
|
||||||
|
from . import tensor
|
||||||
from ._vectors import VecNumLike
|
from ._vectors import VecNumLike
|
||||||
|
from ._tensors import TensorNumLike
|
||||||
from . import value, NumLike
|
from . import value, NumLike
|
||||||
from typing import TypeVar, Any, overload, Callable
|
from typing import TypeVar, Any, overload, Callable
|
||||||
from ._basic_types import add_op, unifloat
|
from ._basic_types import add_op, unifloat
|
||||||
|
|
@ -15,6 +17,8 @@ def exp(x: float | int) -> float: ...
|
||||||
def exp(x: value[Any]) -> value[float]: ...
|
def exp(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def exp(x: vector[Any]) -> vector[float]: ...
|
def exp(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def exp(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def exp(x: Any) -> Any:
|
def exp(x: Any) -> Any:
|
||||||
"""Exponential function to basis e
|
"""Exponential function to basis e
|
||||||
|
|
||||||
|
|
@ -26,7 +30,7 @@ def exp(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('exp', [x])
|
return add_op('exp', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(exp)
|
return x.map(exp)
|
||||||
return float(math.exp(x))
|
return float(math.exp(x))
|
||||||
|
|
||||||
|
|
@ -37,6 +41,8 @@ def log(x: float | int) -> float: ...
|
||||||
def log(x: value[Any]) -> value[float]: ...
|
def log(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def log(x: vector[Any]) -> vector[float]: ...
|
def log(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def log(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def log(x: Any) -> Any:
|
def log(x: Any) -> Any:
|
||||||
"""Logarithm to basis e
|
"""Logarithm to basis e
|
||||||
|
|
||||||
|
|
@ -48,7 +54,7 @@ def log(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('log', [x])
|
return add_op('log', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(log)
|
return x.map(log)
|
||||||
return float(math.log(x))
|
return float(math.log(x))
|
||||||
|
|
||||||
|
|
@ -61,7 +67,13 @@ def pow(x: value[Any], y: NumLike) -> value[float]: ...
|
||||||
def pow(x: NumLike, y: value[Any]) -> value[float]: ...
|
def pow(x: NumLike, y: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def pow(x: vector[Any], y: Any) -> vector[float]: ...
|
def pow(x: vector[Any], y: Any) -> vector[float]: ...
|
||||||
def pow(x: VecNumLike, y: VecNumLike) -> Any:
|
@overload
|
||||||
|
def pow(x: Any, y: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def pow(x: tensor[Any], y: Any) -> tensor[float]: ...
|
||||||
|
@overload
|
||||||
|
def pow(x: Any, y: tensor[Any]) -> tensor[float]: ...
|
||||||
|
def pow(x: TensorNumLike, y: TensorNumLike) -> Any:
|
||||||
"""x to the power of y
|
"""x to the power of y
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
@ -70,8 +82,10 @@ def pow(x: VecNumLike, y: VecNumLike) -> Any:
|
||||||
Returns:
|
Returns:
|
||||||
result of x**y
|
result of x**y
|
||||||
"""
|
"""
|
||||||
|
if isinstance(x, tensor) or isinstance(y, tensor):
|
||||||
|
return _map2_tensor(x, y, pow)
|
||||||
if isinstance(x, vector) or isinstance(y, vector):
|
if isinstance(x, vector) or isinstance(y, vector):
|
||||||
return _map2(x, y, pow)
|
return _map2_vector(x, y, pow)
|
||||||
if isinstance(y, int) and 0 <= y < 8:
|
if isinstance(y, int) and 0 <= y < 8:
|
||||||
if y == 0:
|
if y == 0:
|
||||||
return 1
|
return 1
|
||||||
|
|
@ -93,6 +107,8 @@ def sqrt(x: float | int) -> float: ...
|
||||||
def sqrt(x: value[Any]) -> value[float]: ...
|
def sqrt(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def sqrt(x: vector[Any]) -> vector[float]: ...
|
def sqrt(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def sqrt(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def sqrt(x: Any) -> Any:
|
def sqrt(x: Any) -> Any:
|
||||||
"""Square root function
|
"""Square root function
|
||||||
|
|
||||||
|
|
@ -104,7 +120,7 @@ def sqrt(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('sqrt', [x])
|
return add_op('sqrt', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(sqrt)
|
return x.map(sqrt)
|
||||||
return float(math.sqrt(x))
|
return float(math.sqrt(x))
|
||||||
|
|
||||||
|
|
@ -115,6 +131,8 @@ def sin(x: float | int) -> float: ...
|
||||||
def sin(x: value[Any]) -> value[float]: ...
|
def sin(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def sin(x: vector[Any]) -> vector[float]: ...
|
def sin(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def sin(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def sin(x: Any) -> Any:
|
def sin(x: Any) -> Any:
|
||||||
"""Sine function
|
"""Sine function
|
||||||
|
|
||||||
|
|
@ -126,7 +144,7 @@ def sin(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('sin', [x])
|
return add_op('sin', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(sin)
|
return x.map(sin)
|
||||||
return math.sin(x)
|
return math.sin(x)
|
||||||
|
|
||||||
|
|
@ -137,6 +155,8 @@ def cos(x: float | int) -> float: ...
|
||||||
def cos(x: value[Any]) -> value[float]: ...
|
def cos(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def cos(x: vector[Any]) -> vector[float]: ...
|
def cos(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def cos(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def cos(x: Any) -> Any:
|
def cos(x: Any) -> Any:
|
||||||
"""Cosine function
|
"""Cosine function
|
||||||
|
|
||||||
|
|
@ -148,7 +168,7 @@ def cos(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('cos', [x])
|
return add_op('cos', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(cos)
|
return x.map(cos)
|
||||||
return math.cos(x)
|
return math.cos(x)
|
||||||
|
|
||||||
|
|
@ -159,6 +179,8 @@ def tan(x: float | int) -> float: ...
|
||||||
def tan(x: value[Any]) -> value[float]: ...
|
def tan(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def tan(x: vector[Any]) -> vector[float]: ...
|
def tan(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def tan(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def tan(x: Any) -> Any:
|
def tan(x: Any) -> Any:
|
||||||
"""Tangent function
|
"""Tangent function
|
||||||
|
|
||||||
|
|
@ -170,8 +192,7 @@ def tan(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('tan', [x])
|
return add_op('tan', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
#return x.map(tan)
|
|
||||||
return x.map(tan)
|
return x.map(tan)
|
||||||
return math.tan(x)
|
return math.tan(x)
|
||||||
|
|
||||||
|
|
@ -182,6 +203,8 @@ def atan(x: float | int) -> float: ...
|
||||||
def atan(x: value[Any]) -> value[float]: ...
|
def atan(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def atan(x: vector[Any]) -> vector[float]: ...
|
def atan(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def atan(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def atan(x: Any) -> Any:
|
def atan(x: Any) -> Any:
|
||||||
"""Inverse tangent function
|
"""Inverse tangent function
|
||||||
|
|
||||||
|
|
@ -193,7 +216,7 @@ def atan(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('atan', [x])
|
return add_op('atan', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(atan)
|
return x.map(atan)
|
||||||
return math.atan(x)
|
return math.atan(x)
|
||||||
|
|
||||||
|
|
@ -208,7 +231,11 @@ def atan2(x: NumLike, y: value[Any]) -> value[float]: ...
|
||||||
def atan2(x: vector[float], y: VecNumLike) -> vector[float]: ...
|
def atan2(x: vector[float], y: VecNumLike) -> vector[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def atan2(x: VecNumLike, y: vector[float]) -> vector[float]: ...
|
def atan2(x: VecNumLike, y: vector[float]) -> vector[float]: ...
|
||||||
def atan2(x: VecNumLike, y: VecNumLike) -> Any:
|
@overload
|
||||||
|
def atan2(x: tensor[float], y: TensorNumLike) -> tensor[float]: ...
|
||||||
|
@overload
|
||||||
|
def atan2(x: TensorNumLike, y: tensor[float]) -> tensor[float]: ...
|
||||||
|
def atan2(x: TensorNumLike, y: TensorNumLike) -> Any:
|
||||||
"""2-argument arctangent
|
"""2-argument arctangent
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
@ -218,8 +245,10 @@ def atan2(x: VecNumLike, y: VecNumLike) -> Any:
|
||||||
Returns:
|
Returns:
|
||||||
Result in radian
|
Result in radian
|
||||||
"""
|
"""
|
||||||
|
if isinstance(x, tensor) or isinstance(y, tensor):
|
||||||
|
return _map2_tensor(x, y, atan2)
|
||||||
if isinstance(x, vector) or isinstance(y, vector):
|
if isinstance(x, vector) or isinstance(y, vector):
|
||||||
return _map2(x, y, atan2)
|
return _map2_vector(x, y, atan2)
|
||||||
if isinstance(x, value) or isinstance(y, value):
|
if isinstance(x, value) or isinstance(y, value):
|
||||||
return add_op('atan2', [x, y])
|
return add_op('atan2', [x, y])
|
||||||
return math.atan2(x, y)
|
return math.atan2(x, y)
|
||||||
|
|
@ -231,6 +260,8 @@ def asin(x: float | int) -> float: ...
|
||||||
def asin(x: value[Any]) -> value[float]: ...
|
def asin(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def asin(x: vector[Any]) -> vector[float]: ...
|
def asin(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def asin(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def asin(x: Any) -> Any:
|
def asin(x: Any) -> Any:
|
||||||
"""Inverse sine function
|
"""Inverse sine function
|
||||||
|
|
||||||
|
|
@ -242,7 +273,7 @@ def asin(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('asin', [x])
|
return add_op('asin', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(asin)
|
return x.map(asin)
|
||||||
return math.asin(x)
|
return math.asin(x)
|
||||||
|
|
||||||
|
|
@ -253,6 +284,8 @@ def acos(x: float | int) -> float: ...
|
||||||
def acos(x: value[Any]) -> value[float]: ...
|
def acos(x: value[Any]) -> value[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def acos(x: vector[Any]) -> vector[float]: ...
|
def acos(x: vector[Any]) -> vector[float]: ...
|
||||||
|
@overload
|
||||||
|
def acos(x: tensor[Any]) -> tensor[float]: ...
|
||||||
def acos(x: Any) -> Any:
|
def acos(x: Any) -> Any:
|
||||||
"""Inverse cosine function
|
"""Inverse cosine function
|
||||||
|
|
||||||
|
|
@ -264,11 +297,12 @@ def acos(x: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('acos', [x])
|
return add_op('acos', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(acos)
|
return x.map(acos)
|
||||||
return math.asin(x)
|
return math.asin(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Debug test function
|
||||||
@overload
|
@overload
|
||||||
def get_42(x: float | int) -> float: ...
|
def get_42(x: float | int) -> float: ...
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -286,7 +320,9 @@ def abs(x: U) -> U: ...
|
||||||
def abs(x: value[U]) -> value[U]: ...
|
def abs(x: value[U]) -> value[U]: ...
|
||||||
@overload
|
@overload
|
||||||
def abs(x: vector[U]) -> vector[U]: ...
|
def abs(x: vector[U]) -> vector[U]: ...
|
||||||
def abs(x: U | value[U] | vector[U]) -> Any:
|
@overload
|
||||||
|
def abs(x: tensor[U]) -> tensor[U]: ...
|
||||||
|
def abs(x: U | value[U] | vector[U] | tensor[U]) -> Any:
|
||||||
"""Absolute value function
|
"""Absolute value function
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
@ -297,18 +333,20 @@ def abs(x: U | value[U] | vector[U]) -> Any:
|
||||||
"""
|
"""
|
||||||
if isinstance(x, value):
|
if isinstance(x, value):
|
||||||
return add_op('abs', [x])
|
return add_op('abs', [x])
|
||||||
if isinstance(x, vector):
|
if isinstance(x, vector | tensor):
|
||||||
return x.map(abs)
|
return x.map(abs)
|
||||||
return (x < 0) * -x + (x >= 0) * x
|
return (x < 0) * -x + (x >= 0) * x
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def sign(x: U) -> U: ...
|
def sign(x: U) -> int: ...
|
||||||
@overload
|
@overload
|
||||||
def sign(x: value[U]) -> value[U]: ...
|
def sign(x: value[U]) -> value[int]: ...
|
||||||
@overload
|
@overload
|
||||||
def sign(x: vector[U]) -> vector[U]: ...
|
def sign(x: vector[U]) -> vector[int]: ...
|
||||||
def sign(x: U | value[U] | vector[U]) -> Any:
|
@overload
|
||||||
|
def sign(x: tensor[U]) -> tensor[int]: ...
|
||||||
|
def sign(x: U | value[U] | vector[U] | tensor[U]) -> Any:
|
||||||
"""Return 1 for positive numbers and -1 for negative numbers.
|
"""Return 1 for positive numbers and -1 for negative numbers.
|
||||||
For an input of 0 the return value is 0.
|
For an input of 0 the return value is 0.
|
||||||
|
|
||||||
|
|
@ -318,8 +356,11 @@ def sign(x: U | value[U] | vector[U]) -> Any:
|
||||||
Returns:
|
Returns:
|
||||||
-1, 0 or 1
|
-1, 0 or 1
|
||||||
"""
|
"""
|
||||||
ret = (x > 0) - (x < 0)
|
if isinstance(x, value):
|
||||||
return ret
|
return add_op('sign', [x])
|
||||||
|
if isinstance(x, vector | tensor):
|
||||||
|
return x.map(sign)
|
||||||
|
return (x > 0) - (x < 0)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -367,7 +408,13 @@ def min(x: U | value[U], y: U | value[U]) -> Any:
|
||||||
Returns:
|
Returns:
|
||||||
Minimum of x and y
|
Minimum of x and y
|
||||||
"""
|
"""
|
||||||
return (x < y) * x + (x >= y) * y
|
if isinstance(x, value):
|
||||||
|
return add_op('min', [x, y])
|
||||||
|
if isinstance(x, tensor):
|
||||||
|
return _map2_tensor(x, y, min)
|
||||||
|
if isinstance(x, vector):
|
||||||
|
return _map2_vector(x, y, min)
|
||||||
|
return x if x < y else y
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -386,7 +433,13 @@ def max(x: U | value[U], y: U | value[U]) -> Any:
|
||||||
Returns:
|
Returns:
|
||||||
Maximum of x and y
|
Maximum of x and y
|
||||||
"""
|
"""
|
||||||
return (x > y) * x + (x <= y) * y
|
if isinstance(x, value):
|
||||||
|
return add_op('max', [x, y])
|
||||||
|
if isinstance(x, tensor):
|
||||||
|
return _map2_tensor(x, y, max)
|
||||||
|
if isinstance(x, vector):
|
||||||
|
return _map2_vector(x, y, max)
|
||||||
|
return x if x > y else y
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -400,7 +453,16 @@ def lerp(v1: U, v2: U, t: float) -> U: ...
|
||||||
@overload
|
@overload
|
||||||
def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ...
|
def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ...
|
||||||
def lerp(v1: U | value[U] | vector[U], v2: U | value[U] | vector[U], t: unifloat) -> Any:
|
def lerp(v1: U | value[U] | vector[U], v2: U | value[U] | vector[U], t: unifloat) -> Any:
|
||||||
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t."""
|
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
v1: First value or vector
|
||||||
|
v2: Second value or vector
|
||||||
|
t: Interpolation factor (0.0 to 1.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interpolated value or vector
|
||||||
|
"""
|
||||||
if isinstance(v1, vector) or isinstance(v2, vector):
|
if isinstance(v1, vector) or isinstance(v2, vector):
|
||||||
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
|
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
|
||||||
assert len(v1.values) == len(v2.values), "Vectors must be of the same length."
|
assert len(v1.values) == len(v2.values), "Vectors must be of the same length."
|
||||||
|
|
@ -414,13 +476,15 @@ def relu(x: U) -> U: ...
|
||||||
def relu(x: value[U]) -> value[U]: ...
|
def relu(x: value[U]) -> value[U]: ...
|
||||||
@overload
|
@overload
|
||||||
def relu(x: vector[U]) -> vector[U]: ...
|
def relu(x: vector[U]) -> vector[U]: ...
|
||||||
def relu(x: U | value[U] | vector[U]) -> Any:
|
@overload
|
||||||
|
def relu(x: tensor[U]) -> tensor[U]: ...
|
||||||
|
def relu(x: U | value[U] | vector[U] | tensor[U]) -> Any:
|
||||||
"""Returns x for x > 0 and otherwise 0."""
|
"""Returns x for x > 0 and otherwise 0."""
|
||||||
ret = (x > 0) * x
|
ret = x * (x > 0)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[U] | U]) -> vector[U]:
|
def _map2_vector(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[U] | U]) -> vector[U]:
|
||||||
"""Applies a function to each element of the vector and a second vector or scalar."""
|
"""Applies a function to each element of the vector and a second vector or scalar."""
|
||||||
if isinstance(self, vector) and isinstance(other, vector):
|
if isinstance(self, vector) and isinstance(other, vector):
|
||||||
return vector(func(x, y) for x, y in zip(self.values, other.values))
|
return vector(func(x, y) for x, y in zip(self.values, other.values))
|
||||||
|
|
@ -430,3 +494,20 @@ def _map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], value[
|
||||||
return vector(func(self, x) for x in other.values)
|
return vector(func(self, x) for x in other.values)
|
||||||
else:
|
else:
|
||||||
return vector([func(self, other)])
|
return vector([func(self, other)])
|
||||||
|
|
||||||
|
|
||||||
|
def _map2_tensor(self: TensorNumLike, other: TensorNumLike, func: Callable[[Any, Any], value[U] | U]) -> tensor[U]:
|
||||||
|
"""Applies a function to each element of the vector and a second vector or scalar."""
|
||||||
|
if isinstance(self, vector):
|
||||||
|
self = tensor(self.values, (len(self.values),))
|
||||||
|
if isinstance(other, vector):
|
||||||
|
other = tensor(other.values, (len(other.values),))
|
||||||
|
if isinstance(self, tensor) and isinstance(other, tensor):
|
||||||
|
assert self.shape == other.shape, "Tensors must have the same shape"
|
||||||
|
return tensor([func(x, y) for x, y in zip(self.values, other.values)], self.shape)
|
||||||
|
elif isinstance(self, tensor):
|
||||||
|
return tensor([func(x, other) for x in self.values], self.shape)
|
||||||
|
elif isinstance(other, tensor):
|
||||||
|
return tensor([func(self, x) for x in other.values], other.shape)
|
||||||
|
else:
|
||||||
|
return tensor(func(self, other))
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
from copapy._basic_types import NumLike, ArrayType
|
from copapy._basic_types import NumLike, ArrayType
|
||||||
from . import value
|
from . import value
|
||||||
from ._vectors import vector
|
from ._vectors import vector, VecFloatLike, VecIntLike, VecNumLike
|
||||||
from ._mixed import mixed_sum
|
from ._mixed import mixed_sum
|
||||||
from typing import TypeVar, Any, overload, TypeAlias, Callable, Iterator, Sequence
|
from typing import TypeVar, Any, overload, TypeAlias, Callable, Iterator, Sequence
|
||||||
from ._helper_types import TNum
|
from ._helper_types import TNum
|
||||||
|
|
||||||
TensorNumLike: TypeAlias = 'tensor[Any] | vector[Any] | value[Any] | int | float | bool'
|
TensorNumLike: TypeAlias = 'tensor[Any] | vector[Any] | value[Any] | int | float | bool'
|
||||||
TensorIntLike: TypeAlias = 'tensor[int] | value[int] | int'
|
TensorIntLike: TypeAlias = 'tensor[int] | vector[int] | value[int] | int | bool'
|
||||||
TensorFloatLike: TypeAlias = 'tensor[float] | value[float] | float'
|
TensorFloatLike: TypeAlias = 'tensor[float] | vector[float] | value[float] | float'
|
||||||
TensorSequence: TypeAlias = 'Sequence[TNum | value[TNum]] | Sequence[Sequence[TNum | value[TNum]]] | Sequence[Sequence[Sequence[TNum | value[TNum]]]]'
|
TensorSequence: TypeAlias = 'Sequence[TNum | value[TNum]] | Sequence[Sequence[TNum | value[TNum]]] | Sequence[Sequence[Sequence[TNum | value[TNum]]]]'
|
||||||
U = TypeVar("U", int, float)
|
U = TypeVar("U", int, float)
|
||||||
|
|
||||||
|
|
@ -26,6 +26,7 @@ class tensor(ArrayType[TNum]):
|
||||||
values: Nested iterables of constant values or copapy values.
|
values: Nested iterables of constant values or copapy values.
|
||||||
Can be a scalar, 1D iterable (vector),
|
Can be a scalar, 1D iterable (vector),
|
||||||
or n-dimensional nested structure.
|
or n-dimensional nested structure.
|
||||||
|
shape: Optional shape of the tensor. If not provided, inferred from values.
|
||||||
"""
|
"""
|
||||||
if shape:
|
if shape:
|
||||||
self.shape: tuple[int, ...] = tuple(shape)
|
self.shape: tuple[int, ...] = tuple(shape)
|
||||||
|
|
@ -264,15 +265,19 @@ class tensor(ArrayType[TNum]):
|
||||||
@overload
|
@overload
|
||||||
def __add__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
def __add__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __add__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ...
|
def __add__(self, other: TensorNumLike) -> 'tensor[Any]': ...
|
||||||
def __add__(self, other: TensorNumLike) -> Any:
|
def __add__(self, other: TensorNumLike) -> Any:
|
||||||
"""Element-wise addition."""
|
"""Element-wise addition."""
|
||||||
return self._binary_op(other, lambda a, b: a + b)
|
return self._binary_op(other, lambda a, b: a + b)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __radd__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
def __radd__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __radd__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ...
|
def __radd__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
|
||||||
|
@overload
|
||||||
|
def __radd__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
|
||||||
|
@overload
|
||||||
|
def __radd__(self, other: VecNumLike) -> 'tensor[Any]': ...
|
||||||
def __radd__(self, other: Any) -> Any:
|
def __radd__(self, other: Any) -> Any:
|
||||||
return self + other
|
return self + other
|
||||||
|
|
||||||
|
|
@ -283,15 +288,19 @@ class tensor(ArrayType[TNum]):
|
||||||
@overload
|
@overload
|
||||||
def __sub__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
def __sub__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __sub__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ...
|
def __sub__(self, other: TensorNumLike) -> 'tensor[Any]': ...
|
||||||
def __sub__(self, other: TensorNumLike) -> Any:
|
def __sub__(self, other: TensorNumLike) -> Any:
|
||||||
"""Element-wise subtraction."""
|
"""Element-wise subtraction."""
|
||||||
return self._binary_op(other, lambda a, b: a - b, commutative=False)
|
return self._binary_op(other, lambda a, b: a - b, commutative=False)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __rsub__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
def __rsub__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __rsub__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ...
|
def __rsub__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
|
||||||
|
@overload
|
||||||
|
def __rsub__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
|
||||||
|
@overload
|
||||||
|
def __rsub__(self, other: VecNumLike) -> 'tensor[Any]': ...
|
||||||
def __rsub__(self, other: TensorNumLike) -> Any:
|
def __rsub__(self, other: TensorNumLike) -> Any:
|
||||||
return self._binary_op(other, lambda a, b: b - a, commutative=False, reversed=True)
|
return self._binary_op(other, lambda a, b: b - a, commutative=False, reversed=True)
|
||||||
|
|
||||||
|
|
@ -302,15 +311,19 @@ class tensor(ArrayType[TNum]):
|
||||||
@overload
|
@overload
|
||||||
def __mul__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
def __mul__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __mul__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ...
|
def __mul__(self, other: TensorNumLike) -> 'tensor[Any]': ...
|
||||||
def __mul__(self, other: TensorNumLike) -> Any:
|
def __mul__(self, other: TensorNumLike) -> Any:
|
||||||
"""Element-wise multiplication."""
|
"""Element-wise multiplication."""
|
||||||
return self._binary_op(other, lambda a, b: a * b)
|
return self._binary_op(other, lambda a, b: a * b)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __rmul__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
def __rmul__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __rmul__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ...
|
def __rmul__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
|
||||||
|
@overload
|
||||||
|
def __rmul__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
|
||||||
|
@overload
|
||||||
|
def __rmul__(self, other: VecNumLike) -> 'tensor[Any]': ...
|
||||||
def __rmul__(self, other: TensorNumLike) -> Any:
|
def __rmul__(self, other: TensorNumLike) -> Any:
|
||||||
return self * other
|
return self * other
|
||||||
|
|
||||||
|
|
@ -329,15 +342,19 @@ class tensor(ArrayType[TNum]):
|
||||||
@overload
|
@overload
|
||||||
def __pow__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
def __pow__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __pow__(self, other: TensorNumLike) -> 'tensor[int] | tensor[float]': ...
|
def __pow__(self, other: TensorNumLike) -> 'tensor[Any]': ...
|
||||||
def __pow__(self, other: TensorNumLike) -> Any:
|
def __pow__(self, other: TensorNumLike) -> Any:
|
||||||
"""Element-wise power."""
|
"""Element-wise power."""
|
||||||
return self._binary_op(other, lambda a, b: a ** b, commutative=False)
|
return self._binary_op(other, lambda a, b: a ** b, commutative=False)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __rpow__(self: 'tensor[float]', other: TensorNumLike) -> 'tensor[float]': ...
|
def __rpow__(self: 'tensor[int]', other: VecFloatLike) -> 'tensor[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __rpow__(self: 'tensor[int]', other: value[int] | int) -> 'tensor[int]': ...
|
def __rpow__(self: 'tensor[int]', other: VecIntLike) -> 'tensor[int]': ...
|
||||||
|
@overload
|
||||||
|
def __rpow__(self: 'tensor[float]', other: VecNumLike) -> 'tensor[float]': ...
|
||||||
|
@overload
|
||||||
|
def __rpow__(self, other: VecNumLike) -> 'tensor[Any]': ...
|
||||||
def __rpow__(self, other: TensorNumLike) -> Any:
|
def __rpow__(self, other: TensorNumLike) -> Any:
|
||||||
return self._binary_op(other, lambda a, b: b ** a, commutative=False, reversed=True)
|
return self._binary_op(other, lambda a, b: b ** a, commutative=False, reversed=True)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue