mirror of https://github.com/Nonannet/copapy.git
jit decorator added
This commit is contained in:
parent
c8e6848530
commit
369c279a68
|
|
@ -1,4 +1,4 @@
|
||||||
from ._target import Target
|
from ._target import Target, jit
|
||||||
from ._basic_types import NumLike, value, generic_sdb, iif
|
from ._basic_types import NumLike, value, generic_sdb, iif
|
||||||
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
|
from ._vectors import vector, distance, scalar_projection, angle_between, rotate_vector, vector_projection
|
||||||
from ._matrices import matrix, identity, zeros, ones, diagonal, eye
|
from ._matrices import matrix, identity, zeros, ones, diagonal, eye
|
||||||
|
|
@ -41,5 +41,6 @@ __all__ = [
|
||||||
"rotate_vector",
|
"rotate_vector",
|
||||||
"vector_projection",
|
"vector_projection",
|
||||||
"grad",
|
"grad",
|
||||||
"eye"
|
"eye",
|
||||||
|
"jit"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Iterable, overload, TypeVar, Any
|
from typing import Iterable, overload, TypeVar, Any, Callable, TypeAlias
|
||||||
from . import _binwrite as binw
|
from . import _binwrite as binw
|
||||||
from coparun_module import coparun, read_data_mem, create_target, clear_target
|
from coparun_module import coparun, read_data_mem, create_target, clear_target
|
||||||
import struct
|
import struct
|
||||||
|
|
@ -7,6 +7,11 @@ from ._basic_types import value, Net, Node, Write, NumLike
|
||||||
from ._compiler import compile_to_dag
|
from ._compiler import compile_to_dag
|
||||||
|
|
||||||
T = TypeVar("T", int, float)
|
T = TypeVar("T", int, float)
|
||||||
|
Values: TypeAlias = 'Iterable[NumLike] | NumLike'
|
||||||
|
ArgType: TypeAlias = int | float | Iterable[int | float]
|
||||||
|
TRet = TypeVar("TRet", Iterable[int | float], int, float)
|
||||||
|
|
||||||
|
_jit_cache: dict[Any, tuple['Target', tuple[value[Any] | Iterable[value[Any]], ...], NumLike | Iterable[NumLike]]] = {}
|
||||||
|
|
||||||
|
|
||||||
def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None:
|
def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None:
|
||||||
|
|
@ -17,6 +22,25 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
|
||||||
dw.write_int(lengths)
|
dw.write_int(lengths)
|
||||||
|
|
||||||
|
|
||||||
|
def jit(func: Callable[..., TRet]) -> Callable[..., TRet]:
|
||||||
|
def call_helper(*args: ArgType) -> TRet:
|
||||||
|
if func in _jit_cache:
|
||||||
|
tg, inputs, out = _jit_cache[func]
|
||||||
|
for input, arg in zip(inputs, args):
|
||||||
|
tg.write_value(input, arg)
|
||||||
|
else:
|
||||||
|
tg = Target()
|
||||||
|
inputs = tuple(
|
||||||
|
tuple(value(ai) for ai in a) if isinstance(a, Iterable) else value(a) for a in args)
|
||||||
|
out = func(*inputs) # type: ignore
|
||||||
|
tg.compile(out)
|
||||||
|
_jit_cache[func] = (tg, inputs, out)
|
||||||
|
tg.run()
|
||||||
|
return tg.read_value(out) # type: ignore
|
||||||
|
|
||||||
|
return call_helper
|
||||||
|
|
||||||
|
|
||||||
class Target():
|
class Target():
|
||||||
"""Target device for compiling for and running on copapy code.
|
"""Target device for compiling for and running on copapy code.
|
||||||
"""
|
"""
|
||||||
|
|
@ -111,6 +135,41 @@ class Target():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported value type: {var_type}")
|
raise ValueError(f"Unsupported value type: {var_type}")
|
||||||
|
|
||||||
|
def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None:
|
||||||
|
"""Reads the numeric value of a copapy type.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
net: Variable to overwrite
|
||||||
|
value: Value
|
||||||
|
"""
|
||||||
|
if isinstance(net, Iterable):
|
||||||
|
assert isinstance(value, Iterable), "If net is iterable, value must be iterable too"
|
||||||
|
for ni, vi in zip(net, value):
|
||||||
|
self.write_value(ni, vi)
|
||||||
|
return
|
||||||
|
|
||||||
|
assert not isinstance(value, Iterable), "If net is not iterable, value must not be iterable"
|
||||||
|
|
||||||
|
assert isinstance(net, Net), "Argument must be a copapy value"
|
||||||
|
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target."
|
||||||
|
addr, lengths, var_type = self._values[net]
|
||||||
|
assert lengths > 0
|
||||||
|
|
||||||
|
dw = binw.data_writer(self.sdb.byteorder)
|
||||||
|
dw.write_com(binw.Command.COPY_DATA)
|
||||||
|
dw.write_int(addr)
|
||||||
|
dw.write_int(lengths)
|
||||||
|
|
||||||
|
if var_type == 'float':
|
||||||
|
dw.write_value(float(value), lengths)
|
||||||
|
elif var_type == 'int' or var_type == 'bool':
|
||||||
|
dw.write_value(int(value), lengths)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported value type: {var_type}")
|
||||||
|
|
||||||
|
dw.write_com(binw.Command.END_COM)
|
||||||
|
assert coparun(dw.get_data()) > 0
|
||||||
|
|
||||||
def read_value_remote(self, net: Net) -> None:
|
def read_value_remote(self, net: Net) -> None:
|
||||||
"""Reads the raw data of a value by the runner."""
|
"""Reads the raw data of a value by the runner."""
|
||||||
dw = binw.data_writer(self.sdb.byteorder)
|
dw = binw.data_writer(self.sdb.byteorder)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue