adding get_sub_functions

This commit is contained in:
Nicolas 2025-10-12 23:13:33 +02:00
parent 228daf5c9e
commit c551fd1f2e
4 changed files with 80 additions and 40 deletions

View File

@ -128,10 +128,10 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
typed_op = '_'.join([op] + [a.dtype for a in arg_nets]) typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
if typed_op not in generic_sdb.function_definitions: if typed_op not in generic_sdb.stencil_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
result_type = generic_sdb.function_definitions[typed_op].split('_')[0] result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
result_net = Net(result_type, Op(typed_op, arg_nets)) result_net = Net(result_type, Op(typed_op, arg_nets))
@ -346,6 +346,18 @@ def get_variable_mem_layout(variable_list: list[Net], sdb: stencil_database) ->
return object_list, offset return object_list, offset
def get_aux_function_mem_layout(function_names: list[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]:
offset: int = 0
function_list: list[tuple[str, int, int]] = []
for name in function_names:
lengths = sdb.get_symbol_size(name)
function_list.append((name, offset, lengths))
offset += (lengths + 3) // 4 * 4
return function_list, offset
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
variables: dict[Net, tuple[int, int, str]] = dict() variables: dict[Net, tuple[int, int, str]] = dict()
@ -354,17 +366,16 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
output_ops = list(add_read_ops(ordered_ops)) output_ops = list(add_read_ops(ordered_ops))
extended_output_ops = list(add_write_ops(output_ops, const_net_list)) extended_output_ops = list(add_write_ops(output_ops, const_net_list))
# Get all nets associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops)
dw = binw.data_writer(sdb.byteorder) dw = binw.data_writer(sdb.byteorder)
object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb)
# Deallocate old allocated memory (if existing) # Deallocate old allocated memory (if existing)
dw.write_com(binw.Command.FREE_MEMORY) dw.write_com(binw.Command.FREE_MEMORY)
# Get all nets/variables associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops)
# Write data # Write data
object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb)
dw.write_com(binw.Command.ALLOCATE_DATA) dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths) dw.write_int(data_section_lengths)
@ -377,22 +388,33 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
dw.write_value(net.source.value, lengths) dw.write_value(net.source.value, lengths)
# print(f'+ {net.dtype} {net.source.value}') # print(f'+ {net.dtype} {net.source.value}')
# write auxiliary_functions # Write auxiliary_functions
# TODO object_list, data_section_lengths = get_aux_function_mem_layout(variable_list, sdb)
dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths)
for net, out_offs, lengths in object_list:
variables[net] = (out_offs, lengths, net.dtype)
if isinstance(net.source, InitVar):
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs)
dw.write_int(lengths)
dw.write_value(net.source.value, lengths)
# print(f'+ {net.dtype} {net.source.value}')
# Prepare program code and relocations # Prepare program code and relocations
object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list} object_addr_lookup = {net: out_offs for net, out_offs, _ in object_list}
data_list: list[bytes] = [] data_list: list[bytes] = []
patch_list: list[tuple[int, int, int]] = [] patch_list: list[tuple[int, int, int]] = []
offset = 0 # offset in generated code chunk offset = 0 # offset in generated code chunk
# assemble stencils to main program # assemble stencils to main program
data = sdb.get_function_body('entry_function_shell', 'start') data = sdb.get_function_code('entry_function_shell', 'start')
data_list.append(data) data_list.append(data)
offset += len(data) offset += len(data)
for associated_net, node in extended_output_ops: for associated_net, node in extended_output_ops:
assert node.name in sdb.function_definitions, f"- Warning: {node.name} stencil not found" assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found"
data = sdb.get_stencil_code(node.name) data = sdb.get_stencil_code(node.name)
data_list.append(data) data_list.append(data)
# print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) # print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
@ -400,7 +422,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
for patch in sdb.get_patch_positions(node.name): for patch in sdb.get_patch_positions(node.name):
if patch.target_symbol_info == 'STT_OBJECT': if patch.target_symbol_info == 'STT_OBJECT':
assert associated_net, f"Relocation found but no net defined for operation {node.name}" assert associated_net, f"Relocation found but no net defined for operation {node.name}"
object_addr = object_addr_lookp[associated_net] object_addr = object_addr_lookup[associated_net]
patch_value = object_addr + patch.addend - (offset + patch.addr) patch_value = object_addr + patch.addend - (offset + patch.addr)
# print('patch: ', patch, object_addr, patch_value) # print('patch: ', patch, object_addr, patch_value)
patch_list.append((patch.type.value, offset + patch.addr, patch_value)) patch_list.append((patch.type.value, offset + patch.addr, patch_value))
@ -410,7 +432,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
offset += len(data) offset += len(data)
data = sdb.get_function_body('entry_function_shell', 'end') data = sdb.get_function_code('entry_function_shell', 'end')
data_list.append(data) data_list.append(data)
offset += len(data) offset += len(data)
# print('function_end', offset, data) # print('function_end', offset, data)

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from pelfy import open_elf_file, elf_file, elf_symbol from pelfy import open_elf_file, elf_file, elf_symbol
from typing import Generator, Literal from typing import Generator, Literal, Iterable
from enum import Enum from enum import Enum
ByteOrder = Literal['little', 'big'] ByteOrder = Literal['little', 'big']
@ -45,11 +45,9 @@ def get_return_function_type(symbol: elf_symbol) -> str:
def strip_function(func: elf_symbol) -> bytes: def strip_function(func: elf_symbol) -> bytes:
"""Return stencil code by striped stancil function""" """Return stencil code by striped stancil function"""
if func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE': assert func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE', f"{func.name} is not a stancil function"
start_index, end_index = get_stencil_position(func) start_index, end_index = get_stencil_position(func)
return func.data[start_index:end_index] return func.data[start_index:end_index]
else:
return func.data
def get_stencil_position(func: elf_symbol) -> tuple[int, int]: def get_stencil_position(func: elf_symbol) -> tuple[int, int]:
@ -64,13 +62,15 @@ def get_last_call_in_function(func: elf_symbol) -> int:
assert reloc, f'No call function in stencil function {func.name}.' assert reloc, f'No call function in stencil function {func.name}.'
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION
def symbol_is_stencil(sym: elf_symbol) -> bool:
return (sym.info == 'STT_FUNC' and len(sym.relocations) > 0 and
sym.relocations[-1].symbol.info == 'STT_NOTYPE')
class stencil_database(): class stencil_database():
"""A class for loading and querying a stencil database from an ELF object file """A class for loading and querying a stencil database from an ELF object file
Attributes: Attributes:
function_definitions (dict[str, str]): dictionary of function names and their return types stencil_definitions (dict[str, str]): dictionary of function names and their return types
data (dict[str, bytes]): dictionary of function names and their stripped code
var_size (dict[str, int]): dictionary of object names and their sizes var_size (dict[str, int]): dictionary of object names and their sizes
byteorder (ByteOrder): byte order of the ELF file byteorder (ByteOrder): byte order of the ELF file
elf (elf_file): the loaded ELF file elf (elf_file): the loaded ELF file
@ -87,21 +87,23 @@ class stencil_database():
else: else:
self.elf = elf_file(obj_file) self.elf = elf_file(obj_file)
self.function_definitions = {s.name: get_return_function_type(s) self.stencil_definitions = {s.name: get_return_function_type(s)
for s in self.elf.symbols for s in self.elf.symbols
if s.info == 'STT_FUNC'} if s.info == 'STT_FUNC'}
self.data = {s.name: strip_function(s)
for s in self.elf.symbols #self.data = {s.name: strip_function(s)
if s.info == 'STT_FUNC'} # for s in self.elf.symbols
# if s.info == 'STT_FUNC'}
self.var_size = {s.name: s.fields['st_size'] self.var_size = {s.name: s.fields['st_size']
for s in self.elf.symbols for s in self.elf.symbols
if s.info == 'STT_OBJECT'} if s.info == 'STT_OBJECT'}
self.byteorder: ByteOrder = self.elf.byteorder self.byteorder: ByteOrder = self.elf.byteorder
for name in self.function_definitions.keys(): #for name in self.function_definitions.keys():
sym = self.elf.symbols[name] # sym = self.elf.symbols[name]
sym.relocations # sym.relocations
self.elf.symbols[name].data # self.elf.symbols[name].data
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]: def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
"""Return patch positions for a provided symbol (function or object) """Return patch positions for a provided symbol (function or object)
@ -142,11 +144,28 @@ class stencil_database():
Striped function code Striped function code
""" """
return strip_function(self.elf.symbols[name]) return strip_function(self.elf.symbols[name])
def get_sub_functions(self, names: Iterable[str]) -> set[str]:
name_set: set[str] = set()
for name in names:
if name not in name_set:
func = self.elf.symbols[name]
for reloc in func.relocations:
name_set.add(reloc.symbol.name)
name_set |= self.get_sub_functions(name_set)
return name_set
def get_function_body(self, name: str, part: Literal['start', 'end']) -> bytes: def get_symbol_size(self, name: str) -> int:
return self.elf.symbols[name].fields['st_size']
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
"""Returns machine code for a specified function name"""
func = self.elf.symbols[name] func = self.elf.symbols[name]
assert func.info == 'STT_FUNC', f"{name} is not a function"
index = get_last_call_in_function(func) index = get_last_call_in_function(func)
if part == 'start': if part == 'start':
return func.data[:index] return func.data[:index]
elif part == 'end':
return func.data[index + LENGTH_CALL_INSTRUCTION:]
else: else:
return func.data[index + LENGTH_CALL_INSTRUCTION:] return func.data

View File

@ -7,13 +7,13 @@ sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o')
def test_list_symbols(): def test_list_symbols():
print('----') print('----')
#print(sdb.function_definitions) #print(sdb.function_definitions)
for sym_name in sdb.function_definitions.keys(): for sym_name in sdb.stencil_definitions.keys():
print('\n-', sym_name) print('\n-', sym_name)
#print(list(sdb.get_patch_positions(sym_name))) #print(list(sdb.get_patch_positions(sym_name)))
def test_start_end_function(): def test_start_end_function():
for sym_name in sdb.function_definitions.keys(): for sym_name in sdb.stencil_definitions.keys():
symbol = sdb.elf.symbols[sym_name] symbol = sdb.elf.symbols[sym_name]
if symbol.relocations and symbol.relocations[-1].symbol.info == 'STT_NOTYPE': if symbol.relocations and symbol.relocations[-1].symbol.info == 'STT_NOTYPE':
@ -26,7 +26,7 @@ def test_start_end_function():
def test_aux_functions(): def test_aux_functions():
for sym_name in sdb.function_definitions.keys(): for sym_name in sdb.stencil_definitions.keys():
symbol = sdb.elf.symbols[sym_name] symbol = sdb.elf.symbols[sym_name]
for reloc in symbol.relocations: for reloc in symbol.relocations:
if reloc.symbol.info != "STT_NOTYPE": if reloc.symbol.info != "STT_NOTYPE":
@ -35,5 +35,4 @@ def test_aux_functions():
if __name__ == "__main__": if __name__ == "__main__":
test_list_symbols() test_aux_functions()
test_start_end_function()

View File

@ -8,7 +8,7 @@ def test_compile() -> None:
c1 = CPVariable(9) c1 = CPVariable(9)
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4] #ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
ret = [c1 / 4] ret = [c1 // 4]
out = [Write(r) for r in ret] out = [Write(r) for r in ret]