mirror of https://github.com/Nonannet/copapy.git
type hints revised
This commit is contained in:
parent
61dc29e68b
commit
ebb4abc5d3
|
|
@ -1,4 +1,4 @@
|
||||||
from . import variable, vector
|
from . import variable, vector, matrix
|
||||||
import copapy.backend as cpb
|
import copapy.backend as cpb
|
||||||
from typing import Any, Sequence, overload
|
from typing import Any, Sequence, overload
|
||||||
import copapy as cp
|
import copapy as cp
|
||||||
|
|
@ -6,13 +6,25 @@ from ._basic_types import Net, unifloat
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def grad(var: variable[Any], to: variable[Any]) -> unifloat: ...
|
def grad(x: variable[Any], y: variable[Any]) -> unifloat: ...
|
||||||
@overload
|
@overload
|
||||||
def grad(var: variable[Any], to: Sequence[variable[Any]]) -> Sequence[unifloat]: ...
|
def grad(x: variable[Any], y: Sequence[variable[Any]]) -> list[unifloat]: ...
|
||||||
@overload
|
@overload
|
||||||
def grad(var: variable[Any], to: vector[Any]) -> vector[float]: ...
|
def grad(x: variable[Any], y: vector[Any]) -> vector[float]: ...
|
||||||
def grad(var: variable[Any], to: variable[Any] | Sequence[variable[Any]] | vector[Any]) -> unifloat | Sequence[unifloat] | vector[float]:
|
@overload
|
||||||
edges = cpb.get_all_dag_edges([var.source])
|
def grad(x: variable[Any], y: matrix[Any]) -> matrix[float]: ...
|
||||||
|
def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[float]) -> Any:
|
||||||
|
"""Returns the partial derivative dx/dy where x needs to be a scalar
|
||||||
|
and y might be a scalar, a list of scalars, a vector or matrix.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: Value to return derivative of
|
||||||
|
y: Value(s) to derive in respect to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Derivative of x with the type and dimensions of y
|
||||||
|
"""
|
||||||
|
edges = cpb.get_all_dag_edges([x.source])
|
||||||
ordered_ops = cpb.stable_toposort(edges)
|
ordered_ops = cpb.stable_toposort(edges)
|
||||||
|
|
||||||
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
|
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
|
||||||
|
|
@ -24,79 +36,81 @@ def grad(var: variable[Any], to: variable[Any] | Sequence[variable[Any]] | vecto
|
||||||
for node in reversed(ordered_ops):
|
for node in reversed(ordered_ops):
|
||||||
print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
|
print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
|
||||||
if node.args:
|
if node.args:
|
||||||
args: Sequence[Any] = [v for v in node.args]
|
args: Sequence[Any] = list(node.args)
|
||||||
g = 1.0 if node is var.source else grad_dict[net_lookup[node]]
|
g = 1.0 if node is x.source else grad_dict[net_lookup[node]]
|
||||||
opn = node.name.split('_')[0]
|
opn = node.name.split('_')[0]
|
||||||
x: variable[Any] = args[0]
|
a: variable[Any] = args[0]
|
||||||
y: variable[Any] = args[1] if len(args) > 1 else x
|
b: variable[Any] = args[1] if len(args) > 1 else a
|
||||||
|
|
||||||
if opn in ['ge', 'gt', 'eq', 'ne']:
|
if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']:
|
||||||
pass # Derivative is 0
|
pass # Derivative is 0 for all ops returning integers
|
||||||
|
|
||||||
elif opn == 'add':
|
elif opn == 'add':
|
||||||
add_grad(x, g)
|
add_grad(a, g)
|
||||||
add_grad(y, g)
|
add_grad(b, g)
|
||||||
|
|
||||||
elif opn == 'sub':
|
elif opn == 'sub':
|
||||||
add_grad(x, g)
|
add_grad(a, g)
|
||||||
add_grad(y, -g)
|
add_grad(b, -g)
|
||||||
|
|
||||||
elif opn == 'mul':
|
elif opn == 'mul':
|
||||||
add_grad(x, y * g)
|
add_grad(a, b * g)
|
||||||
add_grad(y, x * g)
|
add_grad(b, a * g)
|
||||||
|
|
||||||
elif opn == 'div':
|
elif opn == 'div':
|
||||||
add_grad(x, g / y)
|
add_grad(a, g / b)
|
||||||
add_grad(y, -x * g / (y**2))
|
add_grad(b, -a * g / (b**2))
|
||||||
|
|
||||||
elif opn == 'pow':
|
elif opn == 'mod':
|
||||||
add_grad(x, (y * (x ** (y - 1))) * g)
|
add_grad(a, g)
|
||||||
add_grad(y, (x ** y * cp.log(x)) * g)
|
add_grad(b, -a * g / b)
|
||||||
|
|
||||||
elif opn == 'sqrt':
|
|
||||||
add_grad(x, g * (0.5 / cp.sqrt(x)))
|
|
||||||
|
|
||||||
elif opn == 'abs':
|
|
||||||
add_grad(x, g * cp.sign(x))
|
|
||||||
|
|
||||||
elif opn == 'sin':
|
|
||||||
add_grad(x, g * cp.cos(x))
|
|
||||||
|
|
||||||
elif opn == 'cos':
|
|
||||||
add_grad(x, g * -cp.sin(x))
|
|
||||||
|
|
||||||
elif opn == 'tan':
|
|
||||||
add_grad(x, g * (1 / cp.cos(x) ** 2))
|
|
||||||
|
|
||||||
elif opn == 'asin':
|
|
||||||
add_grad(x, g * (1 / cp.sqrt(1 - x**2)))
|
|
||||||
|
|
||||||
elif opn == 'acos':
|
|
||||||
add_grad(x, g * (-1 / cp.sqrt(1 - x**2)))
|
|
||||||
|
|
||||||
elif opn == 'atan':
|
|
||||||
add_grad(x, g * (1 / (1 + x**2)))
|
|
||||||
|
|
||||||
elif opn == 'atan2':
|
|
||||||
denom = x**2 + y**2
|
|
||||||
add_grad(x, g * (-y / denom))
|
|
||||||
add_grad(y, g * ( x / denom))
|
|
||||||
|
|
||||||
elif opn == 'log':
|
elif opn == 'log':
|
||||||
add_grad(x, g / x)
|
add_grad(a, g / a)
|
||||||
|
|
||||||
elif opn == 'exp':
|
elif opn == 'exp':
|
||||||
add_grad(x, g * cp.exp(x))
|
add_grad(a, g * cp.exp(a))
|
||||||
|
|
||||||
elif opn == 'gt':
|
elif opn == 'pow':
|
||||||
add_grad(x, g)
|
add_grad(a, (b * (a ** (b - 1))) * g)
|
||||||
add_grad(y, -g)
|
add_grad(b, (a ** b * cp.log(a)) * g)
|
||||||
|
|
||||||
|
elif opn == 'sqrt':
|
||||||
|
add_grad(a, g * (0.5 / cp.sqrt(a)))
|
||||||
|
|
||||||
|
#elif opn == 'abs':
|
||||||
|
# add_grad(x, g * cp.sign(x))
|
||||||
|
|
||||||
|
elif opn == 'sin':
|
||||||
|
add_grad(a, g * cp.cos(a))
|
||||||
|
|
||||||
|
elif opn == 'cos':
|
||||||
|
add_grad(a, g * -cp.sin(a))
|
||||||
|
|
||||||
|
elif opn == 'tan':
|
||||||
|
add_grad(a, g * (1 / cp.cos(a) ** 2))
|
||||||
|
|
||||||
|
elif opn == 'asin':
|
||||||
|
add_grad(a, g * (1 / cp.sqrt(1 - a**2)))
|
||||||
|
|
||||||
|
elif opn == 'acos':
|
||||||
|
add_grad(a, g * (-1 / cp.sqrt(1 - a**2)))
|
||||||
|
|
||||||
|
elif opn == 'atan':
|
||||||
|
add_grad(a, g * (1 / (1 + a**2)))
|
||||||
|
|
||||||
|
elif opn == 'atan2':
|
||||||
|
denom = a**2 + b**2
|
||||||
|
add_grad(a, g * (-b / denom))
|
||||||
|
add_grad(b, g * ( a / denom))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Operation {opn} not yet supported for auto diff.")
|
raise ValueError(f"Operation {opn} not yet supported for auto diff.")
|
||||||
|
|
||||||
if isinstance(to, variable):
|
if isinstance(y, variable):
|
||||||
return grad_dict[to]
|
return grad_dict[y]
|
||||||
if isinstance(to, vector):
|
if isinstance(y, vector):
|
||||||
return vector(grad_dict[dvar] for dvar in to)
|
return vector(grad_dict[yi] if isinstance(yi, variable) else 0.0 for yi in y)
|
||||||
return [grad_dict[dvar] for dvar in to]
|
if isinstance(y, matrix):
|
||||||
|
return matrix((grad_dict[yi] if isinstance(yi, variable) else 0.0 for yi in row) for row in y)
|
||||||
|
return [grad_dict[yi] for yi in y]
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,13 @@ import pkgutil
|
||||||
from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast
|
from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast
|
||||||
from ._stencils import stencil_database, detect_process_arch
|
from ._stencils import stencil_database, detect_process_arch
|
||||||
import copapy as cp
|
import copapy as cp
|
||||||
|
from ._helper_types import TNum
|
||||||
|
|
||||||
NumLike: TypeAlias = 'variable[int] | variable[float] | int | float'
|
NumLike: TypeAlias = 'variable[int] | variable[float] | int | float'
|
||||||
unifloat: TypeAlias = 'variable[float] | float'
|
unifloat: TypeAlias = 'variable[float] | float'
|
||||||
uniint: TypeAlias = 'variable[int] | int'
|
uniint: TypeAlias = 'variable[int] | int'
|
||||||
|
|
||||||
TCPNum = TypeVar("TCPNum", bound='variable[Any]')
|
TCPNum = TypeVar("TCPNum", bound='variable[Any]')
|
||||||
TNum = TypeVar("TNum", int, float)
|
|
||||||
TVarNumb: TypeAlias = 'variable[Any] | int | float'
|
TVarNumb: TypeAlias = 'variable[Any] | int | float'
|
||||||
|
|
||||||
stencil_cache: dict[tuple[str, str], stencil_database] = {}
|
stencil_cache: dict[tuple[str, str], stencil_database] = {}
|
||||||
|
|
@ -312,7 +312,7 @@ class variable(Generic[TNum], Net):
|
||||||
return add_op('bwand', [self, other], True)
|
return add_op('bwand', [self, other], True)
|
||||||
|
|
||||||
def __rand__(self, other: uniint) -> 'variable[int]':
|
def __rand__(self, other: uniint) -> 'variable[int]':
|
||||||
return add_op('rwand', [other, self], True)
|
return add_op('bwand', [other, self], True)
|
||||||
|
|
||||||
def __or__(self, other: uniint) -> 'variable[int]':
|
def __or__(self, other: uniint) -> 'variable[int]':
|
||||||
return add_op('bwor', [self, other], True)
|
return add_op('bwor', [self, other], True)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
TNum = TypeVar("TNum", int, float)
|
||||||
|
U = TypeVar("U", int, float)
|
||||||
|
|
@ -2,12 +2,13 @@ from . import vector
|
||||||
from ._vectors import VecNumLike
|
from ._vectors import VecNumLike
|
||||||
from . import variable, NumLike
|
from . import variable, NumLike
|
||||||
from typing import TypeVar, Any, overload, Callable
|
from typing import TypeVar, Any, overload, Callable
|
||||||
from ._basic_types import add_op
|
from ._basic_types import add_op, unifloat
|
||||||
import math
|
import math
|
||||||
|
|
||||||
T = TypeVar("T", int, float, variable[int], variable[float])
|
T = TypeVar("T", int, float, variable[int], variable[float])
|
||||||
U = TypeVar("U", int, float)
|
U = TypeVar("U", int, float)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def exp(x: float | int) -> float: ...
|
def exp(x: float | int) -> float: ...
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -284,7 +285,9 @@ def get_42(x: NumLike) -> variable[float] | float:
|
||||||
def abs(x: U) -> U: ...
|
def abs(x: U) -> U: ...
|
||||||
@overload
|
@overload
|
||||||
def abs(x: variable[U]) -> variable[U]: ...
|
def abs(x: variable[U]) -> variable[U]: ...
|
||||||
def abs(x: U | variable[U]) -> Any:
|
@overload
|
||||||
|
def abs(x: vector[U]) -> vector[U]: ...
|
||||||
|
def abs(x: U | variable[U] | vector[U]) -> Any:
|
||||||
"""Absolute value function
|
"""Absolute value function
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
@ -293,16 +296,18 @@ def abs(x: U | variable[U]) -> Any:
|
||||||
Returns:
|
Returns:
|
||||||
Absolute value of x
|
Absolute value of x
|
||||||
"""
|
"""
|
||||||
|
#tt = -x * (x < 0)
|
||||||
ret = (x < 0) * -x + (x >= 0) * x
|
ret = (x < 0) * -x + (x >= 0) * x
|
||||||
return ret # REMpyright: ignore[reportReturnType]
|
return ret # REMpyright: ignore[reportReturnType]
|
||||||
|
|
||||||
|
|
||||||
#TODO: Add vector support
|
|
||||||
@overload
|
@overload
|
||||||
def sign(x: U) -> U: ...
|
def sign(x: U) -> U: ...
|
||||||
@overload
|
@overload
|
||||||
def sign(x: variable[U]) -> variable[U]: ...
|
def sign(x: variable[U]) -> variable[U]: ...
|
||||||
def sign(x: U | variable[U]) -> Any:
|
@overload
|
||||||
|
def sign(x: vector[U]) -> vector[U]: ...
|
||||||
|
def sign(x: U | variable[U] | vector[U]) -> Any:
|
||||||
"""Return 1 for positive numbers and -1 for negative numbers.
|
"""Return 1 for positive numbers and -1 for negative numbers.
|
||||||
For an input of 0 the return value is 0.
|
For an input of 0 the return value is 0.
|
||||||
|
|
||||||
|
|
@ -384,16 +389,16 @@ def max(x: U | variable[U], y: U | variable[U]) -> Any:
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def lerp(v1: variable[U], v2: U | variable[U], t: U | variable[U]) -> variable[U]: ...
|
def lerp(v1: variable[U], v2: U | variable[U], t: unifloat) -> variable[U]: ...
|
||||||
@overload
|
@overload
|
||||||
def lerp(v1: U | variable[U], v2: variable[U], t: U | variable[U]) -> variable[U]: ...
|
def lerp(v1: U | variable[U], v2: variable[U], t: unifloat) -> variable[U]: ...
|
||||||
@overload
|
@overload
|
||||||
def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[U]) -> variable[U]: ...
|
def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[float]) -> variable[U]: ...
|
||||||
@overload
|
@overload
|
||||||
def lerp(v1: U, v2: U, t: U) -> U: ...
|
def lerp(v1: U, v2: U, t: float) -> U: ...
|
||||||
@overload
|
@overload
|
||||||
def lerp(v1: vector[U], v2: vector[U], t: 'U | variable[U]') -> vector[U]: ...
|
def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ...
|
||||||
def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U | variable[U]) -> Any:
|
def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: unifloat) -> Any:
|
||||||
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t."""
|
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t."""
|
||||||
if isinstance(v1, vector) or isinstance(v2, vector):
|
if isinstance(v1, vector) or isinstance(v2, vector):
|
||||||
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
|
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
|
||||||
|
|
@ -402,12 +407,13 @@ def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U
|
||||||
return v1 * (1 - t) + v2 * t
|
return v1 * (1 - t) + v2 * t
|
||||||
|
|
||||||
|
|
||||||
#TODO: Add vector support
|
|
||||||
@overload
|
@overload
|
||||||
def relu(x: U) -> U: ...
|
def relu(x: U) -> U: ...
|
||||||
@overload
|
@overload
|
||||||
def relu(x: variable[U]) -> variable[U]: ...
|
def relu(x: variable[U]) -> variable[U]: ...
|
||||||
def relu(x: U | variable[U]) -> Any:
|
@overload
|
||||||
|
def relu(x: vector[U]) -> vector[U]: ...
|
||||||
|
def relu(x: U | variable[U] | vector[U]) -> Any:
|
||||||
"""Returns x for x > 0 and otherwise 0."""
|
"""Returns x for x > 0 and otherwise 0."""
|
||||||
ret = (x > 0) * x
|
ret = (x > 0) * x
|
||||||
return ret
|
return ret
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
from . import variable
|
from . import variable
|
||||||
from ._vectors import vector
|
from ._vectors import vector
|
||||||
from ._mixed import mixed_sum
|
from ._mixed import mixed_sum
|
||||||
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator
|
from typing import TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator, Generic
|
||||||
|
from ._helper_types import TNum
|
||||||
|
|
||||||
MatNumLike: TypeAlias = 'matrix[int] | matrix[float] | variable[int] | variable[float] | int | float'
|
MatNumLike: TypeAlias = 'matrix[int] | matrix[float] | variable[int] | variable[float] | int | float'
|
||||||
MatIntLike: TypeAlias = 'matrix[int] | variable[int] | int'
|
MatIntLike: TypeAlias = 'matrix[int] | variable[int] | int'
|
||||||
MatFloatLike: TypeAlias = 'matrix[float] | variable[float] | float'
|
MatFloatLike: TypeAlias = 'matrix[float] | variable[float] | float'
|
||||||
TT = TypeVar("TT", int, float)
|
|
||||||
U = TypeVar("U", int, float)
|
U = TypeVar("U", int, float)
|
||||||
|
|
||||||
|
|
||||||
class matrix(Generic[TT]):
|
class matrix(Generic[TNum]):
|
||||||
"""Mathematical matrix class supporting basic operations and interactions with variables.
|
"""Mathematical matrix class supporting basic operations and interactions with variables.
|
||||||
"""
|
"""
|
||||||
def __init__(self, values: Iterable[Iterable[TT | variable[TT]]]):
|
def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]]):
|
||||||
"""Create a matrix with given values and variables.
|
"""Create a matrix with given values and variables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -23,7 +23,7 @@ class matrix(Generic[TT]):
|
||||||
if rows:
|
if rows:
|
||||||
row_len = len(rows[0])
|
row_len = len(rows[0])
|
||||||
assert all(len(row) == row_len for row in rows), "All rows must have the same length"
|
assert all(len(row) == row_len for row in rows), "All rows must have the same length"
|
||||||
self.values: tuple[tuple[variable[TT] | TT, ...], ...] = tuple(rows)
|
self.values: tuple[tuple[variable[TNum] | TNum, ...], ...] = tuple(rows)
|
||||||
self.rows = len(self.values)
|
self.rows = len(self.values)
|
||||||
self.cols = len(self.values[0]) if self.values else 0
|
self.cols = len(self.values[0]) if self.values else 0
|
||||||
|
|
||||||
|
|
@ -33,13 +33,13 @@ class matrix(Generic[TT]):
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.rows
|
return self.rows
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> tuple[variable[TT] | TT, ...]:
|
def __getitem__(self, index: int) -> tuple[variable[TNum] | TNum, ...]:
|
||||||
return self.values[index]
|
return self.values[index]
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[tuple[variable[TT] | TT, ...]]:
|
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
|
||||||
return iter(self.values)
|
return iter(self.values)
|
||||||
|
|
||||||
def __neg__(self) -> 'matrix[TT]':
|
def __neg__(self) -> 'matrix[TNum]':
|
||||||
return matrix((-a for a in row) for row in self.values)
|
return matrix((-a for a in row) for row in self.values)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -165,10 +165,10 @@ class matrix(Generic[TT]):
|
||||||
)
|
)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __matmul__(self: 'matrix[TT]', other: 'vector[TT]') -> 'vector[TT]': ...
|
def __matmul__(self: 'matrix[TNum]', other: 'vector[TNum]') -> 'vector[TNum]': ...
|
||||||
@overload
|
@overload
|
||||||
def __matmul__(self: 'matrix[TT]', other: 'matrix[TT]') -> 'matrix[TT]': ...
|
def __matmul__(self: 'matrix[TNum]', other: 'matrix[TNum]') -> 'matrix[TNum]': ...
|
||||||
def __matmul__(self: 'matrix[TT]', other: 'matrix[TT] | vector[TT]') -> 'matrix[TT] | vector[TT]':
|
def __matmul__(self: 'matrix[TNum]', other: 'matrix[TNum] | vector[TNum]') -> 'matrix[TNum] | vector[TNum]':
|
||||||
"""Matrix multiplication using @ operator"""
|
"""Matrix multiplication using @ operator"""
|
||||||
if isinstance(other, vector):
|
if isinstance(other, vector):
|
||||||
assert self.cols == len(other.values), \
|
assert self.cols == len(other.values), \
|
||||||
|
|
@ -179,9 +179,9 @@ class matrix(Generic[TT]):
|
||||||
assert isinstance(other, matrix), "Cannot multiply matrix with {type(other)}"
|
assert isinstance(other, matrix), "Cannot multiply matrix with {type(other)}"
|
||||||
assert self.cols == other.rows, \
|
assert self.cols == other.rows, \
|
||||||
f"Matrix columns ({self.cols}) must match other matrix rows ({other.rows})"
|
f"Matrix columns ({self.cols}) must match other matrix rows ({other.rows})"
|
||||||
result: list[list[TT | variable[TT]]] = []
|
result: list[list[TNum | variable[TNum]]] = []
|
||||||
for row in self.values:
|
for row in self.values:
|
||||||
new_row: list[TT | variable[TT]] = []
|
new_row: list[TNum | variable[TNum]] = []
|
||||||
for col_idx in range(other.cols):
|
for col_idx in range(other.cols):
|
||||||
col = tuple(other.values[i][col_idx] for i in range(other.rows))
|
col = tuple(other.values[i][col_idx] for i in range(other.rows))
|
||||||
element = sum(a * b for a, b in zip(row, col))
|
element = sum(a * b for a, b in zip(row, col))
|
||||||
|
|
@ -189,7 +189,7 @@ class matrix(Generic[TT]):
|
||||||
result.append(new_row)
|
result.append(new_row)
|
||||||
return matrix(result)
|
return matrix(result)
|
||||||
|
|
||||||
def transpose(self) -> 'matrix[TT]':
|
def transpose(self) -> 'matrix[TNum]':
|
||||||
"""Return the transpose of the matrix."""
|
"""Return the transpose of the matrix."""
|
||||||
if not self.values:
|
if not self.values:
|
||||||
return matrix([])
|
return matrix([])
|
||||||
|
|
@ -199,21 +199,21 @@ class matrix(Generic[TT]):
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def T(self) -> 'matrix[TT]':
|
def T(self) -> 'matrix[TNum]':
|
||||||
return self.transpose()
|
return self.transpose()
|
||||||
|
|
||||||
def row(self, index: int) -> vector[TT]:
|
def row(self, index: int) -> vector[TNum]:
|
||||||
"""Get a row as a vector."""
|
"""Get a row as a vector."""
|
||||||
assert 0 <= index < self.rows, f"Row index {index} out of bounds"
|
assert 0 <= index < self.rows, f"Row index {index} out of bounds"
|
||||||
return vector(self.values[index])
|
return vector(self.values[index])
|
||||||
|
|
||||||
def col(self, index: int) -> vector[TT]:
|
def col(self, index: int) -> vector[TNum]:
|
||||||
"""Get a column as a vector."""
|
"""Get a column as a vector."""
|
||||||
assert 0 <= index < self.cols, f"Column index {index} out of bounds"
|
assert 0 <= index < self.cols, f"Column index {index} out of bounds"
|
||||||
return vector(self.values[i][index] for i in range(self.rows))
|
return vector(self.values[i][index] for i in range(self.rows))
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def trace(self: 'matrix[TT]') -> TT | variable[TT]: ...
|
def trace(self: 'matrix[TNum]') -> TNum | variable[TNum]: ...
|
||||||
@overload
|
@overload
|
||||||
def trace(self: 'matrix[int]') -> int | variable[int]: ...
|
def trace(self: 'matrix[int]') -> int | variable[int]: ...
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -224,7 +224,7 @@ class matrix(Generic[TT]):
|
||||||
return mixed_sum(self.values[i][i] for i in range(self.rows))
|
return mixed_sum(self.values[i][i] for i in range(self.rows))
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def sum(self: 'matrix[TT]') -> TT | variable[TT]: ...
|
def sum(self: 'matrix[TNum]') -> TNum | variable[TNum]: ...
|
||||||
@overload
|
@overload
|
||||||
def sum(self: 'matrix[int]') -> int | variable[int]: ...
|
def sum(self: 'matrix[int]') -> int | variable[int]: ...
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -240,7 +240,7 @@ class matrix(Generic[TT]):
|
||||||
for row in self.values
|
for row in self.values
|
||||||
)
|
)
|
||||||
|
|
||||||
def homogenize(self) -> 'matrix[TT]':
|
def homogenize(self) -> 'matrix[TNum]':
|
||||||
"""Convert all elements to variables if any element is a variable."""
|
"""Convert all elements to variables if any element is a variable."""
|
||||||
if any(isinstance(val, variable) for row in self.values for val in row):
|
if any(isinstance(val, variable) for row in self.values for val in row):
|
||||||
return matrix(
|
return matrix(
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,28 @@
|
||||||
from . import variable
|
from . import variable
|
||||||
from ._mixed import mixed_sum, mixed_homogenize
|
from ._mixed import mixed_sum, mixed_homogenize
|
||||||
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator
|
from typing import TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator, Generic
|
||||||
import copapy as cp
|
import copapy as cp
|
||||||
|
from ._helper_types import TNum
|
||||||
|
|
||||||
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
|
#VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
|
||||||
|
VecNumLike: TypeAlias = 'vector[Any] | variable[Any] | int | float | bool'
|
||||||
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
|
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
|
||||||
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
|
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
|
||||||
T = TypeVar("T", int, float)
|
|
||||||
U = TypeVar("U", int, float)
|
U = TypeVar("U", int, float)
|
||||||
|
|
||||||
epsilon = 1e-20
|
epsilon = 1e-20
|
||||||
|
|
||||||
|
|
||||||
class vector(Generic[T]):
|
class vector(Generic[TNum]):
|
||||||
"""Mathematical vector class supporting basic operations and interactions with variables.
|
"""Mathematical vector class supporting basic operations and interactions with variables.
|
||||||
"""
|
"""
|
||||||
def __init__(self, values: Iterable[T | variable[T]]):
|
def __init__(self, values: Iterable[TNum | variable[TNum]]):
|
||||||
"""Create a vector with given values and variables.
|
"""Create a vector with given values and variables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values: iterable of constant values and variables
|
values: iterable of constant values and variables
|
||||||
"""
|
"""
|
||||||
self.values: tuple[variable[T] | T, ...] = tuple(values)
|
self.values: tuple[variable[TNum] | TNum, ...] = tuple(values)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"vector({self.values})"
|
return f"vector({self.values})"
|
||||||
|
|
@ -29,13 +30,13 @@ class vector(Generic[T]):
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.values)
|
return len(self.values)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> variable[T] | T:
|
def __getitem__(self, index: int) -> variable[TNum] | TNum:
|
||||||
return self.values[index]
|
return self.values[index]
|
||||||
|
|
||||||
def __neg__(self) -> 'vector[T]':
|
def __neg__(self) -> 'vector[TNum]':
|
||||||
return vector(-a for a in self.values)
|
return vector(-a for a in self.values)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[variable[T] | T]:
|
def __iter__(self) -> Iterator[variable[TNum] | TNum]:
|
||||||
return iter(self.values)
|
return iter(self.values)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -56,6 +57,8 @@ class vector(Generic[T]):
|
||||||
def __radd__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
def __radd__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __radd__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
def __radd__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
||||||
|
@overload
|
||||||
|
def __radd__(self, other: VecNumLike) -> 'vector[Any]': ...
|
||||||
def __radd__(self, other: Any) -> Any:
|
def __radd__(self, other: Any) -> Any:
|
||||||
return self + other
|
return self + other
|
||||||
|
|
||||||
|
|
@ -77,6 +80,8 @@ class vector(Generic[T]):
|
||||||
def __rsub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
def __rsub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __rsub__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
def __rsub__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
||||||
|
@overload
|
||||||
|
def __rsub__(self, other: VecNumLike) -> 'vector[Any]': ...
|
||||||
def __rsub__(self, other: VecNumLike) -> Any:
|
def __rsub__(self, other: VecNumLike) -> Any:
|
||||||
if isinstance(other, vector):
|
if isinstance(other, vector):
|
||||||
assert len(self.values) == len(other.values)
|
assert len(self.values) == len(other.values)
|
||||||
|
|
@ -101,6 +106,8 @@ class vector(Generic[T]):
|
||||||
def __rmul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
def __rmul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
||||||
@overload
|
@overload
|
||||||
def __rmul__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
def __rmul__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
||||||
|
@overload
|
||||||
|
def __rmul__(self, other: VecNumLike) -> 'vector[Any]': ...
|
||||||
def __rmul__(self, other: VecNumLike) -> Any:
|
def __rmul__(self, other: VecNumLike) -> Any:
|
||||||
return self * other
|
return self * other
|
||||||
|
|
||||||
|
|
@ -151,6 +158,42 @@ class vector(Generic[T]):
|
||||||
a1 * b2 - a2 * b1
|
a1 * b2 - a2 * b1
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def __gt__(self, other: VecNumLike) -> 'vector[int]':
|
||||||
|
if isinstance(other, vector):
|
||||||
|
assert len(self.values) == len(other.values)
|
||||||
|
return vector(a > b for a, b in zip(self.values, other.values))
|
||||||
|
return vector(a > other for a in self.values)
|
||||||
|
|
||||||
|
def __lt__(self, other: VecNumLike) -> 'vector[int]':
|
||||||
|
if isinstance(other, vector):
|
||||||
|
assert len(self.values) == len(other.values)
|
||||||
|
return vector(a < b for a, b in zip(self.values, other.values))
|
||||||
|
return vector(a < other for a in self.values)
|
||||||
|
|
||||||
|
def __ge__(self, other: VecNumLike) -> 'vector[int]':
|
||||||
|
if isinstance(other, vector):
|
||||||
|
assert len(self.values) == len(other.values)
|
||||||
|
return vector(a >= b for a, b in zip(self.values, other.values))
|
||||||
|
return vector(a >= other for a in self.values)
|
||||||
|
|
||||||
|
def __le__(self, other: VecNumLike) -> 'vector[int]':
|
||||||
|
if isinstance(other, vector):
|
||||||
|
assert len(self.values) == len(other.values)
|
||||||
|
return vector(a <= b for a, b in zip(self.values, other.values))
|
||||||
|
return vector(a <= other for a in self.values)
|
||||||
|
|
||||||
|
def __eq__(self, other: VecNumLike) -> 'vector[int]': # type: ignore
|
||||||
|
if isinstance(other, vector):
|
||||||
|
assert len(self.values) == len(other.values)
|
||||||
|
return vector(a == b for a, b in zip(self.values, other.values))
|
||||||
|
return vector(a == other for a in self.values)
|
||||||
|
|
||||||
|
def __ne__(self, other: VecNumLike) -> 'vector[int]': # type: ignore
|
||||||
|
if isinstance(other, vector):
|
||||||
|
assert len(self.values) == len(other.values)
|
||||||
|
return vector(a != b for a, b in zip(self.values, other.values))
|
||||||
|
return vector(a != other for a in self.values)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def sum(self: 'vector[int]') -> int | variable[int]: ...
|
def sum(self: 'vector[int]') -> int | variable[int]: ...
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -169,7 +212,7 @@ class vector(Generic[T]):
|
||||||
mag = self.magnitude() + epsilon
|
mag = self.magnitude() + epsilon
|
||||||
return self / mag
|
return self / mag
|
||||||
|
|
||||||
def homogenize(self) -> 'vector[T]':
|
def homogenize(self) -> 'vector[TNum]':
|
||||||
if any(isinstance(val, variable) for val in self.values):
|
if any(isinstance(val, variable) for val in self.values):
|
||||||
return vector(mixed_homogenize(self))
|
return vector(mixed_homogenize(self))
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue