code cleaned up, passes flake8 and mypy

This commit is contained in:
Nicolas 2025-10-01 23:09:49 +02:00
parent 3b7fcdb88b
commit 5b7ca52b7c
11 changed files with 159 additions and 261 deletions

View File

@ -13,6 +13,7 @@ exclude =
build, build,
dist, dist,
.conda, .conda,
.venv,
tests/autogenerated_* tests/autogenerated_*
# Enable specific plugins or options # Enable specific plugins or options

View File

@ -3,22 +3,25 @@ from typing import Generator, Iterable, Any
from . import binwrite as binw from . import binwrite as binw
from .stencil_db import stencil_database from .stencil_db import stencil_database
Operand = type['Net'] | float | int Operand = type['Net'] | float | int
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
return [name for name, value in scope.items() if value is var] return [name for name, value in scope.items() if value is var]
# _ccode = pkgutil.get_data(__name__, 'stencils.c') # _ccode = pkgutil.get_data(__name__, 'stencils.c')
# assert _ccode is not None # assert _ccode is not None
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o') sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
class Node: class Node:
def __init__(self): def __init__(self) -> None:
self.args: list[Net] = [] self.args: list[Net] = []
self.name: str = '' self.name: str = ''
def __repr__(self): def __repr__(self) -> str:
#return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})" #return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else self.value})"
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})" return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, Const) else '')})"
@ -44,19 +47,19 @@ class Net:
def __radd__(self, other: Any) -> 'Net': def __radd__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True) return _add_op('add', [self, other], True)
def __sub__ (self, other: Any) -> 'Net': def __sub__(self, other: Any) -> 'Net':
return _add_op('sub', [self, other]) return _add_op('sub', [self, other])
def __rsub__ (self, other: Any) -> 'Net': def __rsub__(self, other: Any) -> 'Net':
return _add_op('sub', [other, self]) return _add_op('sub', [other, self])
def __truediv__ (self, other: Any) -> 'Net': def __truediv__(self, other: Any) -> 'Net':
return _add_op('div', [self, other]) return _add_op('div', [self, other])
def __rtruediv__ (self, other: Any) -> 'Net': def __rtruediv__(self, other: Any) -> 'Net':
return _add_op('div', [other, self]) return _add_op('div', [other, self])
def __repr__(self): def __repr__(self) -> str:
names = get_var_name(self) names = get_var_name(self)
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}" return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}"
@ -94,7 +97,7 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
if commutative: if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
typed_op = '_'.join([op] + [a.dtype for a in arg_nets]) typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
if typed_op not in sdb.function_definitions: if typed_op not in sdb.function_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
@ -141,24 +144,6 @@ def const_vector3d(x: float, y: float, z: float) -> vec3d:
return vec3d((const(x), const(y), const(z))) return vec3d((const(x), const(y), const(z)))
def get_multiuse_nets(root: list[Node]) -> set[Net]:
"""Finds all nets that get accessed more than one time. Therefore
storage on the heap might be better.
"""
known_nets: set[Net] = set()
def recursive_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
for net in net_list:
#print(net)
if net in known_nets:
yield net
else:
known_nets.add(net)
yield from recursive_node_search(net.source.args)
return set(recursive_node_search(op.args[0] for op in root))
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]: def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
"""List of all possible paths. Ops in order of execution (output at the end) """List of all possible paths. Ops in order of execution (output at the end)
""" """
@ -240,7 +225,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
registers[i] = net registers[i] = net
if node in net_lookup: if node in net_lookup:
yield None , node yield None, node
registers[0] = net_lookup[node] registers[0] = net_lookup[node]
else: else:
print('--->', node) print('--->', node)
@ -283,7 +268,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
if isinstance(end_nodes, Node): if isinstance(end_nodes, Node):
node_list = [end_nodes] node_list = [end_nodes]
else: else:
node_list = end_nodes node_list = list(end_nodes)
path_segments = list(get_path_segments(node_list)) path_segments = list(get_path_segments(node_list))
ordered_ops = list(get_ordered_ops(path_segments)) ordered_ops = list(get_ordered_ops(path_segments))
@ -297,10 +282,8 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
# Get all nets associated with heap memory # Get all nets associated with heap memory
variable_list = get_nets(const_list, extended_output_ops) variable_list = get_nets(const_list, extended_output_ops)
dw = binw.data_writer(sdb.byteorder) dw = binw.data_writer(sdb.byteorder)
def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]: def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]:
offset: int = 0 offset: int = 0
object_list: list[tuple[Net, int, int]] = [] object_list: list[tuple[Net, int, int]] = []
@ -312,14 +295,14 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
return object_list, offset return object_list, offset
object_list, data_section_lengths = variable_mem_layout(variable_list) object_list, data_section_lengths = variable_mem_layout(variable_list)
#data_section_lengths = object_list[-1][1] + object_list[-1][2] # Write data
dw.write_com(binw.Command.ALLOCATE_DATA) dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths) dw.write_int(data_section_lengths)
for net, out_offs, lengths in object_list: for net, out_offs, lengths in object_list:
dw.add_variable(net, out_offs, lengths, net.dtype)
if isinstance(net.source, Const): if isinstance(net.source, Const):
dw.write_com(binw.Command.COPY_DATA) dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs) dw.write_int(out_offs)
@ -330,7 +313,7 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
# write auxiliary_functions # write auxiliary_functions
# TODO # TODO
# Prepare program data and relocations # Prepare program code and relocations
object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list} object_addr_lookp = {net: out_offs for net, out_offs, _ in object_list}
data_list: list[bytes] = [] data_list: list[bytes] = []
patch_list: list[tuple[int, int, int]] = [] patch_list: list[tuple[int, int, int]] = []
@ -383,7 +366,12 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
dw.write_com(binw.Command.SET_ENTR_POINT) dw.write_com(binw.Command.SET_ENTR_POINT)
dw.write_int(0) dw.write_int(0)
# run program command
dw.write_com(binw.Command.END_PROG)
return dw return dw
def read_variable(bw: binw.data_writer, net: Net) -> None:
assert net in bw.variables, f"Variable {net} not found in data writer variables"
addr, lengths, _ = bw.variables[net]
bw.write_com(binw.Command.READ_DATA)
bw.write_int(addr)
bw.write_int(lengths)

View File

@ -1,85 +1,41 @@
from enum import Enum from enum import Enum
from pelfy import elf_symbol from typing import Literal, Any
from typing import Literal
import struct import struct
Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2), Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
('ALLOCATE_CODE', 3), ('COPY_CODE', 4), ('ALLOCATE_CODE', 3), ('COPY_CODE', 4),
('PATCH_FUNC', 5), ('PATCH_OBJECT', 6), ('PATCH_FUNC', 5), ('PATCH_OBJECT', 6),
('SET_ENTR_POINT', 64), ('END_PROG', 255)]) ('SET_ENTR_POINT', 64), ('READ_DATA', 65),
('END_PROG', 255)])
def get_variable_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int]], int]:
object_list: list[tuple[elf_symbol, int, int]] = []
out_offs = 0
for sym in symbols:
assert sym.info == 'STT_OBJECT'
lengths = sym.fields['st_size'],
object_list.append((sym, out_offs, lengths))
out_offs += (lengths + 3) // 4 * 4
return object_list, out_offs
def get_function_data(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
code_list: list[tuple[elf_symbol, int, int, int]] = []
out_offs = 0
for sym in symbols:
assert sym.info == 'STT_FUNC'
lengths = sym.fields['st_size']
#if strip_function:
# assert False, 'Not implemente'
# TODO: Strip functions
# Symbol, start out_offset in symbol, offset in output file, output lengths
# Set in_sym_out_offs and lengths
in_sym_offs = 0
code_list.append((sym, in_sym_offs, out_offs, lengths))
# out_offs += (lengths + 3) // 4 * 4
out_offs += lengths # should be aligned by default?
return code_list, out_offs
def get_function_data_blob(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_symbol, int, int, int]], int]:
code_list: list[tuple[elf_symbol, int, int, int]] = []
out_offs = 0
for sym in symbols:
assert sym.info == 'STT_FUNC'
lengths = sym.fields['st_size']
#if strip_function:
# assert False, 'Not implemente'
# TODO: Strip functions
# Symbol, start out_offset in symbol, offset in output file, output lengths
# Set in_sym_out_offs and lengths
in_sym_offs = 0
code_list.append((sym, in_sym_offs, out_offs, lengths))
out_offs += (lengths + 3) // 4 * 4
return code_list, out_offs
class data_writer(): class data_writer():
def __init__(self, byteorder: Literal['little', 'big']): def __init__(self, byteorder: Literal['little', 'big']):
self._data: list[tuple[str, bytes, int]] = list() self._data: list[tuple[str, bytes, int]] = list()
self.byteorder = byteorder self.byteorder: Literal['little', 'big'] = byteorder
self.variables: dict[Any, tuple[int, int, str]] = dict()
def write_int(self, value: int, num_bytes: int = 4, signed: bool = False): def add_variable(self, net: Any, addr: int, lengths: int, var_type: str) -> None:
self.variables[net] = (addr, lengths, var_type)
def write_int(self, value: int, num_bytes: int = 4, signed: bool = False) -> None:
self._data.append((f"INT {value}", value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=signed), 0)) self._data.append((f"INT {value}", value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=signed), 0))
def write_com(self, value: Enum, num_bytes: int = 4): def write_com(self, value: Enum, num_bytes: int = 4) -> None:
self._data.append((value.name, value.value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=False), 1)) self._data.append((value.name, value.value.to_bytes(length=num_bytes, byteorder=self.byteorder, signed=False), 1))
def write_byte(self, value: int): def write_byte(self, value: int) -> None:
self._data.append((f"BYTE {value}", bytes([value]), 0)) self._data.append((f"BYTE {value}", bytes([value]), 0))
def write_bytes(self, value: bytes): def write_bytes(self, value: bytes) -> None:
self._data.append((f"BYTES {len(value)}", value, 0)) self._data.append((f"BYTES {len(value)}", value, 0))
def write_value(self, value: int | float, num_bytes: int = 4): def write_value(self, value: int | float, num_bytes: int = 4) -> None:
if isinstance(value, int): if isinstance(value, int):
self.write_int(value, num_bytes, True) self.write_int(value, num_bytes, True)
else: else:
# 32 bit or 64 bit float
en = {'little': '<', 'big': '>'}[self.byteorder] en = {'little': '<', 'big': '>'}[self.byteorder]
if num_bytes == 4: if num_bytes == 4:
data = struct.pack(en + 'f', value) data = struct.pack(en + 'f', value)
@ -97,14 +53,6 @@ class data_writer():
def get_data(self) -> bytes: def get_data(self) -> bytes:
return b''.join(dat for _, dat, _ in self._data) return b''.join(dat for _, dat, _ in self._data)
def to_file(self, path: str): def to_file(self, path: str) -> None:
with open(path, 'wb') as f: with open(path, 'wb') as f:
f.write(self.get_data()) f.write(self.get_data())
def get_c_consts() -> str:
ret: list[str] = []
for c in Command:
ret.append (f"#define {c.name} {c.value}")
for c in PatchType:
ret.append(f"#define {c.name} {c.value}")
return '\n'.join(ret)

