mirror of https://github.com/Nonannet/copapy.git
__eq__ for Op extended with caching
This commit is contained in:
parent
a6ca756d51
commit
e69c0ff24b
|
|
@ -397,19 +397,26 @@ class Op(Node):
|
||||||
if not isinstance(other, Op):
|
if not isinstance(other, Op):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
# Traverse graph for both notes. Return false on first difference
|
# Traverse graph for both notes. Return false on first difference.
|
||||||
|
# A false inequality result in seldom cases is ok, whereas a false
|
||||||
|
# equality result leads to wrong computation results.
|
||||||
nodes: list[tuple[Node, Node]] = [(self, other)]
|
nodes: list[tuple[Node, Node]] = [(self, other)]
|
||||||
|
seen: set[tuple[int, int]] = set()
|
||||||
while(nodes):
|
while(nodes):
|
||||||
s_node, o_node = nodes.pop()
|
s_node, o_node = nodes.pop()
|
||||||
|
|
||||||
if s_node.node_hash != o_node.node_hash:
|
if s_node.node_hash != o_node.node_hash:
|
||||||
return False
|
return False
|
||||||
|
key = (id(s_node), id(o_node))
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
if isinstance(s_node, Op):
|
if isinstance(s_node, Op):
|
||||||
if s_node.name.split('_')[0] != o_node.name.split('_')[0]:
|
if (s_node.name.split('_')[0] != o_node.name.split('_')[0] or
|
||||||
|
len(o_node.args) != len(s_node.args)):
|
||||||
return False
|
return False
|
||||||
if s_node.commutative:
|
if s_node.commutative:
|
||||||
for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash),
|
for s_net, o_net in zip(sorted(s_node.args, key=hash),
|
||||||
sorted(o_node.args, key=lambda x: x.source.node_hash)):
|
sorted(o_node.args, key=hash)):
|
||||||
if s_net is not o_net:
|
if s_net is not o_net:
|
||||||
nodes.append((s_net.source, o_net.source))
|
nodes.append((s_net.source, o_net.source))
|
||||||
else:
|
else:
|
||||||
|
|
@ -418,6 +425,7 @@ class Op(Node):
|
||||||
nodes.append((s_net.source, o_net.source))
|
nodes.append((s_net.source, o_net.source))
|
||||||
elif s_node != o_node:
|
elif s_node != o_node:
|
||||||
return False
|
return False
|
||||||
|
seen.add(key)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue