type hints added

This commit is contained in:
Nicolas Kruse 2025-08-29 22:58:10 +02:00
parent 9143bfef5b
commit d6b388384e
2 changed files with 140 additions and 70 deletions

View File

@ -1,16 +1,25 @@
import re import re
import pkgutil import pkgutil
from typing import Generator, Iterable, Any
import pelfy
from . import binwrite as binw
def get_var_name(var, scope=globals()): def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
return [name for name, value in scope.items() if value is var] return [name for name, value in scope.items() if value is var]
def _get_c_function_definitions(code: str): def _get_c_function_definitions(code: str) -> dict[str, str]:
ret = re.findall(r".*?void\s+([a-z_1-9]*)\s*\([^\)]*?\)[^\}]*?\{[^\}]*?result_([a-z_]*)\(.*?", code, flags=re.S) ret = re.findall(r".*?void\s+([a-z_1-9]*)\s*\([^\)]*?\)[^\}]*?\{[^\}]*?result_([a-z_]*)\(.*?", code, flags=re.S)
return {r[0]: r[1] for r in ret} return {r[0]: r[1] for r in ret}
_function_definitions = _get_c_function_definitions(pkgutil.get_data(__name__, 'ops.c').decode('utf-8')) _ccode = pkgutil.get_data(__name__, 'ops.c')
assert _ccode is not None
_function_definitions = _get_c_function_definitions(_ccode.decode('utf-8'))
class Node: class Node:
def __init__(self):
self.args: list[Net] = []
self.name: str = ''
def __repr__(self): def __repr__(self):
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})" #return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})" return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})"
@ -23,28 +32,28 @@ class Net:
self.dtype = dtype self.dtype = dtype
self.source = source self.source = source
def __mul__(self, other): def __mul__(self, other: Any) -> 'Net':
return _add_op('mul', [self, other], True) return _add_op('mul', [self, other], True)
def __rmul__(self, other): def __rmul__(self, other: Any) -> 'Net':
return _add_op('mul', [self, other], True) return _add_op('mul', [self, other], True)
def __add__(self, other): def __add__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True) return _add_op('add', [self, other], True)
def __radd__(self, other): def __radd__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True) return _add_op('add', [self, other], True)
def __sub__ (self, other): def __sub__ (self, other: Any) -> 'Net':
return _add_op('sub', [self, other]) return _add_op('sub', [self, other])
def __rsub__ (self, other): def __rsub__ (self, other: Any) -> 'Net':
return _add_op('sub', [other, self]) return _add_op('sub', [other, self])
def __truediv__ (self, other): def __truediv__ (self, other: Any) -> 'Net':
return _add_op('div', [self, other]) return _add_op('div', [self, other])
def __rtruediv__ (self, other): def __rtruediv__ (self, other: Any) -> 'Net':
return _add_op('div', [other, self]) return _add_op('div', [other, self])
def __repr__(self): def __repr__(self):
@ -73,34 +82,35 @@ class Write(Node):
class Op(Node): class Op(Node):
def __init__(self, typed_op_name: str, args: list[Net]): def __init__(self, typed_op_name: str, args: list[Net]):
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
self.name = typed_op_name self.name: str = typed_op_name
self.args = args self.args: list[Net] = args
def _add_op(op: str, args: list[Net], commutative = False): def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
args = [a if isinstance(a, Net) else const(a) for a in args] arg_nets = [a if isinstance(a, Net) else const(a) for a in args]
if commutative: if commutative:
args = sorted(args, key=lambda a: a.dtype) arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
typed_op = '_'.join([op] + [a.dtype for a in args]) typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
if typed_op not in _function_definitions: if typed_op not in _function_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in args])}") raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
result_type = _function_definitions[typed_op] result_type = _function_definitions[typed_op]
result_net = Net(result_type, Op(typed_op, args)) result_net = Net(result_type, Op(typed_op, arg_nets))
return result_net return result_net
#def read_input(hw: Device, test_value: float): #def read_input(hw: Device, test_value: float):
# return Net(type(value)) # return Net(type(value))
def const(value: float | int | bool): def const(value: Any) -> Net:
assert isinstance(value, (int, float, bool)), f'Unsupported type for const: {type(value).__name__}'
new_const = Const(value) new_const = Const(value)
return Net(new_const.dtype, new_const) return Net(new_const.dtype, new_const)
def _get_data_and_dtype(value): def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
if isinstance(value, int): if isinstance(value, int):
return ('int', int(value)) return ('int', int(value))
elif isinstance(value, float): elif isinstance(value, float):
@ -110,25 +120,28 @@ def _get_data_and_dtype(value):
else: else:
raise ValueError(f'Non supported data type: {type(value).__name__}') raise ValueError(f'Non supported data type: {type(value).__name__}')
def const_vector3d(x: float, y: float, z: float):
return vec3d((const(x), const(y), const(z)))
class vec3d: class vec3d:
def __init__(self, value: tuple[float, float, float]): def __init__(self, value: tuple[Net, Net, Net]):
self.value = value self.value = value
def __add__(self, other): def __add__(self, other: 'vec3d') -> 'vec3d':
return vec3d(tuple(a+b for a,b in zip(self.value, other.value))) a1, a2, a3 = self.value
b1, b2, b3 = other.value
return vec3d((a1 + b1, a2 + b2, a3 + b3))
def get_multiuse_nets(root: list[Op]): def const_vector3d(x: float, y: float, z: float) -> vec3d:
""" return vec3d((const(x), const(y), const(z)))
Finds all nets that get accessed more than one time. Therefore
def get_multiuse_nets(root: list[Node]) -> set[Net]:
"""Finds all nets that get accessed more than one time. Therefore
storage on the heap might be better. storage on the heap might be better.
""" """
known_nets: set[Net] = set() known_nets: set[Net] = set()
def recursiv_node_search(net_list: list[Net]): def recursiv_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
for net in net_list: for net in net_list:
#print(net) #print(net)
if net in known_nets: if net in known_nets:
@ -140,39 +153,39 @@ def get_multiuse_nets(root: list[Op]):
return set(recursiv_node_search(op.args[0] for op in root)) return set(recursiv_node_search(op.args[0] for op in root))
def get_path_segments(root: list[Op]) -> list[list[Op]]: def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
"""List of all possible paths. Ops in order of execution (output at the end)
""" """
List of all possible paths. Ops in order of execution (output at the end) def rekursiv_node_search(node_list: Iterable[Node], path: list[Node]) -> Generator[list[Node], None, None]:
""" for node in node_list:
def recursiv_node_search(op_list: list[Op], path: list[Op]) -> list[Op]: new_path = [node] + path
for op in op_list: if node.args:
new_path = [op] + path yield from rekursiv_node_search([net.source for net in node.args], new_path)
if op.args:
yield from recursiv_node_search([net.source for net in op.args], new_path)
else: else:
yield new_path yield new_path
known_ops: set[Op] = set() known_nodes: set[Node] = set()
sorted_path_list = sorted(recursiv_node_search(root, []), key=lambda x: -len(x)) sorted_path_list = sorted(rekursiv_node_search(root, []), key=lambda x: -len(x))
for path in sorted_path_list: for path in sorted_path_list:
sflag = False sflag = False
for i, net in enumerate(path): for i, net in enumerate(path):
if net in known_ops or i == len(path) - 1: if net in known_nodes or i == len(path) - 1:
if sflag: if sflag:
if i > 0: if i > 0:
yield path[:i+1] yield path[:i+1]
break break
else: else:
sflag = True sflag = True
known_ops.add(net) known_nodes.add(net)
def get_ordered_ops(path_segments: list[list[Op]]) -> list[Op]: def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, None]:
"""Merge in all tree branches at branch position into the path segments
"""
finished_paths: set[int] = set() finished_paths: set[int] = set()
for i, path in enumerate(path_segments): for i, path in enumerate(path_segments):
#print(i)
if i not in finished_paths: if i not in finished_paths:
for op in path: for op in path:
for j in range(i + 1, len(path_segments)): for j in range(i + 1, len(path_segments)):
@ -187,44 +200,93 @@ def get_ordered_ops(path_segments: list[list[Op]]) -> list[Op]:
finished_paths.add(i) finished_paths.add(i)
def get_consts(op_list: list[Node]): def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
net_lookup = {net.source: net for op in op_list for net in op.args} net_lookup = {net.source: net for op in op_list for net in op.args}
return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)] return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)]
def add_read_ops(op_list): def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add read operation before each op where arguments are not allredy possitioned """Add read operation before each op where arguments are not already positioned
correctly in the registers correctly in the registers
Returns: Returns:
Yields a tuples of a net and a operation. The net is the result net Yields a tuples of a net and a operation. The net is the result net
from the retuned operation""" from the returned operation"""
registers = [None] * 16 registers: list[None | Net] = [None] * 16
net_lookup = {net.source: net for op in op_list for net in op.args} net_lookup = {net.source: net for node in node_list for net in node.args}
for op in op_list: for node in node_list:
if not op.name.startswith('const_'): if not node.name.startswith('const_'):
for i, net in enumerate(op.args): for i, net in enumerate(node.args):
if net != registers[i]: if net != registers[i]:
#if net in registers: #if net in registers:
# print('x swap registers') # print('x swap registers')
new_op = Op('read_reg' + str(i) + '_' + net.dtype, []) new_node = Op('read_reg' + str(i) + '_' + net.dtype, [])
yield net, new_op yield net, new_node
registers[i] = net registers[i] = net
yield net_lookup.get(op), op if node in net_lookup:
if op in net_lookup: yield net_lookup[node], node
registers[0] = net_lookup[op] registers[0] = net_lookup[node]
else:
print('--->', node)
yield None, node
def add_write_ops(op_list, const_list):
"""Add write operation for each new defined net if a read operation is later folowed""" def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add write operation for each new defined net if a read operation is later flowed"""
stored_nets = {c[1] for c in const_list} stored_nets = {c[1] for c in const_list}
read_back_nets = {net for net, op in op_list if op.name.startswith('read_reg')} read_back_nets = {net for net, node in net_node_list if node.name.startswith('read_reg')}
for net, op in op_list: for net, node in net_node_list:
yield net, op yield net, node
if net in read_back_nets and net not in stored_nets: if net and net in read_back_nets and net not in stored_nets:
yield (net, Op('write_' + net.dtype, [net])) yield (net, Op('write_' + net.dtype, [net]))
stored_nets.add(net) stored_nets.add(net)
def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer:
if isinstance(end_nodes, Node):
node_list = [end_nodes]
else:
node_list = end_nodes
path_segments = list(get_path_segments(node_list))
ordered_ops = list(get_ordered_ops(path_segments))
const_list = get_consts(ordered_ops)
output_ops = list(add_read_ops(ordered_ops))
extended_output_ops = list(add_write_ops(output_ops, const_list))
obj_file: str = 'src/copapy/obj/test4_o0.o'
elf = pelfy.open_elf_file(obj_file)
dw = binw.data_writer(elf.byteorder)
prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'}
prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'}
auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC']
auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT']
# write data sections
object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects)
dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths)
for sym, out_offs, lengths in object_list:
if sym.section and sym.section.type != 'SHT_NOBITS':
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs)
dw.write_int(lengths)
dw.write_bytes(sym.data)
print('-----')
return dw

