mirror of https://github.com/Nonannet/copapy.git
type hints revised
This commit is contained in:
parent
61dc29e68b
commit
ebb4abc5d3
|
|
@ -1,4 +1,4 @@
|
|||
from . import variable, vector
|
||||
from . import variable, vector, matrix
|
||||
import copapy.backend as cpb
|
||||
from typing import Any, Sequence, overload
|
||||
import copapy as cp
|
||||
|
|
@ -6,13 +6,25 @@ from ._basic_types import Net, unifloat
|
|||
|
||||
|
||||
@overload
|
||||
def grad(var: variable[Any], to: variable[Any]) -> unifloat: ...
|
||||
def grad(x: variable[Any], y: variable[Any]) -> unifloat: ...
|
||||
@overload
|
||||
def grad(var: variable[Any], to: Sequence[variable[Any]]) -> Sequence[unifloat]: ...
|
||||
def grad(x: variable[Any], y: Sequence[variable[Any]]) -> list[unifloat]: ...
|
||||
@overload
|
||||
def grad(var: variable[Any], to: vector[Any]) -> vector[float]: ...
|
||||
def grad(var: variable[Any], to: variable[Any] | Sequence[variable[Any]] | vector[Any]) -> unifloat | Sequence[unifloat] | vector[float]:
|
||||
edges = cpb.get_all_dag_edges([var.source])
|
||||
def grad(x: variable[Any], y: vector[Any]) -> vector[float]: ...
|
||||
@overload
|
||||
def grad(x: variable[Any], y: matrix[Any]) -> matrix[float]: ...
|
||||
def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[float]) -> Any:
|
||||
"""Returns the partial derivative dx/dy where x needs to be a scalar
|
||||
and y might be a scalar, a list of scalars, a vector or matrix.
|
||||
|
||||
Arguments:
|
||||
x: Value to return derivative of
|
||||
y: Value(s) to derive in respect to
|
||||
|
||||
Returns:
|
||||
Derivative of x with the type and dimensions of y
|
||||
"""
|
||||
edges = cpb.get_all_dag_edges([x.source])
|
||||
ordered_ops = cpb.stable_toposort(edges)
|
||||
|
||||
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
|
||||
|
|
@ -24,79 +36,81 @@ def grad(var: variable[Any], to: variable[Any] | Sequence[variable[Any]] | vecto
|
|||
for node in reversed(ordered_ops):
|
||||
print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
|
||||
if node.args:
|
||||
args: Sequence[Any] = [v for v in node.args]
|
||||
g = 1.0 if node is var.source else grad_dict[net_lookup[node]]
|
||||
args: Sequence[Any] = list(node.args)
|
||||
g = 1.0 if node is x.source else grad_dict[net_lookup[node]]
|
||||
opn = node.name.split('_')[0]
|
||||
x: variable[Any] = args[0]
|
||||
y: variable[Any] = args[1] if len(args) > 1 else x
|
||||
a: variable[Any] = args[0]
|
||||
b: variable[Any] = args[1] if len(args) > 1 else a
|
||||
|
||||
if opn in ['ge', 'gt', 'eq', 'ne']:
|
||||
pass # Derivative is 0
|
||||
if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']:
|
||||
pass # Derivative is 0 for all ops returning integers
|
||||
|
||||
elif opn == 'add':
|
||||
add_grad(x, g)
|
||||
add_grad(y, g)
|
||||
add_grad(a, g)
|
||||
add_grad(b, g)
|
||||
|
||||
elif opn == 'sub':
|
||||
add_grad(x, g)
|
||||
add_grad(y, -g)
|
||||
add_grad(a, g)
|
||||
add_grad(b, -g)
|
||||
|
||||
elif opn == 'mul':
|
||||
add_grad(x, y * g)
|
||||
add_grad(y, x * g)
|
||||
add_grad(a, b * g)
|
||||
add_grad(b, a * g)
|
||||
|
||||
elif opn == 'div':
|
||||
add_grad(x, g / y)
|
||||
add_grad(y, -x * g / (y**2))
|
||||
elif opn == 'div':
|
||||
add_grad(a, g / b)
|
||||
add_grad(b, -a * g / (b**2))
|
||||
|
||||
elif opn == 'pow':
|
||||
add_grad(x, (y * (x ** (y - 1))) * g)
|
||||
add_grad(y, (x ** y * cp.log(x)) * g)
|
||||
|
||||
elif opn == 'sqrt':
|
||||
add_grad(x, g * (0.5 / cp.sqrt(x)))
|
||||
|
||||
elif opn == 'abs':
|
||||
add_grad(x, g * cp.sign(x))
|
||||
|
||||
elif opn == 'sin':
|
||||
add_grad(x, g * cp.cos(x))
|
||||
|
||||
elif opn == 'cos':
|
||||
add_grad(x, g * -cp.sin(x))
|
||||
|
||||
elif opn == 'tan':
|
||||
add_grad(x, g * (1 / cp.cos(x) ** 2))
|
||||
|
||||
elif opn == 'asin':
|
||||
add_grad(x, g * (1 / cp.sqrt(1 - x**2)))
|
||||
|
||||
elif opn == 'acos':
|
||||
add_grad(x, g * (-1 / cp.sqrt(1 - x**2)))
|
||||
|
||||
elif opn == 'atan':
|
||||
add_grad(x, g * (1 / (1 + x**2)))
|
||||
|
||||
elif opn == 'atan2':
|
||||
denom = x**2 + y**2
|
||||
add_grad(x, g * (-y / denom))
|
||||
add_grad(y, g * ( x / denom))
|
||||
elif opn == 'mod':
|
||||
add_grad(a, g)
|
||||
add_grad(b, -a * g / b)
|
||||
|
||||
elif opn == 'log':
|
||||
add_grad(x, g / x)
|
||||
add_grad(a, g / a)
|
||||
|
||||
elif opn == 'exp':
|
||||
add_grad(x, g * cp.exp(x))
|
||||
add_grad(a, g * cp.exp(a))
|
||||
|
||||
elif opn == 'gt':
|
||||
add_grad(x, g)
|
||||
add_grad(y, -g)
|
||||
elif opn == 'pow':
|
||||
add_grad(a, (b * (a ** (b - 1))) * g)
|
||||
add_grad(b, (a ** b * cp.log(a)) * g)
|
||||
|
||||
elif opn == 'sqrt':
|
||||
add_grad(a, g * (0.5 / cp.sqrt(a)))
|
||||
|
||||
#elif opn == 'abs':
|
||||
# add_grad(x, g * cp.sign(x))
|
||||
|
||||
elif opn == 'sin':
|
||||
add_grad(a, g * cp.cos(a))
|
||||
|
||||
elif opn == 'cos':
|
||||
add_grad(a, g * -cp.sin(a))
|
||||
|
||||
elif opn == 'tan':
|
||||
add_grad(a, g * (1 / cp.cos(a) ** 2))
|
||||
|
||||
elif opn == 'asin':
|
||||
add_grad(a, g * (1 / cp.sqrt(1 - a**2)))
|
||||
|
||||
elif opn == 'acos':
|
||||
add_grad(a, g * (-1 / cp.sqrt(1 - a**2)))
|
||||
|
||||
elif opn == 'atan':
|
||||
add_grad(a, g * (1 / (1 + a**2)))
|
||||
|
||||
elif opn == 'atan2':
|
||||
denom = a**2 + b**2
|
||||
add_grad(a, g * (-b / denom))
|
||||
add_grad(b, g * ( a / denom))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Operation {opn} not yet supported for auto diff.")
|
||||
|
||||
if isinstance(to, variable):
|
||||
return grad_dict[to]
|
||||
if isinstance(to, vector):
|
||||
return vector(grad_dict[dvar] for dvar in to)
|
||||
return [grad_dict[dvar] for dvar in to]
|
||||
if isinstance(y, variable):
|
||||
return grad_dict[y]
|
||||
if isinstance(y, vector):
|
||||
return vector(grad_dict[yi] if isinstance(yi, variable) else 0.0 for yi in y)
|
||||
if isinstance(y, matrix):
|
||||
return matrix((grad_dict[yi] if isinstance(yi, variable) else 0.0 for yi in row) for row in y)
|
||||
return [grad_dict[yi] for yi in y]
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ import pkgutil
|
|||
from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast
|
||||
from ._stencils import stencil_database, detect_process_arch
|
||||
import copapy as cp
|
||||
from ._helper_types import TNum
|
||||
|
||||
NumLike: TypeAlias = 'variable[int] | variable[float] | int | float'
|
||||
unifloat: TypeAlias = 'variable[float] | float'
|
||||
uniint: TypeAlias = 'variable[int] | int'
|
||||
|
||||
TCPNum = TypeVar("TCPNum", bound='variable[Any]')
|
||||
TNum = TypeVar("TNum", int, float)
|
||||
TVarNumb: TypeAlias = 'variable[Any] | int | float'
|
||||
|
||||
stencil_cache: dict[tuple[str, str], stencil_database] = {}
|
||||
|
|
@ -312,7 +312,7 @@ class variable(Generic[TNum], Net):
|
|||
return add_op('bwand', [self, other], True)
|
||||
|
||||
def __rand__(self, other: uniint) -> 'variable[int]':
|
||||
return add_op('rwand', [other, self], True)
|
||||
return add_op('bwand', [other, self], True)
|
||||
|
||||
def __or__(self, other: uniint) -> 'variable[int]':
|
||||
return add_op('bwor', [self, other], True)
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], Non
|
|||
Tuples of (source_node, target_node) representing edges in the DAG
|
||||
"""
|
||||
emitted_nodes: set[tuple[Node, Node]] = set()
|
||||
|
||||
|
||||
for node in nodes:
|
||||
yield from get_all_dag_edges(net.source for net in node.args)
|
||||
for net in node.args:
|
||||
|
|
@ -138,7 +138,7 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list
|
|||
read_back_nets = {
|
||||
net for net, node in net_node_list
|
||||
if net and node.name.startswith('read_')}
|
||||
|
||||
|
||||
registers: list[Net | None] = [None, None]
|
||||
|
||||
for net, node in net_node_list:
|
||||
|
|
@ -253,7 +253,7 @@ def get_aux_func_layout(function_names: Iterable[str], sdb: stencil_database, of
|
|||
alignment = sdb.get_section_alignment(index)
|
||||
offset = (offset + alignment - 1) // alignment * alignment
|
||||
section_list.append((index, offset, lengths))
|
||||
section_cache[index] = offset
|
||||
section_cache[index] = offset
|
||||
function_lookup[name] = offset + sdb.get_symbol_offset(name)
|
||||
offset += lengths
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
from typing import TypeVar
|
||||
|
||||
TNum = TypeVar("TNum", int, float)
|
||||
U = TypeVar("U", int, float)
|
||||
|
|
@ -2,12 +2,13 @@ from . import vector
|
|||
from ._vectors import VecNumLike
|
||||
from . import variable, NumLike
|
||||
from typing import TypeVar, Any, overload, Callable
|
||||
from ._basic_types import add_op
|
||||
from ._basic_types import add_op, unifloat
|
||||
import math
|
||||
|
||||
T = TypeVar("T", int, float, variable[int], variable[float])
|
||||
U = TypeVar("U", int, float)
|
||||
|
||||
|
||||
@overload
|
||||
def exp(x: float | int) -> float: ...
|
||||
@overload
|
||||
|
|
@ -284,7 +285,9 @@ def get_42(x: NumLike) -> variable[float] | float:
|
|||
def abs(x: U) -> U: ...
|
||||
@overload
|
||||
def abs(x: variable[U]) -> variable[U]: ...
|
||||
def abs(x: U | variable[U]) -> Any:
|
||||
@overload
|
||||
def abs(x: vector[U]) -> vector[U]: ...
|
||||
def abs(x: U | variable[U] | vector[U]) -> Any:
|
||||
"""Absolute value function
|
||||
|
||||
Arguments:
|
||||
|
|
@ -293,16 +296,18 @@ def abs(x: U | variable[U]) -> Any:
|
|||
Returns:
|
||||
Absolute value of x
|
||||
"""
|
||||
#tt = -x * (x < 0)
|
||||
ret = (x < 0) * -x + (x >= 0) * x
|
||||
return ret # REMpyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
#TODO: Add vector support
|
||||
@overload
|
||||
def sign(x: U) -> U: ...
|
||||
@overload
|
||||
def sign(x: variable[U]) -> variable[U]: ...
|
||||
def sign(x: U | variable[U]) -> Any:
|
||||
@overload
|
||||
def sign(x: vector[U]) -> vector[U]: ...
|
||||
def sign(x: U | variable[U] | vector[U]) -> Any:
|
||||
"""Return 1 for positive numbers and -1 for negative numbers.
|
||||
For an input of 0 the return value is 0.
|
||||
|
||||
|
|
@ -333,13 +338,13 @@ def clamp(x: U | variable[U] | vector[U], min_value: U | variable[U], max_value:
|
|||
x: Input value
|
||||
min_value: Minimum limit
|
||||
max_value: Maximum limit
|
||||
|
||||
|
||||
Returns:
|
||||
Clamped value of x
|
||||
"""
|
||||
if isinstance(x, vector):
|
||||
return vector(clamp(comp, min_value, max_value) for comp in x.values)
|
||||
|
||||
|
||||
return (x < min_value) * min_value + \
|
||||
(x > max_value) * max_value + \
|
||||
((x >= min_value) & (x <= max_value)) * x
|
||||
|
|
@ -384,16 +389,16 @@ def max(x: U | variable[U], y: U | variable[U]) -> Any:
|
|||
|
||||
|
||||
@overload
|
||||
def lerp(v1: variable[U], v2: U | variable[U], t: U | variable[U]) -> variable[U]: ...
|
||||
def lerp(v1: variable[U], v2: U | variable[U], t: unifloat) -> variable[U]: ...
|
||||
@overload
|
||||
def lerp(v1: U | variable[U], v2: variable[U], t: U | variable[U]) -> variable[U]: ...
|
||||
def lerp(v1: U | variable[U], v2: variable[U], t: unifloat) -> variable[U]: ...
|
||||
@overload
|
||||
def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[U]) -> variable[U]: ...
|
||||
def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[float]) -> variable[U]: ...
|
||||
@overload
|
||||
def lerp(v1: U, v2: U, t: U) -> U: ...
|
||||
def lerp(v1: U, v2: U, t: float) -> U: ...
|
||||
@overload
|
||||
def lerp(v1: vector[U], v2: vector[U], t: 'U | variable[U]') -> vector[U]: ...
|
||||
def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U | variable[U]) -> Any:
|
||||
def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ...
|
||||
def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: unifloat) -> Any:
|
||||
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t."""
|
||||
if isinstance(v1, vector) or isinstance(v2, vector):
|
||||
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
|
||||
|
|
@ -402,12 +407,13 @@ def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U
|
|||
return v1 * (1 - t) + v2 * t
|
||||
|
||||
|
||||
#TODO: Add vector support
|
||||
@overload
|
||||
def relu(x: U) -> U: ...
|
||||
@overload
|
||||
def relu(x: variable[U]) -> variable[U]: ...
|
||||
def relu(x: U | variable[U]) -> Any:
|
||||
@overload
|
||||
def relu(x: vector[U]) -> vector[U]: ...
|
||||
def relu(x: U | variable[U] | vector[U]) -> Any:
|
||||
"""Returns x for x > 0 and otherwise 0."""
|
||||
ret = (x > 0) * x
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
from . import variable
|
||||
from ._vectors import vector
|
||||
from ._mixed import mixed_sum
|
||||
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator
|
||||
from typing import TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator, Generic
|
||||
from ._helper_types import TNum
|
||||
|
||||
MatNumLike: TypeAlias = 'matrix[int] | matrix[float] | variable[int] | variable[float] | int | float'
|
||||
MatIntLike: TypeAlias = 'matrix[int] | variable[int] | int'
|
||||
MatFloatLike: TypeAlias = 'matrix[float] | variable[float] | float'
|
||||
TT = TypeVar("TT", int, float)
|
||||
U = TypeVar("U", int, float)
|
||||
|
||||
|
||||
class matrix(Generic[TT]):
|
||||
class matrix(Generic[TNum]):
|
||||
"""Mathematical matrix class supporting basic operations and interactions with variables.
|
||||
"""
|
||||
def __init__(self, values: Iterable[Iterable[TT | variable[TT]]]):
|
||||
def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]]):
|
||||
"""Create a matrix with given values and variables.
|
||||
|
||||
Args:
|
||||
|
|
@ -23,7 +23,7 @@ class matrix(Generic[TT]):
|
|||
if rows:
|
||||
row_len = len(rows[0])
|
||||
assert all(len(row) == row_len for row in rows), "All rows must have the same length"
|
||||
self.values: tuple[tuple[variable[TT] | TT, ...], ...] = tuple(rows)
|
||||
self.values: tuple[tuple[variable[TNum] | TNum, ...], ...] = tuple(rows)
|
||||
self.rows = len(self.values)
|
||||
self.cols = len(self.values[0]) if self.values else 0
|
||||
|
||||
|
|
@ -33,13 +33,13 @@ class matrix(Generic[TT]):
|
|||
def __len__(self) -> int:
|
||||
return self.rows
|
||||
|
||||
def __getitem__(self, index: int) -> tuple[variable[TT] | TT, ...]:
|
||||
def __getitem__(self, index: int) -> tuple[variable[TNum] | TNum, ...]:
|
||||
return self.values[index]
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[variable[TT] | TT, ...]]:
|
||||
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
|
||||
return iter(self.values)
|
||||
|
||||
def __neg__(self) -> 'matrix[TT]':
|
||||
def __neg__(self) -> 'matrix[TNum]':
|
||||
return matrix((-a for a in row) for row in self.values)
|
||||
|
||||
@overload
|
||||
|
|
@ -165,23 +165,23 @@ class matrix(Generic[TT]):
|
|||
)
|
||||
|
||||
@overload
|
||||
def __matmul__(self: 'matrix[TT]', other: 'vector[TT]') -> 'vector[TT]': ...
|
||||
def __matmul__(self: 'matrix[TNum]', other: 'vector[TNum]') -> 'vector[TNum]': ...
|
||||
@overload
|
||||
def __matmul__(self: 'matrix[TT]', other: 'matrix[TT]') -> 'matrix[TT]': ...
|
||||
def __matmul__(self: 'matrix[TT]', other: 'matrix[TT] | vector[TT]') -> 'matrix[TT] | vector[TT]':
|
||||
def __matmul__(self: 'matrix[TNum]', other: 'matrix[TNum]') -> 'matrix[TNum]': ...
|
||||
def __matmul__(self: 'matrix[TNum]', other: 'matrix[TNum] | vector[TNum]') -> 'matrix[TNum] | vector[TNum]':
|
||||
"""Matrix multiplication using @ operator"""
|
||||
if isinstance(other, vector):
|
||||
assert self.cols == len(other.values), \
|
||||
f"Matrix columns ({self.cols}) must match vector length ({len(other.values)})"
|
||||
vec_result = (mixed_sum(a * b for a, b in zip(row, other.values)) for row in self.values)
|
||||
return vector(vec_result)
|
||||
else:
|
||||
else:
|
||||
assert isinstance(other, matrix), "Cannot multiply matrix with {type(other)}"
|
||||
assert self.cols == other.rows, \
|
||||
f"Matrix columns ({self.cols}) must match other matrix rows ({other.rows})"
|
||||
result: list[list[TT | variable[TT]]] = []
|
||||
result: list[list[TNum | variable[TNum]]] = []
|
||||
for row in self.values:
|
||||
new_row: list[TT | variable[TT]] = []
|
||||
new_row: list[TNum | variable[TNum]] = []
|
||||
for col_idx in range(other.cols):
|
||||
col = tuple(other.values[i][col_idx] for i in range(other.rows))
|
||||
element = sum(a * b for a, b in zip(row, col))
|
||||
|
|
@ -189,7 +189,7 @@ class matrix(Generic[TT]):
|
|||
result.append(new_row)
|
||||
return matrix(result)
|
||||
|
||||
def transpose(self) -> 'matrix[TT]':
|
||||
def transpose(self) -> 'matrix[TNum]':
|
||||
"""Return the transpose of the matrix."""
|
||||
if not self.values:
|
||||
return matrix([])
|
||||
|
|
@ -197,23 +197,23 @@ class matrix(Generic[TT]):
|
|||
tuple(self.values[i][j] for i in range(self.rows))
|
||||
for j in range(self.cols)
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def T(self) -> 'matrix[TT]':
|
||||
def T(self) -> 'matrix[TNum]':
|
||||
return self.transpose()
|
||||
|
||||
def row(self, index: int) -> vector[TT]:
|
||||
def row(self, index: int) -> vector[TNum]:
|
||||
"""Get a row as a vector."""
|
||||
assert 0 <= index < self.rows, f"Row index {index} out of bounds"
|
||||
return vector(self.values[index])
|
||||
|
||||
def col(self, index: int) -> vector[TT]:
|
||||
def col(self, index: int) -> vector[TNum]:
|
||||
"""Get a column as a vector."""
|
||||
assert 0 <= index < self.cols, f"Column index {index} out of bounds"
|
||||
return vector(self.values[i][index] for i in range(self.rows))
|
||||
|
||||
@overload
|
||||
def trace(self: 'matrix[TT]') -> TT | variable[TT]: ...
|
||||
def trace(self: 'matrix[TNum]') -> TNum | variable[TNum]: ...
|
||||
@overload
|
||||
def trace(self: 'matrix[int]') -> int | variable[int]: ...
|
||||
@overload
|
||||
|
|
@ -224,7 +224,7 @@ class matrix(Generic[TT]):
|
|||
return mixed_sum(self.values[i][i] for i in range(self.rows))
|
||||
|
||||
@overload
|
||||
def sum(self: 'matrix[TT]') -> TT | variable[TT]: ...
|
||||
def sum(self: 'matrix[TNum]') -> TNum | variable[TNum]: ...
|
||||
@overload
|
||||
def sum(self: 'matrix[int]') -> int | variable[int]: ...
|
||||
@overload
|
||||
|
|
@ -240,7 +240,7 @@ class matrix(Generic[TT]):
|
|||
for row in self.values
|
||||
)
|
||||
|
||||
def homogenize(self) -> 'matrix[TT]':
|
||||
def homogenize(self) -> 'matrix[TNum]':
|
||||
"""Convert all elements to variables if any element is a variable."""
|
||||
if any(isinstance(val, variable) for row in self.values for val in row):
|
||||
return matrix(
|
||||
|
|
|
|||
|
|
@ -1,27 +1,28 @@
|
|||
from . import variable
|
||||
from ._mixed import mixed_sum, mixed_homogenize
|
||||
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator
|
||||
from typing import TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator, Generic
|
||||
import copapy as cp
|
||||
from ._helper_types import TNum
|
||||
|
||||
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
|
||||
#VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
|
||||
VecNumLike: TypeAlias = 'vector[Any] | variable[Any] | int | float | bool'
|
||||
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
|
||||
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
|
||||
T = TypeVar("T", int, float)
|
||||
U = TypeVar("U", int, float)
|
||||
|
||||
epsilon = 1e-20
|
||||
|
||||
|
||||
class vector(Generic[T]):
|
||||
class vector(Generic[TNum]):
|
||||
"""Mathematical vector class supporting basic operations and interactions with variables.
|
||||
"""
|
||||
def __init__(self, values: Iterable[T | variable[T]]):
|
||||
def __init__(self, values: Iterable[TNum | variable[TNum]]):
|
||||
"""Create a vector with given values and variables.
|
||||
|
||||
Args:
|
||||
values: iterable of constant values and variables
|
||||
"""
|
||||
self.values: tuple[variable[T] | T, ...] = tuple(values)
|
||||
self.values: tuple[variable[TNum] | TNum, ...] = tuple(values)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"vector({self.values})"
|
||||
|
|
@ -29,13 +30,13 @@ class vector(Generic[T]):
|
|||
def __len__(self) -> int:
|
||||
return len(self.values)
|
||||
|
||||
def __getitem__(self, index: int) -> variable[T] | T:
|
||||
def __getitem__(self, index: int) -> variable[TNum] | TNum:
|
||||
return self.values[index]
|
||||
|
||||
def __neg__(self) -> 'vector[T]':
|
||||
def __neg__(self) -> 'vector[TNum]':
|
||||
return vector(-a for a in self.values)
|
||||
|
||||
def __iter__(self) -> Iterator[variable[T] | T]:
|
||||
def __iter__(self) -> Iterator[variable[TNum] | TNum]:
|
||||
return iter(self.values)
|
||||
|
||||
@overload
|
||||
|
|
@ -56,6 +57,8 @@ class vector(Generic[T]):
|
|||
def __radd__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
||||
@overload
|
||||
def __radd__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
||||
@overload
|
||||
def __radd__(self, other: VecNumLike) -> 'vector[Any]': ...
|
||||
def __radd__(self, other: Any) -> Any:
|
||||
return self + other
|
||||
|
||||
|
|
@ -77,6 +80,8 @@ class vector(Generic[T]):
|
|||
def __rsub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
||||
@overload
|
||||
def __rsub__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
||||
@overload
|
||||
def __rsub__(self, other: VecNumLike) -> 'vector[Any]': ...
|
||||
def __rsub__(self, other: VecNumLike) -> Any:
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
|
|
@ -101,6 +106,8 @@ class vector(Generic[T]):
|
|||
def __rmul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
||||
@overload
|
||||
def __rmul__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
||||
@overload
|
||||
def __rmul__(self, other: VecNumLike) -> 'vector[Any]': ...
|
||||
def __rmul__(self, other: VecNumLike) -> Any:
|
||||
return self * other
|
||||
|
||||
|
|
@ -150,6 +157,42 @@ class vector(Generic[T]):
|
|||
a3 * b1 - a1 * b3,
|
||||
a1 * b2 - a2 * b1
|
||||
])
|
||||
|
||||
def __gt__(self, other: VecNumLike) -> 'vector[int]':
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
return vector(a > b for a, b in zip(self.values, other.values))
|
||||
return vector(a > other for a in self.values)
|
||||
|
||||
def __lt__(self, other: VecNumLike) -> 'vector[int]':
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
return vector(a < b for a, b in zip(self.values, other.values))
|
||||
return vector(a < other for a in self.values)
|
||||
|
||||
def __ge__(self, other: VecNumLike) -> 'vector[int]':
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
return vector(a >= b for a, b in zip(self.values, other.values))
|
||||
return vector(a >= other for a in self.values)
|
||||
|
||||
def __le__(self, other: VecNumLike) -> 'vector[int]':
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
return vector(a <= b for a, b in zip(self.values, other.values))
|
||||
return vector(a <= other for a in self.values)
|
||||
|
||||
def __eq__(self, other: VecNumLike) -> 'vector[int]': # type: ignore
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
return vector(a == b for a, b in zip(self.values, other.values))
|
||||
return vector(a == other for a in self.values)
|
||||
|
||||
def __ne__(self, other: VecNumLike) -> 'vector[int]': # type: ignore
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
return vector(a != b for a, b in zip(self.values, other.values))
|
||||
return vector(a != other for a in self.values)
|
||||
|
||||
@overload
|
||||
def sum(self: 'vector[int]') -> int | variable[int]: ...
|
||||
|
|
@ -168,8 +211,8 @@ class vector(Generic[T]):
|
|||
"""Returns a normalized (unit length) version of the vector."""
|
||||
mag = self.magnitude() + epsilon
|
||||
return self / mag
|
||||
|
||||
def homogenize(self) -> 'vector[T]':
|
||||
|
||||
def homogenize(self) -> 'vector[TNum]':
|
||||
if any(isinstance(val, variable) for val in self.values):
|
||||
return vector(mixed_homogenize(self))
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue