This commit is contained in:
Nicolas Kruse 2025-09-20 23:25:07 +02:00
parent e6b94da270
commit d4bc56f1db
11 changed files with 392 additions and 30 deletions

View File

@ -1,9 +1,12 @@
import re import re
import pkgutil import pkgutil
from typing import Generator, Iterable, Any from tkinter import Variable
from typing import Generator, Iterable, Any, Literal, TypeVar
import pelfy import pelfy
from . import binwrite as binw from . import binwrite as binw
from .stencil_db import stencil_database
Operand = type['Net'] | float | int
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
return [name for name, value in scope.items() if value is var] return [name for name, value in scope.items() if value is var]
@ -18,8 +21,9 @@ _ccode = pkgutil.get_data(__name__, 'stencils.c')
assert _ccode is not None assert _ccode is not None
_function_definitions = _get_c_function_definitions(_ccode.decode('utf-8')) _function_definitions = _get_c_function_definitions(_ccode.decode('utf-8'))
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
print(_function_definitions) #print(_function_definitions)
class Node: class Node:
def __init__(self): def __init__(self):
@ -104,10 +108,10 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
typed_op = '_'.join([op] + [a.dtype for a in arg_nets]) typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
if typed_op not in _function_definitions: if typed_op not in sdb.function_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
result_type = _function_definitions[typed_op].split('_')[0] result_type = sdb.function_definitions[typed_op].split('_')[0]
result_net = Net(result_type, Op(typed_op, arg_nets)) result_net = Net(result_type, Op(typed_op, arg_nets))
@ -155,16 +159,16 @@ def get_multiuse_nets(root: list[Node]) -> set[Net]:
""" """
known_nets: set[Net] = set() known_nets: set[Net] = set()
def recursiv_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]: def recursive_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
for net in net_list: for net in net_list:
#print(net) #print(net)
if net in known_nets: if net in known_nets:
yield net yield net
else: else:
known_nets.add(net) known_nets.add(net)
yield from recursiv_node_search(net.source.args) yield from recursive_node_search(net.source.args)
return set(recursiv_node_search(op.args[0] for op in root)) return set(recursive_node_search(op.args[0] for op in root))
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]: def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
@ -215,6 +219,10 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No
def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]: def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
"""Get all const nodes in the op list
Returns:
List of tuples of (name, net, value)"""
net_lookup = {net.source: net for op in op_list for net in op.args} net_lookup = {net.source: net for op in op_list for net in op.args}
return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)] return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)]
@ -224,7 +232,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
correctly in the registers correctly in the registers
Returns: Returns:
Yields a tuples of a net and a operation. The net is the result net Yields tuples of a net and a operation. The net is the result net
from the returned operation""" from the returned operation"""
registers: list[None | Net] = [None] * 2 registers: list[None | Net] = [None] * 2
@ -252,19 +260,31 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]: def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add write operation for each new defined net if a read operation is later flowed""" """Add write operation for each new defined net if a read operation is later followed"""
stored_nets = {c[1] for c in const_list} stored_nets = {c[1] for c in const_list}
read_back_nets = {net for net, node in net_node_list if node.name.startswith('read_')} read_back_nets = {net for net, node in net_node_list if node.name.startswith('read_')}
for net, node in net_node_list: for net, node in net_node_list:
yield net, node yield net, node
if net and net in read_back_nets and net not in stored_nets: if net and net in read_back_nets and net not in stored_nets:
yield (net, Op('write_' + net.dtype, [net])) yield (net, Write(net))
stored_nets.add(net) stored_nets.add(net)
def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer: def get_variable_nets(nodes: Iterable[Node], nets_in: Iterable[Net]) -> list[Net]:
nets: set[Net] = set()
for node in nodes:
if node.name.startswith('write_'):
nets.add(node.args[0])
for net_in in nets_in:
if net_in.source.name.startswith('read_'):
nets.add(net_in)
return list(nets)
def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer:
if isinstance(end_nodes, Node): if isinstance(end_nodes, Node):
node_list = [end_nodes] node_list = [end_nodes]
else: else:
@ -276,42 +296,79 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
output_ops = list(add_read_ops(ordered_ops)) output_ops = list(add_read_ops(ordered_ops))
extended_output_ops = list(add_write_ops(output_ops, const_list)) extended_output_ops = list(add_write_ops(output_ops, const_list))
for net, node in extended_output_ops:
print(node.name)
obj_file: str = 'src/copapy/obj/stencils_x86_64.o' variable_list = get_variable_nets((node for _, node in extended_output_ops),
elf = pelfy.open_elf_file(obj_file) (net for net, _ in extended_output_ops if net))
dw = binw.data_writer(elf.byteorder) #assert False
prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'} #obj_file: str = 'src/copapy/obj/stencils_x86_64_O3.o'
prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'} #elf = pelfy.open_elf_file(obj_file)
auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC'] dw = binw.data_writer(sdb.byteorder)
auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT']
#prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'}
#prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'}
#auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC']
#auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT']
# write data sections # write auxiliary_objects to data sections
object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects) #object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects)
#dw.write_com(binw.Command.ALLOCATE_DATA)
#dw.write_int(data_section_lengths)
#for sym, out_offs, lengths in object_list:
# if sym.section and sym.section.type != 'SHT_NOBITS':
# dw.write_com(binw.Command.COPY_DATA)
# dw.write_int(out_offs)
# dw.write_int(lengths)
# dw.write_bytes(sym.data)
# print(f'+ {sym.name} {sym.fields}')
# write variables to data sections
def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]:
offset: int = 0
object_list: list[tuple[Net, int, int]] = []
for variable in variable_list:
lengths = sdb.var_size['dummy_' + variable.dtype]
object_list.append((variable, offset, lengths))
offset += (lengths + 3) // 4 * 4
return object_list, offset
object_list, data_section_lengths = variable_mem_layout(variable_list)
#data_section_lengths = object_list[-1][1] + object_list[-1][2]
dw.write_com(binw.Command.ALLOCATE_DATA) dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths) dw.write_int(data_section_lengths)
for sym, out_offs, lengths in object_list: for net, out_offs, lengths in object_list:
if sym.section and sym.section.type != 'SHT_NOBITS': if isinstance(net.source, Const):
dw.write_com(binw.Command.COPY_DATA) dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs) dw.write_int(out_offs)
dw.write_int(lengths) dw.write_int(lengths)
dw.write_bytes(sym.data) dw.write_value(net.source.value, lengths)
print(f'+ {net.dtype} {net.source.value}')
# write auxiliary_functions # write auxiliary_functions
# TODO # TODO
# write program # write program
print(list(prototype_functions.keys()))
for net, node in extended_output_ops: for net, node in extended_output_ops:
if node.name in prototype_functions: assert node.name in sdb.function_definitions, f"- Warning: {node.name} prototype not found"
#print(prototype_functions[node.name]) data = sdb.get_func_data(node.name)
pass print('*', node.name, ' '.join(f'{d:02X}' for d in data))
else: print(f"- Warning: {node.name} prototype not found") for reloc_offset, lengths, bits, reloc_type in sdb.get_relocs(node.name):
#if not relocation.symbol.name.startswith('result_'):
print(relocation)
print('-----') print('-----')

View File

@ -1,6 +1,7 @@
from enum import Enum from enum import Enum
from pelfy import elf_symbol from pelfy import elf_symbol
from typing import Literal from typing import Literal
import struct
Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2), Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
('ALLOCATE_CODE', 3), ('COPY_CODE', 4), ('ALLOCATE_CODE', 3), ('COPY_CODE', 4),
@ -67,7 +68,6 @@ def get_function_data_blob(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_sy
out_offs += (lengths + 3) // 4 * 4 out_offs += (lengths + 3) // 4 * 4
return code_list, out_offs return code_list, out_offs
class data_writer(): class data_writer():
def __init__(self, byteorder: Literal['little', 'big']): def __init__(self, byteorder: Literal['little', 'big']):
self._data: list[tuple[str, bytes, int]] = list() self._data: list[tuple[str, bytes, int]] = list()
@ -85,6 +85,18 @@ class data_writer():
def write_bytes(self, value: bytes): def write_bytes(self, value: bytes):
self._data.append((f"BYTES {len(value)}", value, 0)) self._data.append((f"BYTES {len(value)}", value, 0))
def write_value(self, value: int | float, num_bytes: int = 4):
if isinstance(value, int):
self.write_int(value, num_bytes, True)
else:
en = {'little': '<', 'big': '>'}[self.byteorder]
if num_bytes == 4:
data = struct.pack(en + 'f', value)
else:
data = struct.pack(en + 'd', value)
assert len(data) == num_bytes
self.write_bytes(data)
def print(self) -> None: def print(self) -> None:
for name, dat, flag in self._data: for name, dat, flag in self._data:
if flag: if flag:

View File

@ -3,6 +3,7 @@ from typing import Generator
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'} op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'}
def get_op_code(op: str, type1: str, type2: str, type_out: str): def get_op_code(op: str, type1: str, type2: str, type_out: str):
return f""" return f"""
void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
@ -17,11 +18,13 @@ def get_result_stubs1(type1: str):
void result_{type1}({type1} arg1); void result_{type1}({type1} arg1);
""" """
def get_result_stubs2(type1: str, type2: str): def get_result_stubs2(type1: str, type2: str):
return f""" return f"""
void result_{type1}_{type2}({type1} arg1, {type2} arg2); void result_{type1}_{type2}({type1} arg1, {type2} arg2);
""" """
def get_read_reg0_code(type1: str, type2: str, type_out: str): def get_read_reg0_code(type1: str, type2: str, type_out: str):
return f""" return f"""
void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
@ -31,6 +34,7 @@ def get_read_reg0_code(type1: str, type2: str, type_out: str):
}} }}
""" """
def get_read_reg1_code(type1: str, type2: str, type_out: str): def get_read_reg1_code(type1: str, type2: str, type_out: str):
return f""" return f"""
void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{ void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{
@ -40,6 +44,7 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str):
}} }}
""" """
def get_write_code(type1: str): def get_write_code(type1: str):
return f""" return f"""
void write_{type1}({type1} arg1) {{ void write_{type1}({type1} arg1) {{
@ -50,6 +55,7 @@ def get_write_code(type1: str):
}} }}
""" """
def permutate(*lists: list[str]) -> Generator[list[str], None, None]: def permutate(*lists: list[str]) -> Generator[list[str], None, None]:
if len(lists) == 0: if len(lists) == 0:
yield [] yield []
@ -59,6 +65,7 @@ def permutate(*lists: list[str]) -> Generator[list[str], None, None]:
for items in permutate(*rest): for items in permutate(*rest):
yield [item, *items] yield [item, *items]
if __name__ == "__main__": if __name__ == "__main__":
types = ['int', 'float'] types = ['int', 'float']
ops = ['add', 'sub', 'mul', 'div'] ops = ['add', 'sub', 'mul', 'div']

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

81
src/copapy/stencil_db.py Normal file
View File

@ -0,0 +1,81 @@
from ast import Tuple
from os import name
from tkinter import NO
import pelfy
from typing import Generator, Literal
start_marker = 0xF17ECAFE
end_marker = 0xF27ECAFE
LENGTH_CALL_INSTRUCTION = 4 # x86_64
def get_ret_function_def(symbol: pelfy.elf_symbol):
#print('*', symbol.name, symbol.section)
result_func = symbol.relocations[-1].symbol
assert result_func.name.startswith('result_')
return result_func.name[7:]
def strip_symbol(data: bytes, byteorder: Literal['little', 'big']) -> bytes:
"""Return data between start and end marker and removing last instruction (call)"""
# Find first start marker
start_index = data.find(start_marker.to_bytes(4, byteorder))
# Find last end marker
end_index = data.rfind(end_marker.to_bytes(4, byteorder), start_index)
assert start_index > -1 and end_index > -1, f"Marker not found"
return data[start_index + 4:end_index - LENGTH_CALL_INSTRUCTION]
class stencil_database():
def __init__(self, obj_file: str):
self.elf = pelfy.open_elf_file(obj_file)
#print(self.elf.symbols)
self.function_definitions = {s.name: get_ret_function_def(s) for s in self.elf.symbols
if s.info == 'STT_FUNC'}
self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols
if s.info == 'STT_FUNC'}
self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols
if s.info == 'STT_OBJECT'}
self.byteorder: Literal['little', 'big'] = self.elf.byteorder
for name in self.function_definitions.keys():
sym = self.elf.symbols[name]
sym.relocations
self.elf.symbols[name].data
def get_relocs(self, symbol_name: str) -> Generator[tuple[int, int, str], None, None]:
"""Return relocation offset relative to striped symbol start.
Yields tuples of (reloc_offset, symbol_lenght, bits, reloc_type)
1. reloc_offset: Offset of the relocation relative to the start of the stripped symbol data.
2. Length of the striped symbol.
3. Bits to patch
4. reloc_type: Type of the relocation as a string.
"""
symbol = self.elf.symbols[symbol_name]
start_index = symbol.data.find(start_marker.to_bytes(4, symbol.file.byteorder))
end_index = symbol.data.rfind(end_marker.to_bytes(4, symbol.file.byteorder), start_index)
for reloc in symbol.relocations:
reloc_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
if reloc_offset < end_index - start_index - LENGTH_CALL_INSTRUCTION:
yield (reloc_offset, reloc.bits, reloc.type)
def get_func_data(self, name: str) -> bytes:
return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder)

196
src/runner/runmem2.c Normal file
View File

@ -0,0 +1,196 @@
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/stat.h>
#include <string.h>
#include <stdint.h>
#define ALLOCATE_DATA 1
#define COPY_DATA 2
#define ALLOCATE_CODE 3
#define COPY_CODE 4
#define RELOCATE_FUNC 5
#define RELOCATE_OBJECT 6
#define SET_ENTR_POINT 64
#define END_PROG 255
#define RELOC_RELATIVE_32 0
uint8_t *data_memory;
uint8_t *executable_memory;
uint32_t executable_memory_len;
int (*entr_point)();
uint8_t *get_executable_memory(uint32_t num_bytes){
// Allocate executable memory
uint8_t *mem = (uint8_t*)mmap(NULL, num_bytes,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
return mem;
}
uint8_t *get_data_memory(uint32_t num_bytes) {
uint8_t *mem = (uint8_t*)mmap(NULL, num_bytes,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
//uint8_t *mem = (uint8_t*)malloc(num_bytes);
return mem;
}
int mark_mem_executable(){
if (mprotect(executable_memory, executable_memory_len, PROT_READ | PROT_EXEC) == -1) {
perror("mprotect failed");
return 0;
}else{
return 1;
}
}
void patch_mem_32(uint8_t *patch_addr, int32_t value){
int32_t *val_ptr = (int32_t*)patch_addr;
*val_ptr = value;
}
int relocate(uint8_t *patch_addr, uint32_t reloc_type, int32_t value){
if (reloc_type == RELOC_RELATIVE_32){
patch_mem_32(patch_addr, value);
}else{
printf("Not implemented");
return 0;
}
return 1;
}
int parse_commands(uint8_t *bytes){
int32_t value;
uint32_t command;
uint32_t reloc_type;
uint32_t offs;
int data_offs;
uint32_t size;
int err_flag = 0;
while(!err_flag){
command = *(uint32_t*)bytes;
bytes += 4;
switch(command) {
case ALLOCATE_DATA:
size = *(uint32_t*)bytes; bytes += 4;
printf("ALLOCATE_DATA size=%i\n", size);
data_memory = get_data_memory(size);
break;
case COPY_DATA:
offs = *(uint32_t*)bytes; bytes += 4;
size = *(uint32_t*)bytes; bytes += 4;
printf("COPY_DATA offs=%i size=%i\n", offs, size);
memcpy(data_memory + offs, bytes, size); bytes += size;
break;
case ALLOCATE_CODE:
size = *(uint32_t*)bytes; bytes += 4;
printf("ALLOCATE_CODE size=%i\n", size);
executable_memory = get_executable_memory(size);
executable_memory_len = size;
//printf("# d %i c %i off %i\n", data_memory, executable_memory, data_offs);
break;
case COPY_CODE:
offs = *(uint32_t*)bytes; bytes += 4;
size = *(uint32_t*)bytes; bytes += 4;
printf("COPY_CODE offs=%i size=%i\n", offs, size);
memcpy(executable_memory + offs, bytes, size); bytes += size;
break;
case RELOCATE_FUNC:
offs = *(uint32_t*)bytes; bytes += 4;
reloc_type = *(uint32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4;
printf("RELOCATE_FUNC patch_offs=%i reloc_type=%i value=%i\n",
offs, reloc_type, value);
relocate(executable_memory + offs, reloc_type, value);
break;
case RELOCATE_OBJECT:
offs = *(uint32_t*)bytes; bytes += 4;
reloc_type = *(uint32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4;
printf("RELOCATE_OBJECT patch_offs=%i reloc_type=%i value=%i\n",
offs, reloc_type, value);
data_offs = (int32_t)(data_memory - executable_memory);
if (abs(data_offs) > 0x7FFFFFFF) {
perror("code and data memory to far apart");
return EXIT_FAILURE;
}
relocate(executable_memory + offs, reloc_type, value + (int32_t)data_offs);
//printf("> %i\n", data_offs);
break;
case SET_ENTR_POINT:
uint32_t rel_entr_point = *(uint32_t*)bytes; bytes += 4;
printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point);
entr_point = (int (*)())(executable_memory + rel_entr_point);
break;
case END_PROG:
printf("END_PROG\n");
mark_mem_executable();
int ret = entr_point();
printf("Return value: %i\n", ret);
err_flag = 1;
break;
default:
printf("Unknown command\n");
err_flag = -1;
break;
}
}
return err_flag;
}
int main(int argc, char *argv[]) {
if (argc != 2) {
fprintf(stderr, "Usage: %s <binary_file>\n", argv[0]);
return EXIT_FAILURE;
}
// Open the file
int fd = open(argv[1], O_RDONLY);
if (fd < 0) {
perror("open");
return EXIT_FAILURE;
}
// Get file size
struct stat st;
if (fstat(fd, &st) < 0) {
perror("fstat");
close(fd);
return EXIT_FAILURE;
}
if (st.st_size == 0) {
fprintf(stderr, "Error: File is empty\n");
close(fd);
return EXIT_FAILURE;
}
//uint8_t *file_buff = get_data_memory((uint32_t)st.st_size);
uint8_t *file_buff = (uint8_t*)malloc((size_t)st.st_size);
// Read file into allocated memory
if (read(fd, file_buff, (long unsigned int)st.st_size) != st.st_size) {
perror("read");
close(fd);
return EXIT_FAILURE;
}
close(fd);
parse_commands(file_buff);
munmap(executable_memory, executable_memory_len);
return EXIT_SUCCESS;
}

View File

@ -21,7 +21,7 @@ def test_compile():
il = rc.compile_to_instruction_list(out) il = rc.compile_to_instruction_list(out)
#print('#', il.print()) print('#', il.print())
if __name__ == "__main__": if __name__ == "__main__":

9
tests/test_stencil.db.py Normal file
View File

@ -0,0 +1,9 @@
from copapy import stencil_database
if __name__ == "__main__":
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
print('----')
#print(sdb.function_definitions)
for sym_name in sdb.function_definitions.keys():
print('-', sym_name)
print(list(sdb.get_relocs(sym_name)))