mirror of https://github.com/Nonannet/copapy.git
code cleaned up, passes flake8 and mypy
This commit is contained in:
parent
3b7fcdb88b
commit
5b7ca52b7c
1
.flake8
1
.flake8
|
|
@ -13,6 +13,7 @@ exclude =
|
||||||
build,
|
build,
|
||||||
dist,
|
dist,
|
||||||
.conda,
|
.conda,
|
||||||
|
.venv,
|
||||||
tests/autogenerated_*
|
tests/autogenerated_*
|
||||||
|
|
||||||
# Enable specific plugins or options
|
# Enable specific plugins or options
|
||||||
|
|
|
||||||
|
|
@ -3,22 +3,25 @@ from typing import Generator, Iterable, Any
|
||||||
from . import binwrite as binw
|
from . import binwrite as binw
|
||||||
from .stencil_db import stencil_database
|
from .stencil_db import stencil_database
|
||||||
|
|
||||||
Operand = type['Net'] | float | int
|
Operand = type['Net'] | float | int
|
||||||
|
|
||||||
|
|
||||||
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
||||||
return [name for name, value in scope.items() if value is var]
|
return [name for name, value in scope.items() if value is var]
|
||||||
|
|
||||||
# _ccode = pkgutil.get_data(__name__, 'stencils.c')
|
# _ccode = pkgutil.get_data(__name__, 'stencils.c')
|
||||||
# assert _ccode is not None
|
# assert _ccode is not None
|
||||||
|
|
||||||
|
|
||||||
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
|
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.args: list[Net] = []
|
self.args: list[Net] = []
|
||||||
self.name: str = ''
|
self.name: str = ''
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
|
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
|
||||||
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})"
|
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})"
|
||||||
|
|
||||||
|
|
@ -37,26 +40,26 @@ class Net:
|
||||||
|
|
||||||
def __rmul__(self, other: Any) -> 'Net':
|
def __rmul__(self, other: Any) -> 'Net':
|
||||||
return _add_op('mul', [self, other], True)
|
return _add_op('mul', [self, other], True)
|
||||||
|
|
||||||
def __add__(self, other: Any) -> 'Net':
|
def __add__(self, other: Any) -> 'Net':
|
||||||
return _add_op('add', [self, other], True)
|
return _add_op('add', [self, other], True)
|
||||||
|
|
||||||
def __radd__(self, other: Any) -> 'Net':
|
def __radd__(self, other: Any) -> 'Net':
|
||||||
return _add_op('add', [self, other], True)
|
return _add_op('add', [self, other], True)
|
||||||
|
|
||||||
def __sub__ (self, other: Any) -> 'Net':
|
def __sub__(self, other: Any) -> 'Net':
|
||||||
return _add_op('sub', [self, other])
|
return _add_op('sub', [self, other])
|
||||||
|
|
||||||
def __rsub__ (self, other: Any) -> 'Net':
|
def __rsub__(self, other: Any) -> 'Net':
|
||||||
return _add_op('sub', [other, self])
|
return _add_op('sub', [other, self])
|
||||||
|
|
||||||
def __truediv__ (self, other: Any) -> 'Net':
|
def __truediv__(self, other: Any) -> 'Net':
|
||||||
return _add_op('div', [self, other])
|
return _add_op('div', [self, other])
|
||||||
|
|
||||||
def __rtruediv__ (self, other: Any) -> 'Net':
|
def __rtruediv__(self, other: Any) -> 'Net':
|
||||||
return _add_op('div', [other, self])
|
return _add_op('div', [other, self])
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
names = get_var_name(self)
|
names = get_var_name(self)
|
||||||
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}"
|
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}"
|
||||||
|
|
||||||
|
|
@ -94,7 +97,7 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
|
||||||
if commutative:
|
if commutative:
|
||||||
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
|
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
|
||||||
|
|
||||||
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
|
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
|
||||||
|
|
||||||
if typed_op not in sdb.function_definitions:
|
if typed_op not in sdb.function_definitions:
|
||||||
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
||||||
|
|
@ -141,24 +144,6 @@ def const_vector3d(x: float, y: float, z: float) -> vec3d:
|
||||||
return vec3d((const(x), const(y), const(z)))
|
return vec3d((const(x), const(y), const(z)))
|
||||||
|
|
||||||
|
|
||||||
def get_multiuse_nets(root: list[Node]) -> set[Net]:
|
|
||||||
"""Finds all nets that get accessed more than one time. Therefore
|
|
||||||
storage on the heap might be better.
|
|
||||||
"""
|
|
||||||
known_nets: set[Net] = set()
|
|
||||||
|
|
||||||
def recursive_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
|
|
||||||
for net in net_list:
|
|
||||||
#print(net)
|
|
||||||
if net in known_nets:
|
|
||||||
yield net
|
|
||||||
else:
|
|
||||||
known_nets.add(net)
|
|
||||||
yield from recursive_node_search(net.source.args)
|
|
||||||
|
|
||||||
return set(recursive_node_search(op.args[0] for op in root))
|
|
||||||
|
|
||||||
|
|
||||||
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
|
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
|
||||||
"""List of all possible paths. Ops in order of execution (output at the end)
|
"""List of all possible paths. Ops in order of execution (output at the end)
|
||||||
"""
|
"""
|
||||||
|
|
@ -190,7 +175,7 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No
|
||||||
"""Merge in all tree branches at branch position into the path segments
|
"""Merge in all tree branches at branch position into the path segments
|
||||||
"""
|
"""
|
||||||
finished_paths: set[int] = set()
|
finished_paths: set[int] = set()
|
||||||
|
|
||||||
for i, path in enumerate(path_segments):
|
for i, path in enumerate(path_segments):
|
||||||
if i not in finished_paths:
|
if i not in finished_paths:
|
||||||
for op in path:
|
for op in path:
|
||||||
|
|
@ -205,10 +190,10 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No
|
||||||
yield op
|
yield op
|
||||||
finished_paths.add(i)
|
finished_paths.add(i)
|
||||||
|
|
||||||
|
|
||||||
def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
|
def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
|
||||||
"""Get all const nodes in the op list
|
"""Get all const nodes in the op list
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tuples of (name, net, value)"""
|
List of tuples of (name, net, value)"""
|
||||||
net_lookup = {net.source: net for op in op_list for net in op.args}
|
net_lookup = {net.source: net for op in op_list for net in op.args}
|
||||||
|
|
@ -218,7 +203,7 @@ def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
|
||||||
def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
|
def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
|
||||||
"""Add read operation before each op where arguments are not already positioned
|
"""Add read operation before each op where arguments are not already positioned
|
||||||
correctly in the registers
|
correctly in the registers
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Yields tuples of a net and a operation. The net is only provided
|
Yields tuples of a net and a operation. The net is only provided
|
||||||
for new added read operations. Otherwise None is returned in the tuple."""
|
for new added read operations. Otherwise None is returned in the tuple."""
|
||||||
|
|
@ -226,7 +211,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
|
||||||
|
|
||||||
# Generate result net lookup table
|
# Generate result net lookup table
|
||||||
net_lookup = {net.source: net for node in node_list for net in node.args}
|
net_lookup = {net.source: net for node in node_list for net in node.args}
|
||||||
|
|
||||||
for node in node_list:
|
for node in node_list:
|
||||||
if not node.name.startswith('const_'):
|
if not node.name.startswith('const_'):
|
||||||
for i, net in enumerate(node.args):
|
for i, net in enumerate(node.args):
|
||||||
|
|
@ -238,9 +223,9 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
|
||||||
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
|
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
|
||||||
yield net, new_node
|
yield net, new_node
|
||||||
registers[i] = net
|
registers[i] = net
|
||||||
|
|
||||||
if node in net_lookup:
|
if node in net_lookup:
|
||||||
yield None , node
|
yield None, node
|
||||||
registers[0] = net_lookup[node]
|
registers[0] = net_lookup[node]
|
||||||
else:
|
else:
|
||||||
print('--->', node)
|
print('--->', node)
|
||||||
|
|
@ -283,7 +268,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
|
||||||
if isinstance(end_nodes, Node):
|
if isinstance(end_nodes, Node):
|
||||||
node_list = [end_nodes]
|
node_list = [end_nodes]
|
||||||
else:
|
else:
|
||||||
node_list = end_nodes
|
node_list = list(end_nodes)
|
||||||
|
|
||||||
path_segments = list(get_path_segments(node_list))
|
path_segments = list(get_path_segments(node_list))
|
||||||
ordered_ops = list(get_ordered_ops(path_segments))
|
ordered_ops = list(get_ordered_ops(path_segments))
|
||||||
|
|
@ -297,10 +282,8 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
|
||||||
# Get all nets associated with heap memory
|
# Get all nets associated with heap memory
|
||||||
variable_list = get_nets(const_list, extended_output_ops)
|
variable_list = get_nets(const_list, extended_output_ops)
|
||||||
|
|
||||||
|
|
||||||
dw = binw.data_writer(sdb.byteorder)
|
dw = binw.data_writer(sdb.byteorder)
|
||||||
|
|
||||||
|
|
||||||
def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]:
|
def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]:
|
||||||
offset: int = 0
|
offset: int = 0
|
||||||
object_list: list[tuple[Net, int, int]] = []
|
object_list: list[tuple[Net, int, int]] = []
|
||||||
|
|
@ -312,14 +295,14 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
|
||||||
|
|
||||||
return object_list, offset
|
return object_list, offset
|
||||||
|
|
||||||
|
|
||||||
object_list, data_section_lengths = variable_mem_layout(variable_list)
|
object_list, data_section_lengths = variable_mem_layout(variable_list)
|
||||||
|
|
||||||
#data_section_lengths = object_list[-1][1] + object_list[-1][2]
|
# Write data
|
||||||
dw.write_com(binw.Command.ALLOCATE_DATA)
|
dw.write_com(binw.Command.ALLOCATE_DATA)
|
||||||
dw.write_int(data_section_lengths)
|
dw.write_int(data_section_lengths)
|
||||||
|
|
||||||
for net, out_offs, lengths in object_list:
|
for net, out_offs, lengths in object_list:
|
||||||
|
dw.add_variable(net, out_offs, lengths, net.dtype)
|
||||||
if isinstance(net.source, Const):
|
if isinstance(net.source, Const):
|
||||||
dw.write_com(binw.Command.COPY_DATA)
|
dw.write_com(binw.Command.COPY_DATA)
|
||||||
dw.write_int(out_offs)
|
dw.write_int(out_offs)
|
||||||
|
|
@ -330,7 +313,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
|
||||||
# write auxiliary_functions
|
# write auxiliary_functions
|
||||||
# TODO
|
# TODO
|
||||||
|
|
||||||
# Prepare program data and relocations
|
# Prepare program code and relocations
|
||||||
object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list}
|
object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list}
|
||||||
data_list: list[bytes] = []
|
data_list: list[bytes] = []
|
||||||
patch_list: list[tuple[int, int, int]] = []
|
patch_list: list[tuple[int, int, int]] = []
|
||||||
|
|
@ -347,7 +330,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
|
||||||
data = sdb.get_func_data(node.name)
|
data = sdb.get_func_data(node.name)
|
||||||
data_list.append(data)
|
data_list.append(data)
|
||||||
print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
|
print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
|
||||||
|
|
||||||
for patch in sdb.get_patch_positions(node.name):
|
for patch in sdb.get_patch_positions(node.name):
|
||||||
assert result_net, f"Relocation found but no net defined for operation {node.name}"
|
assert result_net, f"Relocation found but no net defined for operation {node.name}"
|
||||||
object_addr = object_addr_lookp[result_net]
|
object_addr = object_addr_lookp[result_net]
|
||||||
|
|
@ -383,7 +366,12 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
|
||||||
dw.write_com(binw.Command.SET_ENTR_POINT)
|
dw.write_com(binw.Command.SET_ENTR_POINT)
|
||||||
dw.write_int(0)
|
dw.write_int(0)
|
||||||
|
|
||||||
# run program command
|
|
||||||
dw.write_com(binw.Command.END_PROG)
|
|
||||||
|
|
||||||
return dw
|
return dw
|
||||||
|
|
||||||
|
|
||||||
|
def read_variable(bw: binw.data_writer, net: Net) -> None:
|
||||||
|
assert net in bw.variables, f"Variable {net} not found in data writer variables"
|
||||||
|
addr, lengths, _ = bw.variables[net]
|
||||||
|
bw.write_com(binw.Command.READ_DATA)
|
||||||
|
bw.write_int(addr)
|
||||||
|
bw.write_int(lengths)
|
||||||
|
|
|
||||||
|
|
@ -1,85 +1,41 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pelfy import elf_symbol
|
from typing import Literal, Any
|
||||||
from typing import Literal
|
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
|
||||||
Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
|
Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
|
||||||
('ALLOCATE_CODE', 3), ('COPY_CODE', 4),
|
('ALLOCATE_CODE', 3), ('COPY_CODE', 4),
|
||||||
('PATCH_FUNC', 5), ('PATCH_OBJECT', 6),
|
('PATCH_FUNC', 5), ('PATCH_OBJECT', 6),
|
||||||
('SET_ENTR_POINT', 64), ('END_PROG', 255)])
|
('SET_ENTR_POINT', 64), ('READ_DATA', 65),
|
||||||
|
('END_PROG', 255)])
|
||||||
|
|
||||||
|
|
||||||
def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]:
|
|
||||||
object_list: list[tuple[elf_symbol, int, int]] = []
|
|
||||||
out_offs = 0
|
|
||||||
for sym in symbols:
|
|
||||||
assert sym.info == 'STT_OBJECT'
|
|
||||||
lengths = sym.fields['st_size'],
|
|
||||||
object_list.append((sym, out_offs, lengths))
|
|
||||||
out_offs += (lengths + 3) // 4 * 4
|
|
||||||
return object_list, out_offs
|
|
||||||
|
|
||||||
|
|
||||||
def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
|
|
||||||
code_list: list[tuple[elf_symbol, int, int, int]] = []
|
|
||||||
out_offs = 0
|
|
||||||
for sym in symbols:
|
|
||||||
assert sym.info == 'STT_FUNC'
|
|
||||||
lengths = sym.fields['st_size']
|
|
||||||
|
|
||||||
#if strip_function:
|
|
||||||
# assert False, 'Not implemente'
|
|
||||||
# TODO: Strip functions
|
|
||||||
# Symbol, start out_offset in symbol, offset in output file, output lengths
|
|
||||||
# Set in_sym_out_offs and lengths
|
|
||||||
in_sym_offs = 0
|
|
||||||
|
|
||||||
code_list.append((sym, in_sym_offs, out_offs, lengths))
|
|
||||||
# out_offs += (lengths + 3) // 4 * 4
|
|
||||||
out_offs += lengths # should be aligned by default?
|
|
||||||
return code_list, out_offs
|
|
||||||
|
|
||||||
|
|
||||||
def get_function_data_blob(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
|
|
||||||
code_list: list[tuple[elf_symbol, int, int, int]] = []
|
|
||||||
out_offs = 0
|
|
||||||
for sym in symbols:
|
|
||||||
assert sym.info == 'STT_FUNC'
|
|
||||||
lengths = sym.fields['st_size']
|
|
||||||
|
|
||||||
#if strip_function:
|
|
||||||
# assert False, 'Not implemente'
|
|
||||||
# TODO: Strip functions
|
|
||||||
# Symbol, start out_offset in symbol, offset in output file, output lengths
|
|
||||||
# Set in_sym_out_offs and lengths
|
|
||||||
in_sym_offs = 0
|
|
||||||
|
|
||||||
code_list.append((sym, in_sym_offs, out_offs, lengths))
|
|
||||||
out_offs += (lengths + 3) // 4 * 4
|
|
||||||
return code_list, out_offs
|
|
||||||
|
|
||||||
class data_writer():
|
class data_writer():
|
||||||
def __init__(self, byteorder: Literal['little', 'big']):
|
def __init__(self, byteorder: Literal['little', 'big']):
|
||||||
self._data: list[tuple[str, bytes, int]] = list()
|
self._data: list[tuple[str, bytes, int]] = list()
|
||||||
self.byteorder = byteorder
|
self.byteorder: Literal['little', 'big'] = byteorder
|
||||||
|
self.variables: dict[Any, tuple[int, int, str]] = dict()
|
||||||
|
|
||||||
def write_int(self, value: int, num_bytes: int = 4, signed: bool = False):
|
def add_variable(self, net: Any, addr: int, lengths: int, var_type: str) -> None:
|
||||||
|
self.variables[net] = (addr, lengths, var_type)
|
||||||
|
|
||||||
|
def write_int(self, value: int, num_bytes: int = 4, signed: bool = False) -> None:
|
||||||
self._data.append((f"INT {value}", value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=signed), 0))
|
self._data.append((f"INT {value}", value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=signed), 0))
|
||||||
|
|
||||||
def write_com(self, value: Enum, num_bytes: int = 4):
|
def write_com(self, value: Enum, num_bytes: int = 4) -> None:
|
||||||
self._data.append((value.name, value.value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=False), 1))
|
self._data.append((value.name, value.value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=False), 1))
|
||||||
|
|
||||||
def write_byte(self, value: int):
|
def write_byte(self, value: int) -> None:
|
||||||
self._data.append((f"BYTE {value}", bytes([value]), 0))
|
self._data.append((f"BYTE {value}", bytes([value]), 0))
|
||||||
|
|
||||||
def write_bytes(self, value: bytes):
|
def write_bytes(self, value: bytes) -> None:
|
||||||
self._data.append((f"BYTES {len(value)}", value, 0))
|
self._data.append((f"BYTES {len(value)}", value, 0))
|
||||||
|
|
||||||
def write_value(self, value: int | float, num_bytes: int = 4):
|
def write_value(self, value: int | float, num_bytes: int = 4) -> None:
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
self.write_int(value, num_bytes, True)
|
self.write_int(value, num_bytes, True)
|
||||||
else:
|
else:
|
||||||
|
# 32 bit or 64 bit float
|
||||||
en = {'little': '<', 'big': '>'}[self.byteorder]
|
en = {'little': '<', 'big': '>'}[self.byteorder]
|
||||||
if num_bytes == 4:
|
if num_bytes == 4:
|
||||||
data = struct.pack(en + 'f', value)
|
data = struct.pack(en + 'f', value)
|
||||||
|
|
@ -97,14 +53,6 @@ class data_writer():
|
||||||
def get_data(self) -> bytes:
|
def get_data(self) -> bytes:
|
||||||
return b''.join(dat for _, dat, _ in self._data)
|
return b''.join(dat for _, dat, _ in self._data)
|
||||||
|
|
||||||
def to_file(self, path: str):
|
def to_file(self, path: str) -> None:
|
||||||
with open(path, 'wb') as f:
|
with open(path, 'wb') as f:
|
||||||
f.write(self.get_data())
|
f.write(self.get_data())
|
||||||
|
|
||||||
def get_c_consts() -> str:
|
|
||||||
ret: list[str] = []
|
|
||||||
for c in Command:
|
|
||||||
ret.append (f"#define {c.name} {c.value}")
|
|
||||||
for c in PatchType:
|
|
||||||
ret.append(f"#define {c.name} {c.value}")
|
|
||||||
return '\n'.join(ret)
|
|
||||||
|
|
@ -4,23 +4,27 @@ from typing import Generator
|
||||||
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'}
|
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'}
|
||||||
|
|
||||||
|
|
||||||
def get_function_start():
|
def get_function_start() -> str:
|
||||||
return """
|
return """
|
||||||
void function_start(){
|
int function_start(){
|
||||||
result_int(0);
|
result_int(0);
|
||||||
asm volatile (".long 0xF27ECAFE");
|
asm volatile (".long 0xF27ECAFE");
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_function_end():
|
|
||||||
|
def get_function_end() -> str:
|
||||||
return """
|
return """
|
||||||
void function_end(){
|
int function_end(){
|
||||||
result_int(0);
|
result_int(0);
|
||||||
asm volatile (".long 0xF17ECAFE");
|
asm volatile (".long 0xF17ECAFE");
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_op_code(op: str, type1: str, type2: str, type_out: str):
|
|
||||||
|
def get_op_code(op: str, type1: str, type2: str, type_out: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
||||||
asm volatile (".long 0xF17ECAFE");
|
asm volatile (".long 0xF17ECAFE");
|
||||||
|
|
@ -29,19 +33,20 @@ def get_op_code(op: str, type1: str, type2: str, type_out: str):
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_result_stubs1(type1: str):
|
|
||||||
|
def get_result_stubs1(type1: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
void result_{type1}({type1} arg1);
|
void result_{type1}({type1} arg1);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_result_stubs2(type1: str, type2: str):
|
def get_result_stubs2(type1: str, type2: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
void result_{type1}_{type2}({type1} arg1, {type2} arg2);
|
void result_{type1}_{type2}({type1} arg1, {type2} arg2);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_read_reg0_code(type1: str, type2: str, type_out: str):
|
def get_read_reg0_code(type1: str, type2: str, type_out: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
||||||
asm volatile (".long 0xF17ECAFE");
|
asm volatile (".long 0xF17ECAFE");
|
||||||
|
|
@ -51,7 +56,7 @@ def get_read_reg0_code(type1: str, type2: str, type_out: str):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_read_reg1_code(type1: str, type2: str, type_out: str):
|
def get_read_reg1_code(type1: str, type2: str, type_out: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
||||||
asm volatile (".long 0xF17ECAFE");
|
asm volatile (".long 0xF17ECAFE");
|
||||||
|
|
@ -61,7 +66,7 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_write_code(type1: str):
|
def get_write_code(type1: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
void write_{type1}({type1} arg1) {{
|
void write_{type1}({type1} arg1) {{
|
||||||
asm volatile (".long 0xF17ECAFE");
|
asm volatile (".long 0xF17ECAFE");
|
||||||
|
|
@ -107,7 +112,7 @@ if __name__ == "__main__":
|
||||||
for t1, t2, t_out in permutate(types, types, types):
|
for t1, t2, t_out in permutate(types, types, types):
|
||||||
code += get_read_reg0_code(t1, t2, t_out)
|
code += get_read_reg0_code(t1, t2, t_out)
|
||||||
code += get_read_reg1_code(t1, t2, t_out)
|
code += get_read_reg1_code(t1, t2, t_out)
|
||||||
|
|
||||||
for t1 in types:
|
for t1 in types:
|
||||||
code += get_write_code(t1)
|
code += get_write_code(t1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,95 +0,0 @@
|
||||||
//notes:
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//Helper functions
|
|
||||||
void result_int(int ret1){
|
|
||||||
asm ("");
|
|
||||||
}
|
|
||||||
|
|
||||||
void result_float(float ret1){
|
|
||||||
asm ("");
|
|
||||||
}
|
|
||||||
|
|
||||||
void result_float_float(float ret1, float ret2){
|
|
||||||
asm ("");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//Operations
|
|
||||||
void add_int_int(int arg1, int arg2){
|
|
||||||
result_int(arg1 + arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void add_float_float(float arg1, float arg2){
|
|
||||||
result_float(arg1 + arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void add_float_int(float arg1, float arg2){
|
|
||||||
result_float(arg1 + arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void sub_int_int(int arg1, int arg2){
|
|
||||||
result_int(arg1 - arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void sub_float_float(float arg1, float arg2){
|
|
||||||
result_float(arg1 - arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void sub_float_int(float arg1, int arg2){
|
|
||||||
result_float(arg1 - arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void sub_int_float(int arg1, float arg2){
|
|
||||||
result_float(arg1 - arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void mul_int_int(int arg1, int arg2){
|
|
||||||
result_int(arg1 * arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void mul_float_float(float arg1, float arg2){
|
|
||||||
result_float(arg1 * arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void mul_float_int(float arg1, int arg2){
|
|
||||||
result_float(arg1 + arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void div_int_int(int arg1, int arg2){
|
|
||||||
result_int(arg1 / arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void div_float_float(float arg1, float arg2){
|
|
||||||
result_float(arg1 / arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void div_float_int(float arg1, int arg2){
|
|
||||||
result_float(arg1 / arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void div_int_float(int arg1, float arg2){
|
|
||||||
result_float(arg1 / arg2);
|
|
||||||
}
|
|
||||||
|
|
||||||
//Read global variables from heap
|
|
||||||
int read_int_ret = 1337;
|
|
||||||
|
|
||||||
void read_int(){
|
|
||||||
result_int(read_int_ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
float read_float_ret = 1337;
|
|
||||||
|
|
||||||
void read_float(){
|
|
||||||
result_float(read_float_ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
void read_float_2(float arg1){
|
|
||||||
result_float_float(arg1, read_float_ret);
|
|
||||||
}
|
|
||||||
|
|
@ -13,6 +13,7 @@ LENGTH_CALL_INSTRUCTION = 5 # x86_64
|
||||||
|
|
||||||
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
|
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class patch_entry:
|
class patch_entry:
|
||||||
"""
|
"""
|
||||||
|
|
@ -21,7 +22,7 @@ class patch_entry:
|
||||||
Attributes:
|
Attributes:
|
||||||
addr (int): address of first byte to patch relative to the start of the symbol
|
addr (int): address of first byte to patch relative to the start of the symbol
|
||||||
type (RelocationType): relocation type"""
|
type (RelocationType): relocation type"""
|
||||||
|
|
||||||
type: RelocationType
|
type: RelocationType
|
||||||
addr: int
|
addr: int
|
||||||
addend: int
|
addend: int
|
||||||
|
|
@ -73,13 +74,13 @@ class stencil_database():
|
||||||
|
|
||||||
self.function_definitions = {s.name: get_ret_function_def(s) for s in self.elf.symbols
|
self.function_definitions = {s.name: get_ret_function_def(s) for s in self.elf.symbols
|
||||||
if s.info == 'STT_FUNC'}
|
if s.info == 'STT_FUNC'}
|
||||||
|
|
||||||
self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols
|
self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols
|
||||||
if s.info == 'STT_FUNC'}
|
if s.info == 'STT_FUNC'}
|
||||||
|
|
||||||
self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols
|
self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols
|
||||||
if s.info == 'STT_OBJECT'}
|
if s.info == 'STT_OBJECT'}
|
||||||
|
|
||||||
self.byteorder: Literal['little', 'big'] = self.elf.byteorder
|
self.byteorder: Literal['little', 'big'] = self.elf.byteorder
|
||||||
|
|
||||||
for name in self.function_definitions.keys():
|
for name in self.function_definitions.keys():
|
||||||
|
|
@ -87,7 +88,6 @@ class stencil_database():
|
||||||
sym.relocations
|
sym.relocations
|
||||||
self.elf.symbols[name].data
|
self.elf.symbols[name].data
|
||||||
|
|
||||||
|
|
||||||
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
|
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
|
||||||
"""Return patch positions
|
"""Return patch positions
|
||||||
"""
|
"""
|
||||||
|
|
@ -95,7 +95,7 @@ class stencil_database():
|
||||||
start_index, end_index = get_stencil_position(symbol.data, symbol.file.byteorder)
|
start_index, end_index = get_stencil_position(symbol.data, symbol.file.byteorder)
|
||||||
|
|
||||||
for reloc in symbol.relocations:
|
for reloc in symbol.relocations:
|
||||||
|
|
||||||
# address to fist byte to patch relative to the start of the symbol
|
# address to fist byte to patch relative to the start of the symbol
|
||||||
patch = translate_relocation(
|
patch = translate_relocation(
|
||||||
reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index,
|
reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index,
|
||||||
|
|
@ -107,8 +107,5 @@ class stencil_database():
|
||||||
if patch.addr < end_index - start_index:
|
if patch.addr < end_index - start_index:
|
||||||
yield patch
|
yield patch
|
||||||
|
|
||||||
|
|
||||||
def get_func_data(self, name: str) -> bytes:
|
def get_func_data(self, name: str) -> bytes:
|
||||||
return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder)
|
return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -223,13 +223,15 @@
|
||||||
asm volatile (".long 0xF27ECAFE");
|
asm volatile (".long 0xF27ECAFE");
|
||||||
}
|
}
|
||||||
|
|
||||||
void function_start(){
|
int function_start(){
|
||||||
result_int(0);
|
result_int(0);
|
||||||
asm volatile (".long 0xF27ECAFE");
|
asm volatile (".long 0xF27ECAFE");
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void function_end(){
|
int function_end(){
|
||||||
result_int(0);
|
result_int(0);
|
||||||
asm volatile (".long 0xF17ECAFE");
|
asm volatile (".long 0xF17ECAFE");
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
#define PATCH_FUNC 5
|
#define PATCH_FUNC 5
|
||||||
#define PATCH_OBJECT 6
|
#define PATCH_OBJECT 6
|
||||||
#define SET_ENTR_POINT 64
|
#define SET_ENTR_POINT 64
|
||||||
|
#define READ_DATA 65
|
||||||
#define END_PROG 255
|
#define END_PROG 255
|
||||||
|
|
||||||
#define PATCH_RELATIVE_32 0
|
#define PATCH_RELATIVE_32 0
|
||||||
|
|
@ -132,16 +133,27 @@ int parse_commands(uint8_t *bytes){
|
||||||
case SET_ENTR_POINT:
|
case SET_ENTR_POINT:
|
||||||
rel_entr_point = *(uint32_t*)bytes; bytes += 4;
|
rel_entr_point = *(uint32_t*)bytes; bytes += 4;
|
||||||
printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point);
|
printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point);
|
||||||
entr_point = (int (*)())(executable_memory + rel_entr_point);
|
entr_point = (int (*)())(executable_memory + rel_entr_point);
|
||||||
|
|
||||||
|
mark_mem_executable();
|
||||||
|
int ret = entr_point();
|
||||||
|
printf("Return value: %i\n", ret);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case END_PROG:
|
case END_PROG:
|
||||||
printf("END_PROG\n");
|
printf("END_PROG\n");
|
||||||
mark_mem_executable();
|
|
||||||
int ret = entr_point();
|
|
||||||
printf("Return value: %i\n", ret);
|
|
||||||
err_flag = 1;
|
err_flag = 1;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case READ_DATA:
|
||||||
|
offs = *(uint32_t*)bytes; bytes += 4;
|
||||||
|
size = *(uint32_t*)bytes; bytes += 4;
|
||||||
|
printf("READ_DATA offs=%i size=%i data=", offs, size);
|
||||||
|
for (uint32_t i = 0; i < size; i++) {
|
||||||
|
printf("%02X ", data_memory[offs + i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
printf("Unknown command\n");
|
printf("Unknown command\n");
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from copapy import Write, const
|
from copapy import Write, const
|
||||||
import copapy as rc
|
import copapy as rc
|
||||||
|
|
||||||
|
|
||||||
def test_ast_generation():
|
def test_ast_generation():
|
||||||
c1 = const(1.11)
|
c1 = const(1.11)
|
||||||
c2 = const(2.22)
|
#c2 = const(2.22)
|
||||||
#c3 = const(3.33)
|
#c3 = const(3.33)
|
||||||
|
|
||||||
#i1 = c1 + c2
|
#i1 = c1 + c2
|
||||||
|
|
@ -12,36 +13,31 @@ def test_ast_generation():
|
||||||
|
|
||||||
#r1 = i1 + i3
|
#r1 = i1 + i3
|
||||||
#r2 = i3 * i2
|
#r2 = i3 * i2
|
||||||
|
|
||||||
i1 = c1 * 2
|
i1 = c1 * 2
|
||||||
i2 = i1 + 3
|
r1 = i1 + 7
|
||||||
|
out = Write(r1)
|
||||||
|
|
||||||
r1 = i1 + i2
|
|
||||||
r2 = c2 + 4 + c1
|
|
||||||
|
|
||||||
out = [Write(r1), Write(r2)]
|
|
||||||
print(out)
|
print(out)
|
||||||
print('--')
|
print('-- get_path_segments:')
|
||||||
|
|
||||||
path_segments = list(rc.get_path_segments(out))
|
path_segments = list(rc.get_path_segments([out]))
|
||||||
for p in path_segments:
|
for p in path_segments:
|
||||||
print(p)
|
print(p)
|
||||||
print('--')
|
print('-- get_ordered_ops:')
|
||||||
|
|
||||||
ordered_ops = list(rc.get_ordered_ops(path_segments))
|
ordered_ops = list(rc.get_ordered_ops(path_segments))
|
||||||
for p in path_segments:
|
for p in path_segments:
|
||||||
print(p)
|
print(p)
|
||||||
print('--')
|
print('-- get_consts:')
|
||||||
|
|
||||||
const_list = rc.get_consts(ordered_ops)
|
const_list = rc.get_consts(ordered_ops)
|
||||||
for p in const_list:
|
for p in const_list:
|
||||||
print(p)
|
print(p)
|
||||||
print('--')
|
print('-- add_read_ops:')
|
||||||
|
|
||||||
output_ops = list(rc.add_read_ops(ordered_ops))
|
output_ops = list(rc.add_read_ops(ordered_ops))
|
||||||
for p in output_ops:
|
for p in output_ops:
|
||||||
print(p)
|
print(p)
|
||||||
print('--')
|
print('-- add_write_ops:')
|
||||||
|
|
||||||
extended_output_ops = list(rc.add_write_ops(output_ops, const_list))
|
extended_output_ops = list(rc.add_write_ops(output_ops, const_list))
|
||||||
for p in extended_output_ops:
|
for p in extended_output_ops:
|
||||||
|
|
@ -50,4 +46,4 @@ def test_ast_generation():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_ast_generation()
|
test_ast_generation()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
from copapy import Write, const
|
from copapy import Write, const
|
||||||
import copapy as rc
|
import copapy
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import struct
|
||||||
|
from copapy import binwrite as binw
|
||||||
|
|
||||||
|
|
||||||
def run_command(command: list[str], encoding: str = 'utf8') -> str:
|
def run_command(command: list[str], encoding: str = 'utf8') -> str:
|
||||||
process = subprocess.Popen(command, stdout=subprocess.PIPE)
|
process = subprocess.Popen(command, stdout=subprocess.PIPE)
|
||||||
|
|
@ -9,9 +12,10 @@ def run_command(command: list[str], encoding: str = 'utf8') -> str:
|
||||||
assert error is None, f"Error occurred: {error.decode(encoding)}"
|
assert error is None, f"Error occurred: {error.decode(encoding)}"
|
||||||
return output.decode(encoding)
|
return output.decode(encoding)
|
||||||
|
|
||||||
def test_compile():
|
|
||||||
c1 = const(1.11)
|
def test_example():
|
||||||
c2 = const(2.22)
|
c1 = 1.11
|
||||||
|
c2 = 2.22
|
||||||
|
|
||||||
i1 = c1 * 2
|
i1 = c1 * 2
|
||||||
i2 = i1 + 3
|
i2 = i1 + 3
|
||||||
|
|
@ -19,9 +23,48 @@ def test_compile():
|
||||||
r1 = i1 + i2
|
r1 = i1 + i2
|
||||||
r2 = c2 + 4 + c1
|
r2 = c2 + 4 + c1
|
||||||
|
|
||||||
out = [Write(r1), Write(r2)]
|
en = {'little': '<', 'big': '>'}['little']
|
||||||
|
data = struct.pack(en + 'f', r1)
|
||||||
|
print("example r1 " + ' '.join(f'{b:02X}' for b in data))
|
||||||
|
|
||||||
il = rc.compile_to_instruction_list(out)
|
data = struct.pack(en + 'f', r2)
|
||||||
|
print("example r2 " + ' '.join(f'{b:02X}' for b in data))
|
||||||
|
|
||||||
|
# assert False
|
||||||
|
# example r1 7B 14 EE 40
|
||||||
|
# example r2 5C 8F EA 40
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile():
|
||||||
|
|
||||||
|
print(run_command(['bash', 'build.sh']))
|
||||||
|
|
||||||
|
c1 = const(4)
|
||||||
|
#c2 = const(2)
|
||||||
|
|
||||||
|
#i1 = c1 * 2
|
||||||
|
#i2 = i1 + 3
|
||||||
|
|
||||||
|
#r1 = i1 + i2
|
||||||
|
#r2 = c2 + 4 + c1
|
||||||
|
|
||||||
|
#out = [Write(r1), Write(r2)]
|
||||||
|
|
||||||
|
i1 = c1 * 2
|
||||||
|
r1 = i1 + 7
|
||||||
|
out = Write(r1)
|
||||||
|
|
||||||
|
il = copapy.compile_to_instruction_list(out)
|
||||||
|
|
||||||
|
#copapy.read_variable(il, i1)
|
||||||
|
copapy.read_variable(il, r1)
|
||||||
|
|
||||||
|
il.write_com(binw.Command.READ_DATA)
|
||||||
|
il.write_int(0)
|
||||||
|
il.write_int(36)
|
||||||
|
|
||||||
|
# run program command
|
||||||
|
il.write_com(binw.Command.END_PROG)
|
||||||
|
|
||||||
print('#', il.print())
|
print('#', il.print())
|
||||||
|
|
||||||
|
|
@ -30,8 +73,8 @@ def test_compile():
|
||||||
result = run_command(['./bin/runmem2', 'test.copapy'])
|
result = run_command(['./bin/runmem2', 'test.copapy'])
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
assert 'Return value: 0' in result
|
assert 'Return value: 1' in result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_compile()
|
test_compile()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from copapy import stencil_database
|
from copapy import stencil_database
|
||||||
from copapy import stencil_db
|
from copapy import stencil_db
|
||||||
|
|
||||||
|
|
||||||
def test_list_symbols():
|
def test_list_symbols():
|
||||||
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
|
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
|
||||||
print('----')
|
print('----')
|
||||||
|
|
@ -18,9 +19,9 @@ def test_start_end_function():
|
||||||
|
|
||||||
start, end = stencil_db.get_stencil_position(data, sdb.elf.byteorder)
|
start, end = stencil_db.get_stencil_position(data, sdb.elf.byteorder)
|
||||||
|
|
||||||
assert start>= 0 and end >= start and end <= len(data)
|
assert start >= 0 and end >= start and end <= len(data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_list_symbols()
|
test_list_symbols()
|
||||||
test_start_end_function()
|
test_start_end_function()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue