eye function added for creating matrices

This commit is contained in:
Nicolas Kruse 2025-12-04 22:38:52 +01:00
parent 257fe96bb3
commit da92aa9e2c
3 changed files with 22 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from ._target import Target from ._target import Target
from ._basic_types import NumLike, variable, generic_sdb, iif from ._basic_types import NumLike, variable, generic_sdb, iif
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
from ._matrices import matrix, identity, zeros, ones, diagonal from ._matrices import matrix, identity, zeros, ones, diagonal, eye
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
from ._autograd import grad from ._autograd import grad
@ -40,5 +40,6 @@ __all__ = [
"angle_between", "angle_between",
"rotate_vector", "rotate_vector",
"vector_projection", "vector_projection",
"grad" "grad",
"eye"
] ]

View File

@ -13,13 +13,16 @@ U = TypeVar("U", int, float)
class matrix(Generic[TNum]): class matrix(Generic[TNum]):
"""Mathematical matrix class supporting basic operations and interactions with variables. """Mathematical matrix class supporting basic operations and interactions with variables.
""" """
def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]]): def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]] | vector[TNum]):
"""Create a matrix with given values and variables. """Create a matrix with given values and variables.
Args: Args:
values: iterable of iterable of constant values and variables values: iterable of iterable of constant values and variables
""" """
rows = tuple(tuple(row) for row in values) if isinstance(values, vector):
rows = (values.values,)
else:
rows = tuple(tuple(row) for row in values)
if rows: if rows:
row_len = len(rows[0]) row_len = len(rows[0])
assert all(len(row) == row_len for row in rows), "All rows must have the same length" assert all(len(row) == row_len for row in rows), "All rows must have the same length"
@ -33,8 +36,11 @@ class matrix(Generic[TNum]):
def __len__(self) -> int: def __len__(self) -> int:
return self.rows return self.rows
def __getitem__(self, index: int) -> tuple[variable[TNum] | TNum, ...]: def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum:
return self.values[index] assert len(key) == 2
row = key[0]
col = key[1]
return self.values[row][col]
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]: def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
return iter(self.values) return iter(self.values)
@ -251,8 +257,6 @@ class matrix(Generic[TNum]):
return self return self
# Utility functions for matrices
def identity(size: int) -> matrix[int]: def identity(size: int) -> matrix[int]:
"""Create an identity matrix of given size.""" """Create an identity matrix of given size."""
return matrix( return matrix(
@ -277,6 +281,15 @@ def ones(rows: int, cols: int) -> matrix[int]:
) )
def eye(rows: int, cols: int | None = None) -> matrix[int]:
"""Create a matrix with ones on the diagonal and zeros elsewhere."""
cols = cols if cols else rows
return matrix(
tuple(1 if i == j else 0 for j in range(cols))
for i in range(rows)
)
@overload @overload
def diagonal(vec: 'vector[int]') -> matrix[int]: ... def diagonal(vec: 'vector[int]') -> matrix[int]: ...
@overload @overload

View File

@ -223,8 +223,6 @@ class vector(Generic[TNum]):
return vector(func(x) for x in self.values) return vector(func(x) for x in self.values)
# Utility functions for 3D vectors with two arguments
def cross_product(v1: vector[float], v2: vector[float]) -> vector[float]: def cross_product(v1: vector[float], v2: vector[float]) -> vector[float]:
"""Calculate the cross product of two 3D vectors.""" """Calculate the cross product of two 3D vectors."""
return v1.cross(v2) return v1.cross(v2)