From 52f5b2801708f648cca464c27ebd5ae49c38fde9 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Sat, 18 Oct 2025 23:20:49 +0200 Subject: [PATCH] full type hints added --- src/copapy/__init__.py | 428 +++++++++++++++++++++++++++++++---------- 1 file changed, 325 insertions(+), 103 deletions(-) diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 3ac106a..d9ae459 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -1,5 +1,7 @@ import pkgutil -from typing import Generator, Iterable, Any +from typing import Generator, Iterable, Any, TypeVar, overload, TypeAlias +from typing import cast + from . import binwrite as binw from .stencil_db import stencil_database from collections import defaultdict, deque @@ -7,7 +9,16 @@ from coparun_module import coparun, read_data_mem import struct import platform -Operand = type['Net'] | float | int +#CPNumLike: TypeAlias = 'cpint | cpfloat | cpbool' +NumLike: TypeAlias = 'cpint | cpfloat | cpbool | int | float| bool' +NumLikeAndNet: TypeAlias = 'cpint | cpfloat | cpbool | int | float | bool | Net' +NetAndNum: TypeAlias = 'Net | int | float' + +unifloat: TypeAlias = 'cpfloat | float' +uniint: TypeAlias = 'cpint | int' +unibool: TypeAlias = 'cpbool | bool' + +TNumber = TypeVar("TNumber", bound='CPNumber') def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: @@ -24,6 +35,7 @@ def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> s generic_sdb = stencil_db_from_package() + def transl_type(t: str) -> str: return {'bool': 'int'}.get(t, t) @@ -47,63 +59,6 @@ class Net: self.dtype = dtype self.source = source - def __mul__(self, other: Any) -> 'Net': - return _add_op('mul', [self, other], True) - - def __rmul__(self, other: Any) -> 'Net': - return _add_op('mul', [self, other], True) - - def __add__(self, other: Any) -> 'Net': - return _add_op('add', [self, other], True) - - def __radd__(self, other: Any) -> 'Net': - return _add_op('add', [self, other], True) - - def __sub__(self, other: Any) -> 'Net': - return _add_op('sub', [self, other]) - - def __rsub__(self, other: Any) -> 'Net': - return _add_op('sub', [other, self]) - - def __truediv__(self, other: Any) -> 'Net': - return _add_op('div', [self, other]) - - def __rtruediv__(self, other: Any) -> 'Net': - return _add_op('div', [other, self]) - - def __floordiv__(self, other: Any) -> 'Net': - return _add_op('floordiv', [self, other]) - - def __rfloordiv__(self, other: Any) -> 'Net': - return _add_op('floordiv', [other, self]) - - def __neg__(self) -> 'Net': - return _add_op('sub', [CPVariable(0), self]) - - def __gt__(self, other: Any) -> 'Net': - return _add_op('gt', [self, other]) - - def __lt__(self, other: Any) -> 'Net': - return _add_op('gt', [other, self]) - - def __eq__(self, other: Any) -> 'Net': # type: ignore - return _add_op('eq', [self, other], True) - - def __req__(self, other: Any) -> 'Net': - return _add_op('eq', [self, other], True) - - def __ne__(self, other: Any) -> 'Net': # type: ignore - return _add_op('ne', [self, other], True) - - def __rne__(self, other: Any) -> 'Net': - return _add_op('ne', [self, other], True) - - def __mod__(self, other: Any) -> 'Net': - return _add_op('mod', [self, other]) - - def __rmod__(self, other: Any) -> 'Net': - return _add_op('mod', [other, self]) - def __repr__(self) -> str: names = get_var_name(self) return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}" @@ -112,15 +67,278 @@ class Net: return id(self) +class CPNumber(Net): + def __init__(self, dtype: str, source: Node): + self.dtype = dtype + self.source = source + + @overload + def __mul__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __mul__(self, other: unifloat) -> 'cpfloat': + ... + + def __mul__(self, other: NumLike) -> 'CPNumber': + return _add_op('mul', [self, other], True) + + @overload + def __rmul__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rmul__(self, other: unifloat) -> 'cpfloat': + ... + + def __rmul__(self, other: NumLike) -> 'CPNumber': + return _add_op('mul', [self, other], True) + + @overload + def __add__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __add__(self, other: unifloat) -> 'cpfloat': + ... + + def __add__(self, other: NumLike) -> 'CPNumber': + return _add_op('add', [self, other], True) + + @overload + def __radd__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __radd__(self, other: unifloat) -> 'cpfloat': + ... + + def __radd__(self, other: NumLike) -> 'CPNumber': + return _add_op('add', [self, other], True) + + @overload + def __sub__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __sub__(self, other: unifloat) -> 'cpfloat': + ... + + def __sub__(self, other: NumLike) -> 'CPNumber': + return _add_op('sub', [self, other]) + + @overload + def __rsub__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rsub__(self, other: unifloat) -> 'cpfloat': + ... + + def __rsub__(self, other: NumLike) -> 'CPNumber': + return _add_op('sub', [other, self]) + + def __truediv__(self, other: NumLike) -> 'cpfloat': + ret = _add_op('div', [self, other]) + assert isinstance(ret, cpfloat) + return ret + + def __rtruediv__(self, other: NumLike) -> 'cpfloat': + ret = _add_op('div', [other, self]) + assert isinstance(ret, cpfloat) + return ret + + @overload + def __floordiv__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __floordiv__(self, other: unifloat) -> 'cpfloat': + ... + + def __floordiv__(self, other: NumLike) -> 'CPNumber': + return _add_op('floordiv', [self, other]) + + @overload + def __rfloordiv__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rfloordiv__(self, other: unifloat) -> 'cpfloat': + ... + + def __rfloordiv__(self, other: NumLike) -> 'CPNumber': + return _add_op('floordiv', [other, self]) + + def __neg__(self: TNumber) -> TNumber: + return cast(TNumber, _add_op('sub', [cpvalue(0), self])) + + def __gt__(self, other: NumLike) -> 'cpbool': + ret = _add_op('gt', [self, other]) + return cpbool(ret.source) + + def __lt__(self, other: NumLike) -> 'cpbool': + ret = _add_op('gt', [other, self]) + return cpbool(ret.source) + + def __eq__(self, other: NumLike) -> 'cpbool': # type: ignore + ret = _add_op('eq', [self, other], True) + return cpbool(ret.source) + + def __ne__(self, other: NumLike) -> 'cpbool': # type: ignore + ret = _add_op('ne', [self, other], True) + return cpbool(ret.source) + + @overload + def __mod__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __mod__(self, other: unifloat) -> 'cpfloat': + ... + + def __mod__(self, other: NumLike) -> 'CPNumber': + return _add_op('mod', [self, other]) + + @overload + def __rmod__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rmod__(self, other: unifloat) -> 'cpfloat': + ... + + def __rmod__(self, other: NumLike) -> 'CPNumber': + return _add_op('mod', [other, self]) + + @overload + def __pow__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __pow__(self, other: unifloat) -> 'cpfloat': + ... + + def __pow__(self, other: NumLike) -> 'CPNumber': + return _add_op('pow', [other, self]) + + @overload + def __rpow__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rpow__(self, other: unifloat) -> 'cpfloat': + ... + + def __rpow__(self, other: NumLike) -> 'CPNumber': + return _add_op('rpow', [self, other]) + + def __hash__(self) -> int: + return super().__hash__() + + +class cpint(CPNumber): + def __init__(self, source: int | Node): + if isinstance(source, Node): + self.source = source + else: + self.source = InitVar(int(source)) + self.dtype = 'int' + + def __lshift__(self, other: uniint) -> 'cpint': + ret = _add_op('lshift', [self, other]) + assert isinstance(ret, cpint) + return ret + + def __rlshift__(self, other: uniint) -> 'cpint': + ret = _add_op('lshift', [other, self]) + assert isinstance(ret, cpint) + return ret + + def __rshift__(self, other: uniint) -> 'cpint': + ret = _add_op('rshift', [self, other]) + assert isinstance(ret, cpint) + return ret + + def __rrshift__(self, other: uniint) -> 'cpint': + ret = _add_op('rshift', [other, self]) + assert isinstance(ret, cpint) + return ret + + def __and__(self, other: uniint) -> 'cpint': + ret = _add_op('bwand', [self, other], True) + assert isinstance(ret, cpint) + return ret + + def __rand__(self, other: uniint) -> 'cpint': + ret = _add_op('rwand', [other, self], True) + assert isinstance(ret, cpint) + return ret + + def __or__(self, other: uniint) -> 'cpint': + ret = _add_op('bwor', [self, other], True) + assert isinstance(ret, cpint) + return ret + + def __ror__(self, other: uniint) -> 'cpint': + ret = _add_op('bwor', [other, self], True) + assert isinstance(ret, cpint) + return ret + + def __xor__(self, other: uniint) -> 'cpint': + ret = _add_op('bwxor', [self, other], True) + assert isinstance(ret, cpint) + return ret + + def __rxor__(self, other: uniint) -> 'cpint': + ret = _add_op('bwxor', [other, self], True) + assert isinstance(ret, cpint) + return ret + + +class cpfloat(CPNumber): + def __init__(self, source: float | Node): + if isinstance(source, Node): + self.source = source + else: + self.source = InitVar(float(source)) + self.dtype = 'float' + + +class cpbool(cpint): + def __init__(self, source: bool | Node): + if isinstance(source, Node): + self.source = source + else: + self.source = InitVar(bool(source)) + self.dtype = 'bool' + + +class cpvector: + def __init__(self, *value: NumLike): + self.value = value + + def __add__(self, other: 'cpvector') -> 'cpvector': + assert len(self.value) == len(other.value) + tup = (a + b for a, b in zip(self.value, other.value)) + return cpvector(*(v for v in tup if isinstance(v, CPNumber))) + + class InitVar(Node): - def __init__(self, value: float | int | bool): + def __init__(self, value: int | float): self.dtype, self.value = _get_data_and_dtype(value) self.name = 'const_' + self.dtype self.args = [] class Write(Node): - def __init__(self, net: Net): + def __init__(self, input: NetAndNum): + if isinstance(input, Net): + net = input + else: + node = InitVar(input) + net = Net(node.dtype, node) + self.name = 'write_' + transl_type(net.dtype) self.args = [net] @@ -132,8 +350,13 @@ class Op(Node): self.args: list[Net] = args -def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: - arg_nets = [a if isinstance(a, Net) else CPVariable(a) for a in args] +def net_from_value(value: Any) -> Net: + vi = InitVar(value) + return Net(vi.dtype, vi) + + +def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber: + arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] if commutative: arg_nets = sorted(arg_nets, key=lambda a: a.dtype) @@ -141,37 +364,49 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) if typed_op not in generic_sdb.stencil_definitions: - raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") + #raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") + raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}") - if op in {'eq', 'ne', 'gt'}: - result_type = 'bool' + result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] + + #if op in {'eq', 'ne', 'gt'}: + # assert result_type == 'int' + # result_type = 'bool' + + if result_type == 'int': + return cpint(Op(typed_op, arg_nets)) + #elif result_type == 'float': else: - result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] - - result_net = Net(result_type, Op(typed_op, arg_nets)) - - return result_net + return cpfloat(Op(typed_op, arg_nets)) + #else: + # return cpbool(result_type, Op(typed_op, arg_nets)) + #return CPNumber(result_type, Op(typed_op, arg_nets)) -class CPVariable(Net): - def __init__(self, value: Any): - vi = InitVar(value) - Net.__init__(self, vi.dtype, vi) +@overload +def cpvalue(value: bool) -> cpbool: + ... -class CPFloat(CPVariable): - def __init__(self, value: float): - super().__init__(value) +@overload +def cpvalue(value: int) -> cpint: + ... -class CPInt(CPVariable): - def __init__(self, value: int): - super().__init__(value) +@overload +def cpvalue(value: float) -> cpfloat: + ... -class CPBool(CPVariable): - def __init__(self, value: bool): - super().__init__(value) +def cpvalue(value: bool | int | float) -> cpbool | cpint | cpfloat: + vi = InitVar(value) + + if isinstance(value, bool): + return cpbool(vi) + elif isinstance(value, float): + return cpfloat(vi) + else: + return cpint(vi) def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: @@ -185,20 +420,6 @@ def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: raise ValueError(f'Non supported data type: {type(value).__name__}') -class vec3d: - def __init__(self, value: tuple[Net, Net, Net]): - self.value = value - - def __add__(self, other: 'vec3d') -> 'vec3d': - a1, a2, a3 = self.value - b1, b2, b3 = other.value - return vec3d((a1 + b1, a2 + b2, a3 + b3)) - - -def const_vector3d(x: float, y: float, z: float) -> vec3d: - return vec3d((CPVariable(x), CPVariable(y), CPVariable(z))) - - def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: """Perform a stable topological sort on a directed acyclic graph (DAG). Arguments: @@ -481,15 +702,15 @@ class Target(): self.sdb = stencil_db_from_package(arch, optimization) self._variables: dict[Net, tuple[int, int, str]] = dict() - def compile(self, *variables: Net | list[Net]) -> None: + def compile(self, *variables: int | float | cpint | cpfloat | cpbool | list[int | float | cpint | cpfloat | cpbool]) -> None: nodes: list[Node] = [] for s in variables: - if isinstance(s, Net): - nodes.append(Write(s)) - else: + if isinstance(s, list): for net in s: assert isinstance(net, Net), f"The folowing element is not a Net: {net}" nodes.append(Write(net)) + else: + nodes.append(Write(s)) dw, self._variables = compile_to_instruction_list(nodes, self.sdb) dw.write_com(binw.Command.END_COM) @@ -502,7 +723,8 @@ class Target(): dw.write_com(binw.Command.END_COM) assert coparun(dw.get_data()) > 0 - def read_value(self, net: Net) -> float | int | bool: + def read_value(self, net: NumLike) -> float | int | bool: + assert isinstance(net, Net), "Variable must be a copapy variable object" assert net in self._variables, f"Variable {net} not found" addr, lengths, var_type = self._variables[net] assert lengths > 0