mirror of https://github.com/Nonannet/copapy.git
__eq__ for Op extended with caching
This commit is contained in:
parent
a6ca756d51
commit
e69c0ff24b
|
|
@ -397,19 +397,26 @@ class Op(Node):
|
|||
if not isinstance(other, Op):
|
||||
return NotImplemented
|
||||
|
||||
# Traverse graph for both notes. Return false on first difference
|
||||
# Traverse graph for both notes. Return false on first difference.
|
||||
# A false inequality result in seldom cases is ok, whereas a false
|
||||
# equality result leads to wrong computation results.
|
||||
nodes: list[tuple[Node, Node]] = [(self, other)]
|
||||
seen: set[tuple[int, int]] = set()
|
||||
while(nodes):
|
||||
s_node, o_node = nodes.pop()
|
||||
|
||||
if s_node.node_hash != o_node.node_hash:
|
||||
return False
|
||||
key = (id(s_node), id(o_node))
|
||||
if key in seen:
|
||||
continue
|
||||
if isinstance(s_node, Op):
|
||||
if s_node.name.split('_')[0] != o_node.name.split('_')[0]:
|
||||
if (s_node.name.split('_')[0] != o_node.name.split('_')[0] or
|
||||
len(o_node.args) != len(s_node.args)):
|
||||
return False
|
||||
if s_node.commutative:
|
||||
for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash),
|
||||
sorted(o_node.args, key=lambda x: x.source.node_hash)):
|
||||
for s_net, o_net in zip(sorted(s_node.args, key=hash),
|
||||
sorted(o_node.args, key=hash)):
|
||||
if s_net is not o_net:
|
||||
nodes.append((s_net.source, o_net.source))
|
||||
else:
|
||||
|
|
@ -418,6 +425,7 @@ class Op(Node):
|
|||
nodes.append((s_net.source, o_net.source))
|
||||
elif s_node != o_node:
|
||||
return False
|
||||
seen.add(key)
|
||||
return True
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
|
|
|||
Loading…
Reference in New Issue