compiler: Added patching for aux functions

This commit is contained in:
Nicolas Kruse 2025-10-26 12:37:44 +01:00
parent 538bf23412
commit e400eff2b0
2 changed files with 37 additions and 14 deletions

View File

@ -166,8 +166,8 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset:
return object_list, offset
def get_target_sym_lookup(function_names: Iterable[str], sdb: stencil_database) -> dict[str, patch_entry]:
return {patch.target_symbol_name: patch for name in set(function_names) for patch in sdb.get_patch_positions(name)}
#def get_target_sym_lookup(function_names: Iterable[str], sdb: stencil_database) -> dict[str, patch_entry]:
# return {patch.target_symbol_name: patch for name in set(function_names) for patch in sdb.get_patch_positions(name)}
def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]:
@ -221,18 +221,18 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
dw.write_int(variables_data_lengths)
# Heap constants
for section_id, out_offs, lengths in section_mem_layout:
for section_id, start, lengths in section_mem_layout:
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs)
dw.write_int(start)
dw.write_int(lengths)
dw.write_bytes(sdb.get_section_data(section_id))
# Heap variables
for net, out_offs, lengths in variable_mem_layout:
variables[net] = (out_offs, lengths, net.dtype)
for net, start, lengths in variable_mem_layout:
variables[net] = (start, lengths, net.dtype)
if isinstance(net.source, CPConstant):
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs)
dw.write_int(start)
dw.write_int(lengths)
dw.write_value(net.source.value, lengths)
# print(f'+ {net.dtype} {net.source.value}')
@ -246,7 +246,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout}
offset = aux_function_lengths # offset in generated code chunk
# assemble stencils to main program
# assemble stencils to main program and patch stencils
data = sdb.get_function_code('entry_function_shell', 'start')
data_list.append(data)
offset += len(data)
@ -257,7 +257,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
data_list.append(data)
#print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for patch in sdb.get_patch_positions(node.name):
for patch in sdb.get_patch_positions(node.name, stencil=True):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
if patch.target_symbol_name.startswith('dummy_'):
# Patch for write and read addresses to/from heap variables
@ -265,8 +265,11 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
addr = object_addr_lookup[associated_net]
patch_value = addr + patch.addend - (offset + patch.addr)
elif patch.target_symbol_name.startswith('result_'):
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
else:
# Patch constants addresses on heap
print('##', section_addr_lookup, node.name, patch)
addr = section_addr_lookup[patch.target_symbol_section_index]
patch_value = addr + patch.addend - (offset + patch.addr)
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
@ -288,13 +291,29 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
dw.write_com(binw.Command.ALLOCATE_CODE)
dw.write_int(offset)
# write aux functions
for name, out_offs, lengths in aux_function_mem_layout:
# write aux functions code
for name, start, lengths in aux_function_mem_layout:
dw.write_com(binw.Command.COPY_CODE)
dw.write_int(out_offs)
dw.write_int(start)
dw.write_int(lengths)
dw.write_bytes(sdb.get_function_code(name))
# Patch aux functions
for name, start, lengths in aux_function_mem_layout:
for patch in sdb.get_patch_positions(name):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
# Patch constants/variable addresses on heap
addr = section_addr_lookup[patch.target_symbol_section_index]
patch_value = addr + patch.addend - (start + patch.addr)
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
elif patch.target_symbol_info == 'STT_FUNC':
addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = addr + patch.addend - (start + patch.addr)
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_FUNC))
else:
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
# write entry function code
dw.write_com(binw.Command.COPY_CODE)
dw.write_int(aux_function_lengths)

View File

@ -119,7 +119,7 @@ class stencil_database():
ret.add(sym.section.index)
return list(ret)
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
def get_patch_positions(self, symbol_name: str, stencil: bool = False) -> Generator[patch_entry, None, None]:
"""Return patch positions for a provided symbol (function or object)
Args:
@ -129,7 +129,11 @@ class stencil_database():
patch_entry: every relocation for the symbol
"""
symbol = self.elf.symbols[symbol_name]
start_index, end_index = get_stencil_position(symbol)
if stencil:
start_index, end_index = get_stencil_position(symbol)
else:
start_index = 0
end_index = symbol.fields['st_size']
for reloc in symbol.relocations: