mirror of https://github.com/Nonannet/copapy.git
type hints added
This commit is contained in:
parent
9143bfef5b
commit
d6b388384e
|
|
@ -1,16 +1,25 @@
|
||||||
import re
|
import re
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
from typing import Generator, Iterable, Any
|
||||||
|
import pelfy
|
||||||
|
from . import binwrite as binw
|
||||||
|
|
||||||
def get_var_name(var, scope=globals()):
|
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
||||||
return [name for name, value in scope.items() if value is var]
|
return [name for name, value in scope.items() if value is var]
|
||||||
|
|
||||||
def _get_c_function_definitions(code: str):
|
def _get_c_function_definitions(code: str) -> dict[str, str]:
|
||||||
ret = re.findall(r".*?void\s+([a-z_1-9]*)\s*\([^\)]*?\)[^\}]*?\{[^\}]*?result_([a-z_]*)\(.*?", code, flags=re.S)
|
ret = re.findall(r".*?void\s+([a-z_1-9]*)\s*\([^\)]*?\)[^\}]*?\{[^\}]*?result_([a-z_]*)\(.*?", code, flags=re.S)
|
||||||
return {r[0]: r[1] for r in ret}
|
return {r[0]: r[1] for r in ret}
|
||||||
|
|
||||||
_function_definitions = _get_c_function_definitions(pkgutil.get_data(__name__, 'ops.c').decode('utf-8'))
|
_ccode = pkgutil.get_data(__name__, 'ops.c')
|
||||||
|
assert _ccode is not None
|
||||||
|
_function_definitions = _get_c_function_definitions(_ccode.decode('utf-8'))
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
|
def __init__(self):
|
||||||
|
self.args: list[Net] = []
|
||||||
|
self.name: str = ''
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
|
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
|
||||||
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})"
|
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})"
|
||||||
|
|
@ -23,28 +32,28 @@ class Net:
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.source = source
|
self.source = source
|
||||||
|
|
||||||
def __mul__(self, other):
|
def __mul__(self, other: Any) -> 'Net':
|
||||||
return _add_op('mul', [self, other], True)
|
return _add_op('mul', [self, other], True)
|
||||||
|
|
||||||
def __rmul__(self, other):
|
def __rmul__(self, other: Any) -> 'Net':
|
||||||
return _add_op('mul', [self, other], True)
|
return _add_op('mul', [self, other], True)
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other: Any) -> 'Net':
|
||||||
return _add_op('add', [self, other], True)
|
return _add_op('add', [self, other], True)
|
||||||
|
|
||||||
def __radd__(self, other):
|
def __radd__(self, other: Any) -> 'Net':
|
||||||
return _add_op('add', [self, other], True)
|
return _add_op('add', [self, other], True)
|
||||||
|
|
||||||
def __sub__ (self, other):
|
def __sub__ (self, other: Any) -> 'Net':
|
||||||
return _add_op('sub', [self, other])
|
return _add_op('sub', [self, other])
|
||||||
|
|
||||||
def __rsub__ (self, other):
|
def __rsub__ (self, other: Any) -> 'Net':
|
||||||
return _add_op('sub', [other, self])
|
return _add_op('sub', [other, self])
|
||||||
|
|
||||||
def __truediv__ (self, other):
|
def __truediv__ (self, other: Any) -> 'Net':
|
||||||
return _add_op('div', [self, other])
|
return _add_op('div', [self, other])
|
||||||
|
|
||||||
def __rtruediv__ (self, other):
|
def __rtruediv__ (self, other: Any) -> 'Net':
|
||||||
return _add_op('div', [other, self])
|
return _add_op('div', [other, self])
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|
@ -73,34 +82,35 @@ class Write(Node):
|
||||||
class Op(Node):
|
class Op(Node):
|
||||||
def __init__(self, typed_op_name: str, args: list[Net]):
|
def __init__(self, typed_op_name: str, args: list[Net]):
|
||||||
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
|
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
|
||||||
self.name = typed_op_name
|
self.name: str = typed_op_name
|
||||||
self.args = args
|
self.args: list[Net] = args
|
||||||
|
|
||||||
def _add_op(op: str, args: list[Net], commutative = False):
|
def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
|
||||||
args = [a if isinstance(a, Net) else const(a) for a in args]
|
arg_nets = [a if isinstance(a, Net) else const(a) for a in args]
|
||||||
|
|
||||||
if commutative:
|
if commutative:
|
||||||
args = sorted(args, key=lambda a: a.dtype)
|
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
|
||||||
|
|
||||||
typed_op = '_'.join([op] + [a.dtype for a in args])
|
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
|
||||||
|
|
||||||
if typed_op not in _function_definitions:
|
if typed_op not in _function_definitions:
|
||||||
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in args])}")
|
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
||||||
|
|
||||||
result_type = _function_definitions[typed_op]
|
result_type = _function_definitions[typed_op]
|
||||||
|
|
||||||
result_net = Net(result_type, Op(typed_op, args))
|
result_net = Net(result_type, Op(typed_op, arg_nets))
|
||||||
|
|
||||||
return result_net
|
return result_net
|
||||||
|
|
||||||
#def read_input(hw: Device, test_value: float):
|
#def read_input(hw: Device, test_value: float):
|
||||||
# return Net(type(value))
|
# return Net(type(value))
|
||||||
|
|
||||||
def const(value: float | int | bool):
|
def const(value: Any) -> Net:
|
||||||
|
assert isinstance(value, (int, float, bool)), f'Unsupported type for const: {type(value).__name__}'
|
||||||
new_const = Const(value)
|
new_const = Const(value)
|
||||||
return Net(new_const.dtype, new_const)
|
return Net(new_const.dtype, new_const)
|
||||||
|
|
||||||
def _get_data_and_dtype(value):
|
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
return ('int', int(value))
|
return ('int', int(value))
|
||||||
elif isinstance(value, float):
|
elif isinstance(value, float):
|
||||||
|
|
@ -110,25 +120,28 @@ def _get_data_and_dtype(value):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Non supported data type: {type(value).__name__}')
|
raise ValueError(f'Non supported data type: {type(value).__name__}')
|
||||||
|
|
||||||
def const_vector3d(x: float, y: float, z: float):
|
|
||||||
return vec3d((const(x), const(y), const(z)))
|
|
||||||
|
|
||||||
class vec3d:
|
class vec3d:
|
||||||
def __init__(self, value: tuple[float, float, float]):
|
def __init__(self, value: tuple[Net, Net, Net]):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other: 'vec3d') -> 'vec3d':
|
||||||
return vec3d(tuple(a+b for a,b in zip(self.value, other.value)))
|
a1, a2, a3 = self.value
|
||||||
|
b1, b2, b3 = other.value
|
||||||
|
return vec3d((a1 + b1, a2 + b2, a3 + b3))
|
||||||
|
|
||||||
|
|
||||||
def get_multiuse_nets(root: list[Op]):
|
def const_vector3d(x: float, y: float, z: float) -> vec3d:
|
||||||
"""
|
return vec3d((const(x), const(y), const(z)))
|
||||||
Finds all nets that get accessed more than one time. Therefore
|
|
||||||
|
|
||||||
|
def get_multiuse_nets(root: list[Node]) -> set[Net]:
|
||||||
|
"""Finds all nets that get accessed more than one time. Therefore
|
||||||
storage on the heap might be better.
|
storage on the heap might be better.
|
||||||
"""
|
"""
|
||||||
known_nets: set[Net] = set()
|
known_nets: set[Net] = set()
|
||||||
|
|
||||||
def recursiv_node_search(net_list: list[Net]):
|
def recursiv_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
|
||||||
for net in net_list:
|
for net in net_list:
|
||||||
#print(net)
|
#print(net)
|
||||||
if net in known_nets:
|
if net in known_nets:
|
||||||
|
|
@ -140,39 +153,39 @@ def get_multiuse_nets(root: list[Op]):
|
||||||
return set(recursiv_node_search(op.args[0] for op in root))
|
return set(recursiv_node_search(op.args[0] for op in root))
|
||||||
|
|
||||||
|
|
||||||
def get_path_segments(root: list[Op]) -> list[list[Op]]:
|
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
|
||||||
|
"""List of all possible paths. Ops in order of execution (output at the end)
|
||||||
"""
|
"""
|
||||||
List of all possible paths. Ops in order of execution (output at the end)
|
def rekursiv_node_search(node_list: Iterable[Node], path: list[Node]) -> Generator[list[Node], None, None]:
|
||||||
"""
|
for node in node_list:
|
||||||
def recursiv_node_search(op_list: list[Op], path: list[Op]) -> list[Op]:
|
new_path = [node] + path
|
||||||
for op in op_list:
|
if node.args:
|
||||||
new_path = [op] + path
|
yield from rekursiv_node_search([net.source for net in node.args], new_path)
|
||||||
if op.args:
|
|
||||||
yield from recursiv_node_search([net.source for net in op.args], new_path)
|
|
||||||
else:
|
else:
|
||||||
yield new_path
|
yield new_path
|
||||||
|
|
||||||
known_ops: set[Op] = set()
|
known_nodes: set[Node] = set()
|
||||||
sorted_path_list = sorted(recursiv_node_search(root, []), key=lambda x: -len(x))
|
sorted_path_list = sorted(rekursiv_node_search(root, []), key=lambda x: -len(x))
|
||||||
|
|
||||||
for path in sorted_path_list:
|
for path in sorted_path_list:
|
||||||
sflag = False
|
sflag = False
|
||||||
for i, net in enumerate(path):
|
for i, net in enumerate(path):
|
||||||
if net in known_ops or i == len(path) - 1:
|
if net in known_nodes or i == len(path) - 1:
|
||||||
if sflag:
|
if sflag:
|
||||||
if i > 0:
|
if i > 0:
|
||||||
yield path[:i+1]
|
yield path[:i+1]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
sflag = True
|
sflag = True
|
||||||
known_ops.add(net)
|
known_nodes.add(net)
|
||||||
|
|
||||||
|
|
||||||
def get_ordered_ops(path_segments: list[list[Op]]) -> list[Op]:
|
def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, None]:
|
||||||
|
"""Merge in all tree branches at branch position into the path segments
|
||||||
|
"""
|
||||||
finished_paths: set[int] = set()
|
finished_paths: set[int] = set()
|
||||||
|
|
||||||
for i, path in enumerate(path_segments):
|
for i, path in enumerate(path_segments):
|
||||||
#print(i)
|
|
||||||
if i not in finished_paths:
|
if i not in finished_paths:
|
||||||
for op in path:
|
for op in path:
|
||||||
for j in range(i + 1, len(path_segments)):
|
for j in range(i + 1, len(path_segments)):
|
||||||
|
|
@ -187,44 +200,93 @@ def get_ordered_ops(path_segments: list[list[Op]]) -> list[Op]:
|
||||||
finished_paths.add(i)
|
finished_paths.add(i)
|
||||||
|
|
||||||
|
|
||||||
def get_consts(op_list: list[Node]):
|
def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
|
||||||
net_lookup = {net.source: net for op in op_list for net in op.args}
|
net_lookup = {net.source: net for op in op_list for net in op.args}
|
||||||
return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)]
|
return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)]
|
||||||
|
|
||||||
|
|
||||||
def add_read_ops(op_list):
|
def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
|
||||||
"""Add read operation before each op where arguments are not allredy possitioned
|
"""Add read operation before each op where arguments are not already positioned
|
||||||
correctly in the registers
|
correctly in the registers
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Yields a tuples of a net and a operation. The net is the result net
|
Yields a tuples of a net and a operation. The net is the result net
|
||||||
from the retuned operation"""
|
from the returned operation"""
|
||||||
registers = [None] * 16
|
registers: list[None | Net] = [None] * 16
|
||||||
|
|
||||||
net_lookup = {net.source: net for op in op_list for net in op.args}
|
net_lookup = {net.source: net for node in node_list for net in node.args}
|
||||||
|
|
||||||
for op in op_list:
|
for node in node_list:
|
||||||
if not op.name.startswith('const_'):
|
if not node.name.startswith('const_'):
|
||||||
for i, net in enumerate(op.args):
|
for i, net in enumerate(node.args):
|
||||||
if net != registers[i]:
|
if net != registers[i]:
|
||||||
#if net in registers:
|
#if net in registers:
|
||||||
# print('x swap registers')
|
# print('x swap registers')
|
||||||
new_op = Op('read_reg' + str(i) + '_' + net.dtype, [])
|
new_node = Op('read_reg' + str(i) + '_' + net.dtype, [])
|
||||||
yield net, new_op
|
yield net, new_node
|
||||||
registers[i] = net
|
registers[i] = net
|
||||||
|
|
||||||
yield net_lookup.get(op), op
|
if node in net_lookup:
|
||||||
if op in net_lookup:
|
yield net_lookup[node], node
|
||||||
registers[0] = net_lookup[op]
|
registers[0] = net_lookup[node]
|
||||||
|
else:
|
||||||
|
print('--->', node)
|
||||||
|
yield None, node
|
||||||
|
|
||||||
|
|
||||||
def add_write_ops(op_list, const_list):
|
|
||||||
"""Add write operation for each new defined net if a read operation is later folowed"""
|
def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]:
|
||||||
|
"""Add write operation for each new defined net if a read operation is later flowed"""
|
||||||
stored_nets = {c[1] for c in const_list}
|
stored_nets = {c[1] for c in const_list}
|
||||||
read_back_nets = {net for net, op in op_list if op.name.startswith('read_reg')}
|
read_back_nets = {net for net, node in net_node_list if node.name.startswith('read_reg')}
|
||||||
|
|
||||||
for net, op in op_list:
|
for net, node in net_node_list:
|
||||||
yield net, op
|
yield net, node
|
||||||
if net in read_back_nets and net not in stored_nets:
|
if net and net in read_back_nets and net not in stored_nets:
|
||||||
yield (net, Op('write_' + net.dtype, [net]))
|
yield (net, Op('write_' + net.dtype, [net]))
|
||||||
stored_nets.add(net)
|
stored_nets.add(net)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer:
|
||||||
|
|
||||||
|
if isinstance(end_nodes, Node):
|
||||||
|
node_list = [end_nodes]
|
||||||
|
else:
|
||||||
|
node_list = end_nodes
|
||||||
|
|
||||||
|
path_segments = list(get_path_segments(node_list))
|
||||||
|
ordered_ops = list(get_ordered_ops(path_segments))
|
||||||
|
const_list = get_consts(ordered_ops)
|
||||||
|
output_ops = list(add_read_ops(ordered_ops))
|
||||||
|
extended_output_ops = list(add_write_ops(output_ops, const_list))
|
||||||
|
|
||||||
|
|
||||||
|
obj_file: str = 'src/copapy/obj/test4_o0.o'
|
||||||
|
elf = pelfy.open_elf_file(obj_file)
|
||||||
|
|
||||||
|
dw = binw.data_writer(elf.byteorder)
|
||||||
|
|
||||||
|
prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'}
|
||||||
|
prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'}
|
||||||
|
|
||||||
|
auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC']
|
||||||
|
auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# write data sections
|
||||||
|
object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects)
|
||||||
|
|
||||||
|
dw.write_com(binw.Command.ALLOCATE_DATA)
|
||||||
|
dw.write_int(data_section_lengths)
|
||||||
|
|
||||||
|
for sym, out_offs, lengths in object_list:
|
||||||
|
if sym.section and sym.section.type != 'SHT_NOBITS':
|
||||||
|
dw.write_com(binw.Command.COPY_DATA)
|
||||||
|
dw.write_int(out_offs)
|
||||||
|
dw.write_int(lengths)
|
||||||
|
dw.write_bytes(sym.data)
|
||||||
|
|
||||||
|
print('-----')
|
||||||
|
|
||||||
|
return dw
|
||||||
|
|
|
||||||
|
|
@ -8,17 +8,17 @@ Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
|
||||||
|
|
||||||
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
|
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
|
||||||
|
|
||||||
def translate_relocation(new_sym_addr: int, new_patch_addr: int, reloc_type: str, r_addend: int) -> int:
|
def translate_relocation(new_sym_addr: int, new_patch_addr: int, reloc_type: str, r_addend: int) -> tuple[int, int]:
|
||||||
if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'):
|
if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'):
|
||||||
# S + A - P
|
# S + A - P
|
||||||
value = new_sym_addr + r_addend - new_patch_addr
|
value = new_sym_addr + r_addend - new_patch_addr
|
||||||
return RelocationType.RELOC_RELATIVE_32, value
|
return RelocationType.RELOC_RELATIVE_32.value, value
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown: {reloc_type}")
|
raise Exception(f"Unknown: {reloc_type}")
|
||||||
|
|
||||||
|
|
||||||
def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]:
|
def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]:
|
||||||
object_list = []
|
object_list: list[tuple[elf_symbol, int, int]] = []
|
||||||
out_offs = 0
|
out_offs = 0
|
||||||
for sym in symbols:
|
for sym in symbols:
|
||||||
assert sym.info == 'STT_OBJECT'
|
assert sym.info == 'STT_OBJECT'
|
||||||
|
|
@ -29,7 +29,7 @@ def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol,
|
||||||
|
|
||||||
|
|
||||||
def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
|
def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
|
||||||
code_list = []
|
code_list: list[tuple[elf_symbol, int, int, int]] = []
|
||||||
out_offs = 0
|
out_offs = 0
|
||||||
for sym in symbols:
|
for sym in symbols:
|
||||||
assert sym.info == 'STT_FUNC'
|
assert sym.info == 'STT_FUNC'
|
||||||
|
|
@ -95,4 +95,12 @@ class data_writer():
|
||||||
|
|
||||||
def to_file(self, path: str):
|
def to_file(self, path: str):
|
||||||
with open(path, 'wb') as f:
|
with open(path, 'wb') as f:
|
||||||
f.write(self.get_data())
|
f.write(self.get_data())
|
||||||
|
|
||||||
|
def get_c_consts() -> str:
|
||||||
|
ret: list[str] = []
|
||||||
|
for c in Command:
|
||||||
|
ret.append (f"#define {c.name} {c.value}")
|
||||||
|
for c in RelocationType:
|
||||||
|
ret.append(f"#define {c.name} {c.value}")
|
||||||
|
return '\n'.join(ret)
|
||||||
Loading…
Reference in New Issue