From a6ca756d51fcb60a90707de33cbd4bb15064d7e3 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Wed, 24 Dec 2025 16:40:46 +0100 Subject: [PATCH] __eq__ for Op type implemented --- src/copapy/_basic_types.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 06ece1c..923694a 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -392,21 +392,15 @@ class Op(Node): return h if h != -1 else -2 def __eq__(self, other: object) -> bool: - #print('* __eq__', self.name, self.args, self.node_hash) if self is other: return True - else: - return False if not isinstance(other, Op): - return False - self_nodes: list[Node] = [self] - other_nodes: list[Node] = [other] - while(self_nodes): - s_node = self_nodes.pop() - o_node = other_nodes.pop() - - if len(self_nodes) > 100: - print(' - ', len(self_nodes)) + return NotImplemented + + # Traverse graph for both notes. Return false on first difference + nodes: list[tuple[Node, Node]] = [(self, other)] + while(nodes): + s_node, o_node = nodes.pop() if s_node.node_hash != o_node.node_hash: return False @@ -416,12 +410,12 @@ class Op(Node): if s_node.commutative: for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash), sorted(o_node.args, key=lambda x: x.source.node_hash)): - self_nodes.append(s_net.source) - other_nodes.append(o_net.source) + if s_net is not o_net: + nodes.append((s_net.source, o_net.source)) else: for s_net, o_net in zip(s_node.args, o_node.args): - self_nodes.append(s_net.source) - other_nodes.append(o_net.source) + if s_net is not o_net: + nodes.append((s_net.source, o_net.source)) elif s_node != o_node: return False return True