From 0c2f2812db698beb986b7ba710d2d0f42b48a9b7 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Fri, 14 Nov 2025 16:23:51 +0100 Subject: [PATCH] caching for stencil_database.get_relocations added --- src/copapy/_stencils.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/copapy/_stencils.py b/src/copapy/_stencils.py index 6890af2..6f1574e 100644 --- a/src/copapy/_stencils.py +++ b/src/copapy/_stencils.py @@ -13,7 +13,6 @@ class relocation_entry: """ A dataclass for representing a relocation entry """ - target_symbol_name: str target_symbol_info: str target_symbol_offset: int @@ -30,8 +29,8 @@ class patch_entry: Attributes: addr (int): address of first byte to patch relative to the start of the symbol - type (RelocationType): relocation type""" - + type (RelocationType): relocation type + """ mask: int address: int value: int @@ -72,7 +71,8 @@ def get_return_function_type(symbol: elf_symbol) -> str: def strip_function(func: elf_symbol) -> bytes: """Return stencil code by striped stancil function""" - assert func.relocations and any(reloc.symbol.name.startswith('result_') for reloc in func.relocations), f"{func.name} is not a stencil function" + #TODO: Add caching since get_last_call_in_function calls the slow "func.relocations" + #assert func.relocations and any(reloc.symbol.name.startswith('result_') for reloc in func.relocations), f"{func.name} is not a stencil function" # <--- is slow start_index, end_index = get_stencil_position(func) return func.data[start_index:end_index] @@ -109,11 +109,6 @@ def get_op_after_last_call_in_function(func: elf_symbol) -> int: return reloc.fields['r_offset'] - func.fields['st_value'] + 4 -def symbol_is_stencil(sym: elf_symbol) -> bool: - return (sym.info == 'STT_FUNC' and len(sym.relocations) > 0 and - sym.relocations[-1].symbol.info == 'STT_NOTYPE') - - class stencil_database(): """A class for loading and querying a stencil database from an ELF object file @@ -153,6 +148,8 @@ class stencil_database(): # sym.relocations # self.elf.symbols[name].data + self._relocation_cache: dict[tuple[str, bool], list[relocation_entry]] = {} + def const_sections_from_functions(self, symbol_names: Iterable[str]) -> list[int]: ret: set[int] = set() @@ -165,6 +162,17 @@ class stencil_database(): return list(ret) def get_relocations(self, symbol_name: str, stencil: bool = False) -> Generator[relocation_entry, None, None]: + cache_key = (symbol_name, stencil) + if cache_key in self._relocation_cache: + # cache hit: + for reloc_entry in self._relocation_cache[cache_key]: + yield reloc_entry + return + + # cache miss: + cache: list[relocation_entry] = [] + self._relocation_cache[cache_key] = cache + symbol = self.elf.symbols[symbol_name] if stencil: start_index, end_index = get_stencil_position(symbol) @@ -178,13 +186,15 @@ class stencil_database(): patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index if patch_offset < end_index - start_index: # Exclude the call to the result_* function - yield relocation_entry(reloc.symbol.name, + reloc_entry = relocation_entry(reloc.symbol.name, reloc.symbol.info, reloc.symbol.fields['st_value'], reloc.symbol.fields['st_shndx'], symbol.fields['st_value'], start_index, reloc) + cache.append(reloc_entry) + yield reloc_entry def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry: """Return patch positions for a provided symbol (function or object)