From 3485584e5ea89e75179bd732969d52e87c745d45 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Thu, 9 Oct 2025 22:50:57 +0200 Subject: [PATCH] refactored names --- src/copapy/__init__.py | 46 +++++++++++++++++++---------------- tests/test_ast_gen.py | 6 ++--- tests/test_compile.py | 6 ++--- tests/test_compile_div.py | 4 +-- tests/test_coparun_module.py | 4 +-- tests/test_coparun_module2.py | 6 ++--- tests/test_crash_win.py | 6 ++--- 7 files changed, 41 insertions(+), 37 deletions(-) diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 21d21ca..e651ac2 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -32,7 +32,7 @@ class Node: def __repr__(self) -> str: #return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})" - return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})" + return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, InitVar) else '')})" class Device(): @@ -75,7 +75,7 @@ class Net: return _add_op('floordiv', [other, self]) def __neg__(self) -> 'Net': - return _add_op('sub', [const(0), self]) + return _add_op('sub', [CPVariable(0), self]) def __gt__(self, other: Any) -> 'Net': return _add_op('gt', [self, other]) @@ -100,14 +100,10 @@ class Net: return id(self) -class Const(Node): +class InitVar(Node): def __init__(self, value: float | int | bool): self.dtype, self.value = _get_data_and_dtype(value) self.name = 'const_' + self.dtype - - #if self.name not in _function_definitions: - # raise ValueError(f"Unsupported operand type for a const: {self.dtype}") - self.args = [] @@ -116,9 +112,6 @@ class Write(Node): self.name = 'write_' + net.dtype self.args = [net] - #if self.name not in _function_definitions: - # raise ValueError(f"Unsupported operand type for write: {net.dtype}") - class Op(Node): def __init__(self, typed_op_name: str, args: list[Net]): @@ -128,7 +121,7 @@ class Op(Node): def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: - arg_nets = [a if isinstance(a, Net) else const(a) for a in args] + arg_nets = [a if isinstance(a, Net) else CPVariable(a) for a in args] if commutative: arg_nets = sorted(arg_nets, key=lambda a: a.dtype) @@ -145,14 +138,25 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: return result_net -#def read_input(hw: Device, test_value: float): -# return Net(type(value)) +class CPVariable(Net): + def __init__(self, value: Any): + vi = InitVar(value) + Net.__init__(self, vi.dtype, vi) -def const(value: Any) -> Net: - assert isinstance(value, (int, float, bool)), f'Unsupported type for const: {type(value).__name__}' - new_const = Const(value) - return Net(new_const.dtype, new_const) +class CPFloat(CPVariable): + def __init__(self, value: float): + super().__init__(value) + + +class CPInt(CPVariable): + def __init__(self, value: int): + super().__init__(value) + + +class CPBool(CPVariable): + def __init__(self, value: bool): + super().__init__(value) def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: @@ -177,7 +181,7 @@ class vec3d: def const_vector3d(x: float, y: float, z: float) -> vec3d: - return vec3d((const(x), const(y), const(z))) + return vec3d((CPVariable(x), CPVariable(y), CPVariable(z))) def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: @@ -251,7 +255,7 @@ def get_const_nets(nodes: list[Node]) -> list[Net]: List of nets whose source node is a Const """ net_lookup = {net.source: net for node in nodes for net in node.args} - return [net_lookup[node] for node in nodes if isinstance(node, Const)] + return [net_lookup[node] for node in nodes if isinstance(node, InitVar)] def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]: @@ -272,7 +276,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No net_lookup = {net.source: net for node in node_list for net in node.args} for node in node_list: - if not node.name.startswith('const_'): + if not isinstance(node, InitVar): for i, net in enumerate(node.args): if id(net) != id(registers[i]): #if net in registers: @@ -365,7 +369,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database for net, out_offs, lengths in object_list: variables[net] = (out_offs, lengths, net.dtype) - if isinstance(net.source, Const): + if isinstance(net.source, InitVar): dw.write_com(binw.Command.COPY_DATA) dw.write_int(out_offs) dw.write_int(lengths) diff --git a/tests/test_ast_gen.py b/tests/test_ast_gen.py index fa7471e..4cead6f 100644 --- a/tests/test_ast_gen.py +++ b/tests/test_ast_gen.py @@ -1,4 +1,4 @@ -from copapy import Write, const +from copapy import Write, CPVariable import copapy as rc @@ -20,8 +20,8 @@ def test_ast_generation(): #r2 = i1 + 9 #out = [Write(r1), Write(r2)] - c1 = const(4) - c2 = const(2) + c1 = CPVariable(4) + c2 = CPVariable(2) #i1 = c1 * 2 #r1 = i1 + 7 + (c2 + 7 * 9) #r2 = i1 + 9 diff --git a/tests/test_compile.py b/tests/test_compile.py index f845b14..f020231 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -1,4 +1,4 @@ -from copapy import Write, const +from copapy import Write, CPVariable import copapy import subprocess import struct @@ -46,8 +46,8 @@ def function(c1, c2): def test_compile(): - c1 = const(4) - c2 = const(2) + c1 = CPVariable(4) + c2 = CPVariable(2) ret = function(c1, c2) diff --git a/tests/test_compile_div.py b/tests/test_compile_div.py index decc97c..c890a6f 100644 --- a/tests/test_compile_div.py +++ b/tests/test_compile_div.py @@ -1,4 +1,4 @@ -from copapy import Write, const +from copapy import Write, CPVariable import copapy import subprocess import struct @@ -19,7 +19,7 @@ def function(c1): def test_compile(): - c1 = const(16) + c1 = CPVariable(16) ret = function(c1) diff --git a/tests/test_coparun_module.py b/tests/test_coparun_module.py index a22dbe2..d6c5563 100644 --- a/tests/test_coparun_module.py +++ b/tests/test_coparun_module.py @@ -1,4 +1,4 @@ -from copapy import const, Target +from copapy import CPVariable, Target from pytest import approx @@ -10,7 +10,7 @@ def function(c1): def test_compile(): - c1 = const(16) + c1 = CPVariable(16) ret = function(c1) diff --git a/tests/test_coparun_module2.py b/tests/test_coparun_module2.py index 9499383..e0a2d0f 100644 --- a/tests/test_coparun_module2.py +++ b/tests/test_coparun_module2.py @@ -1,13 +1,13 @@ from coparun_module import coparun -from copapy import Write, const +from copapy import Write, CPVariable import copapy from copapy import binwrite def test_compile(): - c1 = const(4) - c2 = const(2) * 4 + c1 = CPVariable(4) + c2 = CPVariable(2) * 4 i1 = c2 * 2 r1 = i1 + 7 + (c1 + 7 * 9) diff --git a/tests/test_crash_win.py b/tests/test_crash_win.py index e48a61b..35fa7ff 100644 --- a/tests/test_crash_win.py +++ b/tests/test_crash_win.py @@ -1,4 +1,4 @@ -from copapy import Write, const +from copapy import Write, CPVariable import copapy import subprocess from copapy import binwrite @@ -23,8 +23,8 @@ def function(c1, c2): def test_compile(): - c1 = const(4) - c2 = const(2) + c1 = CPVariable(4) + c2 = CPVariable(2) ret = function(c1, c2)