diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index e3ef3e5..b32828c 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -49,12 +49,22 @@ class Node: name (str): The name of the operation this Node represents. """ def __init__(self) -> None: - self.args: Sequence[Net] = [] + self.args: tuple[Net, ...] = tuple() self.name: str = '' + self.node_hash = 0 def __repr__(self) -> str: return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})" + def get_node_hash(self, commutative: bool = False) -> int: + if commutative: + return hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args)) + return hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args)) + + + def __hash__(self) -> int: + return self.node_hash + class Net: """A Net represents a variable in the computation graph - or more generally it @@ -67,14 +77,13 @@ class Net: def __init__(self, dtype: str, source: Node): self.dtype = dtype self.source = source - self.grad: NumLike = 1 def __repr__(self) -> str: names = get_var_name(self) - return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}" + return f"{'name:' + names[0] if names else 'id:' + str(hash(self))[-5:]}" def __hash__(self) -> int: - return id(self) + return self.source.node_hash class variable(Generic[TNum], Net): @@ -104,8 +113,6 @@ class variable(Generic[TNum], Net): else: self.source = CPConstant(source) self.dtype = 'int' - - self.grad = 1 @overload def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ... @@ -118,7 +125,7 @@ class variable(Generic[TNum], Net): @overload def __add__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ... def __add__(self, other: TVarNumb) -> Any: - if isinstance(other, int | float) and other == 0: + if not isinstance(other, variable) and other == 0: return self return add_op('add', [self, other], True) @@ -129,9 +136,7 @@ class variable(Generic[TNum], Net): @overload def __radd__(self, other: float) -> 'variable[float]': ... def __radd__(self, other: NumLike) -> Any: - if isinstance(other, int | float) and other == 0: - return self - return add_op('add', [self, other], True) + return self + other @overload def __sub__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ... @@ -144,6 +149,8 @@ class variable(Generic[TNum], Net): @overload def __sub__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ... def __sub__(self, other: TVarNumb) -> Any: + if isinstance(other, int | float) and other == 0: + return self return add_op('sub', [self, other]) @overload @@ -168,6 +175,11 @@ class variable(Generic[TNum], Net): def __mul__(self, other: TVarNumb) -> Any: if self.dtype == 'float' and isinstance(other, int): other = float(other) # Prevent runtime conversion of consts; TODO: add this for other operations + if not isinstance(other, variable): + if other == 1: + return self + elif other == 0: + return 0 return add_op('mul', [self, other], True) @overload @@ -177,7 +189,7 @@ class variable(Generic[TNum], Net): @overload def __rmul__(self, other: float) -> 'variable[float]': ... def __rmul__(self, other: NumLike) -> Any: - return add_op('mul', [self, other], True) + return self * other def __truediv__(self, other: NumLike) -> 'variable[float]': return add_op('div', [self, other]) @@ -319,7 +331,8 @@ class CPConstant(Node): def __init__(self, value: int | float): self.dtype, self.value = _get_data_and_dtype(value) self.name = 'const_' + self.dtype - self.args = [] + self.args = tuple() + self.node_hash = id(self) class Write(Node): @@ -331,14 +344,16 @@ class Write(Node): net = Net(node.dtype, node) self.name = 'write_' + transl_type(net.dtype) - self.args = [net] + self.args = (net,) + self.node_hash = hash(self.name) ^ hash(net.source.node_hash) class Op(Node): - def __init__(self, typed_op_name: str, args: Sequence[Net]): + def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False): assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' self.name: str = typed_op_name - self.args: Sequence[Net] = args + self.args: tuple[Net, ...] = tuple(args) + self.node_hash = self.get_node_hash(commutative) def net_from_value(value: Any) -> variable[Any]: @@ -378,9 +393,9 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool = result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] if result_type == 'float': - return variable[float](Op(typed_op, arg_nets), result_type) + return variable[float](Op(typed_op, arg_nets, commutative), result_type) else: - return variable[int](Op(typed_op, arg_nets), result_type) + return variable[int](Op(typed_op, arg_nets, commutative), result_type) def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: