mirror of https://github.com/Nonannet/copapy.git
jit decorator added
This commit is contained in:
parent
c8e6848530
commit
369c279a68
|
|
@ -1,4 +1,4 @@
|
|||
from ._target import Target
|
||||
from ._target import Target, jit
|
||||
from ._basic_types import NumLike, value, generic_sdb, iif
|
||||
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
|
||||
from ._matrices import matrix, identity, zeros, ones, diagonal, eye
|
||||
|
|
@ -41,5 +41,6 @@ __all__ = [
|
|||
"rotate_vector",
|
||||
"vector_projection",
|
||||
"grad",
|
||||
"eye"
|
||||
"eye",
|
||||
"jit"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterable, overload, TypeVar, Any
|
||||
from typing import Iterable, overload, TypeVar, Any, Callable, TypeAlias
|
||||
from . import _binwrite as binw
|
||||
from coparun_module import coparun, read_data_mem, create_target, clear_target
|
||||
import struct
|
||||
|
|
@ -7,6 +7,11 @@ from ._basic_types import value, Net, Node, Write, NumLike
|
|||
from ._compiler import compile_to_dag
|
||||
|
||||
T = TypeVar("T", int, float)
|
||||
Values: TypeAlias = 'Iterable[NumLike] | NumLike'
|
||||
ArgType: TypeAlias = int | float | Iterable[int | float]
|
||||
TRet = TypeVar("TRet", Iterable[int | float], int, float)
|
||||
|
||||
_jit_cache: dict[Any, tuple['Target', tuple[value[Any] | Iterable[value[Any]], ...], NumLike | Iterable[NumLike]]] = {}
|
||||
|
||||
|
||||
def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None:
|
||||
|
|
@ -17,6 +22,25 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
|
|||
dw.write_int(lengths)
|
||||
|
||||
|
||||
def jit(func: Callable[..., TRet]) -> Callable[..., TRet]:
|
||||
def call_helper(*args: ArgType) -> TRet:
|
||||
if func in _jit_cache:
|
||||
tg, inputs, out = _jit_cache[func]
|
||||
for input, arg in zip(inputs, args):
|
||||
tg.write_value(input, arg)
|
||||
else:
|
||||
tg = Target()
|
||||
inputs = tuple(
|
||||
tuple(value(ai) for ai in a) if isinstance(a, Iterable) else value(a) for a in args)
|
||||
out = func(*inputs) # type: ignore
|
||||
tg.compile(out)
|
||||
_jit_cache[func] = (tg, inputs, out)
|
||||
tg.run()
|
||||
return tg.read_value(out) # type: ignore
|
||||
|
||||
return call_helper
|
||||
|
||||
|
||||
class Target():
|
||||
"""Target device for compiling for and running on copapy code.
|
||||
"""
|
||||
|
|
@ -111,6 +135,41 @@ class Target():
|
|||
else:
|
||||
raise ValueError(f"Unsupported value type: {var_type}")
|
||||
|
||||
def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None:
|
||||
"""Reads the numeric value of a copapy type.
|
||||
|
||||
Arguments:
|
||||
net: Variable to overwrite
|
||||
value: Value
|
||||
"""
|
||||
if isinstance(net, Iterable):
|
||||
assert isinstance(value, Iterable), "If net is iterable, value must be iterable too"
|
||||
for ni, vi in zip(net, value):
|
||||
self.write_value(ni, vi)
|
||||
return
|
||||
|
||||
assert not isinstance(value, Iterable), "If net is not iterable, value must not be iterable"
|
||||
|
||||
assert isinstance(net, Net), "Argument must be a copapy value"
|
||||
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target."
|
||||
addr, lengths, var_type = self._values[net]
|
||||
assert lengths > 0
|
||||
|
||||
dw = binw.data_writer(self.sdb.byteorder)
|
||||
dw.write_com(binw.Command.COPY_DATA)
|
||||
dw.write_int(addr)
|
||||
dw.write_int(lengths)
|
||||
|
||||
if var_type == 'float':
|
||||
dw.write_value(float(value), lengths)
|
||||
elif var_type == 'int' or var_type == 'bool':
|
||||
dw.write_value(int(value), lengths)
|
||||
else:
|
||||
raise ValueError(f"Unsupported value type: {var_type}")
|
||||
|
||||
dw.write_com(binw.Command.END_COM)
|
||||
assert coparun(dw.get_data()) > 0
|
||||
|
||||
def read_value_remote(self, net: Net) -> None:
|
||||
"""Reads the raw data of a value by the runner."""
|
||||
dw = binw.data_writer(self.sdb.byteorder)
|
||||
|
|
|
|||
Loading…
Reference in New Issue