Added multi-target support for coparun-module

This commit is contained in:
Nicolas 2025-12-16 16:15:50 +01:00
parent d9f361a6d6
commit c8e6848530
7 changed files with 164 additions and 74 deletions

View File

@ -1,6 +1,6 @@
from typing import Iterable, overload, TypeVar, Any from typing import Iterable, overload, TypeVar, Any
from . import _binwrite as binw from . import _binwrite as binw
from coparun_module import coparun, read_data_mem from coparun_module import coparun, read_data_mem, create_target, clear_target
import struct import struct
from ._basic_types import stencil_db_from_package from ._basic_types import stencil_db_from_package
from ._basic_types import value, Net, Node, Write, NumLike from ._basic_types import value, Net, Node, Write, NumLike
@ -29,6 +29,10 @@ class Target():
""" """
self.sdb = stencil_db_from_package(arch, optimization) self.sdb = stencil_db_from_package(arch, optimization)
self._values: dict[Net, tuple[int, int, str]] = {} self._values: dict[Net, tuple[int, int, str]] = {}
self._context = create_target()
def __del__(self) -> None:
clear_target(self._context)
def compile(self, *values: int | float | value[int] | value[float] | Iterable[int | float | value[int] | value[float]]) -> None: def compile(self, *values: int | float | value[int] | value[float] | Iterable[int | float | value[int] | value[float]]) -> None:
"""Compiles the code to compute the given values. """Compiles the code to compute the given values.
@ -48,7 +52,7 @@ class Target():
dw, self._values = compile_to_dag(nodes, self.sdb) dw, self._values = compile_to_dag(nodes, self.sdb)
dw.write_com(binw.Command.END_COM) dw.write_com(binw.Command.END_COM)
assert coparun(dw.get_data()) > 0 assert coparun(self._context, dw.get_data()) > 0
def run(self) -> None: def run(self) -> None:
"""Runs the compiled code on the target device. """Runs the compiled code on the target device.
@ -56,7 +60,7 @@ class Target():
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)
dw.write_com(binw.Command.RUN_PROG) dw.write_com(binw.Command.RUN_PROG)
dw.write_com(binw.Command.END_COM) dw.write_com(binw.Command.END_COM)
assert coparun(dw.get_data()) > 0 assert coparun(self._context, dw.get_data()) > 0
@overload @overload
def read_value(self, net: value[T]) -> T: ... def read_value(self, net: value[T]) -> T: ...
@ -84,7 +88,7 @@ class Target():
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target." assert net in self._values, f"Value {net} not found. It might not have been compiled for the target."
addr, lengths, var_type = self._values[net] addr, lengths, var_type = self._values[net]
assert lengths > 0 assert lengths > 0
data = read_data_mem(addr, lengths) data = read_data_mem(self._context, addr, lengths)
assert data is not None and len(data) == lengths, f"Failed to read value {net}" assert data is not None and len(data) == lengths, f"Failed to read value {net}"
en = {'little': '<', 'big': '>'}[self.sdb.byteorder] en = {'little': '<', 'big': '>'}[self.sdb.byteorder]
if var_type == 'float': if var_type == 'float':
@ -111,4 +115,4 @@ class Target():
"""Reads the raw data of a value by the runner.""" """Reads the raw data of a value by the runner."""
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)
add_read_command(dw, self._values, net) add_read_command(dw, self._values, net)
assert coparun(dw.get_data()) > 0 assert coparun(self._context, dw.get_data()) > 0

View File

@ -45,7 +45,8 @@ int main(int argc, char *argv[]) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
int ret = parse_commands(file_buff); runmem_t targ;
int ret = parse_commands(&targ, file_buff);
if (ret == 2) { if (ret == 2) {
/* Dump code for debugging */ /* Dump code for debugging */
@ -54,11 +55,11 @@ int main(int argc, char *argv[]) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
f = fopen(argv[2], "wb"); f = fopen(argv[2], "wb");
fwrite(executable_memory, 1, (size_t)executable_memory_len, f); fwrite(targ.executable_memory, 1, (size_t)targ.executable_memory_len, f);
fclose(f); fclose(f);
} }
free_memory(); free_memory(&targ);
return ret < 0; return ret < 0;
} }

View File

@ -1,30 +1,41 @@
#define PY_SSIZE_T_CLEAN #define PY_SSIZE_T_CLEAN
#include <Python.h> #include <Python.h>
#include "runmem.h" #include "runmem.h"
#include <stdlib.h>
static PyObject* coparun(PyObject* self, PyObject* args) { static PyObject* coparun(PyObject* self, PyObject* args) {
PyObject *handle_obj;
const char *buf; const char *buf;
Py_ssize_t buf_len; Py_ssize_t buf_len;
int result; int result;
if (!PyArg_ParseTuple(args, "y#", &buf, &buf_len)) { // Expect: handle, bytes
if (!PyArg_ParseTuple(args, "Oy#", &handle_obj, &buf, &buf_len)) {
return NULL; /* TypeError set by PyArg_ParseTuple */ return NULL; /* TypeError set by PyArg_ParseTuple */
} }
void *ptr = PyLong_AsVoidPtr(handle_obj);
if (!ptr) {
PyErr_SetString(PyExc_ValueError, "Invalid context handle");
return NULL;
}
runmem_t *context = (runmem_t*)ptr;
/* If parse_commands may run for a long time, release the GIL. */ /* If parse_commands may run for a long time, release the GIL. */
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
result = parse_commands((uint8_t*)buf); result = parse_commands(context, (uint8_t*)buf);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
return PyLong_FromLong(result); return PyLong_FromLong(result);
} }
static PyObject* read_data_mem(PyObject* self, PyObject* args) { static PyObject* read_data_mem(PyObject* self, PyObject* args) {
PyObject *handle_obj;
unsigned long rel_addr; unsigned long rel_addr;
unsigned long length; unsigned long length;
// Parse arguments: unsigned long (relative address), Py_ssize_t (length) // Expect: handle, rel_addr, length
if (!PyArg_ParseTuple(args, "nn", &rel_addr, &length)) { if (!PyArg_ParseTuple(args, "Onn", &handle_obj, &rel_addr, &length)) {
return NULL; return NULL;
} }
@ -33,9 +44,21 @@ static PyObject* read_data_mem(PyObject* self, PyObject* args) {
return NULL; return NULL;
} }
const char *ptr = (const char *)(data_memory + rel_addr); void *ptr = PyLong_AsVoidPtr(handle_obj);
if (!ptr) {
PyErr_SetString(PyExc_ValueError, "Invalid context handle");
return NULL;
}
runmem_t *context = (runmem_t*)ptr;
PyObject *result = PyBytes_FromStringAndSize(ptr, length); if (!context->data_memory || rel_addr + length > context->data_memory_len) {
PyErr_SetString(PyExc_ValueError, "Read out of bounds");
return NULL;
}
const char *data_ptr = (const char *)(context->data_memory + rel_addr);
PyObject *result = PyBytes_FromStringAndSize(data_ptr, length);
if (!result) { if (!result) {
return PyErr_NoMemory(); return PyErr_NoMemory();
} }
@ -43,9 +66,36 @@ static PyObject* read_data_mem(PyObject* self, PyObject* args) {
return result; return result;
} }
static PyObject* create_target(PyObject* self, PyObject* args) {
runmem_t *context = (runmem_t*)calloc(1, sizeof(runmem_t));
if (!context) {
return PyErr_NoMemory();
}
// Return the pointer as a Python integer (handle)
return PyLong_FromVoidPtr((void*)context);
}
static PyObject* clear_target(PyObject* self, PyObject* args) {
PyObject *handle_obj;
if (!PyArg_ParseTuple(args, "O", &handle_obj)) {
return NULL;
}
void *ptr = PyLong_AsVoidPtr(handle_obj);
if (!ptr) {
PyErr_SetString(PyExc_ValueError, "Invalid handle");
return NULL;
}
runmem_t *context = (runmem_t*)ptr;
free_memory(context);
free(context);
Py_RETURN_NONE;
}
static PyMethodDef MyMethods[] = { static PyMethodDef MyMethods[] = {
{"coparun", coparun, METH_VARARGS, "Pass raw command data to coparun"}, {"coparun", coparun, METH_VARARGS, "Pass raw command data to coparun"},
{"read_data_mem", read_data_mem, METH_VARARGS, "Read memory and return as bytes"}, {"read_data_mem", read_data_mem, METH_VARARGS, "Read memory and return as bytes"},
{"create_target", create_target, METH_NOARGS, "Create and return a handle to a zero-initialized runmem_t struct"},
{"clear_target", clear_target, METH_VARARGS, "Free all memory associated with the given target handle"},
{NULL, NULL, 0, NULL} {NULL, NULL, 0, NULL}
}; };

View File

@ -5,14 +5,6 @@
#include "runmem.h" #include "runmem.h"
#include "mem_man.h" #include "mem_man.h"
/* Globals declared extern in runmem.h */
uint8_t *data_memory = NULL;
uint32_t data_memory_len = 0;
uint8_t *executable_memory = NULL;
uint32_t executable_memory_len = 0;
entry_point_t entr_point = NULL;
int data_offs = 0;
void patch(uint8_t *patch_addr, uint32_t patch_mask, int32_t value) { void patch(uint8_t *patch_addr, uint32_t patch_mask, int32_t value) {
uint32_t *val_ptr = (uint32_t*)patch_addr; uint32_t *val_ptr = (uint32_t*)patch_addr;
uint32_t original = *val_ptr; uint32_t original = *val_ptr;
@ -58,23 +50,25 @@ void patch_arm32_abs(uint8_t *patch_addr, uint32_t imm16)
*((uint32_t *)patch_addr) = instr; *((uint32_t *)patch_addr) = instr;
} }
void free_memory() { void free_memory(runmem_t *context) {
deallocate_memory(executable_memory, executable_memory_len); deallocate_memory(context->executable_memory, context->executable_memory_len);
deallocate_memory(data_memory, data_memory_len); deallocate_memory(context->data_memory, context->data_memory_len);
executable_memory_len = 0; context->executable_memory_len = 0;
data_memory_len = 0; context->data_memory_len = 0;
context->executable_memory = NULL;
context->data_memory = NULL;
context->entr_point = NULL;
context->data_offs = 0;
} }
int update_data_offs() { int update_data_offs(runmem_t *context) {
if (data_memory && executable_memory && (data_memory - executable_memory > 0x7FFFFFFF || executable_memory - data_memory > 0x7FFFFFFF)) { if (context->data_memory && context->executable_memory &&
(context->data_memory - context->executable_memory > 0x7FFFFFFF ||
context->executable_memory - context->data_memory > 0x7FFFFFFF)) {
perror("Error: code and data memory to far apart"); perror("Error: code and data memory to far apart");
return 0; return 0;
} }
if (data_memory && executable_memory && (data_memory - executable_memory > 0x7FFFFFFF || executable_memory - data_memory > 0x7FFFFFFF)) { context->data_offs = (int)(context->data_memory - context->executable_memory);
perror("Error: code and data memory to far apart");
return 0;
}
data_offs = (int)(data_memory - executable_memory);
return 1; return 1;
} }
@ -82,7 +76,7 @@ int floor_div(int a, int b) {
return a / b - ((a % b != 0) && ((a < 0) != (b < 0))); return a / b - ((a % b != 0) && ((a < 0) != (b < 0)));
} }
int parse_commands(uint8_t *bytes) { int parse_commands(runmem_t *context, uint8_t *bytes) {
int32_t value; int32_t value;
uint32_t command; uint32_t command;
uint32_t patch_mask; uint32_t patch_mask;
@ -98,33 +92,32 @@ int parse_commands(uint8_t *bytes) {
switch(command) { switch(command) {
case ALLOCATE_DATA: case ALLOCATE_DATA:
size = *(uint32_t*)bytes; bytes += 4; size = *(uint32_t*)bytes; bytes += 4;
data_memory = allocate_data_memory(size); context->data_memory = allocate_data_memory(size);
data_memory_len = size; context->data_memory_len = size;
LOG("ALLOCATE_DATA size=%i mem_addr=%p\n", size, (void*)data_memory); LOG("ALLOCATE_DATA size=%i mem_addr=%p\n", size, (void*)context->data_memory);
if (!update_data_offs()) end_flag = -4; if (!update_data_offs(context)) end_flag = -4;
break; break;
case COPY_DATA: case COPY_DATA:
offs = *(uint32_t*)bytes; bytes += 4; offs = *(uint32_t*)bytes; bytes += 4;
size = *(uint32_t*)bytes; bytes += 4; size = *(uint32_t*)bytes; bytes += 4;
LOG("COPY_DATA offs=%i size=%i\n", offs, size); LOG("COPY_DATA offs=%i size=%i\n", offs, size);
memcpy(data_memory + offs, bytes, size); bytes += size; memcpy(context->data_memory + offs, bytes, size); bytes += size;
break; break;
case ALLOCATE_CODE: case ALLOCATE_CODE:
size = *(uint32_t*)bytes; bytes += 4; size = *(uint32_t*)bytes; bytes += 4;
executable_memory = allocate_executable_memory(size); context->executable_memory = allocate_executable_memory(size);
executable_memory_len = size; context->executable_memory_len = size;
LOG("ALLOCATE_CODE size=%i mem_addr=%p\n", size, (void*)executable_memory); LOG("ALLOCATE_CODE size=%i mem_addr=%p\n", size, (void*)context->executable_memory);
//LOG("# d %i c %i off %i\n", data_memory, executable_memory, data_offs); if (!update_data_offs(context)) end_flag = -4;
if (!update_data_offs()) end_flag = -4;
break; break;
case COPY_CODE: case COPY_CODE:
offs = *(uint32_t*)bytes; bytes += 4; offs = *(uint32_t*)bytes; bytes += 4;
size = *(uint32_t*)bytes; bytes += 4; size = *(uint32_t*)bytes; bytes += 4;
LOG("COPY_CODE offs=%i size=%i\n", offs, size); LOG("COPY_CODE offs=%i size=%i\n", offs, size);
memcpy(executable_memory + offs, bytes, size); bytes += size; memcpy(context->executable_memory + offs, bytes, size); bytes += size;
break; break;
case PATCH_FUNC: case PATCH_FUNC:
@ -134,7 +127,7 @@ int parse_commands(uint8_t *bytes) {
value = *(int32_t*)bytes; bytes += 4; value = *(int32_t*)bytes; bytes += 4;
LOG("PATCH_FUNC patch_offs=%i patch_mask=%#08x scale=%i value=%i\n", LOG("PATCH_FUNC patch_offs=%i patch_mask=%#08x scale=%i value=%i\n",
offs, patch_mask, patch_scale, value); offs, patch_mask, patch_scale, value);
patch(executable_memory + offs, patch_mask, value / patch_scale); patch(context->executable_memory + offs, patch_mask, value / patch_scale);
break; break;
case PATCH_OBJECT: case PATCH_OBJECT:
@ -144,7 +137,7 @@ int parse_commands(uint8_t *bytes) {
value = *(int32_t*)bytes; bytes += 4; value = *(int32_t*)bytes; bytes += 4;
LOG("PATCH_OBJECT patch_offs=%i patch_mask=%#08x scale=%i value=%i\n", LOG("PATCH_OBJECT patch_offs=%i patch_mask=%#08x scale=%i value=%i\n",
offs, patch_mask, patch_scale, value); offs, patch_mask, patch_scale, value);
patch(executable_memory + offs, patch_mask, value / patch_scale + data_offs / patch_scale); patch(context->executable_memory + offs, patch_mask, value / patch_scale + context->data_offs / patch_scale);
break; break;
case PATCH_OBJECT_ABS: case PATCH_OBJECT_ABS:
@ -154,7 +147,7 @@ int parse_commands(uint8_t *bytes) {
value = *(int32_t*)bytes; bytes += 4; value = *(int32_t*)bytes; bytes += 4;
LOG("PATCH_OBJECT_ABS patch_offs=%i patch_mask=%#08x scale=%i value=%i\n", LOG("PATCH_OBJECT_ABS patch_offs=%i patch_mask=%#08x scale=%i value=%i\n",
offs, patch_mask, patch_scale, value); offs, patch_mask, patch_scale, value);
patch(executable_memory + offs, patch_mask, value / patch_scale); patch(context->executable_memory + offs, patch_mask, value / patch_scale);
break; break;
case PATCH_OBJECT_REL: case PATCH_OBJECT_REL:
@ -163,8 +156,8 @@ int parse_commands(uint8_t *bytes) {
patch_scale = *(int32_t*)bytes; bytes += 4; patch_scale = *(int32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4; value = *(int32_t*)bytes; bytes += 4;
LOG("PATCH_OBJECT_REL patch_offs=%i patch_addr=%p scale=%i value=%i\n", LOG("PATCH_OBJECT_REL patch_offs=%i patch_addr=%p scale=%i value=%i\n",
offs, (void*)(data_memory + value), patch_scale, value); offs, (void*)(context->data_memory + value), patch_scale, value);
*(void **)(executable_memory + offs) = data_memory + value; // / patch_scale; *(void **)(context->executable_memory + offs) = context->data_memory + value;
break; break;
case PATCH_OBJECT_HI21: case PATCH_OBJECT_HI21:
@ -173,8 +166,8 @@ int parse_commands(uint8_t *bytes) {
patch_scale = *(int32_t*)bytes; bytes += 4; patch_scale = *(int32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4; value = *(int32_t*)bytes; bytes += 4;
LOG("PATCH_OBJECT_HI21 patch_offs=%i scale=%i value=%i res_value=%i\n", LOG("PATCH_OBJECT_HI21 patch_offs=%i scale=%i value=%i res_value=%i\n",
offs, patch_scale, value, floor_div(data_offs + value, patch_scale) - (int32_t)offs / patch_scale); offs, patch_scale, value, floor_div(context->data_offs + value, patch_scale) - (int32_t)offs / patch_scale);
patch_hi21(executable_memory + offs, floor_div(data_offs + value, patch_scale) - (int32_t)offs / patch_scale); patch_hi21(context->executable_memory + offs, floor_div(context->data_offs + value, patch_scale) - (int32_t)offs / patch_scale);
break; break;
case PATCH_OBJECT_ARM32_ABS: case PATCH_OBJECT_ARM32_ABS:
@ -183,21 +176,24 @@ int parse_commands(uint8_t *bytes) {
patch_scale = *(int32_t*)bytes; bytes += 4; patch_scale = *(int32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4; value = *(int32_t*)bytes; bytes += 4;
LOG("PATCH_OBJECT_ARM32_ABS patch_offs=%i patch_mask=%#08x scale=%i value=%i imm16=%#04x\n", LOG("PATCH_OBJECT_ARM32_ABS patch_offs=%i patch_mask=%#08x scale=%i value=%i imm16=%#04x\n",
offs, patch_mask, patch_scale, value, (uint32_t)((uintptr_t)(data_memory + value) & patch_mask) / (uint32_t)patch_scale); offs, patch_mask, patch_scale, value, (uint32_t)((uintptr_t)(context->data_memory + value) & patch_mask) / (uint32_t)patch_scale);
patch_arm32_abs(executable_memory + offs, (uint32_t)((uintptr_t)(data_memory + value) & patch_mask) / (uint32_t)patch_scale); patch_arm32_abs(context->executable_memory + offs, (uint32_t)((uintptr_t)(context->data_memory + value) & patch_mask) / (uint32_t)patch_scale);
break; break;
case ENTRY_POINT: case ENTRY_POINT:
rel_entr_point = *(uint32_t*)bytes; bytes += 4; rel_entr_point = *(uint32_t*)bytes; bytes += 4;
entr_point = (entry_point_t)(executable_memory + rel_entr_point); context->entr_point = (entry_point_t)(context->executable_memory + rel_entr_point);
LOG("ENTRY_POINT rel_entr_point=%i\n", rel_entr_point); LOG("ENTRY_POINT rel_entr_point=%i\n", rel_entr_point);
mark_mem_executable(executable_memory, executable_memory_len); mark_mem_executable(context->executable_memory, context->executable_memory_len);
break; break;
case RUN_PROG: case RUN_PROG:
LOG("RUN_PROG\n"); LOG("RUN_PROG\n");
int ret = entr_point(); {
int ret = context->entr_point();
(void)ret;
BLOG("Return value: %i\n", ret); BLOG("Return value: %i\n", ret);
}
break; break;
case READ_DATA: case READ_DATA:
@ -205,14 +201,14 @@ int parse_commands(uint8_t *bytes) {
size = *(uint32_t*)bytes; bytes += 4; size = *(uint32_t*)bytes; bytes += 4;
BLOG("READ_DATA offs=%i size=%i data=", offs, size); BLOG("READ_DATA offs=%i size=%i data=", offs, size);
for (uint32_t i = 0; i < size; i++) { for (uint32_t i = 0; i < size; i++) {
printf("%02X ", data_memory[offs + i]); printf("%02X ", context->data_memory[offs + i]);
} }
printf("\n"); printf("\n");
break; break;
case FREE_MEMORY: case FREE_MEMORY:
LOG("FREE_MENORY\n"); LOG("FREE_MENORY\n");
free_memory(); free_memory(context);
break; break;
case DUMP_CODE: case DUMP_CODE:

View File

@ -32,23 +32,24 @@
#define FREE_MEMORY 257 #define FREE_MEMORY 257
#define DUMP_CODE 258 #define DUMP_CODE 258
/* Memory blobs accessible by other translation units */ /* Entry point type */
extern uint8_t *data_memory;
extern uint32_t data_memory_len;
extern uint8_t *executable_memory;
extern uint32_t executable_memory_len;
extern int data_offs;
/* Entry point type and variable */
typedef int (*entry_point_t)(void); typedef int (*entry_point_t)(void);
extern entry_point_t entr_point;
/* Struct for run-time memory state */
typedef struct runmem_s {
uint8_t *data_memory; // Pointer to data memory
uint32_t data_memory_len; // Length of data memory
uint8_t *executable_memory; // Pointer to executable memory
uint32_t executable_memory_len; // Length of executable memory
int data_offs; // Offset of data memory relative to executable memory
entry_point_t entr_point; // Entry point function pointer
} runmem_t;
/* Command parser: takes a pointer to the command stream and returns /* Command parser: takes a pointer to the command stream and returns
an error flag (0 on success according to current code) */ an error flag (0 on success according to current code) */
int parse_commands(uint8_t *bytes); int parse_commands(runmem_t *context, uint8_t *bytes);
/* Free program and data memory */ /* Free program and data memory */
void free_memory(); void free_memory(runmem_t *context);
#endif /* RUNMEM_H */ #endif /* RUNMEM_H */

View File

@ -1,2 +1,4 @@
def coparun(data: bytes) -> int: ... def coparun(context: int, data: bytes) -> int: ...
def read_data_mem(rel_addr: int, length: int) -> bytes: ... def read_data_mem(context: int, rel_addr: int, length: int) -> bytes: ...
def create_target() -> int: ...
def clear_target(context: int) -> None: ...

View File

@ -0,0 +1,36 @@
import copapy as cp
import pytest
def test_multi_target():
# Define variables
a = cp.value(0.25)
b = cp.value(0.87)
# Define computations
c = a + b * 2.0
d = c ** 2 + cp.sin(a)
e = d + cp.sqrt(b)
# Create a target, compile and run
tg1 = cp.Target()
tg1.compile(e)
# Patch constant value
a.source = cp._basic_types.CPConstant(1000.0)
tg2 = cp.Target()
tg2.compile(e)
tg1.run()
tg2.run()
print("Result tg1:", tg1.read_value(e))
print("Result tg2:", tg2.read_value(e))
# Assertions to verify correctness
assert tg1.read_value(e) == pytest.approx((0.25 + 0.87 * 2.0) ** 2 + cp.sin(0.25) + cp.sqrt(0.87), 0.005) # pyright: ignore[reportUnknownMemberType]
assert tg2.read_value(e) == pytest.approx((1000.0 + 0.87 * 2.0) ** 2 + cp.sin(1000.0) + cp.sqrt(0.87), 0.005) # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
test_multi_target()