diff --git a/src/copapy/_binwrite.py b/src/copapy/_binwrite.py index 294d0d3..f4d7272 100644 --- a/src/copapy/_binwrite.py +++ b/src/copapy/_binwrite.py @@ -39,7 +39,7 @@ class data_writer(): data = struct.pack(en + 'f', value) else: data = struct.pack(en + 'd', value) - assert len(data) == num_bytes + assert len(data) == num_bytes, (len(data), num_bytes) self.write_bytes(data) def print(self) -> None: diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 535bde1..4732e85 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -1,6 +1,6 @@ from typing import Generator, Iterable, Any from . import _binwrite as binw -from ._stencils import stencil_database +from ._stencils import stencil_database, patch_entry from collections import defaultdict, deque from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type @@ -171,7 +171,7 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset: object_list: list[tuple[Net, int, int]] = [] for variable in variable_list: - lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype)) + lengths = sdb.get_type_size(transl_type(variable.dtype)) object_list.append((variable, offset, lengths)) offset += (lengths + 3) // 4 * 4 @@ -236,7 +236,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi """ variables: dict[Net, tuple[int, int, str]] = {} data_list: list[bytes] = [] - patch_list: list[tuple[int, int, int, binw.Command]] = [] + patch_list: list[patch_entry] = [] ordered_ops = list(stable_toposort(get_all_dag_edges(node_list))) const_net_list = get_const_nets(ordered_ops) @@ -297,32 +297,31 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi data_list.append(data) #print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) - for patch in sdb.get_patch_positions(node.name, stencil=True): - if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: - if patch.target_symbol_name.startswith('dummy_'): + for reloc in sdb.get_relocations(node.name, stencil=True): + if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: + #print('-- ' + reloc.target_symbol_name + ' // ' + node.name) + if reloc.target_symbol_name.startswith('dummy_'): # Patch for write and read addresses to/from heap variables assert associated_net, f"Relocation found but no net defined for operation {node.name}" #print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") - addr = object_addr_lookup[associated_net] - patch_value = addr + patch.addend - (offset + patch.patch_address) - elif patch.target_symbol_name.startswith('result_'): + obj_addr = object_addr_lookup[associated_net] + patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value) + elif reloc.target_symbol_name.startswith('result_'): raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.") else: # Patch constants addresses on heap - section_addr = section_addr_lookup[patch.target_symbol_section_index] - obj_addr = section_addr + patch.target_symbol_address - patch_value = obj_addr + patch.addend - (offset + patch.patch_address) + obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index] + patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value) #print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name) - patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT)) - #print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name) - elif patch.target_symbol_info == 'STT_FUNC': - addr = aux_func_addr_lookup[patch.target_symbol_name] - patch_value = addr + patch.addend - (offset + patch.patch_address) - patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC)) + elif reloc.target_symbol_info == 'STT_FUNC': + func_addr = aux_func_addr_lookup[reloc.target_symbol_name] + patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value) #print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name) else: - raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") + raise ValueError(f"Unsupported: {node.name} {reloc.target_symbol_info} {reloc.target_symbol_name}") + + patch_list.append(patch) offset += len(data) @@ -334,6 +333,8 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi dw.write_com(binw.Command.ALLOCATE_CODE) dw.write_int(offset) + print('o aux: ', aux_function_mem_layout) + # write aux functions code for name, start, lengths in aux_function_mem_layout: dw.write_com(binw.Command.COPY_CODE) @@ -343,22 +344,21 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi # Patch aux functions for name, start, lengths in aux_function_mem_layout: - for patch in sdb.get_patch_positions(name): - if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: + for reloc in sdb.get_relocations(name): + if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: # Patch constants/variable addresses on heap - section_addr = section_addr_lookup[patch.target_symbol_section_index] - obj_addr = section_addr + patch.target_symbol_address - patch_value = obj_addr + patch.addend - (start + patch.patch_address) - patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT)) + obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index] + patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value) #print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name) - elif patch.target_symbol_info == 'STT_FUNC': - aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name] - patch_value = aux_func_addr + patch.addend - (start + patch.patch_address) - patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC)) + elif reloc.target_symbol_info == 'STT_FUNC': + func_addr = aux_func_addr_lookup[reloc.target_symbol_name] + patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value) else: - raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}") + raise ValueError(f"Unsupported: {name} {reloc.target_symbol_info} {reloc.target_symbol_name}") + + patch_list.append(patch) #assert False, aux_function_mem_layout @@ -369,11 +369,11 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi dw.write_bytes(b''.join(data_list)) # write patch operations - for mask, patch_addr, addr, patch_command in patch_list: - dw.write_com(patch_command) - dw.write_int(patch_addr) - dw.write_int(mask) - dw.write_int(addr, signed=True) + for patch in patch_list: + dw.write_int(patch.patch_type) + dw.write_int(patch.address) + dw.write_int(patch.mask) + dw.write_int(patch.value, signed=True) dw.write_com(binw.Command.ENTRY_POINT) dw.write_int(aux_function_lengths) diff --git a/src/copapy/_stencils.py b/src/copapy/_stencils.py index 980838e..69a9286 100644 --- a/src/copapy/_stencils.py +++ b/src/copapy/_stencils.py @@ -5,6 +5,22 @@ import pelfy ByteOrder = Literal['little', 'big'] + +@dataclass +class relocation_entry: + """ + A dataclass for representing a relocation entry + """ + + target_symbol_name: str + target_symbol_info: str + target_symbol_offset: int + target_section_index: int + function_offset: int + start: int + pelfy_reloc: pelfy.elf_relocation + + @dataclass class patch_entry: """ @@ -15,36 +31,9 @@ class patch_entry: type (RelocationType): relocation type""" mask: int - patch_address: int - addend: int - target_symbol_name: str - target_symbol_info: str - target_symbol_section_index: int - target_symbol_address: int - - -def translate_relocation(reloc: pelfy.elf_relocation, offset: int) -> patch_entry: - if reloc.type.endswith('_PLT32') or reloc.type.endswith('_PC32'): - # S + A - P - mask = 0xFFFFFFFF # 32 bit - imm = offset - - elif reloc.type.endswith('_JUMP26') or reloc.type.endswith('_CALL26'): - # S + A - P - assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64" - mask = 0x3ffffff # 26 bit - imm = offset >> 2 - assert imm < mask, "Relocation immediate value too large" - - else: - raise NotImplementedError(f"Relocation type {reloc.type} not implemented") - - return patch_entry(mask, imm, - reloc.fields['r_addend'], - reloc.symbol.name, - reloc.symbol.info, - reloc.symbol.fields['st_shndx'], - reloc.symbol.fields['st_value']) + address: int + value: int + patch_type: int def get_return_function_type(symbol: elf_symbol) -> str: @@ -72,7 +61,6 @@ def get_last_call_in_function(func: elf_symbol) -> int: # Find last relocation in function assert func.relocations, f'No call function in stencil function {func.name}.' reloc = func.relocations[-1] - print(f"reloc.fields['r_addend'] {reloc.fields['r_addend']}") instruction_lenghs = 4 if reloc.bits < 32 else 5 return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - instruction_lenghs @@ -140,16 +128,7 @@ class stencil_database(): ret.add(sym.section.index) return list(ret) - def get_patch_positions(self, symbol_name: str, stencil: bool = False) -> Generator[patch_entry, None, None]: - """Return patch positions for a provided symbol (function or object) - - Args: - symbol_name: function or object name - - Yields: - patch_entry: every relocation for the symbol - """ - arm_hi_byte_flag: bool = False + def get_relocations(self, symbol_name: str, stencil: bool = False) -> Generator[relocation_entry, None, None]: symbol = self.elf.symbols[symbol_name] if stencil: start_index, end_index = get_stencil_position(symbol) @@ -159,18 +138,58 @@ class stencil_database(): print('->', symbol_name) for reloc in symbol.relocations: - print(' ', symbol_name, arm_hi_byte_flag, reloc.symbol.info) + print(' ', symbol_name, reloc.symbol.info, reloc.symbol.name, reloc.type) + # address to fist byte to patch relative to the start of the symbol patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index if patch_offset < end_index - start_index: # Exclude the call to the result_* function - if reloc.symbol.info == 'STT_SECTION': - arm_hi_byte_flag = True - else: - assert not arm_hi_byte_flag, "Page based relocation for ARM not supported" - # address to fist byte to patch relative to the start of the symbol + yield relocation_entry(reloc.symbol.name, + reloc.symbol.info, + reloc.symbol.fields['st_value'], + reloc.symbol.fields['st_shndx'], + symbol.fields['st_value'], + start_index, + reloc) + + def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry: + """Return patch positions for a provided symbol (function or object) + + Args: + relocation: relocation entry + symbol_address: absolute address of the target symbol + function_offset: absolute address of the first byte of the + function the patch is applied to + + Yields: + patch_entry: every relocation for the symbol + """ + pr = relocation.pelfy_reloc + + # calculate absolut address to the first byte to patch + # relative to the start of the (stripped stencil) function: + patch_offset = pr.fields['r_offset'] - relocation.function_offset - relocation.start + function_offset + + if pr.type.endswith('_PLT32') or pr.type.endswith('_PC32'): + # S + A - P + mask = 0xFFFFFFFF # 32 bit + patch_value = symbol_address + pr.fields['r_addend'] - patch_offset + + print(f'** {patch_offset=} {relocation.target_symbol_name=} {pr.fields['r_offset']=} {relocation.function_offset=} {relocation.start=} {function_offset=}') + print(f' {patch_value=} {symbol_address=} {pr.fields['r_addend']=}, {function_offset=}') + + #elif reloc.type.endswith('_JUMP26') or reloc.type.endswith('_CALL26'): + # # S + A - P + # assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64" + # mask = 0x3ffffff # 26 bit + # imm = offset >> 2 + # assert imm < mask, "Relocation immediate value too large" + + else: + raise NotImplementedError(f"Relocation type {pr.type} not implemented") + + return patch_entry(mask, patch_offset, patch_value, symbol_type) - yield translate_relocation(reloc, patch_offset) def get_stencil_code(self, name: str) -> bytes: """Return the striped function code for a provided function name @@ -202,6 +221,10 @@ class stencil_database(): name_set |= self.get_sub_functions([r.symbol.name]) return name_set + def get_type_size(self, type_name: str) -> int: + """Returns the size of a variable type in bytes.""" + return {'int': 4, 'float': 4}[type_name] + def get_symbol_size(self, name: str) -> int: """Returns the size of a specified symbol name.""" return self.elf.symbols[name].fields['st_size']