This commit is contained in:
Nicolas Kruse 2025-09-20 23:25:07 +02:00
parent e6b94da270
commit d4bc56f1db
11 changed files with 392 additions and 30 deletions

View File

@ -1,9 +1,12 @@
import re
import pkgutil
from typing import Generator, Iterable, Any
from tkinter import Variable
from typing import Generator, Iterable, Any, Literal, TypeVar
import pelfy
from . import binwrite as binw
from .stencil_db import stencil_database
Operand = type['Net'] | float | int
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
return [name for name, value in scope.items() if value is var]
@ -18,8 +21,9 @@ _ccode = pkgutil.get_data(__name__, 'stencils.c')
assert _ccode is not None
_function_definitions = _get_c_function_definitions(_ccode.decode('utf-8'))
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
print(_function_definitions)
#print(_function_definitions)
class Node:
def __init__(self):
@ -104,10 +108,10 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
if typed_op not in _function_definitions:
if typed_op not in sdb.function_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
result_type = _function_definitions[typed_op].split('_')[0]
result_type = sdb.function_definitions[typed_op].split('_')[0]
result_net = Net(result_type, Op(typed_op, arg_nets))
@ -155,16 +159,16 @@ def get_multiuse_nets(root: list[Node]) -> set[Net]:
"""
known_nets: set[Net] = set()
def recursiv_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
def recursive_node_search(net_list: Iterable[Net]) -> Generator[Net, None, None]:
for net in net_list:
#print(net)
if net in known_nets:
yield net
else:
known_nets.add(net)
yield from recursiv_node_search(net.source.args)
yield from recursive_node_search(net.source.args)
return set(recursiv_node_search(op.args[0] for op in root))
return set(recursive_node_search(op.args[0] for op in root))
def get_path_segments(root: Iterable[Node]) -> Generator[list[Node], None, None]:
@ -215,6 +219,10 @@ def get_ordered_ops(path_segments: list[list[Node]]) -> Generator[Node, None, No
def get_consts(op_list: list[Node]) -> list[tuple[str, Net, float | int]]:
"""Get all const nodes in the op list
Returns:
List of tuples of (name, net, value)"""
net_lookup = {net.source: net for op in op_list for net in op.args}
return [(n.name, net_lookup[n], n.value) for n in op_list if isinstance(n, Const)]
@ -224,7 +232,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
correctly in the registers
Returns:
Yields a tuples of a net and a operation. The net is the result net
Yields tuples of a net and a operation. The net is the result net
from the returned operation"""
registers: list[None | Net] = [None] * 2
@ -252,19 +260,31 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_list: list[tuple[str, Net, float | int]]) -> Generator[tuple[Net | None, Node], None, None]:
"""Add write operation for each new defined net if a read operation is later flowed"""
"""Add write operation for each new defined net if a read operation is later followed"""
stored_nets = {c[1] for c in const_list}
read_back_nets = {net for net, node in net_node_list if node.name.startswith('read_')}
for net, node in net_node_list:
yield net, node
if net and net in read_back_nets and net not in stored_nets:
yield (net, Op('write_' + net.dtype, [net]))
yield (net, Write(net))
stored_nets.add(net)
def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer:
def get_variable_nets(nodes: Iterable[Node], nets_in: Iterable[Net]) -> list[Net]:
nets: set[Net] = set()
for node in nodes:
if node.name.startswith('write_'):
nets.add(node.args[0])
for net_in in nets_in:
if net_in.source.name.startswith('read_'):
nets.add(net_in)
return list(nets)
def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_writer:
if isinstance(end_nodes, Node):
node_list = [end_nodes]
else:
@ -276,42 +296,79 @@ def compile_to_instruction_list(end_nodes: Iterable[Node] | Node) -> binw.data_w
output_ops = list(add_read_ops(ordered_ops))
extended_output_ops = list(add_write_ops(output_ops, const_list))
for net, node in extended_output_ops:
print(node.name)
obj_file: str = 'src/copapy/obj/stencils_x86_64.o'
elf = pelfy.open_elf_file(obj_file)
variable_list = get_variable_nets((node for _, node in extended_output_ops),
(net for net, _ in extended_output_ops if net))
dw = binw.data_writer(elf.byteorder)
#assert False
prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'}
prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'}
#obj_file: str = 'src/copapy/obj/stencils_x86_64_O3.o'
#elf = pelfy.open_elf_file(obj_file)
auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC']
auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT']
dw = binw.data_writer(sdb.byteorder)
#prototype_functions = {s.name: s for s in elf.symbols if s.info == 'STT_FUNC'}
#prototype_objects = {s.name: s for s in elf.symbols if s.info == 'STT_OBJECT'}
#auxiliary_functions = [s for s in elf.symbols if s.info == 'STT_FUNC']
#auxiliary_objects = [s for s in elf.symbols if s.info == 'STT_OBJECT']
# write data sections
object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects)
# write auxiliary_objects to data sections
#object_list, data_section_lengths = binw.get_variable_data(auxiliary_objects)
#dw.write_com(binw.Command.ALLOCATE_DATA)
#dw.write_int(data_section_lengths)
#for sym, out_offs, lengths in object_list:
# if sym.section and sym.section.type != 'SHT_NOBITS':
# dw.write_com(binw.Command.COPY_DATA)
# dw.write_int(out_offs)
# dw.write_int(lengths)
# dw.write_bytes(sym.data)
# print(f'+ {sym.name} {sym.fields}')
# write variables to data sections
def variable_mem_layout(variable_list: list[Net]) -> tuple[list[tuple[Net, int, int]], int]:
offset: int = 0
object_list: list[tuple[Net, int, int]] = []
for variable in variable_list:
lengths = sdb.var_size['dummy_' + variable.dtype]
object_list.append((variable, offset, lengths))
offset += (lengths + 3) // 4 * 4
return object_list, offset
object_list, data_section_lengths = variable_mem_layout(variable_list)
#data_section_lengths = object_list[-1][1] + object_list[-1][2]
dw.write_com(binw.Command.ALLOCATE_DATA)
dw.write_int(data_section_lengths)
for sym, out_offs, lengths in object_list:
if sym.section and sym.section.type != 'SHT_NOBITS':
for net, out_offs, lengths in object_list:
if isinstance(net.source, Const):
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs)
dw.write_int(lengths)
dw.write_bytes(sym.data)
dw.write_value(net.source.value, lengths)
print(f'+ {net.dtype} {net.source.value}')
# write auxiliary_functions
# TODO
# write program
print(list(prototype_functions.keys()))
for net, node in extended_output_ops:
if node.name in prototype_functions:
#print(prototype_functions[node.name])
pass
else: print(f"- Warning: {node.name} prototype not found")
assert node.name in sdb.function_definitions, f"- Warning: {node.name} prototype not found"
data = sdb.get_func_data(node.name)
print('*', node.name, ' '.join(f'{d:02X}' for d in data))
for reloc_offset, lengths, bits, reloc_type in sdb.get_relocs(node.name):
#if not relocation.symbol.name.startswith('result_'):
print(relocation)
print('-----')

View File

@ -1,6 +1,7 @@
from enum import Enum
from pelfy import elf_symbol
from typing import Literal
import struct
Command = Enum('Command', [('ALLOCATE_DATA', 1), ('COPY_DATA', 2),
('ALLOCATE_CODE', 3), ('COPY_CODE', 4),
@ -67,7 +68,6 @@ def get_function_data_blob(symbols: list[elf_symbol]) -> tuple[list[tuple[elf_sy
out_offs += (lengths + 3) // 4 * 4
return code_list, out_offs
class data_writer():
def __init__(self, byteorder: Literal['little', 'big']):
self._data: list[tuple[str, bytes, int]] = list()
@ -85,6 +85,18 @@ class data_writer():
def write_bytes(self, value: bytes):
self._data.append((f"BYTES {len(value)}", value, 0))
def write_value(self, value: int | float, num_bytes: int = 4):
if isinstance(value, int):
self.write_int(value, num_bytes, True)
else:
en = {'little': '<', 'big': '>'}[self.byteorder]
if num_bytes == 4:
data = struct.pack(en + 'f', value)
else:
data = struct.pack(en + 'd', value)
assert len(data) == num_bytes
self.write_bytes(data)
def print(self) -> None:
for name, dat, flag in self._data:
if flag:

View File

@ -3,6 +3,7 @@ from typing import Generator
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/'}
def get_op_code(op: str, type1: str, type2: str, type_out: str):
return f"""
void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
@ -17,11 +18,13 @@ def get_result_stubs1(type1: str):
void result_{type1}({type1} arg1);
"""
def get_result_stubs2(type1: str, type2: str):
return f"""
void result_{type1}_{type2}({type1} arg1, {type2} arg2);
"""
def get_read_reg0_code(type1: str, type2: str, type_out: str):
return f"""
void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
@ -31,6 +34,7 @@ def get_read_reg0_code(type1: str, type2: str, type_out: str):
}}
"""
def get_read_reg1_code(type1: str, type2: str, type_out: str):
return f"""
void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{
@ -40,6 +44,7 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str):
}}
"""
def get_write_code(type1: str):
return f"""
void write_{type1}({type1} arg1) {{
@ -50,6 +55,7 @@ def get_write_code(type1: str):
}}
"""
def permutate(*lists: list[str]) -> Generator[list[str], None, None]:
if len(lists) == 0:
yield []
@ -59,6 +65,7 @@ def permutate(*lists: list[str]) -> Generator[list[str], None, None]:
for items in permutate(*rest):
yield [item, *items]
if __name__ == "__main__":
types = ['int', 'float']
ops = ['add', 'sub', 'mul', 'div']

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

81
src/copapy/stencil_db.py Normal file
View File

@ -0,0 +1,81 @@
from ast import Tuple
from os import name
from tkinter import NO
import pelfy
from typing import Generator, Literal
start_marker = 0xF17ECAFE
end_marker = 0xF27ECAFE
LENGTH_CALL_INSTRUCTION = 4 # x86_64
def get_ret_function_def(symbol: pelfy.elf_symbol):
#print('*', symbol.name, symbol.section)
result_func = symbol.relocations[-1].symbol
assert result_func.name.startswith('result_')
return result_func.name[7:]
def strip_symbol(data: bytes, byteorder: Literal['little', 'big']) -> bytes:
"""Return data between start and end marker and removing last instruction (call)"""
# Find first start marker
start_index = data.find(start_marker.to_bytes(4, byteorder))
# Find last end marker
end_index = data.rfind(end_marker.to_bytes(4, byteorder), start_index)
assert start_index > -1 and end_index > -1, f"Marker not found"
return data[start_index + 4:end_index - LENGTH_CALL_INSTRUCTION]
class stencil_database():
def __init__(self, obj_file: str):
self.elf = pelfy.open_elf_file(obj_file)
#print(self.elf.symbols)
self.function_definitions = {s.name: get_ret_function_def(s) for s in self.elf.symbols
if s.info == 'STT_FUNC'}
self.data = {s.name: strip_symbol(s.data, self.elf.byteorder) for s in self.elf.symbols
if s.info == 'STT_FUNC'}
self.var_size = {s.name: s.fields['st_size'] for s in self.elf.symbols
if s.info == 'STT_OBJECT'}
self.byteorder: Literal['little', 'big'] = self.elf.byteorder
for name in self.function_definitions.keys():
sym = self.elf.symbols[name]
sym.relocations
self.elf.symbols[name].data
def get_relocs(self, symbol_name: str) -> Generator[tuple[int, int, str], None, None]:
"""Return relocation offset relative to striped symbol start.
Yields tuples of (reloc_offset, symbol_lenght, bits, reloc_type)
1. reloc_offset: Offset of the relocation relative to the start of the stripped symbol data.
2. Length of the striped symbol.
3. Bits to patch
4. reloc_type: Type of the relocation as a string.
"""
symbol = self.elf.symbols[symbol_name]
start_index = symbol.data.find(start_marker.to_bytes(4, symbol.file.byteorder))
end_index = symbol.data.rfind(end_marker.to_bytes(4, symbol.file.byteorder), start_index)
for reloc in symbol.relocations:
reloc_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
if reloc_offset < end_index - start_index - LENGTH_CALL_INSTRUCTION:
yield (reloc_offset, reloc.bits, reloc.type)
def get_func_data(self, name: str) -> bytes:
return strip_symbol(self.elf.symbols[name].data, self.elf.byteorder)

196
src/runner/runmem2.c Normal file
View File

@ -0,0 +1,196 @@
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/stat.h>
#include <string.h>
#include <stdint.h>
#define ALLOCATE_DATA 1
#define COPY_DATA 2
#define ALLOCATE_CODE 3
#define COPY_CODE 4
#define RELOCATE_FUNC 5
#define RELOCATE_OBJECT 6
#define SET_ENTR_POINT 64
#define END_PROG 255
#define RELOC_RELATIVE_32 0
uint8_t *data_memory;
uint8_t *executable_memory;
uint32_t executable_memory_len;
int (*entr_point)();
uint8_t *get_executable_memory(uint32_t num_bytes){
// Allocate executable memory
uint8_t *mem = (uint8_t*)mmap(NULL, num_bytes,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
return mem;
}
uint8_t *get_data_memory(uint32_t num_bytes) {
uint8_t *mem = (uint8_t*)mmap(NULL, num_bytes,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
//uint8_t *mem = (uint8_t*)malloc(num_bytes);
return mem;
}
int mark_mem_executable(){
if (mprotect(executable_memory, executable_memory_len, PROT_READ | PROT_EXEC) == -1) {
perror("mprotect failed");
return 0;
}else{
return 1;
}
}
void patch_mem_32(uint8_t *patch_addr, int32_t value){
int32_t *val_ptr = (int32_t*)patch_addr;
*val_ptr = value;
}
int relocate(uint8_t *patch_addr, uint32_t reloc_type, int32_t value){
if (reloc_type == RELOC_RELATIVE_32){
patch_mem_32(patch_addr, value);
}else{
printf("Not implemented");
return 0;
}
return 1;
}
int parse_commands(uint8_t *bytes){
int32_t value;
uint32_t command;
uint32_t reloc_type;
uint32_t offs;
int data_offs;
uint32_t size;
int err_flag = 0;
while(!err_flag){
command = *(uint32_t*)bytes;
bytes += 4;
switch(command) {
case ALLOCATE_DATA:
size = *(uint32_t*)bytes; bytes += 4;
printf("ALLOCATE_DATA size=%i\n", size);
data_memory = get_data_memory(size);
break;
case COPY_DATA:
offs = *(uint32_t*)bytes; bytes += 4;
size = *(uint32_t*)bytes; bytes += 4;
printf("COPY_DATA offs=%i size=%i\n", offs, size);
memcpy(data_memory + offs, bytes, size); bytes += size;
break;
case ALLOCATE_CODE:
size = *(uint32_t*)bytes; bytes += 4;
printf("ALLOCATE_CODE size=%i\n", size);
executable_memory = get_executable_memory(size);
executable_memory_len = size;
//printf("# d %i c %i off %i\n", data_memory, executable_memory, data_offs);
break;
case COPY_CODE:
offs = *(uint32_t*)bytes; bytes += 4;
size = *(uint32_t*)bytes; bytes += 4;
printf("COPY_CODE offs=%i size=%i\n", offs, size);
memcpy(executable_memory + offs, bytes, size); bytes += size;
break;
case RELOCATE_FUNC:
offs = *(uint32_t*)bytes; bytes += 4;
reloc_type = *(uint32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4;
printf("RELOCATE_FUNC patch_offs=%i reloc_type=%i value=%i\n",
offs, reloc_type, value);
relocate(executable_memory + offs, reloc_type, value);
break;
case RELOCATE_OBJECT:
offs = *(uint32_t*)bytes; bytes += 4;
reloc_type = *(uint32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4;
printf("RELOCATE_OBJECT patch_offs=%i reloc_type=%i value=%i\n",
offs, reloc_type, value);
data_offs = (int32_t)(data_memory - executable_memory);
if (abs(data_offs) > 0x7FFFFFFF) {
perror("code and data memory to far apart");
return EXIT_FAILURE;
}
relocate(executable_memory + offs, reloc_type, value + (int32_t)data_offs);
//printf("> %i\n", data_offs);
break;
case SET_ENTR_POINT:
uint32_t rel_entr_point = *(uint32_t*)bytes; bytes += 4;
printf("SET_ENTR_POINT rel_entr_point=%i\n", rel_entr_point);
entr_point = (int (*)())(executable_memory + rel_entr_point);
break;
case END_PROG:
printf("END_PROG\n");
mark_mem_executable();
int ret = entr_point();
printf("Return value: %i\n", ret);
err_flag = 1;
break;
default:
printf("Unknown command\n");
err_flag = -1;
break;
}
}
return err_flag;
}
int main(int argc, char *argv[]) {
if (argc != 2) {
fprintf(stderr, "Usage: %s <binary_file>\n", argv[0]);
return EXIT_FAILURE;
}
// Open the file
int fd = open(argv[1], O_RDONLY);
if (fd < 0) {
perror("open");
return EXIT_FAILURE;
}
// Get file size
struct stat st;
if (fstat(fd, &st) < 0) {
perror("fstat");
close(fd);
return EXIT_FAILURE;
}
if (st.st_size == 0) {
fprintf(stderr, "Error: File is empty\n");
close(fd);
return EXIT_FAILURE;
}
//uint8_t *file_buff = get_data_memory((uint32_t)st.st_size);
uint8_t *file_buff = (uint8_t*)malloc((size_t)st.st_size);
// Read file into allocated memory
if (read(fd, file_buff, (long unsigned int)st.st_size) != st.st_size) {
perror("read");
close(fd);
return EXIT_FAILURE;
}
close(fd);
parse_commands(file_buff);
munmap(executable_memory, executable_memory_len);
return EXIT_SUCCESS;
}

View File

@ -21,7 +21,7 @@ def test_compile():
il = rc.compile_to_instruction_list(out)
#print('#', il.print())
print('#', il.print())
if __name__ == "__main__":

9
tests/test_stencil.db.py Normal file
View File

@ -0,0 +1,9 @@
from copapy import stencil_database
if __name__ == "__main__":
sdb = stencil_database('src/copapy/obj/stencils_x86_64_O3.o')
print('----')
#print(sdb.function_definitions)
for sym_name in sdb.function_definitions.keys():
print('-', sym_name)
print(list(sdb.get_relocs(sym_name)))