diff --git a/src/copapy/stencil_db.py b/src/copapy/stencil_db.py index febda61..c060354 100644 --- a/src/copapy/stencil_db.py +++ b/src/copapy/stencil_db.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -import pelfy +from pelfy import open_elf_file, elf_file, elf_symbol from typing import Generator, Literal from enum import Enum @@ -42,7 +42,7 @@ def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_add raise Exception(f"Unknown relocation type: {reloc_type}") -def get_return_function_type(symbol: pelfy.elf_symbol) -> str: +def get_return_function_type(symbol: elf_symbol) -> str: if symbol.relocations: result_func = symbol.relocations[-1].symbol if result_func.name.startswith('result_'): @@ -50,20 +50,41 @@ def get_return_function_type(symbol: pelfy.elf_symbol) -> str: return 'void' -def strip_symbol(data: bytes, byteorder: ByteOrder) -> bytes: +def strip_function(func: elf_symbol) -> bytes: """Return striped function code based on NOP markers""" - start_index, end_index = get_stencil_position(data, byteorder) - return data[start_index:end_index] + start_index, end_index = get_stencil_position(func) + return func.data[start_index:end_index] -def get_stencil_position(data: bytes, byteorder: ByteOrder) -> tuple[int, int]: +def get_stencil_position(func: elf_symbol) -> tuple[int, int]: + + #assert func.name != 'function_start', func.relocations + # Find first start marker - marker_index = data.find(START_MARKER.to_bytes(MARKER_LENGTH, byteorder)) + marker_index = func.data.find(START_MARKER.to_bytes(MARKER_LENGTH, func.file.byteorder)) start_index = 0 if marker_index < 0 else marker_index + MARKER_LENGTH # Find last end marker - end_index = data.rfind(END_MARKER.to_bytes(MARKER_LENGTH, byteorder)) - end_index = len(data) if end_index < 0 else end_index - LENGTH_CALL_INSTRUCTION + end_index = func.data.rfind(END_MARKER.to_bytes(MARKER_LENGTH, func.file.byteorder)) + end_index = len(func.data) if end_index < 0 else end_index - LENGTH_CALL_INSTRUCTION + + reloc = func.relocations[-1] + end_index2 = reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION + + print(func.relocations[-1]) + assert end_index2 == end_index, func.name + + return start_index, end_index + + +def get_stencil_position2(func: elf_symbol) -> tuple[int, int]: + # Find first start marker + marker_index = func.data.find(START_MARKER.to_bytes(MARKER_LENGTH, func.file.byteorder)) + start_index = 0 if marker_index < 0 else marker_index + MARKER_LENGTH + + # Find last end marker + end_index = func.data.rfind(END_MARKER.to_bytes(MARKER_LENGTH, func.file.byteorder)) + end_index = len(func.data) if end_index < 0 else end_index - LENGTH_CALL_INSTRUCTION return start_index, end_index @@ -86,20 +107,20 @@ class stencil_database(): obj_file: path to the ELF object file or bytes of the ELF object file """ if isinstance(obj_file, str): - self.elf = pelfy.open_elf_file(obj_file) + self.elf = open_elf_file(obj_file) else: - self.elf = pelfy.elf_file(obj_file) + self.elf = elf_file(obj_file) self.function_definitions = {s.name: get_return_function_type(s) for s in self.elf.symbols if s.info == 'STT_FUNC'} - self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) + self.data = {s.name: strip_function(s) for s in self.elf.symbols if s.info == 'STT_FUNC'} self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols if s.info == 'STT_OBJECT'} - self.byteorder = self.elf.byteorder + self.byteorder: ByteOrder = self.elf.byteorder for name in self.function_definitions.keys(): sym = self.elf.symbols[name] @@ -116,7 +137,7 @@ class stencil_database(): patch_entry: every relocation for the symbol """ symbol = self.elf.symbols[symbol_name] - start_index, end_index = get_stencil_position(symbol.data, symbol.file.byteorder) + start_index, end_index = get_stencil_position(symbol) for reloc in symbol.relocations: @@ -140,4 +161,4 @@ class stencil_database(): Returns: Striped function code """ - return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder) + return strip_function(self.elf.symbols[name]) diff --git a/tests/test_stencil_db.py b/tests/test_stencil_db.py index 5f92064..ab7b4b7 100644 --- a/tests/test_stencil_db.py +++ b/tests/test_stencil_db.py @@ -16,12 +16,12 @@ def test_start_end_function(): arch = platform.machine() sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o') for sym_name in sdb.function_definitions.keys(): - data = sdb.elf.symbols[sym_name].data - print('-', sym_name, stencil_db.get_stencil_position(data, sdb.elf.byteorder), len(data)) + symbol = sdb.elf.symbols[sym_name] + print('-', sym_name, stencil_db.get_stencil_position(symbol), len(symbol.data)) - start, end = stencil_db.get_stencil_position(data, sdb.elf.byteorder) + start, end = stencil_db.get_stencil_position(symbol) - assert start >= 0 and end >= start and end <= len(data) + assert start >= 0 and end >= start and end <= len(symbol.data) if __name__ == "__main__":