Net and value types separated

This commit is contained in:
Nicolas 2025-12-23 17:54:57 +01:00
parent df22438ffe
commit 6dcaa6797c
16 changed files with 152 additions and 135 deletions

View File

@ -35,23 +35,23 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | matrix[Any
assert isinstance(y, Sequence) or isinstance(y, vector)
y_set = {v for v in y}
edges = cpb.get_all_dag_edges_between([x.source], (net.source for net in y_set if isinstance(net, Net)))
edges = cpb.get_all_dag_edges_between([x.net.source], (v.net.source for v in y_set if isinstance(v, value)))
ordered_ops = cpb.stable_toposort(edges)
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
grad_dict: dict[Net, unifloat] = dict()
def add_grad(val: value[Any], gradient_value: unifloat) -> None:
grad_dict[val] = grad_dict.get(val, 0.0) + gradient_value
grad_dict[val.net] = grad_dict.get(val.net, 0.0) + gradient_value
for node in reversed(ordered_ops):
#print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
if node.args:
args: Sequence[Any] = list(node.args)
g = 1.0 if node is x.source else grad_dict[net_lookup[node]]
args: Sequence[Net] = list(node.args)
g = 1.0 if node is x.net.source else grad_dict[net_lookup[node]]
opn = node.name.split('_')[0]
a: value[Any] = args[0]
b: value[Any] = args[1] if len(args) > 1 else a
a: value[float] = value(args[0])
b: value[float] = value(args[1]) if len(args) > 1 else a
if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']:
pass # Derivative is 0 for all ops returning integers
@ -119,9 +119,9 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | matrix[Any
raise ValueError(f"Operation {opn} not yet supported for auto diff.")
if isinstance(y, value):
return grad_dict[y]
return grad_dict[y.net]
if isinstance(y, vector):
return vector(grad_dict[yi] if isinstance(yi, value) else 0.0 for yi in y)
return vector(grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y)
if isinstance(y, matrix):
return matrix((grad_dict[yi] if isinstance(yi, value) else 0.0 for yi in row) for row in y)
return [grad_dict[yi] for yi in y]
return matrix((grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in row) for row in y)
return [grad_dict[yi.net] for yi in y]

View File

@ -84,34 +84,50 @@ class Net:
def __hash__(self) -> int:
return self.source.node_hash
def __eq__(self, value: object) -> bool:
return isinstance(value, Net) and self.source.node_hash == value.source.node_hash # TODO: change to 64 bit hash
class value(Generic[TNum], Net):
class value(Generic[TNum]):
"""A "value" represents a typed scalar variable. It supports arithmetic and
comparison operations.
Attributes:
dtype (str): Data type of this value.
"""
def __init__(self, source: TNum | Node, dtype: str | None = None):
def __init__(self, source: TNum | Net, dtype: str | None = None):
"""Instance a value.
Arguments:
source: A numeric value or Node object.
dtype: Data type of this value. Required if source is a Node.
dtype: Data type of this value.
net: Reference to the underlying Net in the graph
"""
if isinstance(source, Node):
self.source = source
assert dtype, 'For source type Node a dtype argument is required.'
if isinstance(source, Net):
self.net: Net = source
if dtype:
assert transl_type(dtype) == source.dtype, f"Type of Net ({source.dtype}) does not match {dtype}"
self.dtype: str = dtype
else:
self.dtype = source.dtype
elif dtype == 'int' or dtype == 'bool':
new_node = CPConstant(int(source), False)
self.net = Net(new_node.dtype, new_node)
self.dtype = dtype
elif isinstance(source, float):
self.source = CPConstant(source, False)
self.dtype = 'float'
elif isinstance(source, bool):
self.source = CPConstant(source, False)
elif dtype == 'float':
new_node = CPConstant(float(source), False)
self.net = Net(new_node.dtype, new_node)
self.dtype = dtype
elif dtype is None:
if isinstance(source, bool):
new_node = CPConstant(source, False)
self.net = Net(new_node.dtype, new_node)
self.dtype = 'bool'
else:
self.source = CPConstant(source, False)
self.dtype = 'int'
new_node = CPConstant(source, False)
self.net = Net(new_node.dtype, new_node)
self.dtype = new_node.dtype
else:
raise ValueError('Unknown type: {dtype}')
@overload
def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
@ -227,28 +243,22 @@ class value(Generic[TNum], Net):
return cast(TCPNum, add_op('sub', [value(0), self]))
def __gt__(self, other: TVarNumb) -> 'value[int]':
ret = add_op('gt', [self, other])
return value(ret.source, dtype='bool')
return add_op('gt', [self, other], dtype='bool')
def __lt__(self, other: TVarNumb) -> 'value[int]':
ret = add_op('gt', [other, self])
return value(ret.source, dtype='bool')
return add_op('gt', [other, self], dtype='bool')
def __ge__(self, other: TVarNumb) -> 'value[int]':
ret = add_op('ge', [self, other])
return value(ret.source, dtype='bool')
return add_op('ge', [self, other], dtype='bool')
def __le__(self, other: TVarNumb) -> 'value[int]':
ret = add_op('ge', [other, self])
return value(ret.source, dtype='bool')
return add_op('ge', [other, self], dtype='bool')
def __eq__(self, other: TVarNumb) -> 'value[int]': # type: ignore
ret = add_op('eq', [self, other], True)
return value(ret.source, dtype='bool')
return add_op('eq', [self, other], True, dtype='bool')
def __ne__(self, other: TVarNumb) -> 'value[int]': # type: ignore
ret = add_op('ne', [self, other], True)
return value(ret.source, dtype='bool')
return add_op('ne', [self, other], True, dtype='bool')
@overload
def __mod__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
@ -295,7 +305,7 @@ class value(Generic[TNum], Net):
return cp.pow(other, self)
def __hash__(self) -> int:
return super().__hash__()
return id(self)
# Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniint) -> 'value[int]':
@ -330,16 +340,26 @@ class value(Generic[TNum], Net):
class CPConstant(Node):
def __init__(self, value: int | float, anonymous: bool = True):
self.dtype, self.value = _get_data_and_dtype(value)
def __init__(self, value: Any, anonymous: bool = True):
if isinstance(value, int):
self.value: int | float = value
self.dtype = 'int'
elif isinstance(value, float):
self.value = value
self.dtype = 'float'
else:
raise ValueError(f'Non supported data type: {type(value).__name__}')
self.name = 'const_' + self.dtype
self.args = tuple()
self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self)
class Write(Node):
def __init__(self, input: Net | int | float):
if isinstance(input, Net):
def __init__(self, input: value[Any] | Net | int | float):
if isinstance(input, value):
net = input.net
elif isinstance(input, Net):
net = input
else:
node = CPConstant(input)
@ -352,15 +372,16 @@ class Write(Node):
class Op(Node):
def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False):
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
self.name: str = typed_op_name
self.args: tuple[Net, ...] = tuple(args)
self.node_hash = self.get_node_hash(commutative)
def net_from_value(val: Any) -> value[Any]:
vi = CPConstant(val)
return value(vi, vi.dtype)
def value_from_number(val: Any) -> value[Any]:
# Create anonymous constant that can be removed during optimization
new_node = CPConstant(val)
new_net = Net(new_node.dtype, new_node)
return value(new_net)
@overload
@ -392,29 +413,22 @@ def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
return (expression != 0) * true_result + (expression == 0) * false_result
def add_op(op: str, args: list[value[Any] | int | float], commutative: bool = False) -> value[Any]:
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
def add_op(op: str, args: list[value[Any] | int | float], commutative: bool = False, dtype: str | None = None) -> value[Any]:
arg_values = [a if isinstance(a, value) else value_from_number(a) for a in args]
if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order
arg_values = sorted(arg_values, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_values])
if typed_op not in generic_sdb.stencil_definitions:
raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}")
raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_values])}")
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
if result_type == 'float':
return value[float](Op(typed_op, arg_nets, commutative), result_type)
else:
return value[int](Op(typed_op, arg_nets, commutative), result_type)
result_net = Net(result_type, Op(typed_op, [av.net for av in arg_values], commutative))
if dtype:
result_type = dtype
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
if isinstance(value, int):
return ('int', int(value))
elif isinstance(value, float):
return ('float', float(value))
else:
raise ValueError(f'Non supported data type: {type(value).__name__}')
return value(result_net, result_type)

View File

@ -221,6 +221,8 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
for net in el:
if isinstance(net, Net):
nets.add(net)
else:
assert net is None or isinstance(net, Node), net
return list(nets)
@ -351,7 +353,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_com(binw.Command.FREE_MEMORY)
# Get all nets/variables associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops)
variable_list = get_nets([const_net_list], extended_output_ops)
stencil_names = {node.name for _, node in extended_output_ops}
aux_function_names = sdb.get_sub_functions(stencil_names)
@ -378,7 +380,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_int(start)
dw.write_int(lengths)
dw.write_value(net.source.value, lengths)
#print(f'+ {net.dtype} {net.source.value}')
print(f'+ {net.dtype} {net.source.value}')
# prep auxiliary_functions
code_section_layout, func_addr_lookup, aux_func_len = get_aux_func_layout(aux_function_names, sdb)

View File

@ -66,21 +66,21 @@ class Target():
def __del__(self) -> None:
clear_target(self._context)
def compile(self, *values: int | float | value[int] | value[float] | Iterable[int | float | value[int] | value[float]]) -> None:
def compile(self, *values: int | float | value[Any] | Iterable[int | float | value[Any]]) -> None:
"""Compiles the code to compute the given values.
Arguments:
values: Values to compute
"""
nodes: list[Node] = []
for s in values:
if isinstance(s, Iterable):
for net in s:
if isinstance(net, Net):
nodes.append(Write(net))
for input in values:
if isinstance(input, Iterable):
for v in input:
if isinstance(v, value):
nodes.append(Write(v))
else:
if isinstance(s, Net):
nodes.append(Write(s))
if isinstance(input, value):
nodes.append(Write(input))
dw, self._values = compile_to_dag(nodes, self.sdb)
dw.write_com(binw.Command.END_COM)
@ -95,32 +95,33 @@ class Target():
assert coparun(self._context, dw.get_data()) > 0
@overload
def read_value(self, net: value[T]) -> T: ...
def read_value(self, variables: value[T]) -> T: ...
@overload
def read_value(self, net: NumLike) -> float | int | bool: ...
def read_value(self, variables: NumLike) -> float | int | bool: ...
@overload
def read_value(self, net: Iterable[T | value[T]]) -> list[T]: ...
def read_value(self, net: NumLike | value[T] | Iterable[T | value[T]]) -> Any:
def read_value(self, variables: Iterable[T | value[T]]) -> list[T]: ...
def read_value(self, variables: NumLike | value[T] | Iterable[T | value[T]]) -> Any:
"""Reads the numeric value of a copapy type.
Arguments:
net: Value or multiple Values to read
variables: Variable or multiple variables to read
Returns:
Numeric value or values
"""
if isinstance(net, Iterable):
return [self.read_value(ni) if isinstance(ni, value) else ni for ni in net]
if isinstance(variables, Iterable):
return [self.read_value(ni) if isinstance(ni, value) else ni for ni in variables]
if isinstance(net, float | int):
return net
if isinstance(variables, float | int):
return variables
assert isinstance(net, Net), "Argument must be a copapy value"
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target."
addr, lengths, var_type = self._values[net]
assert isinstance(variables, value), "Argument must be a copapy value"
assert variables.net in self._values, f"Value {variables} not found. It might not have been compiled for the target."
addr, lengths, _ = self._values[variables.net]
var_type = variables.dtype
assert lengths > 0
data = read_data_mem(self._context, addr, lengths)
assert data is not None and len(data) == lengths, f"Failed to read value {net}"
assert data is not None and len(data) == lengths, f"Failed to read value {variables}"
en = {'little': '<', 'big': '>'}[self.sdb.byteorder]
if var_type == 'float':
if lengths == 4:
@ -142,24 +143,24 @@ class Target():
else:
raise ValueError(f"Unsupported value type: {var_type}")
def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None:
def write_value(self, variables: value[Any] | Iterable[value[Any]], data: int | float | Iterable[int | float]) -> None:
"""Write to a copapy value on the target.
Arguments:
net: Singe variable or multiple variables to overwrite
variables: Singe variable or multiple variables to overwrite
value: Singe value or multiple values to write
"""
if isinstance(net, Iterable):
assert isinstance(value, Iterable), "If net is iterable, value must be iterable too"
for ni, vi in zip(net, value):
if isinstance(variables, Iterable):
assert isinstance(data, Iterable), "If net is iterable, value must be iterable too"
for ni, vi in zip(variables, data):
self.write_value(ni, vi)
return
assert not isinstance(value, Iterable), "If net is not iterable, value must not be iterable"
assert not isinstance(data, Iterable), "If net is not iterable, value must not be iterable"
assert isinstance(net, Net), "Argument must be a copapy value"
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target."
addr, lengths, var_type = self._values[net]
assert isinstance(variables, value), "Argument must be a copapy value"
assert variables.net in self._values, f"Value {variables} not found. It might not have been compiled for the target."
addr, lengths, var_type = self._values[variables.net]
assert lengths > 0
dw = binw.data_writer(self.sdb.byteorder)
@ -168,17 +169,17 @@ class Target():
dw.write_int(lengths)
if var_type == 'float':
dw.write_value(float(value), lengths)
dw.write_value(float(data), lengths)
elif var_type == 'int' or var_type == 'bool':
dw.write_value(int(value), lengths)
dw.write_value(int(data), lengths)
else:
raise ValueError(f"Unsupported value type: {var_type}")
dw.write_com(binw.Command.END_COM)
assert coparun(self._context, dw.get_data()) > 0
def read_value_remote(self, net: Net) -> None:
def read_value_remote(self, variable: value[Any]) -> None:
"""Reads the raw data of a value by the runner."""
dw = binw.data_writer(self.sdb.byteorder)
add_read_command(dw, self._values, net)
add_read_command(dw, self._values, variable.net)
assert coparun(self._context, dw.get_data()) > 0

View File

@ -30,9 +30,9 @@ def test_compile():
il.write_com(_binwrite.Command.RUN_PROG)
#il.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test:
assert isinstance(net, copapy.backend.Net)
add_read_command(il, variables, net)
for v in ret_test:
assert isinstance(v, value)
add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM)

View File

@ -70,7 +70,7 @@ def test_timing_compiler():
# Get all nets/variables associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops)
variable_list = get_nets([const_net_list], extended_output_ops)
stencil_names = {node.name for _, node in extended_output_ops}
print(f'-- get_sub_functions: {len(stencil_names)}')

View File

@ -65,9 +65,9 @@ def test_compile():
# run program command
il.write_com(_binwrite.Command.RUN_PROG)
for net in ret:
assert isinstance(net, copapy.backend.Net)
add_read_command(il, variables, net)
for v in ret:
assert isinstance(v, cp.value)
add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM)

View File

@ -60,9 +60,9 @@ def test_compile():
# run program command
il.write_com(_binwrite.Command.RUN_PROG)
for net in ret:
assert isinstance(net, backend.Net)
add_read_command(il, variables, net)
for v in ret:
assert isinstance(v, cp.value)
add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM)

View File

@ -61,9 +61,9 @@ def test_compile():
il.write_com(_binwrite.Command.RUN_PROG)
#il.write_com(_binwrite.Command.DUMP_CODE)
for net in ret:
assert isinstance(net, backend.Net)
add_read_command(il, variables, net)
for v in ret:
assert isinstance(v, cp.value)
add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM)

View File

@ -33,9 +33,9 @@ def test_compile():
# run program command
il.write_com(_binwrite.Command.RUN_PROG)
for net in ret:
assert isinstance(net, Net)
add_read_command(il, vars, net)
for v in ret:
assert isinstance(v, value)
add_read_command(il, vars, v.net)
il.write_com(_binwrite.Command.END_COM)

View File

@ -28,9 +28,9 @@ def test_compile_sqrt():
# run program command
il.write_com(_binwrite.Command.RUN_PROG)
for net in ret:
assert isinstance(net, copapy.backend.Net)
add_read_command(il, variables, net)
for v in ret:
assert isinstance(v, value)
add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM)
@ -62,9 +62,9 @@ def test_compile_log():
# run program command
il.write_com(_binwrite.Command.RUN_PROG)
for net in ret:
assert isinstance(net, copapy.backend.Net)
add_read_command(il, variables, net)
for v in ret:
assert isinstance(v, value)
add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM)
@ -96,9 +96,9 @@ def test_compile_sin():
# run program command
il.write_com(_binwrite.Command.RUN_PROG)
for net in ret:
assert isinstance(net, copapy.backend.Net)
add_read_command(il, variables, net)
for v in ret:
assert isinstance(v, copapy.value)
add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM)

View File

@ -13,7 +13,7 @@ def test_get_dag_stats():
v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size))
assert isinstance(v3, value)
stat = get_dag_stats([v3])
stat = get_dag_stats([v3.net])
print(stat)
assert stat['const_float'] == 2 * v_size

View File

@ -16,7 +16,7 @@ def test_multi_target():
tg1.compile(e)
# Patch constant value
a.source = cp._basic_types.CPConstant(1000.0)
a.net.source = cp._basic_types.CPConstant(1000.0)
tg2 = cp.Target()
tg2.compile(e)

View File

@ -107,9 +107,9 @@ def test_compile():
dw.write_com(_binwrite.Command.RUN_PROG)
#dw.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test:
assert isinstance(net, backend.Net)
add_read_command(dw, variables, net)
for v in ret_test:
assert isinstance(v, value)
add_read_command(dw, variables, v.net)
#dw.write_com(_binwrite.Command.READ_DATA)
#dw.write_int(0)

View File

@ -109,9 +109,9 @@ def test_compile():
dw.write_com(_binwrite.Command.RUN_PROG)
#dw.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test:
assert isinstance(net, backend.Net)
add_read_command(dw, variables, net)
for v in ret_test:
assert isinstance(v, value)
add_read_command(dw, variables, v.net)
#dw.write_com(_binwrite.Command.READ_DATA)
#dw.write_int(0)
@ -148,7 +148,7 @@ def test_compile():
for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, value)
address = variables[test][0]
address = variables[test.net][0]
data = result_data[address]
if test.dtype == 'int':
val = int.from_bytes(data, sdb.byteorder, signed=True)

View File

@ -120,9 +120,9 @@ def test_compile():
dw.write_com(_binwrite.Command.RUN_PROG)
#dw.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test:
assert isinstance(net, backend.Net)
add_read_command(dw, variables, net)
for v in ret_test:
assert isinstance(v, value)
add_read_command(dw, variables, v.net)
#dw.write_com(_binwrite.Command.READ_DATA)
#dw.write_int(0)