From 6dcaa6797cbff68213d1b15e920c2228b48eab17 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Tue, 23 Dec 2025 17:54:57 +0100 Subject: [PATCH] Net and value types separated --- src/copapy/_autograd.py | 20 +++--- src/copapy/_basic_types.py | 118 +++++++++++++++++-------------- src/copapy/_compiler.py | 6 +- src/copapy/_target.py | 69 +++++++++--------- tests/test_branching_stencils.py | 6 +- tests/test_comp_timing.py | 2 +- tests/test_compile.py | 6 +- tests/test_compile_aarch64.py | 6 +- tests/test_compile_armv7.py | 6 +- tests/test_compile_div.py | 6 +- tests/test_compile_math.py | 18 ++--- tests/test_dag_optimization.py | 2 +- tests/test_multi_targets.py | 2 +- tests/test_ops_aarch64.py | 6 +- tests/test_ops_armv7.py | 8 +-- tests/test_ops_x86.py | 6 +- 16 files changed, 152 insertions(+), 135 deletions(-) diff --git a/src/copapy/_autograd.py b/src/copapy/_autograd.py index a3ff594..d17cedc 100644 --- a/src/copapy/_autograd.py +++ b/src/copapy/_autograd.py @@ -35,23 +35,23 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | matrix[Any assert isinstance(y, Sequence) or isinstance(y, vector) y_set = {v for v in y} - edges = cpb.get_all_dag_edges_between([x.source], (net.source for net in y_set if isinstance(net, Net))) + edges = cpb.get_all_dag_edges_between([x.net.source], (v.net.source for v in y_set if isinstance(v, value))) ordered_ops = cpb.stable_toposort(edges) net_lookup = {net.source: net for node in ordered_ops for net in node.args} grad_dict: dict[Net, unifloat] = dict() def add_grad(val: value[Any], gradient_value: unifloat) -> None: - grad_dict[val] = grad_dict.get(val, 0.0) + gradient_value + grad_dict[val.net] = grad_dict.get(val.net, 0.0) + gradient_value for node in reversed(ordered_ops): #print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}") if node.args: - args: Sequence[Any] = list(node.args) - g = 1.0 if node is x.source else grad_dict[net_lookup[node]] + args: Sequence[Net] = list(node.args) + g = 1.0 if node is x.net.source else grad_dict[net_lookup[node]] opn = node.name.split('_')[0] - a: value[Any] = args[0] - b: value[Any] = args[1] if len(args) > 1 else a + a: value[float] = value(args[0]) + b: value[float] = value(args[1]) if len(args) > 1 else a if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']: pass # Derivative is 0 for all ops returning integers @@ -119,9 +119,9 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | matrix[Any raise ValueError(f"Operation {opn} not yet supported for auto diff.") if isinstance(y, value): - return grad_dict[y] + return grad_dict[y.net] if isinstance(y, vector): - return vector(grad_dict[yi] if isinstance(yi, value) else 0.0 for yi in y) + return vector(grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y) if isinstance(y, matrix): - return matrix((grad_dict[yi] if isinstance(yi, value) else 0.0 for yi in row) for row in y) - return [grad_dict[yi] for yi in y] + return matrix((grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in row) for row in y) + return [grad_dict[yi.net] for yi in y] diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 60743bd..70f23f0 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -83,35 +83,51 @@ class Net: def __hash__(self) -> int: return self.source.node_hash + + def __eq__(self, value: object) -> bool: + return isinstance(value, Net) and self.source.node_hash == value.source.node_hash # TODO: change to 64 bit hash -class value(Generic[TNum], Net): +class value(Generic[TNum]): """A "value" represents a typed scalar variable. It supports arithmetic and comparison operations. Attributes: dtype (str): Data type of this value. """ - def __init__(self, source: TNum | Node, dtype: str | None = None): + def __init__(self, source: TNum | Net, dtype: str | None = None): """Instance a value. Arguments: - source: A numeric value or Node object. - dtype: Data type of this value. Required if source is a Node. + dtype: Data type of this value. + net: Reference to the underlying Net in the graph """ - if isinstance(source, Node): - self.source = source - assert dtype, 'For source type Node a dtype argument is required.' + if isinstance(source, Net): + self.net: Net = source + if dtype: + assert transl_type(dtype) == source.dtype, f"Type of Net ({source.dtype}) does not match {dtype}" + self.dtype: str = dtype + else: + self.dtype = source.dtype + elif dtype == 'int' or dtype == 'bool': + new_node = CPConstant(int(source), False) + self.net = Net(new_node.dtype, new_node) self.dtype = dtype - elif isinstance(source, float): - self.source = CPConstant(source, False) - self.dtype = 'float' - elif isinstance(source, bool): - self.source = CPConstant(source, False) - self.dtype = 'bool' + elif dtype == 'float': + new_node = CPConstant(float(source), False) + self.net = Net(new_node.dtype, new_node) + self.dtype = dtype + elif dtype is None: + if isinstance(source, bool): + new_node = CPConstant(source, False) + self.net = Net(new_node.dtype, new_node) + self.dtype = 'bool' + else: + new_node = CPConstant(source, False) + self.net = Net(new_node.dtype, new_node) + self.dtype = new_node.dtype else: - self.source = CPConstant(source, False) - self.dtype = 'int' + raise ValueError('Unknown type: {dtype}') @overload def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ... @@ -227,28 +243,22 @@ class value(Generic[TNum], Net): return cast(TCPNum, add_op('sub', [value(0), self])) def __gt__(self, other: TVarNumb) -> 'value[int]': - ret = add_op('gt', [self, other]) - return value(ret.source, dtype='bool') + return add_op('gt', [self, other], dtype='bool') def __lt__(self, other: TVarNumb) -> 'value[int]': - ret = add_op('gt', [other, self]) - return value(ret.source, dtype='bool') + return add_op('gt', [other, self], dtype='bool') def __ge__(self, other: TVarNumb) -> 'value[int]': - ret = add_op('ge', [self, other]) - return value(ret.source, dtype='bool') + return add_op('ge', [self, other], dtype='bool') def __le__(self, other: TVarNumb) -> 'value[int]': - ret = add_op('ge', [other, self]) - return value(ret.source, dtype='bool') + return add_op('ge', [other, self], dtype='bool') def __eq__(self, other: TVarNumb) -> 'value[int]': # type: ignore - ret = add_op('eq', [self, other], True) - return value(ret.source, dtype='bool') + return add_op('eq', [self, other], True, dtype='bool') def __ne__(self, other: TVarNumb) -> 'value[int]': # type: ignore - ret = add_op('ne', [self, other], True) - return value(ret.source, dtype='bool') + return add_op('ne', [self, other], True, dtype='bool') @overload def __mod__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ... @@ -295,7 +305,7 @@ class value(Generic[TNum], Net): return cp.pow(other, self) def __hash__(self) -> int: - return super().__hash__() + return id(self) # Bitwise and shift operations for cp[int] def __lshift__(self, other: uniint) -> 'value[int]': @@ -330,16 +340,26 @@ class value(Generic[TNum], Net): class CPConstant(Node): - def __init__(self, value: int | float, anonymous: bool = True): - self.dtype, self.value = _get_data_and_dtype(value) + def __init__(self, value: Any, anonymous: bool = True): + if isinstance(value, int): + self.value: int | float = value + self.dtype = 'int' + elif isinstance(value, float): + self.value = value + self.dtype = 'float' + else: + raise ValueError(f'Non supported data type: {type(value).__name__}') + self.name = 'const_' + self.dtype self.args = tuple() self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self) class Write(Node): - def __init__(self, input: Net | int | float): - if isinstance(input, Net): + def __init__(self, input: value[Any] | Net | int | float): + if isinstance(input, value): + net = input.net + elif isinstance(input, Net): net = input else: node = CPConstant(input) @@ -352,15 +372,16 @@ class Write(Node): class Op(Node): def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False): - assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' self.name: str = typed_op_name self.args: tuple[Net, ...] = tuple(args) self.node_hash = self.get_node_hash(commutative) -def net_from_value(val: Any) -> value[Any]: - vi = CPConstant(val) - return value(vi, vi.dtype) +def value_from_number(val: Any) -> value[Any]: + # Create anonymous constant that can be removed during optimization + new_node = CPConstant(val) + new_net = Net(new_node.dtype, new_node) + return value(new_net) @overload @@ -392,29 +413,22 @@ def iif(expression: Any, true_result: Any, false_result: Any) -> Any: return (expression != 0) * true_result + (expression == 0) * false_result -def add_op(op: str, args: list[value[Any] | int | float], commutative: bool = False) -> value[Any]: - arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] +def add_op(op: str, args: list[value[Any] | int | float], commutative: bool = False, dtype: str | None = None) -> value[Any]: + arg_values = [a if isinstance(a, value) else value_from_number(a) for a in args] if commutative: - arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order + arg_values = sorted(arg_values, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order - typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) + typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_values]) if typed_op not in generic_sdb.stencil_definitions: - raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}") + raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_values])}") result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] - if result_type == 'float': - return value[float](Op(typed_op, arg_nets, commutative), result_type) - else: - return value[int](Op(typed_op, arg_nets, commutative), result_type) + result_net = Net(result_type, Op(typed_op, [av.net for av in arg_values], commutative)) + if dtype: + result_type = dtype -def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: - if isinstance(value, int): - return ('int', int(value)) - elif isinstance(value, float): - return ('float', float(value)) - else: - raise ValueError(f'Non supported data type: {type(value).__name__}') + return value(result_net, result_type) diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 32b13d2..6af41b2 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -221,6 +221,8 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]: for net in el: if isinstance(net, Net): nets.add(net) + else: + assert net is None or isinstance(net, Node), net return list(nets) @@ -351,7 +353,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi dw.write_com(binw.Command.FREE_MEMORY) # Get all nets/variables associated with heap memory - variable_list = get_nets([[const_net_list]], extended_output_ops) + variable_list = get_nets([const_net_list], extended_output_ops) stencil_names = {node.name for _, node in extended_output_ops} aux_function_names = sdb.get_sub_functions(stencil_names) @@ -378,7 +380,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi dw.write_int(start) dw.write_int(lengths) dw.write_value(net.source.value, lengths) - #print(f'+ {net.dtype} {net.source.value}') + print(f'+ {net.dtype} {net.source.value}') # prep auxiliary_functions code_section_layout, func_addr_lookup, aux_func_len = get_aux_func_layout(aux_function_names, sdb) diff --git a/src/copapy/_target.py b/src/copapy/_target.py index f3dc33e..8f1866f 100644 --- a/src/copapy/_target.py +++ b/src/copapy/_target.py @@ -66,21 +66,21 @@ class Target(): def __del__(self) -> None: clear_target(self._context) - def compile(self, *values: int | float | value[int] | value[float] | Iterable[int | float | value[int] | value[float]]) -> None: + def compile(self, *values: int | float | value[Any] | Iterable[int | float | value[Any]]) -> None: """Compiles the code to compute the given values. Arguments: values: Values to compute """ nodes: list[Node] = [] - for s in values: - if isinstance(s, Iterable): - for net in s: - if isinstance(net, Net): - nodes.append(Write(net)) + for input in values: + if isinstance(input, Iterable): + for v in input: + if isinstance(v, value): + nodes.append(Write(v)) else: - if isinstance(s, Net): - nodes.append(Write(s)) + if isinstance(input, value): + nodes.append(Write(input)) dw, self._values = compile_to_dag(nodes, self.sdb) dw.write_com(binw.Command.END_COM) @@ -95,32 +95,33 @@ class Target(): assert coparun(self._context, dw.get_data()) > 0 @overload - def read_value(self, net: value[T]) -> T: ... + def read_value(self, variables: value[T]) -> T: ... @overload - def read_value(self, net: NumLike) -> float | int | bool: ... + def read_value(self, variables: NumLike) -> float | int | bool: ... @overload - def read_value(self, net: Iterable[T | value[T]]) -> list[T]: ... - def read_value(self, net: NumLike | value[T] | Iterable[T | value[T]]) -> Any: + def read_value(self, variables: Iterable[T | value[T]]) -> list[T]: ... + def read_value(self, variables: NumLike | value[T] | Iterable[T | value[T]]) -> Any: """Reads the numeric value of a copapy type. Arguments: - net: Value or multiple Values to read + variables: Variable or multiple variables to read Returns: Numeric value or values """ - if isinstance(net, Iterable): - return [self.read_value(ni) if isinstance(ni, value) else ni for ni in net] + if isinstance(variables, Iterable): + return [self.read_value(ni) if isinstance(ni, value) else ni for ni in variables] - if isinstance(net, float | int): - return net + if isinstance(variables, float | int): + return variables - assert isinstance(net, Net), "Argument must be a copapy value" - assert net in self._values, f"Value {net} not found. It might not have been compiled for the target." - addr, lengths, var_type = self._values[net] + assert isinstance(variables, value), "Argument must be a copapy value" + assert variables.net in self._values, f"Value {variables} not found. It might not have been compiled for the target." + addr, lengths, _ = self._values[variables.net] + var_type = variables.dtype assert lengths > 0 data = read_data_mem(self._context, addr, lengths) - assert data is not None and len(data) == lengths, f"Failed to read value {net}" + assert data is not None and len(data) == lengths, f"Failed to read value {variables}" en = {'little': '<', 'big': '>'}[self.sdb.byteorder] if var_type == 'float': if lengths == 4: @@ -142,24 +143,24 @@ class Target(): else: raise ValueError(f"Unsupported value type: {var_type}") - def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None: + def write_value(self, variables: value[Any] | Iterable[value[Any]], data: int | float | Iterable[int | float]) -> None: """Write to a copapy value on the target. Arguments: - net: Singe variable or multiple variables to overwrite + variables: Singe variable or multiple variables to overwrite value: Singe value or multiple values to write """ - if isinstance(net, Iterable): - assert isinstance(value, Iterable), "If net is iterable, value must be iterable too" - for ni, vi in zip(net, value): + if isinstance(variables, Iterable): + assert isinstance(data, Iterable), "If net is iterable, value must be iterable too" + for ni, vi in zip(variables, data): self.write_value(ni, vi) return - assert not isinstance(value, Iterable), "If net is not iterable, value must not be iterable" + assert not isinstance(data, Iterable), "If net is not iterable, value must not be iterable" - assert isinstance(net, Net), "Argument must be a copapy value" - assert net in self._values, f"Value {net} not found. It might not have been compiled for the target." - addr, lengths, var_type = self._values[net] + assert isinstance(variables, value), "Argument must be a copapy value" + assert variables.net in self._values, f"Value {variables} not found. It might not have been compiled for the target." + addr, lengths, var_type = self._values[variables.net] assert lengths > 0 dw = binw.data_writer(self.sdb.byteorder) @@ -168,17 +169,17 @@ class Target(): dw.write_int(lengths) if var_type == 'float': - dw.write_value(float(value), lengths) + dw.write_value(float(data), lengths) elif var_type == 'int' or var_type == 'bool': - dw.write_value(int(value), lengths) + dw.write_value(int(data), lengths) else: raise ValueError(f"Unsupported value type: {var_type}") dw.write_com(binw.Command.END_COM) assert coparun(self._context, dw.get_data()) > 0 - def read_value_remote(self, net: Net) -> None: + def read_value_remote(self, variable: value[Any]) -> None: """Reads the raw data of a value by the runner.""" dw = binw.data_writer(self.sdb.byteorder) - add_read_command(dw, self._values, net) + add_read_command(dw, self._values, variable.net) assert coparun(self._context, dw.get_data()) > 0 diff --git a/tests/test_branching_stencils.py b/tests/test_branching_stencils.py index 8e8c244..ae26f64 100644 --- a/tests/test_branching_stencils.py +++ b/tests/test_branching_stencils.py @@ -30,9 +30,9 @@ def test_compile(): il.write_com(_binwrite.Command.RUN_PROG) #il.write_com(_binwrite.Command.DUMP_CODE) - for net in ret_test: - assert isinstance(net, copapy.backend.Net) - add_read_command(il, variables, net) + for v in ret_test: + assert isinstance(v, value) + add_read_command(il, variables, v.net) il.write_com(_binwrite.Command.END_COM) diff --git a/tests/test_comp_timing.py b/tests/test_comp_timing.py index b848fb9..5a1fe81 100644 --- a/tests/test_comp_timing.py +++ b/tests/test_comp_timing.py @@ -70,7 +70,7 @@ def test_timing_compiler(): # Get all nets/variables associated with heap memory - variable_list = get_nets([[const_net_list]], extended_output_ops) + variable_list = get_nets([const_net_list], extended_output_ops) stencil_names = {node.name for _, node in extended_output_ops} print(f'-- get_sub_functions: {len(stencil_names)}') diff --git a/tests/test_compile.py b/tests/test_compile.py index f34eae4..8f88aac 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -65,9 +65,9 @@ def test_compile(): # run program command il.write_com(_binwrite.Command.RUN_PROG) - for net in ret: - assert isinstance(net, copapy.backend.Net) - add_read_command(il, variables, net) + for v in ret: + assert isinstance(v, cp.value) + add_read_command(il, variables, v.net) il.write_com(_binwrite.Command.END_COM) diff --git a/tests/test_compile_aarch64.py b/tests/test_compile_aarch64.py index 3235a9d..1cdeb14 100644 --- a/tests/test_compile_aarch64.py +++ b/tests/test_compile_aarch64.py @@ -60,9 +60,9 @@ def test_compile(): # run program command il.write_com(_binwrite.Command.RUN_PROG) - for net in ret: - assert isinstance(net, backend.Net) - add_read_command(il, variables, net) + for v in ret: + assert isinstance(v, cp.value) + add_read_command(il, variables, v.net) il.write_com(_binwrite.Command.END_COM) diff --git a/tests/test_compile_armv7.py b/tests/test_compile_armv7.py index 6d52845..79ea027 100644 --- a/tests/test_compile_armv7.py +++ b/tests/test_compile_armv7.py @@ -61,9 +61,9 @@ def test_compile(): il.write_com(_binwrite.Command.RUN_PROG) #il.write_com(_binwrite.Command.DUMP_CODE) - for net in ret: - assert isinstance(net, backend.Net) - add_read_command(il, variables, net) + for v in ret: + assert isinstance(v, cp.value) + add_read_command(il, variables, v.net) il.write_com(_binwrite.Command.END_COM) diff --git a/tests/test_compile_div.py b/tests/test_compile_div.py index 717db36..8b355ad 100644 --- a/tests/test_compile_div.py +++ b/tests/test_compile_div.py @@ -33,9 +33,9 @@ def test_compile(): # run program command il.write_com(_binwrite.Command.RUN_PROG) - for net in ret: - assert isinstance(net, Net) - add_read_command(il, vars, net) + for v in ret: + assert isinstance(v, value) + add_read_command(il, vars, v.net) il.write_com(_binwrite.Command.END_COM) diff --git a/tests/test_compile_math.py b/tests/test_compile_math.py index d65f626..75731a5 100644 --- a/tests/test_compile_math.py +++ b/tests/test_compile_math.py @@ -28,9 +28,9 @@ def test_compile_sqrt(): # run program command il.write_com(_binwrite.Command.RUN_PROG) - for net in ret: - assert isinstance(net, copapy.backend.Net) - add_read_command(il, variables, net) + for v in ret: + assert isinstance(v, value) + add_read_command(il, variables, v.net) il.write_com(_binwrite.Command.END_COM) @@ -62,9 +62,9 @@ def test_compile_log(): # run program command il.write_com(_binwrite.Command.RUN_PROG) - for net in ret: - assert isinstance(net, copapy.backend.Net) - add_read_command(il, variables, net) + for v in ret: + assert isinstance(v, value) + add_read_command(il, variables, v.net) il.write_com(_binwrite.Command.END_COM) @@ -96,9 +96,9 @@ def test_compile_sin(): # run program command il.write_com(_binwrite.Command.RUN_PROG) - for net in ret: - assert isinstance(net, copapy.backend.Net) - add_read_command(il, variables, net) + for v in ret: + assert isinstance(v, copapy.value) + add_read_command(il, variables, v.net) il.write_com(_binwrite.Command.END_COM) diff --git a/tests/test_dag_optimization.py b/tests/test_dag_optimization.py index 5d705ee..bc4db4e 100644 --- a/tests/test_dag_optimization.py +++ b/tests/test_dag_optimization.py @@ -13,7 +13,7 @@ def test_get_dag_stats(): v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size)) assert isinstance(v3, value) - stat = get_dag_stats([v3]) + stat = get_dag_stats([v3.net]) print(stat) assert stat['const_float'] == 2 * v_size diff --git a/tests/test_multi_targets.py b/tests/test_multi_targets.py index 34642e3..3ccc2ce 100644 --- a/tests/test_multi_targets.py +++ b/tests/test_multi_targets.py @@ -16,7 +16,7 @@ def test_multi_target(): tg1.compile(e) # Patch constant value - a.source = cp._basic_types.CPConstant(1000.0) + a.net.source = cp._basic_types.CPConstant(1000.0) tg2 = cp.Target() tg2.compile(e) diff --git a/tests/test_ops_aarch64.py b/tests/test_ops_aarch64.py index 26f026c..ea54347 100644 --- a/tests/test_ops_aarch64.py +++ b/tests/test_ops_aarch64.py @@ -107,9 +107,9 @@ def test_compile(): dw.write_com(_binwrite.Command.RUN_PROG) #dw.write_com(_binwrite.Command.DUMP_CODE) - for net in ret_test: - assert isinstance(net, backend.Net) - add_read_command(dw, variables, net) + for v in ret_test: + assert isinstance(v, value) + add_read_command(dw, variables, v.net) #dw.write_com(_binwrite.Command.READ_DATA) #dw.write_int(0) diff --git a/tests/test_ops_armv7.py b/tests/test_ops_armv7.py index 63a22eb..415076f 100644 --- a/tests/test_ops_armv7.py +++ b/tests/test_ops_armv7.py @@ -109,9 +109,9 @@ def test_compile(): dw.write_com(_binwrite.Command.RUN_PROG) #dw.write_com(_binwrite.Command.DUMP_CODE) - for net in ret_test: - assert isinstance(net, backend.Net) - add_read_command(dw, variables, net) + for v in ret_test: + assert isinstance(v, value) + add_read_command(dw, variables, v.net) #dw.write_com(_binwrite.Command.READ_DATA) #dw.write_int(0) @@ -148,7 +148,7 @@ def test_compile(): for test, ref in zip(ret_test, ret_ref): assert isinstance(test, value) - address = variables[test][0] + address = variables[test.net][0] data = result_data[address] if test.dtype == 'int': val = int.from_bytes(data, sdb.byteorder, signed=True) diff --git a/tests/test_ops_x86.py b/tests/test_ops_x86.py index 6ea427d..13c902e 100644 --- a/tests/test_ops_x86.py +++ b/tests/test_ops_x86.py @@ -120,9 +120,9 @@ def test_compile(): dw.write_com(_binwrite.Command.RUN_PROG) #dw.write_com(_binwrite.Command.DUMP_CODE) - for net in ret_test: - assert isinstance(net, backend.Net) - add_read_command(dw, variables, net) + for v in ret_test: + assert isinstance(v, value) + add_read_command(dw, variables, v.net) #dw.write_com(_binwrite.Command.READ_DATA) #dw.write_int(0)