__eq__ for Op type implemented

This commit is contained in:
Nicolas 2025-12-24 16:40:46 +01:00
parent 2c2d7ca960
commit a6ca756d51
1 changed files with 10 additions and 16 deletions

View File

@ -392,21 +392,15 @@ class Op(Node):
return h if h != -1 else -2 return h if h != -1 else -2
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
#print('* __eq__', self.name, self.args, self.node_hash)
if self is other: if self is other:
return True return True
else:
return False
if not isinstance(other, Op): if not isinstance(other, Op):
return False return NotImplemented
self_nodes: list[Node] = [self]
other_nodes: list[Node] = [other]
while(self_nodes):
s_node = self_nodes.pop()
o_node = other_nodes.pop()
if len(self_nodes) > 100: # Traverse graph for both notes. Return false on first difference
print(' - ', len(self_nodes)) nodes: list[tuple[Node, Node]] = [(self, other)]
while(nodes):
s_node, o_node = nodes.pop()
if s_node.node_hash != o_node.node_hash: if s_node.node_hash != o_node.node_hash:
return False return False
@ -416,12 +410,12 @@ class Op(Node):
if s_node.commutative: if s_node.commutative:
for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash), for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash),
sorted(o_node.args, key=lambda x: x.source.node_hash)): sorted(o_node.args, key=lambda x: x.source.node_hash)):
self_nodes.append(s_net.source) if s_net is not o_net:
other_nodes.append(o_net.source) nodes.append((s_net.source, o_net.source))
else: else:
for s_net, o_net in zip(s_node.args, o_node.args): for s_net, o_net in zip(s_node.args, o_node.args):
self_nodes.append(s_net.source) if s_net is not o_net:
other_nodes.append(o_net.source) nodes.append((s_net.source, o_net.source))
elif s_node != o_node: elif s_node != o_node:
return False return False
return True return True