diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 81436e9..d8bce93 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -1,9 +1,12 @@ import re import pkgutil -from typing import Generator, Iterable, Any +from tkinter import Variable +from typing import Generator, Iterable, Any, Literal, TypeVar import pelfy from . import binwrite as binw +from .stencil_db import stencil_database +Operand = type['Net'] | float | int def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: return [name for name, value in scope.items() if value is var] @@ -18,8 +21,9 @@ _ccode = pkgutil.get_data(__name__, 'stencils.c') assert _ccode is not None _function_definitions = _get_c_function_definitions(_ccode.decode('utf-8')) +sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o') -print(_function_definitions) +#print(_function_definitions) class Node: def __init__(self): @@ -104,10 +108,10 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: typed_op = '_'.join([op] + [a.dtype for a in arg_nets]) - if typed_op not in _function_definitions: + if typed_op not in sdb.function_definitions: raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") - result_type = _function_definitions[typed_op].split('_')[0] + result_type = sdb.function_definitions[typed_op].split('_')[0] result_net = Net(result_type, Op(typed_op, arg_nets)) @@ -155,16 +159,16 @@ def get_multiuse_nets(root: list[Node]) -> set[Net]: """ known_nets: set[Net] = set() - def recursiv_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]: + def recursive_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]: for net in net_list: #print(net) if net in known_nets: yield net else: known_nets.add(net) - yield from recursiv_node_search(net.source.args) + yield from recursive_node_search(net.source.args) - return set(recursiv_node_search(op.args[0] for op in root)) + return set(recursive_node_search(op.args[0] for op in root)) def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]: @@ -215,6 +219,10 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]: + """Get all const nodes in the op list + + Returns: + List of tuples of (name, net, value)""" net_lookup = {net.source: net for op in op_list for net in op.args} return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)] @@ -224,7 +232,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No correctly in the registers Returns: - Yields a tuples of a net and a operation. The net is the result net + Yields tuples of a net and a operation. The net is the result net from the returned operation""" registers: list[None | Net] = [None] * 2 @@ -252,19 +260,31 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]: - """Add write operation for each new defined net if a read operation is later flowed""" + """Add write operation for each new defined net if a read operation is later followed""" stored_nets = {c[1] for c in const_list} read_back_nets = {net for net, node in net_node_list if node.name.startswith('read_')} for net, node in net_node_list: yield net, node if net and net in read_back_nets and net not in stored_nets: - yield (net, Op('write_' + net.dtype, [net])) + yield (net, Write(net)) stored_nets.add(net) -def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer: +def get_variable_nets(nodes: Iterable[Node], nets_in: Iterable[Net]) -> list[Net]: + nets: set[Net] = set() + for node in nodes: + if node.name.startswith('write_'): + nets.add(node.args[0]) + for net_in in nets_in: + if net_in.source.name.startswith('read_'): + nets.add(net_in) + + return list(nets) + + +def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer: if isinstance(end_nodes, Node): node_list = [end_nodes] else: @@ -276,42 +296,79 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w output_ops = list(add_read_ops(ordered_ops)) extended_output_ops = list(add_write_ops(output_ops, const_list)) + for net, node in extended_output_ops: + print(node.name) - obj_file: str = 'src/copapy/obj/stencils_x86_64.o' - elf = pelfy.open_elf_file(obj_file) + variable_list = get_variable_nets((node for _, node in extended_output_ops), + (net for net, _ in extended_output_ops if net)) + + #assert False - dw = binw.data_writer(elf.byteorder) + #obj_file: str = 'src/copapy/obj/stencils_x86_64_O3.o' + #elf = pelfy.open_elf_file(obj_file) - prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'} - prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'} + dw = binw.data_writer(sdb.byteorder) - auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC'] - auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT'] + #prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'} + #prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'} + + #auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC'] + #auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT'] - # write data sections - object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects) + # write auxiliary_objects to data sections + #object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects) + #dw.write_com(binw.Command.ALLOCATE_DATA) + #dw.write_int(data_section_lengths) + + #for sym, out_offs, lengths in object_list: + # if sym.section and sym.section.type != 'SHT_NOBITS': + # dw.write_com(binw.Command.COPY_DATA) + # dw.write_int(out_offs) + # dw.write_int(lengths) + # dw.write_bytes(sym.data) + # print(f'+ {sym.name} {sym.fields}') + + # write variables to data sections + + def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]: + offset: int = 0 + object_list: list[tuple[Net, int, int]] = [] + + for variable in variable_list: + lengths = sdb.var_size['dummy_' + variable.dtype] + object_list.append((variable, offset, lengths)) + offset += (lengths + 3) // 4 * 4 + + return object_list, offset + + + object_list, data_section_lengths = variable_mem_layout(variable_list) + + #data_section_lengths = object_list[-1][1] + object_list[-1][2] dw.write_com(binw.Command.ALLOCATE_DATA) dw.write_int(data_section_lengths) - for sym, out_offs, lengths in object_list: - if sym.section and sym.section.type != 'SHT_NOBITS': + for net, out_offs, lengths in object_list: + if isinstance(net.source, Const): dw.write_com(binw.Command.COPY_DATA) dw.write_int(out_offs) dw.write_int(lengths) - dw.write_bytes(sym.data) + dw.write_value(net.source.value, lengths) + print(f'+ {net.dtype} {net.source.value}') # write auxiliary_functions # TODO # write program - print(list(prototype_functions.keys())) for net, node in extended_output_ops: - if node.name in prototype_functions: - #print(prototype_functions[node.name]) - pass - else: print(f"- Warning: {node.name} prototype not found") + assert node.name in sdb.function_definitions, f"- Warning: {node.name} prototype not found" + data = sdb.get_func_data(node.name) + print('*', node.name, ' '.join(f'{d:02X}' for d in data)) + for reloc_offset, lengths, bits, reloc_type in sdb.get_relocs(node.name): + #if not relocation.symbol.name.startswith('result_'): + print(relocation) print('-----') diff --git a/src/copapy/binwrite.py b/src/copapy/binwrite.py index 86980be..34ec547 100644 --- a/src/copapy/binwrite.py +++ b/src/copapy/binwrite.py @@ -1,6 +1,7 @@ from enum import Enum from pelfy import elf_symbol from typing import Literal +import struct Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2), ('ALLOCATE_CODE', 3), ('COPY_CODE', 4), @@ -67,7 +68,6 @@ def get_function_data_blob(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_sy out_offs += (lengths + 3) // 4 * 4 return code_list, out_offs - class data_writer(): def __init__(self, byteorder: Literal['little', 'big']): self._data: list[tuple[str, bytes, int]] = list() @@ -85,6 +85,18 @@ class data_writer(): def write_bytes(self, value: bytes): self._data.append((f"BYTES {len(value)}", value, 0)) + def write_value(self, value: int | float, num_bytes: int = 4): + if isinstance(value, int): + self.write_int(value, num_bytes, True) + else: + en = {'little': '<', 'big': '>'}[self.byteorder] + if num_bytes == 4: + data = struct.pack(en + 'f', value) + else: + data = struct.pack(en + 'd', value) + assert len(data) == num_bytes + self.write_bytes(data) + def print(self) -> None: for name, dat, flag in self._data: if flag: diff --git a/src/copapy/generate_stencils.py b/src/copapy/generate_stencils.py index f8b1c1d..6d3c237 100644 --- a/src/copapy/generate_stencils.py +++ b/src/copapy/generate_stencils.py @@ -3,6 +3,7 @@ from typing import Generator op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'} + def get_op_code(op: str, type1: str, type2: str, type_out: str): return f""" void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ @@ -17,11 +18,13 @@ def get_result_stubs1(type1: str): void result_{type1}({type1} arg1); """ + def get_result_stubs2(type1: str, type2: str): return f""" void result_{type1}_{type2}({type1} arg1, {type2} arg2); """ + def get_read_reg0_code(type1: str, type2: str, type_out: str): return f""" void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{ @@ -31,6 +34,7 @@ def get_read_reg0_code(type1: str, type2: str, type_out: str): }} """ + def get_read_reg1_code(type1: str, type2: str, type_out: str): return f""" void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{ @@ -40,6 +44,7 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str): }} """ + def get_write_code(type1: str): return f""" void write_{type1}({type1} arg1) {{ @@ -50,6 +55,7 @@ def get_write_code(type1: str): }} """ + def permutate(*lists: list[str]) -> Generator[list[str], None, None]: if len(lists) == 0: yield [] @@ -59,6 +65,7 @@ def permutate(*lists: list[str]) -> Generator[list[str], None, None]: for items in permutate(*rest): yield [item, *items] + if __name__ == "__main__": types = ['int', 'float'] ops = ['add', 'sub', 'mul', 'div'] diff --git a/src/copapy/obj/stencils_x86_64.o b/src/copapy/obj/stencils_x86_64.o new file mode 100644 index 0000000..0af302f Binary files /dev/null and b/src/copapy/obj/stencils_x86_64.o differ diff --git a/src/copapy/obj/stencils_x86_64_O1.o b/src/copapy/obj/stencils_x86_64_O1.o new file mode 100644 index 0000000..df5c0cc Binary files /dev/null and b/src/copapy/obj/stencils_x86_64_O1.o differ diff --git a/src/copapy/obj/stencils_x86_64_O2.o b/src/copapy/obj/stencils_x86_64_O2.o new file mode 100644 index 0000000..5a06a56 Binary files /dev/null and b/src/copapy/obj/stencils_x86_64_O2.o differ diff --git a/src/copapy/obj/stencils_x86_64_O3.o b/src/copapy/obj/stencils_x86_64_O3.o new file mode 100644 index 0000000..5a06a56 Binary files /dev/null and b/src/copapy/obj/stencils_x86_64_O3.o differ diff --git a/src/copapy/stencil_db.py b/src/copapy/stencil_db.py new file mode 100644 index 0000000..2fe2853 --- /dev/null +++ b/src/copapy/stencil_db.py @@ -0,0 +1,81 @@ +from ast import Tuple +from os import name +from tkinter import NO +import pelfy +from typing import Generator, Literal + + +start_marker = 0xF17ECAFE +end_marker = 0xF27ECAFE + +LENGTH_CALL_INSTRUCTION = 4 # x86_64 + +def get_ret_function_def(symbol: pelfy.elf_symbol): + #print('*', symbol.name, symbol.section) + result_func = symbol.relocations[-1].symbol + + assert result_func.name.startswith('result_') + return result_func.name[7:] + + +def strip_symbol(data: bytes, byteorder: Literal['little', 'big']) -> bytes: + """Return data between start and end marker and removing last instruction (call)""" + + # Find first start marker + start_index = data.find(start_marker.to_bytes(4, byteorder)) + + # Find last end marker + end_index = data.rfind(end_marker.to_bytes(4, byteorder), start_index) + + assert start_index > -1 and end_index > -1, f"Marker not found" + return data[start_index + 4:end_index - LENGTH_CALL_INSTRUCTION] + + + +class stencil_database(): + def __init__(self, obj_file: str): + self.elf = pelfy.open_elf_file(obj_file) + + #print(self.elf.symbols) + + self.function_definitions = {s.name: get_ret_function_def(s) for s in self.elf.symbols + if s.info == 'STT_FUNC'} + + self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols + if s.info == 'STT_FUNC'} + + self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols + if s.info == 'STT_OBJECT'} + + self.byteorder: Literal['little', 'big'] = self.elf.byteorder + + for name in self.function_definitions.keys(): + sym = self.elf.symbols[name] + sym.relocations + self.elf.symbols[name].data + + + def get_relocs(self, symbol_name: str) -> Generator[tuple[int, int, str], None, None]: + """Return relocation offset relative to striped symbol start. + Yields tuples of (reloc_offset, symbol_lenght, bits, reloc_type) + 1. reloc_offset: Offset of the relocation relative to the start of the stripped symbol data. + 2. Length of the striped symbol. + 3. Bits to patch + 4. reloc_type: Type of the relocation as a string. + """ + symbol = self.elf.symbols[symbol_name] + start_index = symbol.data.find(start_marker.to_bytes(4, symbol.file.byteorder)) + end_index = symbol.data.rfind(end_marker.to_bytes(4, symbol.file.byteorder), start_index) + + for reloc in symbol.relocations: + + reloc_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index + + if reloc_offset < end_index - start_index - LENGTH_CALL_INSTRUCTION: + yield (reloc_offset, reloc.bits, reloc.type) + + + def get_func_data(self, name: str) -> bytes: + return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder) + + diff --git a/src/runner/runmem2.c b/src/runner/runmem2.c new file mode 100644 index 0000000..2db48b5 --- /dev/null +++ b/src/runner/runmem2.c @@ -0,0 +1,196 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#define ALLOCATE_DATA 1 +#define COPY_DATA 2 +#define ALLOCATE_CODE 3 +#define COPY_CODE 4 +#define RELOCATE_FUNC 5 +#define RELOCATE_OBJECT 6 +#define SET_ENTR_POINT 64 +#define END_PROG 255 + +#define RELOC_RELATIVE_32 0 + +uint8_t *data_memory; +uint8_t *executable_memory; +uint32_t executable_memory_len; +int (*entr_point)(); + +uint8_t *get_executable_memory(uint32_t num_bytes){ + // Allocate executable memory + uint8_t *mem = (uint8_t*)mmap(NULL, num_bytes, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + return mem; +} + +uint8_t *get_data_memory(uint32_t num_bytes) { + uint8_t *mem = (uint8_t*)mmap(NULL, num_bytes, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + //uint8_t *mem = (uint8_t*)malloc(num_bytes); + return mem; +} + +int mark_mem_executable(){ + if (mprotect(executable_memory, executable_memory_len, PROT_READ | PROT_EXEC) == -1) { + perror("mprotect failed"); + return 0; + }else{ + return 1; + } +} + +void patch_mem_32(uint8_t *patch_addr, int32_t value){ + int32_t *val_ptr = (int32_t*)patch_addr; + *val_ptr = value; +} + +int relocate(uint8_t *patch_addr, uint32_t reloc_type, int32_t value){ + if (reloc_type == RELOC_RELATIVE_32){ + patch_mem_32(patch_addr, value); + }else{ + printf("Not implemented"); + return 0; + } + return 1; +} + +int parse_commands(uint8_t *bytes){ + int32_t value; + uint32_t command; + uint32_t reloc_type; + uint32_t offs; + int data_offs; + uint32_t size; + int err_flag = 0; + + while(!err_flag){ + command = *(uint32_t*)bytes; + bytes += 4; + switch(command) { + case ALLOCATE_DATA: + size = *(uint32_t*)bytes; bytes += 4; + printf("ALLOCATE_DATA size=%i\n", size); + data_memory = get_data_memory(size); + break; + + case COPY_DATA: + offs = *(uint32_t*)bytes; bytes += 4; + size = *(uint32_t*)bytes; bytes += 4; + printf("COPY_DATA offs=%i size=%i\n", offs, size); + memcpy(data_memory + offs, bytes, size); bytes += size; + break; + + case ALLOCATE_CODE: + size = *(uint32_t*)bytes; bytes += 4; + printf("ALLOCATE_CODE size=%i\n", size); + executable_memory = get_executable_memory(size); + executable_memory_len = size; + //printf("# d %i c %i off %i\n", data_memory, executable_memory, data_offs); + break; + + case COPY_CODE: + offs = *(uint32_t*)bytes; bytes += 4; + size = *(uint32_t*)bytes; bytes += 4; + printf("COPY_CODE offs=%i size=%i\n", offs, size); + memcpy(executable_memory + offs, bytes, size); bytes += size; + break; + + case RELOCATE_FUNC: + offs = *(uint32_t*)bytes; bytes += 4; + reloc_type = *(uint32_t*)bytes; bytes += 4; + value = *(int32_t*)bytes; bytes += 4; + printf("RELOCATE_FUNC patch_offs=%i reloc_type=%i value=%i\n", + offs, reloc_type, value); + relocate(executable_memory + offs, reloc_type, value); + break; + + case RELOCATE_OBJECT: + offs = *(uint32_t*)bytes; bytes += 4; + reloc_type = *(uint32_t*)bytes; bytes += 4; + value = *(int32_t*)bytes; bytes += 4; + printf("RELOCATE_OBJECT patch_offs=%i reloc_type=%i value=%i\n", + offs, reloc_type, value); + data_offs = (int32_t)(data_memory - executable_memory); + if (abs(data_offs) > 0x7FFFFFFF) { + perror("code and data memory to far apart"); + return EXIT_FAILURE; + } + relocate(executable_memory + offs, reloc_type, value + (int32_t)data_offs); + //printf("> %i\n", data_offs); + break; + + case SET_ENTR_POINT: + uint32_t rel_entr_point = *(uint32_t*)bytes; bytes += 4; + printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point); + entr_point = (int (*)())(executable_memory + rel_entr_point); + break; + + case END_PROG: + printf("END_PROG\n"); + mark_mem_executable(); + int ret = entr_point(); + printf("Return value: %i\n", ret); + err_flag = 1; + break; + + default: + printf("Unknown command\n"); + err_flag = -1; + break; + } + } + return err_flag; +} + +int main(int argc, char *argv[]) { + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return EXIT_FAILURE; + } + + // Open the file + int fd = open(argv[1], O_RDONLY); + if (fd < 0) { + perror("open"); + return EXIT_FAILURE; + } + + // Get file size + struct stat st; + if (fstat(fd, &st) < 0) { + perror("fstat"); + close(fd); + return EXIT_FAILURE; + } + + if (st.st_size == 0) { + fprintf(stderr, "Error: File is empty\n"); + close(fd); + return EXIT_FAILURE; + } + + //uint8_t *file_buff = get_data_memory((uint32_t)st.st_size); + uint8_t *file_buff = (uint8_t*)malloc((size_t)st.st_size); + + // Read file into allocated memory + if (read(fd, file_buff, (long unsigned int)st.st_size) != st.st_size) { + perror("read"); + close(fd); + return EXIT_FAILURE; + } + close(fd); + + parse_commands(file_buff); + + munmap(executable_memory, executable_memory_len); + return EXIT_SUCCESS; +} diff --git a/tests/test_compile.py b/tests/test_compile.py index c3ed29b..f822f53 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -21,7 +21,7 @@ def test_compile(): il = rc.compile_to_instruction_list(out) - #print('#', il.print()) + print('#', il.print()) if __name__ == "__main__": diff --git a/tests/test_stencil.db.py b/tests/test_stencil.db.py new file mode 100644 index 0000000..cc0af9b --- /dev/null +++ b/tests/test_stencil.db.py @@ -0,0 +1,9 @@ +from copapy import stencil_database + +if __name__ == "__main__": + sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o') + print('----') + #print(sdb.function_definitions) + for sym_name in sdb.function_definitions.keys(): + print('-', sym_name) + print(list(sdb.get_relocs(sym_name))) \ No newline at end of file