Net and value types separated

This commit is contained in:
Nicolas 2025-12-23 17:54:57 +01:00
parent df22438ffe
commit 6dcaa6797c
16 changed files with 152 additions and 135 deletions

View File

@ -35,23 +35,23 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | matrix[Any
assert isinstance(y, Sequence) or isinstance(y, vector) assert isinstance(y, Sequence) or isinstance(y, vector)
y_set = {v for v in y} y_set = {v for v in y}
edges = cpb.get_all_dag_edges_between([x.source], (net.source for net in y_set if isinstance(net, Net))) edges = cpb.get_all_dag_edges_between([x.net.source], (v.net.source for v in y_set if isinstance(v, value)))
ordered_ops = cpb.stable_toposort(edges) ordered_ops = cpb.stable_toposort(edges)
net_lookup = {net.source: net for node in ordered_ops for net in node.args} net_lookup = {net.source: net for node in ordered_ops for net in node.args}
grad_dict: dict[Net, unifloat] = dict() grad_dict: dict[Net, unifloat] = dict()
def add_grad(val: value[Any], gradient_value: unifloat) -> None: def add_grad(val: value[Any], gradient_value: unifloat) -> None:
grad_dict[val] = grad_dict.get(val, 0.0) + gradient_value grad_dict[val.net] = grad_dict.get(val.net, 0.0) + gradient_value
for node in reversed(ordered_ops): for node in reversed(ordered_ops):
#print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}") #print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
if node.args: if node.args:
args: Sequence[Any] = list(node.args) args: Sequence[Net] = list(node.args)
g = 1.0 if node is x.source else grad_dict[net_lookup[node]] g = 1.0 if node is x.net.source else grad_dict[net_lookup[node]]
opn = node.name.split('_')[0] opn = node.name.split('_')[0]
a: value[Any] = args[0] a: value[float] = value(args[0])
b: value[Any] = args[1] if len(args) > 1 else a b: value[float] = value(args[1]) if len(args) > 1 else a
if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']: if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']:
pass # Derivative is 0 for all ops returning integers pass # Derivative is 0 for all ops returning integers
@ -119,9 +119,9 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | matrix[Any
raise ValueError(f"Operation {opn} not yet supported for auto diff.") raise ValueError(f"Operation {opn} not yet supported for auto diff.")
if isinstance(y, value): if isinstance(y, value):
return grad_dict[y] return grad_dict[y.net]
if isinstance(y, vector): if isinstance(y, vector):
return vector(grad_dict[yi] if isinstance(yi, value) else 0.0 for yi in y) return vector(grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y)
if isinstance(y, matrix): if isinstance(y, matrix):
return matrix((grad_dict[yi] if isinstance(yi, value) else 0.0 for yi in row) for row in y) return matrix((grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in row) for row in y)
return [grad_dict[yi] for yi in y] return [grad_dict[yi.net] for yi in y]

View File

@ -84,34 +84,50 @@ class Net:
def __hash__(self) -> int: def __hash__(self) -> int:
return self.source.node_hash return self.source.node_hash
def __eq__(self, value: object) -> bool:
return isinstance(value, Net) and self.source.node_hash == value.source.node_hash # TODO: change to 64 bit hash
class value(Generic[TNum], Net):
class value(Generic[TNum]):
"""A "value" represents a typed scalar variable. It supports arithmetic and """A "value" represents a typed scalar variable. It supports arithmetic and
comparison operations. comparison operations.
Attributes: Attributes:
dtype (str): Data type of this value. dtype (str): Data type of this value.
""" """
def __init__(self, source: TNum | Node, dtype: str | None = None): def __init__(self, source: TNum | Net, dtype: str | None = None):
"""Instance a value. """Instance a value.
Arguments: Arguments:
source: A numeric value or Node object. dtype: Data type of this value.
dtype: Data type of this value. Required if source is a Node. net: Reference to the underlying Net in the graph
""" """
if isinstance(source, Node): if isinstance(source, Net):
self.source = source self.net: Net = source
assert dtype, 'For source type Node a dtype argument is required.' if dtype:
assert transl_type(dtype) == source.dtype, f"Type of Net ({source.dtype}) does not match {dtype}"
self.dtype: str = dtype
else:
self.dtype = source.dtype
elif dtype == 'int' or dtype == 'bool':
new_node = CPConstant(int(source), False)
self.net = Net(new_node.dtype, new_node)
self.dtype = dtype self.dtype = dtype
elif isinstance(source, float): elif dtype == 'float':
self.source = CPConstant(source, False) new_node = CPConstant(float(source), False)
self.dtype = 'float' self.net = Net(new_node.dtype, new_node)
elif isinstance(source, bool): self.dtype = dtype
self.source = CPConstant(source, False) elif dtype is None:
if isinstance(source, bool):
new_node = CPConstant(source, False)
self.net = Net(new_node.dtype, new_node)
self.dtype = 'bool' self.dtype = 'bool'
else: else:
self.source = CPConstant(source, False) new_node = CPConstant(source, False)
self.dtype = 'int' self.net = Net(new_node.dtype, new_node)
self.dtype = new_node.dtype
else:
raise ValueError('Unknown type: {dtype}')
@overload @overload
def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ... def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
@ -227,28 +243,22 @@ class value(Generic[TNum], Net):
return cast(TCPNum, add_op('sub', [value(0), self])) return cast(TCPNum, add_op('sub', [value(0), self]))
def __gt__(self, other: TVarNumb) -> 'value[int]': def __gt__(self, other: TVarNumb) -> 'value[int]':
ret = add_op('gt', [self, other]) return add_op('gt', [self, other], dtype='bool')
return value(ret.source, dtype='bool')
def __lt__(self, other: TVarNumb) -> 'value[int]': def __lt__(self, other: TVarNumb) -> 'value[int]':
ret = add_op('gt', [other, self]) return add_op('gt', [other, self], dtype='bool')
return value(ret.source, dtype='bool')
def __ge__(self, other: TVarNumb) -> 'value[int]': def __ge__(self, other: TVarNumb) -> 'value[int]':
ret = add_op('ge', [self, other]) return add_op('ge', [self, other], dtype='bool')
return value(ret.source, dtype='bool')
def __le__(self, other: TVarNumb) -> 'value[int]': def __le__(self, other: TVarNumb) -> 'value[int]':
ret = add_op('ge', [other, self]) return add_op('ge', [other, self], dtype='bool')
return value(ret.source, dtype='bool')
def __eq__(self, other: TVarNumb) -> 'value[int]': # type: ignore def __eq__(self, other: TVarNumb) -> 'value[int]': # type: ignore
ret = add_op('eq', [self, other], True) return add_op('eq', [self, other], True, dtype='bool')
return value(ret.source, dtype='bool')
def __ne__(self, other: TVarNumb) -> 'value[int]': # type: ignore def __ne__(self, other: TVarNumb) -> 'value[int]': # type: ignore
ret = add_op('ne', [self, other], True) return add_op('ne', [self, other], True, dtype='bool')
return value(ret.source, dtype='bool')
@overload @overload
def __mod__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ... def __mod__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
@ -295,7 +305,7 @@ class value(Generic[TNum], Net):
return cp.pow(other, self) return cp.pow(other, self)
def __hash__(self) -> int: def __hash__(self) -> int:
return super().__hash__() return id(self)
# Bitwise and shift operations for cp[int] # Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniint) -> 'value[int]': def __lshift__(self, other: uniint) -> 'value[int]':
@ -330,16 +340,26 @@ class value(Generic[TNum], Net):
class CPConstant(Node): class CPConstant(Node):
def __init__(self, value: int | float, anonymous: bool = True): def __init__(self, value: Any, anonymous: bool = True):
self.dtype, self.value = _get_data_and_dtype(value) if isinstance(value, int):
self.value: int | float = value
self.dtype = 'int'
elif isinstance(value, float):
self.value = value
self.dtype = 'float'
else:
raise ValueError(f'Non supported data type: {type(value).__name__}')
self.name = 'const_' + self.dtype self.name = 'const_' + self.dtype
self.args = tuple() self.args = tuple()
self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self) self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self)
class Write(Node): class Write(Node):
def __init__(self, input: Net | int | float): def __init__(self, input: value[Any] | Net | int | float):
if isinstance(input, Net): if isinstance(input, value):
net = input.net
elif isinstance(input, Net):
net = input net = input
else: else:
node = CPConstant(input) node = CPConstant(input)
@ -352,15 +372,16 @@ class Write(Node):
class Op(Node): class Op(Node):
def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False): def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False):
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
self.name: str = typed_op_name self.name: str = typed_op_name
self.args: tuple[Net, ...] = tuple(args) self.args: tuple[Net, ...] = tuple(args)
self.node_hash = self.get_node_hash(commutative) self.node_hash = self.get_node_hash(commutative)
def net_from_value(val: Any) -> value[Any]: def value_from_number(val: Any) -> value[Any]:
vi = CPConstant(val) # Create anonymous constant that can be removed during optimization
return value(vi, vi.dtype) new_node = CPConstant(val)
new_net = Net(new_node.dtype, new_node)
return value(new_net)
@overload @overload
@ -392,29 +413,22 @@ def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
return (expression != 0) * true_result + (expression == 0) * false_result return (expression != 0) * true_result + (expression == 0) * false_result
def add_op(op: str, args: list[value[Any] | int | float], commutative: bool = False) -> value[Any]: def add_op(op: str, args: list[value[Any] | int | float], commutative: bool = False, dtype: str | None = None) -> value[Any]:
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] arg_values = [a if isinstance(a, value) else value_from_number(a) for a in args]
if commutative: if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order arg_values = sorted(arg_values, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_values])
if typed_op not in generic_sdb.stencil_definitions: if typed_op not in generic_sdb.stencil_definitions:
raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}") raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_values])}")
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
if result_type == 'float': result_net = Net(result_type, Op(typed_op, [av.net for av in arg_values], commutative))
return value[float](Op(typed_op, arg_nets, commutative), result_type)
else:
return value[int](Op(typed_op, arg_nets, commutative), result_type)
if dtype:
result_type = dtype
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: return value(result_net, result_type)
if isinstance(value, int):
return ('int', int(value))
elif isinstance(value, float):
return ('float', float(value))
else:
raise ValueError(f'Non supported data type: {type(value).__name__}')

View File

@ -221,6 +221,8 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
for net in el: for net in el:
if isinstance(net, Net): if isinstance(net, Net):
nets.add(net) nets.add(net)
else:
assert net is None or isinstance(net, Node), net
return list(nets) return list(nets)
@ -351,7 +353,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_com(binw.Command.FREE_MEMORY) dw.write_com(binw.Command.FREE_MEMORY)
# Get all nets/variables associated with heap memory # Get all nets/variables associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops) variable_list = get_nets([const_net_list], extended_output_ops)
stencil_names = {node.name for _, node in extended_output_ops} stencil_names = {node.name for _, node in extended_output_ops}
aux_function_names = sdb.get_sub_functions(stencil_names) aux_function_names = sdb.get_sub_functions(stencil_names)
@ -378,7 +380,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_int(start) dw.write_int(start)
dw.write_int(lengths) dw.write_int(lengths)
dw.write_value(net.source.value, lengths) dw.write_value(net.source.value, lengths)
#print(f'+ {net.dtype} {net.source.value}') print(f'+ {net.dtype} {net.source.value}')
# prep auxiliary_functions # prep auxiliary_functions
code_section_layout, func_addr_lookup, aux_func_len = get_aux_func_layout(aux_function_names, sdb) code_section_layout, func_addr_lookup, aux_func_len = get_aux_func_layout(aux_function_names, sdb)

View File

@ -66,21 +66,21 @@ class Target():
def __del__(self) -> None: def __del__(self) -> None:
clear_target(self._context) clear_target(self._context)
def compile(self, *values: int | float | value[int] | value[float] | Iterable[int | float | value[int] | value[float]]) -> None: def compile(self, *values: int | float | value[Any] | Iterable[int | float | value[Any]]) -> None:
"""Compiles the code to compute the given values. """Compiles the code to compute the given values.
Arguments: Arguments:
values: Values to compute values: Values to compute
""" """
nodes: list[Node] = [] nodes: list[Node] = []
for s in values: for input in values:
if isinstance(s, Iterable): if isinstance(input, Iterable):
for net in s: for v in input:
if isinstance(net, Net): if isinstance(v, value):
nodes.append(Write(net)) nodes.append(Write(v))
else: else:
if isinstance(s, Net): if isinstance(input, value):
nodes.append(Write(s)) nodes.append(Write(input))
dw, self._values = compile_to_dag(nodes, self.sdb) dw, self._values = compile_to_dag(nodes, self.sdb)
dw.write_com(binw.Command.END_COM) dw.write_com(binw.Command.END_COM)
@ -95,32 +95,33 @@ class Target():
assert coparun(self._context, dw.get_data()) > 0 assert coparun(self._context, dw.get_data()) > 0
@overload @overload
def read_value(self, net: value[T]) -> T: ... def read_value(self, variables: value[T]) -> T: ...
@overload @overload
def read_value(self, net: NumLike) -> float | int | bool: ... def read_value(self, variables: NumLike) -> float | int | bool: ...
@overload @overload
def read_value(self, net: Iterable[T | value[T]]) -> list[T]: ... def read_value(self, variables: Iterable[T | value[T]]) -> list[T]: ...
def read_value(self, net: NumLike | value[T] | Iterable[T | value[T]]) -> Any: def read_value(self, variables: NumLike | value[T] | Iterable[T | value[T]]) -> Any:
"""Reads the numeric value of a copapy type. """Reads the numeric value of a copapy type.
Arguments: Arguments:
net: Value or multiple Values to read variables: Variable or multiple variables to read
Returns: Returns:
Numeric value or values Numeric value or values
""" """
if isinstance(net, Iterable): if isinstance(variables, Iterable):
return [self.read_value(ni) if isinstance(ni, value) else ni for ni in net] return [self.read_value(ni) if isinstance(ni, value) else ni for ni in variables]
if isinstance(net, float | int): if isinstance(variables, float | int):
return net return variables
assert isinstance(net, Net), "Argument must be a copapy value" assert isinstance(variables, value), "Argument must be a copapy value"
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target." assert variables.net in self._values, f"Value {variables} not found. It might not have been compiled for the target."
addr, lengths, var_type = self._values[net] addr, lengths, _ = self._values[variables.net]
var_type = variables.dtype
assert lengths > 0 assert lengths > 0
data = read_data_mem(self._context, addr, lengths) data = read_data_mem(self._context, addr, lengths)
assert data is not None and len(data) == lengths, f"Failed to read value {net}" assert data is not None and len(data) == lengths, f"Failed to read value {variables}"
en = {'little': '<', 'big': '>'}[self.sdb.byteorder] en = {'little': '<', 'big': '>'}[self.sdb.byteorder]
if var_type == 'float': if var_type == 'float':
if lengths == 4: if lengths == 4:
@ -142,24 +143,24 @@ class Target():
else: else:
raise ValueError(f"Unsupported value type: {var_type}") raise ValueError(f"Unsupported value type: {var_type}")
def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None: def write_value(self, variables: value[Any] | Iterable[value[Any]], data: int | float | Iterable[int | float]) -> None:
"""Write to a copapy value on the target. """Write to a copapy value on the target.
Arguments: Arguments:
net: Singe variable or multiple variables to overwrite variables: Singe variable or multiple variables to overwrite
value: Singe value or multiple values to write value: Singe value or multiple values to write
""" """
if isinstance(net, Iterable): if isinstance(variables, Iterable):
assert isinstance(value, Iterable), "If net is iterable, value must be iterable too" assert isinstance(data, Iterable), "If net is iterable, value must be iterable too"
for ni, vi in zip(net, value): for ni, vi in zip(variables, data):
self.write_value(ni, vi) self.write_value(ni, vi)
return return
assert not isinstance(value, Iterable), "If net is not iterable, value must not be iterable" assert not isinstance(data, Iterable), "If net is not iterable, value must not be iterable"
assert isinstance(net, Net), "Argument must be a copapy value" assert isinstance(variables, value), "Argument must be a copapy value"
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target." assert variables.net in self._values, f"Value {variables} not found. It might not have been compiled for the target."
addr, lengths, var_type = self._values[net] addr, lengths, var_type = self._values[variables.net]
assert lengths > 0 assert lengths > 0
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)
@ -168,17 +169,17 @@ class Target():
dw.write_int(lengths) dw.write_int(lengths)
if var_type == 'float': if var_type == 'float':
dw.write_value(float(value), lengths) dw.write_value(float(data), lengths)
elif var_type == 'int' or var_type == 'bool': elif var_type == 'int' or var_type == 'bool':
dw.write_value(int(value), lengths) dw.write_value(int(data), lengths)
else: else:
raise ValueError(f"Unsupported value type: {var_type}") raise ValueError(f"Unsupported value type: {var_type}")
dw.write_com(binw.Command.END_COM) dw.write_com(binw.Command.END_COM)
assert coparun(self._context, dw.get_data()) > 0 assert coparun(self._context, dw.get_data()) > 0
def read_value_remote(self, net: Net) -> None: def read_value_remote(self, variable: value[Any]) -> None:
"""Reads the raw data of a value by the runner.""" """Reads the raw data of a value by the runner."""
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)
add_read_command(dw, self._values, net) add_read_command(dw, self._values, variable.net)
assert coparun(self._context, dw.get_data()) > 0 assert coparun(self._context, dw.get_data()) > 0

View File

@ -30,9 +30,9 @@ def test_compile():
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)
#il.write_com(_binwrite.Command.DUMP_CODE) #il.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test: for v in ret_test:
assert isinstance(net, copapy.backend.Net) assert isinstance(v, value)
add_read_command(il, variables, net) add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM) il.write_com(_binwrite.Command.END_COM)

View File

@ -70,7 +70,7 @@ def test_timing_compiler():
# Get all nets/variables associated with heap memory # Get all nets/variables associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops) variable_list = get_nets([const_net_list], extended_output_ops)
stencil_names = {node.name for _, node in extended_output_ops} stencil_names = {node.name for _, node in extended_output_ops}
print(f'-- get_sub_functions: {len(stencil_names)}') print(f'-- get_sub_functions: {len(stencil_names)}')

View File

@ -65,9 +65,9 @@ def test_compile():
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)
for net in ret: for v in ret:
assert isinstance(net, copapy.backend.Net) assert isinstance(v, cp.value)
add_read_command(il, variables, net) add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM) il.write_com(_binwrite.Command.END_COM)

View File

@ -60,9 +60,9 @@ def test_compile():
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)
for net in ret: for v in ret:
assert isinstance(net, backend.Net) assert isinstance(v, cp.value)
add_read_command(il, variables, net) add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM) il.write_com(_binwrite.Command.END_COM)

View File

@ -61,9 +61,9 @@ def test_compile():
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)
#il.write_com(_binwrite.Command.DUMP_CODE) #il.write_com(_binwrite.Command.DUMP_CODE)
for net in ret: for v in ret:
assert isinstance(net, backend.Net) assert isinstance(v, cp.value)
add_read_command(il, variables, net) add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM) il.write_com(_binwrite.Command.END_COM)

View File

@ -33,9 +33,9 @@ def test_compile():
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)
for net in ret: for v in ret:
assert isinstance(net, Net) assert isinstance(v, value)
add_read_command(il, vars, net) add_read_command(il, vars, v.net)
il.write_com(_binwrite.Command.END_COM) il.write_com(_binwrite.Command.END_COM)

View File

@ -28,9 +28,9 @@ def test_compile_sqrt():
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)
for net in ret: for v in ret:
assert isinstance(net, copapy.backend.Net) assert isinstance(v, value)
add_read_command(il, variables, net) add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM) il.write_com(_binwrite.Command.END_COM)
@ -62,9 +62,9 @@ def test_compile_log():
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)
for net in ret: for v in ret:
assert isinstance(net, copapy.backend.Net) assert isinstance(v, value)
add_read_command(il, variables, net) add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM) il.write_com(_binwrite.Command.END_COM)
@ -96,9 +96,9 @@ def test_compile_sin():
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)
for net in ret: for v in ret:
assert isinstance(net, copapy.backend.Net) assert isinstance(v, copapy.value)
add_read_command(il, variables, net) add_read_command(il, variables, v.net)
il.write_com(_binwrite.Command.END_COM) il.write_com(_binwrite.Command.END_COM)

View File

@ -13,7 +13,7 @@ def test_get_dag_stats():
v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size)) v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size))
assert isinstance(v3, value) assert isinstance(v3, value)
stat = get_dag_stats([v3]) stat = get_dag_stats([v3.net])
print(stat) print(stat)
assert stat['const_float'] == 2 * v_size assert stat['const_float'] == 2 * v_size

View File

@ -16,7 +16,7 @@ def test_multi_target():
tg1.compile(e) tg1.compile(e)
# Patch constant value # Patch constant value
a.source = cp._basic_types.CPConstant(1000.0) a.net.source = cp._basic_types.CPConstant(1000.0)
tg2 = cp.Target() tg2 = cp.Target()
tg2.compile(e) tg2.compile(e)

View File

@ -107,9 +107,9 @@ def test_compile():
dw.write_com(_binwrite.Command.RUN_PROG) dw.write_com(_binwrite.Command.RUN_PROG)
#dw.write_com(_binwrite.Command.DUMP_CODE) #dw.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test: for v in ret_test:
assert isinstance(net, backend.Net) assert isinstance(v, value)
add_read_command(dw, variables, net) add_read_command(dw, variables, v.net)
#dw.write_com(_binwrite.Command.READ_DATA) #dw.write_com(_binwrite.Command.READ_DATA)
#dw.write_int(0) #dw.write_int(0)

View File

@ -109,9 +109,9 @@ def test_compile():
dw.write_com(_binwrite.Command.RUN_PROG) dw.write_com(_binwrite.Command.RUN_PROG)
#dw.write_com(_binwrite.Command.DUMP_CODE) #dw.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test: for v in ret_test:
assert isinstance(net, backend.Net) assert isinstance(v, value)
add_read_command(dw, variables, net) add_read_command(dw, variables, v.net)
#dw.write_com(_binwrite.Command.READ_DATA) #dw.write_com(_binwrite.Command.READ_DATA)
#dw.write_int(0) #dw.write_int(0)
@ -148,7 +148,7 @@ def test_compile():
for test, ref in zip(ret_test, ret_ref): for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, value) assert isinstance(test, value)
address = variables[test][0] address = variables[test.net][0]
data = result_data[address] data = result_data[address]
if test.dtype == 'int': if test.dtype == 'int':
val = int.from_bytes(data, sdb.byteorder, signed=True) val = int.from_bytes(data, sdb.byteorder, signed=True)

View File

@ -120,9 +120,9 @@ def test_compile():
dw.write_com(_binwrite.Command.RUN_PROG) dw.write_com(_binwrite.Command.RUN_PROG)
#dw.write_com(_binwrite.Command.DUMP_CODE) #dw.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test: for v in ret_test:
assert isinstance(net, backend.Net) assert isinstance(v, value)
add_read_command(dw, variables, net) add_read_command(dw, variables, v.net)
#dw.write_com(_binwrite.Command.READ_DATA) #dw.write_com(_binwrite.Command.READ_DATA)
#dw.write_int(0) #dw.write_int(0)