mirror of https://github.com/Nonannet/copapy.git
__hash__ and __eq__ for Net and Node partial revised, __eq__ not yet working
This commit is contained in:
parent
dc58e5d19a
commit
f43e025f2e
|
|
@ -56,17 +56,6 @@ class Node:
|
|||
def __repr__(self) -> str:
|
||||
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
||||
|
||||
def get_node_hash(self, commutative: bool = False) -> int:
|
||||
if commutative:
|
||||
return hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
|
||||
return hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.node_hash
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, Node) and self.node_hash == value.node_hash # TODO: change to 64 bit hash
|
||||
|
||||
|
||||
class Net:
|
||||
"""A Net represents a scalar type in the computation graph - or more generally it
|
||||
|
|
@ -87,8 +76,8 @@ class Net:
|
|||
def __hash__(self) -> int:
|
||||
return self.source.node_hash
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, Net) and self.source.node_hash == value.source.node_hash # TODO: change to 64 bit hash
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Net) and self.source == other.source
|
||||
|
||||
|
||||
class value(Generic[TNum]):
|
||||
|
|
@ -132,6 +121,10 @@ class value(Generic[TNum]):
|
|||
else:
|
||||
raise ValueError('Unknown type: {dtype}')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
names = get_var_name(self)
|
||||
return f"{'name:' + names[0] if names else 'h:' + str(self.net.source.node_hash)[-5:]}"
|
||||
|
||||
@overload
|
||||
def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
|
||||
@overload
|
||||
|
|
@ -355,7 +348,18 @@ class CPConstant(Node):
|
|||
|
||||
self.name = 'const_' + self.dtype
|
||||
self.args = tuple()
|
||||
self.node_hash = ((value if isinstance(value, int) else hash(value)) + 0x1_0000_0000) ^ hash(self.dtype) if anonymous else id(self) # TODO: Simplify hash and compare
|
||||
self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self)
|
||||
self.anonymous = anonymous
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (self is other) or (self.anonymous and
|
||||
isinstance(other, CPConstant) and
|
||||
other.anonymous and
|
||||
self.value == other.value and
|
||||
self.dtype == other.dtype)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.node_hash
|
||||
|
||||
|
||||
class Write(Node):
|
||||
|
|
@ -378,6 +382,52 @@ class Op(Node):
|
|||
self.name: str = typed_op_name
|
||||
self.args: tuple[Net, ...] = tuple(args)
|
||||
self.node_hash = self.get_node_hash(commutative)
|
||||
self.commutative = commutative
|
||||
|
||||
def get_node_hash(self, commutative: bool = False) -> int:
|
||||
if commutative:
|
||||
h = hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
|
||||
else:
|
||||
h = hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
|
||||
return h if h != -1 else -2
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
#print('* __eq__', self.name, self.args, self.node_hash)
|
||||
if self is other:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if not isinstance(other, Op):
|
||||
return False
|
||||
self_nodes: list[Node] = [self]
|
||||
other_nodes: list[Node] = [other]
|
||||
while(self_nodes):
|
||||
s_node = self_nodes.pop()
|
||||
o_node = other_nodes.pop()
|
||||
|
||||
if len(self_nodes) > 100:
|
||||
print(' - ', len(self_nodes))
|
||||
|
||||
if s_node.node_hash != o_node.node_hash:
|
||||
return False
|
||||
if isinstance(s_node, Op):
|
||||
if s_node.name.split('_')[0] != o_node.name.split('_')[0]:
|
||||
return False
|
||||
if s_node.commutative:
|
||||
for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash),
|
||||
sorted(o_node.args, key=lambda x: x.source.node_hash)):
|
||||
self_nodes.append(s_net.source)
|
||||
other_nodes.append(o_net.source)
|
||||
else:
|
||||
for s_net, o_net in zip(s_node.args, o_node.args):
|
||||
self_nodes.append(s_net.source)
|
||||
other_nodes.append(o_net.source)
|
||||
elif s_node != o_node:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.node_hash
|
||||
|
||||
|
||||
def value_from_number(val: Any) -> value[Any]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue