Updated quaternion handling for usage with grad() function

This commit is contained in:
Nicolas Kruse 2026-03-31 11:34:46 +02:00
parent f63e09fb99
commit 5eae012d00
2 changed files with 12 additions and 3 deletions

View File

@ -1,3 +1,5 @@
from copapy._quaternion import quaternion
from . import value, vector, tensor from . import value, vector, tensor
import copapy.backend as cpb import copapy.backend as cpb
from typing import Any, Sequence, overload from typing import Any, Sequence, overload
@ -12,8 +14,10 @@ def grad(x: Any, y: vector[Any]) -> vector[float]: ...
@overload @overload
def grad(x: Any, y: tensor[Any]) -> tensor[float]: ... def grad(x: Any, y: tensor[Any]) -> tensor[float]: ...
@overload @overload
def grad(x: Any, y: quaternion) -> quaternion: ...
@overload
def grad(x: Any, y: Sequence[value[Any]]) -> list[unifloat]: ... def grad(x: Any, y: Sequence[value[Any]]) -> list[unifloat]: ...
def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | tensor[Any]) -> Any: def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | tensor[Any] | quaternion) -> Any:
"""Returns the partial derivative dx/dy where x needs to be a scalar """Returns the partial derivative dx/dy where x needs to be a scalar
and y might be a scalar, a list of scalars, a vector or matrix. It and y might be a scalar, a list of scalars, a vector or matrix. It
uses automatic differentiation in reverse-mode. uses automatic differentiation in reverse-mode.
@ -32,7 +36,7 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | tensor[Any
if isinstance(y, tensor): if isinstance(y, tensor):
y_set = {v.get_scalar(0) for v in y.flatten()} y_set = {v.get_scalar(0) for v in y.flatten()}
else: else:
assert isinstance(y, Sequence) or isinstance(y, vector) assert isinstance(y, Sequence) or isinstance(y, vector) or isinstance(y, quaternion)
y_set = set(y) y_set = set(y)
edges = cpb.get_all_dag_edges_between([x.net.source], (v.net.source for v in y_set if isinstance(v, value))) edges = cpb.get_all_dag_edges_between([x.net.source], (v.net.source for v in y_set if isinstance(v, value)))
@ -125,6 +129,8 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | tensor[Any
return grad_dict[y.net] return grad_dict[y.net]
if isinstance(y, vector): if isinstance(y, vector):
return vector(grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y.values) return vector(grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y.values)
if isinstance(y, quaternion):
return quaternion(grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y.values)
if isinstance(y, tensor): if isinstance(y, tensor):
return tensor([grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y.values], y.shape) return tensor([grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y.values], y.shape)
return [grad_dict[yi.net] for yi in y] return [grad_dict[yi.net] for yi in y]

View File

@ -1,4 +1,4 @@
from typing import overload, Iterable, Callable, Any from typing import overload, Iterable, Callable, Any, Iterator
from ._vectors import vector from ._vectors import vector
from ._tensors import tensor from ._tensors import tensor
import copapy as cp import copapy as cp
@ -208,6 +208,9 @@ class quaternion(ArrayType[float]):
""" """
return quaternion(func(x) for x in self.values) return quaternion(func(x) for x in self.values)
def __iter__(self) -> Iterator[value[float] | float]:
return iter(self.values)
def __neg__(self) -> 'quaternion': def __neg__(self) -> 'quaternion':
return quaternion(-self.w, -self.x, -self.y, -self.z) return quaternion(-self.w, -self.x, -self.y, -self.z)