vector functions extended

This commit is contained in:
Nicolas Kruse 2025-11-19 10:01:30 +01:00
parent beddf2e7e9
commit 4049928139
4 changed files with 186 additions and 27 deletions

View File

@ -1,7 +1,7 @@
from ._target import Target
from ._basic_types import NumLike, variable, generic_sdb, iif
from ._vectors import vector
from ._math import sqrt, abs, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
from ._math import sqrt, abs, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max
__all__ = [
"Target",
@ -22,5 +22,13 @@ __all__ = [
"log",
"exp",
"pow",
"get_42"
"get_42",
"clamp",
"min",
"max",
"distance",
"scalar_projection",
"angle_between",
"rotate_vector",
"vector_projection",
]

View File

@ -7,9 +7,11 @@ NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | f
unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool'
uniboolint: TypeAlias = 'variable[bool] | bool | variable[int] | int'
TCPNum = TypeVar("TCPNum", bound='variable[Any]')
TNum = TypeVar("TNum", int, float, bool)
TVarNumb: TypeAlias = 'variable[Any] | int | float | bool'
stencil_cache: dict[tuple[str, str], stencil_database] = {}
@ -113,6 +115,8 @@ class variable(Generic[TNum], Net):
@overload
def __add__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __add__(self, other: NumLike) -> Any:
if isinstance(other, int | float) and other == 0:
return self
return add_op('add', [self, other], True)
@overload
@ -120,6 +124,8 @@ class variable(Generic[TNum], Net):
@overload
def __radd__(self, other: float) -> 'variable[float]': ...
def __radd__(self, other: NumLike) -> Any:
if isinstance(other, int | float) and other == 0:
return self
return add_op('add', [self, other], True)
@overload
@ -185,27 +191,27 @@ class variable(Generic[TNum], Net):
def __neg__(self: TCPNum) -> TCPNum:
return cast(TCPNum, add_op('sub', [variable(0), self]))
def __gt__(self, other: NumLike) -> 'variable[bool]':
def __gt__(self, other: TVarNumb) -> 'variable[bool]':
ret = add_op('gt', [self, other])
return variable(ret.source, dtype='bool')
def __lt__(self, other: NumLike) -> 'variable[bool]':
def __lt__(self, other: TVarNumb) -> 'variable[bool]':
ret = add_op('gt', [other, self])
return variable(ret.source, dtype='bool')
def __ge__(self, other: NumLike) -> 'variable[bool]':
def __ge__(self, other: TVarNumb) -> 'variable[bool]':
ret = add_op('ge', [self, other])
return variable(ret.source, dtype='bool')
def __le__(self, other: NumLike) -> 'variable[bool]':
def __le__(self, other: TVarNumb) -> 'variable[bool]':
ret = add_op('ge', [other, self])
return variable(ret.source, dtype='bool')
def __eq__(self, other: NumLike) -> 'variable[bool]': # type: ignore
def __eq__(self, other: TVarNumb) -> 'variable[bool]': # type: ignore
ret = add_op('eq', [self, other], True)
return variable(ret.source, dtype='bool')
def __ne__(self, other: NumLike) -> 'variable[bool]': # type: ignore
def __ne__(self, other: TVarNumb) -> 'variable[bool]': # type: ignore
ret = add_op('ne', [self, other], True)
return variable(ret.source, dtype='bool')
@ -249,34 +255,34 @@ class variable(Generic[TNum], Net):
return super().__hash__()
# Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniint) -> 'variable[int]':
def __lshift__(self, other: uniboolint) -> 'variable[int]':
return add_op('lshift', [self, other])
def __rlshift__(self, other: uniint) -> 'variable[int]':
def __rlshift__(self, other: uniboolint) -> 'variable[int]':
return add_op('lshift', [other, self])
def __rshift__(self, other: uniint) -> 'variable[int]':
def __rshift__(self, other: uniboolint) -> 'variable[int]':
return add_op('rshift', [self, other])
def __rrshift__(self, other: uniint) -> 'variable[int]':
def __rrshift__(self, other: uniboolint) -> 'variable[int]':
return add_op('rshift', [other, self])
def __and__(self, other: uniint) -> 'variable[int]':
def __and__(self, other: uniboolint) -> 'variable[int]':
return add_op('bwand', [self, other], True)
def __rand__(self, other: uniint) -> 'variable[int]':
def __rand__(self, other: uniboolint) -> 'variable[int]':
return add_op('rwand', [other, self], True)
def __or__(self, other: uniint) -> 'variable[int]':
def __or__(self, other: uniboolint) -> 'variable[int]':
return add_op('bwor', [self, other], True)
def __ror__(self, other: uniint) -> 'variable[int]':
def __ror__(self, other: uniboolint) -> 'variable[int]':
return add_op('bwor', [other, self], True)
def __xor__(self, other: uniint) -> 'variable[int]':
def __xor__(self, other: uniboolint) -> 'variable[int]':
return add_op('bwxor', [self, other], True)
def __rxor__(self, other: uniint) -> 'variable[int]':
def __rxor__(self, other: uniboolint) -> 'variable[int]':
return add_op('bwxor', [other, self], True)

View File

@ -70,7 +70,7 @@ def pow(x: VecNumLike, y: VecNumLike) -> Any:
result of x**y
"""
if isinstance(x, vector) or isinstance(y, vector):
return map2(x, y, pow)
return _map2(x, y, pow)
if isinstance(y, int) and 0 <= y < 8:
if y == 0:
return 1
@ -218,7 +218,7 @@ def atan2(x: VecNumLike, y: VecNumLike) -> Any:
Result in radian
"""
if isinstance(x, vector) or isinstance(y, vector):
return map2(x, y, atan2)
return _map2(x, y, atan2)
if isinstance(x, variable) or isinstance(y, variable):
return add_op('atan2', [x, y])
return math.atan2(x, y)
@ -278,8 +278,11 @@ def get_42(x: NumLike) -> variable[float] | float:
return add_op('get_42', [x, x])
return float((int(x) * 3.0 + 42.0) * 5.0 + 21.0)
def abs(x: T) -> T:
@overload
def abs(x: U) -> U: ...
@overload
def abs(x: variable[U]) -> variable[U]: ...
def abs(x: U | variable[U]) -> Any:
"""Absolute value function
Arguments:
@ -292,7 +295,93 @@ def abs(x: T) -> T:
return ret # pyright: ignore[reportReturnType]
def map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], variable[U] | U]) -> vector[U]:
@overload
def clamp(x: variable[U], min_value: U | variable[U], max_value: U | variable[U]) -> variable[U]: ...
@overload
def clamp(x: U | variable[U], min_value: variable[U], max_value: U | variable[U]) -> variable[U]: ...
@overload
def clamp(x: U | variable[U], min_value: U | variable[U], max_value: variable[U]) -> variable[U]: ...
@overload
def clamp(x: U, min_value: U, max_value: U) -> U: ...
@overload
def clamp(x: vector[U], min_value: 'U | variable[U]', max_value: 'U | variable[U]') -> vector[U]: ...
def clamp(x: U | variable[U] | vector[U], min_value: U | variable[U], max_value: U | variable[U]) -> Any:
"""Clamp function to limit a value between a minimum and maximum.
Arguments:
x: Input value
min_value: Minimum limit
max_value: Maximum limit
Returns:
Clamped value of x
"""
if isinstance(x, vector):
return vector(clamp(comp, min_value, max_value) for comp in x.values)
return (x < min_value) * min_value + \
(x > max_value) * max_value + \
((x >= min_value) & (x <= max_value)) * x
@overload
def min(x: variable[U], y: U | variable[U]) -> variable[U]: ...
@overload
def min(x: U | variable[U], y: variable[U]) -> variable[U]: ...
@overload
def min(x: U, y: U) -> U: ...
def min(x: U | variable[U], y: U | variable[U]) -> Any:
"""Minimum function to get the smaller of two values.
Arguments:
x: First value
y: Second value
Returns:
Minimum of x and y
"""
return (x < y) * x + (x >= y) * y
@overload
def max(x: variable[U], y: U | variable[U]) -> variable[U]: ...
@overload
def max(x: U | variable[U], y: variable[U]) -> variable[U]: ...
@overload
def max(x: U, y: U) -> U: ...
def max(x: U | variable[U], y: U | variable[U]) -> Any:
"""Maximum function to get the larger of two values.
Arguments:
x: First value
y: Second value
Returns:
Maximum of x and y
"""
return (x > y) * x + (x <= y) * y
@overload
def lerp(v1: variable[U], v2: U | variable[U], t: U | variable[U]) -> variable[U]: ...
@overload
def lerp(v1: U | variable[U], v2: variable[U], t: U | variable[U]) -> variable[U]: ...
@overload
def lerp(v1: U | variable[U], v2: U | variable[U], t: variable[U]) -> variable[U]: ...
@overload
def lerp(v1: U, v2: U, t: U) -> U: ...
@overload
def lerp(v1: vector[U], v2: vector[U], t: 'U | variable[U]') -> vector[U]: ...
def lerp(v1: U | variable[U] | vector[U], v2: U | variable[U] | vector[U], t: U | variable[U]) -> Any:
"""Linearly interpolate between two values or vectors v1 and v2 by a factor t."""
if isinstance(v1, vector) or isinstance(v2, vector):
assert isinstance(v1, vector) and isinstance(v2, vector), "None or both v1 and v2 must be vectors."
assert len(v1.values) == len(v2.values), "Vectors must be of the same length."
return vector(lerp(vv1, vv2, t) for vv1, vv2 in zip(v1.values, v2.values))
return v1 * (1 - t) + v2 * t
def _map2(self: VecNumLike, other: VecNumLike, func: Callable[[Any, Any], variable[U] | U]) -> vector[U]:
"""Applies a function to each element of the vector and a second vector or scalar."""
if isinstance(self, vector) and isinstance(other, vector):
return vector(func(x, y) for x, y in zip(self.values, other.values))

View File

@ -81,7 +81,7 @@ class vector(Generic[T]):
@overload
def __mul__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
@overload
def __mul__(self: 'vector[float]', other: 'vector[int] | float | int | variable[int]') -> 'vector[float]': ...
def __mul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __mul__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
def __mul__(self, other: VecNumLike) -> Any:
@ -118,7 +118,7 @@ class vector(Generic[T]):
@overload
def dot(self, other: 'vector[int] | vector[float]') -> float | int | variable[float] | variable[int]: ...
def dot(self, other: 'vector[int] | vector[float]') -> Any:
assert len(self.values) == len(other.values)
assert len(self.values) == len(other.values), "Vectors must be of same length."
return sum(a * b for a, b in zip(self.values, other.values))
# @ operator
@ -135,7 +135,7 @@ class vector(Generic[T]):
def cross(self: 'vector[float]', other: 'vector[float]') -> 'vector[float]':
"""3D cross product"""
assert len(self.values) == 3 and len(other.values) == 3
assert len(self.values) == 3 and len(other.values) == 3, "Both vectors must be 3-dimensional."
a1, a2, a3 = self.values
b1, b2, b3 = other.values
return vector([
@ -163,9 +163,65 @@ class vector(Generic[T]):
mag = self.magnitude() + epsilon
return self / mag
def __neg__(self) -> 'vector[float] | vector[int]':
return vector(-a for a in self.values)
def __iter__(self) -> Iterable[variable[T] | T]:
return iter(self.values)
def map(self, func: Callable[[Any], variable[U] | U]) -> 'vector[U]':
"""Applies a function to each element of the vector and returns a new vector."""
return vector(func(x) for x in self.values)
# Utility functions for 3D vectors with two arguments
def cross_product(v1: vector[float], v2: vector[float]) -> vector[float]:
"""Calculate the cross product of two 3D vectors."""
return v1.cross(v2)
def dot_product(v1: vector[float], v2: vector[float]) -> 'float | variable[float]':
"""Calculate the dot product of two vectors."""
return v1.dot(v2)
def distance(v1: vector[float], v2: vector[float]) -> 'float | variable[float]':
"""Calculate the Euclidean distance between two vectors."""
diff = v1 - v2
return diff.magnitude()
def scalar_projection(v1: vector[float], v2: vector[float]) -> 'float | variable[float]':
"""Calculate the scalar projection of v1 onto v2."""
dot_prod = v1.dot(v2)
mag_v2 = v2.magnitude() + epsilon
return dot_prod / mag_v2
def vector_projection(v1: vector[float], v2: vector[float]) -> vector[float]:
"""Calculate the vector projection of v1 onto v2."""
dot_prod = v1.dot(v2)
mag_v2_squared = v2.magnitude() ** 2 + epsilon
scalar_proj = dot_prod / mag_v2_squared
return v2 * scalar_proj
def angle_between(v1: vector[float], v2: vector[float]) -> 'float | variable[float]':
"""Calculate the angle in radians between two vectors."""
dot_prod = v1.dot(v2)
mag_v1 = v1.magnitude()
mag_v2 = v2.magnitude()
cos_angle = dot_prod / (mag_v1 * mag_v2 + epsilon)
return cp.acos(cos_angle)
def rotate_vector(v: vector[float], axis: vector[float], angle: 'float | variable[float]') -> vector[float]:
"""Rotate vector v around a given axis by a specified angle using Rodrigues' rotation formula."""
k = axis.normalize()
cos_angle = cp.cos(angle)
sin_angle = cp.sin(angle)
term1 = v * cos_angle
term2 = k.cross(v) * sin_angle
term3 = k * (k.dot(v)) * (1 - cos_angle)
return term1 + term2 + term3