relocation patching for constants is fixed, tests added

This commit is contained in:
Nicolas Kruse 2025-10-26 22:26:12 +01:00
parent 82c324b1a6
commit 6445ac9724
5 changed files with 61 additions and 49 deletions

View File

@ -192,8 +192,8 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data
return function_list, offset return function_list, offset
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
variables: dict[Net, tuple[int, int, str]] = dict() variables: dict[Net, tuple[int, int, str]] = {}
data_list: list[bytes] = [] data_list: list[bytes] = []
patch_list: list[tuple[int, int, int, binw.Command]] = [] patch_list: list[tuple[int, int, int, binw.Command]] = []
@ -263,22 +263,23 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
assert associated_net, f"Relocation found but no net defined for operation {node.name}" assert associated_net, f"Relocation found but no net defined for operation {node.name}"
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") #print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
addr = object_addr_lookup[associated_net] addr = object_addr_lookup[associated_net]
patch_value = addr + patch.addend - (offset + patch.addr) patch_value = addr + patch.addend - (offset + patch.patch_address)
elif patch.target_symbol_name.startswith('result_'): elif patch.target_symbol_name.startswith('result_'):
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.") raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
else: else:
# Patch constants addresses on heap # Patch constants addresses on heap
print('##', section_addr_lookup, node.name, patch) section_addr = section_addr_lookup[patch.target_symbol_section_index]
addr = section_addr_lookup[patch.target_symbol_section_index] obj_addr = section_addr + patch.target_symbol_address
patch_value = addr + patch.addend - (offset + patch.addr) patch_value = obj_addr + patch.addend - (offset + patch.patch_address)
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) #print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name) patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
elif patch.target_symbol_info == 'STT_FUNC': elif patch.target_symbol_info == 'STT_FUNC':
addr = aux_func_addr_lookup[patch.target_symbol_name] addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = addr + patch.addend - (offset + patch.addr) patch_value = addr + patch.addend - (offset + patch.patch_address)
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC)) patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name) #print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
else: else:
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
@ -305,15 +306,15 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
# Patch constants/variable addresses on heap # Patch constants/variable addresses on heap
section_addr = section_addr_lookup[patch.target_symbol_section_index] section_addr = section_addr_lookup[patch.target_symbol_section_index]
patch_value = section_addr + patch.addend - (start + patch.addr) obj_addr = section_addr + patch.target_symbol_address
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) patch_value = obj_addr + patch.addend - (start + patch.patch_address)
print(patch.type, patch.addr, section_addr, binw.Command.PATCH_OBJECT, name) patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print(patch.type, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT) #print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
elif patch.target_symbol_info == 'STT_FUNC': elif patch.target_symbol_info == 'STT_FUNC':
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name] aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = aux_func_addr + patch.addend - (start + patch.addr) patch_value = aux_func_addr + patch.addend - (start + patch.patch_address)
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_FUNC)) patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
else: else:
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}") raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")

View File

@ -21,11 +21,12 @@ class patch_entry:
type (RelocationType): relocation type""" type (RelocationType): relocation type"""
type: RelocationType type: RelocationType
addr: int patch_address: int
addend: int addend: int
target_symbol_name: str target_symbol_name: str
target_symbol_info: str target_symbol_info: str
target_symbol_section_index: int target_symbol_section_index: int
target_symbol_address: int
def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType: def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType:
@ -150,10 +151,11 @@ class stencil_database():
reloc.fields['r_addend'], reloc.fields['r_addend'],
reloc.symbol.name, reloc.symbol.name,
reloc.symbol.info, reloc.symbol.info,
reloc.symbol.fields['st_shndx']) reloc.symbol.fields['st_shndx'],
reloc.symbol.fields['st_value'])
# Exclude the call to the result_* function # Exclude the call to the result_* function
if patch.addr < end_index - start_index: if patch.patch_address < end_index - start_index:
yield patch yield patch
def get_stencil_code(self, name: str) -> bytes: def get_stencil_code(self, name: str) -> bytes:
@ -189,7 +191,7 @@ class stencil_database():
return self.elf.sections[id].data return self.elf.sections[id].data
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes: def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
"""Returns machine code for a specified function name""" """Returns machine code for a specified function name."""
func = self.elf.symbols[name] func = self.elf.symbols[name]
assert func.info == 'STT_FUNC', f"{name} is not a function" assert func.info == 'STT_FUNC', f"{name} is not a function"

View File

@ -7,7 +7,7 @@ VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float' VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
T = TypeVar("T", int, float) T = TypeVar("T", int, float)
epsilon = 1e-10 epsilon = 1e-20
class vector(Generic[T]): class vector(Generic[T]):

View File

@ -1,15 +1,17 @@
from copapy import variable, Target from copapy import variable, Target
import pytest import pytest
import copapy import copapy as cp
def test_corse(): def test_corse():
c_i = variable(9) a_i = 9
c_f = variable(2.5) a_f = 2.5
c_i = variable(a_i)
c_f = variable(a_f)
# c_b = variable(True) # c_b = variable(True)
ret_test = (c_f ** c_f, c_i ** c_i) # , c_i & 3) ret_test = (c_f ** c_f, c_i ** c_i) # , c_i & 3)
ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3) ret_refe = (a_f ** a_f, a_i ** a_i) # , a_i & 3)
tg = Target() tg = Target()
print('* compile and copy ...') print('* compile and copy ...')
@ -18,8 +20,8 @@ def test_corse():
tg.run() tg.run()
print('* finished') print('* finished')
for test, ref in zip(ret_test, ret_ref): for test, ref in zip(ret_test, ret_refe):
assert isinstance(test, copapy.variable) assert isinstance(test, cp.variable)
val = tg.read_value(test) val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype) print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool): #for t in (int, float, bool):
@ -28,12 +30,14 @@ def test_corse():
def test_fine(): def test_fine():
c_i = variable(9) a_i = 9
c_f = variable(2.5) a_f = 2.5
c_i = variable(a_i)
c_f = variable(a_f)
# c_b = variable(True) # c_b = variable(True)
ret_test = (c_f ** 2, c_i ** -1)#, c_i & 3) ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f)) # , c_i & 3)
ret_ref = (2.5 ** 2, 9 ** -1)#, 9 & 3) ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f)) # , a_i & 3)
tg = Target() tg = Target()
print('* compile and copy ...') print('* compile and copy ...')
@ -42,8 +46,8 @@ def test_fine():
tg.run() tg.run()
print('* finished') print('* finished')
for test, ref in zip(ret_test, ret_ref): for test, ref in zip(ret_test, ret_refe):
assert isinstance(test, copapy.variable) assert isinstance(test, cp.variable)
val = tg.read_value(test) val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype) print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool): #for t in (int, float, bool):

View File

@ -1,8 +1,10 @@
import copapy as cp import copapy as cp
import pytest
def test_vectors_init(): def test_vectors_init():
tt1 = cp.vector(range(3)) + cp.vector([1.1, 2.2, 3.3]) tt1 = cp.vector(range(3)) + cp.vector([1.1, 2.2, 3.3])
tt2 = cp.vector([1.1,2,cp.variable(5)])# + cp.vector(range(3)) tt2 = cp.vector([1.1, 2, cp.variable(5)]) + cp.vector(range(3))
tt3 = (cp.vector(range(3)) + 5.6) tt3 = (cp.vector(range(3)) + 5.6)
tt4 = cp.vector([1.1, 2, 3]) + cp.vector(cp.variable(v) for v in range(3)) tt4 = cp.vector([1.1, 2, 3]) + cp.vector(cp.variable(v) for v in range(3))
tt5 = cp.vector([1, 2, 3]).dot(tt4) tt5 = cp.vector([1, 2, 3]).dot(tt4)
@ -15,19 +17,22 @@ def test_compiled_vectors():
t2 = t1.sum() t2 = t1.sum()
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3)) t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
#t4 = ((t3 * t1) * 2).magnitude()
t4 = ((t3 * t1) * 2).sum() t4 = ((t3 * t1) * 2).sum()
t5 = cp._math.sqrt2(cp.variable(8.0)) t5 = ((t3 * t1) * 2).magnitude()
t6 = cp._math.get_42()
tg = cp.Target() tg = cp.Target()
tg.compile(t2, t4, t5, t6) tg.compile(t2, t4, t5)
tg.run() tg.run()
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2 assert isinstance(t2, cp.variable)
#assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5 assert tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 20.5 + 4.5
assert tg.read_value(t6) == 42.0 assert isinstance(t4, cp.variable)
assert tg.read_value(t4) == pytest.approx(((10/1*2) + (12/2*2) + (14/3*2)), 0.001) # pyright: ignore[reportUnknownMemberType]
assert isinstance(t5, cp.variable)
assert tg.read_value(t5) == pytest.approx(((10/1*2)**2 + (12/2*2)**2 + (14/3*2)**2) ** 0.5, 0.001) # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__": if __name__ == "__main__":
test_compiled_vectors() test_compiled_vectors()