diff --git a/src/copapy/_matrices.py b/src/copapy/_matrices.py index 906b983..773ac38 100644 --- a/src/copapy/_matrices.py +++ b/src/copapy/_matrices.py @@ -20,9 +20,10 @@ class matrix(Generic[TNum]): values: iterable of iterable of constant values and variables """ if isinstance(values, vector): - rows = (values.values,) + rows = [values.values] else: - rows = tuple(tuple(row) for row in values) + rows = [tuple(row) for row in values] + if rows: row_len = len(rows[0]) assert all(len(row) == row_len for row in rows), "All rows must have the same length" diff --git a/src/copapy/_target.py b/src/copapy/_target.py index 13fe4f6..d2bbed4 100644 --- a/src/copapy/_target.py +++ b/src/copapy/_target.py @@ -40,11 +40,11 @@ class Target(): for s in variables: if isinstance(s, Iterable): for net in s: - assert isinstance(net, Net), f"The folowing element is not a Net: {net}" - nodes.append(Write(net)) + if isinstance(net, Net): + nodes.append(Write(net)) else: - assert isinstance(s, Net), f"The folowing element is not a Net: {s}" - nodes.append(Write(s)) + if isinstance(s, Net): + nodes.append(Write(s)) dw, self._variables = compile_to_dag(nodes, self.sdb) dw.write_com(binw.Command.END_COM) diff --git a/src/copapy/_vectors.py b/src/copapy/_vectors.py index 3c50d2a..44d5869 100644 --- a/src/copapy/_vectors.py +++ b/src/copapy/_vectors.py @@ -30,7 +30,13 @@ class vector(Generic[TNum]): def __len__(self) -> int: return len(self.values) - def __getitem__(self, index: int) -> variable[TNum] | TNum: + @overload + def __getitem__(self, index: int) -> variable[TNum] | TNum: ... + @overload + def __getitem__(self, index: slice) -> 'vector[TNum]': ... + def __getitem__(self, index: int | slice) -> 'vector[TNum] | variable[TNum] | TNum': + if isinstance(index, slice): + return vector(self.values[index]) return self.values[index] def __neg__(self) -> 'vector[TNum]': @@ -111,6 +117,29 @@ class vector(Generic[TNum]): def __rmul__(self, other: VecNumLike) -> Any: return self * other + @overload + def __pow__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ... + @overload + def __pow__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ... + @overload + def __pow__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ... + @overload + def __pow__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ... + def __pow__(self, other: VecNumLike) -> Any: + if isinstance(other, vector): + assert len(self.values) == len(other.values) + return vector(a ** b for a, b in zip(self.values, other.values)) + return vector(a ** other for a in self.values) + + @overload + def __rpow__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ... + @overload + def __rpow__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ... + @overload + def __rpow__(self, other: VecNumLike) -> 'vector[Any]': ... + def __rpow__(self, other: VecNumLike) -> Any: + return self ** other + def __truediv__(self, other: VecNumLike) -> 'vector[float]': if isinstance(other, vector): assert len(self.values) == len(other.values)