Changed cpfloat, cpint etc. to generic variable[float] etc.

This commit is contained in:
Nicolas Kruse 2025-10-23 17:23:12 +02:00
parent 0d1d0a03c9
commit f61591a6ca
15 changed files with 125 additions and 188 deletions

View File

@ -1,15 +1,12 @@
from ._target import Target
from ._basic_types import NumLike, cpbool, cpfloat, cpint, \
CPNumber, cpvalue, cpvector, generic_sdb, iif
from ._basic_types import NumLike, variable, \
CPNumber, cpvector, generic_sdb, iif
__all__ = [
"Target",
"NumLike",
"cpbool",
"cpfloat",
"cpint",
"variable",
"CPNumber",
"cpvalue",
"cpvector",
"generic_sdb",
"iif",

View File

@ -1,15 +1,15 @@
import pkgutil
from typing import Any, TypeVar, overload, TypeAlias
from typing import Any, TypeVar, overload, TypeAlias, Generic, cast
from ._stencils import stencil_database
import platform
NumLike: TypeAlias = 'cpint | cpfloat | cpbool | int | float| bool'
NumLikeAndNet: TypeAlias = 'cpint | cpfloat | cpbool | int | float | bool | Net'
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool'
NumLikeAndNet: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool | Net'
NetAndNum: TypeAlias = 'Net | int | float'
unifloat: TypeAlias = 'cpfloat | float'
uniint: TypeAlias = 'cpint | int'
unibool: TypeAlias = 'cpbool | bool'
unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool'
TNumber = TypeVar("TNumber", bound='CPNumber')
T = TypeVar("T")
@ -40,7 +40,7 @@ class Node:
self.name: str = ''
def __repr__(self) -> str:
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, InitVar) else '')})"
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
class Device():
@ -70,7 +70,7 @@ class CPNumber(Net):
...
@overload
def __mul__(self, other: unifloat) -> 'cpfloat':
def __mul__(self, other: unifloat) -> 'variable[float]':
...
def __mul__(self, other: NumLike) -> 'CPNumber':
@ -81,7 +81,7 @@ class CPNumber(Net):
...
@overload
def __rmul__(self, other: unifloat) -> 'cpfloat':
def __rmul__(self, other: unifloat) -> 'variable[float]':
...
def __rmul__(self, other: NumLike) -> 'CPNumber':
@ -92,7 +92,7 @@ class CPNumber(Net):
...
@overload
def __add__(self, other: unifloat) -> 'cpfloat':
def __add__(self, other: unifloat) -> 'variable[float]':
...
def __add__(self, other: NumLike) -> 'CPNumber':
@ -103,7 +103,7 @@ class CPNumber(Net):
...
@overload
def __radd__(self, other: unifloat) -> 'cpfloat':
def __radd__(self, other: unifloat) -> 'variable[float]':
...
def __radd__(self, other: NumLike) -> 'CPNumber':
@ -114,7 +114,7 @@ class CPNumber(Net):
...
@overload
def __sub__(self, other: unifloat) -> 'cpfloat':
def __sub__(self, other: unifloat) -> 'variable[float]':
...
def __sub__(self, other: NumLike) -> 'CPNumber':
@ -125,28 +125,24 @@ class CPNumber(Net):
...
@overload
def __rsub__(self, other: unifloat) -> 'cpfloat':
def __rsub__(self, other: unifloat) -> 'variable[float]':
...
def __rsub__(self, other: NumLike) -> 'CPNumber':
return _add_op('sub', [other, self])
def __truediv__(self, other: NumLike) -> 'cpfloat':
ret = _add_op('div', [self, other])
assert isinstance(ret, cpfloat)
return ret
def __truediv__(self, other: NumLike) -> 'variable[float]':
return _add_op('div', [self, other])
def __rtruediv__(self, other: NumLike) -> 'cpfloat':
ret = _add_op('div', [other, self])
assert isinstance(ret, cpfloat)
return ret
def __rtruediv__(self, other: NumLike) -> 'variable[float]':
return _add_op('div', [other, self])
@overload
def __floordiv__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __floordiv__(self, other: unifloat) -> 'cpfloat':
def __floordiv__(self, other: unifloat) -> 'variable[float]':
...
def __floordiv__(self, other: NumLike) -> 'CPNumber':
@ -157,39 +153,38 @@ class CPNumber(Net):
...
@overload
def __rfloordiv__(self, other: unifloat) -> 'cpfloat':
def __rfloordiv__(self, other: unifloat) -> 'variable[float]':
...
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [other, self])
def __neg__(self: TNumber) -> TNumber:
ret = _add_op('sub', [cpvalue(0), self])
assert isinstance(ret, type(self))
return ret
assert isinstance(T, variable)
return cast(TNumber, _add_op('sub', [variable(0), self]))
def __gt__(self, other: NumLike) -> 'cpbool':
def __gt__(self, other: NumLike) -> 'variable[bool]':
ret = _add_op('gt', [self, other])
return cpbool(ret.source)
return variable(ret.source, dtype='bool')
def __lt__(self, other: NumLike) -> 'cpbool':
def __lt__(self, other: NumLike) -> 'variable[bool]':
ret = _add_op('gt', [other, self])
return cpbool(ret.source)
return variable(ret.source, dtype='bool')
def __eq__(self, other: NumLike) -> 'cpbool': # type: ignore
def __eq__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = _add_op('eq', [self, other], True)
return cpbool(ret.source)
return variable(ret.source, dtype='bool')
def __ne__(self, other: NumLike) -> 'cpbool': # type: ignore
def __ne__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = _add_op('ne', [self, other], True)
return cpbool(ret.source)
return variable(ret.source, dtype='bool')
@overload
def __mod__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __mod__(self, other: unifloat) -> 'cpfloat':
def __mod__(self, other: unifloat) -> 'variable[float]':
...
def __mod__(self, other: NumLike) -> 'CPNumber':
@ -200,7 +195,7 @@ class CPNumber(Net):
...
@overload
def __rmod__(self, other: unifloat) -> 'cpfloat':
def __rmod__(self, other: unifloat) -> 'variable[float]':
...
def __rmod__(self, other: NumLike) -> 'CPNumber':
@ -211,7 +206,7 @@ class CPNumber(Net):
...
@overload
def __pow__(self, other: unifloat) -> 'cpfloat':
def __pow__(self, other: unifloat) -> 'variable[float]':
...
def __pow__(self, other: NumLike) -> 'CPNumber':
@ -222,7 +217,7 @@ class CPNumber(Net):
...
@overload
def __rpow__(self, other: unifloat) -> 'cpfloat':
def __rpow__(self, other: unifloat) -> 'variable[float]':
...
def __rpow__(self, other: NumLike) -> 'CPNumber':
@ -232,83 +227,54 @@ class CPNumber(Net):
return super().__hash__()
class cpint(CPNumber):
def __init__(self, source: int | Node):
class variable(Generic[T], CPNumber):
def __init__(self, source: T | Node, dtype: str | None = None):
if isinstance(source, Node):
self.source = source
else:
self.source = InitVar(int(source))
self.dtype = 'int'
def __lshift__(self, other: uniint) -> 'cpint':
ret = _add_op('lshift', [self, other])
assert isinstance(ret, cpint)
return ret
def __rlshift__(self, other: uniint) -> 'cpint':
ret = _add_op('lshift', [other, self])
assert isinstance(ret, cpint)
return ret
def __rshift__(self, other: uniint) -> 'cpint':
ret = _add_op('rshift', [self, other])
assert isinstance(ret, cpint)
return ret
def __rrshift__(self, other: uniint) -> 'cpint':
ret = _add_op('rshift', [other, self])
assert isinstance(ret, cpint)
return ret
def __and__(self, other: uniint) -> 'cpint':
ret = _add_op('bwand', [self, other], True)
assert isinstance(ret, cpint)
return ret
def __rand__(self, other: uniint) -> 'cpint':
ret = _add_op('rwand', [other, self], True)
assert isinstance(ret, cpint)
return ret
def __or__(self, other: uniint) -> 'cpint':
ret = _add_op('bwor', [self, other], True)
assert isinstance(ret, cpint)
return ret
def __ror__(self, other: uniint) -> 'cpint':
ret = _add_op('bwor', [other, self], True)
assert isinstance(ret, cpint)
return ret
def __xor__(self, other: uniint) -> 'cpint':
ret = _add_op('bwxor', [self, other], True)
assert isinstance(ret, cpint)
return ret
def __rxor__(self, other: uniint) -> 'cpint':
ret = _add_op('bwxor', [other, self], True)
assert isinstance(ret, cpint)
return ret
class cpfloat(CPNumber):
def __init__(self, source: float | Node | CPNumber):
if isinstance(source, Node):
self.source = source
elif isinstance(source, CPNumber):
self.source = _add_op('cast_float', [source]).source
else:
self.source = InitVar(float(source))
self.dtype = 'float'
class cpbool(cpint):
def __init__(self, source: bool | Node):
if isinstance(source, Node):
self.source = source
else:
self.source = InitVar(bool(source))
assert dtype, 'For source type Node a dtype argument is required.'
self.dtype = dtype
elif isinstance(source, bool):
self.source = CPConstant(source)
self.dtype = 'bool'
elif isinstance(source, int):
self.source = CPConstant(source)
self.dtype = 'int'
elif isinstance(source, float):
self.source = CPConstant(source)
self.dtype = 'float'
else:
raise ValueError(f'Non supported data type: {type(source).__name__}')
# Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniint) -> 'variable[int]':
return _add_op('lshift', [self, other])
def __rlshift__(self, other: uniint) -> 'variable[int]':
return _add_op('lshift', [other, self])
def __rshift__(self, other: uniint) -> 'variable[int]':
return _add_op('rshift', [self, other])
def __rrshift__(self, other: uniint) -> 'variable[int]':
return _add_op('rshift', [other, self])
def __and__(self, other: uniint) -> 'variable[int]':
return _add_op('bwand', [self, other], True)
def __rand__(self, other: uniint) -> 'variable[int]':
return _add_op('rwand', [other, self], True)
def __or__(self, other: uniint) -> 'variable[int]':
return _add_op('bwor', [self, other], True)
def __ror__(self, other: uniint) -> 'variable[int]':
return _add_op('bwor', [other, self], True)
def __xor__(self, other: uniint) -> 'variable[int]':
return _add_op('bwxor', [self, other], True)
def __rxor__(self, other: uniint) -> 'variable[int]':
return _add_op('bwxor', [other, self], True)
class cpvector:
@ -321,7 +287,7 @@ class cpvector:
return cpvector(*(v for v in tup if isinstance(v, CPNumber)))
class InitVar(Node):
class CPConstant(Node):
def __init__(self, value: int | float):
self.dtype, self.value = _get_data_and_dtype(value)
self.name = 'const_' + self.dtype
@ -333,7 +299,7 @@ class Write(Node):
if isinstance(input, Net):
net = input
else:
node = InitVar(input)
node = CPConstant(input)
net = Net(node.dtype, node)
self.name = 'write_' + transl_type(net.dtype)
@ -348,22 +314,22 @@ class Op(Node):
def net_from_value(value: Any) -> Net:
vi = InitVar(value)
vi = CPConstant(value)
return Net(vi.dtype, vi)
@overload
def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> cpbool: # pyright: ignore[reportOverlappingOverload]
def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> variable[bool]: # pyright: ignore[reportOverlappingOverload]
...
@overload
def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> cpint:
def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> variable[int]:
...
@overload
def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> cpfloat:
def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> variable[float]:
...
@ -373,16 +339,15 @@ def iif(expression: NumLike, true_result: T, false_result: T) -> T:
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
# TODO: check that input types are matching
alowed_type = cpint | cpfloat | cpbool | int | float | bool
assert isinstance(true_result, alowed_type) and isinstance(false_result, alowed_type), "Result type not supported"
allowed_type = (variable, int, float, bool)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
if isinstance(expression, CPNumber):
return (expression != 0) * true_result + (expression == 0) * false_result
else:
return true_result if expression else false_result
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber:
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]:
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
if commutative:
@ -396,35 +361,9 @@ def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = Fal
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
if result_type == 'int':
return cpint(Op(typed_op, arg_nets))
return variable[int](Op(typed_op, arg_nets), result_type)
else:
return cpfloat(Op(typed_op, arg_nets))
@overload
def cpvalue(value: bool) -> cpbool: # pyright: ignore[reportOverlappingOverload]
...
@overload
def cpvalue(value: int) -> cpint:
...
@overload
def cpvalue(value: float) -> cpfloat:
...
def cpvalue(value: bool | int | float) -> cpbool | cpint | cpfloat:
vi = InitVar(value)
if isinstance(value, bool):
return cpbool(vi)
elif isinstance(value, float):
return cpfloat(vi)
else:
return cpint(vi)
return variable[float](Op(typed_op, arg_nets), result_type)
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:

