mirror of https://github.com/Nonannet/copapy.git
eye function added for creating matrices
This commit is contained in:
parent
257fe96bb3
commit
da92aa9e2c
|
|
@ -1,7 +1,7 @@
|
||||||
from ._target import Target
|
from ._target import Target
|
||||||
from ._basic_types import NumLike, variable, generic_sdb, iif
|
from ._basic_types import NumLike, variable, generic_sdb, iif
|
||||||
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
|
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
|
||||||
from ._matrices import matrix, identity, zeros, ones, diagonal
|
from ._matrices import matrix, identity, zeros, ones, diagonal, eye
|
||||||
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
|
from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu
|
||||||
from ._autograd import grad
|
from ._autograd import grad
|
||||||
|
|
||||||
|
|
@ -40,5 +40,6 @@ __all__ = [
|
||||||
"angle_between",
|
"angle_between",
|
||||||
"rotate_vector",
|
"rotate_vector",
|
||||||
"vector_projection",
|
"vector_projection",
|
||||||
"grad"
|
"grad",
|
||||||
|
"eye"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,15 @@ U = TypeVar("U", int, float)
|
||||||
class matrix(Generic[TNum]):
|
class matrix(Generic[TNum]):
|
||||||
"""Mathematical matrix class supporting basic operations and interactions with variables.
|
"""Mathematical matrix class supporting basic operations and interactions with variables.
|
||||||
"""
|
"""
|
||||||
def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]]):
|
def __init__(self, values: Iterable[Iterable[TNum | variable[TNum]]] | vector[TNum]):
|
||||||
"""Create a matrix with given values and variables.
|
"""Create a matrix with given values and variables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values: iterable of iterable of constant values and variables
|
values: iterable of iterable of constant values and variables
|
||||||
"""
|
"""
|
||||||
|
if isinstance(values, vector):
|
||||||
|
rows = (values.values,)
|
||||||
|
else:
|
||||||
rows = tuple(tuple(row) for row in values)
|
rows = tuple(tuple(row) for row in values)
|
||||||
if rows:
|
if rows:
|
||||||
row_len = len(rows[0])
|
row_len = len(rows[0])
|
||||||
|
|
@ -33,8 +36,11 @@ class matrix(Generic[TNum]):
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.rows
|
return self.rows
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> tuple[variable[TNum] | TNum, ...]:
|
def __getitem__(self, key: tuple[int, int]) -> variable[TNum] | TNum:
|
||||||
return self.values[index]
|
assert len(key) == 2
|
||||||
|
row = key[0]
|
||||||
|
col = key[1]
|
||||||
|
return self.values[row][col]
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
|
def __iter__(self) -> Iterator[tuple[variable[TNum] | TNum, ...]]:
|
||||||
return iter(self.values)
|
return iter(self.values)
|
||||||
|
|
@ -251,8 +257,6 @@ class matrix(Generic[TNum]):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
# Utility functions for matrices
|
|
||||||
|
|
||||||
def identity(size: int) -> matrix[int]:
|
def identity(size: int) -> matrix[int]:
|
||||||
"""Create an identity matrix of given size."""
|
"""Create an identity matrix of given size."""
|
||||||
return matrix(
|
return matrix(
|
||||||
|
|
@ -277,6 +281,15 @@ def ones(rows: int, cols: int) -> matrix[int]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def eye(rows: int, cols: int | None = None) -> matrix[int]:
|
||||||
|
"""Create a matrix with ones on the diagonal and zeros elsewhere."""
|
||||||
|
cols = cols if cols else rows
|
||||||
|
return matrix(
|
||||||
|
tuple(1 if i == j else 0 for j in range(cols))
|
||||||
|
for i in range(rows)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def diagonal(vec: 'vector[int]') -> matrix[int]: ...
|
def diagonal(vec: 'vector[int]') -> matrix[int]: ...
|
||||||
@overload
|
@overload
|
||||||
|
|
|
||||||
|
|
@ -223,8 +223,6 @@ class vector(Generic[TNum]):
|
||||||
return vector(func(x) for x in self.values)
|
return vector(func(x) for x in self.values)
|
||||||
|
|
||||||
|
|
||||||
# Utility functions for 3D vectors with two arguments
|
|
||||||
|
|
||||||
def cross_product(v1: vector[float], v2: vector[float]) -> vector[float]:
|
def cross_product(v1: vector[float], v2: vector[float]) -> vector[float]:
|
||||||
"""Calculate the cross product of two 3D vectors."""
|
"""Calculate the cross product of two 3D vectors."""
|
||||||
return v1.cross(v2)
|
return v1.cross(v2)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue