type hints revised

This commit is contained in:
Nicolas Kruse 2025-12-04 18:18:29 +01:00
parent 61dc29e68b
commit ebb4abc5d3
7 changed files with 182 additions and 115 deletions

View File

@ -1,4 +1,4 @@
from . import variable, vector from . import variable, vector, matrix
import copapy.backend as cpb import copapy.backend as cpb
from typing import Any, Sequence, overload from typing import Any, Sequence, overload
import copapy as cp import copapy as cp
@ -6,13 +6,25 @@ from ._basic_types import Net, unifloat
@overload @overload
def grad(var: variable[Any], to: variable[Any]) -> unifloat: ... def grad(x: variable[Any], y: variable[Any]) -> unifloat: ...
@overload @overload
def grad(var: variable[Any], to: Sequence[variable[Any]]) -> Sequence[unifloat]: ... def grad(x: variable[Any], y: Sequence[variable[Any]]) -> list[unifloat]: ...
@overload @overload
def grad(var: variable[Any], to: vector[Any]) -> vector[float]: ... def grad(x: variable[Any], y: vector[Any]) -> vector[float]: ...
def grad(var: variable[Any], to: variable[Any] | Sequence[variable[Any]] | vector[Any]) -> unifloat | Sequence[unifloat] | vector[float]: @overload
edges = cpb.get_all_dag_edges([var.source]) def grad(x: variable[Any], y: matrix[Any]) -> matrix[float]: ...
def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[float]) -> Any:
"""Returns the partial derivative dx/dy where x needs to be a scalar
and y might be a scalar, a list of scalars, a vector or matrix.
Arguments:
x: Value to return derivative of
y: Value(s) to derive in respect to
Returns:
Derivative of x with the type and dimensions of y
"""
edges = cpb.get_all_dag_edges([x.source])
ordered_ops = cpb.stable_toposort(edges) ordered_ops = cpb.stable_toposort(edges)
net_lookup = {net.source: net for node in ordered_ops for net in node.args} net_lookup = {net.source: net for node in ordered_ops for net in node.args}
@ -24,79 +36,81 @@ def grad(var: variable[Any], to: variable[Any] | Sequence[variable[Any]] | vecto
for node in reversed(ordered_ops): for node in reversed(ordered_ops):
print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}") print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
if node.args: if node.args:
args: Sequence[Any] = [v for v in node.args] args: Sequence[Any] = list(node.args)
g = 1.0 if node is var.source else grad_dict[net_lookup[node]] g = 1.0 if node is x.source else grad_dict[net_lookup[node]]
opn = node.name.split('_')[0] opn = node.name.split('_')[0]
x: variable[Any] = args[0] a: variable[Any] = args[0]
y: variable[Any] = args[1] if len(args) > 1 else x b: variable[Any] = args[1] if len(args) > 1 else a
if opn in ['ge', 'gt', 'eq', 'ne']: if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']:
pass # Derivative is 0 pass # Derivative is 0 for all ops returning integers
elif opn == 'add': elif opn == 'add':
add_grad(x, g) add_grad(a, g)
add_grad(y, g) add_grad(b, g)
elif opn == 'sub': elif opn == 'sub':
add_grad(x, g) add_grad(a, g)
add_grad(y, -g) add_grad(b, -g)
elif opn == 'mul': elif opn == 'mul':
add_grad(x, y * g) add_grad(a, b * g)
add_grad(y, x * g) add_grad(b, a * g)
elif opn == 'div': elif opn == 'div':
add_grad(x, g / y) add_grad(a, g / b)
add_grad(y, -x * g / (y**2)) add_grad(b, -a * g / (b**2))
elif opn == 'pow': elif opn == 'mod':
add_grad(x, (y * (x ** (y - 1))) * g) add_grad(a, g)
add_grad(y, (x ** y * cp.log(x)) * g) add_grad(b, -a * g / b)
elif opn == 'sqrt':
add_grad(x, g * (0.5 / cp.sqrt(x)))
elif opn == 'abs':
add_grad(x, g * cp.sign(x))
elif opn == 'sin':
add_grad(x, g * cp.cos(x))
elif opn == 'cos':
add_grad(x, g * -cp.sin(x))
elif opn == 'tan':
add_grad(x, g * (1 / cp.cos(x) ** 2))
elif opn == 'asin':
add_grad(x, g * (1 / cp.sqrt(1 - x**2)))
elif opn == 'acos':
add_grad(x, g * (-1 / cp.sqrt(1 - x**2)))
elif opn == 'atan':
add_grad(x, g * (1 / (1 + x**2)))
elif opn == 'atan2':
denom = x**2 + y**2
add_grad(x, g * (-y / denom))
add_grad(y, g * ( x / denom))
elif opn == 'log': elif opn == 'log':
add_grad(x, g / x) add_grad(a, g / a)
elif opn == 'exp': elif opn == 'exp':
add_grad(x, g * cp.exp(x)) add_grad(a, g * cp.exp(a))
elif opn == 'gt': elif opn == 'pow':
add_grad(x, g) add_grad(a, (b * (a ** (b - 1))) * g)
add_grad(y, -g) add_grad(b, (a ** b * cp.log(a)) * g)
elif opn == 'sqrt':
add_grad(a, g * (0.5 / cp.sqrt(a)))
#elif opn == 'abs':
# add_grad(x, g * cp.sign(x))
elif opn == 'sin':
add_grad(a, g * cp.cos(a))
elif opn == 'cos':
add_grad(a, g * -cp.sin(a))
elif opn == 'tan':
add_grad(a, g * (1 / cp.cos(a) ** 2))
elif opn == 'asin':
add_grad(a, g * (1 / cp.sqrt(1 - a**2)))
elif opn == 'acos':
add_grad(a, g * (-1 / cp.sqrt(1 - a**2)))
elif opn == 'atan':
add_grad(a, g * (1 / (1 + a**2)))
elif opn == 'atan2':
denom = a**2 + b**2
add_grad(a, g * (-b / denom))
add_grad(b, g * ( a / denom))
else: else:
raise ValueError(f"Operation {opn} not yet supported for auto diff.") raise ValueError(f"Operation {opn} not yet supported for auto diff.")
if isinstance(to, variable): if isinstance(y, variable):
return grad_dict[to] return grad_dict[y]
if isinstance(to, vector): if isinstance(y, vector):
return vector(grad_dict[dvar] for dvar in to) return vector(grad_dict[yi] if isinstance(yi, variable) else 0.0 for yi in y)
return [grad_dict[dvar] for dvar in to] if isinstance(y, matrix):
return matrix((grad_dict[yi] if isinstance(yi, variable) else 0.0 for yi in row) for row in y)
return [grad_dict[yi] for yi in y]

View File

@ -2,13 +2,13 @@ import pkgutil
from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast
from ._stencils import stencil_database, detect_process_arch from ._stencils import stencil_database, detect_process_arch
import copapy as cp import copapy as cp
from ._helper_types import TNum
NumLike: TypeAlias = 'variable[int] | variable[float] | int | float' NumLike: TypeAlias = 'variable[int] | variable[float] | int | float'
unifloat: TypeAlias = 'variable[float] | float' unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int' uniint: TypeAlias = 'variable[int] | int'
TCPNum = TypeVar("TCPNum", bound='variable[Any]') TCPNum = TypeVar("TCPNum", bound='variable[Any]')
TNum = TypeVar("TNum", int, float)
TVarNumb: TypeAlias = 'variable[Any] | int | float' TVarNumb: TypeAlias = 'variable[Any] | int | float'
stencil_cache: dict[tuple[str, str], stencil_database] = {} stencil_cache: dict[tuple[str, str], stencil_database] = {}
@ -312,7 +312,7 @@ class variable(Generic[TNum], Net):
return add_op('bwand', [self, other], True) return add_op('bwand', [self, other], True)
def __rand__(self, other: uniint) -> 'variable[int]': def __rand__(self, other: uniint) -> 'variable[int]':
return add_op('rwand', [other, self], True) return add_op('bwand', [other, self], True)
def __or__(self, other: uniint) -> 'variable[int]': def __or__(self, other: uniint) -> 'variable[int]':
return add_op('bwor', [self, other], True) return add_op('bwor', [self, other], True)

View File

@ -0,0 +1,4 @@
from typing import TypeVar
TNum = TypeVar("TNum", int, float)
U = TypeVar("U", int, float)

View File

@ -2,12 +2,13 @@ from . import vector
from ._vectors import VecNumLike from ._vectors import VecNumLike
from . import variable, NumLike from . import variable, NumLike
from typing import TypeVar, Any, overload, Callable from typing import TypeVar, Any, overload, Callable
from ._basic_types import add_op from ._basic_types import add_op, unifloat
import math import math
T = TypeVar("T", int, float, variable[int], variable[float]) T = TypeVar("T", int, float, variable[int], variable[float])
U = TypeVar("U", int, float) U = TypeVar("U", int, float)
@overload @overload
def exp(x: float | int) -> float: ... def exp(x: float | int) -> float: ...
@overload @overload
@ -284,7 +285,9 @@ def get_42(x: NumLike) -> variable[float] | float:
def abs(x: U) -> U: ... def abs(x: U) -> U: ...
@overload @overload
def abs(x: variable[U]) -> variable[U]: ... def abs(x: variable[U]) -> variable[U]: ...
def abs(x: U | variable[U]) -> Any: @overload
def abs(x: vector[U]) -> vector[U]: ...
def abs(x: U | variable[U] | vector[U]) -> Any:
"""Absolute value function """Absolute value function
Arguments: Arguments:
@ -293,16 +296,18 @@ def abs(x: U | variable[U]) -> Any:
Returns: Returns:
Absolute value of x Absolute value of x
""" """
#tt = -x * (x < 0)
ret = (x < 0) * -x + (x >= 0) * x ret = (x < 0) * -x + (x >= 0) * x
return ret # REMpyright: ignore[reportReturnType] return ret # REMpyright: ignore[reportReturnType]
#TODO: Add vector support
@overload @overload
def sign(x: U) -> U: ... def sign(x: U) -> U: ...
@overload @overload
def sign(x: variable[U]) -> variable[U]: ... def sign(x: variable[U]) -> variable[U]: ...
def sign(x: U | variable[U]) -> Any: @overload
def sign(x: vector[U]) -> vector[U]: ...
def sign(x: U | variable[U] | vector[U]) -> Any:
"""Return 1 for positive numbers and -1 for negative numbers. """Return 1 for positive numbers and -1 for negative numbers.
For an input of 0 the return value is 0. For an input of 0 the return value is 0.
@ -384,16 +389,16 @@ def max(x: U | variable[U], y: U | variable[U]) -> Any:
@overload @overload
def lerp(v1: variable[U], v2: U | variable[U], t: U | variable[U]) -> variable[U]: ... def lerp(v1: variable[U], v2: U | variable[U], t: unifloat) -> variable[U]: ...
@overload @overload
def lerp(v1: U | variable[U], v2: variable[U], t: U | variable[U]) -> variable[U]: ... def lerp(v1: U | variable[U], v2: variable[U], t: unifloat) -> variable[U]: ...
@overload @overload
def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[U]) -> variable[U]: ... def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[float]) -> variable[U]: ...
@overload @overload
def lerp(v1: U, v2: U, t: U) -> U: ... def lerp(v1: U, v2: U, t: float) -> U: ...
@overload @overload
def lerp(v1: vector[U], v2: vector[U], t: 'U | variable[U]') -> vector[U]: ... def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ...
def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U | variable[U]) -> Any: def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: unifloat) -> Any:
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t.""" """Linearly interpolate between two values or vectors v1 and v2 by a factor t."""
if isinstance(v1, vector) or isinstance(v2, vector): if isinstance(v1, vector) or isinstance(v2, vector):
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors." assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
@ -402,12 +407,13 @@ def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U
return v1 * (1 - t) + v2 * t return v1 * (1 - t) + v2 * t
#TODO: Add vector support
@overload @overload
def relu(x: U) -> U: ... def relu(x: U) -> U: ...
@overload @overload
def relu(x: variable[U]) -> variable[U]: ... def relu(x: variable[U]) -> variable[U]: ...
def relu(x: U | variable[U]) -> Any: @overload
def relu(x: vector[U]) -> vector[U]: ...
def relu(x: U | variable[U] | vector[U]) -> Any:
"""Returns x for x > 0 and otherwise 0.""" """Returns x for x > 0 and otherwise 0."""
ret = (x > 0) * x ret = (x > 0) * x
return ret return ret

View File

@ -1,19 +1,19 @@
from . import variable from . import variable
from ._vectors import vector from ._vectors import vector
from ._mixed import mixed_sum from ._mixed import mixed_sum
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator from typing import TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator, Generic
from ._helper_types import TNum
MatNumLike: TypeAlias = 'matrix[int] | matrix[float] | variable[int] | variable[float] | int | float' MatNumLike: TypeAlias = 'matrix[int] | matrix[float] | variable[int] | variable[float] | int | float'
MatIntLike: TypeAlias = 'matrix[int] | variable[int] | int' MatIntLike: TypeAlias = 'matrix[int] | variable[int] | int'
MatFloatLike: TypeAlias = 'matrix[float] | variable[float] | float' MatFloatLike: TypeAlias = 'matrix[float] | variable[float] | float'
TT = TypeVar("TT", int, float)
U = TypeVar("U", int, float) U = TypeVar("U", int, float)
class matrix(Generic[TT]): class matrix(Generic[TNum]):
"""Mathematical matrix class supporting basic operations and interactions with variables. """Mathematical matrix class supporting basic operations and interactions with variables.
""" """
def __init__(self, values: Iterable[Iterable[TT | variable[TT]]]): def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]]):
"""Create a matrix with given values and variables. """Create a matrix with given values and variables.
Args: Args:
@ -23,7 +23,7 @@ class matrix(Generic[TT]):
if rows: if rows:
row_len = len(rows[0]) row_len = len(rows[0])
assert all(len(row) == row_len for row in rows), "All rows must have the same length" assert all(len(row) == row_len for row in rows), "All rows must have the same length"
self.values: tuple[tuple[variable[TT] | TT, ...], ...] = tuple(rows) self.values: tuple[tuple[variable[TNum] | TNum, ...], ...] = tuple(rows)
self.rows = len(self.values) self.rows = len(self.values)
self.cols = len(self.values[0]) if self.values else 0 self.cols = len(self.values[0]) if self.values else 0
@ -33,13 +33,13 @@ class matrix(Generic[TT]):
def __len__(self) -> int: def __len__(self) -> int:
return self.rows return self.rows
def __getitem__(self, index: int) -> tuple[variable[TT] | TT, ...]: def __getitem__(self, index: int) -> tuple[variable[TNum] | TNum, ...]:
return self.values[index] return self.values[index]
def __iter__(self) -> Iterator[tuple[variable[TT] | TT, ...]]: def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
return iter(self.values) return iter(self.values)
def __neg__(self) -> 'matrix[TT]': def __neg__(self) -> 'matrix[TNum]':
return matrix((-a for a in row) for row in self.values) return matrix((-a for a in row) for row in self.values)
@overload @overload
@ -165,10 +165,10 @@ class matrix(Generic[TT]):
) )
@overload @overload
def __matmul__(self: 'matrix[TT]', other: 'vector[TT]') -> 'vector[TT]': ... def __matmul__(self: 'matrix[TNum]', other: 'vector[TNum]') -> 'vector[TNum]': ...
@overload @overload
def __matmul__(self: 'matrix[TT]', other: 'matrix[TT]') -> 'matrix[TT]': ... def __matmul__(self: 'matrix[TNum]', other: 'matrix[TNum]') -> 'matrix[TNum]': ...
def __matmul__(self: 'matrix[TT]', other: 'matrix[TT] | vector[TT]') -> 'matrix[TT] | vector[TT]': def __matmul__(self: 'matrix[TNum]', other: 'matrix[TNum] | vector[TNum]') -> 'matrix[TNum] | vector[TNum]':
"""Matrix multiplication using @ operator""" """Matrix multiplication using @ operator"""
if isinstance(other, vector): if isinstance(other, vector):
assert self.cols == len(other.values), \ assert self.cols == len(other.values), \
@ -179,9 +179,9 @@ class matrix(Generic[TT]):
assert isinstance(other, matrix), "Cannot multiply matrix with {type(other)}" assert isinstance(other, matrix), "Cannot multiply matrix with {type(other)}"
assert self.cols == other.rows, \ assert self.cols == other.rows, \
f"Matrix columns ({self.cols}) must match other matrix rows ({other.rows})" f"Matrix columns ({self.cols}) must match other matrix rows ({other.rows})"
result: list[list[TT | variable[TT]]] = [] result: list[list[TNum | variable[TNum]]] = []
for row in self.values: for row in self.values:
new_row: list[TT | variable[TT]] = [] new_row: list[TNum | variable[TNum]] = []
for col_idx in range(other.cols): for col_idx in range(other.cols):
col = tuple(other.values[i][col_idx] for i in range(other.rows)) col = tuple(other.values[i][col_idx] for i in range(other.rows))
element = sum(a * b for a, b in zip(row, col)) element = sum(a * b for a, b in zip(row, col))
@ -189,7 +189,7 @@ class matrix(Generic[TT]):
result.append(new_row) result.append(new_row)
return matrix(result) return matrix(result)
def transpose(self) -> 'matrix[TT]': def transpose(self) -> 'matrix[TNum]':
"""Return the transpose of the matrix.""" """Return the transpose of the matrix."""
if not self.values: if not self.values:
return matrix([]) return matrix([])
@ -199,21 +199,21 @@ class matrix(Generic[TT]):
) )
@property @property
def T(self) -> 'matrix[TT]': def T(self) -> 'matrix[TNum]':
return self.transpose() return self.transpose()
def row(self, index: int) -> vector[TT]: def row(self, index: int) -> vector[TNum]:
"""Get a row as a vector.""" """Get a row as a vector."""
assert 0 <= index < self.rows, f"Row index {index} out of bounds" assert 0 <= index < self.rows, f"Row index {index} out of bounds"
return vector(self.values[index]) return vector(self.values[index])
def col(self, index: int) -> vector[TT]: def col(self, index: int) -> vector[TNum]:
"""Get a column as a vector.""" """Get a column as a vector."""
assert 0 <= index < self.cols, f"Column index {index} out of bounds" assert 0 <= index < self.cols, f"Column index {index} out of bounds"
return vector(self.values[i][index] for i in range(self.rows)) return vector(self.values[i][index] for i in range(self.rows))
@overload @overload
def trace(self: 'matrix[TT]') -> TT | variable[TT]: ... def trace(self: 'matrix[TNum]') -> TNum | variable[TNum]: ...
@overload @overload
def trace(self: 'matrix[int]') -> int | variable[int]: ... def trace(self: 'matrix[int]') -> int | variable[int]: ...
@overload @overload
@ -224,7 +224,7 @@ class matrix(Generic[TT]):
return mixed_sum(self.values[i][i] for i in range(self.rows)) return mixed_sum(self.values[i][i] for i in range(self.rows))
@overload @overload
def sum(self: 'matrix[TT]') -> TT | variable[TT]: ... def sum(self: 'matrix[TNum]') -> TNum | variable[TNum]: ...
@overload @overload
def sum(self: 'matrix[int]') -> int | variable[int]: ... def sum(self: 'matrix[int]') -> int | variable[int]: ...
@overload @overload
@ -240,7 +240,7 @@ class matrix(Generic[TT]):
for row in self.values for row in self.values
) )
def homogenize(self) -> 'matrix[TT]': def homogenize(self) -> 'matrix[TNum]':
"""Convert all elements to variables if any element is a variable.""" """Convert all elements to variables if any element is a variable."""
if any(isinstance(val, variable) for row in self.values for val in row): if any(isinstance(val, variable) for row in self.values for val in row):
return matrix( return matrix(

View File

@ -1,27 +1,28 @@
from . import variable from . import variable
from ._mixed import mixed_sum, mixed_homogenize from ._mixed import mixed_sum, mixed_homogenize
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator from typing import TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator, Generic
import copapy as cp import copapy as cp
from ._helper_types import TNum
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool' #VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
VecNumLike: TypeAlias = 'vector[Any] | variable[Any] | int | float | bool'
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int' VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float' VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
T = TypeVar("T", int, float)
U = TypeVar("U", int, float) U = TypeVar("U", int, float)
epsilon = 1e-20 epsilon = 1e-20
class vector(Generic[T]): class vector(Generic[TNum]):
"""Mathematical vector class supporting basic operations and interactions with variables. """Mathematical vector class supporting basic operations and interactions with variables.
""" """
def __init__(self, values: Iterable[T | variable[T]]): def __init__(self, values: Iterable[TNum | variable[TNum]]):
"""Create a vector with given values and variables. """Create a vector with given values and variables.
Args: Args:
values: iterable of constant values and variables values: iterable of constant values and variables
""" """
self.values: tuple[variable[T] | T, ...] = tuple(values) self.values: tuple[variable[TNum] | TNum, ...] = tuple(values)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"vector({self.values})" return f"vector({self.values})"
@ -29,13 +30,13 @@ class vector(Generic[T]):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.values) return len(self.values)
def __getitem__(self, index: int) -> variable[T] | T: def __getitem__(self, index: int) -> variable[TNum] | TNum:
return self.values[index] return self.values[index]
def __neg__(self) -> 'vector[T]': def __neg__(self) -> 'vector[TNum]':
return vector(-a for a in self.values) return vector(-a for a in self.values)
def __iter__(self) -> Iterator[variable[T] | T]: def __iter__(self) -> Iterator[variable[TNum] | TNum]:
return iter(self.values) return iter(self.values)
@overload @overload
@ -56,6 +57,8 @@ class vector(Generic[T]):
def __radd__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ... def __radd__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload @overload
def __radd__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ... def __radd__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
@overload
def __radd__(self, other: VecNumLike) -> 'vector[Any]': ...
def __radd__(self, other: Any) -> Any: def __radd__(self, other: Any) -> Any:
return self + other return self + other
@ -77,6 +80,8 @@ class vector(Generic[T]):
def __rsub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ... def __rsub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload @overload
def __rsub__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ... def __rsub__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
@overload
def __rsub__(self, other: VecNumLike) -> 'vector[Any]': ...
def __rsub__(self, other: VecNumLike) -> Any: def __rsub__(self, other: VecNumLike) -> Any:
if isinstance(other, vector): if isinstance(other, vector):
assert len(self.values) == len(other.values) assert len(self.values) == len(other.values)
@ -101,6 +106,8 @@ class vector(Generic[T]):
def __rmul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ... def __rmul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload @overload
def __rmul__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ... def __rmul__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
@overload
def __rmul__(self, other: VecNumLike) -> 'vector[Any]': ...
def __rmul__(self, other: VecNumLike) -> Any: def __rmul__(self, other: VecNumLike) -> Any:
return self * other return self * other
@ -151,6 +158,42 @@ class vector(Generic[T]):
a1 * b2 - a2 * b1 a1 * b2 - a2 * b1
]) ])
def __gt__(self, other: VecNumLike) -> 'vector[int]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a > b for a, b in zip(self.values, other.values))
return vector(a > other for a in self.values)
def __lt__(self, other: VecNumLike) -> 'vector[int]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a < b for a, b in zip(self.values, other.values))
return vector(a < other for a in self.values)
def __ge__(self, other: VecNumLike) -> 'vector[int]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a >= b for a, b in zip(self.values, other.values))
return vector(a >= other for a in self.values)
def __le__(self, other: VecNumLike) -> 'vector[int]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a <= b for a, b in zip(self.values, other.values))
return vector(a <= other for a in self.values)
def __eq__(self, other: VecNumLike) -> 'vector[int]': # type: ignore
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a == b for a, b in zip(self.values, other.values))
return vector(a == other for a in self.values)
def __ne__(self, other: VecNumLike) -> 'vector[int]': # type: ignore
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a != b for a, b in zip(self.values, other.values))
return vector(a != other for a in self.values)
@overload @overload
def sum(self: 'vector[int]') -> int | variable[int]: ... def sum(self: 'vector[int]') -> int | variable[int]: ...
@overload @overload
@ -169,7 +212,7 @@ class vector(Generic[T]):
mag = self.magnitude() + epsilon mag = self.magnitude() + epsilon
return self / mag return self / mag
def homogenize(self) -> 'vector[T]': def homogenize(self) -> 'vector[TNum]':
if any(isinstance(val, variable) for val in self.values): if any(isinstance(val, variable) for val in self.values):
return vector(mixed_homogenize(self)) return vector(mixed_homogenize(self))
else: else: