mirror of https://github.com/Nonannet/copapy.git
relocation patching for constants is fixed, tests added
This commit is contained in:
parent
82c324b1a6
commit
6445ac9724
|
|
@ -192,8 +192,8 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data
|
|||
return function_list, offset
|
||||
|
||||
|
||||
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||
variables: dict[Net, tuple[int, int, str]] = dict()
|
||||
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||
variables: dict[Net, tuple[int, int, str]] = {}
|
||||
data_list: list[bytes] = []
|
||||
patch_list: list[tuple[int, int, int, binw.Command]] = []
|
||||
|
||||
|
|
@ -263,22 +263,23 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
|
||||
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||
addr = object_addr_lookup[associated_net]
|
||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
||||
patch_value = addr + patch.addend - (offset + patch.patch_address)
|
||||
elif patch.target_symbol_name.startswith('result_'):
|
||||
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
|
||||
else:
|
||||
# Patch constants addresses on heap
|
||||
print('##', section_addr_lookup, node.name, patch)
|
||||
addr = section_addr_lookup[patch.target_symbol_section_index]
|
||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
|
||||
print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
|
||||
section_addr = section_addr_lookup[patch.target_symbol_section_index]
|
||||
obj_addr = section_addr + patch.target_symbol_address
|
||||
patch_value = obj_addr + patch.addend - (offset + patch.patch_address)
|
||||
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
|
||||
patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
|
||||
#print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
|
||||
|
||||
elif patch.target_symbol_info == 'STT_FUNC':
|
||||
addr = aux_func_addr_lookup[patch.target_symbol_name]
|
||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC))
|
||||
print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
|
||||
patch_value = addr + patch.addend - (offset + patch.patch_address)
|
||||
patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
|
||||
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||
|
||||
|
|
@ -305,15 +306,15 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
|
||||
# Patch constants/variable addresses on heap
|
||||
section_addr = section_addr_lookup[patch.target_symbol_section_index]
|
||||
patch_value = section_addr + patch.addend - (start + patch.addr)
|
||||
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
|
||||
print(patch.type, patch.addr, section_addr, binw.Command.PATCH_OBJECT, name)
|
||||
#print(patch.type, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT)
|
||||
obj_addr = section_addr + patch.target_symbol_address
|
||||
patch_value = obj_addr + patch.addend - (start + patch.patch_address)
|
||||
patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
|
||||
#print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
|
||||
|
||||
elif patch.target_symbol_info == 'STT_FUNC':
|
||||
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
|
||||
patch_value = aux_func_addr + patch.addend - (start + patch.addr)
|
||||
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_FUNC))
|
||||
patch_value = aux_func_addr + patch.addend - (start + patch.patch_address)
|
||||
patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||
|
|
|
|||
|
|
@ -21,11 +21,12 @@ class patch_entry:
|
|||
type (RelocationType): relocation type"""
|
||||
|
||||
type: RelocationType
|
||||
addr: int
|
||||
patch_address: int
|
||||
addend: int
|
||||
target_symbol_name: str
|
||||
target_symbol_info: str
|
||||
target_symbol_section_index: int
|
||||
target_symbol_address: int
|
||||
|
||||
|
||||
def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType:
|
||||
|
|
@ -150,10 +151,11 @@ class stencil_database():
|
|||
reloc.fields['r_addend'],
|
||||
reloc.symbol.name,
|
||||
reloc.symbol.info,
|
||||
reloc.symbol.fields['st_shndx'])
|
||||
reloc.symbol.fields['st_shndx'],
|
||||
reloc.symbol.fields['st_value'])
|
||||
|
||||
# Exclude the call to the result_* function
|
||||
if patch.addr < end_index - start_index:
|
||||
if patch.patch_address < end_index - start_index:
|
||||
yield patch
|
||||
|
||||
def get_stencil_code(self, name: str) -> bytes:
|
||||
|
|
@ -189,7 +191,7 @@ class stencil_database():
|
|||
return self.elf.sections[id].data
|
||||
|
||||
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
|
||||
"""Returns machine code for a specified function name"""
|
||||
"""Returns machine code for a specified function name."""
|
||||
func = self.elf.symbols[name]
|
||||
assert func.info == 'STT_FUNC', f"{name} is not a function"
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
|
|||
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
|
||||
T = TypeVar("T", int, float)
|
||||
|
||||
epsilon = 1e-10
|
||||
epsilon = 1e-20
|
||||
|
||||
|
||||
class vector(Generic[T]):
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
from copapy import variable, Target
|
||||
import pytest
|
||||
import copapy
|
||||
import copapy as cp
|
||||
|
||||
|
||||
def test_corse():
|
||||
c_i = variable(9)
|
||||
c_f = variable(2.5)
|
||||
a_i = 9
|
||||
a_f = 2.5
|
||||
c_i = variable(a_i)
|
||||
c_f = variable(a_f)
|
||||
# c_b = variable(True)
|
||||
|
||||
ret_test = (c_f ** c_f, c_i ** c_i) # , c_i & 3)
|
||||
ret_ref = (2.5 ** 2.5, 9 ** 9)#, 9 & 3)
|
||||
ret_refe = (a_f ** a_f, a_i ** a_i) # , a_i & 3)
|
||||
|
||||
tg = Target()
|
||||
print('* compile and copy ...')
|
||||
|
|
@ -18,8 +20,8 @@ def test_corse():
|
|||
tg.run()
|
||||
print('* finished')
|
||||
|
||||
for test, ref in zip(ret_test, ret_ref):
|
||||
assert isinstance(test, copapy.variable)
|
||||
for test, ref in zip(ret_test, ret_refe):
|
||||
assert isinstance(test, cp.variable)
|
||||
val = tg.read_value(test)
|
||||
print('+', val, ref, type(val), test.dtype)
|
||||
#for t in (int, float, bool):
|
||||
|
|
@ -28,12 +30,14 @@ def test_corse():
|
|||
|
||||
|
||||
def test_fine():
|
||||
c_i = variable(9)
|
||||
c_f = variable(2.5)
|
||||
a_i = 9
|
||||
a_f = 2.5
|
||||
c_i = variable(a_i)
|
||||
c_f = variable(a_f)
|
||||
# c_b = variable(True)
|
||||
|
||||
ret_test = (c_f ** 2, c_i ** -1)#, c_i & 3)
|
||||
ret_ref = (2.5 ** 2, 9 ** -1)#, 9 & 3)
|
||||
ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f)) # , c_i & 3)
|
||||
ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f)) # , a_i & 3)
|
||||
|
||||
tg = Target()
|
||||
print('* compile and copy ...')
|
||||
|
|
@ -42,8 +46,8 @@ def test_fine():
|
|||
tg.run()
|
||||
print('* finished')
|
||||
|
||||
for test, ref in zip(ret_test, ret_ref):
|
||||
assert isinstance(test, copapy.variable)
|
||||
for test, ref in zip(ret_test, ret_refe):
|
||||
assert isinstance(test, cp.variable)
|
||||
val = tg.read_value(test)
|
||||
print('+', val, ref, type(val), test.dtype)
|
||||
#for t in (int, float, bool):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import copapy as cp
|
||||
import pytest
|
||||
|
||||
|
||||
def test_vectors_init():
|
||||
tt1 = cp.vector(range(3)) + cp.vector([1.1, 2.2, 3.3])
|
||||
tt2 = cp.vector([1.1,2,cp.variable(5)])# + cp.vector(range(3))
|
||||
tt2 = cp.vector([1.1, 2, cp.variable(5)]) + cp.vector(range(3))
|
||||
tt3 = (cp.vector(range(3)) + 5.6)
|
||||
tt4 = cp.vector([1.1, 2, 3]) + cp.vector(cp.variable(v) for v in range(3))
|
||||
tt5 = cp.vector([1, 2, 3]).dot(tt4)
|
||||
|
|
@ -15,19 +17,22 @@ def test_compiled_vectors():
|
|||
t2 = t1.sum()
|
||||
|
||||
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
|
||||
#t4 = ((t3 * t1) * 2).magnitude()
|
||||
t4 = ((t3 * t1) * 2).sum()
|
||||
t5 = cp._math.sqrt2(cp.variable(8.0))
|
||||
t6 = cp._math.get_42()
|
||||
t5 = ((t3 * t1) * 2).magnitude()
|
||||
|
||||
tg = cp.Target()
|
||||
tg.compile(t2, t4, t5, t6)
|
||||
tg.compile(t2, t4, t5)
|
||||
tg.run()
|
||||
|
||||
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
|
||||
#assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
|
||||
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 20.5 + 4.5
|
||||
assert tg.read_value(t6) == 42.0
|
||||
assert isinstance(t2, cp.variable)
|
||||
assert tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
|
||||
|
||||
assert isinstance(t4, cp.variable)
|
||||
assert tg.read_value(t4) == pytest.approx(((10/1*2) + (12/2*2) + (14/3*2)), 0.001) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
assert isinstance(t5, cp.variable)
|
||||
assert tg.read_value(t5) == pytest.approx(((10/1*2)**2 + (12/2*2)**2 + (14/3*2)**2) ** 0.5, 0.001) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_compiled_vectors()
|
||||
|
|
|
|||
Loading…
Reference in New Issue