code cleaned up, passes flake8 and mypy

This commit is contained in:
Nicolas 2025-10-01 23:09:49 +02:00
parent 3b7fcdb88b
commit 5b7ca52b7c
11 changed files with 159 additions and 261 deletions

View File

@ -13,6 +13,7 @@ exclude =
build, build,
dist, dist,
.conda, .conda,
.venv,
tests/autogenerated_* tests/autogenerated_*
# Enable specific plugins or options # Enable specific plugins or options

View File

@ -3,22 +3,25 @@ from typing import Generator, Iterable, Any
from . import binwrite as binw from . import binwrite as binw
from .stencil_db import stencil_database from .stencil_db import stencil_database
Operand = type['Net'] | float | int Operand = type['Net'] | float | int
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
return [name for name, value in scope.items() if value is var] return [name for name, value in scope.items() if value is var]
# _ccode = pkgutil.get_data(__name__, 'stencils.c') # _ccode = pkgutil.get_data(__name__, 'stencils.c')
# assert _ccode is not None # assert _ccode is not None
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o') sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
class Node: class Node:
def __init__(self): def __init__(self) -> None:
self.args: list[Net] = [] self.args: list[Net] = []
self.name: str = '' self.name: str = ''
def __repr__(self): def __repr__(self) -> str:
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})" #return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})" return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})"
@ -37,26 +40,26 @@ class Net:
def __rmul__(self, other: Any) -> 'Net': def __rmul__(self, other: Any) -> 'Net':
return _add_op('mul', [self, other], True) return _add_op('mul', [self, other], True)
def __add__(self, other: Any) -> 'Net': def __add__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True) return _add_op('add', [self, other], True)
def __radd__(self, other: Any) -> 'Net': def __radd__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True) return _add_op('add', [self, other], True)
def __sub__ (self, other: Any) -> 'Net': def __sub__(self, other: Any) -> 'Net':
return _add_op('sub', [self, other]) return _add_op('sub', [self, other])
def __rsub__ (self, other: Any) -> 'Net': def __rsub__(self, other: Any) -> 'Net':
return _add_op('sub', [other, self]) return _add_op('sub', [other, self])
def __truediv__ (self, other: Any) -> 'Net': def __truediv__(self, other: Any) -> 'Net':
return _add_op('div', [self, other]) return _add_op('div', [self, other])
def __rtruediv__ (self, other: Any) -> 'Net': def __rtruediv__(self, other: Any) -> 'Net':
return _add_op('div', [other, self]) return _add_op('div', [other, self])
def __repr__(self): def __repr__(self) -> str:
names = get_var_name(self) names = get_var_name(self)
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}" return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}"
@ -94,7 +97,7 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
if commutative: if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
typed_op = '_'.join([op] + [a.dtype for a in arg_nets]) typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
if typed_op not in sdb.function_definitions: if typed_op not in sdb.function_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
@ -141,24 +144,6 @@ def const_vector3d(x: float, y: float, z: float) -> vec3d:
return vec3d((const(x), const(y), const(z))) return vec3d((const(x), const(y), const(z)))
def get_multiuse_nets(root: list[Node]) -> set[Net]:
"""Finds all nets that get accessed more than one time. Therefore
storage on the heap might be better.
"""
known_nets: set[Net] = set()
def recursive_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
for net in net_list:
#print(net)
if net in known_nets:
yield net
else:
known_nets.add(net)
yield from recursive_node_search(net.source.args)
return set(recursive_node_search(op.args[0] for op in root))
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]: def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
"""List of all possible paths. Ops in order of execution (output at the end) """List of all possible paths. Ops in order of execution (output at the end)
""" """
@ -190,7 +175,7 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No
"""Merge in all tree branches at branch position into the path segments """Merge in all tree branches at branch position into the path segments
""" """
finished_paths: set[int] = set() finished_paths: set[int] = set()
for i, path in enumerate(path_segments): for i, path in enumerate(path_segments):
if i not in finished_paths: if i not in finished_paths:
for op in path: for op in path:
@ -205,10 +190,10 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No
yield op yield op
finished_paths.add(i) finished_paths.add(i)
def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]: def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
"""Get all const nodes in the op list """Get all const nodes in the op list
Returns: Returns:
List of tuples of (name, net, value)""" List of tuples of (name, net, value)"""
net_lookup = {net.source: net for op in op_list for net in op.args} net_lookup = {net.source: net for op in op_list for net in op.args}
@ -218,7 +203,7 @@ def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]: def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add read operation before each op where arguments are not already positioned """Add read operation before each op where arguments are not already positioned
correctly in the registers correctly in the registers
Returns: Returns:
Yields tuples of a net and a operation. The net is only provided Yields tuples of a net and a operation. The net is only provided
for new added read operations. Otherwise None is returned in the tuple.""" for new added read operations. Otherwise None is returned in the tuple."""
@ -226,7 +211,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
# Generate result net lookup table # Generate result net lookup table
net_lookup = {net.source: net for node in node_list for net in node.args} net_lookup = {net.source: net for node in node_list for net in node.args}
for node in node_list: for node in node_list:
if not node.name.startswith('const_'): if not node.name.startswith('const_'):
for i, net in enumerate(node.args): for i, net in enumerate(node.args):
@ -238,9 +223,9 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), []) new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
yield net, new_node yield net, new_node
registers[i] = net registers[i] = net
if node in net_lookup: if node in net_lookup:
yield None , node yield None, node
registers[0] = net_lookup[node] registers[0] = net_lookup[node]
else: else:
print('--->', node) print('--->', node)
@ -283,7 +268,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
if isinstance(end_nodes, Node): if isinstance(end_nodes, Node):
node_list = [end_nodes] node_list = [end_nodes]
else: else:
node_list = end_nodes node_list = list(end_nodes)
path_segments = list(get_path_segments(node_list)) path_segments = list(get_path_segments(node_list))
ordered_ops = list(get_ordered_ops(path_segments)) ordered_ops = list(get_ordered_ops(path_segments))
@ -297,10 +282,8 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
# Get all nets associated with heap memory # Get all nets associated with heap memory
variable_list = get_nets(const_list, extended_output_ops) variable_list = get_nets(const_list, extended_output_ops)
dw = binw.data_writer(sdb.byteorder) dw = binw.data_writer(sdb.byteorder)
def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]: def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]:
offset: int = 0 offset: int = 0
object_list: list[tuple[Net, int, int]] = [] object_list: list[tuple[Net, int, int]] = []
@ -312,14 +295,14 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
return object_list, offset return object_list, offset
object_list, data_section_lengths = variable_mem_layout(variable_list) object_list, data_section_lengths = variable_mem_layout(variable_list)
#data_section_lengths = object_list[-1][1] + object_list[-1][2] # Write data
dw.write_com(binw.Command.ALLOCATE_DATA) dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths) dw.write_int(data_section_lengths)
for net, out_offs, lengths in object_list: for net, out_offs, lengths in object_list:
dw.add_variable(net, out_offs, lengths, net.dtype)
if isinstance(net.source, Const): if isinstance(net.source, Const):
dw.write_com(binw.Command.COPY_DATA) dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs) dw.write_int(out_offs)
@ -330,7 +313,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
# write auxiliary_functions # write auxiliary_functions
# TODO # TODO
# Prepare program data and relocations # Prepare program code and relocations
object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list} object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list}
data_list: list[bytes] = [] data_list: list[bytes] = []
patch_list: list[tuple[int, int, int]] = [] patch_list: list[tuple[int, int, int]] = []
@ -347,7 +330,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
data = sdb.get_func_data(node.name) data = sdb.get_func_data(node.name)
data_list.append(data) data_list.append(data)
print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for patch in sdb.get_patch_positions(node.name): for patch in sdb.get_patch_positions(node.name):
assert result_net, f"Relocation found but no net defined for operation {node.name}" assert result_net, f"Relocation found but no net defined for operation {node.name}"
object_addr = object_addr_lookp[result_net] object_addr = object_addr_lookp[result_net]
@ -383,7 +366,12 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
dw.write_com(binw.Command.SET_ENTR_POINT) dw.write_com(binw.Command.SET_ENTR_POINT)
dw.write_int(0) dw.write_int(0)
# run program command
dw.write_com(binw.Command.END_PROG)
return dw return dw
def read_variable(bw: binw.data_writer, net: Net) -> None:
assert net in bw.variables, f"Variable {net} not found in data writer variables"
addr, lengths, _ = bw.variables[net]
bw.write_com(binw.Command.READ_DATA)
bw.write_int(addr)
bw.write_int(lengths)

View File

@ -1,85 +1,41 @@
from enum import Enum from enum import Enum
from pelfy import elf_symbol from typing import Literal, Any
from typing import Literal
import struct import struct
Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2), Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
('ALLOCATE_CODE', 3), ('COPY_CODE', 4), ('ALLOCATE_CODE', 3), ('COPY_CODE', 4),
('PATCH_FUNC', 5), ('PATCH_OBJECT', 6), ('PATCH_FUNC', 5), ('PATCH_OBJECT', 6),
('SET_ENTR_POINT', 64), ('END_PROG', 255)]) ('SET_ENTR_POINT', 64), ('READ_DATA', 65),
('END_PROG', 255)])
def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]:
object_list: list[tuple[elf_symbol, int, int]] = []
out_offs = 0
for sym in symbols:
assert sym.info == 'STT_OBJECT'
lengths = sym.fields['st_size'],
object_list.append((sym, out_offs, lengths))
out_offs += (lengths + 3) // 4 * 4
return object_list, out_offs
def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
code_list: list[tuple[elf_symbol, int, int, int]] = []
out_offs = 0
for sym in symbols:
assert sym.info == 'STT_FUNC'
lengths = sym.fields['st_size']
#if strip_function:
# assert False, 'Not implemente'
# TODO: Strip functions
# Symbol, start out_offset in symbol, offset in output file, output lengths
# Set in_sym_out_offs and lengths
in_sym_offs = 0
code_list.append((sym, in_sym_offs, out_offs, lengths))
# out_offs += (lengths + 3) // 4 * 4
out_offs += lengths # should be aligned by default?
return code_list, out_offs
def get_function_data_blob(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
code_list: list[tuple[elf_symbol, int, int, int]] = []
out_offs = 0
for sym in symbols:
assert sym.info == 'STT_FUNC'
lengths = sym.fields['st_size']
#if strip_function:
# assert False, 'Not implemente'
# TODO: Strip functions
# Symbol, start out_offset in symbol, offset in output file, output lengths
# Set in_sym_out_offs and lengths
in_sym_offs = 0
code_list.append((sym, in_sym_offs, out_offs, lengths))
out_offs += (lengths + 3) // 4 * 4
return code_list, out_offs
class data_writer(): class data_writer():
def __init__(self, byteorder: Literal['little', 'big']): def __init__(self, byteorder: Literal['little', 'big']):
self._data: list[tuple[str, bytes, int]] = list() self._data: list[tuple[str, bytes, int]] = list()
self.byteorder = byteorder self.byteorder: Literal['little', 'big'] = byteorder
self.variables: dict[Any, tuple[int, int, str]] = dict()
def write_int(self, value: int, num_bytes: int = 4, signed: bool = False): def add_variable(self, net: Any, addr: int, lengths: int, var_type: str) -> None:
self.variables[net] = (addr, lengths, var_type)
def write_int(self, value: int, num_bytes: int = 4, signed: bool = False) -> None:
self._data.append((f"INT {value}", value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=signed), 0)) self._data.append((f"INT {value}", value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=signed), 0))
def write_com(self, value: Enum, num_bytes: int = 4): def write_com(self, value: Enum, num_bytes: int = 4) -> None:
self._data.append((value.name, value.value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=False), 1)) self._data.append((value.name, value.value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=False), 1))
def write_byte(self, value: int): def write_byte(self, value: int) -> None:
self._data.append((f"BYTE {value}", bytes([value]), 0)) self._data.append((f"BYTE {value}", bytes([value]), 0))
def write_bytes(self, value: bytes): def write_bytes(self, value: bytes) -> None:
self._data.append((f"BYTES {len(value)}", value, 0)) self._data.append((f"BYTES {len(value)}", value, 0))
def write_value(self, value: int | float, num_bytes: int = 4): def write_value(self, value: int | float, num_bytes: int = 4) -> None:
if isinstance(value, int): if isinstance(value, int):
self.write_int(value, num_bytes, True) self.write_int(value, num_bytes, True)
else: else:
# 32 bit or 64 bit float
en = {'little': '<', 'big': '>'}[self.byteorder] en = {'little': '<', 'big': '>'}[self.byteorder]
if num_bytes == 4: if num_bytes == 4:
data = struct.pack(en + 'f', value) data = struct.pack(en + 'f', value)
@ -97,14 +53,6 @@ class data_writer():
def get_data(self) -> bytes: def get_data(self) -> bytes:
return b''.join(dat for _, dat, _ in self._data) return b''.join(dat for _, dat, _ in self._data)
def to_file(self, path: str): def to_file(self, path: str) -> None:
with open(path, 'wb') as f: with open(path, 'wb') as f:
f.write(self.get_data()) f.write(self.get_data())
def get_c_consts() -> str:
ret: list[str] = []
for c in Command:
ret.append (f"#define {c.name} {c.value}")
for c in PatchType:
ret.append(f"#define {c.name} {c.value}")
return '\n'.join(ret)

View File

@ -4,23 +4,27 @@ from typing import Generator
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'} op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'}
def get_function_start(): def get_function_start() -> str:
return """ return """
void function_start(){ int function_start(){
result_int(0); result_int(0);
asm volatile (".long 0xF27ECAFE"); asm volatile (".long 0xF27ECAFE");
return 1;
} }
""" """
def get_function_end():
def get_function_end() -> str:
return """ return """
void function_end(){ int function_end(){
result_int(0); result_int(0);
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
return 1;
} }
""" """
def get_op_code(op: str, type1: str, type2: str, type_out: str):
def get_op_code(op: str, type1: str, type2: str, type_out: str) -> str:
return f""" return f"""
void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
@ -29,19 +33,20 @@ def get_op_code(op: str, type1: str, type2: str, type_out: str):
}} }}
""" """
def get_result_stubs1(type1: str):
def get_result_stubs1(type1: str) -> str:
return f""" return f"""
void result_{type1}({type1} arg1); void result_{type1}({type1} arg1);
""" """
def get_result_stubs2(type1: str, type2: str): def get_result_stubs2(type1: str, type2: str) -> str:
return f""" return f"""
void result_{type1}_{type2}({type1} arg1, {type2} arg2); void result_{type1}_{type2}({type1} arg1, {type2} arg2);
""" """
def get_read_reg0_code(type1: str, type2: str, type_out: str): def get_read_reg0_code(type1: str, type2: str, type_out: str) -> str:
return f""" return f"""
void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
@ -51,7 +56,7 @@ def get_read_reg0_code(type1: str, type2: str, type_out: str):
""" """
def get_read_reg1_code(type1: str, type2: str, type_out: str): def get_read_reg1_code(type1: str, type2: str, type_out: str) -> str:
return f""" return f"""
void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
@ -61,7 +66,7 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str):
""" """
def get_write_code(type1: str): def get_write_code(type1: str) -> str:
return f""" return f"""
void write_{type1}({type1} arg1) {{ void write_{type1}({type1} arg1) {{
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
@ -107,7 +112,7 @@ if __name__ == "__main__":
for t1, t2, t_out in permutate(types, types, types): for t1, t2, t_out in permutate(types, types, types):
code += get_read_reg0_code(t1, t2, t_out) code += get_read_reg0_code(t1, t2, t_out)
code += get_read_reg1_code(t1, t2, t_out) code += get_read_reg1_code(t1, t2, t_out)
for t1 in types: for t1 in types:
code += get_write_code(t1) code += get_write_code(t1)

View File

@ -1,95 +0,0 @@
//notes:
//Helper functions
void result_int(int ret1){
asm ("");
}
void result_float(float ret1){
asm ("");
}
void result_float_float(float ret1, float ret2){
asm ("");
}
//Operations
void add_int_int(int arg1, int arg2){
result_int(arg1 + arg2);
}
void add_float_float(float arg1, float arg2){
result_float(arg1 + arg2);
}
void add_float_int(float arg1, float arg2){
result_float(arg1 + arg2);
}
void sub_int_int(int arg1, int arg2){
result_int(arg1 - arg2);
}
void sub_float_float(float arg1, float arg2){
result_float(arg1 - arg2);
}
void sub_float_int(float arg1, int arg2){
result_float(arg1 - arg2);
}
void sub_int_float(int arg1, float arg2){
result_float(arg1 - arg2);
}
void mul_int_int(int arg1, int arg2){
result_int(arg1 * arg2);
}
void mul_float_float(float arg1, float arg2){
result_float(arg1 * arg2);
}
void mul_float_int(float arg1, int arg2){
result_float(arg1 + arg2);
}
void div_int_int(int arg1, int arg2){
result_int(arg1 / arg2);
}
void div_float_float(float arg1, float arg2){
result_float(arg1 / arg2);
}
void div_float_int(float arg1, int arg2){
result_float(arg1 / arg2);
}
void div_int_float(int arg1, float arg2){
result_float(arg1 / arg2);
}
//Read global variables from heap
int read_int_ret = 1337;
void read_int(){
result_int(read_int_ret);
}
float read_float_ret = 1337;
void read_float(){
result_float(read_float_ret);
}
void read_float_2(float arg1){
result_float_float(arg1, read_float_ret);
}

View File

@ -13,6 +13,7 @@ LENGTH_CALL_INSTRUCTION = 5 # x86_64
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)]) RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
@dataclass @dataclass
class patch_entry: class patch_entry:
""" """
@ -21,7 +22,7 @@ class patch_entry:
Attributes: Attributes:
addr (int): address of first byte to patch relative to the start of the symbol addr (int): address of first byte to patch relative to the start of the symbol
type (RelocationType): relocation type""" type (RelocationType): relocation type"""
type: RelocationType type: RelocationType
addr: int addr: int
addend: int addend: int
@ -73,13 +74,13 @@ class stencil_database():
self.function_definitions = {s.name: get_ret_function_def(s) for s in self.elf.symbols self.function_definitions = {s.name: get_ret_function_def(s) for s in self.elf.symbols
if s.info == 'STT_FUNC'} if s.info == 'STT_FUNC'}
self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols
if s.info == 'STT_FUNC'} if s.info == 'STT_FUNC'}
self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols
if s.info == 'STT_OBJECT'} if s.info == 'STT_OBJECT'}
self.byteorder: Literal['little', 'big'] = self.elf.byteorder self.byteorder: Literal['little', 'big'] = self.elf.byteorder
for name in self.function_definitions.keys(): for name in self.function_definitions.keys():
@ -87,7 +88,6 @@ class stencil_database():
sym.relocations sym.relocations
self.elf.symbols[name].data self.elf.symbols[name].data
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]: def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
"""Return patch positions """Return patch positions
""" """
@ -95,7 +95,7 @@ class stencil_database():
start_index, end_index = get_stencil_position(symbol.data, symbol.file.byteorder) start_index, end_index = get_stencil_position(symbol.data, symbol.file.byteorder)
for reloc in symbol.relocations: for reloc in symbol.relocations:
# address to fist byte to patch relative to the start of the symbol # address to fist byte to patch relative to the start of the symbol
patch = translate_relocation( patch = translate_relocation(
reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index, reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index,
@ -107,8 +107,5 @@ class stencil_database():
if patch.addr < end_index - start_index: if patch.addr < end_index - start_index:
yield patch yield patch
def get_func_data(self, name: str) -> bytes: def get_func_data(self, name: str) -> bytes:
return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder) return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder)

View File

@ -223,13 +223,15 @@
asm volatile (".long 0xF27ECAFE"); asm volatile (".long 0xF27ECAFE");
} }
void function_start(){ int function_start(){
result_int(0); result_int(0);
asm volatile (".long 0xF27ECAFE"); asm volatile (".long 0xF27ECAFE");
return 1;
} }
void function_end(){ int function_end(){
result_int(0); result_int(0);
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
return 1;
} }

View File

@ -14,6 +14,7 @@
#define PATCH_FUNC 5 #define PATCH_FUNC 5
#define PATCH_OBJECT 6 #define PATCH_OBJECT 6
#define SET_ENTR_POINT 64 #define SET_ENTR_POINT 64
#define READ_DATA 65
#define END_PROG 255 #define END_PROG 255
#define PATCH_RELATIVE_32 0 #define PATCH_RELATIVE_32 0
@ -132,16 +133,27 @@ int parse_commands(uint8_t *bytes){
case SET_ENTR_POINT: case SET_ENTR_POINT:
rel_entr_point = *(uint32_t*)bytes; bytes += 4; rel_entr_point = *(uint32_t*)bytes; bytes += 4;
printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point); printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point);
entr_point = (int (*)())(executable_memory + rel_entr_point); entr_point = (int (*)())(executable_memory + rel_entr_point);
mark_mem_executable();
int ret = entr_point();
printf("Return value: %i\n", ret);
break; break;
case END_PROG: case END_PROG:
printf("END_PROG\n"); printf("END_PROG\n");
mark_mem_executable();
int ret = entr_point();
printf("Return value: %i\n", ret);
err_flag = 1; err_flag = 1;
break; break;
case READ_DATA:
offs = *(uint32_t*)bytes; bytes += 4;
size = *(uint32_t*)bytes; bytes += 4;
printf("READ_DATA offs=%i size=%i data=", offs, size);
for (uint32_t i = 0; i < size; i++) {
printf("%02X ", data_memory[offs + i]);
}
printf("\n");
break;
default: default:
printf("Unknown command\n"); printf("Unknown command\n");

View File

@ -1,9 +1,10 @@
from copapy import Write, const from copapy import Write, const
import copapy as rc import copapy as rc
def test_ast_generation(): def test_ast_generation():
c1 = const(1.11) c1 = const(1.11)
c2 = const(2.22) #c2 = const(2.22)
#c3 = const(3.33) #c3 = const(3.33)
#i1 = c1 + c2 #i1 = c1 + c2
@ -12,36 +13,31 @@ def test_ast_generation():
#r1 = i1 + i3 #r1 = i1 + i3
#r2 = i3 * i2 #r2 = i3 * i2
i1 = c1 * 2 i1 = c1 * 2
i2 = i1 + 3 r1 = i1 + 7
out = Write(r1)
r1 = i1 + i2
r2 = c2 + 4 + c1
out = [Write(r1), Write(r2)]
print(out) print(out)
print('--') print('-- get_path_segments:')
path_segments = list(rc.get_path_segments(out)) path_segments = list(rc.get_path_segments([out]))
for p in path_segments: for p in path_segments:
print(p) print(p)
print('--') print('-- get_ordered_ops:')
ordered_ops = list(rc.get_ordered_ops(path_segments)) ordered_ops = list(rc.get_ordered_ops(path_segments))
for p in path_segments: for p in path_segments:
print(p) print(p)
print('--') print('-- get_consts:')
const_list = rc.get_consts(ordered_ops) const_list = rc.get_consts(ordered_ops)
for p in const_list: for p in const_list:
print(p) print(p)
print('--') print('-- add_read_ops:')
output_ops = list(rc.add_read_ops(ordered_ops)) output_ops = list(rc.add_read_ops(ordered_ops))
for p in output_ops: for p in output_ops:
print(p) print(p)
print('--') print('-- add_write_ops:')
extended_output_ops = list(rc.add_write_ops(output_ops, const_list)) extended_output_ops = list(rc.add_write_ops(output_ops, const_list))
for p in extended_output_ops: for p in extended_output_ops:
@ -50,4 +46,4 @@ def test_ast_generation():
if __name__ == "__main__": if __name__ == "__main__":
test_ast_generation() test_ast_generation()

View File

@ -1,6 +1,9 @@
from copapy import Write, const from copapy import Write, const
import copapy as rc import copapy
import subprocess import subprocess
import struct
from copapy import binwrite as binw
def run_command(command: list[str], encoding: str = 'utf8') -> str: def run_command(command: list[str], encoding: str = 'utf8') -> str:
process = subprocess.Popen(command, stdout=subprocess.PIPE) process = subprocess.Popen(command, stdout=subprocess.PIPE)
@ -9,9 +12,10 @@ def run_command(command: list[str], encoding: str = 'utf8') -> str:
assert error is None, f"Error occurred: {error.decode(encoding)}" assert error is None, f"Error occurred: {error.decode(encoding)}"
return output.decode(encoding) return output.decode(encoding)
def test_compile():
c1 = const(1.11) def test_example():
c2 = const(2.22) c1 = 1.11
c2 = 2.22
i1 = c1 * 2 i1 = c1 * 2
i2 = i1 + 3 i2 = i1 + 3
@ -19,9 +23,48 @@ def test_compile():
r1 = i1 + i2 r1 = i1 + i2
r2 = c2 + 4 + c1 r2 = c2 + 4 + c1
out = [Write(r1), Write(r2)] en = {'little': '<', 'big': '>'}['little']
data = struct.pack(en + 'f', r1)
print("example r1 " + ' '.join(f'{b:02X}' for b in data))
il = rc.compile_to_instruction_list(out) data = struct.pack(en + 'f', r2)
print("example r2 " + ' '.join(f'{b:02X}' for b in data))
# assert False
# example r1 7B 14 EE 40
# example r2 5C 8F EA 40
def test_compile():
print(run_command(['bash', 'build.sh']))
c1 = const(4)
#c2 = const(2)
#i1 = c1 * 2
#i2 = i1 + 3
#r1 = i1 + i2
#r2 = c2 + 4 + c1
#out = [Write(r1), Write(r2)]
i1 = c1 * 2
r1 = i1 + 7
out = Write(r1)
il = copapy.compile_to_instruction_list(out)
#copapy.read_variable(il, i1)
copapy.read_variable(il, r1)
il.write_com(binw.Command.READ_DATA)
il.write_int(0)
il.write_int(36)
# run program command
il.write_com(binw.Command.END_PROG)
print('#', il.print()) print('#', il.print())
@ -30,8 +73,8 @@ def test_compile():
result = run_command(['./bin/runmem2', 'test.copapy']) result = run_command(['./bin/runmem2', 'test.copapy'])
print(result) print(result)
assert 'Return value: 0' in result assert 'Return value: 1' in result
if __name__ == "__main__": if __name__ == "__main__":
test_compile() test_compile()

View File

@ -1,6 +1,7 @@
from copapy import stencil_database from copapy import stencil_database
from copapy import stencil_db from copapy import stencil_db
def test_list_symbols(): def test_list_symbols():
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o') sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
print('----') print('----')
@ -18,9 +19,9 @@ def test_start_end_function():
start, end = stencil_db.get_stencil_position(data, sdb.elf.byteorder) start, end = stencil_db.get_stencil_position(data, sdb.elf.byteorder)
assert start>= 0 and end >= start and end <= len(data) assert start >= 0 and end >= start and end <= len(data)
if __name__ == "__main__": if __name__ == "__main__":
test_list_symbols() test_list_symbols()
test_start_end_function() test_start_end_function()