shape property added and __getitem__ of matrix extended

This commit is contained in:
Nicolas Kruse 2025-12-05 08:28:02 +01:00
parent 6d47779c03
commit 959d80b082
2 changed files with 28 additions and 5 deletions

View File

@ -34,13 +34,26 @@ class matrix(Generic[TNum]):
return f"matrix({self.values})"
def __len__(self) -> int:
"""Return the number of rows in the matrix."""
return self.rows
def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum:
assert len(key) == 2
row = key[0]
col = key[1]
return self.values[row][col]
@overload
def __getitem__(self, key: int) -> vector[TNum]: ...
@overload
def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum: ...
def __getitem__(self, key: int | tuple[int, int]) -> Any:
"""Get a row as a vector or a specific element.
Args:
key: row index or (row, col) tuple
Returns:
vector if row index is given, else the element at (row, col)
"""
if isinstance(key, tuple):
assert len(key) == 2
return self.values[key[0]][key[1]]
else:
return vector(self.values[key])
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
return iter(self.values)
@ -204,6 +217,11 @@ class matrix(Generic[TNum]):
for j in range(self.cols)
)
@property
def shape(self) -> tuple[int, int]:
"""Return the shape of the matrix as (rows, cols)."""
return (self.rows, self.cols)
@property
def T(self) -> 'matrix[TNum]':
return self.transpose()

View File

@ -194,6 +194,11 @@ class vector(Generic[TNum]):
return vector(a != b for a, b in zip(self.values, other.values))
return vector(a != other for a in self.values)
@property
def shape(self) -> tuple[int]:
"""Return the shape of the vector as (length,)."""
return (len(self.values),)
@overload
def sum(self: 'vector[int]') -> int | variable[int]: ...
@overload