From e69c0ff24b125cf7854882a8c1ccaaaafd800222 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Fri, 26 Dec 2025 13:37:15 +0100 Subject: [PATCH] __eq__ for Op extended with caching --- src/copapy/_basic_types.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 923694a..fa5f0cf 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -397,19 +397,26 @@ class Op(Node): if not isinstance(other, Op): return NotImplemented - # Traverse graph for both notes. Return false on first difference + # Traverse graph for both notes. Return false on first difference. + # A false inequality result in seldom cases is ok, whereas a false + # equality result leads to wrong computation results. nodes: list[tuple[Node, Node]] = [(self, other)] + seen: set[tuple[int, int]] = set() while(nodes): s_node, o_node = nodes.pop() if s_node.node_hash != o_node.node_hash: return False + key = (id(s_node), id(o_node)) + if key in seen: + continue if isinstance(s_node, Op): - if s_node.name.split('_')[0] != o_node.name.split('_')[0]: + if (s_node.name.split('_')[0] != o_node.name.split('_')[0] or + len(o_node.args) != len(s_node.args)): return False if s_node.commutative: - for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash), - sorted(o_node.args, key=lambda x: x.source.node_hash)): + for s_net, o_net in zip(sorted(s_node.args, key=hash), + sorted(o_node.args, key=hash)): if s_net is not o_net: nodes.append((s_net.source, o_net.source)) else: @@ -418,6 +425,7 @@ class Op(Node): nodes.append((s_net.source, o_net.source)) elif s_node != o_node: return False + seen.add(key) return True def __hash__(self) -> int: