mirror of https://github.com/Nonannet/copapy.git
Single argument functions do not need a dummy argument anymore
This commit is contained in:
parent
4b752a6094
commit
20a8dcd1a2
|
|
@ -333,7 +333,7 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool =
|
||||||
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
|
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
|
||||||
|
|
||||||
if commutative:
|
if commutative:
|
||||||
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
|
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order
|
||||||
|
|
||||||
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])
|
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
|
||||||
for node in node_list:
|
for node in node_list:
|
||||||
if not isinstance(node, CPConstant):
|
if not isinstance(node, CPConstant):
|
||||||
for i, net in enumerate(node.args):
|
for i, net in enumerate(node.args):
|
||||||
if id(net) != id(registers[i]):
|
if id(net) != id(registers[i]): # TODO: consider register swap and commutative ops
|
||||||
#if net in registers:
|
#if net in registers:
|
||||||
# print('x swap registers')
|
# print('x swap registers')
|
||||||
type_list = ['int' if r is None else transl_type(r.dtype) for r in registers]
|
type_list = ['int' if r is None else transl_type(r.dtype) for r in registers]
|
||||||
|
|
@ -108,8 +108,11 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
|
||||||
registers[i] = net
|
registers[i] = net
|
||||||
|
|
||||||
if node in net_lookup:
|
if node in net_lookup:
|
||||||
yield net_lookup[node], node
|
result_net = net_lookup[node]
|
||||||
registers[0] = net_lookup[node]
|
yield result_net, node
|
||||||
|
registers[0] = result_net
|
||||||
|
if len(node.args) < 2: # Reset virtual register for single argument functions
|
||||||
|
registers[1] = None
|
||||||
else:
|
else:
|
||||||
yield None, node
|
yield None, node
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ def exp(x: NumLike) -> variable[float] | float:
|
||||||
result of e**x
|
result of e**x
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('exp', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('exp', [x])
|
||||||
return float(math.exp(x))
|
return float(math.exp(x))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ def log(x: NumLike) -> variable[float] | float:
|
||||||
result of ln(x)
|
result of ln(x)
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('log', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('log', [x])
|
||||||
return float(math.log(x))
|
return float(math.log(x))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -86,7 +86,7 @@ def sqrt(x: NumLike) -> variable[float] | float:
|
||||||
Square root of x
|
Square root of x
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('sqrt', [x])
|
||||||
return float(math.sqrt(x))
|
return float(math.sqrt(x))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -104,7 +104,7 @@ def sin(x: NumLike) -> variable[float] | float:
|
||||||
Square root of x
|
Square root of x
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('sin', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('sin', [x])
|
||||||
return math.sin(x)
|
return math.sin(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -122,7 +122,7 @@ def cos(x: NumLike) -> variable[float] | float:
|
||||||
Cosine of x
|
Cosine of x
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('cos', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('cos', [x])
|
||||||
return math.cos(x)
|
return math.cos(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,7 +140,7 @@ def tan(x: NumLike) -> variable[float] | float:
|
||||||
Tangent of x
|
Tangent of x
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('tan', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('tan', [x])
|
||||||
return math.tan(x)
|
return math.tan(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -158,7 +158,7 @@ def atan(x: NumLike) -> variable[float] | float:
|
||||||
Inverse tangent of x
|
Inverse tangent of x
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('atan', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('atan', [x])
|
||||||
return math.atan(x)
|
return math.atan(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -177,7 +177,7 @@ def atan2(x: NumLike, y: NumLike) -> variable[float] | float:
|
||||||
Result in radian
|
Result in radian
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable) or isinstance(y, variable):
|
if isinstance(x, variable) or isinstance(y, variable):
|
||||||
return add_op('atan2', [x, y]) # TODO: fix 2. dummy argument
|
return add_op('atan2', [x, y])
|
||||||
return math.atan2(x, y)
|
return math.atan2(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -195,7 +195,7 @@ def asin(x: NumLike) -> variable[float] | float:
|
||||||
Inverse sine of x
|
Inverse sine of x
|
||||||
"""
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('asin', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('asin', [x])
|
||||||
return math.asin(x)
|
return math.asin(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,10 +102,10 @@ def get_func2(func_name: str, type1: str, type2: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
@norm_indent
|
@norm_indent
|
||||||
def get_math_func1(func_name: str, type1: str, type2: str) -> str:
|
def get_math_func1(func_name: str, type1: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
STENCIL void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
STENCIL void {func_name}_{type1}({type1} arg1) {{
|
||||||
result_float_{type2}({func_name}f((float)arg1), arg2);
|
result_float({func_name}f((float)arg1));
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -242,7 +242,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan']
|
fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan']
|
||||||
for fn, t1 in permutate(fnames, types):
|
for fn, t1 in permutate(fnames, types):
|
||||||
code += get_math_func1(fn, t1, t1)
|
code += get_math_func1(fn, t1)
|
||||||
|
|
||||||
fnames = ['atan2', 'pow']
|
fnames = ['atan2', 'pow']
|
||||||
for fn, t1, t2 in permutate(fnames, types, types):
|
for fn, t1, t2 in permutate(fnames, types, types):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue