copapy/tests/test_comp_timing.py

209 lines
8.4 KiB
Python
Raw Normal View History

2025-11-14 21:50:23 +00:00
import time
from copapy import variable
from copapy import backend
from copapy.backend import Write, stencil_db_from_package
import copapy.backend as cpbe
import copapy as cp
import copapy._binwrite as binw
from copapy._compiler import get_nets, get_section_layout, get_data_layout
2025-11-24 15:22:46 +00:00
from copapy._compiler import patch_entry, CPConstant, get_aux_func_layout
2025-11-14 21:50:23 +00:00
def test_timing_compiler():
t1 = cp.vector([10, 11]*128) + cp.vector(cp.variable(v) for v in range(256))
t2 = t1.sum()
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(256))
t5 = ((t3 * t1) * 2).magnitude()
out = [Write(t5)]
print(out)
print('-- get_edges:')
t0 = time.time()
edges = list(cpbe.get_all_dag_edges(out))
t1 = time.time()
print(f' found {len(edges)} edges')
#for p in edges:
# print('#', p)
print(f"get_edges time: {t1-t0:.6f}s")
print('-- get_ordered_ops:')
t0 = time.time()
ordered_ops = list(cpbe.stable_toposort(edges))
t1 = time.time()
print(f' found {len(ordered_ops)} ops')
#for p in ordered_ops:
# print('#', p)
print(f"get_ordered_ops time: {t1-t0:.6f}s")
print('-- get_consts:')
t0 = time.time()
const_net_list = cpbe.get_const_nets(ordered_ops)
t1 = time.time()
#for p in const_list:
# print('#', p)
print(f"get_consts time: {t1-t0:.6f}s")
print('-- add_read_ops:')
t0 = time.time()
output_ops = list(cpbe.add_read_ops(ordered_ops))
t1 = time.time()
#for p in output_ops:
# print('#', p)
print(f"add_read_ops time: {t1-t0:.6f}s")
print('-- add_write_ops:')
t0 = time.time()
extended_output_ops = list(cpbe.add_write_ops(output_ops, const_net_list))
t1 = time.time()
#for p in extended_output_ops:
# print('#', p)
print(f"add_write_ops time: {t1-t0:.6f}s")
print('--')
print('-- load_stencil_db:')
t0 = time.time()
sdb = stencil_db_from_package()
dw = binw.data_writer(sdb.byteorder)
t1 = time.time()
print(f"load_stencil_db time: {t1-t0:.6f}s")
# Get all nets/variables associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops)
stencil_names = {node.name for _, node in extended_output_ops}
print(f'-- get_sub_functions: {len(stencil_names)}')
t0 = time.time()
aux_function_names = sdb.get_sub_functions(stencil_names)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
print('-- const_sections_from_functions:')
t0 = time.time()
used_sections = sdb.const_sections_from_functions(aux_function_names | stencil_names)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
print('-- get_section_layout:')
t0 = time.time()
section_mem_layout, sections_length = get_section_layout(used_sections, sdb)
2025-11-24 15:22:46 +00:00
variable_mem_layout, _ = get_data_layout(variable_list, sdb, sections_length)
2025-11-14 21:50:23 +00:00
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
variables: dict[backend.Net, tuple[int, int, str]] = {}
data_list: list[bytes] = []
patch_list: list[patch_entry] = []
print('-- write_data:')
t0 = time.time()
# Heap constants
for section_id, start, lengths in section_mem_layout:
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(start)
dw.write_int(lengths)
dw.write_bytes(sdb.get_section_data(section_id))
# Heap variables
for net, start, lengths in variable_mem_layout:
variables[net] = (start, lengths, net.dtype)
if isinstance(net.source, CPConstant):
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(start)
dw.write_int(lengths)
dw.write_value(net.source.value, lengths)
#print(f'+ {net.dtype} {net.source.value}')
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
# prep auxiliary_functions
2025-11-24 15:22:46 +00:00
_, aux_func_addr_lookup, aux_function_lengths = get_aux_func_layout(aux_function_names, sdb)
2025-11-14 21:50:23 +00:00
# Prepare program code and relocations
object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout}
section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout}
# assemble stencils to main program and patch stencils
data = sdb.get_function_code('entry_function_shell', 'start')
data_list.append(data)
offset = aux_function_lengths + len(data)
print('-- relocate stencils:')
t0 = time.time()
for associated_net, node in extended_output_ops:
assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found"
data = sdb.get_stencil_code(node.name)
data_list.append(data)
#print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for reloc in sdb.get_relocations(node.name, stencil=True):
if reloc.target_symbol_info in ('STT_OBJECT', 'STT_NOTYPE', 'STT_SECTION'):
#print('-- ' + reloc.target_symbol_name + ' // ' + node.name)
if reloc.target_symbol_name.startswith('dummy_'):
# Patch for write and read addresses to/from heap variables
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
obj_addr = object_addr_lookup[associated_net]
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
elif reloc.target_symbol_name.startswith('result_'):
# Set return jump address to address of following stencil
patch = sdb.get_patch(reloc, offset + len(data), offset, binw.Command.PATCH_FUNC.value)
else:
# Patch constants addresses on heap
assert reloc.target_section_index in section_addr_lookup, f"- Function or object in {node.name} missing: {reloc.pelfy_reloc.symbol.name}"
obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
elif reloc.target_symbol_info == 'STT_FUNC':
func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value)
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
else:
raise ValueError(f"Unsupported: {node.name} {reloc.target_symbol_info} {reloc.target_symbol_name}")
patch_list.append(patch)
offset += len(data)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
print('-- relocate aux functions:')
t0 = time.time()
# Patch aux functions
2025-11-24 15:22:46 +00:00
for name, start in aux_func_addr_lookup.items():
2025-11-14 21:50:23 +00:00
for reloc in sdb.get_relocations(name):
#assert reloc.target_symbol_info != 'STT_FUNC', "Not tested yet!"
if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE', 'STT_SECTION'}:
# Patch constants/variable addresses on heap
#print('--> DATA ', name, reloc.pelfy_reloc.symbol.name, reloc.pelfy_reloc.symbol.info, reloc.pelfy_reloc.symbol.section.name)
assert reloc.target_section_index in section_addr_lookup, f"- Function or object in {name} missing: {reloc.pelfy_reloc.symbol.name}"
obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
patch = sdb.get_patch(reloc, obj_addr, start, binw.Command.PATCH_OBJECT.value)
elif reloc.target_symbol_info == 'STT_FUNC':
#print('--> FUNC', name, reloc.pelfy_reloc.symbol.name, reloc.pelfy_reloc.symbol.info, reloc.pelfy_reloc.symbol.section.name)
func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
patch = sdb.get_patch(reloc, func_addr, start, binw.Command.PATCH_FUNC.value)
#print(f' FUNC {func_addr=} {start=} {patch.address=}')
else:
raise ValueError(f"Unsupported: {name=} {reloc.target_symbol_info=} {reloc.target_symbol_name=} {reloc.target_section_index}")
patch_list.append(patch)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
if __name__ == "__main__":
test_timing_compiler()