From ba4531ee6971a86322f1f6dafe60fad1538062a0 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Thu, 23 Oct 2025 23:24:57 +0200 Subject: [PATCH] support for stencils using heap stored constants added --- src/copapy/__init__.py | 2 +- src/copapy/_compiler.py | 70 ++++++++++++++++++++++++++++++----------- src/copapy/_stencils.py | 26 +++++++++++++-- src/copapy/_target.py | 2 +- src/copapy/backend.py | 2 +- tests/test_ext_ops.py | 6 ++-- 6 files changed, 82 insertions(+), 26 deletions(-) diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index e2dcd01..2957265 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -13,4 +13,4 @@ __all__ = [ "cpvector", "generic_sdb", "iif", -] \ No newline at end of file +] diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 8019855..9fa2067 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -1,6 +1,6 @@ from typing import Generator, Iterable, Any from . import _binwrite as binw -from ._stencils import stencil_database +from ._stencils import stencil_database, patch_entry from collections import defaultdict, deque from ._basic_types import Net, Node, Write, InitVar, Op, transl_type @@ -155,8 +155,7 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]: return list(nets) -def get_variable_mem_layout(variable_list: Iterable[Net], sdb: stencil_database) -> tuple[list[tuple[Net, int, int]], int]: - offset: int = 0 +def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[Net, int, int]], int]: object_list: list[tuple[Net, int, int]] = [] for variable in variable_list: @@ -167,8 +166,22 @@ def get_variable_mem_layout(variable_list: Iterable[Net], sdb: stencil_database) return object_list, offset -def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]: - offset: int = 0 +def get_target_sym_lookup(function_names: Iterable[str], sdb: stencil_database) -> dict[str, patch_entry]: + return {patch.target_symbol_name: patch for name in set(function_names) for patch in sdb.get_patch_positions(name)} + + +def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]: + section_list: list[tuple[int, int, int]] = [] + + for id in section_indexes: + lengths = sdb.get_section_size(id) + section_list.append((id, offset, lengths)) + offset += (lengths + 3) // 4 * 4 + + return section_list, offset + + +def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[str, int, int]], int]: function_list: list[tuple[str, int, int]] = [] for name in function_names: @@ -181,6 +194,8 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: variables: dict[Net, tuple[int, int, str]] = dict() + data_list: list[bytes] = [] + patch_list: list[tuple[int, int, int, binw.Command]] = [] ordered_ops = list(stable_toposort(get_all_dag_edges(node_list))) const_net_list = get_const_nets(ordered_ops) @@ -195,11 +210,24 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database # Get all nets/variables associated with heap memory variable_list = get_nets([[const_net_list]], extended_output_ops) - # Write data - variable_mem_layout, data_section_lengths = get_variable_mem_layout(variable_list, sdb) - dw.write_com(binw.Command.ALLOCATE_DATA) - dw.write_int(data_section_lengths) + stencil_names = [node.name for _, node in extended_output_ops] + aux_function_names = sdb.get_sub_functions(stencil_names) + used_sections = sdb.const_sections_from_functions(aux_function_names | set(stencil_names)) + # Write data + section_mem_layout, sections_length = get_section_layout(used_sections, sdb) + variable_mem_layout, variables_data_lengths = get_data_layout(variable_list, sdb, sections_length) + dw.write_com(binw.Command.ALLOCATE_DATA) + dw.write_int(variables_data_lengths) + + # Heap constants + for section_id, out_offs, lengths in section_mem_layout: + dw.write_com(binw.Command.COPY_DATA) + dw.write_int(out_offs) + dw.write_int(lengths) + dw.write_bytes(sdb.get_section_data(section_id)) + + # Heap variables for net, out_offs, lengths in variable_mem_layout: variables[net] = (out_offs, lengths, net.dtype) if isinstance(net.source, InitVar): @@ -210,14 +238,12 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database # print(f'+ {net.dtype} {net.source.value}') # prep auxiliary_functions - aux_function_names = sdb.get_sub_functions(node.name for _, node in extended_output_ops) aux_function_mem_layout, aux_function_lengths = get_aux_function_mem_layout(aux_function_names, sdb) aux_func_addr_lookup = {name: offs for name, offs, _ in aux_function_mem_layout} # Prepare program code and relocations object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout} - data_list: list[bytes] = [] - patch_list: list[tuple[int, int, int, binw.Command]] = [] + section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout} offset = aux_function_lengths # offset in generated code chunk # assemble stencils to main program @@ -229,14 +255,22 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found" data = sdb.get_stencil_code(node.name) data_list.append(data) - # print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) + #print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) for patch in sdb.get_patch_positions(node.name): - if patch.target_symbol_info == 'STT_OBJECT': - assert associated_net, f"Relocation found but no net defined for operation {node.name}" - addr = object_addr_lookup[associated_net] - patch_value = addr + patch.addend - (offset + patch.addr) + if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: + if patch.target_symbol_name.startswith('dummy_'): + # Patch for write and read addresses to/from heap variables + assert associated_net, f"Relocation found but no net defined for operation {node.name}" + #print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") + addr = object_addr_lookup[associated_net] + patch_value = addr + patch.addend - (offset + patch.addr) + else: + # Patch constants addresses on heap + addr = section_addr_lookup[patch.target_symbol_section_index] + patch_value = addr + patch.addend - (offset + patch.addr) patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) + elif patch.target_symbol_info == 'STT_FUNC': addr = aux_func_addr_lookup[patch.target_symbol_name] patch_value = addr + patch.addend - (offset + patch.addr) @@ -277,4 +311,4 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database dw.write_com(binw.Command.ENTRY_POINT) dw.write_int(aux_function_lengths) - return dw, variables \ No newline at end of file + return dw, variables diff --git a/src/copapy/_stencils.py b/src/copapy/_stencils.py index cd6b750..91714dc 100644 --- a/src/copapy/_stencils.py +++ b/src/copapy/_stencils.py @@ -25,6 +25,7 @@ class patch_entry: addend: int target_symbol_name: str target_symbol_info: str + target_symbol_section_index: int def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType: @@ -58,8 +59,8 @@ def get_stencil_position(func: elf_symbol) -> tuple[int, int]: def get_last_call_in_function(func: elf_symbol) -> int: # Find last relocation in function + assert func.relocations, f'No call function in stencil function {func.name}.' reloc = func.relocations[-1] - assert reloc, f'No call function in stencil function {func.name}.' return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION @@ -107,6 +108,17 @@ class stencil_database(): # sym.relocations # self.elf.symbols[name].data + def const_sections_from_functions(self, symbol_names: Iterable[str]) -> list[int]: + ret: set[int] = set() + + for name in symbol_names: + for reloc in self.elf.symbols[name].relocations: + sym = reloc.symbol + if sym.section and sym.section.type == 'SHT_PROGBITS' and \ + sym.info != 'STT_FUNC' and not sym.name.startswith('dummy_'): + ret.add(sym.section.index) + return list(ret) + def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]: """Return patch positions for a provided symbol (function or object) @@ -130,7 +142,11 @@ class stencil_database(): reloc.bits, reloc.fields['r_addend']) - patch = patch_entry(rtype, patch_offset, reloc.fields['r_addend'], reloc.symbol.name, reloc.symbol.info) + patch = patch_entry(rtype, patch_offset, + reloc.fields['r_addend'], + reloc.symbol.name, + reloc.symbol.info, + reloc.symbol.fields['st_shndx']) # Exclude the call to the result_* function if patch.addr < end_index - start_index: @@ -162,6 +178,12 @@ class stencil_database(): def get_symbol_size(self, name: str) -> int: return self.elf.symbols[name].fields['st_size'] + def get_section_size(self, id: int) -> int: + return self.elf.sections[id].fields['sh_size'] + + def get_section_data(self, id: int) -> bytes: + return self.elf.sections[id].data + def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes: """Returns machine code for a specified function name""" func = self.elf.symbols[name] diff --git a/src/copapy/_target.py b/src/copapy/_target.py index 9f77903..7ed7944 100644 --- a/src/copapy/_target.py +++ b/src/copapy/_target.py @@ -88,4 +88,4 @@ class Target(): def read_value_remote(self, net: Net) -> None: dw = binw.data_writer(self.sdb.byteorder) add_read_command(dw, self._variables, net) - assert coparun(dw.get_data()) > 0 \ No newline at end of file + assert coparun(dw.get_data()) > 0 diff --git a/src/copapy/backend.py b/src/copapy/backend.py index 447a352..3fca0ed 100644 --- a/src/copapy/backend.py +++ b/src/copapy/backend.py @@ -17,4 +17,4 @@ __all__ = [ "get_all_dag_edges", "add_read_ops", "add_write_ops", -] \ No newline at end of file +] diff --git a/tests/test_ext_ops.py b/tests/test_ext_ops.py index 8123694..009c486 100644 --- a/tests/test_ext_ops.py +++ b/tests/test_ext_ops.py @@ -5,11 +5,11 @@ import copapy def test_compile(): c_i = cpvalue(9) - c_f = cpvalue(1.111) + c_f = cpvalue(2.5) # c_b = cpvalue(True) ret_test = (c_f ** c_f, c_i ** c_i) - ret_ref = (1.111 ** 1.111, 9 ** 9) + ret_ref = (2.5 ** 2.5, 9 ** 9) tg = Target() print('* compile and copy ...') @@ -24,7 +24,7 @@ def test_compile(): print('+', val, ref, type(val), test.dtype) #for t in (int, float, bool): # assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" - assert val == pytest.approx(ref, 1e-3), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] + assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] if __name__ == "__main__":