mirror of https://github.com/Nonannet/copapy.git
adding get_sub_functions
This commit is contained in:
parent
228daf5c9e
commit
c551fd1f2e
|
|
@ -128,10 +128,10 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
|
||||||
|
|
||||||
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
|
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
|
||||||
|
|
||||||
if typed_op not in generic_sdb.function_definitions:
|
if typed_op not in generic_sdb.stencil_definitions:
|
||||||
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
||||||
|
|
||||||
result_type = generic_sdb.function_definitions[typed_op].split('_')[0]
|
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
||||||
|
|
||||||
result_net = Net(result_type, Op(typed_op, arg_nets))
|
result_net = Net(result_type, Op(typed_op, arg_nets))
|
||||||
|
|
||||||
|
|
@ -346,6 +346,18 @@ def get_variable_mem_layout(variable_list: list[Net], sdb: stencil_database) ->
|
||||||
return object_list, offset
|
return object_list, offset
|
||||||
|
|
||||||
|
|
||||||
|
def get_aux_function_mem_layout(function_names: list[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]:
|
||||||
|
offset: int = 0
|
||||||
|
function_list: list[tuple[str, int, int]] = []
|
||||||
|
|
||||||
|
for name in function_names:
|
||||||
|
lengths = sdb.get_symbol_size(name)
|
||||||
|
function_list.append((name, offset, lengths))
|
||||||
|
offset += (lengths + 3) // 4 * 4
|
||||||
|
|
||||||
|
return function_list, offset
|
||||||
|
|
||||||
|
|
||||||
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||||
variables: dict[Net, tuple[int, int, str]] = dict()
|
variables: dict[Net, tuple[int, int, str]] = dict()
|
||||||
|
|
||||||
|
|
@ -354,17 +366,16 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
||||||
output_ops = list(add_read_ops(ordered_ops))
|
output_ops = list(add_read_ops(ordered_ops))
|
||||||
extended_output_ops = list(add_write_ops(output_ops, const_net_list))
|
extended_output_ops = list(add_write_ops(output_ops, const_net_list))
|
||||||
|
|
||||||
# Get all nets associated with heap memory
|
|
||||||
variable_list = get_nets([[const_net_list]], extended_output_ops)
|
|
||||||
|
|
||||||
dw = binw.data_writer(sdb.byteorder)
|
dw = binw.data_writer(sdb.byteorder)
|
||||||
|
|
||||||
object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb)
|
|
||||||
|
|
||||||
# Deallocate old allocated memory (if existing)
|
# Deallocate old allocated memory (if existing)
|
||||||
dw.write_com(binw.Command.FREE_MEMORY)
|
dw.write_com(binw.Command.FREE_MEMORY)
|
||||||
|
|
||||||
|
# Get all nets/variables associated with heap memory
|
||||||
|
variable_list = get_nets([[const_net_list]], extended_output_ops)
|
||||||
|
|
||||||
# Write data
|
# Write data
|
||||||
|
object_list, data_section_lengths = get_variable_mem_layout(variable_list, sdb)
|
||||||
dw.write_com(binw.Command.ALLOCATE_DATA)
|
dw.write_com(binw.Command.ALLOCATE_DATA)
|
||||||
dw.write_int(data_section_lengths)
|
dw.write_int(data_section_lengths)
|
||||||
|
|
||||||
|
|
@ -377,22 +388,33 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
||||||
dw.write_value(net.source.value, lengths)
|
dw.write_value(net.source.value, lengths)
|
||||||
# print(f'+ {net.dtype} {net.source.value}')
|
# print(f'+ {net.dtype} {net.source.value}')
|
||||||
|
|
||||||
# write auxiliary_functions
|
# Write auxiliary_functions
|
||||||
# TODO
|
object_list, data_section_lengths = get_aux_function_mem_layout(variable_list, sdb)
|
||||||
|
dw.write_com(binw.Command.ALLOCATE_DATA)
|
||||||
|
dw.write_int(data_section_lengths)
|
||||||
|
|
||||||
|
for net, out_offs, lengths in object_list:
|
||||||
|
variables[net] = (out_offs, lengths, net.dtype)
|
||||||
|
if isinstance(net.source, InitVar):
|
||||||
|
dw.write_com(binw.Command.COPY_DATA)
|
||||||
|
dw.write_int(out_offs)
|
||||||
|
dw.write_int(lengths)
|
||||||
|
dw.write_value(net.source.value, lengths)
|
||||||
|
# print(f'+ {net.dtype} {net.source.value}')
|
||||||
|
|
||||||
# Prepare program code and relocations
|
# Prepare program code and relocations
|
||||||
object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list}
|
object_addr_lookup = {net: out_offs for net, out_offs, _ in object_list}
|
||||||
data_list: list[bytes] = []
|
data_list: list[bytes] = []
|
||||||
patch_list: list[tuple[int, int, int]] = []
|
patch_list: list[tuple[int, int, int]] = []
|
||||||
offset = 0 # offset in generated code chunk
|
offset = 0 # offset in generated code chunk
|
||||||
|
|
||||||
# assemble stencils to main program
|
# assemble stencils to main program
|
||||||
data = sdb.get_function_body('entry_function_shell', 'start')
|
data = sdb.get_function_code('entry_function_shell', 'start')
|
||||||
data_list.append(data)
|
data_list.append(data)
|
||||||
offset += len(data)
|
offset += len(data)
|
||||||
|
|
||||||
for associated_net, node in extended_output_ops:
|
for associated_net, node in extended_output_ops:
|
||||||
assert node.name in sdb.function_definitions, f"- Warning: {node.name} stencil not found"
|
assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found"
|
||||||
data = sdb.get_stencil_code(node.name)
|
data = sdb.get_stencil_code(node.name)
|
||||||
data_list.append(data)
|
data_list.append(data)
|
||||||
# print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
|
# print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
|
||||||
|
|
@ -400,7 +422,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
||||||
for patch in sdb.get_patch_positions(node.name):
|
for patch in sdb.get_patch_positions(node.name):
|
||||||
if patch.target_symbol_info == 'STT_OBJECT':
|
if patch.target_symbol_info == 'STT_OBJECT':
|
||||||
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
|
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
|
||||||
object_addr = object_addr_lookp[associated_net]
|
object_addr = object_addr_lookup[associated_net]
|
||||||
patch_value = object_addr + patch.addend - (offset + patch.addr)
|
patch_value = object_addr + patch.addend - (offset + patch.addr)
|
||||||
# print('patch: ', patch, object_addr, patch_value)
|
# print('patch: ', patch, object_addr, patch_value)
|
||||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
|
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
|
||||||
|
|
@ -410,7 +432,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
||||||
|
|
||||||
offset += len(data)
|
offset += len(data)
|
||||||
|
|
||||||
data = sdb.get_function_body('entry_function_shell', 'end')
|
data = sdb.get_function_code('entry_function_shell', 'end')
|
||||||
data_list.append(data)
|
data_list.append(data)
|
||||||
offset += len(data)
|
offset += len(data)
|
||||||
# print('function_end', offset, data)
|
# print('function_end', offset, data)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pelfy import open_elf_file, elf_file, elf_symbol
|
from pelfy import open_elf_file, elf_file, elf_symbol
|
||||||
from typing import Generator, Literal
|
from typing import Generator, Literal, Iterable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
ByteOrder = Literal['little', 'big']
|
ByteOrder = Literal['little', 'big']
|
||||||
|
|
@ -45,11 +45,9 @@ def get_return_function_type(symbol: elf_symbol) -> str:
|
||||||
|
|
||||||
def strip_function(func: elf_symbol) -> bytes:
|
def strip_function(func: elf_symbol) -> bytes:
|
||||||
"""Return stencil code by striped stancil function"""
|
"""Return stencil code by striped stancil function"""
|
||||||
if func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE':
|
assert func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE', f"{func.name} is not a stancil function"
|
||||||
start_index, end_index = get_stencil_position(func)
|
start_index, end_index = get_stencil_position(func)
|
||||||
return func.data[start_index:end_index]
|
return func.data[start_index:end_index]
|
||||||
else:
|
|
||||||
return func.data
|
|
||||||
|
|
||||||
|
|
||||||
def get_stencil_position(func: elf_symbol) -> tuple[int, int]:
|
def get_stencil_position(func: elf_symbol) -> tuple[int, int]:
|
||||||
|
|
@ -64,13 +62,15 @@ def get_last_call_in_function(func: elf_symbol) -> int:
|
||||||
assert reloc, f'No call function in stencil function {func.name}.'
|
assert reloc, f'No call function in stencil function {func.name}.'
|
||||||
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION
|
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - LENGTH_CALL_INSTRUCTION
|
||||||
|
|
||||||
|
def symbol_is_stencil(sym: elf_symbol) -> bool:
|
||||||
|
return (sym.info == 'STT_FUNC' and len(sym.relocations) > 0 and
|
||||||
|
sym.relocations[-1].symbol.info == 'STT_NOTYPE')
|
||||||
|
|
||||||
class stencil_database():
|
class stencil_database():
|
||||||
"""A class for loading and querying a stencil database from an ELF object file
|
"""A class for loading and querying a stencil database from an ELF object file
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
function_definitions (dict[str, str]): dictionary of function names and their return types
|
stencil_definitions (dict[str, str]): dictionary of function names and their return types
|
||||||
data (dict[str, bytes]): dictionary of function names and their stripped code
|
|
||||||
var_size (dict[str, int]): dictionary of object names and their sizes
|
var_size (dict[str, int]): dictionary of object names and their sizes
|
||||||
byteorder (ByteOrder): byte order of the ELF file
|
byteorder (ByteOrder): byte order of the ELF file
|
||||||
elf (elf_file): the loaded ELF file
|
elf (elf_file): the loaded ELF file
|
||||||
|
|
@ -87,21 +87,23 @@ class stencil_database():
|
||||||
else:
|
else:
|
||||||
self.elf = elf_file(obj_file)
|
self.elf = elf_file(obj_file)
|
||||||
|
|
||||||
self.function_definitions = {s.name: get_return_function_type(s)
|
self.stencil_definitions = {s.name: get_return_function_type(s)
|
||||||
for s in self.elf.symbols
|
for s in self.elf.symbols
|
||||||
if s.info == 'STT_FUNC'}
|
if s.info == 'STT_FUNC'}
|
||||||
self.data = {s.name: strip_function(s)
|
|
||||||
for s in self.elf.symbols
|
#self.data = {s.name: strip_function(s)
|
||||||
if s.info == 'STT_FUNC'}
|
# for s in self.elf.symbols
|
||||||
|
# if s.info == 'STT_FUNC'}
|
||||||
|
|
||||||
self.var_size = {s.name: s.fields['st_size']
|
self.var_size = {s.name: s.fields['st_size']
|
||||||
for s in self.elf.symbols
|
for s in self.elf.symbols
|
||||||
if s.info == 'STT_OBJECT'}
|
if s.info == 'STT_OBJECT'}
|
||||||
self.byteorder: ByteOrder = self.elf.byteorder
|
self.byteorder: ByteOrder = self.elf.byteorder
|
||||||
|
|
||||||
for name in self.function_definitions.keys():
|
#for name in self.function_definitions.keys():
|
||||||
sym = self.elf.symbols[name]
|
# sym = self.elf.symbols[name]
|
||||||
sym.relocations
|
# sym.relocations
|
||||||
self.elf.symbols[name].data
|
# self.elf.symbols[name].data
|
||||||
|
|
||||||
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
|
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
|
||||||
"""Return patch positions for a provided symbol (function or object)
|
"""Return patch positions for a provided symbol (function or object)
|
||||||
|
|
@ -143,10 +145,27 @@ class stencil_database():
|
||||||
"""
|
"""
|
||||||
return strip_function(self.elf.symbols[name])
|
return strip_function(self.elf.symbols[name])
|
||||||
|
|
||||||
def get_function_body(self, name: str, part: Literal['start', 'end']) -> bytes:
|
def get_sub_functions(self, names: Iterable[str]) -> set[str]:
|
||||||
|
name_set: set[str] = set()
|
||||||
|
for name in names:
|
||||||
|
if name not in name_set:
|
||||||
|
func = self.elf.symbols[name]
|
||||||
|
for reloc in func.relocations:
|
||||||
|
name_set.add(reloc.symbol.name)
|
||||||
|
name_set |= self.get_sub_functions(name_set)
|
||||||
|
return name_set
|
||||||
|
|
||||||
|
def get_symbol_size(self, name: str) -> int:
|
||||||
|
return self.elf.symbols[name].fields['st_size']
|
||||||
|
|
||||||
|
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
|
||||||
|
"""Returns machine code for a specified function name"""
|
||||||
func = self.elf.symbols[name]
|
func = self.elf.symbols[name]
|
||||||
|
assert func.info == 'STT_FUNC', f"{name} is not a function"
|
||||||
index = get_last_call_in_function(func)
|
index = get_last_call_in_function(func)
|
||||||
if part == 'start':
|
if part == 'start':
|
||||||
return func.data[:index]
|
return func.data[:index]
|
||||||
else:
|
elif part == 'end':
|
||||||
return func.data[index + LENGTH_CALL_INSTRUCTION:]
|
return func.data[index + LENGTH_CALL_INSTRUCTION:]
|
||||||
|
else:
|
||||||
|
return func.data
|
||||||
|
|
@ -7,13 +7,13 @@ sdb = stencil_database(f'src/copapy/obj/stencils_{arch}_O3.o')
|
||||||
def test_list_symbols():
|
def test_list_symbols():
|
||||||
print('----')
|
print('----')
|
||||||
#print(sdb.function_definitions)
|
#print(sdb.function_definitions)
|
||||||
for sym_name in sdb.function_definitions.keys():
|
for sym_name in sdb.stencil_definitions.keys():
|
||||||
print('\n-', sym_name)
|
print('\n-', sym_name)
|
||||||
#print(list(sdb.get_patch_positions(sym_name)))
|
#print(list(sdb.get_patch_positions(sym_name)))
|
||||||
|
|
||||||
|
|
||||||
def test_start_end_function():
|
def test_start_end_function():
|
||||||
for sym_name in sdb.function_definitions.keys():
|
for sym_name in sdb.stencil_definitions.keys():
|
||||||
symbol = sdb.elf.symbols[sym_name]
|
symbol = sdb.elf.symbols[sym_name]
|
||||||
|
|
||||||
if symbol.relocations and symbol.relocations[-1].symbol.info == 'STT_NOTYPE':
|
if symbol.relocations and symbol.relocations[-1].symbol.info == 'STT_NOTYPE':
|
||||||
|
|
@ -26,7 +26,7 @@ def test_start_end_function():
|
||||||
|
|
||||||
|
|
||||||
def test_aux_functions():
|
def test_aux_functions():
|
||||||
for sym_name in sdb.function_definitions.keys():
|
for sym_name in sdb.stencil_definitions.keys():
|
||||||
symbol = sdb.elf.symbols[sym_name]
|
symbol = sdb.elf.symbols[sym_name]
|
||||||
for reloc in symbol.relocations:
|
for reloc in symbol.relocations:
|
||||||
if reloc.symbol.info != "STT_NOTYPE":
|
if reloc.symbol.info != "STT_NOTYPE":
|
||||||
|
|
@ -35,5 +35,4 @@ def test_aux_functions():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_list_symbols()
|
test_aux_functions()
|
||||||
test_start_end_function()
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ def test_compile() -> None:
|
||||||
c1 = CPVariable(9)
|
c1 = CPVariable(9)
|
||||||
|
|
||||||
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
|
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
|
||||||
ret = [c1 / 4]
|
ret = [c1 // 4]
|
||||||
|
|
||||||
out = [Write(r) for r in ret]
|
out = [Write(r) for r in ret]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue