diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index baca9cf..acafd77 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -1,7 +1,6 @@ import pkgutil from typing import Any, TypeVar, overload, TypeAlias, Generic, cast -from ._stencils import stencil_database -import platform +from ._stencils import stencil_database, detect_process_arch import copapy as cp NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool' @@ -12,17 +11,25 @@ unibool: TypeAlias = 'variable[bool] | bool' TCPNum = TypeVar("TCPNum", bound='variable[Any]') TNum = TypeVar("TNum", int, bool, float) +stencil_cache: dict[tuple[str, str], stencil_database] = {} + def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: return [name for name, value in scope.items() if value is var] def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> stencil_database: + global stencil_cache + ci = (arch, optimization) + if ci in stencil_cache: + return stencil_cache[ci] # return cached stencil db if arch == 'native': - arch = platform.machine() + arch = detect_process_arch() stencil_data = pkgutil.get_data(__name__, f"obj/stencils_{arch}_{optimization}.o") assert stencil_data, f"stencils_{arch}_{optimization} not found" - return stencil_database(stencil_data) + sdb = stencil_database(stencil_data) + stencil_cache[ci] = sdb + return sdb generic_sdb = stencil_db_from_package() diff --git a/src/copapy/_stencils.py b/src/copapy/_stencils.py index 754f72f..ca63547 100644 --- a/src/copapy/_stencils.py +++ b/src/copapy/_stencils.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from pelfy import open_elf_file, elf_file, elf_symbol from typing import Generator, Literal, Iterable import pelfy +import struct +import platform ByteOrder = Literal['little', 'big'] @@ -37,6 +39,28 @@ class patch_entry: patch_type: int +def detect_process_arch() -> str: + bits = struct.calcsize("P") * 8 + arch = platform.machine().lower() + + if arch in ('amd64', 'x86_64'): + arch_family = 'x86_64' if bits == 64 else 'x86' + elif arch in ('i386', 'i686', 'x86'): + arch_family = 'x86' + elif arch in ('arm64', 'aarch64'): + arch_family = 'arm64' + elif 'arm' in arch: + arch_family = 'arm' + elif 'mips' in arch: + arch_family = 'mips64' if bits == 64 else 'mips' + elif 'riscv' in arch: + arch_family = 'riscv64' if bits == 64 else 'riscv' + else: + raise NotImplementedError(f"Platform {arch} with {bits} bits is not supported.") + + return arch_family + + def get_return_function_type(symbol: elf_symbol) -> str: if symbol.relocations: result_func = symbol.relocations[-1].symbol @@ -275,6 +299,10 @@ class stencil_database(): def get_section_size(self, index: int) -> int: """Returns the size of a section specified by index.""" return self.elf.sections[index].fields['sh_size'] + + def get_section_alignment(self, index: int) -> int: + """Returns the required alignment of a section specified by index.""" + return self.elf.sections[index].fields['sh_addralign'] def get_section_data(self, index: int) -> bytes: """Returns the data of a section specified by index."""