mirror of https://github.com/Nonannet/copapy.git
relocation patching for constants is fixed, tests added
This commit is contained in:
parent
82c324b1a6
commit
6445ac9724
|
|
@ -192,8 +192,8 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data
|
||||||
return function_list, offset
|
return function_list, offset
|
||||||
|
|
||||||
|
|
||||||
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||||
variables: dict[Net, tuple[int, int, str]] = dict()
|
variables: dict[Net, tuple[int, int, str]] = {}
|
||||||
data_list: list[bytes] = []
|
data_list: list[bytes] = []
|
||||||
patch_list: list[tuple[int, int, int, binw.Command]] = []
|
patch_list: list[tuple[int, int, int, binw.Command]] = []
|
||||||
|
|
||||||
|
|
@ -263,22 +263,23 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
||||||
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
|
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
|
||||||
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||||
addr = object_addr_lookup[associated_net]
|
addr = object_addr_lookup[associated_net]
|
||||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
patch_value = addr + patch.addend - (offset + patch.patch_address)
|
||||||
elif patch.target_symbol_name.startswith('result_'):
|
elif patch.target_symbol_name.startswith('result_'):
|
||||||
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
|
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
|
||||||
else:
|
else:
|
||||||
# Patch constants addresses on heap
|
# Patch constants addresses on heap
|
||||||
print('##', section_addr_lookup, node.name, patch)
|
section_addr = section_addr_lookup[patch.target_symbol_section_index]
|
||||||
addr = section_addr_lookup[patch.target_symbol_section_index]
|
obj_addr = section_addr + patch.target_symbol_address
|
||||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
patch_value = obj_addr + patch.addend - (offset + patch.patch_address)
|
||||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
|
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
|
||||||
print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
|
patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
|
||||||
|
#print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
|
||||||
|
|
||||||
elif patch.target_symbol_info == 'STT_FUNC':
|
elif patch.target_symbol_info == 'STT_FUNC':
|
||||||
addr = aux_func_addr_lookup[patch.target_symbol_name]
|
addr = aux_func_addr_lookup[patch.target_symbol_name]
|
||||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
patch_value = addr + patch.addend - (offset + patch.patch_address)
|
||||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC))
|
patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
|
||||||
print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
|
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||||
|
|
||||||
|
|
@ -305,15 +306,15 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
||||||
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
|
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
|
||||||
# Patch constants/variable addresses on heap
|
# Patch constants/variable addresses on heap
|
||||||
section_addr = section_addr_lookup[patch.target_symbol_section_index]
|
section_addr = section_addr_lookup[patch.target_symbol_section_index]
|
||||||
patch_value = section_addr + patch.addend - (start + patch.addr)
|
obj_addr = section_addr + patch.target_symbol_address
|
||||||
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
|
patch_value = obj_addr + patch.addend - (start + patch.patch_address)
|
||||||
print(patch.type, patch.addr, section_addr, binw.Command.PATCH_OBJECT, name)
|
patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
|
||||||
#print(patch.type, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT)
|
#print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
|
||||||
|
|
||||||
elif patch.target_symbol_info == 'STT_FUNC':
|
elif patch.target_symbol_info == 'STT_FUNC':
|
||||||
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
|
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
|
||||||
patch_value = aux_func_addr + patch.addend - (start + patch.addr)
|
patch_value = aux_func_addr + patch.addend - (start + patch.patch_address)
|
||||||
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_FUNC))
|
patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||||
|
|
|
||||||
|
|
@ -21,11 +21,12 @@ class patch_entry:
|
||||||
type (RelocationType): relocation type"""
|
type (RelocationType): relocation type"""
|
||||||
|
|
||||||
type: RelocationType
|
type: RelocationType
|
||||||
addr: int
|
patch_address: int
|
||||||
addend: int
|
addend: int
|
||||||
target_symbol_name: str
|
target_symbol_name: str
|
||||||
target_symbol_info: str
|
target_symbol_info: str
|
||||||
target_symbol_section_index: int
|
target_symbol_section_index: int
|
||||||
|
target_symbol_address: int
|
||||||
|
|
||||||
|
|
||||||
def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType:
|
def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType:
|
||||||
|
|
@ -150,10 +151,11 @@ class stencil_database():
|
||||||
reloc.fields['r_addend'],
|
reloc.fields['r_addend'],
|
||||||
reloc.symbol.name,
|
reloc.symbol.name,
|
||||||
reloc.symbol.info,
|
reloc.symbol.info,
|
||||||
reloc.symbol.fields['st_shndx'])
|
reloc.symbol.fields['st_shndx'],
|
||||||
|
reloc.symbol.fields['st_value'])
|
||||||
|
|
||||||
# Exclude the call to the result_* function
|
# Exclude the call to the result_* function
|
||||||
if patch.addr < end_index - start_index:
|
if patch.patch_address < end_index - start_index:
|
||||||
yield patch
|
yield patch
|
||||||
|
|
||||||
def get_stencil_code(self, name: str) -> bytes:
|
def get_stencil_code(self, name: str) -> bytes:
|
||||||
|
|
@ -189,7 +191,7 @@ class stencil_database():
|
||||||
return self.elf.sections[id].data
|
return self.elf.sections[id].data
|
||||||
|
|
||||||
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
|
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
|
||||||
"""Returns machine code for a specified function name"""
|
"""Returns machine code for a specified function name."""
|
||||||
func = self.elf.symbols[name]
|
func = self.elf.symbols[name]
|
||||||
assert func.info == 'STT_FUNC', f"{name} is not a function"
|
assert func.info == 'STT_FUNC', f"{name} is not a function"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
|
||||||
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
|
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
|
||||||
T = TypeVar("T", int, float)
|
T = TypeVar("T", int, float)
|
||||||
|
|
||||||
epsilon = 1e-10
|
epsilon = 1e-20
|
||||||
|
|
||||||
|
|
||||||
class vector(Generic[T]):
|
class vector(Generic[T]):
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,17 @@
|
||||||
from copapy import variable, Target
|
from copapy import variable, Target
|
||||||
import pytest
|
import pytest
|
||||||
import copapy
|
import copapy as cp
|
||||||
|
|
||||||
|
|
||||||
def test_corse():
|
def test_corse():
|
||||||
c_i = variable(9)
|
a_i = 9
|
||||||
c_f = variable(2.5)
|
a_f = 2.5
|
||||||
|
c_i = variable(a_i)
|
||||||
|
c_f = variable(a_f)
|
||||||
# c_b = variable(True)
|
# c_b = variable(True)
|
||||||
|
|
||||||
ret_test = (c_f ** c_f, c_i ** c_i)#, c_i & 3)
|
ret_test = (c_f ** c_f, c_i ** c_i) # , c_i & 3)
|
||||||
ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3)
|
ret_refe = (a_f ** a_f, a_i ** a_i) # , a_i & 3)
|
||||||
|
|
||||||
tg = Target()
|
tg = Target()
|
||||||
print('* compile and copy ...')
|
print('* compile and copy ...')
|
||||||
|
|
@ -18,8 +20,8 @@ def test_corse():
|
||||||
tg.run()
|
tg.run()
|
||||||
print('* finished')
|
print('* finished')
|
||||||
|
|
||||||
for test, ref in zip(ret_test, ret_ref):
|
for test, ref in zip(ret_test, ret_refe):
|
||||||
assert isinstance(test, copapy.variable)
|
assert isinstance(test, cp.variable)
|
||||||
val = tg.read_value(test)
|
val = tg.read_value(test)
|
||||||
print('+', val, ref, type(val), test.dtype)
|
print('+', val, ref, type(val), test.dtype)
|
||||||
#for t in (int, float, bool):
|
#for t in (int, float, bool):
|
||||||
|
|
@ -28,12 +30,14 @@ def test_corse():
|
||||||
|
|
||||||
|
|
||||||
def test_fine():
|
def test_fine():
|
||||||
c_i = variable(9)
|
a_i = 9
|
||||||
c_f = variable(2.5)
|
a_f = 2.5
|
||||||
|
c_i = variable(a_i)
|
||||||
|
c_f = variable(a_f)
|
||||||
# c_b = variable(True)
|
# c_b = variable(True)
|
||||||
|
|
||||||
ret_test = (c_f ** 2, c_i ** -1)#, c_i & 3)
|
ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f)) # , c_i & 3)
|
||||||
ret_ref = (2.5 ** 2, 9 ** -1)#, 9 & 3)
|
ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f)) # , a_i & 3)
|
||||||
|
|
||||||
tg = Target()
|
tg = Target()
|
||||||
print('* compile and copy ...')
|
print('* compile and copy ...')
|
||||||
|
|
@ -42,8 +46,8 @@ def test_fine():
|
||||||
tg.run()
|
tg.run()
|
||||||
print('* finished')
|
print('* finished')
|
||||||
|
|
||||||
for test, ref in zip(ret_test, ret_ref):
|
for test, ref in zip(ret_test, ret_refe):
|
||||||
assert isinstance(test, copapy.variable)
|
assert isinstance(test, cp.variable)
|
||||||
val = tg.read_value(test)
|
val = tg.read_value(test)
|
||||||
print('+', val, ref, type(val), test.dtype)
|
print('+', val, ref, type(val), test.dtype)
|
||||||
#for t in (int, float, bool):
|
#for t in (int, float, bool):
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
import copapy as cp
|
import copapy as cp
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def test_vectors_init():
|
def test_vectors_init():
|
||||||
tt1 = cp.vector(range(3)) + cp.vector([1.1,2.2,3.3])
|
tt1 = cp.vector(range(3)) + cp.vector([1.1, 2.2, 3.3])
|
||||||
tt2 = cp.vector([1.1,2,cp.variable(5)])# + cp.vector(range(3))
|
tt2 = cp.vector([1.1, 2, cp.variable(5)]) + cp.vector(range(3))
|
||||||
tt3 = (cp.vector(range(3)) + 5.6)
|
tt3 = (cp.vector(range(3)) + 5.6)
|
||||||
tt4 = cp.vector([1.1,2,3]) + cp.vector(cp.variable(v) for v in range(3))
|
tt4 = cp.vector([1.1, 2, 3]) + cp.vector(cp.variable(v) for v in range(3))
|
||||||
tt5 = cp.vector([1,2,3]).dot(tt4)
|
tt5 = cp.vector([1, 2, 3]).dot(tt4)
|
||||||
|
|
||||||
print(tt1, tt2, tt3, tt4, tt5)
|
print(tt1, tt2, tt3, tt4, tt5)
|
||||||
|
|
||||||
|
|
@ -15,19 +17,22 @@ def test_compiled_vectors():
|
||||||
t2 = t1.sum()
|
t2 = t1.sum()
|
||||||
|
|
||||||
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
|
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
|
||||||
#t4 = ((t3 * t1) * 2).magnitude()
|
|
||||||
t4 = ((t3 * t1) * 2).sum()
|
t4 = ((t3 * t1) * 2).sum()
|
||||||
t5 = cp._math.sqrt2(cp.variable(8.0))
|
t5 = ((t3 * t1) * 2).magnitude()
|
||||||
t6 = cp._math.get_42()
|
|
||||||
|
|
||||||
tg = cp.Target()
|
tg = cp.Target()
|
||||||
tg.compile(t2, t4, t5, t6)
|
tg.compile(t2, t4, t5)
|
||||||
tg.run()
|
tg.run()
|
||||||
|
|
||||||
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
|
assert isinstance(t2, cp.variable)
|
||||||
#assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
|
assert tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
|
||||||
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 20.5 + 4.5
|
|
||||||
assert tg.read_value(t6) == 42.0
|
assert isinstance(t4, cp.variable)
|
||||||
|
assert tg.read_value(t4) == pytest.approx(((10/1*2) + (12/2*2) + (14/3*2)), 0.001) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
assert isinstance(t5, cp.variable)
|
||||||
|
assert tg.read_value(t5) == pytest.approx(((10/1*2)**2 + (12/2*2)**2 + (14/3*2)**2) ** 0.5, 0.001) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_compiled_vectors()
|
test_compiled_vectors()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue