__eq__ for Op extended with caching

This commit is contained in:
Nicolas 2025-12-26 13:37:15 +01:00
parent a6ca756d51
commit e69c0ff24b
1 changed files with 12 additions and 4 deletions

View File

@ -397,19 +397,26 @@ class Op(Node):
if not isinstance(other, Op): if not isinstance(other, Op):
return NotImplemented return NotImplemented
# Traverse graph for both notes. Return false on first difference # Traverse graph for both notes. Return false on first difference.
# A false inequality result in seldom cases is ok, whereas a false
# equality result leads to wrong computation results.
nodes: list[tuple[Node, Node]] = [(self, other)] nodes: list[tuple[Node, Node]] = [(self, other)]
seen: set[tuple[int, int]] = set()
while(nodes): while(nodes):
s_node, o_node = nodes.pop() s_node, o_node = nodes.pop()
if s_node.node_hash != o_node.node_hash: if s_node.node_hash != o_node.node_hash:
return False return False
key = (id(s_node), id(o_node))
if key in seen:
continue
if isinstance(s_node, Op): if isinstance(s_node, Op):
if s_node.name.split('_')[0] != o_node.name.split('_')[0]: if (s_node.name.split('_')[0] != o_node.name.split('_')[0] or
len(o_node.args) != len(s_node.args)):
return False return False
if s_node.commutative: if s_node.commutative:
for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash), for s_net, o_net in zip(sorted(s_node.args, key=hash),
sorted(o_node.args, key=lambda x: x.source.node_hash)): sorted(o_node.args, key=hash)):
if s_net is not o_net: if s_net is not o_net:
nodes.append((s_net.source, o_net.source)) nodes.append((s_net.source, o_net.source))
else: else:
@ -418,6 +425,7 @@ class Op(Node):
nodes.append((s_net.source, o_net.source)) nodes.append((s_net.source, o_net.source))
elif s_node != o_node: elif s_node != o_node:
return False return False
seen.add(key)
return True return True
def __hash__(self) -> int: def __hash__(self) -> int: