mirror of https://github.com/Nonannet/copapy.git
iif function added with test
This commit is contained in:
parent
5d9c9511f5
commit
067e4f32eb
|
|
@ -18,6 +18,7 @@ uniint: TypeAlias = 'cpint | int'
|
||||||
unibool: TypeAlias = 'cpbool | bool'
|
unibool: TypeAlias = 'cpbool | bool'
|
||||||
|
|
||||||
TNumber = TypeVar("TNumber", bound='CPNumber')
|
TNumber = TypeVar("TNumber", bound='CPNumber')
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
||||||
|
|
@ -355,6 +356,36 @@ def net_from_value(value: Any) -> Net:
|
||||||
return Net(vi.dtype, vi)
|
return Net(vi.dtype, vi)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> cpbool: # pyright: ignore[reportOverlappingOverload]
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> cpint:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> cpfloat:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def iif(expression: NumLike, true_result: T, false_result: T) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
|
||||||
|
# TODO: check that input types are matching
|
||||||
|
alowed_type = cpint | cpfloat | cpbool | int | float | bool
|
||||||
|
assert isinstance(true_result, alowed_type) and isinstance(false_result, alowed_type), "Result type not supported"
|
||||||
|
if isinstance(expression, CPNumber):
|
||||||
|
return (expression != 0) * true_result + (expression == 0) * false_result
|
||||||
|
else:
|
||||||
|
return true_result if expression else false_result
|
||||||
|
|
||||||
|
|
||||||
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber:
|
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber:
|
||||||
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
|
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from copapy import cpvalue, Target, NumLike, Net, cpint
|
from copapy import cpvalue, Target, NumLike, Net, iif, cpint
|
||||||
from pytest import approx
|
from pytest import approx
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -32,14 +32,21 @@ def function6(c1: NumLike) -> list[NumLike]:
|
||||||
return [c1 == True]
|
return [c1 == True]
|
||||||
|
|
||||||
|
|
||||||
def test_compile():
|
def iiftests(c1: NumLike) -> list[NumLike]:
|
||||||
|
return [iif(c1 > 5, 8, 9),
|
||||||
|
iif(c1 < 5, 8.5, 9.5),
|
||||||
|
iif(1 > 5, 3.3, 8.8) + c1,
|
||||||
|
iif(1 < 5, c1 * 3.3, 8.8),
|
||||||
|
iif(c1 < 5, c1 * 3.3, 8.8)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile():
|
||||||
c_i = cpvalue(9)
|
c_i = cpvalue(9)
|
||||||
c_f = cpvalue(1.111)
|
c_f = cpvalue(1.111)
|
||||||
c_b = cpvalue(True)
|
c_b = cpvalue(True)
|
||||||
|
|
||||||
ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [cpint(9) % 2]
|
ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [cpint(9) % 2] + iiftests(c_i) + iiftests(c_f)
|
||||||
ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2]
|
ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2] + iiftests(9) + iiftests(1.111)
|
||||||
|
|
||||||
tg = Target()
|
tg = Target()
|
||||||
print('* compile and copy ...')
|
print('* compile and copy ...')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue