diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 492442b..e2c7b17 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -2,6 +2,7 @@ from typing import Generator, Iterable, Any from . import binwrite as binw from .stencil_db import stencil_database +from collections import defaultdict, deque Operand = type['Net'] | float | int @@ -144,6 +145,54 @@ def const_vector3d(x: float, y: float, z: float) -> vec3d: return vec3d((const(x), const(y), const(z))) +def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: + # edges: list of (u, v) pairs meaning u -> v + + # Track adjacency and indegrees + adj: defaultdict[Node, list[Node]] = defaultdict(list) + indeg: defaultdict[Node, int] = defaultdict(int) + order: dict[Node, int] = {} # first-appearance order of each node + + # Build graph and order map + pos = 0 + for u, v in edges: + if u not in order: + order[u] = pos; pos += 1 + if v not in order: + order[v] = pos; pos += 1 + adj[u].append(v) + indeg[v] += 1 + indeg.setdefault(u, 0) + + # Initialize queue with nodes of indegree 0, sorted by first appearance + queue = deque(sorted([n for n in indeg if indeg[n] == 0], key=lambda x: order[x])) + result: list[Node] = [] + + while queue: + node = queue.popleft() + result.append(node) + + for nei in adj[node]: + indeg[nei] -= 1 + if indeg[nei] == 0: + queue.append(nei) + + # Maintain stability: sort queue by appearance order + queue = deque(sorted(queue, key=lambda x: order[x])) + + # Check if graph had a cycle (not all nodes output) + if len(result) != len(indeg): + raise ValueError("Graph contains a cycle — topological sort not possible") + + return result + + +def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]: + for node in nodes: + yield from get_all_dag_edges(net.source for net in node.args) + yield from ((net.source, node) for net in node.args) + + def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]: """List of all possible paths. Ops in order of execution (output at the end) """ @@ -182,11 +231,12 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No for j in range(i + 1, len(path_segments)): path_stub = path_segments[j] if op == path_stub[-1]: + print(op) for insert_op in path_stub[:-1]: - #print('->', insert_op) + print('->', insert_op) yield insert_op finished_paths.add(j) - #print('- ', op) + print('- ', op) yield op finished_paths.add(i) @@ -219,7 +269,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No #if net in registers: # print('x swap registers') type_list = ['int' if r is None else r.dtype for r in registers] - print(type_list) + print(net, type_list) new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), []) yield net, new_node registers[i] = net @@ -228,7 +278,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No yield None, node registers[0] = net_lookup[node] else: - print('--->', node) + print('>', node) yield None, node @@ -240,6 +290,7 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list assert all(node.name.startswith('read_') for net, node in net_node_list if net) read_back_nets = {net for net, _ in net_node_list if net} + print('#', read_back_nets, stored_nets) for net, node in net_node_list: if isinstance(node, Write): @@ -248,6 +299,7 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list yield net, node if net and net in read_back_nets and net not in stored_nets: + print('> add Write') yield net, Write(net) stored_nets.add(net) @@ -270,8 +322,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w else: node_list = list(end_nodes) - path_segments = list(get_path_segments(node_list)) - ordered_ops = list(get_ordered_ops(path_segments)) + ordered_ops = list(stable_toposort(get_all_dag_edges(node_list))) const_list = get_consts(ordered_ops) output_ops = list(add_read_ops(ordered_ops)) extended_output_ops = list(add_write_ops(output_ops, const_list)) @@ -335,7 +386,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w assert result_net, f"Relocation found but no net defined for operation {node.name}" object_addr = object_addr_lookp[result_net] patch_value = object_addr + patch.addend - (offset + patch.addr) - print('patch: ', patch, object_addr, patch_value) + #print('patch: ', patch, object_addr, patch_value) patch_list.append((patch.type.value, offset + patch.addr, patch_value)) offset += len(data) diff --git a/tests/test_ast_gen.py b/tests/test_ast_gen.py index c908bc5..c4769a5 100644 --- a/tests/test_ast_gen.py +++ b/tests/test_ast_gen.py @@ -3,7 +3,7 @@ import copapy as rc def test_ast_generation(): - c1 = const(1.11) + #c1 = const(1.11) #c2 = const(2.22) #c3 = const(3.33) @@ -13,19 +13,21 @@ def test_ast_generation(): #r1 = i1 + i3 #r2 = i3 * i2 + c1 = const(4) i1 = c1 * 2 r1 = i1 + 7 - out = Write(r1) + r2 = i1 + 9 + out = [Write(r1), Write(r2)] print(out) print('-- get_path_segments:') - path_segments = list(rc.get_path_segments([out])) + path_segments = list(rc.get_path_segments(out)) for p in path_segments: print(p) print('-- get_ordered_ops:') ordered_ops = list(rc.get_ordered_ops(path_segments)) - for p in path_segments: + for p in ordered_ops: print(p) print('-- get_consts:') diff --git a/tests/test_ast_gen2.py b/tests/test_ast_gen2.py new file mode 100644 index 0000000..460b750 --- /dev/null +++ b/tests/test_ast_gen2.py @@ -0,0 +1,53 @@ +from copapy import Write, const, Node +import copapy as rc +from typing import Iterable, Generator + + +def test_ast_generation(): + #c1 = const(1.11) + #c2 = const(2.22) + #c3 = const(3.33) + + #i1 = c1 + c2 + #i2 = c2 * i1 + #i3 = i2 + 4 + + #r1 = i1 + i3 + #r2 = i3 * i2 + c1 = const(4) + i1 = c1 * 2 + r1 = i1 + 7 + r2 = i1 + 9 + out = [Write(r1), Write(r2)] + + print(out) + print('-- get_edges:') + + edges = list(rc.get_all_dag_edges(out)) + for p in edges: + print('#', p) + + print('-- get_ordered_ops:') + ordered_ops = list(rc.stable_toposort(edges)) + for p in ordered_ops: + print('#', p) + + print('-- get_consts:') + const_list = rc.get_consts(ordered_ops) + for p in const_list: + print('#', p) + + print('-- add_read_ops:') + output_ops = list(rc.add_read_ops(ordered_ops)) + for p in output_ops: + print('#', p) + + print('-- add_write_ops:') + extended_output_ops = list(rc.add_write_ops(output_ops, const_list)) + for p in extended_output_ops: + print('#', p) + print('--') + + +if __name__ == "__main__": + test_ast_generation() diff --git a/tests/test_compile.py b/tests/test_compile.py index a8aefa8..72f3025 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -39,8 +39,8 @@ def test_compile(): print(run_command(['bash', 'build.sh'])) - c1 = const(4) - #c2 = const(2) + #c1 = const(1.11) + #c2 = const(2.22) #i1 = c1 * 2 #i2 = i1 + 3 @@ -50,14 +50,17 @@ def test_compile(): #out = [Write(r1), Write(r2)] + c1 = const(4) + c2 = const(2) i1 = c1 * 2 r1 = i1 + 7 - out = Write(r1) + r2 = i1 + 9 + out = [Write(r1), Write(r2)] il = copapy.compile_to_instruction_list(out) - #copapy.read_variable(il, i1) copapy.read_variable(il, r1) + copapy.read_variable(il, r2) il.write_com(binw.Command.READ_DATA) il.write_int(0)