From 369c279a68c230100edd66c5225a3870792a8894 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Fri, 19 Dec 2025 16:07:15 +0100 Subject: [PATCH] jit decorator added --- src/copapy/__init__.py | 5 ++-- src/copapy/_target.py | 61 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 6e6ea03..6bb1d59 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -1,4 +1,4 @@ -from ._target import Target +from ._target import Target, jit from ._basic_types import NumLike, value, generic_sdb, iif from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection from ._matrices import matrix, identity, zeros, ones, diagonal, eye @@ -41,5 +41,6 @@ __all__ = [ "rotate_vector", "vector_projection", "grad", - "eye" + "eye", + "jit" ] diff --git a/src/copapy/_target.py b/src/copapy/_target.py index 67fbf9b..f9dedc4 100644 --- a/src/copapy/_target.py +++ b/src/copapy/_target.py @@ -1,4 +1,4 @@ -from typing import Iterable, overload, TypeVar, Any +from typing import Iterable, overload, TypeVar, Any, Callable, TypeAlias from . import _binwrite as binw from coparun_module import coparun, read_data_mem, create_target, clear_target import struct @@ -7,6 +7,11 @@ from ._basic_types import value, Net, Node, Write, NumLike from ._compiler import compile_to_dag T = TypeVar("T", int, float) +Values: TypeAlias = 'Iterable[NumLike] | NumLike' +ArgType: TypeAlias = int | float | Iterable[int | float] +TRet = TypeVar("TRet", Iterable[int | float], int, float) + +_jit_cache: dict[Any, tuple['Target', tuple[value[Any] | Iterable[value[Any]], ...], NumLike | Iterable[NumLike]]] = {} def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None: @@ -17,6 +22,25 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, dw.write_int(lengths) +def jit(func: Callable[..., TRet]) -> Callable[..., TRet]: + def call_helper(*args: ArgType) -> TRet: + if func in _jit_cache: + tg, inputs, out = _jit_cache[func] + for input, arg in zip(inputs, args): + tg.write_value(input, arg) + else: + tg = Target() + inputs = tuple( + tuple(value(ai) for ai in a) if isinstance(a, Iterable) else value(a) for a in args) + out = func(*inputs) # type: ignore + tg.compile(out) + _jit_cache[func] = (tg, inputs, out) + tg.run() + return tg.read_value(out) # type: ignore + + return call_helper + + class Target(): """Target device for compiling for and running on copapy code. """ @@ -110,6 +134,41 @@ class Target(): return val else: raise ValueError(f"Unsupported value type: {var_type}") + + def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None: + """Reads the numeric value of a copapy type. + + Arguments: + net: Variable to overwrite + value: Value + """ + if isinstance(net, Iterable): + assert isinstance(value, Iterable), "If net is iterable, value must be iterable too" + for ni, vi in zip(net, value): + self.write_value(ni, vi) + return + + assert not isinstance(value, Iterable), "If net is not iterable, value must not be iterable" + + assert isinstance(net, Net), "Argument must be a copapy value" + assert net in self._values, f"Value {net} not found. It might not have been compiled for the target." + addr, lengths, var_type = self._values[net] + assert lengths > 0 + + dw = binw.data_writer(self.sdb.byteorder) + dw.write_com(binw.Command.COPY_DATA) + dw.write_int(addr) + dw.write_int(lengths) + + if var_type == 'float': + dw.write_value(float(value), lengths) + elif var_type == 'int' or var_type == 'bool': + dw.write_value(int(value), lengths) + else: + raise ValueError(f"Unsupported value type: {var_type}") + + dw.write_com(binw.Command.END_COM) + assert coparun(dw.get_data()) > 0 def read_value_remote(self, net: Net) -> None: """Reads the raw data of a value by the runner."""