typing fixed, variable[bool] replaced by variable[int]

This commit is contained in:
Nicolas Kruse 2025-11-27 12:50:53 +01:00
parent 44b215f728
commit 99a880861a
3 changed files with 51 additions and 67 deletions

View File

@ -3,15 +3,13 @@ from typing import Any, TypeVar, overload, TypeAlias, Generic, cast
from ._stencils import stencil_database, detect_process_arch
import copapy as cp
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool'
NumLike: TypeAlias = 'variable[int] | variable[float] | int | float'
unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool'
uniboolint: TypeAlias = 'variable[bool] | bool | variable[int] | int'
TCPNum = TypeVar("TCPNum", bound='variable[Any]')
TNum = TypeVar("TNum", int, float, bool)
TVarNumb: TypeAlias = 'variable[Any] | int | float | bool'
TNum = TypeVar("TNum", int, float)
TVarNumb: TypeAlias = 'variable[Any] | int | float'
stencil_cache: dict[tuple[str, str], stencil_database] = {}
@ -107,20 +105,20 @@ class variable(Generic[TNum], Net):
self.dtype = 'int'
@overload
def __add__(self, other: TCPNum) -> TCPNum: ...
@overload
def __add__(self: TCPNum, other: uniint) -> TCPNum: ...
def __add__(self: 'variable[int]', other: uniint) -> 'variable[int]': ...
@overload
def __add__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __add__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __add__(self, other: NumLike) -> Any:
def __add__(self: 'variable[float]', other: NumLike) -> 'variable[float]': ...
@overload
def __add__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __add__(self, other: TVarNumb) -> Any:
if isinstance(other, int | float) and other == 0:
return self
return add_op('add', [self, other], True)
@overload
def __radd__(self: TCPNum, other: int) -> TCPNum: ...
def __radd__(self: 'variable[int]', other: int) -> 'variable[int]': ...
@overload
def __radd__(self, other: float) -> 'variable[float]': ...
def __radd__(self, other: NumLike) -> Any:
@ -129,36 +127,36 @@ class variable(Generic[TNum], Net):
return add_op('add', [self, other], True)
@overload
def __sub__(self, other: TCPNum) -> TCPNum: ...
@overload
def __sub__(self: TCPNum, other: uniint) -> TCPNum: ...
def __sub__(self: 'variable[int]', other: uniint) -> 'variable[int]': ...
@overload
def __sub__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __sub__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __sub__(self, other: NumLike) -> Any:
def __sub__(self: 'variable[float]', other: NumLike) -> 'variable[float]': ...
@overload
def __sub__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __sub__(self, other: TVarNumb) -> Any:
return add_op('sub', [self, other])
@overload
def __rsub__(self: TCPNum, other: int) -> TCPNum: ...
def __rsub__(self: 'variable[int]', other: int) -> 'variable[int]': ...
@overload
def __rsub__(self, other: float) -> 'variable[float]': ...
def __rsub__(self, other: NumLike) -> Any:
return add_op('sub', [other, self])
@overload
def __mul__(self, other: TCPNum) -> TCPNum: ...
@overload
def __mul__(self: TCPNum, other: uniint) -> TCPNum: ...
def __mul__(self: 'variable[int]', other: uniint) -> 'variable[int]': ...
@overload
def __mul__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __mul__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __mul__(self, other: NumLike) -> Any:
def __mul__(self: 'variable[float]', other: NumLike) -> 'variable[float]': ...
@overload
def __mul__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __mul__(self, other: TVarNumb) -> Any:
return add_op('mul', [self, other], True)
@overload
def __rmul__(self: TCPNum, other: int) -> TCPNum: ...
def __rmul__(self: 'variable[int]', other: int) -> 'variable[int]': ...
@overload
def __rmul__(self, other: float) -> 'variable[float]': ...
def __rmul__(self, other: NumLike) -> Any:
@ -171,18 +169,18 @@ class variable(Generic[TNum], Net):
return add_op('div', [other, self])
@overload
def __floordiv__(self, other: TCPNum) -> TCPNum: ...
@overload
def __floordiv__(self: TCPNum, other: uniint) -> TCPNum: ...
def __floordiv__(self: 'variable[int]', other: uniint) -> 'variable[int]': ...
@overload
def __floordiv__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __floordiv__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __floordiv__(self, other: NumLike) -> Any:
def __floordiv__(self: 'variable[float]', other: NumLike) -> 'variable[float]': ...
@overload
def __floordiv__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __floordiv__(self, other: TVarNumb) -> Any:
return add_op('floordiv', [self, other])
@overload
def __rfloordiv__(self: TCPNum, other: int) -> TCPNum: ...
def __rfloordiv__(self: 'variable[int]', other: int) -> 'variable[int]': ...
@overload
def __rfloordiv__(self, other: float) -> 'variable[float]': ...
def __rfloordiv__(self, other: NumLike) -> Any:
@ -191,27 +189,27 @@ class variable(Generic[TNum], Net):
def __neg__(self: TCPNum) -> TCPNum:
return cast(TCPNum, add_op('sub', [variable(0), self]))
def __gt__(self, other: TVarNumb) -> 'variable[bool]':
def __gt__(self, other: TVarNumb) -> 'variable[int]':
ret = add_op('gt', [self, other])
return variable(ret.source, dtype='bool')
def __lt__(self, other: TVarNumb) -> 'variable[bool]':
def __lt__(self, other: TVarNumb) -> 'variable[int]':
ret = add_op('gt', [other, self])
return variable(ret.source, dtype='bool')
def __ge__(self, other: TVarNumb) -> 'variable[bool]':
def __ge__(self, other: TVarNumb) -> 'variable[int]':
ret = add_op('ge', [self, other])
return variable(ret.source, dtype='bool')
def __le__(self, other: TVarNumb) -> 'variable[bool]':
def __le__(self, other: TVarNumb) -> 'variable[int]':
ret = add_op('ge', [other, self])
return variable(ret.source, dtype='bool')
def __eq__(self, other: TVarNumb) -> 'variable[bool]': # type: ignore
def __eq__(self, other: TVarNumb) -> 'variable[int]': # type: ignore
ret = add_op('eq', [self, other], True)
return variable(ret.source, dtype='bool')
def __ne__(self, other: TVarNumb) -> 'variable[bool]': # type: ignore
def __ne__(self, other: TVarNumb) -> 'variable[int]': # type: ignore
ret = add_op('ne', [self, other], True)
return variable(ret.source, dtype='bool')
@ -255,34 +253,34 @@ class variable(Generic[TNum], Net):
return super().__hash__()
# Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniboolint) -> 'variable[int]':
def __lshift__(self, other: uniint) -> 'variable[int]':
return add_op('lshift', [self, other])
def __rlshift__(self, other: uniboolint) -> 'variable[int]':
def __rlshift__(self, other: uniint) -> 'variable[int]':
return add_op('lshift', [other, self])
def __rshift__(self, other: uniboolint) -> 'variable[int]':
def __rshift__(self, other: uniint) -> 'variable[int]':
return add_op('rshift', [self, other])
def __rrshift__(self, other: uniboolint) -> 'variable[int]':
def __rrshift__(self, other: uniint) -> 'variable[int]':
return add_op('rshift', [other, self])
def __and__(self, other: uniboolint) -> 'variable[int]':
def __and__(self, other: uniint) -> 'variable[int]':
return add_op('bwand', [self, other], True)
def __rand__(self, other: uniboolint) -> 'variable[int]':
def __rand__(self, other: uniint) -> 'variable[int]':
return add_op('rwand', [other, self], True)
def __or__(self, other: uniboolint) -> 'variable[int]':
def __or__(self, other: uniint) -> 'variable[int]':
return add_op('bwor', [self, other], True)
def __ror__(self, other: uniboolint) -> 'variable[int]':
def __ror__(self, other: uniint) -> 'variable[int]':
return add_op('bwor', [other, self], True)
def __xor__(self, other: uniboolint) -> 'variable[int]':
def __xor__(self, other: uniint) -> 'variable[int]':
return add_op('bwxor', [self, other], True)
def __rxor__(self, other: uniboolint) -> 'variable[int]':
def __rxor__(self, other: uniint) -> 'variable[int]':
return add_op('bwxor', [other, self], True)
@ -318,9 +316,7 @@ def net_from_value(value: Any) -> Net:
@overload
def iif(expression: variable[Any], true_result: unibool, false_result: unibool) -> variable[bool]: ... # pyright: ignore[reportOverlappingOverload]
@overload
def iif(expression: variable[Any], true_result: uniint, false_result: uniint) -> variable[int]: ...
def iif(expression: variable[Any], true_result: uniint, false_result: uniint) -> variable[int]: ... # pyright: ignore[reportOverlappingOverload]
@overload
def iif(expression: variable[Any], true_result: unifloat, false_result: unifloat) -> variable[float]: ...
@overload
@ -330,7 +326,7 @@ def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]
@overload
def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]: ...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
allowed_type = (variable, int, float, bool)
allowed_type = (variable, int, float)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
return (expression != 0) * true_result + (expression == 0) * false_result
@ -355,9 +351,7 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool =
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
if isinstance(value, bool):
return ('bool', int(value))
elif isinstance(value, int):
if isinstance(value, int):
return ('int', int(value))
elif isinstance(value, float):
return ('float', float(value))

View File

@ -28,7 +28,7 @@ class Target():
self.sdb = stencil_db_from_package(arch, optimization)
self._variables: dict[Net, tuple[int, int, str]] = {}
def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None:
def compile(self, *variables: int | float | variable[int] | variable[float] | Iterable[int | float | variable[int] | variable[float]]) -> None:
"""Compiles the code to compute the given variables.
Arguments:
@ -56,21 +56,11 @@ class Target():
assert coparun(dw.get_data()) > 0
@overload
def read_value(self, net: variable[bool]) -> bool:
...
def read_value(self, net: variable[float]) -> float: ...
@overload
def read_value(self, net: variable[float]) -> float:
...
def read_value(self, net: variable[int]) -> int: ...
@overload
def read_value(self, net: variable[int]) -> int:
...
@overload
def read_value(self, net: NumLike) -> float | int | bool:
...
def read_value(self, net: NumLike) -> float | int | bool: ...
def read_value(self, net: NumLike) -> float | int | bool:
"""Reads the value of a variable.

View File

@ -2,7 +2,7 @@ from . import variable
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias, Callable
import copapy as cp
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | variable[bool] | int | float | bool'
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float | bool'
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
T = TypeVar("T", int, float)