added optimization for + 0, - 0 and * 1 operations

This commit is contained in:
Nicolas Kruse 2025-12-03 17:28:49 +01:00
parent a30ee12d0f
commit c5048980c2
1 changed files with 32 additions and 17 deletions

View File

@ -49,12 +49,22 @@ class Node:
name (str): The name of the operation this Node represents. name (str): The name of the operation this Node represents.
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.args: Sequence[Net] = [] self.args: tuple[Net, ...] = tuple()
self.name: str = '' self.name: str = ''
self.node_hash = 0
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})" return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
def get_node_hash(self, commutative: bool = False) -> int:
if commutative:
return hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
return hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
def __hash__(self) -> int:
return self.node_hash
class Net: class Net:
"""A Net represents a variable in the computation graph - or more generally it """A Net represents a variable in the computation graph - or more generally it
@ -67,14 +77,13 @@ class Net:
def __init__(self, dtype: str, source: Node): def __init__(self, dtype: str, source: Node):
self.dtype = dtype self.dtype = dtype
self.source = source self.source = source
self.grad: NumLike = 1
def __repr__(self) -> str: def __repr__(self) -> str:
names = get_var_name(self) names = get_var_name(self)
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}" return f"{'name:' + names[0] if names else 'id:' + str(hash(self))[-5:]}"
def __hash__(self) -> int: def __hash__(self) -> int:
return id(self) return self.source.node_hash
class variable(Generic[TNum], Net): class variable(Generic[TNum], Net):
@ -104,8 +113,6 @@ class variable(Generic[TNum], Net):
else: else:
self.source = CPConstant(source) self.source = CPConstant(source)
self.dtype = 'int' self.dtype = 'int'
self.grad = 1
@overload @overload
def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ... def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
@ -118,7 +125,7 @@ class variable(Generic[TNum], Net):
@overload @overload
def __add__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ... def __add__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __add__(self, other: TVarNumb) -> Any: def __add__(self, other: TVarNumb) -> Any:
if isinstance(other, int | float) and other == 0: if not isinstance(other, variable) and other == 0:
return self return self
return add_op('add', [self, other], True) return add_op('add', [self, other], True)
@ -129,9 +136,7 @@ class variable(Generic[TNum], Net):
@overload @overload
def __radd__(self, other: float) -> 'variable[float]': ... def __radd__(self, other: float) -> 'variable[float]': ...
def __radd__(self, other: NumLike) -> Any: def __radd__(self, other: NumLike) -> Any:
if isinstance(other, int | float) and other == 0: return self + other
return self
return add_op('add', [self, other], True)
@overload @overload
def __sub__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ... def __sub__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
@ -144,6 +149,8 @@ class variable(Generic[TNum], Net):
@overload @overload
def __sub__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ... def __sub__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
def __sub__(self, other: TVarNumb) -> Any: def __sub__(self, other: TVarNumb) -> Any:
if isinstance(other, int | float) and other == 0:
return self
return add_op('sub', [self, other]) return add_op('sub', [self, other])
@overload @overload
@ -168,6 +175,11 @@ class variable(Generic[TNum], Net):
def __mul__(self, other: TVarNumb) -> Any: def __mul__(self, other: TVarNumb) -> Any:
if self.dtype == 'float' and isinstance(other, int): if self.dtype == 'float' and isinstance(other, int):
other = float(other) # Prevent runtime conversion of consts; TODO: add this for other operations other = float(other) # Prevent runtime conversion of consts; TODO: add this for other operations
if not isinstance(other, variable):
if other == 1:
return self
elif other == 0:
return 0
return add_op('mul', [self, other], True) return add_op('mul', [self, other], True)
@overload @overload
@ -177,7 +189,7 @@ class variable(Generic[TNum], Net):
@overload @overload
def __rmul__(self, other: float) -> 'variable[float]': ... def __rmul__(self, other: float) -> 'variable[float]': ...
def __rmul__(self, other: NumLike) -> Any: def __rmul__(self, other: NumLike) -> Any:
return add_op('mul', [self, other], True) return self * other
def __truediv__(self, other: NumLike) -> 'variable[float]': def __truediv__(self, other: NumLike) -> 'variable[float]':
return add_op('div', [self, other]) return add_op('div', [self, other])
@ -319,7 +331,8 @@ class CPConstant(Node):
def __init__(self, value: int | float): def __init__(self, value: int | float):
self.dtype, self.value = _get_data_and_dtype(value) self.dtype, self.value = _get_data_and_dtype(value)
self.name = 'const_' + self.dtype self.name = 'const_' + self.dtype
self.args = [] self.args = tuple()
self.node_hash = id(self)
class Write(Node): class Write(Node):
@ -331,14 +344,16 @@ class Write(Node):
net = Net(node.dtype, node) net = Net(node.dtype, node)
self.name = 'write_' + transl_type(net.dtype) self.name = 'write_' + transl_type(net.dtype)
self.args = [net] self.args = (net,)
self.node_hash = hash(self.name) ^ hash(net.source.node_hash)
class Op(Node): class Op(Node):
def __init__(self, typed_op_name: str, args: Sequence[Net]): def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False):
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
self.name: str = typed_op_name self.name: str = typed_op_name
self.args: Sequence[Net] = args self.args: tuple[Net, ...] = tuple(args)
self.node_hash = self.get_node_hash(commutative)
def net_from_value(value: Any) -> variable[Any]: def net_from_value(value: Any) -> variable[Any]:
@ -378,9 +393,9 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool =
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
if result_type == 'float': if result_type == 'float':
return variable[float](Op(typed_op, arg_nets), result_type) return variable[float](Op(typed_op, arg_nets, commutative), result_type)
else: else:
return variable[int](Op(typed_op, arg_nets), result_type) return variable[int](Op(typed_op, arg_nets, commutative), result_type)
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: