From da92aa9e2cbe26f68cad3e06d8e5967e28f79071 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Thu, 4 Dec 2025 22:38:52 +0100 Subject: [PATCH] eye function added for creating matrices --- src/copapy/__init__.py | 5 +++-- src/copapy/_matrices.py | 25 +++++++++++++++++++------ src/copapy/_vectors.py | 2 -- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index b8216c6..1309d5f 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -1,7 +1,7 @@ from ._target import Target from ._basic_types import NumLike, variable, generic_sdb, iif from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection -from ._matrices import matrix, identity, zeros, ones, diagonal +from ._matrices import matrix, identity, zeros, ones, diagonal, eye from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu from ._autograd import grad @@ -40,5 +40,6 @@ __all__ = [ "angle_between", "rotate_vector", "vector_projection", - "grad" + "grad", + "eye" ] diff --git a/src/copapy/_matrices.py b/src/copapy/_matrices.py index f1a28c2..3086a8c 100644 --- a/src/copapy/_matrices.py +++ b/src/copapy/_matrices.py @@ -13,13 +13,16 @@ U = TypeVar("U", int, float) class matrix(Generic[TNum]): """Mathematical matrix class supporting basic operations and interactions with variables. """ - def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]]): + def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]] | vector[TNum]): """Create a matrix with given values and variables. Args: values: iterable of iterable of constant values and variables """ - rows = tuple(tuple(row) for row in values) + if isinstance(values, vector): + rows = (values.values,) + else: + rows = tuple(tuple(row) for row in values) if rows: row_len = len(rows[0]) assert all(len(row) == row_len for row in rows), "All rows must have the same length" @@ -33,8 +36,11 @@ class matrix(Generic[TNum]): def __len__(self) -> int: return self.rows - def __getitem__(self, index: int) -> tuple[variable[TNum] | TNum, ...]: - return self.values[index] + def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum: + assert len(key) == 2 + row = key[0] + col = key[1] + return self.values[row][col] def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]: return iter(self.values) @@ -251,8 +257,6 @@ class matrix(Generic[TNum]): return self -# Utility functions for matrices - def identity(size: int) -> matrix[int]: """Create an identity matrix of given size.""" return matrix( @@ -277,6 +281,15 @@ def ones(rows: int, cols: int) -> matrix[int]: ) +def eye(rows: int, cols: int | None = None) -> matrix[int]: + """Create a matrix with ones on the diagonal and zeros elsewhere.""" + cols = cols if cols else rows + return matrix( + tuple(1 if i == j else 0 for j in range(cols)) + for i in range(rows) + ) + + @overload def diagonal(vec: 'vector[int]') -> matrix[int]: ... @overload diff --git a/src/copapy/_vectors.py b/src/copapy/_vectors.py index f6baee4..4188d6d 100644 --- a/src/copapy/_vectors.py +++ b/src/copapy/_vectors.py @@ -223,8 +223,6 @@ class vector(Generic[TNum]): return vector(func(x) for x in self.values) -# Utility functions for 3D vectors with two arguments - def cross_product(v1: vector[float], v2: vector[float]) -> vector[float]: """Calculate the cross product of two 3D vectors.""" return v1.cross(v2)