Single argument functions do not need a dummy argument anymore

This commit is contained in:
Nicolas 2025-11-14 17:28:05 +01:00
parent 4b752a6094
commit 20a8dcd1a2
4 changed files with 20 additions and 17 deletions

View File

@ -333,7 +333,7 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool =
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
if commutative: if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])

View File

@ -99,7 +99,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
for node in node_list: for node in node_list:
if not isinstance(node, CPConstant): if not isinstance(node, CPConstant):
for i, net in enumerate(node.args): for i, net in enumerate(node.args):
if id(net) != id(registers[i]): if id(net) != id(registers[i]): # TODO: consider register swap and commutative ops
#if net in registers: #if net in registers:
# print('x swap registers') # print('x swap registers')
type_list = ['int' if r is None else transl_type(r.dtype) for r in registers] type_list = ['int' if r is None else transl_type(r.dtype) for r in registers]
@ -108,8 +108,11 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
registers[i] = net registers[i] = net
if node in net_lookup: if node in net_lookup:
yield net_lookup[node], node result_net = net_lookup[node]
registers[0] = net_lookup[node] yield result_net, node
registers[0] = result_net
if len(node.args) < 2: # Reset virtual register for single argument functions
registers[1] = None
else: else:
yield None, node yield None, node

View File

@ -20,7 +20,7 @@ def exp(x: NumLike) -> variable[float] | float:
result of e**x result of e**x
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('exp', [x, x]) # TODO: fix 2. dummy argument return add_op('exp', [x])
return float(math.exp(x)) return float(math.exp(x))
@ -38,7 +38,7 @@ def log(x: NumLike) -> variable[float] | float:
result of ln(x) result of ln(x)
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('log', [x, x]) # TODO: fix 2. dummy argument return add_op('log', [x])
return float(math.log(x)) return float(math.log(x))
@ -86,7 +86,7 @@ def sqrt(x: NumLike) -> variable[float] | float:
Square root of x Square root of x
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument return add_op('sqrt', [x])
return float(math.sqrt(x)) return float(math.sqrt(x))
@ -104,7 +104,7 @@ def sin(x: NumLike) -> variable[float] | float:
Square root of x Square root of x
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('sin', [x, x]) # TODO: fix 2. dummy argument return add_op('sin', [x])
return math.sin(x) return math.sin(x)
@ -122,7 +122,7 @@ def cos(x: NumLike) -> variable[float] | float:
Cosine of x Cosine of x
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('cos', [x, x]) # TODO: fix 2. dummy argument return add_op('cos', [x])
return math.cos(x) return math.cos(x)
@ -140,7 +140,7 @@ def tan(x: NumLike) -> variable[float] | float:
Tangent of x Tangent of x
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('tan', [x, x]) # TODO: fix 2. dummy argument return add_op('tan', [x])
return math.tan(x) return math.tan(x)
@ -158,7 +158,7 @@ def atan(x: NumLike) -> variable[float] | float:
Inverse tangent of x Inverse tangent of x
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('atan', [x, x]) # TODO: fix 2. dummy argument return add_op('atan', [x])
return math.atan(x) return math.atan(x)
@ -177,7 +177,7 @@ def atan2(x: NumLike, y: NumLike) -> variable[float] | float:
Result in radian Result in radian
""" """
if isinstance(x, variable) or isinstance(y, variable): if isinstance(x, variable) or isinstance(y, variable):
return add_op('atan2', [x, y]) # TODO: fix 2. dummy argument return add_op('atan2', [x, y])
return math.atan2(x, y) return math.atan2(x, y)
@ -195,7 +195,7 @@ def asin(x: NumLike) -> variable[float] | float:
Inverse sine of x Inverse sine of x
""" """
if isinstance(x, variable): if isinstance(x, variable):
return add_op('asin', [x, x]) # TODO: fix 2. dummy argument return add_op('asin', [x])
return math.asin(x) return math.asin(x)

View File

@ -102,10 +102,10 @@ def get_func2(func_name: str, type1: str, type2: str) -> str:
@norm_indent @norm_indent
def get_math_func1(func_name: str, type1: str, type2: str) -> str: def get_math_func1(func_name: str, type1: str) -> str:
return f""" return f"""
STENCIL void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ STENCIL void {func_name}_{type1}({type1} arg1) {{
result_float_{type2}({func_name}f((float)arg1), arg2); result_float({func_name}f((float)arg1));
}} }}
""" """
@ -242,7 +242,7 @@ if __name__ == "__main__":
fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan'] fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan']
for fn, t1 in permutate(fnames, types): for fn, t1 in permutate(fnames, types):
code += get_math_func1(fn, t1, t1) code += get_math_func1(fn, t1)
fnames = ['atan2', 'pow'] fnames = ['atan2', 'pow']
for fn, t1, t2 in permutate(fnames, types, types): for fn, t1, t2 in permutate(fnames, types, types):