type hint fixes

This commit is contained in:
Nicolas 2025-12-06 15:13:28 +01:00
parent d526c5ddc0
commit a21970de79
3 changed files with 37 additions and 7 deletions

View File

@ -20,9 +20,10 @@ class matrix(Generic[TNum]):
values: iterable of iterable of constant values and variables values: iterable of iterable of constant values and variables
""" """
if isinstance(values, vector): if isinstance(values, vector):
rows = (values.values,) rows = [values.values]
else: else:
rows = tuple(tuple(row) for row in values) rows = [tuple(row) for row in values]
if rows: if rows:
row_len = len(rows[0]) row_len = len(rows[0])
assert all(len(row) == row_len for row in rows), "All rows must have the same length" assert all(len(row) == row_len for row in rows), "All rows must have the same length"

View File

@ -40,10 +40,10 @@ class Target():
for s in variables: for s in variables:
if isinstance(s, Iterable): if isinstance(s, Iterable):
for net in s: for net in s:
assert isinstance(net, Net), f"The folowing element is not a Net: {net}" if isinstance(net, Net):
nodes.append(Write(net)) nodes.append(Write(net))
else: else:
assert isinstance(s, Net), f"The folowing element is not a Net: {s}" if isinstance(s, Net):
nodes.append(Write(s)) nodes.append(Write(s))
dw, self._variables = compile_to_dag(nodes, self.sdb) dw, self._variables = compile_to_dag(nodes, self.sdb)

View File

@ -30,7 +30,13 @@ class vector(Generic[TNum]):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.values) return len(self.values)
def __getitem__(self, index: int) -> variable[TNum] | TNum: @overload
def __getitem__(self, index: int) -> variable[TNum] | TNum: ...
@overload
def __getitem__(self, index: slice) -> 'vector[TNum]': ...
def __getitem__(self, index: int | slice) -> 'vector[TNum] | variable[TNum] | TNum':
if isinstance(index, slice):
return vector(self.values[index])
return self.values[index] return self.values[index]
def __neg__(self) -> 'vector[TNum]': def __neg__(self) -> 'vector[TNum]':
@ -111,6 +117,29 @@ class vector(Generic[TNum]):
def __rmul__(self, other: VecNumLike) -> Any: def __rmul__(self, other: VecNumLike) -> Any:
return self * other return self * other
@overload
def __pow__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ...
@overload
def __pow__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
@overload
def __pow__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __pow__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
def __pow__(self, other: VecNumLike) -> Any:
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a ** b for a, b in zip(self.values, other.values))
return vector(a ** other for a in self.values)
@overload
def __rpow__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __rpow__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
@overload
def __rpow__(self, other: VecNumLike) -> 'vector[Any]': ...
def __rpow__(self, other: VecNumLike) -> Any:
return self ** other
def __truediv__(self, other: VecNumLike) -> 'vector[float]': def __truediv__(self, other: VecNumLike) -> 'vector[float]':
if isinstance(other, vector): if isinstance(other, vector):
assert len(self.values) == len(other.values) assert len(self.values) == len(other.values)