mirror of https://github.com/Nonannet/copapy.git
full type hints added
This commit is contained in:
parent
0e36b672d8
commit
52f5b28017
|
|
@ -1,5 +1,7 @@
|
|||
import pkgutil
|
||||
from typing import Generator, Iterable, Any
|
||||
from typing import Generator, Iterable, Any, TypeVar, overload, TypeAlias
|
||||
from typing import cast
|
||||
|
||||
from . import binwrite as binw
|
||||
from .stencil_db import stencil_database
|
||||
from collections import defaultdict, deque
|
||||
|
|
@ -7,7 +9,16 @@ from coparun_module import coparun, read_data_mem
|
|||
import struct
|
||||
import platform
|
||||
|
||||
Operand = type['Net'] | float | int
|
||||
#CPNumLike: TypeAlias = 'cpint | cpfloat | cpbool'
|
||||
NumLike: TypeAlias = 'cpint | cpfloat | cpbool | int | float| bool'
|
||||
NumLikeAndNet: TypeAlias = 'cpint | cpfloat | cpbool | int | float | bool | Net'
|
||||
NetAndNum: TypeAlias = 'Net | int | float'
|
||||
|
||||
unifloat: TypeAlias = 'cpfloat | float'
|
||||
uniint: TypeAlias = 'cpint | int'
|
||||
unibool: TypeAlias = 'cpbool | bool'
|
||||
|
||||
TNumber = TypeVar("TNumber", bound='CPNumber')
|
||||
|
||||
|
||||
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
|
||||
|
|
@ -24,6 +35,7 @@ def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> s
|
|||
|
||||
generic_sdb = stencil_db_from_package()
|
||||
|
||||
|
||||
def transl_type(t: str) -> str:
|
||||
return {'bool': 'int'}.get(t, t)
|
||||
|
||||
|
|
@ -47,63 +59,6 @@ class Net:
|
|||
self.dtype = dtype
|
||||
self.source = source
|
||||
|
||||
def __mul__(self, other: Any) -> 'Net':
|
||||
return _add_op('mul', [self, other], True)
|
||||
|
||||
def __rmul__(self, other: Any) -> 'Net':
|
||||
return _add_op('mul', [self, other], True)
|
||||
|
||||
def __add__(self, other: Any) -> 'Net':
|
||||
return _add_op('add', [self, other], True)
|
||||
|
||||
def __radd__(self, other: Any) -> 'Net':
|
||||
return _add_op('add', [self, other], True)
|
||||
|
||||
def __sub__(self, other: Any) -> 'Net':
|
||||
return _add_op('sub', [self, other])
|
||||
|
||||
def __rsub__(self, other: Any) -> 'Net':
|
||||
return _add_op('sub', [other, self])
|
||||
|
||||
def __truediv__(self, other: Any) -> 'Net':
|
||||
return _add_op('div', [self, other])
|
||||
|
||||
def __rtruediv__(self, other: Any) -> 'Net':
|
||||
return _add_op('div', [other, self])
|
||||
|
||||
def __floordiv__(self, other: Any) -> 'Net':
|
||||
return _add_op('floordiv', [self, other])
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> 'Net':
|
||||
return _add_op('floordiv', [other, self])
|
||||
|
||||
def __neg__(self) -> 'Net':
|
||||
return _add_op('sub', [CPVariable(0), self])
|
||||
|
||||
def __gt__(self, other: Any) -> 'Net':
|
||||
return _add_op('gt', [self, other])
|
||||
|
||||
def __lt__(self, other: Any) -> 'Net':
|
||||
return _add_op('gt', [other, self])
|
||||
|
||||
def __eq__(self, other: Any) -> 'Net': # type: ignore
|
||||
return _add_op('eq', [self, other], True)
|
||||
|
||||
def __req__(self, other: Any) -> 'Net':
|
||||
return _add_op('eq', [self, other], True)
|
||||
|
||||
def __ne__(self, other: Any) -> 'Net': # type: ignore
|
||||
return _add_op('ne', [self, other], True)
|
||||
|
||||
def __rne__(self, other: Any) -> 'Net':
|
||||
return _add_op('ne', [self, other], True)
|
||||
|
||||
def __mod__(self, other: Any) -> 'Net':
|
||||
return _add_op('mod', [self, other])
|
||||
|
||||
def __rmod__(self, other: Any) -> 'Net':
|
||||
return _add_op('mod', [other, self])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
names = get_var_name(self)
|
||||
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}"
|
||||
|
|
@ -112,15 +67,278 @@ class Net:
|
|||
return id(self)
|
||||
|
||||
|
||||
class CPNumber(Net):
|
||||
def __init__(self, dtype: str, source: Node):
|
||||
self.dtype = dtype
|
||||
self.source = source
|
||||
|
||||
@overload
|
||||
def __mul__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __mul__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __mul__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('mul', [self, other], True)
|
||||
|
||||
@overload
|
||||
def __rmul__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __rmul__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __rmul__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('mul', [self, other], True)
|
||||
|
||||
@overload
|
||||
def __add__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __add__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __add__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('add', [self, other], True)
|
||||
|
||||
@overload
|
||||
def __radd__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __radd__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __radd__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('add', [self, other], True)
|
||||
|
||||
@overload
|
||||
def __sub__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __sub__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __sub__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('sub', [self, other])
|
||||
|
||||
@overload
|
||||
def __rsub__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __rsub__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __rsub__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('sub', [other, self])
|
||||
|
||||
def __truediv__(self, other: NumLike) -> 'cpfloat':
|
||||
ret = _add_op('div', [self, other])
|
||||
assert isinstance(ret, cpfloat)
|
||||
return ret
|
||||
|
||||
def __rtruediv__(self, other: NumLike) -> 'cpfloat':
|
||||
ret = _add_op('div', [other, self])
|
||||
assert isinstance(ret, cpfloat)
|
||||
return ret
|
||||
|
||||
@overload
|
||||
def __floordiv__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __floordiv__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __floordiv__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('floordiv', [self, other])
|
||||
|
||||
@overload
|
||||
def __rfloordiv__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __rfloordiv__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('floordiv', [other, self])
|
||||
|
||||
def __neg__(self: TNumber) -> TNumber:
|
||||
return cast(TNumber, _add_op('sub', [cpvalue(0), self]))
|
||||
|
||||
def __gt__(self, other: NumLike) -> 'cpbool':
|
||||
ret = _add_op('gt', [self, other])
|
||||
return cpbool(ret.source)
|
||||
|
||||
def __lt__(self, other: NumLike) -> 'cpbool':
|
||||
ret = _add_op('gt', [other, self])
|
||||
return cpbool(ret.source)
|
||||
|
||||
def __eq__(self, other: NumLike) -> 'cpbool': # type: ignore
|
||||
ret = _add_op('eq', [self, other], True)
|
||||
return cpbool(ret.source)
|
||||
|
||||
def __ne__(self, other: NumLike) -> 'cpbool': # type: ignore
|
||||
ret = _add_op('ne', [self, other], True)
|
||||
return cpbool(ret.source)
|
||||
|
||||
@overload
|
||||
def __mod__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __mod__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __mod__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('mod', [self, other])
|
||||
|
||||
@overload
|
||||
def __rmod__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __rmod__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __rmod__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('mod', [other, self])
|
||||
|
||||
@overload
|
||||
def __pow__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __pow__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __pow__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('pow', [other, self])
|
||||
|
||||
@overload
|
||||
def __rpow__(self: TNumber, other: uniint) -> TNumber:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __rpow__(self, other: unifloat) -> 'cpfloat':
|
||||
...
|
||||
|
||||
def __rpow__(self, other: NumLike) -> 'CPNumber':
|
||||
return _add_op('rpow', [self, other])
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return super().__hash__()
|
||||
|
||||
|
||||
class cpint(CPNumber):
|
||||
def __init__(self, source: int | Node):
|
||||
if isinstance(source, Node):
|
||||
self.source = source
|
||||
else:
|
||||
self.source = InitVar(int(source))
|
||||
self.dtype = 'int'
|
||||
|
||||
def __lshift__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('lshift', [self, other])
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __rlshift__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('lshift', [other, self])
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __rshift__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('rshift', [self, other])
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __rrshift__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('rshift', [other, self])
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __and__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('bwand', [self, other], True)
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __rand__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('rwand', [other, self], True)
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __or__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('bwor', [self, other], True)
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __ror__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('bwor', [other, self], True)
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __xor__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('bwxor', [self, other], True)
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
def __rxor__(self, other: uniint) -> 'cpint':
|
||||
ret = _add_op('bwxor', [other, self], True)
|
||||
assert isinstance(ret, cpint)
|
||||
return ret
|
||||
|
||||
|
||||
class cpfloat(CPNumber):
|
||||
def __init__(self, source: float | Node):
|
||||
if isinstance(source, Node):
|
||||
self.source = source
|
||||
else:
|
||||
self.source = InitVar(float(source))
|
||||
self.dtype = 'float'
|
||||
|
||||
|
||||
class cpbool(cpint):
|
||||
def __init__(self, source: bool | Node):
|
||||
if isinstance(source, Node):
|
||||
self.source = source
|
||||
else:
|
||||
self.source = InitVar(bool(source))
|
||||
self.dtype = 'bool'
|
||||
|
||||
|
||||
class cpvector:
|
||||
def __init__(self, *value: NumLike):
|
||||
self.value = value
|
||||
|
||||
def __add__(self, other: 'cpvector') -> 'cpvector':
|
||||
assert len(self.value) == len(other.value)
|
||||
tup = (a + b for a, b in zip(self.value, other.value))
|
||||
return cpvector(*(v for v in tup if isinstance(v, CPNumber)))
|
||||
|
||||
|
||||
class InitVar(Node):
|
||||
def __init__(self, value: float | int | bool):
|
||||
def __init__(self, value: int | float):
|
||||
self.dtype, self.value = _get_data_and_dtype(value)
|
||||
self.name = 'const_' + self.dtype
|
||||
self.args = []
|
||||
|
||||
|
||||
class Write(Node):
|
||||
def __init__(self, net: Net):
|
||||
def __init__(self, input: NetAndNum):
|
||||
if isinstance(input, Net):
|
||||
net = input
|
||||
else:
|
||||
node = InitVar(input)
|
||||
net = Net(node.dtype, node)
|
||||
|
||||
self.name = 'write_' + transl_type(net.dtype)
|
||||
self.args = [net]
|
||||
|
||||
|
|
@ -132,8 +350,13 @@ class Op(Node):
|
|||
self.args: list[Net] = args
|
||||
|
||||
|
||||
def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
|
||||
arg_nets = [a if isinstance(a, Net) else CPVariable(a) for a in args]
|
||||
def net_from_value(value: Any) -> Net:
|
||||
vi = InitVar(value)
|
||||
return Net(vi.dtype, vi)
|
||||
|
||||
|
||||
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber:
|
||||
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
|
||||
|
||||
if commutative:
|
||||
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
|
||||
|
|
@ -141,37 +364,49 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
|
|||
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])
|
||||
|
||||
if typed_op not in generic_sdb.stencil_definitions:
|
||||
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
||||
#raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
||||
raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}")
|
||||
|
||||
if op in {'eq', 'ne', 'gt'}:
|
||||
result_type = 'bool'
|
||||
else:
|
||||
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
||||
|
||||
result_net = Net(result_type, Op(typed_op, arg_nets))
|
||||
#if op in {'eq', 'ne', 'gt'}:
|
||||
# assert result_type == 'int'
|
||||
# result_type = 'bool'
|
||||
|
||||
return result_net
|
||||
if result_type == 'int':
|
||||
return cpint(Op(typed_op, arg_nets))
|
||||
#elif result_type == 'float':
|
||||
else:
|
||||
return cpfloat(Op(typed_op, arg_nets))
|
||||
#else:
|
||||
# return cpbool(result_type, Op(typed_op, arg_nets))
|
||||
#return CPNumber(result_type, Op(typed_op, arg_nets))
|
||||
|
||||
|
||||
class CPVariable(Net):
|
||||
def __init__(self, value: Any):
|
||||
@overload
|
||||
def cpvalue(value: bool) -> cpbool:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def cpvalue(value: int) -> cpint:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def cpvalue(value: float) -> cpfloat:
|
||||
...
|
||||
|
||||
|
||||
def cpvalue(value: bool | int | float) -> cpbool | cpint | cpfloat:
|
||||
vi = InitVar(value)
|
||||
Net.__init__(self, vi.dtype, vi)
|
||||
|
||||
|
||||
class CPFloat(CPVariable):
|
||||
def __init__(self, value: float):
|
||||
super().__init__(value)
|
||||
|
||||
|
||||
class CPInt(CPVariable):
|
||||
def __init__(self, value: int):
|
||||
super().__init__(value)
|
||||
|
||||
|
||||
class CPBool(CPVariable):
|
||||
def __init__(self, value: bool):
|
||||
super().__init__(value)
|
||||
if isinstance(value, bool):
|
||||
return cpbool(vi)
|
||||
elif isinstance(value, float):
|
||||
return cpfloat(vi)
|
||||
else:
|
||||
return cpint(vi)
|
||||
|
||||
|
||||
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
||||
|
|
@ -185,20 +420,6 @@ def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
|||
raise ValueError(f'Non supported data type: {type(value).__name__}')
|
||||
|
||||
|
||||
class vec3d:
|
||||
def __init__(self, value: tuple[Net, Net, Net]):
|
||||
self.value = value
|
||||
|
||||
def __add__(self, other: 'vec3d') -> 'vec3d':
|
||||
a1, a2, a3 = self.value
|
||||
b1, b2, b3 = other.value
|
||||
return vec3d((a1 + b1, a2 + b2, a3 + b3))
|
||||
|
||||
|
||||
def const_vector3d(x: float, y: float, z: float) -> vec3d:
|
||||
return vec3d((CPVariable(x), CPVariable(y), CPVariable(z)))
|
||||
|
||||
|
||||
def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
|
||||
"""Perform a stable topological sort on a directed acyclic graph (DAG).
|
||||
Arguments:
|
||||
|
|
@ -481,15 +702,15 @@ class Target():
|
|||
self.sdb = stencil_db_from_package(arch, optimization)
|
||||
self._variables: dict[Net, tuple[int, int, str]] = dict()
|
||||
|
||||
def compile(self, *variables: Net | list[Net]) -> None:
|
||||
def compile(self, *variables: int | float | cpint | cpfloat | cpbool | list[int | float | cpint | cpfloat | cpbool]) -> None:
|
||||
nodes: list[Node] = []
|
||||
for s in variables:
|
||||
if isinstance(s, Net):
|
||||
nodes.append(Write(s))
|
||||
else:
|
||||
if isinstance(s, list):
|
||||
for net in s:
|
||||
assert isinstance(net, Net), f"The folowing element is not a Net: {net}"
|
||||
nodes.append(Write(net))
|
||||
else:
|
||||
nodes.append(Write(s))
|
||||
|
||||
dw, self._variables = compile_to_instruction_list(nodes, self.sdb)
|
||||
dw.write_com(binw.Command.END_COM)
|
||||
|
|
@ -502,7 +723,8 @@ class Target():
|
|||
dw.write_com(binw.Command.END_COM)
|
||||
assert coparun(dw.get_data()) > 0
|
||||
|
||||
def read_value(self, net: Net) -> float | int | bool:
|
||||
def read_value(self, net: NumLike) -> float | int | bool:
|
||||
assert isinstance(net, Net), "Variable must be a copapy variable object"
|
||||
assert net in self._variables, f"Variable {net} not found"
|
||||
addr, lengths, var_type = self._variables[net]
|
||||
assert lengths > 0
|
||||
|
|
|
|||
Loading…
Reference in New Issue