mirror of https://github.com/Nonannet/copapy.git
__hash__ and __eq__ for Net and Node partial revised, __eq__ not yet working
This commit is contained in:
parent
dc58e5d19a
commit
f43e025f2e
|
|
@ -56,17 +56,6 @@ class Node:
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
||||||
|
|
||||||
def get_node_hash(self, commutative: bool = False) -> int:
|
|
||||||
if commutative:
|
|
||||||
return hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
|
|
||||||
return hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return self.node_hash
|
|
||||||
|
|
||||||
def __eq__(self, value: object) -> bool:
|
|
||||||
return isinstance(value, Node) and self.node_hash == value.node_hash # TODO: change to 64 bit hash
|
|
||||||
|
|
||||||
|
|
||||||
class Net:
|
class Net:
|
||||||
"""A Net represents a scalar type in the computation graph - or more generally it
|
"""A Net represents a scalar type in the computation graph - or more generally it
|
||||||
|
|
@ -87,8 +76,8 @@ class Net:
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return self.source.node_hash
|
return self.source.node_hash
|
||||||
|
|
||||||
def __eq__(self, value: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
return isinstance(value, Net) and self.source.node_hash == value.source.node_hash # TODO: change to 64 bit hash
|
return isinstance(other, Net) and self.source == other.source
|
||||||
|
|
||||||
|
|
||||||
class value(Generic[TNum]):
|
class value(Generic[TNum]):
|
||||||
|
|
@ -132,6 +121,10 @@ class value(Generic[TNum]):
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown type: {dtype}')
|
raise ValueError('Unknown type: {dtype}')
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
names = get_var_name(self)
|
||||||
|
return f"{'name:' + names[0] if names else 'h:' + str(self.net.source.node_hash)[-5:]}"
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
|
def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -355,7 +348,18 @@ class CPConstant(Node):
|
||||||
|
|
||||||
self.name = 'const_' + self.dtype
|
self.name = 'const_' + self.dtype
|
||||||
self.args = tuple()
|
self.args = tuple()
|
||||||
self.node_hash = ((value if isinstance(value, int) else hash(value)) + 0x1_0000_0000) ^ hash(self.dtype) if anonymous else id(self) # TODO: Simplify hash and compare
|
self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self)
|
||||||
|
self.anonymous = anonymous
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
return (self is other) or (self.anonymous and
|
||||||
|
isinstance(other, CPConstant) and
|
||||||
|
other.anonymous and
|
||||||
|
self.value == other.value and
|
||||||
|
self.dtype == other.dtype)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return self.node_hash
|
||||||
|
|
||||||
|
|
||||||
class Write(Node):
|
class Write(Node):
|
||||||
|
|
@ -378,6 +382,52 @@ class Op(Node):
|
||||||
self.name: str = typed_op_name
|
self.name: str = typed_op_name
|
||||||
self.args: tuple[Net, ...] = tuple(args)
|
self.args: tuple[Net, ...] = tuple(args)
|
||||||
self.node_hash = self.get_node_hash(commutative)
|
self.node_hash = self.get_node_hash(commutative)
|
||||||
|
self.commutative = commutative
|
||||||
|
|
||||||
|
def get_node_hash(self, commutative: bool = False) -> int:
|
||||||
|
if commutative:
|
||||||
|
h = hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
|
||||||
|
else:
|
||||||
|
h = hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
|
||||||
|
return h if h != -1 else -2
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
#print('* __eq__', self.name, self.args, self.node_hash)
|
||||||
|
if self is other:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if not isinstance(other, Op):
|
||||||
|
return False
|
||||||
|
self_nodes: list[Node] = [self]
|
||||||
|
other_nodes: list[Node] = [other]
|
||||||
|
while(self_nodes):
|
||||||
|
s_node = self_nodes.pop()
|
||||||
|
o_node = other_nodes.pop()
|
||||||
|
|
||||||
|
if len(self_nodes) > 100:
|
||||||
|
print(' - ', len(self_nodes))
|
||||||
|
|
||||||
|
if s_node.node_hash != o_node.node_hash:
|
||||||
|
return False
|
||||||
|
if isinstance(s_node, Op):
|
||||||
|
if s_node.name.split('_')[0] != o_node.name.split('_')[0]:
|
||||||
|
return False
|
||||||
|
if s_node.commutative:
|
||||||
|
for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash),
|
||||||
|
sorted(o_node.args, key=lambda x: x.source.node_hash)):
|
||||||
|
self_nodes.append(s_net.source)
|
||||||
|
other_nodes.append(o_net.source)
|
||||||
|
else:
|
||||||
|
for s_net, o_net in zip(s_node.args, o_node.args):
|
||||||
|
self_nodes.append(s_net.source)
|
||||||
|
other_nodes.append(o_net.source)
|
||||||
|
elif s_node != o_node:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return self.node_hash
|
||||||
|
|
||||||
|
|
||||||
def value_from_number(val: Any) -> value[Any]:
|
def value_from_number(val: Any) -> value[Any]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue