From f43e025f2efb4cdb2de61fadd87c3763cc96a06d Mon Sep 17 00:00:00 2001 From: Nicolas Date: Wed, 24 Dec 2025 14:09:31 +0100 Subject: [PATCH] __hash__ and __eq__ for Net and Node partial revised, __eq__ not yet working --- src/copapy/_basic_types.py | 78 +++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 14 deletions(-) diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 9ed7f49..06ece1c 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -56,17 +56,6 @@ class Node: def __repr__(self) -> str: return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})" - def get_node_hash(self, commutative: bool = False) -> int: - if commutative: - return hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args)) - return hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args)) - - def __hash__(self) -> int: - return self.node_hash - - def __eq__(self, value: object) -> bool: - return isinstance(value, Node) and self.node_hash == value.node_hash # TODO: change to 64 bit hash - class Net: """A Net represents a scalar type in the computation graph - or more generally it @@ -87,8 +76,8 @@ class Net: def __hash__(self) -> int: return self.source.node_hash - def __eq__(self, value: object) -> bool: - return isinstance(value, Net) and self.source.node_hash == value.source.node_hash # TODO: change to 64 bit hash + def __eq__(self, other: object) -> bool: + return isinstance(other, Net) and self.source == other.source class value(Generic[TNum]): @@ -132,6 +121,10 @@ class value(Generic[TNum]): else: raise ValueError('Unknown type: {dtype}') + def __repr__(self) -> str: + names = get_var_name(self) + return f"{'name:' + names[0] if names else 'h:' + str(self.net.source.node_hash)[-5:]}" + @overload def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ... @overload @@ -355,7 +348,18 @@ class CPConstant(Node): self.name = 'const_' + self.dtype self.args = tuple() - self.node_hash = ((value if isinstance(value, int) else hash(value)) + 0x1_0000_0000) ^ hash(self.dtype) if anonymous else id(self) # TODO: Simplify hash and compare + self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self) + self.anonymous = anonymous + + def __eq__(self, other: object) -> bool: + return (self is other) or (self.anonymous and + isinstance(other, CPConstant) and + other.anonymous and + self.value == other.value and + self.dtype == other.dtype) + + def __hash__(self) -> int: + return self.node_hash class Write(Node): @@ -378,6 +382,52 @@ class Op(Node): self.name: str = typed_op_name self.args: tuple[Net, ...] = tuple(args) self.node_hash = self.get_node_hash(commutative) + self.commutative = commutative + + def get_node_hash(self, commutative: bool = False) -> int: + if commutative: + h = hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args)) + else: + h = hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args)) + return h if h != -1 else -2 + + def __eq__(self, other: object) -> bool: + #print('* __eq__', self.name, self.args, self.node_hash) + if self is other: + return True + else: + return False + if not isinstance(other, Op): + return False + self_nodes: list[Node] = [self] + other_nodes: list[Node] = [other] + while(self_nodes): + s_node = self_nodes.pop() + o_node = other_nodes.pop() + + if len(self_nodes) > 100: + print(' - ', len(self_nodes)) + + if s_node.node_hash != o_node.node_hash: + return False + if isinstance(s_node, Op): + if s_node.name.split('_')[0] != o_node.name.split('_')[0]: + return False + if s_node.commutative: + for s_net, o_net in zip(sorted(s_node.args, key=lambda x: x.source.node_hash), + sorted(o_node.args, key=lambda x: x.source.node_hash)): + self_nodes.append(s_net.source) + other_nodes.append(o_net.source) + else: + for s_net, o_net in zip(s_node.args, o_node.args): + self_nodes.append(s_net.source) + other_nodes.append(o_net.source) + elif s_node != o_node: + return False + return True + + def __hash__(self) -> int: + return self.node_hash def value_from_number(val: Any) -> value[Any]: