mirror of https://github.com/Nonannet/copapy.git
shape property added and __getitem__ of matrix extended
This commit is contained in:
parent
6d47779c03
commit
959d80b082
|
|
@ -34,13 +34,26 @@ class matrix(Generic[TNum]):
|
||||||
return f"matrix({self.values})"
|
return f"matrix({self.values})"
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of rows in the matrix."""
|
||||||
return self.rows
|
return self.rows
|
||||||
|
|
||||||
def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum:
|
@overload
|
||||||
assert len(key) == 2
|
def __getitem__(self, key: int) -> vector[TNum]: ...
|
||||||
row = key[0]
|
@overload
|
||||||
col = key[1]
|
def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum: ...
|
||||||
return self.values[row][col]
|
def __getitem__(self, key: int | tuple[int, int]) -> Any:
|
||||||
|
"""Get a row as a vector or a specific element.
|
||||||
|
Args:
|
||||||
|
key: row index or (row, col) tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
vector if row index is given, else the element at (row, col)
|
||||||
|
"""
|
||||||
|
if isinstance(key, tuple):
|
||||||
|
assert len(key) == 2
|
||||||
|
return self.values[key[0]][key[1]]
|
||||||
|
else:
|
||||||
|
return vector(self.values[key])
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
|
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
|
||||||
return iter(self.values)
|
return iter(self.values)
|
||||||
|
|
@ -204,6 +217,11 @@ class matrix(Generic[TNum]):
|
||||||
for j in range(self.cols)
|
for j in range(self.cols)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> tuple[int, int]:
|
||||||
|
"""Return the shape of the matrix as (rows, cols)."""
|
||||||
|
return (self.rows, self.cols)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def T(self) -> 'matrix[TNum]':
|
def T(self) -> 'matrix[TNum]':
|
||||||
return self.transpose()
|
return self.transpose()
|
||||||
|
|
|
||||||
|
|
@ -194,6 +194,11 @@ class vector(Generic[TNum]):
|
||||||
return vector(a != b for a, b in zip(self.values, other.values))
|
return vector(a != b for a, b in zip(self.values, other.values))
|
||||||
return vector(a != other for a in self.values)
|
return vector(a != other for a in self.values)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> tuple[int]:
|
||||||
|
"""Return the shape of the vector as (length,)."""
|
||||||
|
return (len(self.values),)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def sum(self: 'vector[int]') -> int | variable[int]: ...
|
def sum(self: 'vector[int]') -> int | variable[int]: ...
|
||||||
@overload
|
@overload
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue