refactoring API generics

This commit is contained in:
Nicolas Kruse 2025-10-24 00:36:22 +02:00
parent a8280f8d2d
commit cb1447f125
3 changed files with 39 additions and 41 deletions

View File

@ -4,15 +4,12 @@ from ._stencils import stencil_database
import platform
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool'
NumLikeAndNet: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool | Net'
NetAndNum: TypeAlias = 'Net | int | float'
unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool'
TNumber = TypeVar("TNumber", bound='CPNumber')
T = TypeVar("T")
TCPNum = TypeVar("TCPNum", bound='CPNumber')
TNum = TypeVar("TNum", int, bool, float)
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
@ -66,7 +63,7 @@ class CPNumber(Net):
self.source = source
@overload
def __mul__(self: TNumber, other: uniint) -> TNumber:
def __mul__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -77,7 +74,7 @@ class CPNumber(Net):
return _add_op('mul', [self, other], True)
@overload
def __rmul__(self: TNumber, other: uniint) -> TNumber:
def __rmul__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -88,7 +85,7 @@ class CPNumber(Net):
return _add_op('mul', [self, other], True)
@overload
def __add__(self: TNumber, other: uniint) -> TNumber:
def __add__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -99,7 +96,7 @@ class CPNumber(Net):
return _add_op('add', [self, other], True)
@overload
def __radd__(self: TNumber, other: uniint) -> TNumber:
def __radd__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -110,7 +107,7 @@ class CPNumber(Net):
return _add_op('add', [self, other], True)
@overload
def __sub__(self: TNumber, other: uniint) -> TNumber:
def __sub__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -121,7 +118,7 @@ class CPNumber(Net):
return _add_op('sub', [self, other])
@overload
def __rsub__(self: TNumber, other: uniint) -> TNumber:
def __rsub__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -138,7 +135,7 @@ class CPNumber(Net):
return _add_op('div', [other, self])
@overload
def __floordiv__(self: TNumber, other: uniint) -> TNumber:
def __floordiv__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -149,7 +146,7 @@ class CPNumber(Net):
return _add_op('floordiv', [self, other])
@overload
def __rfloordiv__(self: TNumber, other: uniint) -> TNumber:
def __rfloordiv__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -159,9 +156,8 @@ class CPNumber(Net):
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [other, self])
def __neg__(self: TNumber) -> TNumber:
assert isinstance(T, variable)
return cast(TNumber, _add_op('sub', [variable(0), self]))
def __neg__(self: TCPNum) -> TCPNum:
return cast(TCPNum, _add_op('sub', [variable(0), self]))
def __gt__(self, other: NumLike) -> 'variable[bool]':
ret = _add_op('gt', [self, other])
@ -180,7 +176,7 @@ class CPNumber(Net):
return variable(ret.source, dtype='bool')
@overload
def __mod__(self: TNumber, other: uniint) -> TNumber:
def __mod__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -191,7 +187,7 @@ class CPNumber(Net):
return _add_op('mod', [self, other])
@overload
def __rmod__(self: TNumber, other: uniint) -> TNumber:
def __rmod__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -202,7 +198,7 @@ class CPNumber(Net):
return _add_op('mod', [other, self])
@overload
def __pow__(self: TNumber, other: uniint) -> TNumber:
def __pow__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -213,7 +209,7 @@ class CPNumber(Net):
return _add_op('pow', [other, self])
@overload
def __rpow__(self: TNumber, other: uniint) -> TNumber:
def __rpow__(self: TCPNum, other: uniint) -> TCPNum:
...
@overload
@ -227,23 +223,21 @@ class CPNumber(Net):
return super().__hash__()
class variable(Generic[T], CPNumber):
def __init__(self, source: T | Node, dtype: str | None = None):
class variable(Generic[TNum], CPNumber):
def __init__(self, source: TNum | Node, dtype: str | None = None):
if isinstance(source, Node):
self.source = source
assert dtype, 'For source type Node a dtype argument is required.'
self.dtype = dtype
elif isinstance(source, bool):
self.source = CPConstant(source)
self.dtype = 'bool'
elif isinstance(source, int):
self.source = CPConstant(source)
self.dtype = 'int'
elif isinstance(source, float):
self.source = CPConstant(source)
self.dtype = 'float'
elif isinstance(source, bool):
self.source = CPConstant(source)
self.dtype = 'bool'
else:
raise ValueError(f'Non supported data type: {type(source).__name__}')
self.source = CPConstant(source)
self.dtype = 'int'
# Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniint) -> 'variable[int]':
@ -285,7 +279,7 @@ class CPConstant(Node):
class Write(Node):
def __init__(self, input: NetAndNum):
def __init__(self, input: Net | int | float):
if isinstance(input, Net):
net = input
else:
@ -324,17 +318,24 @@ def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) ->
@overload
def iif(expression: NumLike, true_result: T, false_result: T) -> T:
def iif(expression: float | int, true_result: TNum, false_result: TNum) -> TNum:
...
@overload
def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]) -> variable[TNum]:
...
@overload
def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]:
...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
allowed_type = (variable, int, float, bool)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
if isinstance(expression, CPNumber):
return (expression != 0) * true_result + (expression == 0) * false_result
else:
return true_result if expression else false_result
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]:

View File

@ -1,9 +1,6 @@
from numpy import isin
from copapy import NumLike, CPNumber, variable
from copapy import CPNumber, variable
from typing import Generic, TypeVar, Iterable, Any, overload
from copapy._basic_types import TNum
T = TypeVar("T", int, float, bool)
T2 = TypeVar("T2", bound=CPNumber)

View File

@ -8,8 +8,8 @@ def test_compile():
c_f = variable(2.5)
# c_b = variable(True)
ret_test = (c_f ** c_f, c_i ** c_i)
ret_ref = (2.5 ** 2.5, 9 ** 9)
ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3)
ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3)
tg = Target()
print('* compile and copy ...')