typing fixed, variable[bool] replaced by variable[int]

This commit is contained in:
Nicolas Kruse 2025-11-27 12:50:53 +01:00
parent 44b215f728
commit 99a880861a
3 changed files with 51 additions and 67 deletions

View File

@ -3,15 +3,13 @@ from typing import Any, TypeVar, overload, TypeAlias, Generic, cast
from ._stencils import stencil_database, detect_process_arch from ._stencils import stencil_database, detect_process_arch
import copapy as cp import copapy as cp
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool' NumLike: TypeAlias = 'variable[int] | variable[float] | int | float'
unifloat: TypeAlias = 'variable[float] | float' unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int' uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool'
uniboolint: TypeAlias = 'variable[bool] | bool | variable[int] | int'
TCPNum = TypeVar("TCPNum", bound='variable[Any]') TCPNum = TypeVar("TCPNum", bound='variable[Any]')
TNum = TypeVar("TNum", int, float, bool) TNum = TypeVar("TNum", int, float)
TVarNumb: TypeAlias = 'variable[Any] | int | float | bool' TVarNumb: TypeAlias = 'variable[Any] | int | float'
stencil_cache: dict[tuple[str, str], stencil_database] = {} stencil_cache: dict[tuple[str, str], stencil_database] = {}
@ -107,20 +105,20 @@ class variable(Generic[TNum], Net):
self.dtype = 'int' self.dtype = 'int'
@overload @overload
def __add__(self, other: TCPNum) -> TCPNum: ... def __add__(self: 'variable[int]', other: uniint) -> 'variable[int]': ...
@overload
def __add__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload @overload
def __add__(self, other: unifloat) -> 'variable[float]': ... def __add__(self, other: unifloat) -> 'variable[float]': ...
@overload @overload
def __add__(self, other: NumLike) -> 'variable[float] | variable[int]': ... def __add__(self: 'variable[float]', other: NumLike) -> 'variable[float]': ...
def __add__(self, other: NumLike) -> Any: @overload
def __add__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __add__(self, other: TVarNumb) -> Any:
if isinstance(other, int | float) and other == 0: if isinstance(other, int | float) and other == 0:
return self return self
return add_op('add', [self, other], True) return add_op('add', [self, other], True)
@overload @overload
def __radd__(self: TCPNum, other: int) -> TCPNum: ... def __radd__(self: 'variable[int]', other: int) -> 'variable[int]': ...
@overload @overload
def __radd__(self, other: float) -> 'variable[float]': ... def __radd__(self, other: float) -> 'variable[float]': ...
def __radd__(self, other: NumLike) -> Any: def __radd__(self, other: NumLike) -> Any:
@ -129,36 +127,36 @@ class variable(Generic[TNum], Net):
return add_op('add', [self, other], True) return add_op('add', [self, other], True)
@overload @overload
def __sub__(self, other: TCPNum) -> TCPNum: ... def __sub__(self: 'variable[int]', other: uniint) -> 'variable[int]': ...
@overload
def __sub__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload @overload
def __sub__(self, other: unifloat) -> 'variable[float]': ... def __sub__(self, other: unifloat) -> 'variable[float]': ...
@overload @overload
def __sub__(self, other: NumLike) -> 'variable[float] | variable[int]': ... def __sub__(self: 'variable[float]', other: NumLike) -> 'variable[float]': ...
def __sub__(self, other: NumLike) -> Any: @overload
def __sub__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __sub__(self, other: TVarNumb) -> Any:
return add_op('sub', [self, other]) return add_op('sub', [self, other])
@overload @overload
def __rsub__(self: TCPNum, other: int) -> TCPNum: ... def __rsub__(self: 'variable[int]', other: int) -> 'variable[int]': ...
@overload @overload
def __rsub__(self, other: float) -> 'variable[float]': ... def __rsub__(self, other: float) -> 'variable[float]': ...
def __rsub__(self, other: NumLike) -> Any: def __rsub__(self, other: NumLike) -> Any:
return add_op('sub', [other, self]) return add_op('sub', [other, self])
@overload @overload
def __mul__(self, other: TCPNum) -> TCPNum: ... def __mul__(self: 'variable[int]', other: uniint) -> 'variable[int]': ...
@overload
def __mul__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload @overload
def __mul__(self, other: unifloat) -> 'variable[float]': ... def __mul__(self, other: unifloat) -> 'variable[float]': ...
@overload @overload
def __mul__(self, other: NumLike) -> 'variable[float] | variable[int]': ... def __mul__(self: 'variable[float]', other: NumLike) -> 'variable[float]': ...
def __mul__(self, other: NumLike) -> Any: @overload
def __mul__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __mul__(self, other: TVarNumb) -> Any:
return add_op('mul', [self, other], True) return add_op('mul', [self, other], True)
@overload @overload
def __rmul__(self: TCPNum, other: int) -> TCPNum: ... def __rmul__(self: 'variable[int]', other: int) -> 'variable[int]': ...
@overload @overload
def __rmul__(self, other: float) -> 'variable[float]': ... def __rmul__(self, other: float) -> 'variable[float]': ...
def __rmul__(self, other: NumLike) -> Any: def __rmul__(self, other: NumLike) -> Any:
@ -171,18 +169,18 @@ class variable(Generic[TNum], Net):
return add_op('div', [other, self]) return add_op('div', [other, self])
@overload @overload
def __floordiv__(self, other: TCPNum) -> TCPNum: ... def __floordiv__(self: 'variable[int]', other: uniint) -> 'variable[int]': ...
@overload
def __floordiv__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload @overload
def __floordiv__(self, other: unifloat) -> 'variable[float]': ... def __floordiv__(self, other: unifloat) -> 'variable[float]': ...
@overload @overload
def __floordiv__(self, other: NumLike) -> 'variable[float] | variable[int]': ... def __floordiv__(self: 'variable[float]', other: NumLike) -> 'variable[float]': ...
def __floordiv__(self, other: NumLike) -> Any: @overload
def __floordiv__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __floordiv__(self, other: TVarNumb) -> Any:
return add_op('floordiv', [self, other]) return add_op('floordiv', [self, other])
@overload @overload
def __rfloordiv__(self: TCPNum, other: int) -> TCPNum: ... def __rfloordiv__(self: 'variable[int]', other: int) -> 'variable[int]': ...
@overload @overload
def __rfloordiv__(self, other: float) -> 'variable[float]': ... def __rfloordiv__(self, other: float) -> 'variable[float]': ...
def __rfloordiv__(self, other: NumLike) -> Any: def __rfloordiv__(self, other: NumLike) -> Any:
@ -191,27 +189,27 @@ class variable(Generic[TNum], Net):
def __neg__(self: TCPNum) -> TCPNum: def __neg__(self: TCPNum) -> TCPNum:
return cast(TCPNum, add_op('sub', [variable(0), self])) return cast(TCPNum, add_op('sub', [variable(0), self]))
def __gt__(self, other: TVarNumb) -> 'variable[bool]': def __gt__(self, other: TVarNumb) -> 'variable[int]':
ret = add_op('gt', [self, other]) ret = add_op('gt', [self, other])
return variable(ret.source, dtype='bool') return variable(ret.source, dtype='bool')
def __lt__(self, other: TVarNumb) -> 'variable[bool]': def __lt__(self, other: TVarNumb) -> 'variable[int]':
ret = add_op('gt', [other, self]) ret = add_op('gt', [other, self])
return variable(ret.source, dtype='bool') return variable(ret.source, dtype='bool')
def __ge__(self, other: TVarNumb) -> 'variable[bool]': def __ge__(self, other: TVarNumb) -> 'variable[int]':
ret = add_op('ge', [self, other]) ret = add_op('ge', [self, other])
return variable(ret.source, dtype='bool') return variable(ret.source, dtype='bool')
def __le__(self, other: TVarNumb) -> 'variable[bool]': def __le__(self, other: TVarNumb) -> 'variable[int]':
ret = add_op('ge', [other, self]) ret = add_op('ge', [other, self])
return variable(ret.source, dtype='bool') return variable(ret.source, dtype='bool')
def __eq__(self, other: TVarNumb) -> 'variable[bool]': # type: ignore def __eq__(self, other: TVarNumb) -> 'variable[int]': # type: ignore
ret = add_op('eq', [self, other], True) ret = add_op('eq', [self, other], True)
return variable(ret.source, dtype='bool') return variable(ret.source, dtype='bool')
def __ne__(self, other: TVarNumb) -> 'variable[bool]': # type: ignore def __ne__(self, other: TVarNumb) -> 'variable[int]': # type: ignore
ret = add_op('ne', [self, other], True) ret = add_op('ne', [self, other], True)
return variable(ret.source, dtype='bool') return variable(ret.source, dtype='bool')
@ -255,34 +253,34 @@ class variable(Generic[TNum], Net):
return super().__hash__() return super().__hash__()
# Bitwise and shift operations for cp[int] # Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniboolint) -> 'variable[int]': def __lshift__(self, other: uniint) -> 'variable[int]':
return add_op('lshift', [self, other]) return add_op('lshift', [self, other])
def __rlshift__(self, other: uniboolint) -> 'variable[int]': def __rlshift__(self, other: uniint) -> 'variable[int]':
return add_op('lshift', [other, self]) return add_op('lshift', [other, self])
def __rshift__(self, other: uniboolint) -> 'variable[int]': def __rshift__(self, other: uniint) -> 'variable[int]':
return add_op('rshift', [self, other]) return add_op('rshift', [self, other])
def __rrshift__(self, other: uniboolint) -> 'variable[int]': def __rrshift__(self, other: uniint) -> 'variable[int]':
return add_op('rshift', [other, self]) return add_op('rshift', [other, self])
def __and__(self, other: uniboolint) -> 'variable[int]': def __and__(self, other: uniint) -> 'variable[int]':
return add_op('bwand', [self, other], True) return add_op('bwand', [self, other], True)
def __rand__(self, other: uniboolint) -> 'variable[int]': def __rand__(self, other: uniint) -> 'variable[int]':
return add_op('rwand', [other, self], True) return add_op('rwand', [other, self], True)
def __or__(self, other: uniboolint) -> 'variable[int]': def __or__(self, other: uniint) -> 'variable[int]':
return add_op('bwor', [self, other], True) return add_op('bwor', [self, other], True)
def __ror__(self, other: uniboolint) -> 'variable[int]': def __ror__(self, other: uniint) -> 'variable[int]':
return add_op('bwor', [other, self], True) return add_op('bwor', [other, self], True)
def __xor__(self, other: uniboolint) -> 'variable[int]': def __xor__(self, other: uniint) -> 'variable[int]':
return add_op('bwxor', [self, other], True) return add_op('bwxor', [self, other], True)
def __rxor__(self, other: uniboolint) -> 'variable[int]': def __rxor__(self, other: uniint) -> 'variable[int]':
return add_op('bwxor', [other, self], True) return add_op('bwxor', [other, self], True)
@ -318,9 +316,7 @@ def net_from_value(value: Any) -> Net:
@overload @overload
def iif(expression: variable[Any], true_result: unibool, false_result: unibool) -> variable[bool]: ... # pyright: ignore[reportOverlappingOverload] def iif(expression: variable[Any], true_result: uniint, false_result: uniint) -> variable[int]: ... # pyright: ignore[reportOverlappingOverload]
@overload
def iif(expression: variable[Any], true_result: uniint, false_result: uniint) -> variable[int]: ...
@overload @overload
def iif(expression: variable[Any], true_result: unifloat, false_result: unifloat) -> variable[float]: ... def iif(expression: variable[Any], true_result: unifloat, false_result: unifloat) -> variable[float]: ...
@overload @overload
@ -330,7 +326,7 @@ def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]
@overload @overload
def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]: ... def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]: ...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any: def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
allowed_type = (variable, int, float, bool) allowed_type = (variable, int, float)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported" assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
return (expression != 0) * true_result + (expression == 0) * false_result return (expression != 0) * true_result + (expression == 0) * false_result
@ -355,9 +351,7 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool =
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
if isinstance(value, bool): if isinstance(value, int):
return ('bool', int(value))
elif isinstance(value, int):
return ('int', int(value)) return ('int', int(value))
elif isinstance(value, float): elif isinstance(value, float):
return ('float', float(value)) return ('float', float(value))

View File

@ -28,7 +28,7 @@ class Target():
self.sdb = stencil_db_from_package(arch, optimization) self.sdb = stencil_db_from_package(arch, optimization)
self._variables: dict[Net, tuple[int, int, str]] = {} self._variables: dict[Net, tuple[int, int, str]] = {}
def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None: def compile(self, *variables: int | float | variable[int] | variable[float] | Iterable[int | float | variable[int] | variable[float]]) -> None:
"""Compiles the code to compute the given variables. """Compiles the code to compute the given variables.
Arguments: Arguments:
@ -56,21 +56,11 @@ class Target():
assert coparun(dw.get_data()) > 0 assert coparun(dw.get_data()) > 0
@overload @overload
def read_value(self, net: variable[bool]) -> bool: def read_value(self, net: variable[float]) -> float: ...
...
@overload @overload
def read_value(self, net: variable[float]) -> float: def read_value(self, net: variable[int]) -> int: ...
...
@overload @overload
def read_value(self, net: variable[int]) -> int: def read_value(self, net: NumLike) -> float | int | bool: ...
...
@overload
def read_value(self, net: NumLike) -> float | int | bool:
...
def read_value(self, net: NumLike) -> float | int | bool: def read_value(self, net: NumLike) -> float | int | bool:
"""Reads the value of a variable. """Reads the value of a variable.

View File

@ -2,7 +2,7 @@ from . import variable
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable
import copapy as cp import copapy as cp
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | variable[bool] | int | float | bool' VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int' VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float' VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
T = TypeVar("T", int, float) T = TypeVar("T", int, float)