mirror of https://github.com/Nonannet/copapy.git
refactoring API generics
This commit is contained in:
parent
a8280f8d2d
commit
cb1447f125
|
|
@ -4,15 +4,12 @@ from ._stencils import stencil_database
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool'
|
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool'
|
||||||
NumLikeAndNet: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool | Net'
|
|
||||||
NetAndNum: TypeAlias = 'Net | int | float'
|
|
||||||
|
|
||||||
unifloat: TypeAlias = 'variable[float] | float'
|
unifloat: TypeAlias = 'variable[float] | float'
|
||||||
uniint: TypeAlias = 'variable[int] | int'
|
uniint: TypeAlias = 'variable[int] | int'
|
||||||
unibool: TypeAlias = 'variable[bool] | bool'
|
unibool: TypeAlias = 'variable[bool] | bool'
|
||||||
|
|
||||||
TNumber = TypeVar("TNumber", bound='CPNumber')
|
TCPNum = TypeVar("TCPNum", bound='CPNumber')
|
||||||
T = TypeVar("T")
|
TNum = TypeVar("TNum", int, bool, float)
|
||||||
|
|
||||||
|
|
||||||
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
||||||
|
|
@ -66,7 +63,7 @@ class CPNumber(Net):
|
||||||
self.source = source
|
self.source = source
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __mul__(self: TNumber, other: uniint) -> TNumber:
|
def __mul__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -77,7 +74,7 @@ class CPNumber(Net):
|
||||||
return _add_op('mul', [self, other], True)
|
return _add_op('mul', [self, other], True)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __rmul__(self: TNumber, other: uniint) -> TNumber:
|
def __rmul__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -88,7 +85,7 @@ class CPNumber(Net):
|
||||||
return _add_op('mul', [self, other], True)
|
return _add_op('mul', [self, other], True)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __add__(self: TNumber, other: uniint) -> TNumber:
|
def __add__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -99,7 +96,7 @@ class CPNumber(Net):
|
||||||
return _add_op('add', [self, other], True)
|
return _add_op('add', [self, other], True)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __radd__(self: TNumber, other: uniint) -> TNumber:
|
def __radd__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -110,7 +107,7 @@ class CPNumber(Net):
|
||||||
return _add_op('add', [self, other], True)
|
return _add_op('add', [self, other], True)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __sub__(self: TNumber, other: uniint) -> TNumber:
|
def __sub__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -121,7 +118,7 @@ class CPNumber(Net):
|
||||||
return _add_op('sub', [self, other])
|
return _add_op('sub', [self, other])
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __rsub__(self: TNumber, other: uniint) -> TNumber:
|
def __rsub__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -138,7 +135,7 @@ class CPNumber(Net):
|
||||||
return _add_op('div', [other, self])
|
return _add_op('div', [other, self])
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __floordiv__(self: TNumber, other: uniint) -> TNumber:
|
def __floordiv__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -149,7 +146,7 @@ class CPNumber(Net):
|
||||||
return _add_op('floordiv', [self, other])
|
return _add_op('floordiv', [self, other])
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __rfloordiv__(self: TNumber, other: uniint) -> TNumber:
|
def __rfloordiv__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -159,9 +156,8 @@ class CPNumber(Net):
|
||||||
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
|
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
|
||||||
return _add_op('floordiv', [other, self])
|
return _add_op('floordiv', [other, self])
|
||||||
|
|
||||||
def __neg__(self: TNumber) -> TNumber:
|
def __neg__(self: TCPNum) -> TCPNum:
|
||||||
assert isinstance(T, variable)
|
return cast(TCPNum, _add_op('sub', [variable(0), self]))
|
||||||
return cast(TNumber, _add_op('sub', [variable(0), self]))
|
|
||||||
|
|
||||||
def __gt__(self, other: NumLike) -> 'variable[bool]':
|
def __gt__(self, other: NumLike) -> 'variable[bool]':
|
||||||
ret = _add_op('gt', [self, other])
|
ret = _add_op('gt', [self, other])
|
||||||
|
|
@ -180,7 +176,7 @@ class CPNumber(Net):
|
||||||
return variable(ret.source, dtype='bool')
|
return variable(ret.source, dtype='bool')
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __mod__(self: TNumber, other: uniint) -> TNumber:
|
def __mod__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -191,7 +187,7 @@ class CPNumber(Net):
|
||||||
return _add_op('mod', [self, other])
|
return _add_op('mod', [self, other])
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __rmod__(self: TNumber, other: uniint) -> TNumber:
|
def __rmod__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -202,7 +198,7 @@ class CPNumber(Net):
|
||||||
return _add_op('mod', [other, self])
|
return _add_op('mod', [other, self])
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __pow__(self: TNumber, other: uniint) -> TNumber:
|
def __pow__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -213,7 +209,7 @@ class CPNumber(Net):
|
||||||
return _add_op('pow', [other, self])
|
return _add_op('pow', [other, self])
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __rpow__(self: TNumber, other: uniint) -> TNumber:
|
def __rpow__(self: TCPNum, other: uniint) -> TCPNum:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -227,23 +223,21 @@ class CPNumber(Net):
|
||||||
return super().__hash__()
|
return super().__hash__()
|
||||||
|
|
||||||
|
|
||||||
class variable(Generic[T], CPNumber):
|
class variable(Generic[TNum], CPNumber):
|
||||||
def __init__(self, source: T | Node, dtype: str | None = None):
|
def __init__(self, source: TNum | Node, dtype: str | None = None):
|
||||||
if isinstance(source, Node):
|
if isinstance(source, Node):
|
||||||
self.source = source
|
self.source = source
|
||||||
assert dtype, 'For source type Node a dtype argument is required.'
|
assert dtype, 'For source type Node a dtype argument is required.'
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
elif isinstance(source, bool):
|
|
||||||
self.source = CPConstant(source)
|
|
||||||
self.dtype = 'bool'
|
|
||||||
elif isinstance(source, int):
|
|
||||||
self.source = CPConstant(source)
|
|
||||||
self.dtype = 'int'
|
|
||||||
elif isinstance(source, float):
|
elif isinstance(source, float):
|
||||||
self.source = CPConstant(source)
|
self.source = CPConstant(source)
|
||||||
self.dtype = 'float'
|
self.dtype = 'float'
|
||||||
|
elif isinstance(source, bool):
|
||||||
|
self.source = CPConstant(source)
|
||||||
|
self.dtype = 'bool'
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Non supported data type: {type(source).__name__}')
|
self.source = CPConstant(source)
|
||||||
|
self.dtype = 'int'
|
||||||
|
|
||||||
# Bitwise and shift operations for cp[int]
|
# Bitwise and shift operations for cp[int]
|
||||||
def __lshift__(self, other: uniint) -> 'variable[int]':
|
def __lshift__(self, other: uniint) -> 'variable[int]':
|
||||||
|
|
@ -285,7 +279,7 @@ class CPConstant(Node):
|
||||||
|
|
||||||
|
|
||||||
class Write(Node):
|
class Write(Node):
|
||||||
def __init__(self, input: NetAndNum):
|
def __init__(self, input: Net | int | float):
|
||||||
if isinstance(input, Net):
|
if isinstance(input, Net):
|
||||||
net = input
|
net = input
|
||||||
else:
|
else:
|
||||||
|
|
@ -324,17 +318,24 @@ def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) ->
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def iif(expression: NumLike, true_result: T, false_result: T) -> T:
|
def iif(expression: float | int, true_result: TNum, false_result: TNum) -> TNum:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]) -> variable[TNum]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
|
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
|
||||||
allowed_type = (variable, int, float, bool)
|
allowed_type = (variable, int, float, bool)
|
||||||
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
|
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
|
||||||
if isinstance(expression, CPNumber):
|
|
||||||
return (expression != 0) * true_result + (expression == 0) * false_result
|
return (expression != 0) * true_result + (expression == 0) * false_result
|
||||||
else:
|
|
||||||
return true_result if expression else false_result
|
|
||||||
|
|
||||||
|
|
||||||
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]:
|
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,6 @@
|
||||||
from numpy import isin
|
from copapy import CPNumber, variable
|
||||||
from copapy import NumLike, CPNumber, variable
|
|
||||||
from typing import Generic, TypeVar, Iterable, Any, overload
|
from typing import Generic, TypeVar, Iterable, Any, overload
|
||||||
|
|
||||||
from copapy._basic_types import TNum
|
|
||||||
|
|
||||||
T = TypeVar("T", int, float, bool)
|
T = TypeVar("T", int, float, bool)
|
||||||
T2 = TypeVar("T2", bound=CPNumber)
|
T2 = TypeVar("T2", bound=CPNumber)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ def test_compile():
|
||||||
c_f = variable(2.5)
|
c_f = variable(2.5)
|
||||||
# c_b = variable(True)
|
# c_b = variable(True)
|
||||||
|
|
||||||
ret_test = (c_f ** c_f, c_i ** c_i)
|
ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3)
|
||||||
ret_ref = (2.5 ** 2.5, 9 ** 9)
|
ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3)
|
||||||
|
|
||||||
tg = Target()
|
tg = Target()
|
||||||
print('* compile and copy ...')
|
print('* compile and copy ...')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue