From 20a8dcd1a216c355a958d665a5b9028e4d17b0ab Mon Sep 17 00:00:00 2001 From: Nicolas Date: Fri, 14 Nov 2025 17:28:05 +0100 Subject: [PATCH] Single argument functions do not need a dummy argument anymore --- src/copapy/_basic_types.py | 2 +- src/copapy/_compiler.py | 9 ++++++--- src/copapy/_math.py | 18 +++++++++--------- stencils/generate_stencils.py | 8 ++++---- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 044203e..c77819e 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -333,7 +333,7 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool = arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] if commutative: - arg_nets = sorted(arg_nets, key=lambda a: a.dtype) + arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 7143a19..87cae78 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -99,7 +99,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No for node in node_list: if not isinstance(node, CPConstant): for i, net in enumerate(node.args): - if id(net) != id(registers[i]): + if id(net) != id(registers[i]): # TODO: consider register swap and commutative ops #if net in registers: # print('x swap registers') type_list = ['int' if r is None else transl_type(r.dtype) for r in registers] @@ -108,8 +108,11 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No registers[i] = net if node in net_lookup: - yield net_lookup[node], node - registers[0] = net_lookup[node] + result_net = net_lookup[node] + yield result_net, node + registers[0] = result_net + if len(node.args) < 2: # Reset virtual register for single argument functions + registers[1] = None else: yield None, node diff --git a/src/copapy/_math.py b/src/copapy/_math.py index 02ac4ae..4fed4fd 100644 --- a/src/copapy/_math.py +++ b/src/copapy/_math.py @@ -20,7 +20,7 @@ def exp(x: NumLike) -> variable[float] | float: result of e**x """ if isinstance(x, variable): - return add_op('exp', [x, x]) # TODO: fix 2. dummy argument + return add_op('exp', [x]) return float(math.exp(x)) @@ -38,7 +38,7 @@ def log(x: NumLike) -> variable[float] | float: result of ln(x) """ if isinstance(x, variable): - return add_op('log', [x, x]) # TODO: fix 2. dummy argument + return add_op('log', [x]) return float(math.log(x)) @@ -86,7 +86,7 @@ def sqrt(x: NumLike) -> variable[float] | float: Square root of x """ if isinstance(x, variable): - return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument + return add_op('sqrt', [x]) return float(math.sqrt(x)) @@ -104,7 +104,7 @@ def sin(x: NumLike) -> variable[float] | float: Square root of x """ if isinstance(x, variable): - return add_op('sin', [x, x]) # TODO: fix 2. dummy argument + return add_op('sin', [x]) return math.sin(x) @@ -122,7 +122,7 @@ def cos(x: NumLike) -> variable[float] | float: Cosine of x """ if isinstance(x, variable): - return add_op('cos', [x, x]) # TODO: fix 2. dummy argument + return add_op('cos', [x]) return math.cos(x) @@ -140,7 +140,7 @@ def tan(x: NumLike) -> variable[float] | float: Tangent of x """ if isinstance(x, variable): - return add_op('tan', [x, x]) # TODO: fix 2. dummy argument + return add_op('tan', [x]) return math.tan(x) @@ -158,7 +158,7 @@ def atan(x: NumLike) -> variable[float] | float: Inverse tangent of x """ if isinstance(x, variable): - return add_op('atan', [x, x]) # TODO: fix 2. dummy argument + return add_op('atan', [x]) return math.atan(x) @@ -177,7 +177,7 @@ def atan2(x: NumLike, y: NumLike) -> variable[float] | float: Result in radian """ if isinstance(x, variable) or isinstance(y, variable): - return add_op('atan2', [x, y]) # TODO: fix 2. dummy argument + return add_op('atan2', [x, y]) return math.atan2(x, y) @@ -195,7 +195,7 @@ def asin(x: NumLike) -> variable[float] | float: Inverse sine of x """ if isinstance(x, variable): - return add_op('asin', [x, x]) # TODO: fix 2. dummy argument + return add_op('asin', [x]) return math.asin(x) diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index 33a74c9..78051c3 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -102,10 +102,10 @@ def get_func2(func_name: str, type1: str, type2: str) -> str: @norm_indent -def get_math_func1(func_name: str, type1: str, type2: str) -> str: +def get_math_func1(func_name: str, type1: str) -> str: return f""" - STENCIL void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ - result_float_{type2}({func_name}f((float)arg1), arg2); + STENCIL void {func_name}_{type1}({type1} arg1) {{ + result_float({func_name}f((float)arg1)); }} """ @@ -242,7 +242,7 @@ if __name__ == "__main__": fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan'] for fn, t1 in permutate(fnames, types): - code += get_math_func1(fn, t1, t1) + code += get_math_func1(fn, t1) fnames = ['atan2', 'pow'] for fn, t1, t2 in permutate(fnames, types, types):