diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 3af237d..fcfb338 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -4,15 +4,12 @@ from ._stencils import stencil_database import platform NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool' -NumLikeAndNet: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool | Net' -NetAndNum: TypeAlias = 'Net | int | float' - unifloat: TypeAlias = 'variable[float] | float' uniint: TypeAlias = 'variable[int] | int' unibool: TypeAlias = 'variable[bool] | bool' -TNumber = TypeVar("TNumber", bound='CPNumber') -T = TypeVar("T") +TCPNum = TypeVar("TCPNum", bound='CPNumber') +TNum = TypeVar("TNum", int, bool, float) def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: @@ -66,7 +63,7 @@ class CPNumber(Net): self.source = source @overload - def __mul__(self: TNumber, other: uniint) -> TNumber: + def __mul__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -77,7 +74,7 @@ class CPNumber(Net): return _add_op('mul', [self, other], True) @overload - def __rmul__(self: TNumber, other: uniint) -> TNumber: + def __rmul__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -88,7 +85,7 @@ class CPNumber(Net): return _add_op('mul', [self, other], True) @overload - def __add__(self: TNumber, other: uniint) -> TNumber: + def __add__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -99,7 +96,7 @@ class CPNumber(Net): return _add_op('add', [self, other], True) @overload - def __radd__(self: TNumber, other: uniint) -> TNumber: + def __radd__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -110,7 +107,7 @@ class CPNumber(Net): return _add_op('add', [self, other], True) @overload - def __sub__(self: TNumber, other: uniint) -> TNumber: + def __sub__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -121,7 +118,7 @@ class CPNumber(Net): return _add_op('sub', [self, other]) @overload - def __rsub__(self: TNumber, other: uniint) -> TNumber: + def __rsub__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -138,7 +135,7 @@ class CPNumber(Net): return _add_op('div', [other, self]) @overload - def __floordiv__(self: TNumber, other: uniint) -> TNumber: + def __floordiv__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -149,7 +146,7 @@ class CPNumber(Net): return _add_op('floordiv', [self, other]) @overload - def __rfloordiv__(self: TNumber, other: uniint) -> TNumber: + def __rfloordiv__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -159,9 +156,8 @@ class CPNumber(Net): def __rfloordiv__(self, other: NumLike) -> 'CPNumber': return _add_op('floordiv', [other, self]) - def __neg__(self: TNumber) -> TNumber: - assert isinstance(T, variable) - return cast(TNumber, _add_op('sub', [variable(0), self])) + def __neg__(self: TCPNum) -> TCPNum: + return cast(TCPNum, _add_op('sub', [variable(0), self])) def __gt__(self, other: NumLike) -> 'variable[bool]': ret = _add_op('gt', [self, other]) @@ -180,7 +176,7 @@ class CPNumber(Net): return variable(ret.source, dtype='bool') @overload - def __mod__(self: TNumber, other: uniint) -> TNumber: + def __mod__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -191,7 +187,7 @@ class CPNumber(Net): return _add_op('mod', [self, other]) @overload - def __rmod__(self: TNumber, other: uniint) -> TNumber: + def __rmod__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -202,7 +198,7 @@ class CPNumber(Net): return _add_op('mod', [other, self]) @overload - def __pow__(self: TNumber, other: uniint) -> TNumber: + def __pow__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -213,7 +209,7 @@ class CPNumber(Net): return _add_op('pow', [other, self]) @overload - def __rpow__(self: TNumber, other: uniint) -> TNumber: + def __rpow__(self: TCPNum, other: uniint) -> TCPNum: ... @overload @@ -227,23 +223,21 @@ class CPNumber(Net): return super().__hash__() -class variable(Generic[T], CPNumber): - def __init__(self, source: T | Node, dtype: str | None = None): +class variable(Generic[TNum], CPNumber): + def __init__(self, source: TNum | Node, dtype: str | None = None): if isinstance(source, Node): self.source = source assert dtype, 'For source type Node a dtype argument is required.' self.dtype = dtype - elif isinstance(source, bool): - self.source = CPConstant(source) - self.dtype = 'bool' - elif isinstance(source, int): - self.source = CPConstant(source) - self.dtype = 'int' elif isinstance(source, float): self.source = CPConstant(source) self.dtype = 'float' + elif isinstance(source, bool): + self.source = CPConstant(source) + self.dtype = 'bool' else: - raise ValueError(f'Non supported data type: {type(source).__name__}') + self.source = CPConstant(source) + self.dtype = 'int' # Bitwise and shift operations for cp[int] def __lshift__(self, other: uniint) -> 'variable[int]': @@ -285,7 +279,7 @@ class CPConstant(Node): class Write(Node): - def __init__(self, input: NetAndNum): + def __init__(self, input: Net | int | float): if isinstance(input, Net): net = input else: @@ -324,17 +318,24 @@ def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> @overload -def iif(expression: NumLike, true_result: T, false_result: T) -> T: +def iif(expression: float | int, true_result: TNum, false_result: TNum) -> TNum: + ... + + +@overload +def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]) -> variable[TNum]: + ... + + +@overload +def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]: ... def iif(expression: Any, true_result: Any, false_result: Any) -> Any: allowed_type = (variable, int, float, bool) assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported" - if isinstance(expression, CPNumber): - return (expression != 0) * true_result + (expression == 0) * false_result - else: - return true_result if expression else false_result + return (expression != 0) * true_result + (expression == 0) * false_result def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]: diff --git a/src/copapy/_vectors.py b/src/copapy/_vectors.py index f6ef683..c129f20 100644 --- a/src/copapy/_vectors.py +++ b/src/copapy/_vectors.py @@ -1,9 +1,6 @@ -from numpy import isin -from copapy import NumLike, CPNumber, variable +from copapy import CPNumber, variable from typing import Generic, TypeVar, Iterable, Any, overload -from copapy._basic_types import TNum - T = TypeVar("T", int, float, bool) T2 = TypeVar("T2", bound=CPNumber) diff --git a/tests/test_ext_ops.py b/tests/test_ext_ops.py index e8086b0..d3a2955 100644 --- a/tests/test_ext_ops.py +++ b/tests/test_ext_ops.py @@ -8,8 +8,8 @@ def test_compile(): c_f = variable(2.5) # c_b = variable(True) - ret_test = (c_f ** c_f, c_i ** c_i) - ret_ref = (2.5 ** 2.5, 9 ** 9) + ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3) + ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3) tg = Target() print('* compile and copy ...')