__hash__ and __eq__ for Net and Node partial revised, __eq__ not yet working

This commit is contained in:
Nicolas 2025-12-24 14:09:31 +01:00
parent dc58e5d19a
commit f43e025f2e
1 changed files with 64 additions and 14 deletions

View File

@ -56,17 +56,6 @@ class Node:
def __repr__(self) -> str:
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
def get_node_hash(self, commutative: bool = False) -> int:
if commutative:
return hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
return hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
def __hash__(self) -> int:
return self.node_hash
def __eq__(self, value: object) -> bool:
return isinstance(value, Node) and self.node_hash == value.node_hash # TODO: change to 64 bit hash
class Net:
"""A Net represents a scalar type in the computation graph - or more generally it
@ -87,8 +76,8 @@ class Net:
def __hash__(self) -> int:
return self.source.node_hash
def __eq__(self, value: object) -> bool:
return isinstance(value, Net) and self.source.node_hash == value.source.node_hash # TODO: change to 64 bit hash
def __eq__(self, other: object) -> bool:
return isinstance(other, Net) and self.source == other.source
class value(Generic[TNum]):
@ -132,6 +121,10 @@ class value(Generic[TNum]):
else:
raise ValueError('Unknown type: {dtype}')
def __repr__(self) -> str:
names = get_var_name(self)
return f"{'name:' + names[0] if names else 'h:' + str(self.net.source.node_hash)[-5:]}"
@overload
def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
@overload
@ -355,7 +348,18 @@ class CPConstant(Node):
self.name = 'const_' + self.dtype
self.args = tuple()
self.node_hash = ((value if isinstance(value, int) else hash(value)) + 0x1_0000_0000) ^ hash(self.dtype) if anonymous else id(self) # TODO: Simplify hash and compare
self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self)
self.anonymous = anonymous
def __eq__(self, other: object) -> bool:
return (self is other) or (self.anonymous and
isinstance(other, CPConstant) and
other.anonymous and
self.value == other.value and
self.dtype == other.dtype)
def __hash__(self) -> int:
return self.node_hash
class Write(Node):
@ -378,6 +382,52 @@ class Op(Node):
self.name: str = typed_op_name
self.args: tuple[Net, ...] = tuple(args)
self.node_hash = self.get_node_hash(commutative)
self.commutative = commutative
def get_node_hash(self, commutative: bool = False) -> int:
if commutative:
h = hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
else:
h = hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
return h if h != -1 else -2
def __eq__(self, other: object) -> bool:
#print('* __eq__', self.name, self.args, self.node_hash)
if self is other:
return True
else:
return False
if not isinstance(other, Op):
return False
self_nodes: list[Node] = [self]
other_nodes: list[Node] = [other]
while(self_nodes):
s_node = self_nodes.pop()
o_node = other_nodes.pop()
if len(self_nodes) > 100:
print(' - ', len(self_nodes))
if s_node.node_hash != o_node.node_hash:
return False
if isinstance(s_node, Op):
if s_node.name.split('_')[0] != o_node.name.split('_')[0]:
return False
if s_node.commutative:
for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash),
sorted(o_node.args, key=lambda x: x.source.node_hash)):
self_nodes.append(s_net.source)
other_nodes.append(o_net.source)
else:
for s_net, o_net in zip(s_node.args, o_node.args):
self_nodes.append(s_net.source)
other_nodes.append(o_net.source)
elif s_node != o_node:
return False
return True
def __hash__(self) -> int:
return self.node_hash
def value_from_number(val: Any) -> value[Any]: