code clean up

This commit is contained in:
Nicolas Kruse 2025-10-11 22:53:55 +02:00
parent 1501cc71a8
commit 3509608f7c
3 changed files with 27 additions and 37 deletions

View File

@ -334,6 +334,18 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
return list(nets) return list(nets)
def get_variable_mem_layout(variable_list: list[Net], sdb: stencil_database) -> tuple[list[tuple[Net, int, int]], int]:
offset: int = 0
object_list: list[tuple[Net, int, int]] = []
for variable in variable_list:
lengths = sdb.var_size['dummy_' + variable.dtype]
object_list.append((variable, offset, lengths))
offset += (lengths + 3) // 4 * 4
return object_list, offset
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
variables: dict[Net, tuple[int, int, str]] = dict() variables: dict[Net, tuple[int, int, str]] = dict()
@ -347,18 +359,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
dw = binw.data_writer(sdb.byteorder) dw = binw.data_writer(sdb.byteorder)
def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]: object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb)
offset: int = 0
object_list: list[tuple[Net, int, int]] = []
for variable in variable_list:
lengths = sdb.var_size['dummy_' + variable.dtype]
object_list.append((variable, offset, lengths))
offset += (lengths + 3) // 4 * 4
return object_list, offset
object_list, data_section_lengths = variable_mem_layout(variable_list)
# Deallocate old allocated memory (if existing) # Deallocate old allocated memory (if existing)
dw.write_com(binw.Command.FREE_MEMORY) dw.write_com(binw.Command.FREE_MEMORY)
@ -385,8 +386,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
patch_list: list[tuple[int, int, int]] = [] patch_list: list[tuple[int, int, int]] = []
offset = 0 # offset in generated code chunk offset = 0 # offset in generated code chunk
# print('object_addr_lookp: ', object_addr_lookp) # assemble stencils to main program
data = sdb.get_function_body('function_start', 'start') data = sdb.get_function_body('function_start', 'start')
data_list.append(data) data_list.append(data)
offset += len(data) offset += len(data)
@ -456,8 +456,6 @@ class Target():
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)
dw.write_com(binw.Command.RUN_PROG) dw.write_com(binw.Command.RUN_PROG)
dw.write_int(0) dw.write_int(0)
#for s in self._variables:
# add_read_command(dw, self._variables, s)
dw.write_com(binw.Command.END_PROG) dw.write_com(binw.Command.END_PROG)
assert coparun(dw.get_data()) > 0 assert coparun(dw.get_data()) > 0

View File

@ -5,10 +5,6 @@ from enum import Enum
ByteOrder = Literal['little', 'big'] ByteOrder = Literal['little', 'big']
START_MARKER = 0xE1401F0F # Nop on x86-64
END_MARKER = 0xE2401F0F # Nop on x86-64
MARKER_LENGTH = 4
# on x86_64: call or jmp instruction when tail call optimized # on x86_64: call or jmp instruction when tail call optimized
LENGTH_CALL_INSTRUCTION = 5 LENGTH_CALL_INSTRUCTION = 5
@ -69,18 +65,6 @@ def get_last_call_in_function(func: elf_symbol) -> int:
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION
def get_stencil_position2(func: elf_symbol) -> tuple[int, int]:
# Find first start marker
marker_index = func.data.find(START_MARKER.to_bytes(MARKER_LENGTH, func.file.byteorder))
start_index = 0 if marker_index < 0 else marker_index + MARKER_LENGTH
# Find last end marker
end_index = func.data.rfind(END_MARKER.to_bytes(MARKER_LENGTH, func.file.byteorder))
end_index = len(func.data) if end_index < 0 else end_index - LENGTH_CALL_INSTRUCTION
return start_index, end_index
class stencil_database(): class stencil_database():
"""A class for loading and querying a stencil database from an ELF object file """A class for loading and querying a stencil database from an ELF object file
@ -140,7 +124,7 @@ class stencil_database():
reloc.bits, reloc.bits,
reloc.fields['r_addend']) reloc.fields['r_addend'])
# Exclude the call to the result_x function # Exclude the call to the result_* function
if patch.addr < end_index - start_index: if patch.addr < end_index - start_index:
yield patch yield patch

View File

@ -1,10 +1,11 @@
from numpy import info
from copapy import stencil_database, stencil_db from copapy import stencil_database, stencil_db
import platform import platform
def test_list_symbols():
arch = platform.machine() arch = platform.machine()
sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o') sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o')
def test_list_symbols():
print('----') print('----')
#print(sdb.function_definitions) #print(sdb.function_definitions)
for sym_name in sdb.function_definitions.keys(): for sym_name in sdb.function_definitions.keys():
@ -13,8 +14,6 @@ def test_list_symbols():
def test_start_end_function(): def test_start_end_function():
arch = platform.machine()
sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o')
for sym_name in sdb.function_definitions.keys(): for sym_name in sdb.function_definitions.keys():
symbol = sdb.elf.symbols[sym_name] symbol = sdb.elf.symbols[sym_name]
print('-', sym_name, stencil_db.get_stencil_position(symbol), len(symbol.data)) print('-', sym_name, stencil_db.get_stencil_position(symbol), len(symbol.data))
@ -24,6 +23,15 @@ def test_start_end_function():
assert start >= 0 and end >= start and end <= len(symbol.data) assert start >= 0 and end >= start and end <= len(symbol.data)
def test_aux_functions():
for sym_name in sdb.function_definitions.keys():
symbol = sdb.elf.symbols[sym_name]
for reloc in symbol.relocations:
if reloc.symbol.info != "STT_NOTYPE":
print(reloc.symbol.name, reloc.symbol.info)
if __name__ == "__main__": if __name__ == "__main__":
test_list_symbols() test_list_symbols()
test_start_end_function() test_start_end_function()