jit decorator added

This commit is contained in:
Nicolas Kruse 2025-12-19 16:07:15 +01:00
parent c8e6848530
commit 369c279a68
2 changed files with 63 additions and 3 deletions

View File

@ -1,4 +1,4 @@
from ._target import Target from ._target import Target, jit
from ._basic_types import NumLike, value, generic_sdb, iif from ._basic_types import NumLike, value, generic_sdb, iif
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
from ._matrices import matrix, identity, zeros, ones, diagonal, eye from ._matrices import matrix, identity, zeros, ones, diagonal, eye
@ -41,5 +41,6 @@ __all__ = [
"rotate_vector", "rotate_vector",
"vector_projection", "vector_projection",
"grad", "grad",
"eye" "eye",
"jit"
] ]

View File

@ -1,4 +1,4 @@
from typing import Iterable, overload, TypeVar, Any from typing import Iterable, overload, TypeVar, Any, Callable, TypeAlias
from . import _binwrite as binw from . import _binwrite as binw
from coparun_module import coparun, read_data_mem, create_target, clear_target from coparun_module import coparun, read_data_mem, create_target, clear_target
import struct import struct
@ -7,6 +7,11 @@ from ._basic_types import value, Net, Node, Write, NumLike
from ._compiler import compile_to_dag from ._compiler import compile_to_dag
T = TypeVar("T", int, float) T = TypeVar("T", int, float)
Values: TypeAlias = 'Iterable[NumLike] | NumLike'
ArgType: TypeAlias = int | float | Iterable[int | float]
TRet = TypeVar("TRet", Iterable[int | float], int, float)
_jit_cache: dict[Any, tuple['Target', tuple[value[Any] | Iterable[value[Any]], ...], NumLike | Iterable[NumLike]]] = {}
def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None: def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None:
@ -17,6 +22,25 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
dw.write_int(lengths) dw.write_int(lengths)
def jit(func: Callable[..., TRet]) -> Callable[..., TRet]:
def call_helper(*args: ArgType) -> TRet:
if func in _jit_cache:
tg, inputs, out = _jit_cache[func]
for input, arg in zip(inputs, args):
tg.write_value(input, arg)
else:
tg = Target()
inputs = tuple(
tuple(value(ai) for ai in a) if isinstance(a, Iterable) else value(a) for a in args)
out = func(*inputs) # type: ignore
tg.compile(out)
_jit_cache[func] = (tg, inputs, out)
tg.run()
return tg.read_value(out) # type: ignore
return call_helper
class Target(): class Target():
"""Target device for compiling for and running on copapy code. """Target device for compiling for and running on copapy code.
""" """
@ -111,6 +135,41 @@ class Target():
else: else:
raise ValueError(f"Unsupported value type: {var_type}") raise ValueError(f"Unsupported value type: {var_type}")
def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None:
"""Reads the numeric value of a copapy type.
Arguments:
net: Variable to overwrite
value: Value
"""
if isinstance(net, Iterable):
assert isinstance(value, Iterable), "If net is iterable, value must be iterable too"
for ni, vi in zip(net, value):
self.write_value(ni, vi)
return
assert not isinstance(value, Iterable), "If net is not iterable, value must not be iterable"
assert isinstance(net, Net), "Argument must be a copapy value"
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target."
addr, lengths, var_type = self._values[net]
assert lengths > 0
dw = binw.data_writer(self.sdb.byteorder)
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(addr)
dw.write_int(lengths)
if var_type == 'float':
dw.write_value(float(value), lengths)
elif var_type == 'int' or var_type == 'bool':
dw.write_value(int(value), lengths)
else:
raise ValueError(f"Unsupported value type: {var_type}")
dw.write_com(binw.Command.END_COM)
assert coparun(dw.get_data()) > 0
def read_value_remote(self, net: Net) -> None: def read_value_remote(self, net: Net) -> None:
"""Reads the raw data of a value by the runner.""" """Reads the raw data of a value by the runner."""
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)