mirror of https://github.com/Nonannet/copapy.git
stable_toposort function for stable sorting of the DAG added and get_all_dag_edges function to derive all edges from a list of nodes with now children (last nodes)
This commit is contained in:
parent
5b7ca52b7c
commit
a5f0274665
|
|
@ -2,6 +2,7 @@
|
||||||
from typing import Generator, Iterable, Any
|
from typing import Generator, Iterable, Any
|
||||||
from . import binwrite as binw
|
from . import binwrite as binw
|
||||||
from .stencil_db import stencil_database
|
from .stencil_db import stencil_database
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
Operand = type['Net'] | float | int
|
Operand = type['Net'] | float | int
|
||||||
|
|
||||||
|
|
@ -144,6 +145,54 @@ def const_vector3d(x: float, y: float, z: float) -> vec3d:
|
||||||
return vec3d((const(x), const(y), const(z)))
|
return vec3d((const(x), const(y), const(z)))
|
||||||
|
|
||||||
|
|
||||||
|
def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
|
||||||
|
# edges: list of (u, v) pairs meaning u -> v
|
||||||
|
|
||||||
|
# Track adjacency and indegrees
|
||||||
|
adj: defaultdict[Node, list[Node]] = defaultdict(list)
|
||||||
|
indeg: defaultdict[Node, int] = defaultdict(int)
|
||||||
|
order: dict[Node, int] = {} # first-appearance order of each node
|
||||||
|
|
||||||
|
# Build graph and order map
|
||||||
|
pos = 0
|
||||||
|
for u, v in edges:
|
||||||
|
if u not in order:
|
||||||
|
order[u] = pos; pos += 1
|
||||||
|
if v not in order:
|
||||||
|
order[v] = pos; pos += 1
|
||||||
|
adj[u].append(v)
|
||||||
|
indeg[v] += 1
|
||||||
|
indeg.setdefault(u, 0)
|
||||||
|
|
||||||
|
# Initialize queue with nodes of indegree 0, sorted by first appearance
|
||||||
|
queue = deque(sorted([n for n in indeg if indeg[n] == 0], key=lambda x: order[x]))
|
||||||
|
result: list[Node] = []
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
node = queue.popleft()
|
||||||
|
result.append(node)
|
||||||
|
|
||||||
|
for nei in adj[node]:
|
||||||
|
indeg[nei] -= 1
|
||||||
|
if indeg[nei] == 0:
|
||||||
|
queue.append(nei)
|
||||||
|
|
||||||
|
# Maintain stability: sort queue by appearance order
|
||||||
|
queue = deque(sorted(queue, key=lambda x: order[x]))
|
||||||
|
|
||||||
|
# Check if graph had a cycle (not all nodes output)
|
||||||
|
if len(result) != len(indeg):
|
||||||
|
raise ValueError("Graph contains a cycle — topological sort not possible")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
|
||||||
|
for node in nodes:
|
||||||
|
yield from get_all_dag_edges(net.source for net in node.args)
|
||||||
|
yield from ((net.source, node) for net in node.args)
|
||||||
|
|
||||||
|
|
||||||
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
|
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
|
||||||
"""List of all possible paths. Ops in order of execution (output at the end)
|
"""List of all possible paths. Ops in order of execution (output at the end)
|
||||||
"""
|
"""
|
||||||
|
|
@ -182,11 +231,12 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No
|
||||||
for j in range(i + 1, len(path_segments)):
|
for j in range(i + 1, len(path_segments)):
|
||||||
path_stub = path_segments[j]
|
path_stub = path_segments[j]
|
||||||
if op == path_stub[-1]:
|
if op == path_stub[-1]:
|
||||||
|
print(op)
|
||||||
for insert_op in path_stub[:-1]:
|
for insert_op in path_stub[:-1]:
|
||||||
#print('->', insert_op)
|
print('->', insert_op)
|
||||||
yield insert_op
|
yield insert_op
|
||||||
finished_paths.add(j)
|
finished_paths.add(j)
|
||||||
#print('- ', op)
|
print('- ', op)
|
||||||
yield op
|
yield op
|
||||||
finished_paths.add(i)
|
finished_paths.add(i)
|
||||||
|
|
||||||
|
|
@ -219,7 +269,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
|
||||||
#if net in registers:
|
#if net in registers:
|
||||||
# print('x swap registers')
|
# print('x swap registers')
|
||||||
type_list = ['int' if r is None else r.dtype for r in registers]
|
type_list = ['int' if r is None else r.dtype for r in registers]
|
||||||
print(type_list)
|
print(net, type_list)
|
||||||
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
|
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
|
||||||
yield net, new_node
|
yield net, new_node
|
||||||
registers[i] = net
|
registers[i] = net
|
||||||
|
|
@ -228,7 +278,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
|
||||||
yield None, node
|
yield None, node
|
||||||
registers[0] = net_lookup[node]
|
registers[0] = net_lookup[node]
|
||||||
else:
|
else:
|
||||||
print('--->', node)
|
print('>', node)
|
||||||
yield None, node
|
yield None, node
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -240,6 +290,7 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list
|
||||||
|
|
||||||
assert all(node.name.startswith('read_') for net, node in net_node_list if net)
|
assert all(node.name.startswith('read_') for net, node in net_node_list if net)
|
||||||
read_back_nets = {net for net, _ in net_node_list if net}
|
read_back_nets = {net for net, _ in net_node_list if net}
|
||||||
|
print('#', read_back_nets, stored_nets)
|
||||||
|
|
||||||
for net, node in net_node_list:
|
for net, node in net_node_list:
|
||||||
if isinstance(node, Write):
|
if isinstance(node, Write):
|
||||||
|
|
@ -248,6 +299,7 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list
|
||||||
yield net, node
|
yield net, node
|
||||||
|
|
||||||
if net and net in read_back_nets and net not in stored_nets:
|
if net and net in read_back_nets and net not in stored_nets:
|
||||||
|
print('> add Write')
|
||||||
yield net, Write(net)
|
yield net, Write(net)
|
||||||
stored_nets.add(net)
|
stored_nets.add(net)
|
||||||
|
|
||||||
|
|
@ -270,8 +322,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
|
||||||
else:
|
else:
|
||||||
node_list = list(end_nodes)
|
node_list = list(end_nodes)
|
||||||
|
|
||||||
path_segments = list(get_path_segments(node_list))
|
ordered_ops = list(stable_toposort(get_all_dag_edges(node_list)))
|
||||||
ordered_ops = list(get_ordered_ops(path_segments))
|
|
||||||
const_list = get_consts(ordered_ops)
|
const_list = get_consts(ordered_ops)
|
||||||
output_ops = list(add_read_ops(ordered_ops))
|
output_ops = list(add_read_ops(ordered_ops))
|
||||||
extended_output_ops = list(add_write_ops(output_ops, const_list))
|
extended_output_ops = list(add_write_ops(output_ops, const_list))
|
||||||
|
|
@ -335,7 +386,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
|
||||||
assert result_net, f"Relocation found but no net defined for operation {node.name}"
|
assert result_net, f"Relocation found but no net defined for operation {node.name}"
|
||||||
object_addr = object_addr_lookp[result_net]
|
object_addr = object_addr_lookp[result_net]
|
||||||
patch_value = object_addr + patch.addend - (offset + patch.addr)
|
patch_value = object_addr + patch.addend - (offset + patch.addr)
|
||||||
print('patch: ', patch, object_addr, patch_value)
|
#print('patch: ', patch, object_addr, patch_value)
|
||||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
|
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
|
||||||
|
|
||||||
offset += len(data)
|
offset += len(data)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import copapy as rc
|
||||||
|
|
||||||
|
|
||||||
def test_ast_generation():
|
def test_ast_generation():
|
||||||
c1 = const(1.11)
|
#c1 = const(1.11)
|
||||||
#c2 = const(2.22)
|
#c2 = const(2.22)
|
||||||
#c3 = const(3.33)
|
#c3 = const(3.33)
|
||||||
|
|
||||||
|
|
@ -13,19 +13,21 @@ def test_ast_generation():
|
||||||
|
|
||||||
#r1 = i1 + i3
|
#r1 = i1 + i3
|
||||||
#r2 = i3 * i2
|
#r2 = i3 * i2
|
||||||
|
c1 = const(4)
|
||||||
i1 = c1 * 2
|
i1 = c1 * 2
|
||||||
r1 = i1 + 7
|
r1 = i1 + 7
|
||||||
out = Write(r1)
|
r2 = i1 + 9
|
||||||
|
out = [Write(r1), Write(r2)]
|
||||||
|
|
||||||
print(out)
|
print(out)
|
||||||
print('-- get_path_segments:')
|
print('-- get_path_segments:')
|
||||||
|
|
||||||
path_segments = list(rc.get_path_segments([out]))
|
path_segments = list(rc.get_path_segments(out))
|
||||||
for p in path_segments:
|
for p in path_segments:
|
||||||
print(p)
|
print(p)
|
||||||
print('-- get_ordered_ops:')
|
print('-- get_ordered_ops:')
|
||||||
ordered_ops = list(rc.get_ordered_ops(path_segments))
|
ordered_ops = list(rc.get_ordered_ops(path_segments))
|
||||||
for p in path_segments:
|
for p in ordered_ops:
|
||||||
print(p)
|
print(p)
|
||||||
print('-- get_consts:')
|
print('-- get_consts:')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
from copapy import Write, const, Node
|
||||||
|
import copapy as rc
|
||||||
|
from typing import Iterable, Generator
|
||||||
|
|
||||||
|
|
||||||
|
def test_ast_generation():
|
||||||
|
#c1 = const(1.11)
|
||||||
|
#c2 = const(2.22)
|
||||||
|
#c3 = const(3.33)
|
||||||
|
|
||||||
|
#i1 = c1 + c2
|
||||||
|
#i2 = c2 * i1
|
||||||
|
#i3 = i2 + 4
|
||||||
|
|
||||||
|
#r1 = i1 + i3
|
||||||
|
#r2 = i3 * i2
|
||||||
|
c1 = const(4)
|
||||||
|
i1 = c1 * 2
|
||||||
|
r1 = i1 + 7
|
||||||
|
r2 = i1 + 9
|
||||||
|
out = [Write(r1), Write(r2)]
|
||||||
|
|
||||||
|
print(out)
|
||||||
|
print('-- get_edges:')
|
||||||
|
|
||||||
|
edges = list(rc.get_all_dag_edges(out))
|
||||||
|
for p in edges:
|
||||||
|
print('#', p)
|
||||||
|
|
||||||
|
print('-- get_ordered_ops:')
|
||||||
|
ordered_ops = list(rc.stable_toposort(edges))
|
||||||
|
for p in ordered_ops:
|
||||||
|
print('#', p)
|
||||||
|
|
||||||
|
print('-- get_consts:')
|
||||||
|
const_list = rc.get_consts(ordered_ops)
|
||||||
|
for p in const_list:
|
||||||
|
print('#', p)
|
||||||
|
|
||||||
|
print('-- add_read_ops:')
|
||||||
|
output_ops = list(rc.add_read_ops(ordered_ops))
|
||||||
|
for p in output_ops:
|
||||||
|
print('#', p)
|
||||||
|
|
||||||
|
print('-- add_write_ops:')
|
||||||
|
extended_output_ops = list(rc.add_write_ops(output_ops, const_list))
|
||||||
|
for p in extended_output_ops:
|
||||||
|
print('#', p)
|
||||||
|
print('--')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_ast_generation()
|
||||||
|
|
@ -39,8 +39,8 @@ def test_compile():
|
||||||
|
|
||||||
print(run_command(['bash', 'build.sh']))
|
print(run_command(['bash', 'build.sh']))
|
||||||
|
|
||||||
c1 = const(4)
|
#c1 = const(1.11)
|
||||||
#c2 = const(2)
|
#c2 = const(2.22)
|
||||||
|
|
||||||
#i1 = c1 * 2
|
#i1 = c1 * 2
|
||||||
#i2 = i1 + 3
|
#i2 = i1 + 3
|
||||||
|
|
@ -50,14 +50,17 @@ def test_compile():
|
||||||
|
|
||||||
#out = [Write(r1), Write(r2)]
|
#out = [Write(r1), Write(r2)]
|
||||||
|
|
||||||
|
c1 = const(4)
|
||||||
|
c2 = const(2)
|
||||||
i1 = c1 * 2
|
i1 = c1 * 2
|
||||||
r1 = i1 + 7
|
r1 = i1 + 7
|
||||||
out = Write(r1)
|
r2 = i1 + 9
|
||||||
|
out = [Write(r1), Write(r2)]
|
||||||
|
|
||||||
il = copapy.compile_to_instruction_list(out)
|
il = copapy.compile_to_instruction_list(out)
|
||||||
|
|
||||||
#copapy.read_variable(il, i1)
|
|
||||||
copapy.read_variable(il, r1)
|
copapy.read_variable(il, r1)
|
||||||
|
copapy.read_variable(il, r2)
|
||||||
|
|
||||||
il.write_com(binw.Command.READ_DATA)
|
il.write_com(binw.Command.READ_DATA)
|
||||||
il.write_int(0)
|
il.write_int(0)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue