mirror of https://github.com/Nonannet/copapy.git
adding get_sub_functions
This commit is contained in:
parent
228daf5c9e
commit
c551fd1f2e
|
|
@ -128,10 +128,10 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
|
|||
|
||||
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
|
||||
|
||||
if typed_op not in generic_sdb.function_definitions:
|
||||
if typed_op not in generic_sdb.stencil_definitions:
|
||||
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
||||
|
||||
result_type = generic_sdb.function_definitions[typed_op].split('_')[0]
|
||||
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
||||
|
||||
result_net = Net(result_type, Op(typed_op, arg_nets))
|
||||
|
||||
|
|
@ -346,6 +346,18 @@ def get_variable_mem_layout(variable_list: list[Net], sdb: stencil_database) ->
|
|||
return object_list, offset
|
||||
|
||||
|
||||
def get_aux_function_mem_layout(function_names: list[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]:
|
||||
offset: int = 0
|
||||
function_list: list[tuple[str, int, int]] = []
|
||||
|
||||
for name in function_names:
|
||||
lengths = sdb.get_symbol_size(name)
|
||||
function_list.append((name, offset, lengths))
|
||||
offset += (lengths + 3) // 4 * 4
|
||||
|
||||
return function_list, offset
|
||||
|
||||
|
||||
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||
variables: dict[Net, tuple[int, int, str]] = dict()
|
||||
|
||||
|
|
@ -354,17 +366,16 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
output_ops = list(add_read_ops(ordered_ops))
|
||||
extended_output_ops = list(add_write_ops(output_ops, const_net_list))
|
||||
|
||||
# Get all nets associated with heap memory
|
||||
variable_list = get_nets([[const_net_list]], extended_output_ops)
|
||||
|
||||
dw = binw.data_writer(sdb.byteorder)
|
||||
|
||||
object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb)
|
||||
|
||||
# Deallocate old allocated memory (if existing)
|
||||
dw.write_com(binw.Command.FREE_MEMORY)
|
||||
|
||||
# Get all nets/variables associated with heap memory
|
||||
variable_list = get_nets([[const_net_list]], extended_output_ops)
|
||||
|
||||
# Write data
|
||||
object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb)
|
||||
dw.write_com(binw.Command.ALLOCATE_DATA)
|
||||
dw.write_int(data_section_lengths)
|
||||
|
||||
|
|
@ -377,22 +388,33 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
dw.write_value(net.source.value, lengths)
|
||||
# print(f'+ {net.dtype} {net.source.value}')
|
||||
|
||||
# write auxiliary_functions
|
||||
# TODO
|
||||
# Write auxiliary_functions
|
||||
object_list, data_section_lengths = get_aux_function_mem_layout(variable_list, sdb)
|
||||
dw.write_com(binw.Command.ALLOCATE_DATA)
|
||||
dw.write_int(data_section_lengths)
|
||||
|
||||
for net, out_offs, lengths in object_list:
|
||||
variables[net] = (out_offs, lengths, net.dtype)
|
||||
if isinstance(net.source, InitVar):
|
||||
dw.write_com(binw.Command.COPY_DATA)
|
||||
dw.write_int(out_offs)
|
||||
dw.write_int(lengths)
|
||||
dw.write_value(net.source.value, lengths)
|
||||
# print(f'+ {net.dtype} {net.source.value}')
|
||||
|
||||
# Prepare program code and relocations
|
||||
object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list}
|
||||
object_addr_lookup = {net: out_offs for net, out_offs, _ in object_list}
|
||||
data_list: list[bytes] = []
|
||||
patch_list: list[tuple[int, int, int]] = []
|
||||
offset = 0 # offset in generated code chunk
|
||||
|
||||
# assemble stencils to main program
|
||||
data = sdb.get_function_body('entry_function_shell', 'start')
|
||||
data = sdb.get_function_code('entry_function_shell', 'start')
|
||||
data_list.append(data)
|
||||
offset += len(data)
|
||||
|
||||
for associated_net, node in extended_output_ops:
|
||||
assert node.name in sdb.function_definitions, f"- Warning: {node.name} stencil not found"
|
||||
assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found"
|
||||
data = sdb.get_stencil_code(node.name)
|
||||
data_list.append(data)
|
||||
# print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
|
||||
|
|
@ -400,7 +422,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
for patch in sdb.get_patch_positions(node.name):
|
||||
if patch.target_symbol_info == 'STT_OBJECT':
|
||||
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
|
||||
object_addr = object_addr_lookp[associated_net]
|
||||
object_addr = object_addr_lookup[associated_net]
|
||||
patch_value = object_addr + patch.addend - (offset + patch.addr)
|
||||
# print('patch: ', patch, object_addr, patch_value)
|
||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
|
||||
|
|
@ -410,7 +432,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
|
||||
offset += len(data)
|
||||
|
||||
data = sdb.get_function_body('entry_function_shell', 'end')
|
||||
data = sdb.get_function_code('entry_function_shell', 'end')
|
||||
data_list.append(data)
|
||||
offset += len(data)
|
||||
# print('function_end', offset, data)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from pelfy import open_elf_file, elf_file, elf_symbol
|
||||
from typing import Generator, Literal
|
||||
from typing import Generator, Literal, Iterable
|
||||
from enum import Enum
|
||||
|
||||
ByteOrder = Literal['little', 'big']
|
||||
|
|
@ -45,11 +45,9 @@ def get_return_function_type(symbol: elf_symbol) -> str:
|
|||
|
||||
def strip_function(func: elf_symbol) -> bytes:
|
||||
"""Return stencil code by striped stancil function"""
|
||||
if func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE':
|
||||
start_index, end_index = get_stencil_position(func)
|
||||
return func.data[start_index:end_index]
|
||||
else:
|
||||
return func.data
|
||||
assert func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE', f"{func.name} is not a stancil function"
|
||||
start_index, end_index = get_stencil_position(func)
|
||||
return func.data[start_index:end_index]
|
||||
|
||||
|
||||
def get_stencil_position(func: elf_symbol) -> tuple[int, int]:
|
||||
|
|
@ -64,13 +62,15 @@ def get_last_call_in_function(func: elf_symbol) -> int:
|
|||
assert reloc, f'No call function in stencil function {func.name}.'
|
||||
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION
|
||||
|
||||
def symbol_is_stencil(sym: elf_symbol) -> bool:
|
||||
return (sym.info == 'STT_FUNC' and len(sym.relocations) > 0 and
|
||||
sym.relocations[-1].symbol.info == 'STT_NOTYPE')
|
||||
|
||||
class stencil_database():
|
||||
"""A class for loading and querying a stencil database from an ELF object file
|
||||
|
||||
Attributes:
|
||||
function_definitions (dict[str, str]): dictionary of function names and their return types
|
||||
data (dict[str, bytes]): dictionary of function names and their stripped code
|
||||
stencil_definitions (dict[str, str]): dictionary of function names and their return types
|
||||
var_size (dict[str, int]): dictionary of object names and their sizes
|
||||
byteorder (ByteOrder): byte order of the ELF file
|
||||
elf (elf_file): the loaded ELF file
|
||||
|
|
@ -87,21 +87,23 @@ class stencil_database():
|
|||
else:
|
||||
self.elf = elf_file(obj_file)
|
||||
|
||||
self.function_definitions = {s.name: get_return_function_type(s)
|
||||
for s in self.elf.symbols
|
||||
if s.info == 'STT_FUNC'}
|
||||
self.data = {s.name: strip_function(s)
|
||||
for s in self.elf.symbols
|
||||
if s.info == 'STT_FUNC'}
|
||||
self.stencil_definitions = {s.name: get_return_function_type(s)
|
||||
for s in self.elf.symbols
|
||||
if s.info == 'STT_FUNC'}
|
||||
|
||||
#self.data = {s.name: strip_function(s)
|
||||
# for s in self.elf.symbols
|
||||
# if s.info == 'STT_FUNC'}
|
||||
|
||||
self.var_size = {s.name: s.fields['st_size']
|
||||
for s in self.elf.symbols
|
||||
if s.info == 'STT_OBJECT'}
|
||||
self.byteorder: ByteOrder = self.elf.byteorder
|
||||
|
||||
for name in self.function_definitions.keys():
|
||||
sym = self.elf.symbols[name]
|
||||
sym.relocations
|
||||
self.elf.symbols[name].data
|
||||
#for name in self.function_definitions.keys():
|
||||
# sym = self.elf.symbols[name]
|
||||
# sym.relocations
|
||||
# self.elf.symbols[name].data
|
||||
|
||||
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
|
||||
"""Return patch positions for a provided symbol (function or object)
|
||||
|
|
@ -142,11 +144,28 @@ class stencil_database():
|
|||
Striped function code
|
||||
"""
|
||||
return strip_function(self.elf.symbols[name])
|
||||
|
||||
def get_sub_functions(self, names: Iterable[str]) -> set[str]:
|
||||
name_set: set[str] = set()
|
||||
for name in names:
|
||||
if name not in name_set:
|
||||
func = self.elf.symbols[name]
|
||||
for reloc in func.relocations:
|
||||
name_set.add(reloc.symbol.name)
|
||||
name_set |= self.get_sub_functions(name_set)
|
||||
return name_set
|
||||
|
||||
def get_function_body(self, name: str, part: Literal['start', 'end']) -> bytes:
|
||||
def get_symbol_size(self, name: str) -> int:
|
||||
return self.elf.symbols[name].fields['st_size']
|
||||
|
||||
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
|
||||
"""Returns machine code for a specified function name"""
|
||||
func = self.elf.symbols[name]
|
||||
assert func.info == 'STT_FUNC', f"{name} is not a function"
|
||||
index = get_last_call_in_function(func)
|
||||
if part == 'start':
|
||||
return func.data[:index]
|
||||
elif part == 'end':
|
||||
return func.data[index + LENGTH_CALL_INSTRUCTION:]
|
||||
else:
|
||||
return func.data[index + LENGTH_CALL_INSTRUCTION:]
|
||||
return func.data
|
||||
|
|
@ -7,13 +7,13 @@ sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o')
|
|||
def test_list_symbols():
|
||||
print('----')
|
||||
#print(sdb.function_definitions)
|
||||
for sym_name in sdb.function_definitions.keys():
|
||||
for sym_name in sdb.stencil_definitions.keys():
|
||||
print('\n-', sym_name)
|
||||
#print(list(sdb.get_patch_positions(sym_name)))
|
||||
|
||||
|
||||
def test_start_end_function():
|
||||
for sym_name in sdb.function_definitions.keys():
|
||||
for sym_name in sdb.stencil_definitions.keys():
|
||||
symbol = sdb.elf.symbols[sym_name]
|
||||
|
||||
if symbol.relocations and symbol.relocations[-1].symbol.info == 'STT_NOTYPE':
|
||||
|
|
@ -26,7 +26,7 @@ def test_start_end_function():
|
|||
|
||||
|
||||
def test_aux_functions():
|
||||
for sym_name in sdb.function_definitions.keys():
|
||||
for sym_name in sdb.stencil_definitions.keys():
|
||||
symbol = sdb.elf.symbols[sym_name]
|
||||
for reloc in symbol.relocations:
|
||||
if reloc.symbol.info != "STT_NOTYPE":
|
||||
|
|
@ -35,5 +35,4 @@ def test_aux_functions():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_list_symbols()
|
||||
test_start_end_function()
|
||||
test_aux_functions()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ def test_compile() -> None:
|
|||
c1 = CPVariable(9)
|
||||
|
||||
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
|
||||
ret = [c1 / 4]
|
||||
ret = [c1 // 4]
|
||||
|
||||
out = [Write(r) for r in ret]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue