diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index acf7155..39a4b4e 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -18,6 +18,7 @@ uniint: TypeAlias = 'cpint | int' unibool: TypeAlias = 'cpbool | bool' TNumber = TypeVar("TNumber", bound='CPNumber') +T = TypeVar("T") def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: @@ -355,6 +356,36 @@ def net_from_value(value: Any) -> Net: return Net(vi.dtype, vi) +@overload +def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> cpbool: # pyright: ignore[reportOverlappingOverload] + ... + + +@overload +def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> cpint: + ... + + +@overload +def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> cpfloat: + ... + + +@overload +def iif(expression: NumLike, true_result: T, false_result: T) -> T: + ... + + +def iif(expression: Any, true_result: Any, false_result: Any) -> Any: + # TODO: check that input types are matching + alowed_type = cpint | cpfloat | cpbool | int | float | bool + assert isinstance(true_result, alowed_type) and isinstance(false_result, alowed_type), "Result type not supported" + if isinstance(expression, CPNumber): + return (expression != 0) * true_result + (expression == 0) * false_result + else: + return true_result if expression else false_result + + def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber: arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] diff --git a/tests/test_ops.py b/tests/test_ops.py index 115f4aa..8b69f74 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -1,4 +1,4 @@ -from copapy import cpvalue, Target, NumLike, Net, cpint +from copapy import cpvalue, Target, NumLike, Net, iif, cpint from pytest import approx @@ -32,14 +32,21 @@ def function6(c1: NumLike) -> list[NumLike]: return [c1 == True] -def test_compile(): +def iiftests(c1: NumLike) -> list[NumLike]: + return [iif(c1 > 5, 8, 9), + iif(c1 < 5, 8.5, 9.5), + iif(1 > 5, 3.3, 8.8) + c1, + iif(1 < 5, c1 * 3.3, 8.8), + iif(c1 < 5, c1 * 3.3, 8.8)] + +def test_compile(): c_i = cpvalue(9) c_f = cpvalue(1.111) c_b = cpvalue(True) - ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [cpint(9) % 2] - ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2] + ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [cpint(9) % 2] + iiftests(c_i) + iiftests(c_f) + ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2] + iiftests(9) + iiftests(1.111) tg = Target() print('* compile and copy ...')