mirror of https://github.com/Nonannet/copapy.git
min, max and sign stencil added
This commit is contained in:
parent
d71922769f
commit
32aad5cafd
|
|
@ -166,8 +166,42 @@ def get_floordiv(op: str, type1: str, type2: str) -> str:
|
||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
return f"""
|
return f"""
|
||||||
STENCIL void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
STENCIL void {op}_{type1}_{type2}({type1} a, {type2} b) {{
|
||||||
result_float_{type2}(floorf((float)arg1 / (float)arg2), arg2);
|
result_float_{type2}(floorf((float)a / (float)b), b);
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@norm_indent
|
||||||
|
def get_min(type1: str, type2: str) -> str:
|
||||||
|
if type1 == 'int' and type2 == 'int':
|
||||||
|
return f"""
|
||||||
|
STENCIL void min_{type1}_{type2}({type1} a, {type2} b) {{
|
||||||
|
result_int_{type2}(a < b ? a : b, b);
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
return f"""
|
||||||
|
STENCIL void min_{type1}_{type2}({type1} a, {type2} b) {{
|
||||||
|
float _a = (float)a; float _b = (float)b;
|
||||||
|
result_float_{type2}(_a < _b ? _a : _b, b);
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@norm_indent
|
||||||
|
def get_max(type1: str, type2: str) -> str:
|
||||||
|
if type1 == 'int' and type2 == 'int':
|
||||||
|
return f"""
|
||||||
|
STENCIL void max_{type1}_{type2}({type1} a, {type2} b) {{
|
||||||
|
result_int_{type2}(a > b ? a : b, b);
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
return f"""
|
||||||
|
STENCIL void max_{type1}_{type2}({type1} a, {type2} b) {{
|
||||||
|
float _a = (float)a; float _b = (float)b;
|
||||||
|
result_float_{type2}(_a > _b ? _a : _b, b);
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -268,10 +302,17 @@ if __name__ == "__main__":
|
||||||
code += get_math_func1('fabsf', 'float', 'abs')
|
code += get_math_func1('fabsf', 'float', 'abs')
|
||||||
code += get_custom_stencil('abs_int(int arg1)', 'result_int(__builtin_abs(arg1));')
|
code += get_custom_stencil('abs_int(int arg1)', 'result_int(__builtin_abs(arg1));')
|
||||||
|
|
||||||
|
for t in types:
|
||||||
|
code += get_custom_stencil(f"sign_{t}({t} arg1)", f"result_int((arg1 > 0) - (arg1 < 0));")
|
||||||
|
|
||||||
fnames = ['atan2', 'pow']
|
fnames = ['atan2', 'pow']
|
||||||
for fn, t1, t2 in permutate(fnames, types, types):
|
for fn, t1, t2 in permutate(fnames, types, types):
|
||||||
code += get_math_func2(fn, t1, t2)
|
code += get_math_func2(fn, t1, t2)
|
||||||
|
|
||||||
|
for t1, t2 in permutate(types, types):
|
||||||
|
code += get_min(t1, t2)
|
||||||
|
code += get_max(t1, t2)
|
||||||
|
|
||||||
for op, t1, t2 in permutate(ops, types, types):
|
for op, t1, t2 in permutate(ops, types, types):
|
||||||
t_out = t1 if t1 == t2 else 'float'
|
t_out = t1 if t1 == t2 else 'float'
|
||||||
if op == 'floordiv':
|
if op == 'floordiv':
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,11 @@ def test_fine():
|
||||||
cp.cos(c_f),
|
cp.cos(c_f),
|
||||||
cp.tan(c_f),
|
cp.tan(c_f),
|
||||||
cp.abs(-c_i),
|
cp.abs(-c_i),
|
||||||
cp.abs(-c_f))
|
cp.abs(-c_f),
|
||||||
|
cp.sign(c_i),
|
||||||
|
cp.sign(-c_f),
|
||||||
|
cp.min(c_i, 5),
|
||||||
|
cp.max(c_f, 5))
|
||||||
|
|
||||||
re2_test = (a_f ** 2,
|
re2_test = (a_f ** 2,
|
||||||
a_i ** -1,
|
a_i ** -1,
|
||||||
|
|
@ -32,7 +36,11 @@ def test_fine():
|
||||||
cp.cos(a_f),
|
cp.cos(a_f),
|
||||||
cp.tan(a_f),
|
cp.tan(a_f),
|
||||||
cp.abs(-a_i),
|
cp.abs(-a_i),
|
||||||
cp.abs(-a_f))
|
cp.abs(-a_f),
|
||||||
|
cp.sign(a_i),
|
||||||
|
cp.sign(-a_f),
|
||||||
|
cp.min(a_i, 5),
|
||||||
|
cp.max(a_f, 5))
|
||||||
|
|
||||||
ret_refe = (a_f ** 2,
|
ret_refe = (a_f ** 2,
|
||||||
a_i ** -1,
|
a_i ** -1,
|
||||||
|
|
@ -43,8 +51,12 @@ def test_fine():
|
||||||
ma.sin(a_f),
|
ma.sin(a_f),
|
||||||
ma.cos(a_f),
|
ma.cos(a_f),
|
||||||
ma.tan(a_f),
|
ma.tan(a_f),
|
||||||
cp.abs(-a_i),
|
abs(-a_i),
|
||||||
cp.abs(-a_f))
|
abs(-a_f),
|
||||||
|
(a_i > 0) - (a_i < 0),
|
||||||
|
(-a_f > 0) - (-a_f < 0),
|
||||||
|
min(a_i, 5),
|
||||||
|
max(a_f, 5))
|
||||||
|
|
||||||
tg = Target()
|
tg = Target()
|
||||||
print('* compile and copy ...')
|
print('* compile and copy ...')
|
||||||
|
|
@ -53,10 +65,10 @@ def test_fine():
|
||||||
tg.run()
|
tg.run()
|
||||||
print('* finished')
|
print('* finished')
|
||||||
|
|
||||||
for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ('^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan')):
|
for test, val2, ref, name in zip(ret_test, re2_test, ret_refe, ['^2', '**-1', 'sqrt_int', 'sqrt_float', 'sin', 'cos', 'tan'] + ['other']*10):
|
||||||
assert isinstance(test, cp.value)
|
assert isinstance(test, cp.value)
|
||||||
val = tg.read_value(test)
|
val = tg.read_value(test)
|
||||||
print('+', val, ref, type(val), test.dtype)
|
print('+', name, val, ref, type(val), test.dtype)
|
||||||
#for t in (int, float, bool):
|
#for t in (int, float, bool):
|
||||||
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
|
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
|
||||||
assert val == pytest.approx(ref, abs=1e-3), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
assert val == pytest.approx(ref, abs=1e-3), f"Result for {name} does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue