diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 6d6b6b4..afac4db 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -14,10 +14,14 @@ def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: return [name for name, value in scope.items() if value is var] -def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> stencil_database: +def get_local_arch() -> str: arch_translation_table = {'ARM64': 'x86_64'} + return arch_translation_table.get(platform.machine(), platform.machine()) + + +def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> stencil_database: if arch == 'native': - arch = arch_translation_table.get(platform.machine(), platform.machine()) + arch = get_local_arch() stencil_data = pkgutil.get_data(__name__, f"obj/stencils_{arch}_{optimization}.o") assert stencil_data, f"stencils_{arch}_{optimization} not found" return stencil_database(stencil_data) diff --git a/tests/test_stencil_db.py b/tests/test_stencil_db.py index a5ff581..cc555b1 100644 --- a/tests/test_stencil_db.py +++ b/tests/test_stencil_db.py @@ -1,10 +1,8 @@ -from copapy import stencil_database -from copapy import stencil_db -import platform +from copapy import stencil_database, stencil_db, get_local_arch def test_list_symbols(): - arch = platform.machine() + arch = get_local_arch() sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o') print('----') #print(sdb.function_definitions) @@ -14,7 +12,7 @@ def test_list_symbols(): def test_start_end_function(): - arch = platform.machine() + arch = get_local_arch() sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o') for sym_name in sdb.function_definitions.keys(): data = sdb.elf.symbols[sym_name].data