diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index 80bd15a..61846c0 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -83,6 +83,15 @@ def get_cast(type1: str, type2: str, type_out: str) -> str: """ +@norm_indent +def get_neg(type1: str) -> str: + return f""" + STENCIL void neg_{type1}({type1} arg1) {{ + result_{type1}(-arg1); + }} + """ + + @norm_indent def get_func1(func_name: str, type1: str) -> str: return f""" @@ -249,6 +258,9 @@ if __name__ == "__main__": for fn, t1 in permutate(fnames, types): code += get_func1(fn, t1) + for t in types: + code += get_neg(t) + fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan'] for fn, t1 in permutate(fnames, types): code += get_math_func1(fn + 'f', t1, fn)