From 6445ac972493ebc16c5371dc87be0a3c47539f55 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Sun, 26 Oct 2025 22:26:12 +0100 Subject: [PATCH] relocation patching for constants is fixed, tests added --- src/copapy/_compiler.py | 35 ++++++++++++++++++----------------- src/copapy/_stencils.py | 10 ++++++---- src/copapy/_vectors.py | 6 +++--- tests/test_math.py | 30 +++++++++++++++++------------- tests/test_vector.py | 29 +++++++++++++++++------------ 5 files changed, 61 insertions(+), 49 deletions(-) diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 3d697c4..3e46f94 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -192,8 +192,8 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data return function_list, offset -def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: - variables: dict[Net, tuple[int, int, str]] = dict() +def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: + variables: dict[Net, tuple[int, int, str]] = {} data_list: list[bytes] = [] patch_list: list[tuple[int, int, int, binw.Command]] = [] @@ -263,22 +263,23 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database assert associated_net, f"Relocation found but no net defined for operation {node.name}" #print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") addr = object_addr_lookup[associated_net] - patch_value = addr + patch.addend - (offset + patch.addr) + patch_value = addr + patch.addend - (offset + patch.patch_address) elif patch.target_symbol_name.startswith('result_'): raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.") else: # Patch constants addresses on heap - print('##', section_addr_lookup, node.name, patch) - addr = section_addr_lookup[patch.target_symbol_section_index] - patch_value = addr + patch.addend - (offset + patch.addr) - patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) - print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name) + section_addr = section_addr_lookup[patch.target_symbol_section_index] + obj_addr = section_addr + patch.target_symbol_address + patch_value = obj_addr + patch.addend - (offset + patch.patch_address) + #print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name) + patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT)) + #print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name) elif patch.target_symbol_info == 'STT_FUNC': addr = aux_func_addr_lookup[patch.target_symbol_name] - patch_value = addr + patch.addend - (offset + patch.addr) - patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC)) - print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name) + patch_value = addr + patch.addend - (offset + patch.patch_address) + patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC)) + #print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name) else: raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") @@ -305,15 +306,15 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: # Patch constants/variable addresses on heap section_addr = section_addr_lookup[patch.target_symbol_section_index] - patch_value = section_addr + patch.addend - (start + patch.addr) - patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) - print(patch.type, patch.addr, section_addr, binw.Command.PATCH_OBJECT, name) - #print(patch.type, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT) + obj_addr = section_addr + patch.target_symbol_address + patch_value = obj_addr + patch.addend - (start + patch.patch_address) + patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT)) + #print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name) elif patch.target_symbol_info == 'STT_FUNC': aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name] - patch_value = aux_func_addr + patch.addend - (start + patch.addr) - patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_FUNC)) + patch_value = aux_func_addr + patch.addend - (start + patch.patch_address) + patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC)) else: raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}") diff --git a/src/copapy/_stencils.py b/src/copapy/_stencils.py index c4e6605..9e904d4 100644 --- a/src/copapy/_stencils.py +++ b/src/copapy/_stencils.py @@ -21,11 +21,12 @@ class patch_entry: type (RelocationType): relocation type""" type: RelocationType - addr: int + patch_address: int addend: int target_symbol_name: str target_symbol_info: str target_symbol_section_index: int + target_symbol_address: int def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType: @@ -150,10 +151,11 @@ class stencil_database(): reloc.fields['r_addend'], reloc.symbol.name, reloc.symbol.info, - reloc.symbol.fields['st_shndx']) + reloc.symbol.fields['st_shndx'], + reloc.symbol.fields['st_value']) # Exclude the call to the result_* function - if patch.addr < end_index - start_index: + if patch.patch_address < end_index - start_index: yield patch def get_stencil_code(self, name: str) -> bytes: @@ -189,7 +191,7 @@ class stencil_database(): return self.elf.sections[id].data def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes: - """Returns machine code for a specified function name""" + """Returns machine code for a specified function name.""" func = self.elf.symbols[name] assert func.info == 'STT_FUNC', f"{name} is not a function" diff --git a/src/copapy/_vectors.py b/src/copapy/_vectors.py index 9244d86..16d4cf1 100644 --- a/src/copapy/_vectors.py +++ b/src/copapy/_vectors.py @@ -7,7 +7,7 @@ VecIntLike: TypeAlias = 'vector[int] | variable[int] | int' VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float' T = TypeVar("T", int, float) -epsilon = 1e-10 +epsilon = 1e-20 class vector(Generic[T]): @@ -153,6 +153,6 @@ class vector(Generic[T]): def normalize(self) -> 'vector[float]': mag = self.magnitude() + epsilon return self / mag - + def __iter__(self) -> Iterable[variable[T] | T]: - return iter(self.values) \ No newline at end of file + return iter(self.values) diff --git a/tests/test_math.py b/tests/test_math.py index 246b976..79aa349 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,15 +1,17 @@ from copapy import variable, Target import pytest -import copapy +import copapy as cp def test_corse(): - c_i = variable(9) - c_f = variable(2.5) + a_i = 9 + a_f = 2.5 + c_i = variable(a_i) + c_f = variable(a_f) # c_b = variable(True) - ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3) - ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3) + ret_test = (c_f ** c_f, c_i ** c_i) # , c_i & 3) + ret_refe = (a_f ** a_f, a_i ** a_i) # , a_i & 3) tg = Target() print('* compile and copy ...') @@ -18,8 +20,8 @@ def test_corse(): tg.run() print('* finished') - for test, ref in zip(ret_test, ret_ref): - assert isinstance(test, copapy.variable) + for test, ref in zip(ret_test, ret_refe): + assert isinstance(test, cp.variable) val = tg.read_value(test) print('+', val, ref, type(val), test.dtype) #for t in (int, float, bool): @@ -28,12 +30,14 @@ def test_corse(): def test_fine(): - c_i = variable(9) - c_f = variable(2.5) + a_i = 9 + a_f = 2.5 + c_i = variable(a_i) + c_f = variable(a_f) # c_b = variable(True) - ret_test = (c_f ** 2, c_i ** -1)#, c_i & 3) - ret_ref = (2.5 ** 2, 9 ** -1)#, 9 & 3) + ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f)) # , c_i & 3) + ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f)) # , a_i & 3) tg = Target() print('* compile and copy ...') @@ -42,8 +46,8 @@ def test_fine(): tg.run() print('* finished') - for test, ref in zip(ret_test, ret_ref): - assert isinstance(test, copapy.variable) + for test, ref in zip(ret_test, ret_refe): + assert isinstance(test, cp.variable) val = tg.read_value(test) print('+', val, ref, type(val), test.dtype) #for t in (int, float, bool): diff --git a/tests/test_vector.py b/tests/test_vector.py index cbae5a2..b448417 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -1,11 +1,13 @@ import copapy as cp +import pytest + def test_vectors_init(): - tt1 = cp.vector(range(3)) + cp.vector([1.1,2.2,3.3]) - tt2 = cp.vector([1.1,2,cp.variable(5)])# + cp.vector(range(3)) + tt1 = cp.vector(range(3)) + cp.vector([1.1, 2.2, 3.3]) + tt2 = cp.vector([1.1, 2, cp.variable(5)]) + cp.vector(range(3)) tt3 = (cp.vector(range(3)) + 5.6) - tt4 = cp.vector([1.1,2,3]) + cp.vector(cp.variable(v) for v in range(3)) - tt5 = cp.vector([1,2,3]).dot(tt4) + tt4 = cp.vector([1.1, 2, 3]) + cp.vector(cp.variable(v) for v in range(3)) + tt5 = cp.vector([1, 2, 3]).dot(tt4) print(tt1, tt2, tt3, tt4, tt5) @@ -15,19 +17,22 @@ def test_compiled_vectors(): t2 = t1.sum() t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3)) - #t4 = ((t3 * t1) * 2).magnitude() t4 = ((t3 * t1) * 2).sum() - t5 = cp._math.sqrt2(cp.variable(8.0)) - t6 = cp._math.get_42() + t5 = ((t3 * t1) * 2).magnitude() tg = cp.Target() - tg.compile(t2, t4, t5, t6) + tg.compile(t2, t4, t5) tg.run() - assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2 - #assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5 - assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 20.5 + 4.5 - assert tg.read_value(t6) == 42.0 + assert isinstance(t2, cp.variable) + assert tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2 + + assert isinstance(t4, cp.variable) + assert tg.read_value(t4) == pytest.approx(((10/1*2) + (12/2*2) + (14/3*2)), 0.001) # pyright: ignore[reportUnknownMemberType] + + assert isinstance(t5, cp.variable) + assert tg.read_value(t5) == pytest.approx(((10/1*2)**2 + (12/2*2)**2 + (14/3*2)**2) ** 0.5, 0.001) # pyright: ignore[reportUnknownMemberType] + if __name__ == "__main__": test_compiled_vectors()