refactored names

This commit is contained in:
Nicolas Kruse 2025-10-09 22:50:57 +02:00
parent 60aa550ec7
commit 3485584e5e
7 changed files with 41 additions and 37 deletions

View File

@ -32,7 +32,7 @@ class Node:
def __repr__(self) -> str: def __repr__(self) -> str:
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})" #return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})" return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, InitVar) else '')})"
class Device(): class Device():
@ -75,7 +75,7 @@ class Net:
return _add_op('floordiv', [other, self]) return _add_op('floordiv', [other, self])
def __neg__(self) -> 'Net': def __neg__(self) -> 'Net':
return _add_op('sub', [const(0), self]) return _add_op('sub', [CPVariable(0), self])
def __gt__(self, other: Any) -> 'Net': def __gt__(self, other: Any) -> 'Net':
return _add_op('gt', [self, other]) return _add_op('gt', [self, other])
@ -100,14 +100,10 @@ class Net:
return id(self) return id(self)
class Const(Node): class InitVar(Node):
def __init__(self, value: float | int | bool): def __init__(self, value: float | int | bool):
self.dtype, self.value = _get_data_and_dtype(value) self.dtype, self.value = _get_data_and_dtype(value)
self.name = 'const_' + self.dtype self.name = 'const_' + self.dtype
#if self.name not in _function_definitions:
# raise ValueError(f"Unsupported operand type for a const: {self.dtype}")
self.args = [] self.args = []
@ -116,9 +112,6 @@ class Write(Node):
self.name = 'write_' + net.dtype self.name = 'write_' + net.dtype
self.args = [net] self.args = [net]
#if self.name not in _function_definitions:
# raise ValueError(f"Unsupported operand type for write: {net.dtype}")
class Op(Node): class Op(Node):
def __init__(self, typed_op_name: str, args: list[Net]): def __init__(self, typed_op_name: str, args: list[Net]):
@ -128,7 +121,7 @@ class Op(Node):
def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
arg_nets = [a if isinstance(a, Net) else const(a) for a in args] arg_nets = [a if isinstance(a, Net) else CPVariable(a) for a in args]
if commutative: if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
@ -145,14 +138,25 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
return result_net return result_net
#def read_input(hw: Device, test_value: float): class CPVariable(Net):
# return Net(type(value)) def __init__(self, value: Any):
vi = InitVar(value)
Net.__init__(self, vi.dtype, vi)
def const(value: Any) -> Net: class CPFloat(CPVariable):
assert isinstance(value, (int, float, bool)), f'Unsupported type for const: {type(value).__name__}' def __init__(self, value: float):
new_const = Const(value) super().__init__(value)
return Net(new_const.dtype, new_const)
class CPInt(CPVariable):
def __init__(self, value: int):
super().__init__(value)
class CPBool(CPVariable):
def __init__(self, value: bool):
super().__init__(value)
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
@ -177,7 +181,7 @@ class vec3d:
def const_vector3d(x: float, y: float, z: float) -> vec3d: def const_vector3d(x: float, y: float, z: float) -> vec3d:
return vec3d((const(x), const(y), const(z))) return vec3d((CPVariable(x), CPVariable(y), CPVariable(z)))
def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
@ -251,7 +255,7 @@ def get_const_nets(nodes: list[Node]) -> list[Net]:
List of nets whose source node is a Const List of nets whose source node is a Const
""" """
net_lookup = {net.source: net for node in nodes for net in node.args} net_lookup = {net.source: net for node in nodes for net in node.args}
return [net_lookup[node] for node in nodes if isinstance(node, Const)] return [net_lookup[node] for node in nodes if isinstance(node, InitVar)]
def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]: def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
@ -272,7 +276,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
net_lookup = {net.source: net for node in node_list for net in node.args} net_lookup = {net.source: net for node in node_list for net in node.args}
for node in node_list: for node in node_list:
if not node.name.startswith('const_'): if not isinstance(node, InitVar):
for i, net in enumerate(node.args): for i, net in enumerate(node.args):
if id(net) != id(registers[i]): if id(net) != id(registers[i]):
#if net in registers: #if net in registers:
@ -365,7 +369,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
for net, out_offs, lengths in object_list: for net, out_offs, lengths in object_list:
variables[net] = (out_offs, lengths, net.dtype) variables[net] = (out_offs, lengths, net.dtype)
if isinstance(net.source, Const): if isinstance(net.source, InitVar):
dw.write_com(binw.Command.COPY_DATA) dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs) dw.write_int(out_offs)
dw.write_int(lengths) dw.write_int(lengths)

View File

@ -1,4 +1,4 @@
from copapy import Write, const from copapy import Write, CPVariable
import copapy as rc import copapy as rc
@ -20,8 +20,8 @@ def test_ast_generation():
#r2 = i1 + 9 #r2 = i1 + 9
#out = [Write(r1), Write(r2)] #out = [Write(r1), Write(r2)]
c1 = const(4) c1 = CPVariable(4)
c2 = const(2) c2 = CPVariable(2)
#i1 = c1 * 2 #i1 = c1 * 2
#r1 = i1 + 7 + (c2 + 7 * 9) #r1 = i1 + 7 + (c2 + 7 * 9)
#r2 = i1 + 9 #r2 = i1 + 9

View File

@ -1,4 +1,4 @@
from copapy import Write, const from copapy import Write, CPVariable
import copapy import copapy
import subprocess import subprocess
import struct import struct
@ -46,8 +46,8 @@ def function(c1, c2):
def test_compile(): def test_compile():
c1 = const(4) c1 = CPVariable(4)
c2 = const(2) c2 = CPVariable(2)
ret = function(c1, c2) ret = function(c1, c2)

View File

@ -1,4 +1,4 @@
from copapy import Write, const from copapy import Write, CPVariable
import copapy import copapy
import subprocess import subprocess
import struct import struct
@ -19,7 +19,7 @@ def function(c1):
def test_compile(): def test_compile():
c1 = const(16) c1 = CPVariable(16)
ret = function(c1) ret = function(c1)

View File

@ -1,4 +1,4 @@
from copapy import const, Target from copapy import CPVariable, Target
from pytest import approx from pytest import approx
@ -10,7 +10,7 @@ def function(c1):
def test_compile(): def test_compile():
c1 = const(16) c1 = CPVariable(16)
ret = function(c1) ret = function(c1)

View File

@ -1,13 +1,13 @@
from coparun_module import coparun from coparun_module import coparun
from copapy import Write, const from copapy import Write, CPVariable
import copapy import copapy
from copapy import binwrite from copapy import binwrite
def test_compile(): def test_compile():
c1 = const(4) c1 = CPVariable(4)
c2 = const(2) * 4 c2 = CPVariable(2) * 4
i1 = c2 * 2 i1 = c2 * 2
r1 = i1 + 7 + (c1 + 7 * 9) r1 = i1 + 7 + (c1 + 7 * 9)

View File

@ -1,4 +1,4 @@
from copapy import Write, const from copapy import Write, CPVariable
import copapy import copapy
import subprocess import subprocess
from copapy import binwrite from copapy import binwrite
@ -23,8 +23,8 @@ def function(c1, c2):
def test_compile(): def test_compile():
c1 = const(4) c1 = CPVariable(4)
c2 = const(2) c2 = CPVariable(2)
ret = function(c1, c2) ret = function(c1, c2)