View File

@ -8,17 +8,17 @@ Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)]) RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
def translate_relocation(new_sym_addr: int, new_patch_addr: int, reloc_type: str, r_addend: int) -> int: def translate_relocation(new_sym_addr: int, new_patch_addr: int, reloc_type: str, r_addend: int) -> tuple[int, int]:
if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'): if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'):
# S + A - P # S + A - P
value = new_sym_addr + r_addend - new_patch_addr value = new_sym_addr + r_addend - new_patch_addr
return RelocationType.RELOC_RELATIVE_32, value return RelocationType.RELOC_RELATIVE_32.value, value
else: else:
raise Exception(f"Unknown: {reloc_type}") raise Exception(f"Unknown: {reloc_type}")
def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]: def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]:
object_list = [] object_list: list[tuple[elf_symbol, int, int]] = []
out_offs = 0 out_offs = 0
for sym in symbols: for sym in symbols:
assert sym.info == 'STT_OBJECT' assert sym.info == 'STT_OBJECT'
@ -29,7 +29,7 @@ def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol,
def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]: def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
code_list = [] code_list: list[tuple[elf_symbol, int, int, int]] = []
out_offs = 0 out_offs = 0
for sym in symbols: for sym in symbols:
assert sym.info == 'STT_FUNC' assert sym.info == 'STT_FUNC'
@ -95,4 +95,12 @@ class data_writer():
def to_file(self, path: str): def to_file(self, path: str):
with open(path, 'wb') as f: with open(path, 'wb') as f:
f.write(self.get_data()) f.write(self.get_data())
def get_c_consts() -> str:
ret: list[str] = []
for c in Command:
ret.append (f"#define {c.name} {c.value}")
for c in RelocationType:
ret.append(f"#define {c.name} {c.value}")
return '\n'.join(ret)