diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 26bf9b5..a2a0112 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -1,16 +1,25 @@ import re import pkgutil +from typing import Generator, Iterable, Any +import pelfy +from . import binwrite as binw -def get_var_name(var, scope=globals()): +def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: return [name for name, value in scope.items() if value is var] -def _get_c_function_definitions(code: str): +def _get_c_function_definitions(code: str) -> dict[str, str]: ret = re.findall(r".*?void\s+([a-z_1-9]*)\s*\([^\)]*?\)[^\}]*?\{[^\}]*?result_([a-z_]*)\(.*?", code, flags=re.S) return {r[0]: r[1] for r in ret} -_function_definitions = _get_c_function_definitions(pkgutil.get_data(__name__, 'ops.c').decode('utf-8')) +_ccode = pkgutil.get_data(__name__, 'ops.c') +assert _ccode is not None +_function_definitions = _get_c_function_definitions(_ccode.decode('utf-8')) class Node: + def __init__(self): + self.args: list[Net] = [] + self.name: str = '' + def __repr__(self): #return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})" return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})" @@ -23,28 +32,28 @@ class Net: self.dtype = dtype self.source = source - def __mul__(self, other): + def __mul__(self, other: Any) -> 'Net': return _add_op('mul', [self, other], True) - def __rmul__(self, other): + def __rmul__(self, other: Any) -> 'Net': return _add_op('mul', [self, other], True) - def __add__(self, other): + def __add__(self, other: Any) -> 'Net': return _add_op('add', [self, other], True) - def __radd__(self, other): + def __radd__(self, other: Any) -> 'Net': return _add_op('add', [self, other], True) - def __sub__ (self, other): + def __sub__ (self, other: Any) -> 'Net': return _add_op('sub', [self, other]) - def __rsub__ (self, other): + def __rsub__ (self, other: Any) -> 'Net': return _add_op('sub', [other, self]) - def __truediv__ (self, other): + def __truediv__ (self, other: Any) -> 'Net': return _add_op('div', [self, other]) - def __rtruediv__ (self, other): + def __rtruediv__ (self, other: Any) -> 'Net': return _add_op('div', [other, self]) def __repr__(self): @@ -73,34 +82,35 @@ class Write(Node): class Op(Node): def __init__(self, typed_op_name: str, args: list[Net]): assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' - self.name = typed_op_name - self.args = args + self.name: str = typed_op_name + self.args: list[Net] = args -def _add_op(op: str, args: list[Net], commutative = False): - args = [a if isinstance(a, Net) else const(a) for a in args] +def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: + arg_nets = [a if isinstance(a, Net) else const(a) for a in args] if commutative: - args = sorted(args, key=lambda a: a.dtype) + arg_nets = sorted(arg_nets, key=lambda a: a.dtype) - typed_op = '_'.join([op] + [a.dtype for a in args]) + typed_op = '_'.join([op] + [a.dtype for a in arg_nets]) if typed_op not in _function_definitions: - raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in args])}") + raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") result_type = _function_definitions[typed_op] - result_net = Net(result_type, Op(typed_op, args)) + result_net = Net(result_type, Op(typed_op, arg_nets)) return result_net #def read_input(hw: Device, test_value: float): # return Net(type(value)) -def const(value: float | int | bool): +def const(value: Any) -> Net: + assert isinstance(value, (int, float, bool)), f'Unsupported type for const: {type(value).__name__}' new_const = Const(value) return Net(new_const.dtype, new_const) -def _get_data_and_dtype(value): +def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: if isinstance(value, int): return ('int', int(value)) elif isinstance(value, float): @@ -110,25 +120,28 @@ def _get_data_and_dtype(value): else: raise ValueError(f'Non supported data type: {type(value).__name__}') -def const_vector3d(x: float, y: float, z: float): - return vec3d((const(x), const(y), const(z))) class vec3d: - def __init__(self, value: tuple[float, float, float]): + def __init__(self, value: tuple[Net, Net, Net]): self.value = value - def __add__(self, other): - return vec3d(tuple(a+b for a,b in zip(self.value, other.value))) + def __add__(self, other: 'vec3d') -> 'vec3d': + a1, a2, a3 = self.value + b1, b2, b3 = other.value + return vec3d((a1 + b1, a2 + b2, a3 + b3)) -def get_multiuse_nets(root: list[Op]): - """ - Finds all nets that get accessed more than one time. Therefore +def const_vector3d(x: float, y: float, z: float) -> vec3d: + return vec3d((const(x), const(y), const(z))) + + +def get_multiuse_nets(root: list[Node]) -> set[Net]: + """Finds all nets that get accessed more than one time. Therefore storage on the heap might be better. """ known_nets: set[Net] = set() - def recursiv_node_search(net_list: list[Net]): + def recursiv_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]: for net in net_list: #print(net) if net in known_nets: @@ -140,39 +153,39 @@ def get_multiuse_nets(root: list[Op]): return set(recursiv_node_search(op.args[0] for op in root)) -def get_path_segments(root: list[Op]) -> list[list[Op]]: +def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]: + """List of all possible paths. Ops in order of execution (output at the end) """ - List of all possible paths. Ops in order of execution (output at the end) - """ - def recursiv_node_search(op_list: list[Op], path: list[Op]) -> list[Op]: - for op in op_list: - new_path = [op] + path - if op.args: - yield from recursiv_node_search([net.source for net in op.args], new_path) + def rekursiv_node_search(node_list: Iterable[Node], path: list[Node]) -> Generator[list[Node], None, None]: + for node in node_list: + new_path = [node] + path + if node.args: + yield from rekursiv_node_search([net.source for net in node.args], new_path) else: yield new_path - known_ops: set[Op] = set() - sorted_path_list = sorted(recursiv_node_search(root, []), key=lambda x: -len(x)) + known_nodes: set[Node] = set() + sorted_path_list = sorted(rekursiv_node_search(root, []), key=lambda x: -len(x)) for path in sorted_path_list: sflag = False for i, net in enumerate(path): - if net in known_ops or i == len(path) - 1: + if net in known_nodes or i == len(path) - 1: if sflag: if i > 0: yield path[:i+1] break else: sflag = True - known_ops.add(net) + known_nodes.add(net) -def get_ordered_ops(path_segments: list[list[Op]]) -> list[Op]: +def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, None]: + """Merge in all tree branches at branch position into the path segments + """ finished_paths: set[int] = set() for i, path in enumerate(path_segments): - #print(i) if i not in finished_paths: for op in path: for j in range(i + 1, len(path_segments)): @@ -187,44 +200,93 @@ def get_ordered_ops(path_segments: list[list[Op]]) -> list[Op]: finished_paths.add(i) -def get_consts(op_list: list[Node]): +def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]: net_lookup = {net.source: net for op in op_list for net in op.args} return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)] -def add_read_ops(op_list): - """Add read operation before each op where arguments are not allredy possitioned +def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]: + """Add read operation before each op where arguments are not already positioned correctly in the registers Returns: Yields a tuples of a net and a operation. The net is the result net - from the retuned operation""" - registers = [None] * 16 + from the returned operation""" + registers: list[None | Net] = [None] * 16 - net_lookup = {net.source: net for op in op_list for net in op.args} + net_lookup = {net.source: net for node in node_list for net in node.args} - for op in op_list: - if not op.name.startswith('const_'): - for i, net in enumerate(op.args): + for node in node_list: + if not node.name.startswith('const_'): + for i, net in enumerate(node.args): if net != registers[i]: #if net in registers: # print('x swap registers') - new_op = Op('read_reg' + str(i) + '_' + net.dtype, []) - yield net, new_op + new_node = Op('read_reg' + str(i) + '_' + net.dtype, []) + yield net, new_node registers[i] = net - yield net_lookup.get(op), op - if op in net_lookup: - registers[0] = net_lookup[op] + if node in net_lookup: + yield net_lookup[node], node + registers[0] = net_lookup[node] + else: + print('--->', node) + yield None, node -def add_write_ops(op_list, const_list): - """Add write operation for each new defined net if a read operation is later folowed""" + +def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]: + """Add write operation for each new defined net if a read operation is later flowed""" stored_nets = {c[1] for c in const_list} - read_back_nets = {net for net, op in op_list if op.name.startswith('read_reg')} + read_back_nets = {net for net, node in net_node_list if node.name.startswith('read_reg')} - for net, op in op_list: - yield net, op - if net in read_back_nets and net not in stored_nets: + for net, node in net_node_list: + yield net, node + if net and net in read_back_nets and net not in stored_nets: yield (net, Op('write_' + net.dtype, [net])) - stored_nets.add(net) \ No newline at end of file + stored_nets.add(net) + + +def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer: + + if isinstance(end_nodes, Node): + node_list = [end_nodes] + else: + node_list = end_nodes + + path_segments = list(get_path_segments(node_list)) + ordered_ops = list(get_ordered_ops(path_segments)) + const_list = get_consts(ordered_ops) + output_ops = list(add_read_ops(ordered_ops)) + extended_output_ops = list(add_write_ops(output_ops, const_list)) + + + obj_file: str = 'src/copapy/obj/test4_o0.o' + elf = pelfy.open_elf_file(obj_file) + + dw = binw.data_writer(elf.byteorder) + + prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'} + prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'} + + auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC'] + auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT'] + + + + # write data sections + object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects) + + dw.write_com(binw.Command.ALLOCATE_DATA) + dw.write_int(data_section_lengths) + + for sym, out_offs, lengths in object_list: + if sym.section and sym.section.type != 'SHT_NOBITS': + dw.write_com(binw.Command.COPY_DATA) + dw.write_int(out_offs) + dw.write_int(lengths) + dw.write_bytes(sym.data) + + print('-----') + + return dw diff --git a/src/copapy/binwrite.py b/src/copapy/binwrite.py index 7d714e7..6ee06b3 100644 --- a/src/copapy/binwrite.py +++ b/src/copapy/binwrite.py @@ -8,17 +8,17 @@ Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2), RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)]) -def translate_relocation(new_sym_addr: int, new_patch_addr: int, reloc_type: str, r_addend: int) -> int: +def translate_relocation(new_sym_addr: int, new_patch_addr: int, reloc_type: str, r_addend: int) -> tuple[int, int]: if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'): # S + A - P value = new_sym_addr + r_addend - new_patch_addr - return RelocationType.RELOC_RELATIVE_32, value + return RelocationType.RELOC_RELATIVE_32.value, value else: raise Exception(f"Unknown: {reloc_type}") def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]: - object_list = [] + object_list: list[tuple[elf_symbol, int, int]] = [] out_offs = 0 for sym in symbols: assert sym.info == 'STT_OBJECT' @@ -29,7 +29,7 @@ def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]: - code_list = [] + code_list: list[tuple[elf_symbol, int, int, int]] = [] out_offs = 0 for sym in symbols: assert sym.info == 'STT_FUNC' @@ -95,4 +95,12 @@ class data_writer(): def to_file(self, path: str): with open(path, 'wb') as f: - f.write(self.get_data()) \ No newline at end of file + f.write(self.get_data()) + +def get_c_consts() -> str: + ret: list[str] = [] + for c in Command: + ret.append (f"#define {c.name} {c.value}") + for c in RelocationType: + ret.append(f"#define {c.name} {c.value}") + return '\n'.join(ret) \ No newline at end of file