From 959d80b082fd0bef4ae5a5688503f67c0dcc29a7 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Fri, 5 Dec 2025 08:28:02 +0100 Subject: [PATCH] shape property added and __getitem__ of matrix extended --- src/copapy/_matrices.py | 28 +++++++++++++++++++++++----- src/copapy/_vectors.py | 5 +++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/copapy/_matrices.py b/src/copapy/_matrices.py index 3086a8c..906b983 100644 --- a/src/copapy/_matrices.py +++ b/src/copapy/_matrices.py @@ -34,13 +34,26 @@ class matrix(Generic[TNum]): return f"matrix({self.values})" def __len__(self) -> int: + """Return the number of rows in the matrix.""" return self.rows - def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum: - assert len(key) == 2 - row = key[0] - col = key[1] - return self.values[row][col] + @overload + def __getitem__(self, key: int) -> vector[TNum]: ... + @overload + def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum: ... + def __getitem__(self, key: int | tuple[int, int]) -> Any: + """Get a row as a vector or a specific element. + Args: + key: row index or (row, col) tuple + + Returns: + vector if row index is given, else the element at (row, col) + """ + if isinstance(key, tuple): + assert len(key) == 2 + return self.values[key[0]][key[1]] + else: + return vector(self.values[key]) def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]: return iter(self.values) @@ -203,6 +216,11 @@ class matrix(Generic[TNum]): tuple(self.values[i][j] for i in range(self.rows)) for j in range(self.cols) ) + + @property + def shape(self) -> tuple[int, int]: + """Return the shape of the matrix as (rows, cols).""" + return (self.rows, self.cols) @property def T(self) -> 'matrix[TNum]': diff --git a/src/copapy/_vectors.py b/src/copapy/_vectors.py index 4188d6d..3c50d2a 100644 --- a/src/copapy/_vectors.py +++ b/src/copapy/_vectors.py @@ -193,6 +193,11 @@ class vector(Generic[TNum]): assert len(self.values) == len(other.values) return vector(a != b for a, b in zip(self.values, other.values)) return vector(a != other for a in self.values) + + @property + def shape(self) -> tuple[int]: + """Return the shape of the vector as (length,).""" + return (len(self.values),) @overload def sum(self: 'vector[int]') -> int | variable[int]: ...