seams to work now! results for examples matches the values.

This commit is contained in:
Nicolas 2025-10-02 23:23:07 +02:00
parent a5f0274665
commit 3476fc8980
6 changed files with 89 additions and 164 deletions

View File

@ -146,7 +146,13 @@ def const_vector3d(x: float, y: float, z: float) -> vec3d:
def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
# edges: list of (u, v) pairs meaning u -> v """Perform a stable topological sort on a directed acyclic graph (DAG).
Arguments:
edges: Iterable of (u, v) pairs meaning u -> v
Returns:
List of nodes in topologically sorted order.
"""
# Track adjacency and indegrees # Track adjacency and indegrees
adj: defaultdict[Node, list[Node]] = defaultdict(list) adj: defaultdict[Node, list[Node]] = defaultdict(list)
@ -188,75 +194,41 @@ def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]: def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
"""Get all edges in the DAG by traversing from the given nodes
Arguments:
nodes: Iterable of nodes to start the traversal from
Yields:
Tuples of (source_node, target_node) representing edges in the DAG
"""
for node in nodes: for node in nodes:
yield from get_all_dag_edges(net.source for net in node.args) yield from get_all_dag_edges(net.source for net in node.args)
yield from ((net.source, node) for net in node.args) yield from ((net.source, node) for net in node.args)
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]: def get_const_nets(nodes: list[Node]) -> list[Net]:
"""List of all possible paths. Ops in order of execution (output at the end) """Get all nets with a constant nodes value
"""
def rekursiv_node_search(node_list: Iterable[Node], path: list[Node]) -> Generator[list[Node], None, None]:
for node in node_list:
new_path = [node] + path
if node.args:
yield from rekursiv_node_search([net.source for net in node.args], new_path)
else:
yield new_path
known_nodes: set[Node] = set()
sorted_path_list = sorted(rekursiv_node_search(root, []), key=lambda x: -len(x))
for path in sorted_path_list:
sflag = False
for i, net in enumerate(path):
if net in known_nodes or i == len(path) - 1:
if sflag:
if i > 0:
yield path[:i+1]
break
else:
sflag = True
known_nodes.add(net)
def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, None]:
"""Merge in all tree branches at branch position into the path segments
"""
finished_paths: set[int] = set()
for i, path in enumerate(path_segments):
if i not in finished_paths:
for op in path:
for j in range(i + 1, len(path_segments)):
path_stub = path_segments[j]
if op == path_stub[-1]:
print(op)
for insert_op in path_stub[:-1]:
print('->', insert_op)
yield insert_op
finished_paths.add(j)
print('- ', op)
yield op
finished_paths.add(i)
def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
"""Get all const nodes in the op list
Returns: Returns:
List of tuples of (name, net, value)""" List of nets whose source node is a Const
net_lookup = {net.source: net for op in op_list for net in op.args} """
return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)] net_lookup = {net.source: net for node in nodes for net in node.args}
return [net_lookup[node] for node in nodes if isinstance(node, Const)]
def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]: def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add read operation before each op where arguments are not already positioned """Add read node before each op where arguments are not already positioned
correctly in the registers correctly in the registers
Arguments:
node_list: List of nodes in the order of execution
Returns: Returns:
Yields tuples of a net and a operation. The net is only provided Yields tuples of a net and a node. The net is the result net
for new added read operations. Otherwise None is returned in the tuple.""" for the node. If the node has no result net None is returned in the tuple.
"""
registers: list[None | Net] = [None] * 2 registers: list[None | Net] = [None] * 2
# Generate result net lookup table # Generate result net lookup table
@ -269,37 +241,42 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
#if net in registers: #if net in registers:
# print('x swap registers') # print('x swap registers')
type_list = ['int' if r is None else r.dtype for r in registers] type_list = ['int' if r is None else r.dtype for r in registers]
print(net, type_list)
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), []) new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
yield net, new_node yield net, new_node
registers[i] = net registers[i] = net
if node in net_lookup: if node in net_lookup:
yield None, node yield net_lookup[node], node
registers[0] = net_lookup[node] registers[0] = net_lookup[node]
else: else:
print('>', node)
yield None, node yield None, node
def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]: def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list[Net]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add write operation for each new defined net if a read operation is later followed""" """Add write operation for each new defined net if a read operation is later followed
Returns:
Yields tuples of a net and a node. The associated net is provided for read and write nodes.
Otherwise None is returned in the tuple.
"""
# Initialize set of nets with constants # Initialize set of nets with constants
stored_nets = {c[1] for c in const_list} stored_nets = set(const_nets)
assert all(node.name.startswith('read_') for net, node in net_node_list if net) #assert all(node.name.startswith('read_') for net, node in net_node_list if net)
read_back_nets = {net for net, _ in net_node_list if net} read_back_nets = {
print('#', read_back_nets, stored_nets) net for net, node in net_node_list
if net and node.name.startswith('read_')}
for net, node in net_node_list: for net, node in net_node_list:
if isinstance(node, Write): if isinstance(node, Write):
yield node.args[0], node yield node.args[0], node
else: elif node.name.startswith('read_'):
yield net, node yield net, node
else:
yield None, node
if net and net in read_back_nets and net not in stored_nets: if net in read_back_nets and net not in stored_nets:
print('> add Write')
yield net, Write(net) yield net, Write(net)
stored_nets.add(net) stored_nets.add(net)
@ -323,15 +300,13 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
node_list = list(end_nodes) node_list = list(end_nodes)
ordered_ops = list(stable_toposort(get_all_dag_edges(node_list))) ordered_ops = list(stable_toposort(get_all_dag_edges(node_list)))
const_list = get_consts(ordered_ops) const_net_list = get_const_nets(ordered_ops)
output_ops = list(add_read_ops(ordered_ops)) output_ops = list(add_read_ops(ordered_ops))
extended_output_ops = list(add_write_ops(output_ops, const_list)) extended_output_ops = list(add_write_ops(output_ops, const_net_list))
for net, node in extended_output_ops:
print(node.name)
# Get all nets associated with heap memory # Get all nets associated with heap memory
variable_list = get_nets(const_list, extended_output_ops) variable_list = get_nets([[const_net_list]], extended_output_ops)
assert(len(set(variable_list)) == len(variable_list)), 'Duplicates!'
dw = binw.data_writer(sdb.byteorder) dw = binw.data_writer(sdb.byteorder)
@ -359,7 +334,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
dw.write_int(out_offs) dw.write_int(out_offs)
dw.write_int(lengths) dw.write_int(lengths)
dw.write_value(net.source.value, lengths) dw.write_value(net.source.value, lengths)
print(f'+ {net.dtype} {net.source.value}') # print(f'+ {net.dtype} {net.source.value}')
# write auxiliary_functions # write auxiliary_functions
# TODO # TODO
@ -370,23 +345,23 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
patch_list: list[tuple[int, int, int]] = [] patch_list: list[tuple[int, int, int]] = []
offset = 0 # offset in generated code chunk offset = 0 # offset in generated code chunk
print('object_addr_lookp: ', object_addr_lookp) # print('object_addr_lookp: ', object_addr_lookp)
data = sdb.get_func_data('function_start') data = sdb.get_func_data('function_start')
data_list.append(data) data_list.append(data)
offset += len(data) offset += len(data)
for result_net, node in extended_output_ops: for associated_net, node in extended_output_ops:
assert node.name in sdb.function_definitions, f"- Warning: {node.name} prototype not found" assert node.name in sdb.function_definitions, f"- Warning: {node.name} prototype not found"
data = sdb.get_func_data(node.name) data = sdb.get_func_data(node.name)
data_list.append(data) data_list.append(data)
print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) # print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for patch in sdb.get_patch_positions(node.name): for patch in sdb.get_patch_positions(node.name):
assert result_net, f"Relocation found but no net defined for operation {node.name}" assert associated_net, f"Relocation found but no net defined for operation {node.name}"
object_addr = object_addr_lookp[result_net] object_addr = object_addr_lookp[associated_net]
patch_value = object_addr + patch.addend - (offset + patch.addr) patch_value = object_addr + patch.addend - (offset + patch.addr)
#print('patch: ', patch, object_addr, patch_value) # print('patch: ', patch, object_addr, patch_value)
patch_list.append((patch.type.value, offset + patch.addr, patch_value)) patch_list.append((patch.type.value, offset + patch.addr, patch_value))
offset += len(data) offset += len(data)
@ -394,7 +369,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
data = sdb.get_func_data('function_end') data = sdb.get_func_data('function_end')
data_list.append(data) data_list.append(data)
offset += len(data) offset += len(data)
print('function_end', offset, data) # print('function_end', offset, data)
# allocate program data # allocate program data
dw.write_com(binw.Command.ALLOCATE_CODE) dw.write_com(binw.Command.ALLOCATE_CODE)

View File

@ -7,7 +7,7 @@ op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'}
def get_function_start() -> str: def get_function_start() -> str:
return """ return """
int function_start(){ int function_start(){
result_int(0); result_int(0); // dummy call instruction before marker gets striped
asm volatile (".long 0xF27ECAFE"); asm volatile (".long 0xF27ECAFE");
return 1; return 1;
} }

View File

@ -224,7 +224,7 @@
} }
int function_start(){ int function_start(){
result_int(0); result_int(0); // dummy call instruction before marker gets striped
asm volatile (".long 0xF27ECAFE"); asm volatile (".long 0xF27ECAFE");
return 1; return 1;
} }

View File

@ -1,5 +1,6 @@
from copapy import Write, const from copapy import Write, const, Node
import copapy as rc import copapy as rc
from typing import Iterable, Generator
def test_ast_generation(): def test_ast_generation():
@ -20,30 +21,31 @@ def test_ast_generation():
out = [Write(r1), Write(r2)] out = [Write(r1), Write(r2)]
print(out) print(out)
print('-- get_path_segments:') print('-- get_edges:')
edges = list(rc.get_all_dag_edges(out))
for p in edges:
print('#', p)
path_segments = list(rc.get_path_segments(out))
for p in path_segments:
print(p)
print('-- get_ordered_ops:') print('-- get_ordered_ops:')
ordered_ops = list(rc.get_ordered_ops(path_segments)) ordered_ops = list(rc.stable_toposort(edges))
for p in ordered_ops: for p in ordered_ops:
print(p) print('#', p)
print('-- get_consts:') print('-- get_consts:')
const_list = rc.get_const_nets(ordered_ops)
const_list = rc.get_consts(ordered_ops)
for p in const_list: for p in const_list:
print(p) print('#', p)
print('-- add_read_ops:')
print('-- add_read_ops:')
output_ops = list(rc.add_read_ops(ordered_ops)) output_ops = list(rc.add_read_ops(ordered_ops))
for p in output_ops: for p in output_ops:
print(p) print('#', p)
print('-- add_write_ops:')
print('-- add_write_ops:')
extended_output_ops = list(rc.add_write_ops(output_ops, const_list)) extended_output_ops = list(rc.add_write_ops(output_ops, const_list))
for p in extended_output_ops: for p in extended_output_ops:
print(p) print('#', p)
print('--') print('--')

View File

@ -1,53 +0,0 @@
from copapy import Write, const, Node
import copapy as rc
from typing import Iterable, Generator
def test_ast_generation():
#c1 = const(1.11)
#c2 = const(2.22)
#c3 = const(3.33)
#i1 = c1 + c2
#i2 = c2 * i1
#i3 = i2 + 4
#r1 = i1 + i3
#r2 = i3 * i2
c1 = const(4)
i1 = c1 * 2
r1 = i1 + 7
r2 = i1 + 9
out = [Write(r1), Write(r2)]
print(out)
print('-- get_edges:')
edges = list(rc.get_all_dag_edges(out))
for p in edges:
print('#', p)
print('-- get_ordered_ops:')
ordered_ops = list(rc.stable_toposort(edges))
for p in ordered_ops:
print('#', p)
print('-- get_consts:')
const_list = rc.get_consts(ordered_ops)
for p in const_list:
print('#', p)
print('-- add_read_ops:')
output_ops = list(rc.add_read_ops(ordered_ops))
for p in output_ops:
print('#', p)
print('-- add_write_ops:')
extended_output_ops = list(rc.add_write_ops(output_ops, const_list))
for p in extended_output_ops:
print('#', p)
print('--')
if __name__ == "__main__":
test_ast_generation()

View File

@ -14,25 +14,23 @@ def run_command(command: list[str], encoding: str = 'utf8') -> str:
def test_example(): def test_example():
c1 = 1.11 c1 = 4
c2 = 2.22 c2 = 2
i1 = c1 * 2 i1 = c1 * 2
i2 = i1 + 3 r1 = i1 + 7 + (c2 + 7 * 9)
r2 = i1 + 9
r1 = i1 + i2
r2 = c2 + 4 + c1
en = {'little': '<', 'big': '>'}['little'] en = {'little': '<', 'big': '>'}['little']
data = struct.pack(en + 'f', r1) data = struct.pack(en + 'i', r1)
print("example r1 " + ' '.join(f'{b:02X}' for b in data)) print("example r1 " + ' '.join(f'{b:02X}' for b in data))
data = struct.pack(en + 'f', r2) data = struct.pack(en + 'i', r2)
print("example r2 " + ' '.join(f'{b:02X}' for b in data)) print("example r2 " + ' '.join(f'{b:02X}' for b in data))
# assert False # assert False
# example r1 7B 14 EE 40 #example r1 42 A0 00 00
# example r2 5C 8F EA 40 #example r2 41 88 00 00
def test_compile(): def test_compile():
@ -53,7 +51,7 @@ def test_compile():
c1 = const(4) c1 = const(4)
c2 = const(2) c2 = const(2)
i1 = c1 * 2 i1 = c1 * 2
r1 = i1 + 7 r1 = i1 + 7 + (c2 + 7 * 9)
r2 = i1 + 9 r2 = i1 + 9
out = [Write(r1), Write(r2)] out = [Write(r1), Write(r2)]
@ -69,15 +67,18 @@ def test_compile():
# run program command # run program command
il.write_com(binw.Command.END_PROG) il.write_com(binw.Command.END_PROG)
print('#', il.print()) print('* Data to runner:')
il.print()
il.to_file('test.copapy') il.to_file('test.copapy')
result = run_command(['./bin/runmem2', 'test.copapy']) result = run_command(['./bin/runmem2', 'test.copapy'])
print('* Output from runner:')
print(result) print(result)
assert 'Return value: 1' in result assert 'Return value: 1' in result
if __name__ == "__main__": if __name__ == "__main__":
#test_example()
test_compile() test_compile()