eye function added for creating matrices

This commit is contained in:
Nicolas Kruse 2025-12-04 22:38:52 +01:00
parent 257fe96bb3
commit da92aa9e2c
3 changed files with 22 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from ._target import Target
from ._basic_types import NumLike, variable, generic_sdb, iif
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
from ._matrices import matrix, identity, zeros, ones, diagonal
from ._matrices import matrix, identity, zeros, ones, diagonal, eye
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
from ._autograd import grad
@ -40,5 +40,6 @@ __all__ = [
"angle_between",
"rotate_vector",
"vector_projection",
"grad"
"grad",
"eye"
]

View File

@ -13,12 +13,15 @@ U = TypeVar("U", int, float)
class matrix(Generic[TNum]):
"""Mathematical matrix class supporting basic operations and interactions with variables.
"""
def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]]):
def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]] | vector[TNum]):
"""Create a matrix with given values and variables.
Args:
values: iterable of iterable of constant values and variables
"""
if isinstance(values, vector):
rows = (values.values,)
else:
rows = tuple(tuple(row) for row in values)
if rows:
row_len = len(rows[0])
@ -33,8 +36,11 @@ class matrix(Generic[TNum]):
def __len__(self) -> int:
return self.rows
def __getitem__(self, index: int) -> tuple[variable[TNum] | TNum, ...]:
return self.values[index]
def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum:
assert len(key) == 2
row = key[0]
col = key[1]
return self.values[row][col]
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
return iter(self.values)
@ -251,8 +257,6 @@ class matrix(Generic[TNum]):
return self
# Utility functions for matrices
def identity(size: int) -> matrix[int]:
"""Create an identity matrix of given size."""
return matrix(
@ -277,6 +281,15 @@ def ones(rows: int, cols: int) -> matrix[int]:
)
def eye(rows: int, cols: int | None = None) -> matrix[int]:
"""Create a matrix with ones on the diagonal and zeros elsewhere."""
cols = cols if cols else rows
return matrix(
tuple(1 if i == j else 0 for j in range(cols))
for i in range(rows)
)
@overload
def diagonal(vec: 'vector[int]') -> matrix[int]: ...
@overload

View File

@ -223,8 +223,6 @@ class vector(Generic[TNum]):
return vector(func(x) for x in self.values)
# Utility functions for 3D vectors with two arguments
def cross_product(v1: vector[float], v2: vector[float]) -> vector[float]:
"""Calculate the cross product of two 3D vectors."""
return v1.cross(v2)