mirror of https://github.com/Nonannet/copapy.git
added optimization for + 0, - 0 and * 1 operations
This commit is contained in:
parent
a30ee12d0f
commit
c5048980c2
|
|
@ -49,12 +49,22 @@ class Node:
|
|||
name (str): The name of the operation this Node represents.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.args: Sequence[Net] = []
|
||||
self.args: tuple[Net, ...] = tuple()
|
||||
self.name: str = ''
|
||||
self.node_hash = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
||||
|
||||
def get_node_hash(self, commutative: bool = False) -> int:
|
||||
if commutative:
|
||||
return hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
|
||||
return hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
|
||||
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.node_hash
|
||||
|
||||
|
||||
class Net:
|
||||
"""A Net represents a variable in the computation graph - or more generally it
|
||||
|
|
@ -67,14 +77,13 @@ class Net:
|
|||
def __init__(self, dtype: str, source: Node):
|
||||
self.dtype = dtype
|
||||
self.source = source
|
||||
self.grad: NumLike = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
names = get_var_name(self)
|
||||
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}"
|
||||
return f"{'name:' + names[0] if names else 'id:' + str(hash(self))[-5:]}"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
return self.source.node_hash
|
||||
|
||||
|
||||
class variable(Generic[TNum], Net):
|
||||
|
|
@ -105,8 +114,6 @@ class variable(Generic[TNum], Net):
|
|||
self.source = CPConstant(source)
|
||||
self.dtype = 'int'
|
||||
|
||||
self.grad = 1
|
||||
|
||||
@overload
|
||||
def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
|
||||
@overload
|
||||
|
|
@ -118,7 +125,7 @@ class variable(Generic[TNum], Net):
|
|||
@overload
|
||||
def __add__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
|
||||
def __add__(self, other: TVarNumb) -> Any:
|
||||
if isinstance(other, int | float) and other == 0:
|
||||
if not isinstance(other, variable) and other == 0:
|
||||
return self
|
||||
return add_op('add', [self, other], True)
|
||||
|
||||
|
|
@ -129,9 +136,7 @@ class variable(Generic[TNum], Net):
|
|||
@overload
|
||||
def __radd__(self, other: float) -> 'variable[float]': ...
|
||||
def __radd__(self, other: NumLike) -> Any:
|
||||
if isinstance(other, int | float) and other == 0:
|
||||
return self
|
||||
return add_op('add', [self, other], True)
|
||||
return self + other
|
||||
|
||||
@overload
|
||||
def __sub__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
|
||||
|
|
@ -144,6 +149,8 @@ class variable(Generic[TNum], Net):
|
|||
@overload
|
||||
def __sub__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
|
||||
def __sub__(self, other: TVarNumb) -> Any:
|
||||
if isinstance(other, int | float) and other == 0:
|
||||
return self
|
||||
return add_op('sub', [self, other])
|
||||
|
||||
@overload
|
||||
|
|
@ -168,6 +175,11 @@ class variable(Generic[TNum], Net):
|
|||
def __mul__(self, other: TVarNumb) -> Any:
|
||||
if self.dtype == 'float' and isinstance(other, int):
|
||||
other = float(other) # Prevent runtime conversion of consts; TODO: add this for other operations
|
||||
if not isinstance(other, variable):
|
||||
if other == 1:
|
||||
return self
|
||||
elif other == 0:
|
||||
return 0
|
||||
return add_op('mul', [self, other], True)
|
||||
|
||||
@overload
|
||||
|
|
@ -177,7 +189,7 @@ class variable(Generic[TNum], Net):
|
|||
@overload
|
||||
def __rmul__(self, other: float) -> 'variable[float]': ...
|
||||
def __rmul__(self, other: NumLike) -> Any:
|
||||
return add_op('mul', [self, other], True)
|
||||
return self * other
|
||||
|
||||
def __truediv__(self, other: NumLike) -> 'variable[float]':
|
||||
return add_op('div', [self, other])
|
||||
|
|
@ -319,7 +331,8 @@ class CPConstant(Node):
|
|||
def __init__(self, value: int | float):
|
||||
self.dtype, self.value = _get_data_and_dtype(value)
|
||||
self.name = 'const_' + self.dtype
|
||||
self.args = []
|
||||
self.args = tuple()
|
||||
self.node_hash = id(self)
|
||||
|
||||
|
||||
class Write(Node):
|
||||
|
|
@ -331,14 +344,16 @@ class Write(Node):
|
|||
net = Net(node.dtype, node)
|
||||
|
||||
self.name = 'write_' + transl_type(net.dtype)
|
||||
self.args = [net]
|
||||
self.args = (net,)
|
||||
self.node_hash = hash(self.name) ^ hash(net.source.node_hash)
|
||||
|
||||
|
||||
class Op(Node):
|
||||
def __init__(self, typed_op_name: str, args: Sequence[Net]):
|
||||
def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False):
|
||||
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
|
||||
self.name: str = typed_op_name
|
||||
self.args: Sequence[Net] = args
|
||||
self.args: tuple[Net, ...] = tuple(args)
|
||||
self.node_hash = self.get_node_hash(commutative)
|
||||
|
||||
|
||||
def net_from_value(value: Any) -> variable[Any]:
|
||||
|
|
@ -378,9 +393,9 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool =
|
|||
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
||||
|
||||
if result_type == 'float':
|
||||
return variable[float](Op(typed_op, arg_nets), result_type)
|
||||
return variable[float](Op(typed_op, arg_nets, commutative), result_type)
|
||||
else:
|
||||
return variable[int](Op(typed_op, arg_nets), result_type)
|
||||
return variable[int](Op(typed_op, arg_nets, commutative), result_type)
|
||||
|
||||
|
||||
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue