full type hints added

This commit is contained in:
Nicolas Kruse 2025-10-18 23:20:49 +02:00
parent 0e36b672d8
commit 52f5b28017
1 changed files with 325 additions and 103 deletions

View File

@ -1,5 +1,7 @@
import pkgutil import pkgutil
from typing import Generator, Iterable, Any from typing import Generator, Iterable, Any, TypeVar, overload, TypeAlias
from typing import cast
from . import binwrite as binw from . import binwrite as binw
from .stencil_db import stencil_database from .stencil_db import stencil_database
from collections import defaultdict, deque from collections import defaultdict, deque
@ -7,7 +9,16 @@ from coparun_module import coparun, read_data_mem
import struct import struct
import platform import platform
Operand = type['Net'] | float | int #CPNumLike: TypeAlias = 'cpint | cpfloat | cpbool'
NumLike: TypeAlias = 'cpint | cpfloat | cpbool | int | float| bool'
NumLikeAndNet: TypeAlias = 'cpint | cpfloat | cpbool | int | float | bool | Net'
NetAndNum: TypeAlias = 'Net | int | float'
unifloat: TypeAlias = 'cpfloat | float'
uniint: TypeAlias = 'cpint | int'
unibool: TypeAlias = 'cpbool | bool'
TNumber = TypeVar("TNumber", bound='CPNumber')
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
@ -24,6 +35,7 @@ def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> s
generic_sdb = stencil_db_from_package() generic_sdb = stencil_db_from_package()
def transl_type(t: str) -> str: def transl_type(t: str) -> str:
return {'bool': 'int'}.get(t, t) return {'bool': 'int'}.get(t, t)
@ -47,63 +59,6 @@ class Net:
self.dtype = dtype self.dtype = dtype
self.source = source self.source = source
def __mul__(self, other: Any) -> 'Net':
return _add_op('mul', [self, other], True)
def __rmul__(self, other: Any) -> 'Net':
return _add_op('mul', [self, other], True)
def __add__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True)
def __radd__(self, other: Any) -> 'Net':
return _add_op('add', [self, other], True)
def __sub__(self, other: Any) -> 'Net':
return _add_op('sub', [self, other])
def __rsub__(self, other: Any) -> 'Net':
return _add_op('sub', [other, self])
def __truediv__(self, other: Any) -> 'Net':
return _add_op('div', [self, other])
def __rtruediv__(self, other: Any) -> 'Net':
return _add_op('div', [other, self])
def __floordiv__(self, other: Any) -> 'Net':
return _add_op('floordiv', [self, other])
def __rfloordiv__(self, other: Any) -> 'Net':
return _add_op('floordiv', [other, self])
def __neg__(self) -> 'Net':
return _add_op('sub', [CPVariable(0), self])
def __gt__(self, other: Any) -> 'Net':
return _add_op('gt', [self, other])
def __lt__(self, other: Any) -> 'Net':
return _add_op('gt', [other, self])
def __eq__(self, other: Any) -> 'Net': # type: ignore
return _add_op('eq', [self, other], True)
def __req__(self, other: Any) -> 'Net':
return _add_op('eq', [self, other], True)
def __ne__(self, other: Any) -> 'Net': # type: ignore
return _add_op('ne', [self, other], True)
def __rne__(self, other: Any) -> 'Net':
return _add_op('ne', [self, other], True)
def __mod__(self, other: Any) -> 'Net':
return _add_op('mod', [self, other])
def __rmod__(self, other: Any) -> 'Net':
return _add_op('mod', [other, self])
def __repr__(self) -> str: def __repr__(self) -> str:
names = get_var_name(self) names = get_var_name(self)
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}" return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}"
@ -112,15 +67,278 @@ class Net:
return id(self) return id(self)
class CPNumber(Net):
def __init__(self, dtype: str, source: Node):
self.dtype = dtype
self.source = source
@overload
def __mul__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __mul__(self, other: unifloat) -> 'cpfloat':
...
def __mul__(self, other: NumLike) -> 'CPNumber':
return _add_op('mul', [self, other], True)
@overload
def __rmul__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rmul__(self, other: unifloat) -> 'cpfloat':
...
def __rmul__(self, other: NumLike) -> 'CPNumber':
return _add_op('mul', [self, other], True)
@overload
def __add__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __add__(self, other: unifloat) -> 'cpfloat':
...
def __add__(self, other: NumLike) -> 'CPNumber':
return _add_op('add', [self, other], True)
@overload
def __radd__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __radd__(self, other: unifloat) -> 'cpfloat':
...
def __radd__(self, other: NumLike) -> 'CPNumber':
return _add_op('add', [self, other], True)
@overload
def __sub__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __sub__(self, other: unifloat) -> 'cpfloat':
...
def __sub__(self, other: NumLike) -> 'CPNumber':
return _add_op('sub', [self, other])
@overload
def __rsub__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rsub__(self, other: unifloat) -> 'cpfloat':
...
def __rsub__(self, other: NumLike) -> 'CPNumber':
return _add_op('sub', [other, self])
def __truediv__(self, other: NumLike) -> 'cpfloat':
ret = _add_op('div', [self, other])
assert isinstance(ret, cpfloat)
return ret
def __rtruediv__(self, other: NumLike) -> 'cpfloat':
ret = _add_op('div', [other, self])
assert isinstance(ret, cpfloat)
return ret
@overload
def __floordiv__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __floordiv__(self, other: unifloat) -> 'cpfloat':
...
def __floordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [self, other])
@overload
def __rfloordiv__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rfloordiv__(self, other: unifloat) -> 'cpfloat':
...
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [other, self])
def __neg__(self: TNumber) -> TNumber:
return cast(TNumber, _add_op('sub', [cpvalue(0), self]))
def __gt__(self, other: NumLike) -> 'cpbool':
ret = _add_op('gt', [self, other])
return cpbool(ret.source)
def __lt__(self, other: NumLike) -> 'cpbool':
ret = _add_op('gt', [other, self])
return cpbool(ret.source)
def __eq__(self, other: NumLike) -> 'cpbool': # type: ignore
ret = _add_op('eq', [self, other], True)
return cpbool(ret.source)
def __ne__(self, other: NumLike) -> 'cpbool': # type: ignore
ret = _add_op('ne', [self, other], True)
return cpbool(ret.source)
@overload
def __mod__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __mod__(self, other: unifloat) -> 'cpfloat':
...
def __mod__(self, other: NumLike) -> 'CPNumber':
return _add_op('mod', [self, other])
@overload
def __rmod__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rmod__(self, other: unifloat) -> 'cpfloat':
...
def __rmod__(self, other: NumLike) -> 'CPNumber':
return _add_op('mod', [other, self])
@overload
def __pow__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __pow__(self, other: unifloat) -> 'cpfloat':
...
def __pow__(self, other: NumLike) -> 'CPNumber':
return _add_op('pow', [other, self])
@overload
def __rpow__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rpow__(self, other: unifloat) -> 'cpfloat':
...
def __rpow__(self, other: NumLike) -> 'CPNumber':
return _add_op('rpow', [self, other])
def __hash__(self) -> int:
return super().__hash__()
class cpint(CPNumber):
def __init__(self, source: int | Node):
if isinstance(source, Node):
self.source = source
else:
self.source = InitVar(int(source))
self.dtype = 'int'
def __lshift__(self, other: uniint) -> 'cpint':
ret = _add_op('lshift', [self, other])
assert isinstance(ret, cpint)
return ret
def __rlshift__(self, other: uniint) -> 'cpint':
ret = _add_op('lshift', [other, self])
assert isinstance(ret, cpint)
return ret
def __rshift__(self, other: uniint) -> 'cpint':
ret = _add_op('rshift', [self, other])
assert isinstance(ret, cpint)
return ret
def __rrshift__(self, other: uniint) -> 'cpint':
ret = _add_op('rshift', [other, self])
assert isinstance(ret, cpint)
return ret
def __and__(self, other: uniint) -> 'cpint':
ret = _add_op('bwand', [self, other], True)
assert isinstance(ret, cpint)
return ret
def __rand__(self, other: uniint) -> 'cpint':
ret = _add_op('rwand', [other, self], True)
assert isinstance(ret, cpint)
return ret
def __or__(self, other: uniint) -> 'cpint':
ret = _add_op('bwor', [self, other], True)
assert isinstance(ret, cpint)
return ret
def __ror__(self, other: uniint) -> 'cpint':
ret = _add_op('bwor', [other, self], True)
assert isinstance(ret, cpint)
return ret
def __xor__(self, other: uniint) -> 'cpint':
ret = _add_op('bwxor', [self, other], True)
assert isinstance(ret, cpint)
return ret
def __rxor__(self, other: uniint) -> 'cpint':
ret = _add_op('bwxor', [other, self], True)
assert isinstance(ret, cpint)
return ret
class cpfloat(CPNumber):
def __init__(self, source: float | Node):
if isinstance(source, Node):
self.source = source
else:
self.source = InitVar(float(source))
self.dtype = 'float'
class cpbool(cpint):
def __init__(self, source: bool | Node):
if isinstance(source, Node):
self.source = source
else:
self.source = InitVar(bool(source))
self.dtype = 'bool'
class cpvector:
def __init__(self, *value: NumLike):
self.value = value
def __add__(self, other: 'cpvector') -> 'cpvector':
assert len(self.value) == len(other.value)
tup = (a + b for a, b in zip(self.value, other.value))
return cpvector(*(v for v in tup if isinstance(v, CPNumber)))
class InitVar(Node): class InitVar(Node):
def __init__(self, value: float | int | bool): def __init__(self, value: int | float):
self.dtype, self.value = _get_data_and_dtype(value) self.dtype, self.value = _get_data_and_dtype(value)
self.name = 'const_' + self.dtype self.name = 'const_' + self.dtype
self.args = [] self.args = []
class Write(Node): class Write(Node):
def __init__(self, net: Net): def __init__(self, input: NetAndNum):
if isinstance(input, Net):
net = input
else:
node = InitVar(input)
net = Net(node.dtype, node)
self.name = 'write_' + transl_type(net.dtype) self.name = 'write_' + transl_type(net.dtype)
self.args = [net] self.args = [net]
@ -132,8 +350,13 @@ class Op(Node):
self.args: list[Net] = args self.args: list[Net] = args
def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: def net_from_value(value: Any) -> Net:
arg_nets = [a if isinstance(a, Net) else CPVariable(a) for a in args] vi = InitVar(value)
return Net(vi.dtype, vi)
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber:
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
if commutative: if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
@ -141,37 +364,49 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])
if typed_op not in generic_sdb.stencil_definitions: if typed_op not in generic_sdb.stencil_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") #raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}")
if op in {'eq', 'ne', 'gt'}: result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
result_type = 'bool'
#if op in {'eq', 'ne', 'gt'}:
# assert result_type == 'int'
# result_type = 'bool'
if result_type == 'int':
return cpint(Op(typed_op, arg_nets))
#elif result_type == 'float':
else: else:
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] return cpfloat(Op(typed_op, arg_nets))
#else:
result_net = Net(result_type, Op(typed_op, arg_nets)) # return cpbool(result_type, Op(typed_op, arg_nets))
#return CPNumber(result_type, Op(typed_op, arg_nets))
return result_net
class CPVariable(Net): @overload
def __init__(self, value: Any): def cpvalue(value: bool) -> cpbool:
vi = InitVar(value) ...
Net.__init__(self, vi.dtype, vi)
class CPFloat(CPVariable): @overload
def __init__(self, value: float): def cpvalue(value: int) -> cpint:
super().__init__(value) ...
class CPInt(CPVariable): @overload
def __init__(self, value: int): def cpvalue(value: float) -> cpfloat:
super().__init__(value) ...
class CPBool(CPVariable): def cpvalue(value: bool | int | float) -> cpbool | cpint | cpfloat:
def __init__(self, value: bool): vi = InitVar(value)
super().__init__(value)
if isinstance(value, bool):
return cpbool(vi)
elif isinstance(value, float):
return cpfloat(vi)
else:
return cpint(vi)
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
@ -185,20 +420,6 @@ def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
raise ValueError(f'Non supported data type: {type(value).__name__}') raise ValueError(f'Non supported data type: {type(value).__name__}')
class vec3d:
def __init__(self, value: tuple[Net, Net, Net]):
self.value = value
def __add__(self, other: 'vec3d') -> 'vec3d':
a1, a2, a3 = self.value
b1, b2, b3 = other.value
return vec3d((a1 + b1, a2 + b2, a3 + b3))
def const_vector3d(x: float, y: float, z: float) -> vec3d:
return vec3d((CPVariable(x), CPVariable(y), CPVariable(z)))
def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
"""Perform a stable topological sort on a directed acyclic graph (DAG). """Perform a stable topological sort on a directed acyclic graph (DAG).
Arguments: Arguments:
@ -481,15 +702,15 @@ class Target():
self.sdb = stencil_db_from_package(arch, optimization) self.sdb = stencil_db_from_package(arch, optimization)
self._variables: dict[Net, tuple[int, int, str]] = dict() self._variables: dict[Net, tuple[int, int, str]] = dict()
def compile(self, *variables: Net | list[Net]) -> None: def compile(self, *variables: int | float | cpint | cpfloat | cpbool | list[int | float | cpint | cpfloat | cpbool]) -> None:
nodes: list[Node] = [] nodes: list[Node] = []
for s in variables: for s in variables:
if isinstance(s, Net): if isinstance(s, list):
nodes.append(Write(s))
else:
for net in s: for net in s:
assert isinstance(net, Net), f"The folowing element is not a Net: {net}" assert isinstance(net, Net), f"The folowing element is not a Net: {net}"
nodes.append(Write(net)) nodes.append(Write(net))
else:
nodes.append(Write(s))
dw, self._variables = compile_to_instruction_list(nodes, self.sdb) dw, self._variables = compile_to_instruction_list(nodes, self.sdb)
dw.write_com(binw.Command.END_COM) dw.write_com(binw.Command.END_COM)
@ -502,7 +723,8 @@ class Target():
dw.write_com(binw.Command.END_COM) dw.write_com(binw.Command.END_COM)
assert coparun(dw.get_data()) > 0 assert coparun(dw.get_data()) > 0
def read_value(self, net: Net) -> float | int | bool: def read_value(self, net: NumLike) -> float | int | bool:
assert isinstance(net, Net), "Variable must be a copapy variable object"
assert net in self._variables, f"Variable {net} not found" assert net in self._variables, f"Variable {net} not found"
addr, lengths, var_type = self._variables[net] addr, lengths, var_type = self._variables[net]
assert lengths > 0 assert lengths > 0