mirror of https://github.com/Nonannet/copapy.git
refactoring API generics
This commit is contained in:
parent
a8280f8d2d
commit
cb1447f125
|
|
@ -4,15 +4,12 @@ from ._stencils import stencil_database
|
|||
import platform
|
||||
|
||||
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool'
|
||||
NumLikeAndNet: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool | Net'
|
||||
NetAndNum: TypeAlias = 'Net | int | float'
|
||||
|
||||
unifloat: TypeAlias = 'variable[float] | float'
|
||||
uniint: TypeAlias = 'variable[int] | int'
|
||||
unibool: TypeAlias = 'variable[bool] | bool'
|
||||
|
||||
TNumber = TypeVar("TNumber", bound='CPNumber')
|
||||
T = TypeVar("T")
|
||||
TCPNum = TypeVar("TCPNum", bound='CPNumber')
|
||||
TNum = TypeVar("TNum", int, bool, float)
|
||||
|
||||
|
||||
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
||||
|
|
@ -66,7 +63,7 @@ class CPNumber(Net):
|
|||
self.source = source
|
||||
|
||||
@overload
|
||||
def __mul__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __mul__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -77,7 +74,7 @@ class CPNumber(Net):
|
|||
return _add_op('mul', [self, other], True)
|
||||
|
||||
@overload
|
||||
def __rmul__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __rmul__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -88,7 +85,7 @@ class CPNumber(Net):
|
|||
return _add_op('mul', [self, other], True)
|
||||
|
||||
@overload
|
||||
def __add__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __add__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -99,7 +96,7 @@ class CPNumber(Net):
|
|||
return _add_op('add', [self, other], True)
|
||||
|
||||
@overload
|
||||
def __radd__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __radd__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -110,7 +107,7 @@ class CPNumber(Net):
|
|||
return _add_op('add', [self, other], True)
|
||||
|
||||
@overload
|
||||
def __sub__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __sub__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -121,7 +118,7 @@ class CPNumber(Net):
|
|||
return _add_op('sub', [self, other])
|
||||
|
||||
@overload
|
||||
def __rsub__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __rsub__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -138,7 +135,7 @@ class CPNumber(Net):
|
|||
return _add_op('div', [other, self])
|
||||
|
||||
@overload
|
||||
def __floordiv__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __floordiv__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -149,7 +146,7 @@ class CPNumber(Net):
|
|||
return _add_op('floordiv', [self, other])
|
||||
|
||||
@overload
|
||||
def __rfloordiv__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __rfloordiv__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -159,9 +156,8 @@ class CPNumber(Net):
|
|||
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('floordiv', [other, self])
|
||||
|
||||
def __neg__(self: TNumber) -> TNumber:
|
||||
assert isinstance(T, variable)
|
||||
return cast(TNumber, _add_op('sub', [variable(0), self]))
|
||||
def __neg__(self: TCPNum) -> TCPNum:
|
||||
return cast(TCPNum, _add_op('sub', [variable(0), self]))
|
||||
|
||||
def __gt__(self, other: NumLike) -> 'variable[bool]':
|
||||
ret = _add_op('gt', [self, other])
|
||||
|
|
@ -180,7 +176,7 @@ class CPNumber(Net):
|
|||
return variable(ret.source, dtype='bool')
|
||||
|
||||
@overload
|
||||
def __mod__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __mod__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -191,7 +187,7 @@ class CPNumber(Net):
|
|||
return _add_op('mod', [self, other])
|
||||
|
||||
@overload
|
||||
def __rmod__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __rmod__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -202,7 +198,7 @@ class CPNumber(Net):
|
|||
return _add_op('mod', [other, self])
|
||||
|
||||
@overload
|
||||
def __pow__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __pow__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -213,7 +209,7 @@ class CPNumber(Net):
|
|||
return _add_op('pow', [other, self])
|
||||
|
||||
@overload
|
||||
def __rpow__(self: TNumber, other: uniint) -> TNumber:
|
||||
def __rpow__(self: TCPNum, other: uniint) -> TCPNum:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -227,23 +223,21 @@ class CPNumber(Net):
|
|||
return super().__hash__()
|
||||
|
||||
|
||||
class variable(Generic[T], CPNumber):
|
||||
def __init__(self, source: T | Node, dtype: str | None = None):
|
||||
class variable(Generic[TNum], CPNumber):
|
||||
def __init__(self, source: TNum | Node, dtype: str | None = None):
|
||||
if isinstance(source, Node):
|
||||
self.source = source
|
||||
assert dtype, 'For source type Node a dtype argument is required.'
|
||||
self.dtype = dtype
|
||||
elif isinstance(source, bool):
|
||||
self.source = CPConstant(source)
|
||||
self.dtype = 'bool'
|
||||
elif isinstance(source, int):
|
||||
self.source = CPConstant(source)
|
||||
self.dtype = 'int'
|
||||
elif isinstance(source, float):
|
||||
self.source = CPConstant(source)
|
||||
self.dtype = 'float'
|
||||
elif isinstance(source, bool):
|
||||
self.source = CPConstant(source)
|
||||
self.dtype = 'bool'
|
||||
else:
|
||||
raise ValueError(f'Non supported data type: {type(source).__name__}')
|
||||
self.source = CPConstant(source)
|
||||
self.dtype = 'int'
|
||||
|
||||
# Bitwise and shift operations for cp[int]
|
||||
def __lshift__(self, other: uniint) -> 'variable[int]':
|
||||
|
|
@ -285,7 +279,7 @@ class CPConstant(Node):
|
|||
|
||||
|
||||
class Write(Node):
|
||||
def __init__(self, input: NetAndNum):
|
||||
def __init__(self, input: Net | int | float):
|
||||
if isinstance(input, Net):
|
||||
net = input
|
||||
else:
|
||||
|
|
@ -324,17 +318,24 @@ def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) ->
|
|||
|
||||
|
||||
@overload
|
||||
def iif(expression: NumLike, true_result: T, false_result: T) -> T:
|
||||
def iif(expression: float | int, true_result: TNum, false_result: TNum) -> TNum:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]) -> variable[TNum]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]:
|
||||
...
|
||||
|
||||
|
||||
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
|
||||
allowed_type = (variable, int, float, bool)
|
||||
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
|
||||
if isinstance(expression, CPNumber):
|
||||
return (expression != 0) * true_result + (expression == 0) * false_result
|
||||
else:
|
||||
return true_result if expression else false_result
|
||||
return (expression != 0) * true_result + (expression == 0) * false_result
|
||||
|
||||
|
||||
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
from numpy import isin
|
||||
from copapy import NumLike, CPNumber, variable
|
||||
from copapy import CPNumber, variable
|
||||
from typing import Generic, TypeVar, Iterable, Any, overload
|
||||
|
||||
from copapy._basic_types import TNum
|
||||
|
||||
T = TypeVar("T", int, float, bool)
|
||||
T2 = TypeVar("T2", bound=CPNumber)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ def test_compile():
|
|||
c_f = variable(2.5)
|
||||
# c_b = variable(True)
|
||||
|
||||
ret_test = (c_f ** c_f, c_i ** c_i)
|
||||
ret_ref = (2.5 ** 2.5, 9 ** 9)
|
||||
ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3)
|
||||
ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3)
|
||||
|
||||
tg = Target()
|
||||
print('* compile and copy ...')
|
||||
|
|
|
|||
Loading…
Reference in New Issue