View File

@ -4,23 +4,27 @@ from typing import Generator
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'} op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'}
def get_function_start(): def get_function_start() -> str:
return """ return """
void function_start(){ int function_start(){
result_int(0); result_int(0);
asm volatile (".long 0xF27ECAFE"); asm volatile (".long 0xF27ECAFE");
return 1;
} }
""" """
def get_function_end():
def get_function_end() -> str:
return """ return """
void function_end(){ int function_end(){
result_int(0); result_int(0);
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
return 1;
} }
""" """
def get_op_code(op: str, type1: str, type2: str, type_out: str):
def get_op_code(op: str, type1: str, type2: str, type_out: str) -> str:
return f""" return f"""
void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
@ -29,19 +33,20 @@ def get_op_code(op: str, type1: str, type2: str, type_out: str):
}} }}
""" """
def get_result_stubs1(type1: str):
def get_result_stubs1(type1: str) -> str:
return f""" return f"""
void result_{type1}({type1} arg1); void result_{type1}({type1} arg1);
""" """
def get_result_stubs2(type1: str, type2: str): def get_result_stubs2(type1: str, type2: str) -> str:
return f""" return f"""
void result_{type1}_{type2}({type1} arg1, {type2} arg2); void result_{type1}_{type2}({type1} arg1, {type2} arg2);
""" """
def get_read_reg0_code(type1: str, type2: str, type_out: str): def get_read_reg0_code(type1: str, type2: str, type_out: str) -> str:
return f""" return f"""
void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
@ -51,7 +56,7 @@ def get_read_reg0_code(type1: str, type2: str, type_out: str):
""" """
def get_read_reg1_code(type1: str, type2: str, type_out: str): def get_read_reg1_code(type1: str, type2: str, type_out: str) -> str:
return f""" return f"""
void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
@ -61,7 +66,7 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str):
""" """
def get_write_code(type1: str): def get_write_code(type1: str) -> str:
return f""" return f"""
void write_{type1}({type1} arg1) {{ void write_{type1}({type1} arg1) {{
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");

View File

@ -1,95 +0,0 @@
//notes:
//Helper functions
void result_int(int ret1){
asm ("");
}
void result_float(float ret1){
asm ("");
}
void result_float_float(float ret1, float ret2){
asm ("");
}
//Operations
void add_int_int(int arg1, int arg2){
result_int(arg1 + arg2);
}
void add_float_float(float arg1, float arg2){
result_float(arg1 + arg2);
}
void add_float_int(float arg1, float arg2){
result_float(arg1 + arg2);
}
void sub_int_int(int arg1, int arg2){
result_int(arg1 - arg2);
}
void sub_float_float(float arg1, float arg2){
result_float(arg1 - arg2);
}
void sub_float_int(float arg1, int arg2){
result_float(arg1 - arg2);
}
void sub_int_float(int arg1, float arg2){
result_float(arg1 - arg2);
}
void mul_int_int(int arg1, int arg2){
result_int(arg1 * arg2);
}
void mul_float_float(float arg1, float arg2){
result_float(arg1 * arg2);
}
void mul_float_int(float arg1, int arg2){
result_float(arg1 + arg2);
}
void div_int_int(int arg1, int arg2){
result_int(arg1 / arg2);
}
void div_float_float(float arg1, float arg2){
result_float(arg1 / arg2);
}
void div_float_int(float arg1, int arg2){
result_float(arg1 / arg2);
}
void div_int_float(int arg1, float arg2){
result_float(arg1 / arg2);
}
//Read global variables from heap
int read_int_ret = 1337;
void read_int(){
result_int(read_int_ret);
}
float read_float_ret = 1337;
void read_float(){
result_float(read_float_ret);
}
void read_float_2(float arg1){
result_float_float(arg1, read_float_ret);
}

View File

@ -13,6 +13,7 @@ LENGTH_CALL_INSTRUCTION = 5 # x86_64
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)]) RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
@dataclass @dataclass
class patch_entry: class patch_entry:
""" """
@ -75,10 +76,10 @@ class stencil_database():
if s.info == 'STT_FUNC'} if s.info == 'STT_FUNC'}
self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols
if s.info == 'STT_FUNC'} if s.info == 'STT_FUNC'}
self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols
if s.info == 'STT_OBJECT'} if s.info == 'STT_OBJECT'}
self.byteorder: Literal['little', 'big'] = self.elf.byteorder self.byteorder: Literal['little', 'big'] = self.elf.byteorder
@ -87,7 +88,6 @@ class stencil_database():
sym.relocations sym.relocations
self.elf.symbols[name].data self.elf.symbols[name].data
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]: def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]:
"""Return patch positions """Return patch positions
""" """
@ -107,8 +107,5 @@ class stencil_database():
if patch.addr < end_index - start_index: if patch.addr < end_index - start_index:
yield patch yield patch
def get_func_data(self, name: str) -> bytes: def get_func_data(self, name: str) -> bytes:
return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder) return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder)

View File

@ -223,13 +223,15 @@
asm volatile (".long 0xF27ECAFE"); asm volatile (".long 0xF27ECAFE");
} }
void function_start(){ int function_start(){
result_int(0); result_int(0);
asm volatile (".long 0xF27ECAFE"); asm volatile (".long 0xF27ECAFE");
return 1;
} }
void function_end(){ int function_end(){
result_int(0); result_int(0);
asm volatile (".long 0xF17ECAFE"); asm volatile (".long 0xF17ECAFE");
return 1;
} }

View File

@ -14,6 +14,7 @@
#define PATCH_FUNC 5 #define PATCH_FUNC 5
#define PATCH_OBJECT 6 #define PATCH_OBJECT 6
#define SET_ENTR_POINT 64 #define SET_ENTR_POINT 64
#define READ_DATA 65
#define END_PROG 255 #define END_PROG 255
#define PATCH_RELATIVE_32 0 #define PATCH_RELATIVE_32 0
@ -133,16 +134,27 @@ int parse_commands(uint8_t *bytes){
rel_entr_point = *(uint32_t*)bytes; bytes += 4; rel_entr_point = *(uint32_t*)bytes; bytes += 4;
printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point); printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point);
entr_point = (int (*)())(executable_memory + rel_entr_point); entr_point = (int (*)())(executable_memory + rel_entr_point);
mark_mem_executable();
int ret = entr_point();
printf("Return value: %i\n", ret);
break; break;
case END_PROG: case END_PROG:
printf("END_PROG\n"); printf("END_PROG\n");
mark_mem_executable();
int ret = entr_point();
printf("Return value: %i\n", ret);
err_flag = 1; err_flag = 1;
break; break;
case READ_DATA:
offs = *(uint32_t*)bytes; bytes += 4;
size = *(uint32_t*)bytes; bytes += 4;
printf("READ_DATA offs=%i size=%i data=", offs, size);
for (uint32_t i = 0; i < size; i++) {
printf("%02X ", data_memory[offs + i]);
}
printf("\n");
break;
default: default:
printf("Unknown command\n"); printf("Unknown command\n");
err_flag = -1; err_flag = -1;

