diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index bc2e326..19b0f73 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -166,8 +166,42 @@ def get_floordiv(op: str, type1: str, type2: str) -> str: """ else: return f""" - STENCIL void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ - result_float_{type2}(floorf((float)arg1 / (float)arg2), arg2); + STENCIL void {op}_{type1}_{type2}({type1} a, {type2} b) {{ + result_float_{type2}(floorf((float)a / (float)b), b); + }} + """ + + +@norm_indent +def get_min(type1: str, type2: str) -> str: + if type1 == 'int' and type2 == 'int': + return f""" + STENCIL void min_{type1}_{type2}({type1} a, {type2} b) {{ + result_int_{type2}(a < b ? a : b, b); + }} + """ + else: + return f""" + STENCIL void min_{type1}_{type2}({type1} a, {type2} b) {{ + float _a = (float)a; float _b = (float)b; + result_float_{type2}(_a < _b ? _a : _b, b); + }} + """ + + +@norm_indent +def get_max(type1: str, type2: str) -> str: + if type1 == 'int' and type2 == 'int': + return f""" + STENCIL void max_{type1}_{type2}({type1} a, {type2} b) {{ + result_int_{type2}(a > b ? a : b, b); + }} + """ + else: + return f""" + STENCIL void max_{type1}_{type2}({type1} a, {type2} b) {{ + float _a = (float)a; float _b = (float)b; + result_float_{type2}(_a > _b ? _a : _b, b); }} """ @@ -268,10 +302,17 @@ if __name__ == "__main__": code += get_math_func1('fabsf', 'float', 'abs') code += get_custom_stencil('abs_int(int arg1)', 'result_int(__builtin_abs(arg1));') + for t in types: + code += get_custom_stencil(f"sign_{t}({t} arg1)", f"result_int((arg1 > 0) - (arg1 < 0));") + fnames = ['atan2', 'pow'] for fn, t1, t2 in permutate(fnames, types, types): code += get_math_func2(fn, t1, t2) + for t1, t2 in permutate(types, types): + code += get_min(t1, t2) + code += get_max(t1, t2) + for op, t1, t2 in permutate(ops, types, types): t_out = t1 if t1 == t2 else 'float' if op == 'floordiv': diff --git a/tests/test_math.py b/tests/test_math.py index 17efda1..a6361ce 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -20,7 +20,11 @@ def test_fine(): cp.cos(c_f), cp.tan(c_f), cp.abs(-c_i), - cp.abs(-c_f)) + cp.abs(-c_f), + cp.sign(c_i), + cp.sign(-c_f), + cp.min(c_i, 5), + cp.max(c_f, 5)) re2_test = (a_f ** 2, a_i ** -1, @@ -32,7 +36,11 @@ def test_fine(): cp.cos(a_f), cp.tan(a_f), cp.abs(-a_i), - cp.abs(-a_f)) + cp.abs(-a_f), + cp.sign(a_i), + cp.sign(-a_f), + cp.min(a_i, 5), + cp.max(a_f, 5)) ret_refe = (a_f ** 2, a_i ** -1, @@ -43,8 +51,12 @@ def test_fine(): ma.sin(a_f), ma.cos(a_f), ma.tan(a_f), - cp.abs(-a_i), - cp.abs(-a_f)) + abs(-a_i), + abs(-a_f), + (a_i > 0) - (a_i < 0), + (-a_f > 0) - (-a_f < 0), + min(a_i, 5), + max(a_f, 5)) tg = Target() print('* compile and copy ...') @@ -53,10 +65,10 @@ def test_fine(): tg.run() print('* finished') - for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')): + for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ['^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan'] + ['other']*10): assert isinstance(test, cp.value) val = tg.read_value(test) - print('+', val, ref, type(val), test.dtype) + print('+', name, val, ref, type(val), test.dtype) #for t in (int, float, bool): # assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" assert val == pytest.approx(ref, abs=1e-3), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]