refactoring API generics

This commit is contained in:
Nicolas Kruse 2025-10-24 00:36:22 +02:00
parent a8280f8d2d
commit cb1447f125
3 changed files with 39 additions and 41 deletions

View File

@ -4,15 +4,12 @@ from ._stencils import stencil_database
import platform import platform
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool' NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool'
NumLikeAndNet: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool | Net'
NetAndNum: TypeAlias = 'Net | int | float'
unifloat: TypeAlias = 'variable[float] | float' unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int' uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool' unibool: TypeAlias = 'variable[bool] | bool'
TNumber = TypeVar("TNumber", bound='CPNumber') TCPNum = TypeVar("TCPNum", bound='CPNumber')
T = TypeVar("T") TNum = TypeVar("TNum", int, bool, float)
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
@ -66,7 +63,7 @@ class CPNumber(Net):
self.source = source self.source = source
@overload @overload
def __mul__(self: TNumber, other: uniint) -> TNumber: def __mul__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -77,7 +74,7 @@ class CPNumber(Net):
return _add_op('mul', [self, other], True) return _add_op('mul', [self, other], True)
@overload @overload
def __rmul__(self: TNumber, other: uniint) -> TNumber: def __rmul__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -88,7 +85,7 @@ class CPNumber(Net):
return _add_op('mul', [self, other], True) return _add_op('mul', [self, other], True)
@overload @overload
def __add__(self: TNumber, other: uniint) -> TNumber: def __add__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -99,7 +96,7 @@ class CPNumber(Net):
return _add_op('add', [self, other], True) return _add_op('add', [self, other], True)
@overload @overload
def __radd__(self: TNumber, other: uniint) -> TNumber: def __radd__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -110,7 +107,7 @@ class CPNumber(Net):
return _add_op('add', [self, other], True) return _add_op('add', [self, other], True)
@overload @overload
def __sub__(self: TNumber, other: uniint) -> TNumber: def __sub__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -121,7 +118,7 @@ class CPNumber(Net):
return _add_op('sub', [self, other]) return _add_op('sub', [self, other])
@overload @overload
def __rsub__(self: TNumber, other: uniint) -> TNumber: def __rsub__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -138,7 +135,7 @@ class CPNumber(Net):
return _add_op('div', [other, self]) return _add_op('div', [other, self])
@overload @overload
def __floordiv__(self: TNumber, other: uniint) -> TNumber: def __floordiv__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -149,7 +146,7 @@ class CPNumber(Net):
return _add_op('floordiv', [self, other]) return _add_op('floordiv', [self, other])
@overload @overload
def __rfloordiv__(self: TNumber, other: uniint) -> TNumber: def __rfloordiv__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -159,9 +156,8 @@ class CPNumber(Net):
def __rfloordiv__(self, other: NumLike) -> 'CPNumber': def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [other, self]) return _add_op('floordiv', [other, self])
def __neg__(self: TNumber) -> TNumber: def __neg__(self: TCPNum) -> TCPNum:
assert isinstance(T, variable) return cast(TCPNum, _add_op('sub', [variable(0), self]))
return cast(TNumber, _add_op('sub', [variable(0), self]))
def __gt__(self, other: NumLike) -> 'variable[bool]': def __gt__(self, other: NumLike) -> 'variable[bool]':
ret = _add_op('gt', [self, other]) ret = _add_op('gt', [self, other])
@ -180,7 +176,7 @@ class CPNumber(Net):
return variable(ret.source, dtype='bool') return variable(ret.source, dtype='bool')
@overload @overload
def __mod__(self: TNumber, other: uniint) -> TNumber: def __mod__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -191,7 +187,7 @@ class CPNumber(Net):
return _add_op('mod', [self, other]) return _add_op('mod', [self, other])
@overload @overload
def __rmod__(self: TNumber, other: uniint) -> TNumber: def __rmod__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -202,7 +198,7 @@ class CPNumber(Net):
return _add_op('mod', [other, self]) return _add_op('mod', [other, self])
@overload @overload
def __pow__(self: TNumber, other: uniint) -> TNumber: def __pow__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -213,7 +209,7 @@ class CPNumber(Net):
return _add_op('pow', [other, self]) return _add_op('pow', [other, self])
@overload @overload
def __rpow__(self: TNumber, other: uniint) -> TNumber: def __rpow__(self: TCPNum, other: uniint) -> TCPNum:
... ...
@overload @overload
@ -227,23 +223,21 @@ class CPNumber(Net):
return super().__hash__() return super().__hash__()
class variable(Generic[T], CPNumber): class variable(Generic[TNum], CPNumber):
def __init__(self, source: T | Node, dtype: str | None = None): def __init__(self, source: TNum | Node, dtype: str | None = None):
if isinstance(source, Node): if isinstance(source, Node):
self.source = source self.source = source
assert dtype, 'For source type Node a dtype argument is required.' assert dtype, 'For source type Node a dtype argument is required.'
self.dtype = dtype self.dtype = dtype
elif isinstance(source, bool):
self.source = CPConstant(source)
self.dtype = 'bool'
elif isinstance(source, int):
self.source = CPConstant(source)
self.dtype = 'int'
elif isinstance(source, float): elif isinstance(source, float):
self.source = CPConstant(source) self.source = CPConstant(source)
self.dtype = 'float' self.dtype = 'float'
elif isinstance(source, bool):
self.source = CPConstant(source)
self.dtype = 'bool'
else: else:
raise ValueError(f'Non supported data type: {type(source).__name__}') self.source = CPConstant(source)
self.dtype = 'int'
# Bitwise and shift operations for cp[int] # Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniint) -> 'variable[int]': def __lshift__(self, other: uniint) -> 'variable[int]':
@ -285,7 +279,7 @@ class CPConstant(Node):
class Write(Node): class Write(Node):
def __init__(self, input: NetAndNum): def __init__(self, input: Net | int | float):
if isinstance(input, Net): if isinstance(input, Net):
net = input net = input
else: else:
@ -324,17 +318,24 @@ def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) ->
@overload @overload
def iif(expression: NumLike, true_result: T, false_result: T) -> T: def iif(expression: float | int, true_result: TNum, false_result: TNum) -> TNum:
...
@overload
def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]) -> variable[TNum]:
...
@overload
def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]:
... ...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any: def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
allowed_type = (variable, int, float, bool) allowed_type = (variable, int, float, bool)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported" assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
if isinstance(expression, CPNumber): return (expression != 0) * true_result + (expression == 0) * false_result
return (expression != 0) * true_result + (expression == 0) * false_result
else:
return true_result if expression else false_result
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]: def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]:

View File

@ -1,9 +1,6 @@
from numpy import isin from copapy import CPNumber, variable
from copapy import NumLike, CPNumber, variable
from typing import Generic, TypeVar, Iterable, Any, overload from typing import Generic, TypeVar, Iterable, Any, overload
from copapy._basic_types import TNum
T = TypeVar("T", int, float, bool) T = TypeVar("T", int, float, bool)
T2 = TypeVar("T2", bound=CPNumber) T2 = TypeVar("T2", bound=CPNumber)

View File

@ -8,8 +8,8 @@ def test_compile():
c_f = variable(2.5) c_f = variable(2.5)
# c_b = variable(True) # c_b = variable(True)
ret_test = (c_f ** c_f, c_i ** c_i) ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3)
ret_ref = (2.5 ** 2.5, 9 ** 9) ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3)
tg = Target() tg = Target()
print('* compile and copy ...') print('* compile and copy ...')