iif function added with test

This commit is contained in:
Nicolas Kruse 2025-10-19 22:48:52 +02:00
parent 5d9c9511f5
commit 067e4f32eb
2 changed files with 42 additions and 4 deletions

View File

@ -18,6 +18,7 @@ uniint: TypeAlias = 'cpint | int'
unibool: TypeAlias = 'cpbool | bool'
TNumber = TypeVar("TNumber", bound='CPNumber')
T = TypeVar("T")
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
@ -355,6 +356,36 @@ def net_from_value(value: Any) -> Net:
return Net(vi.dtype, vi)
@overload
def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> cpbool: # pyright: ignore[reportOverlappingOverload]
...
@overload
def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> cpint:
...
@overload
def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> cpfloat:
...
@overload
def iif(expression: NumLike, true_result: T, false_result: T) -> T:
...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
# TODO: check that input types are matching
alowed_type = cpint | cpfloat | cpbool | int | float | bool
assert isinstance(true_result, alowed_type) and isinstance(false_result, alowed_type), "Result type not supported"
if isinstance(expression, CPNumber):
return (expression != 0) * true_result + (expression == 0) * false_result
else:
return true_result if expression else false_result
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber:
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]

View File

@ -1,4 +1,4 @@
from copapy import cpvalue, Target, NumLike, Net, cpint
from copapy import cpvalue, Target, NumLike, Net, iif, cpint
from pytest import approx
@ -32,14 +32,21 @@ def function6(c1: NumLike) -> list[NumLike]:
return [c1 == True]
def test_compile():
def iiftests(c1: NumLike) -> list[NumLike]:
return [iif(c1 > 5, 8, 9),
iif(c1 < 5, 8.5, 9.5),
iif(1 > 5, 3.3, 8.8) + c1,
iif(1 < 5, c1 * 3.3, 8.8),
iif(c1 < 5, c1 * 3.3, 8.8)]
def test_compile():
c_i = cpvalue(9)
c_f = cpvalue(1.111)
c_b = cpvalue(True)
ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [cpint(9) % 2]
ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2]
ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [cpint(9) % 2] + iiftests(c_i) + iiftests(c_f)
ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2] + iiftests(9) + iiftests(1.111)
tg = Target()
print('* compile and copy ...')