stable_toposort function for stable sorting of the DAG added and get_all_dag_edges function to derive all edges from a list of nodes with now children (last nodes)

This commit is contained in:
Nicolas 2025-10-02 23:10:05 +02:00
parent 5b7ca52b7c
commit a5f0274665
4 changed files with 124 additions and 15 deletions

View File

@ -2,6 +2,7 @@
from typing import Generator, Iterable, Any
from . import binwrite as binw
from .stencil_db import stencil_database
from collections import defaultdict, deque
Operand = type['Net'] | float | int
@ -144,6 +145,54 @@ def const_vector3d(x: float, y: float, z: float) -> vec3d:
return vec3d((const(x), const(y), const(z)))
def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
# edges: list of (u, v) pairs meaning u -> v
# Track adjacency and indegrees
adj: defaultdict[Node, list[Node]] = defaultdict(list)
indeg: defaultdict[Node, int] = defaultdict(int)
order: dict[Node, int] = {} # first-appearance order of each node
# Build graph and order map
pos = 0
for u, v in edges:
if u not in order:
order[u] = pos; pos += 1
if v not in order:
order[v] = pos; pos += 1
adj[u].append(v)
indeg[v] += 1
indeg.setdefault(u, 0)
# Initialize queue with nodes of indegree 0, sorted by first appearance
queue = deque(sorted([n for n in indeg if indeg[n] == 0], key=lambda x: order[x]))
result: list[Node] = []
while queue:
node = queue.popleft()
result.append(node)
for nei in adj[node]:
indeg[nei] -= 1
if indeg[nei] == 0:
queue.append(nei)
# Maintain stability: sort queue by appearance order
queue = deque(sorted(queue, key=lambda x: order[x]))
# Check if graph had a cycle (not all nodes output)
if len(result) != len(indeg):
raise ValueError("Graph contains a cycle — topological sort not possible")
return result
def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
for node in nodes:
yield from get_all_dag_edges(net.source for net in node.args)
yield from ((net.source, node) for net in node.args)
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
"""List of all possible paths. Ops in order of execution (output at the end)
"""
@ -182,11 +231,12 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No
for j in range(i + 1, len(path_segments)):
path_stub = path_segments[j]
if op == path_stub[-1]:
print(op)
for insert_op in path_stub[:-1]:
#print('->', insert_op)
print('->', insert_op)
yield insert_op
finished_paths.add(j)
#print('- ', op)
print('- ', op)
yield op
finished_paths.add(i)
@ -219,7 +269,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
#if net in registers:
# print('x swap registers')
type_list = ['int' if r is None else r.dtype for r in registers]
print(type_list)
print(net, type_list)
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
yield net, new_node
registers[i] = net
@ -228,7 +278,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
yield None, node
registers[0] = net_lookup[node]
else:
print('--->', node)
print('>', node)
yield None, node
@ -240,6 +290,7 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list
assert all(node.name.startswith('read_') for net, node in net_node_list if net)
read_back_nets = {net for net, _ in net_node_list if net}
print('#', read_back_nets, stored_nets)
for net, node in net_node_list:
if isinstance(node, Write):
@ -248,6 +299,7 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list
yield net, node
if net and net in read_back_nets and net not in stored_nets:
print('> add Write')
yield net, Write(net)
stored_nets.add(net)
@ -270,8 +322,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
else:
node_list = list(end_nodes)
path_segments = list(get_path_segments(node_list))
ordered_ops = list(get_ordered_ops(path_segments))
ordered_ops = list(stable_toposort(get_all_dag_edges(node_list)))
const_list = get_consts(ordered_ops)
output_ops = list(add_read_ops(ordered_ops))
extended_output_ops = list(add_write_ops(output_ops, const_list))
@ -335,7 +386,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
assert result_net, f"Relocation found but no net defined for operation {node.name}"
object_addr = object_addr_lookp[result_net]
patch_value = object_addr + patch.addend - (offset + patch.addr)
print('patch: ', patch, object_addr, patch_value)
#print('patch: ', patch, object_addr, patch_value)
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
offset += len(data)

View File

@ -3,7 +3,7 @@ import copapy as rc
def test_ast_generation():
c1 = const(1.11)
#c1 = const(1.11)
#c2 = const(2.22)
#c3 = const(3.33)
@ -13,19 +13,21 @@ def test_ast_generation():
#r1 = i1 + i3
#r2 = i3 * i2
c1 = const(4)
i1 = c1 * 2
r1 = i1 + 7
out = Write(r1)
r2 = i1 + 9
out = [Write(r1), Write(r2)]
print(out)
print('-- get_path_segments:')
path_segments = list(rc.get_path_segments([out]))
path_segments = list(rc.get_path_segments(out))
for p in path_segments:
print(p)
print('-- get_ordered_ops:')
ordered_ops = list(rc.get_ordered_ops(path_segments))
for p in path_segments:
for p in ordered_ops:
print(p)
print('-- get_consts:')

53
tests/test_ast_gen2.py Normal file
View File

@ -0,0 +1,53 @@
from copapy import Write, const, Node
import copapy as rc
from typing import Iterable, Generator
def test_ast_generation():
#c1 = const(1.11)
#c2 = const(2.22)
#c3 = const(3.33)
#i1 = c1 + c2
#i2 = c2 * i1
#i3 = i2 + 4
#r1 = i1 + i3
#r2 = i3 * i2
c1 = const(4)
i1 = c1 * 2
r1 = i1 + 7
r2 = i1 + 9
out = [Write(r1), Write(r2)]
print(out)
print('-- get_edges:')
edges = list(rc.get_all_dag_edges(out))
for p in edges:
print('#', p)
print('-- get_ordered_ops:')
ordered_ops = list(rc.stable_toposort(edges))
for p in ordered_ops:
print('#', p)
print('-- get_consts:')
const_list = rc.get_consts(ordered_ops)
for p in const_list:
print('#', p)
print('-- add_read_ops:')
output_ops = list(rc.add_read_ops(ordered_ops))
for p in output_ops:
print('#', p)
print('-- add_write_ops:')
extended_output_ops = list(rc.add_write_ops(output_ops, const_list))
for p in extended_output_ops:
print('#', p)
print('--')
if __name__ == "__main__":
test_ast_generation()

View File

@ -39,8 +39,8 @@ def test_compile():
print(run_command(['bash', 'build.sh']))
c1 = const(4)
#c2 = const(2)
#c1 = const(1.11)
#c2 = const(2.22)
#i1 = c1 * 2
#i2 = i1 + 3
@ -50,14 +50,17 @@ def test_compile():
#out = [Write(r1), Write(r2)]
c1 = const(4)
c2 = const(2)
i1 = c1 * 2
r1 = i1 + 7
out = Write(r1)
r2 = i1 + 9
out = [Write(r1), Write(r2)]
il = copapy.compile_to_instruction_list(out)
#copapy.read_variable(il, i1)
copapy.read_variable(il, r1)
copapy.read_variable(il, r2)
il.write_com(binw.Command.READ_DATA)
il.write_int(0)