shape property added and __getitem__ of matrix extended

This commit is contained in:
Nicolas Kruse 2025-12-05 08:28:02 +01:00
parent 6d47779c03
commit 959d80b082
2 changed files with 28 additions and 5 deletions

View File

@ -34,13 +34,26 @@ class matrix(Generic[TNum]):
return f"matrix({self.values})" return f"matrix({self.values})"
def __len__(self) -> int: def __len__(self) -> int:
"""Return the number of rows in the matrix."""
return self.rows return self.rows
def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum: @overload
assert len(key) == 2 def __getitem__(self, key: int) -> vector[TNum]: ...
row = key[0] @overload
col = key[1] def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum: ...
return self.values[row][col] def __getitem__(self, key: int | tuple[int, int]) -> Any:
"""Get a row as a vector or a specific element.
Args:
key: row index or (row, col) tuple
Returns:
vector if row index is given, else the element at (row, col)
"""
if isinstance(key, tuple):
assert len(key) == 2
return self.values[key[0]][key[1]]
else:
return vector(self.values[key])
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]: def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
return iter(self.values) return iter(self.values)
@ -203,6 +216,11 @@ class matrix(Generic[TNum]):
tuple(self.values[i][j] for i in range(self.rows)) tuple(self.values[i][j] for i in range(self.rows))
for j in range(self.cols) for j in range(self.cols)
) )
@property
def shape(self) -> tuple[int, int]:
"""Return the shape of the matrix as (rows, cols)."""
return (self.rows, self.cols)
@property @property
def T(self) -> 'matrix[TNum]': def T(self) -> 'matrix[TNum]':

View File

@ -193,6 +193,11 @@ class vector(Generic[TNum]):
assert len(self.values) == len(other.values) assert len(self.values) == len(other.values)
return vector(a != b for a, b in zip(self.values, other.values)) return vector(a != b for a, b in zip(self.values, other.values))
return vector(a != other for a in self.values) return vector(a != other for a in self.values)
@property
def shape(self) -> tuple[int]:
"""Return the shape of the vector as (length,)."""
return (len(self.values),)
@overload @overload
def sum(self: 'vector[int]') -> int | variable[int]: ... def sum(self: 'vector[int]') -> int | variable[int]: ...