View File

@ -1,9 +1,10 @@
from copapy import Write, const from copapy import Write, const
import copapy as rc import copapy as rc
def test_ast_generation(): def test_ast_generation():
c1 = const(1.11) c1 = const(1.11)
c2 = const(2.22) #c2 = const(2.22)
#c3 = const(3.33) #c3 = const(3.33)
#i1 = c1 + c2 #i1 = c1 + c2
@ -12,36 +13,31 @@ def test_ast_generation():
#r1 = i1 + i3 #r1 = i1 + i3
#r2 = i3 * i2 #r2 = i3 * i2
i1 = c1 * 2 i1 = c1 * 2
i2 = i1 + 3 r1 = i1 + 7
out = Write(r1)
r1 = i1 + i2
r2 = c2 + 4 + c1
out = [Write(r1), Write(r2)]
print(out) print(out)
print('--') print('-- get_path_segments:')
path_segments = list(rc.get_path_segments(out)) path_segments = list(rc.get_path_segments([out]))
for p in path_segments: for p in path_segments:
print(p) print(p)
print('--') print('-- get_ordered_ops:')
ordered_ops = list(rc.get_ordered_ops(path_segments)) ordered_ops = list(rc.get_ordered_ops(path_segments))
for p in path_segments: for p in path_segments:
print(p) print(p)
print('--') print('-- get_consts:')
const_list = rc.get_consts(ordered_ops) const_list = rc.get_consts(ordered_ops)
for p in const_list: for p in const_list:
print(p) print(p)
print('--') print('-- add_read_ops:')
output_ops = list(rc.add_read_ops(ordered_ops)) output_ops = list(rc.add_read_ops(ordered_ops))
for p in output_ops: for p in output_ops:
print(p) print(p)
print('--') print('-- add_write_ops:')
extended_output_ops = list(rc.add_write_ops(output_ops, const_list)) extended_output_ops = list(rc.add_write_ops(output_ops, const_list))
for p in extended_output_ops: for p in extended_output_ops:

View File

@ -1,6 +1,9 @@
from copapy import Write, const from copapy import Write, const
import copapy as rc import copapy
import subprocess import subprocess
import struct
from copapy import binwrite as binw
def run_command(command: list[str], encoding: str = 'utf8') -> str: def run_command(command: list[str], encoding: str = 'utf8') -> str:
process = subprocess.Popen(command, stdout=subprocess.PIPE) process = subprocess.Popen(command, stdout=subprocess.PIPE)
@ -9,9 +12,10 @@ def run_command(command: list[str], encoding: str = 'utf8') -> str:
assert error is None, f"Error occurred: {error.decode(encoding)}" assert error is None, f"Error occurred: {error.decode(encoding)}"
return output.decode(encoding) return output.decode(encoding)
def test_compile():
c1 = const(1.11) def test_example():
c2 = const(2.22) c1 = 1.11
c2 = 2.22
i1 = c1 * 2 i1 = c1 * 2
i2 = i1 + 3 i2 = i1 + 3
@ -19,9 +23,48 @@ def test_compile():
r1 = i1 + i2 r1 = i1 + i2
r2 = c2 + 4 + c1 r2 = c2 + 4 + c1
out = [Write(r1), Write(r2)] en = {'little': '<', 'big': '>'}['little']
data = struct.pack(en + 'f', r1)
print("example r1 " + ' '.join(f'{b:02X}' for b in data))
il = rc.compile_to_instruction_list(out) data = struct.pack(en + 'f', r2)
print("example r2 " + ' '.join(f'{b:02X}' for b in data))
# assert False
# example r1 7B 14 EE 40
# example r2 5C 8F EA 40
def test_compile():
print(run_command(['bash', 'build.sh']))
c1 = const(4)
#c2 = const(2)
#i1 = c1 * 2
#i2 = i1 + 3
#r1 = i1 + i2
#r2 = c2 + 4 + c1
#out = [Write(r1), Write(r2)]
i1 = c1 * 2
r1 = i1 + 7
out = Write(r1)
il = copapy.compile_to_instruction_list(out)
#copapy.read_variable(il, i1)
copapy.read_variable(il, r1)
il.write_com(binw.Command.READ_DATA)
il.write_int(0)
il.write_int(36)
# run program command
il.write_com(binw.Command.END_PROG)
print('#', il.print()) print('#', il.print())
@ -30,7 +73,7 @@ def test_compile():
result = run_command(['./bin/runmem2', 'test.copapy']) result = run_command(['./bin/runmem2', 'test.copapy'])
print(result) print(result)
assert 'Return value: 0' in result assert 'Return value: 1' in result
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
from copapy import stencil_database from copapy import stencil_database
from copapy import stencil_db from copapy import stencil_db
def test_list_symbols(): def test_list_symbols():
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o') sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
print('----') print('----')
@ -18,7 +19,7 @@ def test_start_end_function():
start, end = stencil_db.get_stencil_position(data, sdb.elf.byteorder) start, end = stencil_db.get_stencil_position(data, sdb.elf.byteorder)
assert start>= 0 and end >= start and end <= len(data) assert start >= 0 and end >= start and end <= len(data)
if __name__ == "__main__": if __name__ == "__main__":