sub functions added

This commit is contained in:
Nicolas Kruse 2025-10-12 23:21:34 +02:00
parent c551fd1f2e
commit 54ccdfe867
11 changed files with 60 additions and 53 deletions

View File

@ -334,19 +334,19 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
return list(nets) return list(nets)
def get_variable_mem_layout(variable_list: list[Net], sdb: stencil_database) -> tuple[list[tuple[Net, int, int]], int]: def get_variable_mem_layout(variable_list: Iterable[Net], sdb: stencil_database) -> tuple[list[tuple[Net, int, int]], int]:
offset: int = 0 offset: int = 0
object_list: list[tuple[Net, int, int]] = [] object_list: list[tuple[Net, int, int]] = []
for variable in variable_list: for variable in variable_list:
lengths = sdb.var_size['dummy_' + variable.dtype] lengths = sdb.get_symbol_size('dummy_' + variable.dtype)
object_list.append((variable, offset, lengths)) object_list.append((variable, offset, lengths))
offset += (lengths + 3) // 4 * 4 offset += (lengths + 3) // 4 * 4
return object_list, offset return object_list, offset
def get_aux_function_mem_layout(function_names: list[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]: def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]:
offset: int = 0 offset: int = 0
function_list: list[tuple[str, int, int]] = [] function_list: list[tuple[str, int, int]] = []
@ -375,11 +375,11 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
variable_list = get_nets([[const_net_list]], extended_output_ops) variable_list = get_nets([[const_net_list]], extended_output_ops)
# Write data # Write data
object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb) variable_mem_layout, data_section_lengths = get_variable_mem_layout(variable_list, sdb)
dw.write_com(binw.Command.ALLOCATE_DATA) dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths) dw.write_int(data_section_lengths)
for net, out_offs, lengths in object_list: for net, out_offs, lengths in variable_mem_layout:
variables[net] = (out_offs, lengths, net.dtype) variables[net] = (out_offs, lengths, net.dtype)
if isinstance(net.source, InitVar): if isinstance(net.source, InitVar):
dw.write_com(binw.Command.COPY_DATA) dw.write_com(binw.Command.COPY_DATA)
@ -388,25 +388,22 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
dw.write_value(net.source.value, lengths) dw.write_value(net.source.value, lengths)
# print(f'+ {net.dtype} {net.source.value}') # print(f'+ {net.dtype} {net.source.value}')
# Write auxiliary_functions # write auxiliary_functions
object_list, data_section_lengths = get_aux_function_mem_layout(variable_list, sdb) aux_function_names = sdb.get_sub_functions(node.name for _, node in extended_output_ops)
dw.write_com(binw.Command.ALLOCATE_DATA) aux_function_mem_layout, aux_function_lengths = get_aux_function_mem_layout(aux_function_names, sdb)
dw.write_int(data_section_lengths) aux_func_addr_lookup = {name: offs for name, offs, _ in aux_function_mem_layout}
for net, out_offs, lengths in object_list: dw.write_com(binw.Command.COPY_CODE)
variables[net] = (out_offs, lengths, net.dtype) dw.write_int(0)
if isinstance(net.source, InitVar): dw.write_int(aux_function_lengths)
dw.write_com(binw.Command.COPY_DATA) for name, _, _ in aux_function_mem_layout:
dw.write_int(out_offs) dw.write_bytes(sdb.get_function_code(name))
dw.write_int(lengths)
dw.write_value(net.source.value, lengths)
# print(f'+ {net.dtype} {net.source.value}')
# Prepare program code and relocations # Prepare program code and relocations
object_addr_lookup = {net: out_offs for net, out_offs, _ in object_list} object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout}
data_list: list[bytes] = [] data_list: list[bytes] = []
patch_list: list[tuple[int, int, int]] = [] patch_list: list[tuple[int, int, int]] = []
offset = 0 # offset in generated code chunk offset = aux_function_lengths # offset in generated code chunk
# assemble stencils to main program # assemble stencils to main program
data = sdb.get_function_code('entry_function_shell', 'start') data = sdb.get_function_code('entry_function_shell', 'start')
@ -422,14 +419,16 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
for patch in sdb.get_patch_positions(node.name): for patch in sdb.get_patch_positions(node.name):
if patch.target_symbol_info == 'STT_OBJECT': if patch.target_symbol_info == 'STT_OBJECT':
assert associated_net, f"Relocation found but no net defined for operation {node.name}" assert associated_net, f"Relocation found but no net defined for operation {node.name}"
object_addr = object_addr_lookup[associated_net] addr = object_addr_lookup[associated_net]
patch_value = object_addr + patch.addend - (offset + patch.addr) elif patch.target_symbol_info == 'STT_FUNC':
# print('patch: ', patch, object_addr, patch_value) addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
print('++ ', patch.target_symbol_info, patch.target_symbol_name)
else: else:
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
patch_value = addr + patch.addend - (offset + patch.addr)
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
print('++ ', patch.target_symbol_info, patch.target_symbol_name)
offset += len(data) offset += len(data)
data = sdb.get_function_code('entry_function_shell', 'end') data = sdb.get_function_code('entry_function_shell', 'end')
@ -448,11 +447,14 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
dw.write_bytes(b''.join(data_list)) dw.write_bytes(b''.join(data_list))
# write relocations # write relocations
for patch_type, patch_addr, object_addr in patch_list: for patch_type, patch_addr, addr in patch_list:
dw.write_com(binw.Command.PATCH_OBJECT) dw.write_com(binw.Command.PATCH_OBJECT)
dw.write_int(patch_addr) dw.write_int(patch_addr)
dw.write_int(patch_type) dw.write_int(patch_type)
dw.write_int(object_addr, signed=True) dw.write_int(addr, signed=True)
dw.write_com(binw.Command.ENTRY_POINT)
dw.write_int(aux_function_lengths)
return dw, variables return dw, variables
@ -474,15 +476,14 @@ class Target():
nodes.append(Write(net)) nodes.append(Write(net))
dw, self._variables = compile_to_instruction_list(nodes, self.sdb) dw, self._variables = compile_to_instruction_list(nodes, self.sdb)
dw.write_com(binw.Command.END_PROG) dw.write_com(binw.Command.END_COM)
assert coparun(dw.get_data()) > 0 assert coparun(dw.get_data()) > 0
def run(self) -> None: def run(self) -> None:
# set entry point and run code # set entry point and run code
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)
dw.write_com(binw.Command.RUN_PROG) dw.write_com(binw.Command.RUN_PROG)
dw.write_int(0) dw.write_com(binw.Command.END_COM)
dw.write_com(binw.Command.END_PROG)
assert coparun(dw.get_data()) > 0 assert coparun(dw.get_data()) > 0
def read_value(self, net: Net) -> float | int: def read_value(self, net: Net) -> float | int:

View File

@ -6,9 +6,9 @@ ByteOrder = Literal['little', 'big']
Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2), Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
('ALLOCATE_CODE', 3), ('COPY_CODE', 4), ('ALLOCATE_CODE', 3), ('COPY_CODE', 4),
('PATCH_FUNC', 5), ('PATCH_OBJECT', 6), ('PATCH_FUNC', 5), ('PATCH_OBJECT', 6), ('ENTRY_POINT', 7),
('RUN_PROG', 64), ('READ_DATA', 65), ('RUN_PROG', 64), ('READ_DATA', 65),
('END_PROG', 256), ('FREE_MEMORY', 257)]) ('END_COM', 256), ('FREE_MEMORY', 257)])
COMMAND_SIZE = 4 COMMAND_SIZE = 4

View File

@ -95,9 +95,9 @@ class stencil_database():
# for s in self.elf.symbols # for s in self.elf.symbols
# if s.info == 'STT_FUNC'} # if s.info == 'STT_FUNC'}
self.var_size = {s.name: s.fields['st_size'] #self.var_size = {s.name: s.fields['st_size']
for s in self.elf.symbols # for s in self.elf.symbols
if s.info == 'STT_OBJECT'} # if s.info == 'STT_OBJECT'}
self.byteorder: ByteOrder = self.elf.byteorder self.byteorder: ByteOrder = self.elf.byteorder
#for name in self.function_definitions.keys(): #for name in self.function_definitions.keys():
@ -150,9 +150,10 @@ class stencil_database():
for name in names: for name in names:
if name not in name_set: if name not in name_set:
func = self.elf.symbols[name] func = self.elf.symbols[name]
for reloc in func.relocations: for r in func.relocations:
name_set.add(reloc.symbol.name) if r.symbol.info == 'STT_FUNC':
name_set |= self.get_sub_functions(name_set) name_set.add(r.symbol.name)
name_set |= self.get_sub_functions([r.symbol.name])
return name_set return name_set
def get_symbol_size(self, name: str) -> int: def get_symbol_size(self, name: str) -> int:

View File

@ -105,12 +105,15 @@ int parse_commands(uint8_t *bytes) {
patch(executable_memory + offs, reloc_type, value + data_offs); patch(executable_memory + offs, reloc_type, value + data_offs);
break; break;
case RUN_PROG: case ENTRY_POINT:
printf("ENTRY_POINT rel_entr_point=%i\n", rel_entr_point);
rel_entr_point = *(uint32_t*)bytes; bytes += 4; rel_entr_point = *(uint32_t*)bytes; bytes += 4;
printf("RUN_PROG rel_entr_point=%i\n", rel_entr_point);
entr_point = (entry_point_t)(executable_memory + rel_entr_point); entr_point = (entry_point_t)(executable_memory + rel_entr_point);
mark_mem_executable(executable_memory, executable_memory_len); mark_mem_executable(executable_memory, executable_memory_len);
break;
case RUN_PROG:
printf("RUN_PROG");
int ret = entr_point(); int ret = entr_point();
printf("Return value: %i\n", ret); printf("Return value: %i\n", ret);
break; break;

View File

@ -10,9 +10,10 @@
#define COPY_CODE 4 #define COPY_CODE 4
#define PATCH_FUNC 5 #define PATCH_FUNC 5
#define PATCH_OBJECT 6 #define PATCH_OBJECT 6
#define ENTRY_POINT 7
#define RUN_PROG 64 #define RUN_PROG 64
#define READ_DATA 65 #define READ_DATA 65
#define END_PROG 256 #define END_COM 256
#define FREE_MEMORY 257 #define FREE_MEMORY 257
/* Relocation types */ /* Relocation types */

View File

@ -63,7 +63,7 @@ def test_compile():
il.write_int(0) il.write_int(0)
il.write_int(36) il.write_int(36)
il.write_com(binwrite.Command.END_PROG) il.write_com(binwrite.Command.END_COM)
print('* Data to runner:') print('* Data to runner:')
il.print() il.print()

View File

@ -29,13 +29,12 @@ def test_compile():
# run program command # run program command
il.write_com(binwrite.Command.RUN_PROG) il.write_com(binwrite.Command.RUN_PROG)
il.write_int(0)
il.write_com(binwrite.Command.READ_DATA) il.write_com(binwrite.Command.READ_DATA)
il.write_int(0) il.write_int(0)
il.write_int(36) il.write_int(36)
il.write_com(binwrite.Command.END_PROG) il.write_com(binwrite.Command.END_COM)
print('* Data to runner:') print('* Data to runner:')
il.print() il.print()

View File

@ -25,7 +25,7 @@ def test_compile():
il.write_int(36) il.write_int(36)
# run program command # run program command
il.write_com(binwrite.Command.END_PROG) il.write_com(binwrite.Command.END_COM)
#print('* Data to runner:') #print('* Data to runner:')
#il.print() #il.print()

View File

@ -42,7 +42,7 @@ def test_compile():
print('+', name) print('+', name)
copapy.add_read_command(dw, variable_list, net) copapy.add_read_command(dw, variable_list, net)
dw.write_com(binwrite.Command.END_PROG) dw.write_com(binwrite.Command.END_COM)
dw.to_file('bin/test.copapy') dw.to_file('bin/test.copapy')
result = run_command(['bin/coparun', 'bin/test.copapy']) result = run_command(['bin/coparun', 'bin/test.copapy'])

View File

@ -55,17 +55,20 @@ if __name__ == "__main__":
assert reloc_type == RelocationType.RELOC_RELATIVE_32.value assert reloc_type == RelocationType.RELOC_RELATIVE_32.value
program_data[offs:offs + 4] = (value + data_section_offset).to_bytes(4, byteorder, signed=True) program_data[offs:offs + 4] = (value + data_section_offset).to_bytes(4, byteorder, signed=True)
print(f"PATCH_OBJECT patch_offs={offs} reloc_type={reloc_type} value={value}") print(f"PATCH_OBJECT patch_offs={offs} reloc_type={reloc_type} value={value}")
elif com == Command.ENTRY_POINT:
rel_entr_point = dr.read_int()
print(f"ENTRY_POINT rel_entr_point={rel_entr_point}")
elif com == Command.RUN_PROG: elif com == Command.RUN_PROG:
rel_entr_point = dr.read_int() rel_entr_point = dr.read_int()
print(f"RUN_PROG rel_entr_point={rel_entr_point}") print(f"RUN_PROG")
elif com == Command.READ_DATA: elif com == Command.READ_DATA:
offs = dr.read_int() offs = dr.read_int()
size = dr.read_int() size = dr.read_int()
print(f"READ_DATA offs={offs} size={size}") print(f"READ_DATA offs={offs} size={size}")
elif com == Command.FREE_MEMORY: elif com == Command.FREE_MEMORY:
print("READ_DATA") print("READ_DATA")
elif com == Command.END_PROG: elif com == Command.END_COM:
print("END_PROG") print("END_COM")
end_flag = 1 end_flag = 1
else: else:
assert False, f"Unknown command: {com}" assert False, f"Unknown command: {com}"

View File

@ -16,13 +16,12 @@ def test_compile() -> None:
# run program command # run program command
il.write_com(binwrite.Command.RUN_PROG) il.write_com(binwrite.Command.RUN_PROG)
il.write_int(0)
il.write_com(binwrite.Command.READ_DATA) il.write_com(binwrite.Command.READ_DATA)
il.write_int(0) il.write_int(0)
il.write_int(36) il.write_int(36)
il.write_com(binwrite.Command.END_PROG) il.write_com(binwrite.Command.END_COM)
print('* Data to runner:') print('* Data to runner:')
il.print() il.print()