From 09d7d01eeead4880261e19dd7cdb22eaf37e332b Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Fri, 27 Mar 2026 15:22:33 +0100 Subject: [PATCH] quaternion class added --- src/copapy/__init__.py | 4 +- src/copapy/_quaternion.py | 282 ++++++++++++++++++++++++++++++++++++ tests/test_quaternion.py | 296 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 581 insertions(+), 1 deletion(-) create mode 100644 src/copapy/_quaternion.py create mode 100644 tests/test_quaternion.py diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index b9d99a1..680e6a4 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -36,6 +36,7 @@ Example usage: from ._target import Target, jit from ._basic_types import NumLike, value, generic_sdb, iif from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection +from ._quaternion import quaternion from ._tensors import tensor, zeros, ones, arange, eye, identity, diagonal from ._math import sqrt, abs, sign, sin, cos, tan, asin, acos, atan, atan2, log, exp, pow, get_42, clamp, min, max, relu from ._autograd import grad @@ -80,7 +81,8 @@ __all__ = [ "scalar_projection", "angle_between", "rotate_vector", - "vector_projection", +"vector_projection", + "quaternion", "grad", "eye", "jit" diff --git a/src/copapy/_quaternion.py b/src/copapy/_quaternion.py new file mode 100644 index 0000000..6750a70 --- /dev/null +++ b/src/copapy/_quaternion.py @@ -0,0 +1,282 @@ +from typing import overload, Iterable, Callable, Any +from ._vectors import vector +from ._tensors import tensor +import copapy as cp +from ._basic_types import NumLike, value, unifloat, ArrayType +from ._mixed import mixed_sum + + +class quaternion(ArrayType[float]): + """Mathematical quaternion class for representing 3D rotations. + + Attributes: + values (tuple[unifloat, ...]): Internal storage of the (w, x, y, z) components. + w, x, y, z (unifloat): Property accessors to individual components. + """ + def __init__( + self, + w: unifloat | Iterable[unifloat] = 1.0, + x: unifloat = 0.0, + y: unifloat = 0.0, + z: unifloat = 0.0): + """Create a quaternion with given components. + + Arguments: + w: w component, or an iterable of 4 components. + x: x component (ignored if w is an iterable). + y: y component (ignored if w is an iterable). + z: z component (ignored if w is an iterable). + """ + self.shape = (4,) + if isinstance(w, Iterable): + self.values = tuple(v for v in w) + assert len(self.values) == 4, "Sequence must have exactly 4 elements for quaternion initialization." + else: + self.values = (w, x, y, z) + + + @classmethod + def from_euler(cls, roll: NumLike, pitch: NumLike, yaw: NumLike) -> 'quaternion': + """Create a quaternion from Euler angles (roll, pitch, yaw). + + Arguments: + roll: Rotation around the x-axis in radians. + pitch: Rotation around the y-axis in radians. + yaw: Rotation around the z-axis in radians. + + Returns: + A quaternion representing the rotation. + """ + cy = cp.cos(yaw * 0.5) + sy = cp.sin(yaw * 0.5) + ci = cp.cos(pitch * 0.5) + sp = cp.sin(pitch * 0.5) + cr = cp.cos(roll * 0.5) + sr = cp.sin(roll * 0.5) + + w = cr * ci * cy + sr * sp * sy + x = sr * ci * cy - cr * sp * sy + y = cr * sp * cy + sr * ci * sy + z = cr * ci * sy - sr * sp * cy + + return cls(w, x, y, z) + + @classmethod + def identity(cls) -> 'quaternion': + """Return the identity quaternion (no rotation). + + Returns: + The identity quaternion (x=0, y=0, z=0, w=1). + """ + return cls(w=1.0, x=0.0, y=0.0, z=0.0) + + @property + def x(self) -> unifloat: + return self.values[1] + + @property + def y(self) -> unifloat: + return self.values[2] + + @property + def z(self) -> unifloat: + return self.values[3] + + @property + def w(self) -> unifloat: + return self.values[0] + + + def normalize(self) -> 'quaternion': + """Normalize the quaternion to unit length. + + Returns: + A normalized (unit) quaternion. Returns identity if the norm is zero. + """ + n = self.norm() + if not isinstance(n, value) and n == 0: + return quaternion.identity() + return quaternion(v / n for v in self.values) + + def toRotationMatrix(self) -> tensor[float]: + """Convert the quaternion to a 4x4 rotation matrix. + + Returns: + A 4x4 tensor representing the rotation matrix. + """ + w, x, y, z = self.values + x2 = x + x + y2 = y + y + z2 = z + z + xx = x * x2 + xy = x * y2 + xz = x * z2 + yy = y * y2 + yz = y * z2 + zz = z * z2 + wx = w * x2 + wy = w * y2 + wz = w * z2 + + s1: list[unifloat] = [1.0 - (yy + zz), xy - wz, xz + wy, 0.0] + s2: list[unifloat] = [xy + wz, 1.0 - (xx + zz), yz - wx, 0.0] + s3: list[unifloat] = [xz - wy, yz + wx, 1.0 - (xx + yy), 0.0] + s4: list[unifloat] = [0.0, 0.0, 0.0, 1.0] + return tensor([s1, s2, s3, s4]) + + def toEulerAngles(self) -> vector[float]: + """Convert the quaternion to Euler angles (roll, pitch, yaw). + + Returns: + A vector of [roll, pitch, yaw] in radians. + """ + w, x, y, z = self.w, self.x, self.y, self.z + + yaw = cp.atan2(2 * (w * z + x * y), 1 - 2 * (y * y + z * z)) + pitch_sin = cp.clamp(2 * (w * y - z * x), -1.0, 1.0) + pitch = cp.asin(pitch_sin) + roll = cp.atan2(2 * (w * x + y * z), 1 - 2 * (x * x + y * y)) + + return vector([roll, pitch, yaw]) + + def toAxisAngle(self) -> tuple[vector[float], unifloat]: + """Convert the quaternion to axis-angle representation. + + Returns: + A tuple of (axis, angle) where axis is a unit vector and angle is in radians. + """ + n = self.normalize() + sin_half_angle_sq = 1 - n.w * n.w + is_near_identity = sin_half_angle_sq < 1e-6 + s = 1 / cp.sqrt(cp.iif(is_near_identity, 1e-6, sin_half_angle_sq)) + angle = cp.iif(is_near_identity, 0.0, 2 * cp.acos(n.w)) + axis = vector([ + cp.iif(is_near_identity, 1.0, n.x * s), + cp.iif(is_near_identity, 0.0, n.y * s), + cp.iif(is_near_identity, 0.0, n.z * s), + ]) + return axis, angle + + def conjugate(self) -> 'quaternion': + """Return the conjugate of the quaternion. + + Returns: + The conjugate quaternion (negates x, y, z components). + """ + return quaternion(self.w, -self.x, -self.y, -self.z) + + def inverse(self) -> 'quaternion': + """Return the inverse of the quaternion. + + Returns: + The inverse quaternion. Returns identity if the norm is zero. + """ + n2 = self.norm() ** 2 + if not isinstance(n2, value) and n2 == 0: + return quaternion.identity() + return quaternion(v / n2 for v in self.conjugate().values) + + def norm(self) -> unifloat: + """Calculate the norm (magnitude) of the quaternion. + + Returns: + The norm (square root of the sum of squared components). + """ + return cp.sqrt(mixed_sum(v**2 for v in self.values)) + + def rotate_vector(self, vec: vector[float]) -> vector[float]: + """Rotate a 3D vector by this quaternion. + + Arguments: + vec: A 3D vector to rotate. + + Returns: + The rotated vector. + """ + q_vec = quaternion(0, *vec) + rotated_q = self @ q_vec @ self.inverse() + return vector(rotated_q.values[1:]) + + def map(self, func: Callable[[Any], value[float] | float]) -> 'quaternion': + """Applies a function to each element of the quaternion and returns a new quaternion. + + Arguments: + func: A function that takes a single argument. + + Returns: + A new quaternion with the function applied to each element. + """ + return quaternion(func(x) for x in self.values) + + def __neg__(self) -> 'quaternion': + return quaternion(-self.w, -self.x, -self.y, -self.z) + + def __abs__(self) -> unifloat: + return self.norm() + + def __repr__(self) -> str: + return f"vector({self.values})" + + def __len__(self) -> int: + return len(self.values) + + @overload + def __getitem__(self, index: int) -> value[float] | float: ... + @overload + def __getitem__(self, index: slice) -> 'vector[float]': ... + def __getitem__(self, index: int | slice) -> 'vector[float] | value[float] | float': + if isinstance(index, slice): + return vector(self.values[index]) + return self.values[index] + + @overload + def __add__(self, other: 'quaternion') -> 'quaternion': ... + @overload + def __add__(self, other: NumLike) -> 'quaternion': ... + def __add__(self, other: 'quaternion | NumLike') -> 'quaternion': + if isinstance(other, quaternion): + return quaternion(a + b for a, b in zip(self.values, other.values)) + if isinstance(other, value): + return quaternion(v + other for v in self.values) + o = value(other) # Make sure a single constant is allocated + return quaternion(a + o if isinstance(a, value) else a + other for a in self.values) + + def __radd__(self, other: int | float) -> 'quaternion': + return self + other + + @overload + def __sub__(self, other: 'quaternion') -> 'quaternion': ... + @overload + def __sub__(self, other: NumLike) -> 'quaternion': ... + def __sub__(self, other: 'quaternion | NumLike') -> 'quaternion': + if isinstance(other, quaternion): + return quaternion(a - b for a, b in zip(self.values, other.values)) + if isinstance(other, value): + return quaternion(v - other for v in self.values) + o = value(other) # Make sure a single constant is allocated + return quaternion(a - o if isinstance(a, value) else a - other for a in self.values) + + def __rsub__(self, other: NumLike) -> 'quaternion': + return -self + other + + def __mul__(self, other: NumLike) -> 'quaternion': + if isinstance(other, value): + return quaternion(v * other for v in self.values) + o = value(other) # Make sure a single constant is allocated + return quaternion(v * o if isinstance(v, value) else v * other for v in self.values) + + def __rmul__(self, other: NumLike) -> 'quaternion': + return self * other + + def __matmul__(self, other: 'quaternion') -> 'quaternion': + w = self.w * other.w - self.x * other.x - self.y * other.y - self.z * other.z + x = self.w * other.x + self.x * other.w + self.y * other.z - self.z * other.y + y = self.w * other.y - self.x * other.z + self.y * other.w + self.z * other.x + z = self.w * other.z + self.x * other.y - self.y * other.x + self.z * other.w + return quaternion(w, x, y, z) + + def __truediv__(self, other: NumLike) -> 'quaternion': + if isinstance(other, value): + return quaternion(v / other for v in self.values) + o = value(other) # Make sure a single constant is allocated + return quaternion(v / o if isinstance(v, value) else v / other for v in self.values) diff --git a/tests/test_quaternion.py b/tests/test_quaternion.py new file mode 100644 index 0000000..88f1134 --- /dev/null +++ b/tests/test_quaternion.py @@ -0,0 +1,296 @@ +import math +from copapy import quaternion, tensor +from copapy import vector, Target +import copapy as cp + + +def isclose(a, b, rel_tol=1e-9, abs_tol=0.0): + if isinstance(a, tensor) and isinstance(b, tensor): + return all(isclose(av, bv, rel_tol=rel_tol, abs_tol=abs_tol) for av, bv in zip(a.values, b.values)) + if isinstance(a, tensor): + return all(isclose(av, b, rel_tol=rel_tol, abs_tol=abs_tol) for av in a.values) + if isinstance(b, tensor): + return all(isclose(a, bv, rel_tol=rel_tol, abs_tol=abs_tol) for bv in b.values) + return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) + + +def test_identity(): + q = quaternion.identity() + assert q.x == 0.0 + assert q.y == 0.0 + assert q.z == 0.0 + assert q.w == 1.0 + + +def test_constructor_default(): + q = quaternion() + assert q.x == 0.0 + assert q.y == 0.0 + assert q.z == 0.0 + assert q.w == 1.0 + + +def test_constructor_with_values(): + q = quaternion(1.0, 2.0, 3.0, 4.0) + assert q.x == 2.0 + assert q.y == 3.0 + assert q.z == 4.0 + assert q.w == 1.0 + + +def test_from_euler_90_roll(): + q = quaternion.from_euler(math.pi / 2, 0.0, 0.0) + assert isclose(q.w, math.sqrt(2) / 2) + assert isclose(q.x, math.sqrt(2) / 2) + + +def test_from_euler_90_pitch(): + q = quaternion.from_euler(0.0, math.pi / 2, 0.0) + assert isclose(q.w, math.sqrt(2) / 2) + assert isclose(q.y, math.sqrt(2) / 2) + + +def test_from_euler_90_yaw(): + q = quaternion.from_euler(0.0, 0.0, math.pi / 2) + assert isclose(q.w, math.sqrt(2) / 2) + assert isclose(q.z, math.sqrt(2) / 2) + + +def test_normalize(): + q = quaternion(0.0, 2.0, 0.0, 0.0) + n = q.normalize() + assert isclose(n.x, 1.0) + assert n.y == 0.0 + assert n.z == 0.0 + assert n.w == 0.0 + + +def test_normalize_identity(): + q = quaternion.identity().normalize() + assert q.x == 0.0 + assert q.y == 0.0 + assert q.z == 0.0 + assert isclose(q.w, 1.0) + + +def test_to_rotation_matrix_identity(): + q = quaternion.identity() + m = q.toRotationMatrix() + expected = [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + for i in range(4): + for j in range(4): + assert isclose(m[i, j], expected[i][j]) + + +def test_to_euler_roundtrip(): + roll, pitch, yaw = math.pi / 4, math.pi / 6, math.pi / 3 + q = quaternion.from_euler(roll, pitch, yaw) + r, p, y = q.toEulerAngles() + assert isclose(r, roll) + assert isclose(p, pitch) + assert isclose(y, yaw) + + +def test_conjugate(): + q = quaternion(4.0, 1.0, 2.0, 3.0) + c = q.conjugate() + assert c.x == -1.0 + assert c.y == -2.0 + assert c.z == -3.0 + assert c.w == 4.0 + + +def test_inverse_identity(): + q = quaternion.identity() + inv = q.inverse() + assert isclose(inv.x, 0.0) + assert isclose(inv.y, 0.0) + assert isclose(inv.z, 0.0) + assert isclose(inv.w, 1.0) + + +def test_inverse_product(): + q = quaternion(4.0, 1.0, 2.0, 3.0) + inv = q.inverse() + result = q @ inv + assert isclose(result.x, 0.0, abs_tol=1e-6) + assert isclose(result.y, 0.0, abs_tol=1e-6) + assert isclose(result.z, 0.0, abs_tol=1e-6) + assert isclose(result.w, 1.0, abs_tol=1e-6) + + +def test_norm_identity(): + q = quaternion.identity() + assert isclose(abs(q), 1.0) + + +def test_norm_unit(): + q = quaternion(0.0, 1.0, 0.0, 0.0) + assert isclose(abs(q), 1.0) + + +def test_negation(): + q = quaternion(4.0, 1.0, 2.0, 3.0) + assert (-q).x == -1.0 + assert (-q).y == -2.0 + assert (-q).z == -3.0 + assert (-q).w == -4.0 + + +def test_add(): + q1 = quaternion(0.0, 1.0, 0.0, 0.0) + q2 = quaternion(0.0, 0.0, 1.0, 0.0) + s = q1 + q2 + assert s.x == 1.0 + assert s.y == 1.0 + assert s.z == 0.0 + assert s.w == 0.0 + + +def test_add_scalar(): + q = quaternion(0.0, 1.0, 0.0, 0.0) + s = q + 1.0 + assert s.x == 2.0 + assert s.y == 1.0 + assert s.z == 1.0 + assert s.w == 1.0 + + +def test_sub(): + q1 = quaternion(0.0, 1.0, 1.0, 0.0) + q2 = quaternion(0.0, 0.0, 1.0, 0.0) + s = q1 - q2 + assert s.x == 1.0 + assert s.y == 0.0 + + +def test_sub_scalar(): + q = quaternion(2.0, 2.0, 2.0, 2.0) + s = q - 1.0 + assert s.x == 1.0 + assert s.y == 1.0 + assert s.z == 1.0 + assert s.w == 1.0 + + +def test_mul_scalar(): + q = quaternion(4.0, 1.0, 2.0, 3.0) + m = q * 2.0 + assert m.x == 2.0 + assert m.y == 4.0 + assert m.z == 6.0 + assert m.w == 8.0 + + +def test_rmul_scalar(): + q = quaternion(4.0, 1.0, 2.0, 3.0) + m = 2.0 * q + assert m.x == 2.0 + assert m.y == 4.0 + assert m.z == 6.0 + assert m.w == 8.0 + + +def test_matmul(): + q1 = quaternion(0.0, 1.0, 0.0, 0.0) # i + q2 = quaternion(0.0, 0.0, 1.0, 0.0) # j + m = q1 @ q2 + + assert isclose(m.x, 0.0) + assert isclose(m.y, 0.0) + assert isclose(m.z, 1.0) + assert isclose(m.w, 0.0) + + +def test_div(): + q = quaternion(8.0, 2.0, 4.0, 6.0) + d = q / 2.0 + assert d.x == 1.0 + assert d.y == 2.0 + assert d.z == 3.0 + assert d.w == 4.0 + + +def test_to_axis_angle_identity(): + q = quaternion.identity() + axis, angle = q.toAxisAngle() + assert isclose(angle, 0.0) + assert isclose(axis[0], 1.0) + assert isclose(axis[1], 0.0) + assert isclose(axis[2], 0.0) + + +def test_to_axis_angle_90_degrees(): + q = quaternion.from_euler(math.pi / 2, 0.0, 0.0) + axis, angle = q.toAxisAngle() + assert isclose(angle, math.pi / 2) + assert isclose(axis[0], 1.0) + assert isclose(axis[1], 0.0) + assert isclose(axis[2], 0.0) + + +def test_rotate_vector_identity(): + from copapy import vector + q = quaternion.identity() + v = vector([1.0, 2.0, 3.0]) + rotated = q.rotate_vector(v) + assert isclose(rotated[0], 1.0) + assert isclose(rotated[1], 2.0) + assert isclose(rotated[2], 3.0) + + +def test_rotate_vector_90_degrees_x(): + from copapy import vector + q = quaternion.from_euler(math.pi / 2, 0.0, 0.0) + v = vector([0.0, 1.0, 0.0]) + rotated = q.rotate_vector(v) + assert isclose(rotated[0], 0.0, abs_tol=1e-9) + assert isclose(rotated[1], 0.0, abs_tol=1e-9) + assert isclose(rotated[2], 1.0, abs_tol=1e-9) + + +def test_rotate_vector_roundtrip(): + from copapy import vector + q = quaternion.from_euler(math.pi / 4, math.pi / 6, math.pi / 3) + v = vector([1.0, 0.5, 0.25]) + rotated = q.rotate_vector(v) + q_inv = q.inverse() + restored = q_inv.rotate_vector(rotated) + for i in range(3): + assert isclose(restored[i], v[i], abs_tol=1e-9) + + +def test_satellite_attitude_correction(): + current_q = quaternion.from_euler(math.pi / 8, math.pi / 6, 0.0) + desired_q = quaternion.from_euler(cp.value(-math.pi / 8), cp.value(math.pi / 3), cp.value(math.pi / 4)) + solar_panel_normal = vector([0.0, 0.0, 1.0]) + + rotation_q = desired_q @ current_q.inverse() + rotated_normal = rotation_q.rotate_vector(solar_panel_normal) + + expected_current = quaternion.from_euler(math.pi / 8, math.pi / 6, 0.0) + expected_desired = quaternion.from_euler(-math.pi / 8, math.pi / 3, math.pi / 4) + expected_rotation = expected_desired @ expected_current.inverse() + expected_rotated = expected_rotation.rotate_vector(solar_panel_normal) + + tg = Target() + tg.compile(rotation_q, rotated_normal) + tg.run() + + result_q = tg.read_value(rotation_q) + result_normal = tg.read_value(rotated_normal) + + print(rotation_q,result_q) + + assert isclose(result_q[0], expected_rotation.w, abs_tol=1e-6) + assert isclose(result_q[1], expected_rotation.x, abs_tol=1e-6) + assert isclose(result_q[2], expected_rotation.y, abs_tol=1e-6) + assert isclose(result_q[3], expected_rotation.z, abs_tol=1e-6) + assert isclose(result_normal[0], expected_rotated[0], abs_tol=1e-6) + assert isclose(result_normal[1], expected_rotated[1], abs_tol=1e-6) + assert isclose(result_normal[2], expected_rotated[2], abs_tol=1e-6)