mirror of https://github.com/Nonannet/copapy.git
type hint fixes
This commit is contained in:
parent
d526c5ddc0
commit
a21970de79
|
|
@ -20,9 +20,10 @@ class matrix(Generic[TNum]):
|
|||
values: iterable of iterable of constant values and variables
|
||||
"""
|
||||
if isinstance(values, vector):
|
||||
rows = (values.values,)
|
||||
rows = [values.values]
|
||||
else:
|
||||
rows = tuple(tuple(row) for row in values)
|
||||
rows = [tuple(row) for row in values]
|
||||
|
||||
if rows:
|
||||
row_len = len(rows[0])
|
||||
assert all(len(row) == row_len for row in rows), "All rows must have the same length"
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ class Target():
|
|||
for s in variables:
|
||||
if isinstance(s, Iterable):
|
||||
for net in s:
|
||||
assert isinstance(net, Net), f"The folowing element is not a Net: {net}"
|
||||
if isinstance(net, Net):
|
||||
nodes.append(Write(net))
|
||||
else:
|
||||
assert isinstance(s, Net), f"The folowing element is not a Net: {s}"
|
||||
if isinstance(s, Net):
|
||||
nodes.append(Write(s))
|
||||
|
||||
dw, self._variables = compile_to_dag(nodes, self.sdb)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,13 @@ class vector(Generic[TNum]):
|
|||
def __len__(self) -> int:
|
||||
return len(self.values)
|
||||
|
||||
def __getitem__(self, index: int) -> variable[TNum] | TNum:
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> variable[TNum] | TNum: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> 'vector[TNum]': ...
|
||||
def __getitem__(self, index: int | slice) -> 'vector[TNum] | variable[TNum] | TNum':
|
||||
if isinstance(index, slice):
|
||||
return vector(self.values[index])
|
||||
return self.values[index]
|
||||
|
||||
def __neg__(self) -> 'vector[TNum]':
|
||||
|
|
@ -111,6 +117,29 @@ class vector(Generic[TNum]):
|
|||
def __rmul__(self, other: VecNumLike) -> Any:
|
||||
return self * other
|
||||
|
||||
@overload
|
||||
def __pow__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ...
|
||||
@overload
|
||||
def __pow__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
|
||||
@overload
|
||||
def __pow__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
||||
@overload
|
||||
def __pow__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
|
||||
def __pow__(self, other: VecNumLike) -> Any:
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
return vector(a ** b for a, b in zip(self.values, other.values))
|
||||
return vector(a ** other for a in self.values)
|
||||
|
||||
@overload
|
||||
def __rpow__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
|
||||
@overload
|
||||
def __rpow__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
|
||||
@overload
|
||||
def __rpow__(self, other: VecNumLike) -> 'vector[Any]': ...
|
||||
def __rpow__(self, other: VecNumLike) -> Any:
|
||||
return self ** other
|
||||
|
||||
def __truediv__(self, other: VecNumLike) -> 'vector[float]':
|
||||
if isinstance(other, vector):
|
||||
assert len(self.values) == len(other.values)
|
||||
|
|
|
|||
Loading…
Reference in New Issue