type hints added

This commit is contained in:
Nicolas Kruse 2025-08-29 22:58:10 +02:00
parent 9143bfef5b
commit d6b388384e
2 changed files with 140 additions and 70 deletions

View File

@ -1,16 +1,25 @@
import re
import pkgutil
from typing import Generator, Iterable, Any
import pelfy
from . import binwrite as binw
def get_var_name(var, scope=globals()):
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
return [name for name, value in scope.items() if value is var]
def _get_c_function_definitions(code: str):
def _get_c_function_definitions(code: str) -> dict[str, str]:
ret = re.findall(r".*?void\s+([a-z_1-9]*)\s*\([^\)]*?\)[^\}]*?\{[^\}]*?result_([a-z_]*)\(.*?", code, flags=re.S)
return {r[0]: r[1] for r in ret}
_function_definitions = _get_c_function_definitions(pkgutil.get_data(__name__, 'ops.c').decode('utf-8'))
_ccode = pkgutil.get_data(__name__, 'ops.c')
assert _ccode is not None
_function_definitions = _get_c_function_definitions(_ccode.decode('utf-8'))
class Node:
def __init__(self):
self.args: list[Net] = []
self.name: str = ''
def __repr__(self):
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})"
@ -23,28 +32,28 @@ class Net:
self.dtype = dtype
self.source = source
def __mul__(self, other):
def __mul__(self, other: Any) -> 'Net':
return _add_op('mul', [self, other], True)
def __rmul__(self, other):
def __rmul__(self, other: Any) -> 'Net':
return _add_op('mul', [self, other], True)
def __add__(self, other):
def __add__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True)
def __radd__(self, other):
def __radd__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True)
def __sub__ (self, other):
def __sub__ (self, other: Any) -> 'Net':
return _add_op('sub', [self, other])
def __rsub__ (self, other):
def __rsub__ (self, other: Any) -> 'Net':
return _add_op('sub', [other, self])
def __truediv__ (self, other):
def __truediv__ (self, other: Any) -> 'Net':
return _add_op('div', [self, other])
def __rtruediv__ (self, other):
def __rtruediv__ (self, other: Any) -> 'Net':
return _add_op('div', [other, self])
def __repr__(self):
@ -73,34 +82,35 @@ class Write(Node):
class Op(Node):
def __init__(self, typed_op_name: str, args: list[Net]):
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
self.name = typed_op_name
self.args = args
self.name: str = typed_op_name
self.args: list[Net] = args
def _add_op(op: str, args: list[Net], commutative = False):
args = [a if isinstance(a, Net) else const(a) for a in args]
def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
arg_nets = [a if isinstance(a, Net) else const(a) for a in args]
if commutative:
args = sorted(args, key=lambda a: a.dtype)
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
typed_op = '_'.join([op] + [a.dtype for a in args])
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
if typed_op not in _function_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in args])}")
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
result_type = _function_definitions[typed_op]
result_net = Net(result_type, Op(typed_op, args))
result_net = Net(result_type, Op(typed_op, arg_nets))
return result_net
#def read_input(hw: Device, test_value: float):
# return Net(type(value))
def const(value: float | int | bool):
def const(value: Any) -> Net:
assert isinstance(value, (int, float, bool)), f'Unsupported type for const: {type(value).__name__}'
new_const = Const(value)
return Net(new_const.dtype, new_const)
def _get_data_and_dtype(value):
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
if isinstance(value, int):
return ('int', int(value))
elif isinstance(value, float):
@ -110,25 +120,28 @@ def _get_data_and_dtype(value):
else:
raise ValueError(f'Non supported data type: {type(value).__name__}')
def const_vector3d(x: float, y: float, z: float):
return vec3d((const(x), const(y), const(z)))
class vec3d:
def __init__(self, value: tuple[float, float, float]):
def __init__(self, value: tuple[Net, Net, Net]):
self.value = value
def __add__(self, other):
return vec3d(tuple(a+b for a,b in zip(self.value, other.value)))
def __add__(self, other: 'vec3d') -> 'vec3d':
a1, a2, a3 = self.value
b1, b2, b3 = other.value
return vec3d((a1 + b1, a2 + b2, a3 + b3))
def get_multiuse_nets(root: list[Op]):
"""
Finds all nets that get accessed more than one time. Therefore
def const_vector3d(x: float, y: float, z: float) -> vec3d:
return vec3d((const(x), const(y), const(z)))
def get_multiuse_nets(root: list[Node]) -> set[Net]:
"""Finds all nets that get accessed more than one time. Therefore
storage on the heap might be better.
"""
known_nets: set[Net] = set()
def recursiv_node_search(net_list: list[Net]):
def recursiv_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
for net in net_list:
#print(net)
if net in known_nets:
@ -140,39 +153,39 @@ def get_multiuse_nets(root: list[Op]):
return set(recursiv_node_search(op.args[0] for op in root))
def get_path_segments(root: list[Op]) -> list[list[Op]]:
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
"""List of all possible paths. Ops in order of execution (output at the end)
"""
List of all possible paths. Ops in order of execution (output at the end)
"""
def recursiv_node_search(op_list: list[Op], path: list[Op]) -> list[Op]:
for op in op_list:
new_path = [op] + path
if op.args:
yield from recursiv_node_search([net.source for net in op.args], new_path)
def rekursiv_node_search(node_list: Iterable[Node], path: list[Node]) -> Generator[list[Node], None, None]:
for node in node_list:
new_path = [node] + path
if node.args:
yield from rekursiv_node_search([net.source for net in node.args], new_path)
else:
yield new_path
known_ops: set[Op] = set()
sorted_path_list = sorted(recursiv_node_search(root, []), key=lambda x: -len(x))
known_nodes: set[Node] = set()
sorted_path_list = sorted(rekursiv_node_search(root, []), key=lambda x: -len(x))
for path in sorted_path_list:
sflag = False
for i, net in enumerate(path):
if net in known_ops or i == len(path) - 1:
if net in known_nodes or i == len(path) - 1:
if sflag:
if i > 0:
yield path[:i+1]
break
else:
sflag = True
known_ops.add(net)
known_nodes.add(net)
def get_ordered_ops(path_segments: list[list[Op]]) -> list[Op]:
def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, None]:
"""Merge in all tree branches at branch position into the path segments
"""
finished_paths: set[int] = set()
for i, path in enumerate(path_segments):
#print(i)
if i not in finished_paths:
for op in path:
for j in range(i + 1, len(path_segments)):
@ -187,44 +200,93 @@ def get_ordered_ops(path_segments: list[list[Op]]) -> list[Op]:
finished_paths.add(i)
def get_consts(op_list: list[Node]):
def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
net_lookup = {net.source: net for op in op_list for net in op.args}
return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)]
def add_read_ops(op_list):
"""Add read operation before each op where arguments are not allredy possitioned
def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add read operation before each op where arguments are not already positioned
correctly in the registers
Returns:
Yields a tuples of a net and a operation. The net is the result net
from the retuned operation"""
registers = [None] * 16
from the returned operation"""
registers: list[None | Net] = [None] * 16
net_lookup = {net.source: net for op in op_list for net in op.args}
net_lookup = {net.source: net for node in node_list for net in node.args}
for op in op_list:
if not op.name.startswith('const_'):
for i, net in enumerate(op.args):
for node in node_list:
if not node.name.startswith('const_'):
for i, net in enumerate(node.args):
if net != registers[i]:
#if net in registers:
# print('x swap registers')
new_op = Op('read_reg' + str(i) + '_' + net.dtype, [])
yield net, new_op
new_node = Op('read_reg' + str(i) + '_' + net.dtype, [])
yield net, new_node
registers[i] = net
yield net_lookup.get(op), op
if op in net_lookup:
registers[0] = net_lookup[op]
if node in net_lookup:
yield net_lookup[node], node
registers[0] = net_lookup[node]
else:
print('--->', node)
yield None, node
def add_write_ops(op_list, const_list):
"""Add write operation for each new defined net if a read operation is later folowed"""
def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add write operation for each new defined net if a read operation is later flowed"""
stored_nets = {c[1] for c in const_list}
read_back_nets = {net for net, op in op_list if op.name.startswith('read_reg')}
read_back_nets = {net for net, node in net_node_list if node.name.startswith('read_reg')}
for net, op in op_list:
yield net, op
if net in read_back_nets and net not in stored_nets:
for net, node in net_node_list:
yield net, node
if net and net in read_back_nets and net not in stored_nets:
yield (net, Op('write_' + net.dtype, [net]))
stored_nets.add(net)
def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer:
if isinstance(end_nodes, Node):
node_list = [end_nodes]
else:
node_list = end_nodes
path_segments = list(get_path_segments(node_list))
ordered_ops = list(get_ordered_ops(path_segments))
const_list = get_consts(ordered_ops)
output_ops = list(add_read_ops(ordered_ops))
extended_output_ops = list(add_write_ops(output_ops, const_list))
obj_file: str = 'src/copapy/obj/test4_o0.o'
elf = pelfy.open_elf_file(obj_file)
dw = binw.data_writer(elf.byteorder)
prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'}
prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'}
auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC']
auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT']
# write data sections
object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects)
dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths)
for sym, out_offs, lengths in object_list:
if sym.section and sym.section.type != 'SHT_NOBITS':
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs)
dw.write_int(lengths)
dw.write_bytes(sym.data)
print('-----')
return dw

View File

@ -8,17 +8,17 @@ Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
def translate_relocation(new_sym_addr: int, new_patch_addr: int, reloc_type: str, r_addend: int) -> int:
def translate_relocation(new_sym_addr: int, new_patch_addr: int, reloc_type: str, r_addend: int) -> tuple[int, int]:
if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'):
# S + A - P
value = new_sym_addr + r_addend - new_patch_addr
return RelocationType.RELOC_RELATIVE_32, value
return RelocationType.RELOC_RELATIVE_32.value, value
else:
raise Exception(f"Unknown: {reloc_type}")
def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]:
object_list = []
object_list: list[tuple[elf_symbol, int, int]] = []
out_offs = 0
for sym in symbols:
assert sym.info == 'STT_OBJECT'
@ -29,7 +29,7 @@ def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol,
def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
code_list = []
code_list: list[tuple[elf_symbol, int, int, int]] = []
out_offs = 0
for sym in symbols:
assert sym.info == 'STT_FUNC'
@ -96,3 +96,11 @@ class data_writer():
def to_file(self, path: str):
with open(path, 'wb') as f:
f.write(self.get_data())
def get_c_consts() -> str:
ret: list[str] = []
for c in Command:
ret.append (f"#define {c.name} {c.value}")
for c in RelocationType:
ret.append(f"#define {c.name} {c.value}")
return '\n'.join(ret)