diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 87cae78..64585e1 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -270,9 +270,9 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi # Get all nets/variables associated with heap memory variable_list = get_nets([[const_net_list]], extended_output_ops) - stencil_names = [node.name for _, node in extended_output_ops] + stencil_names = {node.name for _, node in extended_output_ops} aux_function_names = sdb.get_sub_functions(stencil_names) - used_sections = sdb.const_sections_from_functions(aux_function_names | set(stencil_names)) + used_sections = sdb.const_sections_from_functions(aux_function_names | stencil_names) # Write data section_mem_layout, sections_length = get_section_layout(used_sections, sdb) diff --git a/src/copapy/_stencils.py b/src/copapy/_stencils.py index 6f1574e..35e7483 100644 --- a/src/copapy/_stencils.py +++ b/src/copapy/_stencils.py @@ -69,14 +69,6 @@ def get_return_function_type(symbol: elf_symbol) -> str: return 'void' -def strip_function(func: elf_symbol) -> bytes: - """Return stencil code by striped stancil function""" - #TODO: Add caching since get_last_call_in_function calls the slow "func.relocations" - #assert func.relocations and any(reloc.symbol.name.startswith('result_') for reloc in func.relocations), f"{func.name} is not a stencil function" # <--- is slow - start_index, end_index = get_stencil_position(func) - return func.data[start_index:end_index] - - def get_stencil_position(func: elf_symbol) -> tuple[int, int]: start_index = 0 # There must be no prolog # Find last relocation in function @@ -149,6 +141,7 @@ class stencil_database(): # self.elf.symbols[name].data self._relocation_cache: dict[tuple[str, bool], list[relocation_entry]] = {} + self._stencil_cache: dict[str, tuple[int, int]] = {} def const_sections_from_functions(self, symbol_names: Iterable[str]) -> list[int]: ret: set[int] = set() @@ -285,7 +278,6 @@ class stencil_database(): return patch_entry(mask, patch_offset, patch_value, scale, symbol_type) - def get_stencil_code(self, name: str) -> bytes: """Return the striped function code for a provided function name @@ -295,7 +287,17 @@ class stencil_database(): Returns: Striped function code """ - return strip_function(self.elf.symbols[name]) + if name in self._stencil_cache: + start_index, lengths = self._stencil_cache[name] + else: + func = self.elf.symbols[name] + start_stencil, end_stencil = get_stencil_position(func) + assert func.section + start_index = func.section['sh_offset'] + func['st_value'] + start_stencil + lengths = end_stencil - start_stencil + self._stencil_cache[name] = (start_index, lengths) + + return self.elf.read_bytes(start_index, lengths) def get_sub_functions(self, names: Iterable[str]) -> set[str]: """Return recursively all functions called by stencils or by other functions