View File

@ -2,7 +2,7 @@ from typing import Generator, Iterable, Any
from . import _binwrite as binw
from ._stencils import stencil_database, patch_entry
from collections import defaultdict, deque
from ._basic_types import Net, Node, Write, InitVar, Op, transl_type
from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type
def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
@ -76,7 +76,7 @@ def get_const_nets(nodes: list[Node]) -> list[Net]:
List of nets whose source node is a Const
"""
net_lookup = {net.source: net for node in nodes for net in node.args}
return [net_lookup[node] for node in nodes if isinstance(node, InitVar)]
return [net_lookup[node] for node in nodes if isinstance(node, CPConstant)]
def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]:
@ -97,7 +97,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
net_lookup = {net.source: net for node in node_list for net in node.args}
for node in node_list:
if not isinstance(node, InitVar):
if not isinstance(node, CPConstant):
for i, net in enumerate(node.args):
if id(net) != id(registers[i]):
#if net in registers:
@ -230,7 +230,7 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
# Heap variables
for net, out_offs, lengths in variable_mem_layout:
variables[net] = (out_offs, lengths, net.dtype)
if isinstance(net.source, InitVar):
if isinstance(net.source, CPConstant):
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs)
dw.write_int(lengths)

View File

@ -3,7 +3,7 @@ from . import _binwrite as binw
from coparun_module import coparun, read_data_mem
import struct
from ._basic_types import stencil_db_from_package
from ._basic_types import cpbool, cpint, cpfloat, Net, Node, Write, NumLike
from ._basic_types import variable, Net, Node, Write, NumLike
from ._compiler import compile_to_instruction_list
@ -20,7 +20,7 @@ class Target():
self.sdb = stencil_db_from_package(arch, optimization)
self._variables: dict[Net, tuple[int, int, str]] = dict()
def compile(self, *variables: int | float | cpint | cpfloat | cpbool | Iterable[int | float | cpint | cpfloat | cpbool]) -> None:
def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None:
nodes: list[Node] = []
for s in variables:
if isinstance(s, Iterable):
@ -42,15 +42,15 @@ class Target():
assert coparun(dw.get_data()) > 0
@overload
def read_value(self, net: cpbool) -> bool:
def read_value(self, net: variable[bool]) -> bool:
...
@overload
def read_value(self, net: cpfloat) -> float:
def read_value(self, net: variable[float]) -> float:
...
@overload
def read_value(self, net: cpint) -> int:
def read_value(self, net: variable[int]) -> int:
...
@overload
@ -61,6 +61,7 @@ class Target():
assert isinstance(net, Net), "Variable must be a copapy variable object"
assert net in self._variables, f"Variable {net} not found"
addr, lengths, var_type = self._variables[net]
print('...', self._variables[net], net.dtype)
assert lengths > 0
data = read_data_mem(addr, lengths)
assert data is not None and len(data) == lengths, f"Failed to read variable {net}"

View File

@ -1,5 +1,5 @@
from ._target import add_read_command
from ._basic_types import Net, Op, Node, InitVar, Write
from ._basic_types import Net, Op, Node, CPConstant, Write
from ._compiler import compile_to_instruction_list, \
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, \
add_write_ops
@ -9,7 +9,7 @@ __all__ = [
"Net",
"Op",
"Node",
"InitVar",
"CPConstant",
"Write",
"compile_to_instruction_list",
"stable_toposort",

View File

@ -1,4 +1,4 @@
from copapy import cpvalue
from copapy import variable
from copapy.backend import Write
import copapy.backend as cpbe
@ -21,8 +21,8 @@ def test_ast_generation():
#r2 = i1 + 9
#out = [Write(r1), Write(r2)]
c1 = cpvalue(4)
c2 = cpvalue(2)
c1 = variable(4)
c2 = variable(2)
#i1 = c1 * 2
#r1 = i1 + 7 + (c2 + 7 * 9)
#r2 = i1 + 9

View File

@ -1,4 +1,4 @@
from copapy import cpvalue, NumLike
from copapy import variable, NumLike
from copapy.backend import Write, compile_to_instruction_list, add_read_command
import copapy
import subprocess
@ -41,8 +41,8 @@ def function(c1: NumLike, c2: NumLike) -> tuple[NumLike, ...]:
def test_compile():
c1 = cpvalue(4)
c2 = cpvalue(2)
c1 = variable(4)
c2 = variable(2)
ret = function(c1, c2)
#ret = [c1 // 3.3 + 5]

View File

@ -1,4 +1,4 @@
from copapy import cpvalue, NumLike
from copapy import variable, NumLike
from copapy.backend import Write, compile_to_instruction_list
import copapy
import subprocess
@ -20,7 +20,7 @@ def function(c1: NumLike) -> list[NumLike]:
def test_compile():
c1 = cpvalue(16)
c1 = variable(16)
ret = function(c1)

View File

@ -1,4 +1,4 @@
from copapy import cpvalue, Target, NumLike
from copapy import variable, Target, NumLike
import pytest
@ -14,7 +14,7 @@ def function(c1: NumLike) -> list[NumLike]:
def test_compile():
c1 = cpvalue(16)
c1 = variable(16)
ret = function(c1)

View File

@ -1,5 +1,5 @@
from coparun_module import coparun
from copapy import cpvalue
from copapy import variable
from copapy.backend import Write, compile_to_instruction_list, add_read_command
import copapy
from copapy import _binwrite
@ -7,8 +7,8 @@ from copapy import _binwrite
def test_compile():
c1 = cpvalue(4)
c2 = cpvalue(2) * 4
c1 = variable(4)
c2 = variable(2) * 4
i1 = c2 * 2
r1 = i1 + 7 + (c1 + 7 * 9)

View File

@ -1,4 +1,4 @@
from copapy import NumLike, cpvalue
from copapy import NumLike, variable
from copapy.backend import Write, Net, compile_to_instruction_list, add_read_command
import copapy
import subprocess
@ -24,8 +24,8 @@ def function(c1: NumLike, c2: NumLike) -> tuple[NumLike, ...]:
def test_compile():
c1 = cpvalue(4)
c2 = cpvalue(2)
c1 = variable(4)
c2 = variable(2)
ret = function(c1, c2)

View File

@ -1,12 +1,12 @@
from copapy import cpvalue, Target
from copapy import variable, Target
import pytest
import copapy
def test_compile():
c_i = cpvalue(9)
c_f = cpvalue(2.5)
# c_b = cpvalue(True)
c_i = variable(9)
c_f = variable(2.5)
# c_b = variable(True)
ret_test = (c_f ** c_f, c_i ** c_i)
ret_ref = (2.5 ** 2.5, 9 ** 9)

View File

@ -1,4 +1,4 @@
from copapy import cpvalue, Target, NumLike, iif, cpint
from copapy import variable, Target, NumLike, iif
import pytest
import copapy
@ -42,11 +42,11 @@ def iiftests(c1: NumLike) -> list[NumLike]:
def test_compile():
c_i = cpvalue(9)
c_f = cpvalue(1.111)
c_b = cpvalue(True)
c_i = variable(9)
c_f = variable(1.111)
c_b = variable(True)
ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [cpint(9) % 2] + iiftests(c_i) + iiftests(c_f)
ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [variable(9) % 2] + iiftests(c_i) + iiftests(c_f)
ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2] + iiftests(9) + iiftests(1.111)
tg = Target()

View File

@ -1,12 +1,12 @@
from copapy import cpvalue, Target, iif
from copapy import variable, Target, iif
import pytest
import copapy
def test_compile():
c_i = cpvalue(9)
c_f = cpvalue(2.5)
# c_b = cpvalue(True)
c_i = variable(9)
c_f = variable(2.5)
# c_b = variable(True)
ret_test = (iif(c_f > 5, c_f, -1), iif(c_i > 5, c_f, 8.8), iif(c_i > 2, c_i, 1))
ret_ref = (iif(2.5 > 5, 2.5, -1), iif(9 > 5, 2.5, 8.8), iif(9 > 2, 9, 1))

View File

@ -1,11 +1,11 @@
from copapy import _binwrite, cpvalue
from copapy import _binwrite, variable
from copapy.backend import Write, compile_to_instruction_list
import copapy
def test_compile() -> None:
c1 = cpvalue(9)
c1 = variable(9)
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
ret = [c1 // 3.3 + 5]