diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 90938ac..a30c9df 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -128,10 +128,10 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: typed_op = '_'.join([op] + [a.dtype for a in arg_nets]) - if typed_op not in generic_sdb.function_definitions: + if typed_op not in generic_sdb.stencil_definitions: raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") - result_type = generic_sdb.function_definitions[typed_op].split('_')[0] + result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] result_net = Net(result_type, Op(typed_op, arg_nets)) @@ -346,6 +346,18 @@ def get_variable_mem_layout(variable_list: list[Net], sdb: stencil_database) -> return object_list, offset +def get_aux_function_mem_layout(function_names: list[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]: + offset: int = 0 + function_list: list[tuple[str, int, int]] = [] + + for name in function_names: + lengths = sdb.get_symbol_size(name) + function_list.append((name, offset, lengths)) + offset += (lengths + 3) // 4 * 4 + + return function_list, offset + + def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: variables: dict[Net, tuple[int, int, str]] = dict() @@ -354,17 +366,16 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database output_ops = list(add_read_ops(ordered_ops)) extended_output_ops = list(add_write_ops(output_ops, const_net_list)) - # Get all nets associated with heap memory - variable_list = get_nets([[const_net_list]], extended_output_ops) - dw = binw.data_writer(sdb.byteorder) - object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb) - # Deallocate old allocated memory (if existing) dw.write_com(binw.Command.FREE_MEMORY) + # Get all nets/variables associated with heap memory + variable_list = get_nets([[const_net_list]], extended_output_ops) + # Write data + object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb) dw.write_com(binw.Command.ALLOCATE_DATA) dw.write_int(data_section_lengths) @@ -377,22 +388,33 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database dw.write_value(net.source.value, lengths) # print(f'+ {net.dtype} {net.source.value}') - # write auxiliary_functions - # TODO + # Write auxiliary_functions + object_list, data_section_lengths = get_aux_function_mem_layout(variable_list, sdb) + dw.write_com(binw.Command.ALLOCATE_DATA) + dw.write_int(data_section_lengths) + + for net, out_offs, lengths in object_list: + variables[net] = (out_offs, lengths, net.dtype) + if isinstance(net.source, InitVar): + dw.write_com(binw.Command.COPY_DATA) + dw.write_int(out_offs) + dw.write_int(lengths) + dw.write_value(net.source.value, lengths) + # print(f'+ {net.dtype} {net.source.value}') # Prepare program code and relocations - object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list} + object_addr_lookup = {net: out_offs for net, out_offs, _ in object_list} data_list: list[bytes] = [] patch_list: list[tuple[int, int, int]] = [] offset = 0 # offset in generated code chunk # assemble stencils to main program - data = sdb.get_function_body('entry_function_shell', 'start') + data = sdb.get_function_code('entry_function_shell', 'start') data_list.append(data) offset += len(data) for associated_net, node in extended_output_ops: - assert node.name in sdb.function_definitions, f"- Warning: {node.name} stencil not found" + assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found" data = sdb.get_stencil_code(node.name) data_list.append(data) # print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) @@ -400,7 +422,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database for patch in sdb.get_patch_positions(node.name): if patch.target_symbol_info == 'STT_OBJECT': assert associated_net, f"Relocation found but no net defined for operation {node.name}" - object_addr = object_addr_lookp[associated_net] + object_addr = object_addr_lookup[associated_net] patch_value = object_addr + patch.addend - (offset + patch.addr) # print('patch: ', patch, object_addr, patch_value) patch_list.append((patch.type.value, offset + patch.addr, patch_value)) @@ -410,7 +432,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database offset += len(data) - data = sdb.get_function_body('entry_function_shell', 'end') + data = sdb.get_function_code('entry_function_shell', 'end') data_list.append(data) offset += len(data) # print('function_end', offset, data) diff --git a/src/copapy/stencil_db.py b/src/copapy/stencil_db.py index ae7e1b0..c06a5f4 100644 --- a/src/copapy/stencil_db.py +++ b/src/copapy/stencil_db.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pelfy import open_elf_file, elf_file, elf_symbol -from typing import Generator, Literal +from typing import Generator, Literal, Iterable from enum import Enum ByteOrder = Literal['little', 'big'] @@ -45,11 +45,9 @@ def get_return_function_type(symbol: elf_symbol) -> str: def strip_function(func: elf_symbol) -> bytes: """Return stencil code by striped stancil function""" - if func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE': - start_index, end_index = get_stencil_position(func) - return func.data[start_index:end_index] - else: - return func.data + assert func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE', f"{func.name} is not a stancil function" + start_index, end_index = get_stencil_position(func) + return func.data[start_index:end_index] def get_stencil_position(func: elf_symbol) -> tuple[int, int]: @@ -64,13 +62,15 @@ def get_last_call_in_function(func: elf_symbol) -> int: assert reloc, f'No call function in stencil function {func.name}.' return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION +def symbol_is_stencil(sym: elf_symbol) -> bool: + return (sym.info == 'STT_FUNC' and len(sym.relocations) > 0 and + sym.relocations[-1].symbol.info == 'STT_NOTYPE') class stencil_database(): """A class for loading and querying a stencil database from an ELF object file Attributes: - function_definitions (dict[str, str]): dictionary of function names and their return types - data (dict[str, bytes]): dictionary of function names and their stripped code + stencil_definitions (dict[str, str]): dictionary of function names and their return types var_size (dict[str, int]): dictionary of object names and their sizes byteorder (ByteOrder): byte order of the ELF file elf (elf_file): the loaded ELF file @@ -87,21 +87,23 @@ class stencil_database(): else: self.elf = elf_file(obj_file) - self.function_definitions = {s.name: get_return_function_type(s) - for s in self.elf.symbols - if s.info == 'STT_FUNC'} - self.data = {s.name: strip_function(s) - for s in self.elf.symbols - if s.info == 'STT_FUNC'} + self.stencil_definitions = {s.name: get_return_function_type(s) + for s in self.elf.symbols + if s.info == 'STT_FUNC'} + + #self.data = {s.name: strip_function(s) + # for s in self.elf.symbols + # if s.info == 'STT_FUNC'} + self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols if s.info == 'STT_OBJECT'} self.byteorder: ByteOrder = self.elf.byteorder - for name in self.function_definitions.keys(): - sym = self.elf.symbols[name] - sym.relocations - self.elf.symbols[name].data + #for name in self.function_definitions.keys(): + # sym = self.elf.symbols[name] + # sym.relocations + # self.elf.symbols[name].data def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]: """Return patch positions for a provided symbol (function or object) @@ -142,11 +144,28 @@ class stencil_database(): Striped function code """ return strip_function(self.elf.symbols[name]) + + def get_sub_functions(self, names: Iterable[str]) -> set[str]: + name_set: set[str] = set() + for name in names: + if name not in name_set: + func = self.elf.symbols[name] + for reloc in func.relocations: + name_set.add(reloc.symbol.name) + name_set |= self.get_sub_functions(name_set) + return name_set - def get_function_body(self, name: str, part: Literal['start', 'end']) -> bytes: + def get_symbol_size(self, name: str) -> int: + return self.elf.symbols[name].fields['st_size'] + + def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes: + """Returns machine code for a specified function name""" func = self.elf.symbols[name] + assert func.info == 'STT_FUNC', f"{name} is not a function" index = get_last_call_in_function(func) if part == 'start': return func.data[:index] + elif part == 'end': + return func.data[index + LENGTH_CALL_INSTRUCTION:] else: - return func.data[index + LENGTH_CALL_INSTRUCTION:] \ No newline at end of file + return func.data \ No newline at end of file diff --git a/tests/test_stencil_db.py b/tests/test_stencil_db.py index 39441b4..c28442b 100644 --- a/tests/test_stencil_db.py +++ b/tests/test_stencil_db.py @@ -7,13 +7,13 @@ sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o') def test_list_symbols(): print('----') #print(sdb.function_definitions) - for sym_name in sdb.function_definitions.keys(): + for sym_name in sdb.stencil_definitions.keys(): print('\n-', sym_name) #print(list(sdb.get_patch_positions(sym_name))) def test_start_end_function(): - for sym_name in sdb.function_definitions.keys(): + for sym_name in sdb.stencil_definitions.keys(): symbol = sdb.elf.symbols[sym_name] if symbol.relocations and symbol.relocations[-1].symbol.info == 'STT_NOTYPE': @@ -26,7 +26,7 @@ def test_start_end_function(): def test_aux_functions(): - for sym_name in sdb.function_definitions.keys(): + for sym_name in sdb.stencil_definitions.keys(): symbol = sdb.elf.symbols[sym_name] for reloc in symbol.relocations: if reloc.symbol.info != "STT_NOTYPE": @@ -35,5 +35,4 @@ def test_aux_functions(): if __name__ == "__main__": - test_list_symbols() - test_start_end_function() + test_aux_functions() diff --git a/tools/make_example.py b/tools/make_example.py index 312557c..8257d90 100644 --- a/tools/make_example.py +++ b/tools/make_example.py @@ -8,7 +8,7 @@ def test_compile() -> None: c1 = CPVariable(9) #ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4] - ret = [c1 / 4] + ret = [c1 // 4] out = [Write(r) for r in ret]