diff --git a/src/copapy/_mixed.py b/src/copapy/_mixed.py new file mode 100644 index 0000000..e4d2632 --- /dev/null +++ b/src/copapy/_mixed.py @@ -0,0 +1,24 @@ + +from . import variable +from typing import TypeVar, Iterable, Any, overload + +T = TypeVar("T", int, float) + + +@overload +def mixed_sum(scalars: Iterable[float | variable[float]]) -> float | variable[float]: ... +@overload +def mixed_sum(scalars: Iterable[int | variable[int]]) -> int | variable[int]: ... +@overload +def mixed_sum(scalars: Iterable[T | variable[T]]) -> T | variable[T]: ... +def mixed_sum(scalars: Iterable[int | float | variable[Any]]) -> Any: + sl = list(scalars) + return sum(a for a in sl if not isinstance(a, variable)) +\ + sum(a for a in sl if isinstance(a, variable)) + + +def mixed_homogenize(scalars: Iterable[T | variable[T]]) -> Iterable[T] | Iterable[variable[T]]: + if any(isinstance(val, variable) for val in scalars): + return (variable(val) if not isinstance(val, variable) else val for val in scalars) + else: + return (val for val in scalars if not isinstance(val, variable)) diff --git a/src/copapy/_vectors.py b/src/copapy/_vectors.py index 3d8a721..e6fbf40 100644 --- a/src/copapy/_vectors.py +++ b/src/copapy/_vectors.py @@ -1,4 +1,5 @@ from . import variable +from ._mixed import mixed_sum, mixed_homogenize from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable, Iterator import copapy as cp @@ -31,7 +32,7 @@ class vector(Generic[T]): def __getitem__(self, index: int) -> variable[T] | T: return self.values[index] - def __neg__(self) -> 'vector[float] | vector[int]': + def __neg__(self) -> 'vector[T]': return vector(-a for a in self.values) def __iter__(self) -> Iterator[variable[T] | T]: @@ -125,7 +126,7 @@ class vector(Generic[T]): def dot(self, other: 'vector[int] | vector[float]') -> float | int | variable[float] | variable[int]: ... def dot(self, other: 'vector[int] | vector[float]') -> Any: assert len(self.values) == len(other.values), "Vectors must be of same length." - return sum(a * b for a, b in zip(self.values, other.values)) + return mixed_sum(a * b for a, b in zip(self.values, other.values)) # @ operator @overload @@ -156,13 +157,12 @@ class vector(Generic[T]): def sum(self: 'vector[float]') -> float | variable[float]: ... def sum(self) -> Any: """Sum of all vector elements.""" - return sum(a for a in self.values if isinstance(a, variable)) +\ - sum(a for a in self.values if not isinstance(a, variable)) + return mixed_sum(self.values) def magnitude(self) -> 'float | variable[float]': """Magnitude (length) of the vector.""" - s = sum(a * a for a in self.values) - return cp.sqrt(s) if isinstance(s, variable) else cp.sqrt(s) + s = mixed_sum(a * a for a in self.values) + return cp.sqrt(s) def normalize(self) -> 'vector[float]': """Returns a normalized (unit length) version of the vector.""" @@ -171,7 +171,7 @@ class vector(Generic[T]): def homogenize(self) -> 'vector[T]': if any(isinstance(val, variable) for val in self.values): - return vector(variable(val) if not isinstance(val, variable) else val for val in self.values) + return vector(mixed_homogenize(self)) else: return self