min, max and sign stencil added

This commit is contained in:
Nicolas 2026-01-05 13:38:49 +01:00
parent d71922769f
commit 32aad5cafd
2 changed files with 61 additions and 8 deletions

View File

@ -166,8 +166,42 @@ def get_floordiv(op: str, type1: str, type2: str) -> str:
"""
else:
return f"""
STENCIL void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_float_{type2}(floorf((float)arg1 / (float)arg2), arg2);
STENCIL void {op}_{type1}_{type2}({type1} a, {type2} b) {{
result_float_{type2}(floorf((float)a / (float)b), b);
}}
"""
@norm_indent
def get_min(type1: str, type2: str) -> str:
if type1 == 'int' and type2 == 'int':
return f"""
STENCIL void min_{type1}_{type2}({type1} a, {type2} b) {{
result_int_{type2}(a < b ? a : b, b);
}}
"""
else:
return f"""
STENCIL void min_{type1}_{type2}({type1} a, {type2} b) {{
float _a = (float)a; float _b = (float)b;
result_float_{type2}(_a < _b ? _a : _b, b);
}}
"""
@norm_indent
def get_max(type1: str, type2: str) -> str:
if type1 == 'int' and type2 == 'int':
return f"""
STENCIL void max_{type1}_{type2}({type1} a, {type2} b) {{
result_int_{type2}(a > b ? a : b, b);
}}
"""
else:
return f"""
STENCIL void max_{type1}_{type2}({type1} a, {type2} b) {{
float _a = (float)a; float _b = (float)b;
result_float_{type2}(_a > _b ? _a : _b, b);
}}
"""
@ -268,10 +302,17 @@ if __name__ == "__main__":
code += get_math_func1('fabsf', 'float', 'abs')
code += get_custom_stencil('abs_int(int arg1)', 'result_int(__builtin_abs(arg1));')
for t in types:
code += get_custom_stencil(f"sign_{t}({t} arg1)", f"result_int((arg1 > 0) - (arg1 < 0));")
fnames = ['atan2', 'pow']
for fn, t1, t2 in permutate(fnames, types, types):
code += get_math_func2(fn, t1, t2)
for t1, t2 in permutate(types, types):
code += get_min(t1, t2)
code += get_max(t1, t2)
for op, t1, t2 in permutate(ops, types, types):
t_out = t1 if t1 == t2 else 'float'
if op == 'floordiv':

View File

@ -20,7 +20,11 @@ def test_fine():
cp.cos(c_f),
cp.tan(c_f),
cp.abs(-c_i),
cp.abs(-c_f))
cp.abs(-c_f),
cp.sign(c_i),
cp.sign(-c_f),
cp.min(c_i, 5),
cp.max(c_f, 5))
re2_test = (a_f ** 2,
a_i ** -1,
@ -32,7 +36,11 @@ def test_fine():
cp.cos(a_f),
cp.tan(a_f),
cp.abs(-a_i),
cp.abs(-a_f))
cp.abs(-a_f),
cp.sign(a_i),
cp.sign(-a_f),
cp.min(a_i, 5),
cp.max(a_f, 5))
ret_refe = (a_f ** 2,
a_i ** -1,
@ -43,8 +51,12 @@ def test_fine():
ma.sin(a_f),
ma.cos(a_f),
ma.tan(a_f),
cp.abs(-a_i),
cp.abs(-a_f))
abs(-a_i),
abs(-a_f),
(a_i > 0) - (a_i < 0),
(-a_f > 0) - (-a_f < 0),
min(a_i, 5),
max(a_f, 5))
tg = Target()
print('* compile and copy ...')
@ -53,10 +65,10 @@ def test_fine():
tg.run()
print('* finished')
for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')):
for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ['^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan'] + ['other']*10):
assert isinstance(test, cp.value)
val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype)
print('+', name, val, ref, type(val), test.dtype)
#for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, abs=1e-3), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]