type hints revised

This commit is contained in:
Nicolas Kruse 2025-12-04 18:18:29 +01:00
parent 61dc29e68b
commit ebb4abc5d3
7 changed files with 182 additions and 115 deletions

View File

@ -1,4 +1,4 @@
from . import variable, vector
from . import variable, vector, matrix
import copapy.backend as cpb
from typing import Any, Sequence, overload
import copapy as cp
@ -6,13 +6,25 @@ from ._basic_types import Net, unifloat
@overload
def grad(var: variable[Any], to: variable[Any]) -> unifloat: ...
def grad(x: variable[Any], y: variable[Any]) -> unifloat: ...
@overload
def grad(var: variable[Any], to: Sequence[variable[Any]]) -> Sequence[unifloat]: ...
def grad(x: variable[Any], y: Sequence[variable[Any]]) -> list[unifloat]: ...
@overload
def grad(var: variable[Any], to: vector[Any]) -> vector[float]: ...
def grad(var: variable[Any], to: variable[Any] | Sequence[variable[Any]] | vector[Any]) -> unifloat | Sequence[unifloat] | vector[float]:
edges = cpb.get_all_dag_edges([var.source])
def grad(x: variable[Any], y: vector[Any]) -> vector[float]: ...
@overload
def grad(x: variable[Any], y: matrix[Any]) -> matrix[float]: ...
def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[float]) -> Any:
"""Returns the partial derivative dx/dy where x needs to be a scalar
and y might be a scalar, a list of scalars, a vector or matrix.
Arguments:
x: Value to return derivative of
y: Value(s) to derive in respect to
Returns:
Derivative of x with the type and dimensions of y
"""
edges = cpb.get_all_dag_edges([x.source])
ordered_ops = cpb.stable_toposort(edges)
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
@ -24,79 +36,81 @@ def grad(var: variable[Any], to: variable[Any] | Sequence[variable[Any]] | vecto
for node in reversed(ordered_ops):
print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
if node.args:
args: Sequence[Any] = [v for v in node.args]
g = 1.0 if node is var.source else grad_dict[net_lookup[node]]
args: Sequence[Any] = list(node.args)
g = 1.0 if node is x.source else grad_dict[net_lookup[node]]
opn = node.name.split('_')[0]
x: variable[Any] = args[0]
y: variable[Any] = args[1] if len(args) > 1 else x
a: variable[Any] = args[0]
b: variable[Any] = args[1] if len(args) > 1 else a
if opn in ['ge', 'gt', 'eq', 'ne']:
pass # Derivative is 0
if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']:
pass # Derivative is 0 for all ops returning integers
elif opn == 'add':
add_grad(x, g)
add_grad(y, g)
add_grad(a, g)
add_grad(b, g)
elif opn == 'sub':
add_grad(x, g)
add_grad(y, -g)
add_grad(a, g)
add_grad(b, -g)
elif opn == 'mul':
add_grad(x, y * g)
add_grad(y, x * g)
add_grad(a, b * g)
add_grad(b, a * g)
elif opn == 'div':
add_grad(x, g / y)
add_grad(y, -x * g / (y**2))
elif opn == 'div':
add_grad(a, g / b)
add_grad(b, -a * g / (b**2))
elif opn == 'pow':
add_grad(x, (y * (x ** (y - 1))) * g)
add_grad(y, (x ** y * cp.log(x)) * g)
elif opn == 'sqrt':
add_grad(x, g * (0.5 / cp.sqrt(x)))
elif opn == 'abs':
add_grad(x, g * cp.sign(x))
elif opn == 'sin':
add_grad(x, g * cp.cos(x))
elif opn == 'cos':
add_grad(x, g * -cp.sin(x))
elif opn == 'tan':
add_grad(x, g * (1 / cp.cos(x) ** 2))
elif opn == 'asin':
add_grad(x, g * (1 / cp.sqrt(1 - x**2)))
elif opn == 'acos':
add_grad(x, g * (-1 / cp.sqrt(1 - x**2)))
elif opn == 'atan':
add_grad(x, g * (1 / (1 + x**2)))
elif opn == 'atan2':
denom = x**2 + y**2
add_grad(x, g * (-y / denom))
add_grad(y, g * ( x / denom))
elif opn == 'mod':
add_grad(a, g)
add_grad(b, -a * g / b)
elif opn == 'log':
add_grad(x, g / x)
add_grad(a, g / a)
elif opn == 'exp':
add_grad(x, g * cp.exp(x))
add_grad(a, g * cp.exp(a))
elif opn == 'gt':
add_grad(x, g)
add_grad(y, -g)
elif opn == 'pow':
add_grad(a, (b * (a ** (b - 1))) * g)
add_grad(b, (a ** b * cp.log(a)) * g)
elif opn == 'sqrt':
add_grad(a, g * (0.5 / cp.sqrt(a)))
#elif opn == 'abs':
# add_grad(x, g * cp.sign(x))
elif opn == 'sin':
add_grad(a, g * cp.cos(a))
elif opn == 'cos':
add_grad(a, g * -cp.sin(a))
elif opn == 'tan':
add_grad(a, g * (1 / cp.cos(a) ** 2))
elif opn == 'asin':
add_grad(a, g * (1 / cp.sqrt(1 - a**2)))
elif opn == 'acos':
add_grad(a, g * (-1 / cp.sqrt(1 - a**2)))
elif opn == 'atan':
add_grad(a, g * (1 / (1 + a**2)))
elif opn == 'atan2':
denom = a**2 + b**2
add_grad(a, g * (-b / denom))
add_grad(b, g * ( a / denom))
else:
raise ValueError(f"Operation {opn} not yet supported for auto diff.")
if isinstance(to, variable):
return grad_dict[to]
if isinstance(to, vector):
return vector(grad_dict[dvar] for dvar in to)
return [grad_dict[dvar] for dvar in to]
if isinstance(y, variable):
return grad_dict[y]
if isinstance(y, vector):
return vector(grad_dict[yi] if isinstance(yi, variable) else 0.0 for yi in y)
if isinstance(y, matrix):
return matrix((grad_dict[yi] if isinstance(yi, variable) else 0.0 for yi in row) for row in y)
return [grad_dict[yi] for yi in y]

View File

@ -2,13 +2,13 @@ import pkgutil
from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast
from ._stencils import stencil_database, detect_process_arch
import copapy as cp
from ._helper_types import TNum
NumLike: TypeAlias = 'variable[int] | variable[float] | int | float'
unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int'
TCPNum = TypeVar("TCPNum", bound='variable[Any]')
TNum = TypeVar("TNum", int, float)
TVarNumb: TypeAlias = 'variable[Any] | int | float'
stencil_cache: dict[tuple[str, str], stencil_database] = {}
@ -312,7 +312,7 @@ class variable(Generic[TNum], Net):
return add_op('bwand', [self, other], True)
def __rand__(self, other: uniint) -> 'variable[int]':
return add_op('rwand', [other, self], True)
return add_op('bwand', [other, self], True)
def __or__(self, other: uniint) -> 'variable[int]':
return add_op('bwor', [self, other], True)

View File

@ -65,7 +65,7 @@ def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], Non
Tuples of (source_node, target_node) representing edges in the DAG
"""
emitted_nodes: set[tuple[Node, Node]] = set()
for node in nodes:
yield from get_all_dag_edges(net.source for net in node.args)
for net in node.args:
@ -138,7 +138,7 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list
read_back_nets = {
net for net, node in net_node_list
if net and node.name.startswith('read_')}
registers: list[Net | None] = [None, None]
for net, node in net_node_list:
@ -253,7 +253,7 @@ def get_aux_func_layout(function_names: Iterable[str], sdb: stencil_database, of
alignment = sdb.get_section_alignment(index)
offset = (offset + alignment - 1) // alignment * alignment
section_list.append((index, offset, lengths))
section_cache[index] = offset
section_cache[index] = offset
function_lookup[name] = offset + sdb.get_symbol_offset(name)
offset += lengths

View File

@ -0,0 +1,4 @@
from typing import TypeVar
TNum = TypeVar("TNum", int, float)
U = TypeVar("U", int, float)

View File

@ -2,12 +2,13 @@ from . import vector
from ._vectors import VecNumLike
from . import variable, NumLike
from typing import TypeVar, Any, overload, Callable
from ._basic_types import add_op
from ._basic_types import add_op, unifloat
import math
T = TypeVar("T", int, float, variable[int], variable[float])
U = TypeVar("U", int, float)
@overload
def exp(x: float | int) -> float: ...
@overload
@ -284,7 +285,9 @@ def get_42(x: NumLike) -> variable[float] | float:
def abs(x: U) -> U: ...
@overload
def abs(x: variable[U]) -> variable[U]: ...
def abs(x: U | variable[U]) -> Any:
@overload
def abs(x: vector[U]) -> vector[U]: ...
def abs(x: U | variable[U] | vector[U]) -> Any:
"""Absolute value function
Arguments:
@ -293,16 +296,18 @@ def abs(x: U | variable[U]) -> Any:
Returns:
Absolute value of x
"""
#tt = -x * (x < 0)
ret = (x < 0) * -x + (x >= 0) * x
return ret # REMpyright: ignore[reportReturnType]
#TODO: Add vector support
@overload
def sign(x: U) -> U: ...
@overload
def sign(x: variable[U]) -> variable[U]: ...
def sign(x: U | variable[U]) -> Any:
@overload
def sign(x: vector[U]) -> vector[U]: ...
def sign(x: U | variable[U] | vector[U]) -> Any:
"""Return 1 for positive numbers and -1 for negative numbers.
For an input of 0 the return value is 0.
@ -333,13 +338,13 @@ def clamp(x: U | variable[U] | vector[U], min_value: U | variable[U], max_value:
x: Input value
min_value: Minimum limit
max_value: Maximum limit
Returns:
Clamped value of x
"""
if isinstance(x, vector):
return vector(clamp(comp, min_value, max_value) for comp in x.values)
return (x < min_value) * min_value + \
(x > max_value) * max_value + \
((x >= min_value) & (x <= max_value)) * x
@ -384,16 +389,16 @@ def max(x: U | variable[U], y: U | variable[U]) -> Any:
@overload
def lerp(v1: variable[U], v2: U | variable[U], t: U | variable[U]) -> variable[U]: ...
def lerp(v1: variable[U], v2: U | variable[U], t: unifloat) -> variable[U]: ...
@overload
def lerp(v1: U | variable[U], v2: variable[U], t: U | variable[U]) -> variable[U]: ...
def lerp(v1: U | variable[U], v2: variable[U], t: unifloat) -> variable[U]: ...
@overload
def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[U]) -> variable[U]: ...
def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[float]) -> variable[U]: ...
@overload
def lerp(v1: U, v2: U, t: U) -> U: ...
def lerp(v1: U, v2: U, t: float) -> U: ...
@overload
def lerp(v1: vector[U], v2: vector[U], t: 'U | variable[U]') -> vector[U]: ...
def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U | variable[U]) -> Any:
def lerp(v1: vector[U], v2: vector[U], t: unifloat) -> vector[U]: ...
def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: unifloat) -> Any:
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t."""
if isinstance(v1, vector) or isinstance(v2, vector):
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
@ -402,12 +407,13 @@ def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U
return v1 * (1 - t) + v2 * t
#TODO: Add vector support
@overload
def relu(x: U) -> U: ...
@overload
def relu(x: variable[U]) -> variable[U]: ...
def relu(x: U | variable[U]) -> Any:
@overload
def relu(x: vector[U]) -> vector[U]: ...
def relu(x: U | variable[U] | vector[U]) -> Any:
"""Returns x for x > 0 and otherwise 0."""
ret = (x > 0) * x
return ret

View File

@ -1,19 +1,19 @@
from . import variable
from ._vectors import vector
from ._mixed import mixed_sum
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator
from typing import TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator, Generic
from ._helper_types import TNum
MatNumLike: TypeAlias = 'matrix[int] | matrix[float] | variable[int] | variable[float] | int | float'
MatIntLike: TypeAlias = 'matrix[int] | variable[int] | int'
MatFloatLike: TypeAlias = 'matrix[float] | variable[float] | float'
TT = TypeVar("TT", int, float)
U = TypeVar("U", int, float)
class matrix(Generic[TT]):
class matrix(Generic[TNum]):
"""Mathematical matrix class supporting basic operations and interactions with variables.
"""
def __init__(self, values: Iterable[Iterable[TT | variable[TT]]]):
def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]]):
"""Create a matrix with given values and variables.
Args:
@ -23,7 +23,7 @@ class matrix(Generic[TT]):
if rows:
row_len = len(rows[0])
assert all(len(row) == row_len for row in rows), "All rows must have the same length"
self.values: tuple[tuple[variable[TT] | TT, ...], ...] = tuple(rows)
self.values: tuple[tuple[variable[TNum] | TNum, ...], ...] = tuple(rows)
self.rows = len(self.values)
self.cols = len(self.values[0]) if self.values else 0
@ -33,13 +33,13 @@ class matrix(Generic[TT]):
def __len__(self) -> int:
return self.rows
def __getitem__(self, index: int) -> tuple[variable[TT] | TT, ...]:
def __getitem__(self, index: int) -> tuple[variable[TNum] | TNum, ...]:
return self.values[index]
def __iter__(self) -> Iterator[tuple[variable[TT] | TT, ...]]:
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
return iter(self.values)
def __neg__(self) -> 'matrix[TT]':
def __neg__(self) -> 'matrix[TNum]':
return matrix((-a for a in row) for row in self.values)
@overload
@ -165,23 +165,23 @@ class matrix(Generic[TT]):
)
@overload
def __matmul__(self: 'matrix[TT]', other: 'vector[TT]') -> 'vector[TT]': ...
def __matmul__(self: 'matrix[TNum]', other: 'vector[TNum]') -> 'vector[TNum]': ...
@overload
def __matmul__(self: 'matrix[TT]', other: 'matrix[TT]') -> 'matrix[TT]': ...
def __matmul__(self: 'matrix[TT]', other: 'matrix[TT] | vector[TT]') -> 'matrix[TT] | vector[TT]':
def __matmul__(self: 'matrix[TNum]', other: 'matrix[TNum]') -> 'matrix[TNum]': ...
def __matmul__(self: 'matrix[TNum]', other: 'matrix[TNum] | vector[TNum]') -> 'matrix[TNum] | vector[TNum]':
"""Matrix multiplication using @ operator"""
if isinstance(other, vector):
assert self.cols == len(other.values), \
f"Matrix columns ({self.cols}) must match vector length ({len(other.values)})"
vec_result = (mixed_sum(a * b for a, b in zip(row, other.values)) for row in self.values)
return vector(vec_result)
else:
else:
assert isinstance(other, matrix), "Cannot multiply matrix with {type(other)}"
assert self.cols == other.rows, \
f"Matrix columns ({self.cols}) must match other matrix rows ({other.rows})"
result: list[list[TT | variable[TT]]] = []
result: list[list[TNum | variable[TNum]]] = []
for row in self.values:
new_row: list[TT | variable[TT]] = []
new_row: list[TNum | variable[TNum]] = []
for col_idx in range(other.cols):
col = tuple(other.values[i][col_idx] for i in range(other.rows))
element = sum(a * b for a, b in zip(row, col))
@ -189,7 +189,7 @@ class matrix(Generic[TT]):
result.append(new_row)
return matrix(result)
def transpose(self) -> 'matrix[TT]':
def transpose(self) -> 'matrix[TNum]':
"""Return the transpose of the matrix."""
if not self.values:
return matrix([])
@ -197,23 +197,23 @@ class matrix(Generic[TT]):
tuple(self.values[i][j] for i in range(self.rows))
for j in range(self.cols)
)
@property
def T(self) -> 'matrix[TT]':
def T(self) -> 'matrix[TNum]':
return self.transpose()
def row(self, index: int) -> vector[TT]:
def row(self, index: int) -> vector[TNum]:
"""Get a row as a vector."""
assert 0 <= index < self.rows, f"Row index {index} out of bounds"
return vector(self.values[index])
def col(self, index: int) -> vector[TT]:
def col(self, index: int) -> vector[TNum]:
"""Get a column as a vector."""
assert 0 <= index < self.cols, f"Column index {index} out of bounds"
return vector(self.values[i][index] for i in range(self.rows))
@overload
def trace(self: 'matrix[TT]') -> TT | variable[TT]: ...
def trace(self: 'matrix[TNum]') -> TNum | variable[TNum]: ...
@overload
def trace(self: 'matrix[int]') -> int | variable[int]: ...
@overload
@ -224,7 +224,7 @@ class matrix(Generic[TT]):
return mixed_sum(self.values[i][i] for i in range(self.rows))
@overload
def sum(self: 'matrix[TT]') -> TT | variable[TT]: ...
def sum(self: 'matrix[TNum]') -> TNum | variable[TNum]: ...
@overload
def sum(self: 'matrix[int]') -> int | variable[int]: ...
@overload
@ -240,7 +240,7 @@ class matrix(Generic[TT]):
for row in self.values
)
def homogenize(self) -> 'matrix[TT]':
def homogenize(self) -> 'matrix[TNum]':
"""Convert all elements to variables if any element is a variable."""
if any(isinstance(val, variable) for row in self.values for val in row):
return matrix(

View File

@ -1,27 +1,28 @@
from . import variable
from ._mixed import mixed_sum, mixed_homogenize
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator
from typing import TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator, Generic
import copapy as cp
from ._helper_types import TNum
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
#VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
VecNumLike: TypeAlias = 'vector[Any] | variable[Any] | int | float | bool'
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
T = TypeVar("T", int, float)
U = TypeVar("U", int, float)
epsilon = 1e-20
class vector(Generic[T]):
class vector(Generic[TNum]):
"""Mathematical vector class supporting basic operations and interactions with variables.
"""
def __init__(self, values: Iterable[T | variable[T]]):
def __init__(self, values: Iterable[TNum | variable[TNum]]):
"""Create a vector with given values and variables.
Args:
values: iterable of constant values and variables
"""
self.values: tuple[variable[T] | T, ...] = tuple(values)
self.values: tuple[variable[TNum] | TNum, ...] = tuple(values)
def __repr__(self) -> str:
return f"vector({self.values})"
@ -29,13 +30,13 @@ class vector(Generic[T]):
def __len__(self) -> int:
return len(self.values)
def __getitem__(self, index: int) -> variable[T] | T:
def __getitem__(self, index: int) -> variable[TNum] | TNum:
return self.values[index]
def __neg__(self) -> 'vector[T]':
def __neg__(self) -> 'vector[TNum]':
return vector(-a for a in self.values)
def __iter__(self) -> Iterator[variable[T] | T]:
def __iter__(self) -> Iterator[variable[TNum] | TNum]:
return iter(self.values)
@overload
@ -56,6 +57,8 @@ class vector(Generic[T]):
def __radd__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __radd__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
@overload
def __radd__(self, other: VecNumLike) -> 'vector[Any]': ...
def __radd__(self, other: Any) -> Any:
return self + other
@ -77,6 +80,8 @@ class vector(Generic[T]):
def __rsub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __rsub__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
@overload
def __rsub__(self, other: VecNumLike) -> 'vector[Any]': ...
def __rsub__(self, other: VecNumLike) -> Any:
if isinstance(other, vector):
assert len(self.values) == len(other.values)
@ -101,6 +106,8 @@ class vector(Generic[T]):
def __rmul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __rmul__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
@overload
def __rmul__(self, other: VecNumLike) -> 'vector[Any]': ...
def __rmul__(self, other: VecNumLike) -> Any:
return self * other
@ -150,6 +157,42 @@ class vector(Generic[T]):
a3 * b1 - a1 * b3,
a1 * b2 - a2 * b1
])
def __gt__(self, other: VecNumLike) -> 'vector[int]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a > b for a, b in zip(self.values, other.values))
return vector(a > other for a in self.values)
def __lt__(self, other: VecNumLike) -> 'vector[int]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a < b for a, b in zip(self.values, other.values))
return vector(a < other for a in self.values)
def __ge__(self, other: VecNumLike) -> 'vector[int]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a >= b for a, b in zip(self.values, other.values))
return vector(a >= other for a in self.values)
def __le__(self, other: VecNumLike) -> 'vector[int]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a <= b for a, b in zip(self.values, other.values))
return vector(a <= other for a in self.values)
def __eq__(self, other: VecNumLike) -> 'vector[int]': # type: ignore
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a == b for a, b in zip(self.values, other.values))
return vector(a == other for a in self.values)
def __ne__(self, other: VecNumLike) -> 'vector[int]': # type: ignore
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a != b for a, b in zip(self.values, other.values))
return vector(a != other for a in self.values)
@overload
def sum(self: 'vector[int]') -> int | variable[int]: ...
@ -168,8 +211,8 @@ class vector(Generic[T]):
"""Returns a normalized (unit length) version of the vector."""
mag = self.magnitude() + epsilon
return self / mag
def homogenize(self) -> 'vector[T]':
def homogenize(self) -> 'vector[TNum]':
if any(isinstance(val, variable) for val in self.values):
return vector(mixed_homogenize(self))
else: