vector type added, sqrt and ge/le added; type hints improved

This commit is contained in:
Nicolas Kruse 2025-10-25 02:21:43 +02:00
parent cb1447f125
commit df84b61a7b
11 changed files with 459 additions and 278 deletions

View File

@ -1,14 +1,16 @@
from ._target import Target from ._target import Target
from ._basic_types import NumLike, variable, \ from ._basic_types import NumLike, variable, \
CPNumber, generic_sdb, iif generic_sdb, iif
from ._vectors import vector from ._vectors import vector
from ._math import sqrt, abs
__all__ = [ __all__ = [
"Target", "Target",
"NumLike", "NumLike",
"variable", "variable",
"CPNumber",
"generic_sdb", "generic_sdb",
"iif", "iif",
"vector" "vector",
"sqrt",
"abs",
] ]

View File

@ -8,7 +8,7 @@ unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int' uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool' unibool: TypeAlias = 'variable[bool] | bool'
TCPNum = TypeVar("TCPNum", bound='CPNumber') TCPNum = TypeVar("TCPNum", bound='variable[Any]')
TNum = TypeVar("TNum", int, bool, float) TNum = TypeVar("TNum", int, bool, float)
@ -57,173 +57,7 @@ class Net:
return id(self) return id(self)
class CPNumber(Net): class variable(Generic[TNum], Net):
def __init__(self, dtype: str, source: Node):
self.dtype = dtype
self.source = source
@overload
def __mul__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __mul__(self, other: unifloat) -> 'variable[float]':
...
def __mul__(self, other: NumLike) -> 'CPNumber':
return _add_op('mul', [self, other], True)
@overload
def __rmul__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __rmul__(self, other: unifloat) -> 'variable[float]':
...
def __rmul__(self, other: NumLike) -> 'CPNumber':
return _add_op('mul', [self, other], True)
@overload
def __add__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __add__(self, other: unifloat) -> 'variable[float]':
...
def __add__(self, other: NumLike) -> 'CPNumber':
return _add_op('add', [self, other], True)
@overload
def __radd__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __radd__(self, other: unifloat) -> 'variable[float]':
...
def __radd__(self, other: NumLike) -> 'CPNumber':
return _add_op('add', [self, other], True)
@overload
def __sub__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __sub__(self, other: unifloat) -> 'variable[float]':
...
def __sub__(self, other: NumLike) -> 'CPNumber':
return _add_op('sub', [self, other])
@overload
def __rsub__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __rsub__(self, other: unifloat) -> 'variable[float]':
...
def __rsub__(self, other: NumLike) -> 'CPNumber':
return _add_op('sub', [other, self])
def __truediv__(self, other: NumLike) -> 'variable[float]':
return _add_op('div', [self, other])
def __rtruediv__(self, other: NumLike) -> 'variable[float]':
return _add_op('div', [other, self])
@overload
def __floordiv__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __floordiv__(self, other: unifloat) -> 'variable[float]':
...
def __floordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [self, other])
@overload
def __rfloordiv__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __rfloordiv__(self, other: unifloat) -> 'variable[float]':
...
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [other, self])
def __neg__(self: TCPNum) -> TCPNum:
return cast(TCPNum, _add_op('sub', [variable(0), self]))
def __gt__(self, other: NumLike) -> 'variable[bool]':
ret = _add_op('gt', [self, other])
return variable(ret.source, dtype='bool')
def __lt__(self, other: NumLike) -> 'variable[bool]':
ret = _add_op('gt', [other, self])
return variable(ret.source, dtype='bool')
def __eq__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = _add_op('eq', [self, other], True)
return variable(ret.source, dtype='bool')
def __ne__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = _add_op('ne', [self, other], True)
return variable(ret.source, dtype='bool')
@overload
def __mod__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __mod__(self, other: unifloat) -> 'variable[float]':
...
def __mod__(self, other: NumLike) -> 'CPNumber':
return _add_op('mod', [self, other])
@overload
def __rmod__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __rmod__(self, other: unifloat) -> 'variable[float]':
...
def __rmod__(self, other: NumLike) -> 'CPNumber':
return _add_op('mod', [other, self])
@overload
def __pow__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __pow__(self, other: unifloat) -> 'variable[float]':
...
def __pow__(self, other: NumLike) -> 'CPNumber':
return _add_op('pow', [other, self])
@overload
def __rpow__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
def __rpow__(self, other: unifloat) -> 'variable[float]':
...
def __rpow__(self, other: NumLike) -> 'CPNumber':
return _add_op('rpow', [self, other])
def __hash__(self) -> int:
return super().__hash__()
class variable(Generic[TNum], CPNumber):
def __init__(self, source: TNum | Node, dtype: str | None = None): def __init__(self, source: TNum | Node, dtype: str | None = None):
if isinstance(source, Node): if isinstance(source, Node):
self.source = source self.source = source
@ -239,36 +73,185 @@ class variable(Generic[TNum], CPNumber):
self.source = CPConstant(source) self.source = CPConstant(source)
self.dtype = 'int' self.dtype = 'int'
@overload
def __add__(self, other: TCPNum) -> TCPNum: ...
@overload
def __add__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __add__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __add__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __add__(self, other: NumLike) -> Any:
return add_op('add', [self, other], True)
@overload
def __radd__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __radd__(self, other: float) -> 'variable[float]': ...
def __radd__(self, other: NumLike) -> Any:
return add_op('add', [self, other], True)
@overload
def __sub__(self, other: TCPNum) -> TCPNum: ...
@overload
def __sub__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __sub__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __sub__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __sub__(self, other: NumLike) -> Any:
return add_op('sub', [self, other])
@overload
def __rsub__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rsub__(self, other: float) -> 'variable[float]': ...
def __rsub__(self, other: NumLike) -> Any:
return add_op('sub', [other, self])
@overload
def __mul__(self, other: TCPNum) -> TCPNum: ...
@overload
def __mul__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __mul__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __mul__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __mul__(self, other: NumLike) -> Any:
return add_op('mul', [self, other], True)
@overload
def __rmul__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rmul__(self, other: float) -> 'variable[float]': ...
def __rmul__(self, other: NumLike) -> Any:
return add_op('mul', [self, other], True)
def __truediv__(self, other: NumLike) -> 'variable[float]':
return add_op('div', [self, other])
def __rtruediv__(self, other: NumLike) -> 'variable[float]':
return add_op('div', [other, self])
@overload
def __floordiv__(self, other: TCPNum) -> TCPNum: ...
@overload
def __floordiv__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __floordiv__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __floordiv__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __floordiv__(self, other: NumLike) -> Any:
return add_op('floordiv', [self, other])
@overload
def __rfloordiv__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rfloordiv__(self, other: float) -> 'variable[float]': ...
def __rfloordiv__(self, other: NumLike) -> Any:
return add_op('floordiv', [other, self])
def __neg__(self: TCPNum) -> TCPNum:
return cast(TCPNum, add_op('sub', [variable(0), self]))
def __gt__(self, other: NumLike) -> 'variable[bool]':
ret = add_op('gt', [self, other])
return variable(ret.source, dtype='bool')
def __lt__(self, other: NumLike) -> 'variable[bool]':
ret = add_op('gt', [other, self])
return variable(ret.source, dtype='bool')
def __ge__(self, other: NumLike) -> 'variable[bool]':
ret = add_op('ge', [self, other])
return variable(ret.source, dtype='bool')
def __le__(self, other: NumLike) -> 'variable[bool]':
ret = add_op('ge', [other, self])
return variable(ret.source, dtype='bool')
def __eq__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = add_op('eq', [self, other], True)
return variable(ret.source, dtype='bool')
def __ne__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = add_op('ne', [self, other], True)
return variable(ret.source, dtype='bool')
@overload
def __mod__(self, other: TCPNum) -> TCPNum: ...
@overload
def __mod__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __mod__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __mod__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __mod__(self, other: NumLike) -> Any:
return add_op('mod', [self, other])
@overload
def __rmod__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rmod__(self, other: float) -> 'variable[float]': ...
def __rmod__(self, other: NumLike) -> Any:
return add_op('mod', [other, self])
@overload
def __pow__(self, other: TCPNum) -> TCPNum: ...
@overload
def __pow__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __pow__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __pow__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __pow__(self, other: NumLike) -> Any:
if not isinstance(other, variable):
if other == 2:
return self * self
if other == -1:
return 1 / self
return add_op('pow', [self, other])
@overload
def __rpow__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rpow__(self, other: float) -> 'variable[float]': ...
def __rpow__(self, other: NumLike) -> Any:
return add_op('rpow', [other, self])
def __hash__(self) -> int:
return super().__hash__()
# Bitwise and shift operations for cp[int] # Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniint) -> 'variable[int]': def __lshift__(self, other: uniint) -> 'variable[int]':
return _add_op('lshift', [self, other]) return add_op('lshift', [self, other])
def __rlshift__(self, other: uniint) -> 'variable[int]': def __rlshift__(self, other: uniint) -> 'variable[int]':
return _add_op('lshift', [other, self]) return add_op('lshift', [other, self])
def __rshift__(self, other: uniint) -> 'variable[int]': def __rshift__(self, other: uniint) -> 'variable[int]':
return _add_op('rshift', [self, other]) return add_op('rshift', [self, other])
def __rrshift__(self, other: uniint) -> 'variable[int]': def __rrshift__(self, other: uniint) -> 'variable[int]':
return _add_op('rshift', [other, self]) return add_op('rshift', [other, self])
def __and__(self, other: uniint) -> 'variable[int]': def __and__(self, other: uniint) -> 'variable[int]':
return _add_op('bwand', [self, other], True) return add_op('bwand', [self, other], True)
def __rand__(self, other: uniint) -> 'variable[int]': def __rand__(self, other: uniint) -> 'variable[int]':
return _add_op('rwand', [other, self], True) return add_op('rwand', [other, self], True)
def __or__(self, other: uniint) -> 'variable[int]': def __or__(self, other: uniint) -> 'variable[int]':
return _add_op('bwor', [self, other], True) return add_op('bwor', [self, other], True)
def __ror__(self, other: uniint) -> 'variable[int]': def __ror__(self, other: uniint) -> 'variable[int]':
return _add_op('bwor', [other, self], True) return add_op('bwor', [other, self], True)
def __xor__(self, other: uniint) -> 'variable[int]': def __xor__(self, other: uniint) -> 'variable[int]':
return _add_op('bwxor', [self, other], True) return add_op('bwxor', [self, other], True)
def __rxor__(self, other: uniint) -> 'variable[int]': def __rxor__(self, other: uniint) -> 'variable[int]':
return _add_op('bwxor', [other, self], True) return add_op('bwxor', [other, self], True)
class CPConstant(Node): class CPConstant(Node):
@ -303,42 +286,24 @@ def net_from_value(value: Any) -> Net:
@overload @overload
def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> variable[bool]: # pyright: ignore[reportOverlappingOverload] def iif(expression: variable[Any], true_result: unibool, false_result: unibool) -> variable[bool]: ... # pyright: ignore[reportOverlappingOverload]
...
@overload @overload
def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> variable[int]: def iif(expression: variable[Any], true_result: uniint, false_result: uniint) -> variable[int]: ...
...
@overload @overload
def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> variable[float]: def iif(expression: variable[Any], true_result: unifloat, false_result: unifloat) -> variable[float]: ...
...
@overload @overload
def iif(expression: float | int, true_result: TNum, false_result: TNum) -> TNum: def iif(expression: float | int, true_result: TNum, false_result: TNum) -> TNum: ...
...
@overload @overload
def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]) -> variable[TNum]: def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]) -> variable[TNum]: ...
...
@overload @overload
def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]: def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]: ...
...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any: def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
allowed_type = (variable, int, float, bool) allowed_type = (variable, int, float, bool)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported" assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
return (expression != 0) * true_result + (expression == 0) * false_result return (expression != 0) * true_result + (expression == 0) * false_result
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]: def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool = False) -> variable[Any]:
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
if commutative: if commutative:
@ -351,10 +316,10 @@ def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = Fal
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
if result_type == 'int': if result_type == 'float':
return variable[int](Op(typed_op, arg_nets), result_type)
else:
return variable[float](Op(typed_op, arg_nets), result_type) return variable[float](Op(typed_op, arg_nets), result_type)
else:
return variable[int](Op(typed_op, arg_nets), result_type)
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:

24
src/copapy/_math.py Normal file
View File

@ -0,0 +1,24 @@
from . import variable, NumLike
from typing import TypeVar, Any, overload
from ._basic_types import add_op
T = TypeVar("T", int, float, variable[int], variable[float])
@overload
def sqrt(x: float | int) -> float: ...
@overload
def sqrt(x: variable[Any]) -> variable[float]: ...
def sqrt(x: NumLike) -> variable[float] | float:
"""Square root function"""
if isinstance(x, variable):
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
return float(x ** 0.5)
def abs(x: T) -> T:
"""Absolute value function"""
ret = (x < 0) * -x + (x >= 0) * x
return ret # pyright: ignore[reportReturnType]

View File

@ -1,42 +1,158 @@
from copapy import CPNumber, variable from . import variable
from typing import Generic, TypeVar, Iterable, Any, overload from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias
from ._math import sqrt
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float'
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
T = TypeVar("T", int, float)
epsilon = 1e-10
T = TypeVar("T", int, float, bool)
T2 = TypeVar("T2", bound=CPNumber)
class vector(Generic[T]): class vector(Generic[T]):
"""Type-safe vector supporting numeric promotion between vector types."""
def __init__(self, values: Iterable[T | variable[T]]): def __init__(self, values: Iterable[T | variable[T]]):
#self.values: tuple[variable[T], ...] = tuple(v if isinstance(v, variable) else variable(v) for v in values)
self.values: tuple[variable[T] | T, ...] = tuple(values) self.values: tuple[variable[T] | T, ...] = tuple(values)
@overload # ---------- Basic dunder methods ----------
def __add__(self, other: 'vector[float] | variable[float] | float') -> 'vector[float]': def __repr__(self) -> str:
... return f"vector({self.values})"
def __len__(self) -> int:
return len(self.values)
def __getitem__(self, index: int) -> variable[T] | T:
return self.values[index]
@overload @overload
def __add__(self: 'vector[T]', other: 'vector[int] | variable[int] | int') -> 'vector[T]': def __add__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ...
... @overload
def __add__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
def __add__(self, other: 'vector[Any] | variable[Any] | float | int') -> Any: @overload
def __add__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __add__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
def __add__(self, other: VecNumLike) -> Any:
if isinstance(other, vector): if isinstance(other, vector):
assert len(self.values) == len(other.values) assert len(self.values) == len(other.values)
return vector(a + b for a, b in zip(self.values, other.values)) return vector(a + b for a, b in zip(self.values, other.values))
else:
return vector(a + other for a in self.values) return vector(a + other for a in self.values)
#@overload @overload
#def sum(self: 'vector[float]') -> variable[float]: def __radd__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
# ... @overload
def __radd__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
def __radd__(self, other: Any) -> Any:
return self + other
#@overload @overload
#def sum(self: 'vector[int]') -> variable[int]: def __sub__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ...
# ... @overload
def __sub__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
@overload
def __sub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __sub__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
def __sub__(self, other: VecNumLike) -> Any:
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a - b for a, b in zip(self.values, other.values))
return vector(a - other for a in self.values)
#def sum(self: 'vector[T]') -> variable[T] | T: @overload
# comp_time = sum(v for v in self.values if not isinstance(v, variable)) def __rsub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
# run_time = sum(v for v in self.values if isinstance(v, variable)) @overload
# if isinstance(run_time, variable): def __rsub__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
# return comp_time + run_time # type: ignore def __rsub__(self, other: VecNumLike) -> Any:
# else: if isinstance(other, vector):
# return comp_time assert len(self.values) == len(other.values)
return vector(b - a for a, b in zip(self.values, other.values))
return vector(other - a for a in self.values)
@overload
def __mul__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ...
@overload
def __mul__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
@overload
def __mul__(self: 'vector[float]', other: 'vector[int] | float | int | variable[int]') -> 'vector[float]': ...
@overload
def __mul__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
def __mul__(self, other: VecNumLike) -> Any:
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a * b for a, b in zip(self.values, other.values))
return vector(a * other for a in self.values)
@overload
def __rmul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __rmul__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
def __rmul__(self, other: VecNumLike) -> Any:
return self * other
def __truediv__(self, other: VecNumLike) -> 'vector[float]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a / b for a, b in zip(self.values, other.values))
return vector(a / other for a in self.values)
def __rtruediv__(self, other: VecNumLike) -> 'vector[float]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(b / a for a, b in zip(self.values, other.values))
return vector(other / a for a in self.values)
@overload
def dot(self: 'vector[int]', other: 'vector[int]') -> int | variable[int]: ...
@overload
def dot(self, other: 'vector[float]') -> float | variable[float]: ...
@overload
def dot(self: 'vector[float]', other: 'vector[int] | vector[float]') -> float | variable[float]: ...
@overload
def dot(self, other: 'vector[int] | vector[float]') -> float | int | variable[float] | variable[int]: ...
def dot(self, other: 'vector[int] | vector[float]') -> Any:
assert len(self.values) == len(other.values)
return sum(a * b for a, b in zip(self.values, other.values))
# @ operator
@overload
def __matmul__(self: 'vector[int]', other: 'vector[int]') -> int | variable[int]: ...
@overload
def __matmul__(self, other: 'vector[float]') -> float | variable[float]: ...
@overload
def __matmul__(self: 'vector[float]', other: 'vector[int] | vector[float]') -> float | variable[float]: ...
@overload
def __matmul__(self, other: 'vector[int] | vector[float]') -> float | int | variable[float] | variable[int]: ...
def __matmul__(self, other: 'vector[int] | vector[float]') -> Any:
return self.dot(other)
def cross(self: 'vector[float]', other: 'vector[float]') -> 'vector[float]':
"""3D cross product"""
assert len(self.values) == 3 and len(other.values) == 3
a1, a2, a3 = self.values
b1, b2, b3 = other.values
return vector([
a2 * b3 - a3 * b2,
a3 * b1 - a1 * b3,
a1 * b2 - a2 * b1
])
@overload
def sum(self: 'vector[int]') -> int | variable[int]: ...
@overload
def sum(self: 'vector[float]') -> float | variable[float]: ...
def sum(self) -> Any:
return sum(a for a in self.values if isinstance(a, variable)) +\
sum(a for a in self.values if not isinstance(a, variable))
def magnitude(self) -> 'float | variable[float]':
s = sum(a * a for a in self.values)
return sqrt(s) if isinstance(s, variable) else sqrt(s)
def normalize(self) -> 'vector[float]':
mag = self.magnitude() + epsilon
return self / mag
def __iter__(self) -> Iterable[variable[T] | T]:
return iter(self.values)

View File

@ -12,6 +12,19 @@ __attribute__((noinline)) int floor_div(float arg1, float arg2) {
return i; return i;
} }
float fast_sqrt(float n) {
if (n < 0) return -1;
float x = n; // initial guess
float epsilon = 0.00001; // desired accuracy
while ((x - n / x) > epsilon || (x - n / x) < -epsilon) {
x = 0.5 * (x + n / x);
}
return x;
}
float fast_pow_float(float base, float exponent) { float fast_pow_float(float base, float exponent) {
union { union {
float f; float f;

View File

@ -4,7 +4,7 @@ from pathlib import Path
import os import os
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/', 'pow': '**', op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/', 'pow': '**',
'gt': '>', 'eq': '==', 'ne': '!=', 'mod': '%'} 'gt': '>', 'eq': '==', 'ge': '>=', 'ne': '!=', 'mod': '%'}
entry_func_prefix = '' entry_func_prefix = ''
stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog
@ -78,6 +78,15 @@ def get_cast(type1: str, type2: str, type_out: str) -> str:
""" """
@norm_indent
def get_sqrt(type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void sqrt_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_float_{type2}(fast_sqrt((float)arg1), arg2);
}}
"""
@norm_indent @norm_indent
def get_conv_code(type1: str, type2: str, type_out: str) -> str: def get_conv_code(type1: str, type2: str, type_out: str) -> str:
return f""" return f"""
@ -189,7 +198,7 @@ if __name__ == "__main__":
# Scalar arithmetic: # Scalar arithmetic:
types = ['int', 'float'] types = ['int', 'float']
ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'eq', 'ne', 'pow'] ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'pow']
for t1 in types: for t1 in types:
code += get_result_stubs1(t1) code += get_result_stubs1(t1)
@ -203,6 +212,9 @@ if __name__ == "__main__":
t_out = 'int' if t1 == 'float' else 'float' t_out = 'int' if t1 == 'float' else 'float'
code += get_cast(t1, t2, t_out) code += get_cast(t1, t2, t_out)
for t1, t2 in permutate(types, types):
code += get_sqrt(t1, t2)
for op, t1, t2 in permutate(ops, types, types): for op, t1, t2 in permutate(ops, types, types):
t_out = t1 if t1 == t2 else 'float' t_out = t1 if t1 == t2 else 'float'
if op == 'floordiv': if op == 'floordiv':
@ -211,7 +223,7 @@ if __name__ == "__main__":
code += get_op_code_float(op, t1, t2) code += get_op_code_float(op, t1, t2)
elif op == 'pow': elif op == 'pow':
code += get_pow(t1, t2) code += get_pow(t1, t2)
elif op == 'gt' or op == 'eq' or op == 'ne': elif op in {'gt', 'eq', 'ge', 'ne'}:
code += get_op_code(op, t1, t2, 'int') code += get_op_code(op, t1, t2, 'int')
else: else:
code += get_op_code(op, t1, t2, t_out) code += get_op_code(op, t1, t2, t_out)

View File

@ -1,31 +0,0 @@
from copapy import variable, Target
import pytest
import copapy
def test_compile():
c_i = variable(9)
c_f = variable(2.5)
# c_b = variable(True)
ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3)
ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3)
tg = Target()
print('* compile and copy ...')
tg.compile(ret_test)
print('* run and copy ...')
tg.run()
print('* finished')
for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, copapy.CPNumber)
val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
test_compile()

56
tests/test_math.py Normal file
View File

@ -0,0 +1,56 @@
from copapy import variable, Target
import pytest
import copapy
def test_corse():
c_i = variable(9)
c_f = variable(2.5)
# c_b = variable(True)
ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3)
ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3)
tg = Target()
print('* compile and copy ...')
tg.compile(ret_test)
print('* run and copy ...')
tg.run()
print('* finished')
for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, copapy.variable)
val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
def test_fine():
c_i = variable(9)
c_f = variable(2.5)
# c_b = variable(True)
ret_test = (c_f ** 2, c_i ** -1)#, c_i & 3)
ret_ref = (2.5 ** 2, 9 ** -1)#, 9 & 3)
tg = Target()
print('* compile and copy ...')
tg.compile(ret_test)
print('* run and copy ...')
tg.run()
print('* finished')
for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, copapy.variable)
val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, 0.001), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
test_corse()
test_fine()

View File

@ -57,7 +57,7 @@ def test_compile():
print('* finished') print('* finished')
for test, ref in zip(ret_test, ret_ref): for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, copapy.CPNumber) assert isinstance(test, copapy.variable)
val = tg.read_value(test) val = tg.read_value(test)
print('+', val, ref, test.dtype) print('+', val, ref, test.dtype)
for t in (int, float, bool): for t in (int, float, bool):

View File

@ -19,7 +19,7 @@ def test_compile():
print('* finished') print('* finished')
for test, ref in zip(ret_test, ret_ref): for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, copapy.CPNumber) assert isinstance(test, copapy.variable)
val = tg.read_value(test) val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype) print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool): #for t in (int, float, bool):

View File

@ -1,5 +1,29 @@
import copapy as cp import copapy as cp
def test_vec(): def test_vectors_init():
tt = cp.vector(range(3)) + cp.vector([1.1,2.2,3.3]) tt1 = cp.vector(range(3)) + cp.vector([1.1,2.2,3.3])
tt2 = (cp.vector(range(3)) + 5.6) tt2 = cp.vector([1.1,2,cp.variable(5)])# + cp.vector(range(3))
tt3 = (cp.vector(range(3)) + 5.6)
tt4 = cp.vector([1.1,2,3]) + cp.vector(cp.variable(v) for v in range(3))
tt5 = cp.vector([1,2,3]).dot(tt4)
print(tt1, tt2, tt3, tt4, tt5)
def test_compiled_vectors():
t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3))
t2 = t1.sum()
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
t4 = ((t3 * t1) * 2).magnitude()
tg = cp.Target()
tg.compile(t2, t4)
tg.run()
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
if __name__ == "__main__":
test_compiled_vectors()