mirror of https://github.com/Nonannet/copapy.git
added optimization for + 0, - 0 and * 1 operations
This commit is contained in:
parent
a30ee12d0f
commit
c5048980c2
|
|
@ -49,12 +49,22 @@ class Node:
|
||||||
name (str): The name of the operation this Node represents.
|
name (str): The name of the operation this Node represents.
|
||||||
"""
|
"""
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.args: Sequence[Net] = []
|
self.args: tuple[Net, ...] = tuple()
|
||||||
self.name: str = ''
|
self.name: str = ''
|
||||||
|
self.node_hash = 0
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
||||||
|
|
||||||
|
def get_node_hash(self, commutative: bool = False) -> int:
|
||||||
|
if commutative:
|
||||||
|
return hash(self.name) ^ hash(frozenset(a.source.node_hash for a in self.args))
|
||||||
|
return hash(self.name) ^ hash(tuple(a.source.node_hash for a in self.args))
|
||||||
|
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return self.node_hash
|
||||||
|
|
||||||
|
|
||||||
class Net:
|
class Net:
|
||||||
"""A Net represents a variable in the computation graph - or more generally it
|
"""A Net represents a variable in the computation graph - or more generally it
|
||||||
|
|
@ -67,14 +77,13 @@ class Net:
|
||||||
def __init__(self, dtype: str, source: Node):
|
def __init__(self, dtype: str, source: Node):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.source = source
|
self.source = source
|
||||||
self.grad: NumLike = 1
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
names = get_var_name(self)
|
names = get_var_name(self)
|
||||||
return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}"
|
return f"{'name:' + names[0] if names else 'id:' + str(hash(self))[-5:]}"
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return id(self)
|
return self.source.node_hash
|
||||||
|
|
||||||
|
|
||||||
class variable(Generic[TNum], Net):
|
class variable(Generic[TNum], Net):
|
||||||
|
|
@ -105,8 +114,6 @@ class variable(Generic[TNum], Net):
|
||||||
self.source = CPConstant(source)
|
self.source = CPConstant(source)
|
||||||
self.dtype = 'int'
|
self.dtype = 'int'
|
||||||
|
|
||||||
self.grad = 1
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
|
def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -118,7 +125,7 @@ class variable(Generic[TNum], Net):
|
||||||
@overload
|
@overload
|
||||||
def __add__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
|
def __add__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
|
||||||
def __add__(self, other: TVarNumb) -> Any:
|
def __add__(self, other: TVarNumb) -> Any:
|
||||||
if isinstance(other, int | float) and other == 0:
|
if not isinstance(other, variable) and other == 0:
|
||||||
return self
|
return self
|
||||||
return add_op('add', [self, other], True)
|
return add_op('add', [self, other], True)
|
||||||
|
|
||||||
|
|
@ -129,9 +136,7 @@ class variable(Generic[TNum], Net):
|
||||||
@overload
|
@overload
|
||||||
def __radd__(self, other: float) -> 'variable[float]': ...
|
def __radd__(self, other: float) -> 'variable[float]': ...
|
||||||
def __radd__(self, other: NumLike) -> Any:
|
def __radd__(self, other: NumLike) -> Any:
|
||||||
if isinstance(other, int | float) and other == 0:
|
return self + other
|
||||||
return self
|
|
||||||
return add_op('add', [self, other], True)
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __sub__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
|
def __sub__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
|
||||||
|
|
@ -144,6 +149,8 @@ class variable(Generic[TNum], Net):
|
||||||
@overload
|
@overload
|
||||||
def __sub__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
|
def __sub__(self, other: TVarNumb) -> 'variable[float] | variable[int]': ...
|
||||||
def __sub__(self, other: TVarNumb) -> Any:
|
def __sub__(self, other: TVarNumb) -> Any:
|
||||||
|
if isinstance(other, int | float) and other == 0:
|
||||||
|
return self
|
||||||
return add_op('sub', [self, other])
|
return add_op('sub', [self, other])
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -168,6 +175,11 @@ class variable(Generic[TNum], Net):
|
||||||
def __mul__(self, other: TVarNumb) -> Any:
|
def __mul__(self, other: TVarNumb) -> Any:
|
||||||
if self.dtype == 'float' and isinstance(other, int):
|
if self.dtype == 'float' and isinstance(other, int):
|
||||||
other = float(other) # Prevent runtime conversion of consts; TODO: add this for other operations
|
other = float(other) # Prevent runtime conversion of consts; TODO: add this for other operations
|
||||||
|
if not isinstance(other, variable):
|
||||||
|
if other == 1:
|
||||||
|
return self
|
||||||
|
elif other == 0:
|
||||||
|
return 0
|
||||||
return add_op('mul', [self, other], True)
|
return add_op('mul', [self, other], True)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -177,7 +189,7 @@ class variable(Generic[TNum], Net):
|
||||||
@overload
|
@overload
|
||||||
def __rmul__(self, other: float) -> 'variable[float]': ...
|
def __rmul__(self, other: float) -> 'variable[float]': ...
|
||||||
def __rmul__(self, other: NumLike) -> Any:
|
def __rmul__(self, other: NumLike) -> Any:
|
||||||
return add_op('mul', [self, other], True)
|
return self * other
|
||||||
|
|
||||||
def __truediv__(self, other: NumLike) -> 'variable[float]':
|
def __truediv__(self, other: NumLike) -> 'variable[float]':
|
||||||
return add_op('div', [self, other])
|
return add_op('div', [self, other])
|
||||||
|
|
@ -319,7 +331,8 @@ class CPConstant(Node):
|
||||||
def __init__(self, value: int | float):
|
def __init__(self, value: int | float):
|
||||||
self.dtype, self.value = _get_data_and_dtype(value)
|
self.dtype, self.value = _get_data_and_dtype(value)
|
||||||
self.name = 'const_' + self.dtype
|
self.name = 'const_' + self.dtype
|
||||||
self.args = []
|
self.args = tuple()
|
||||||
|
self.node_hash = id(self)
|
||||||
|
|
||||||
|
|
||||||
class Write(Node):
|
class Write(Node):
|
||||||
|
|
@ -331,14 +344,16 @@ class Write(Node):
|
||||||
net = Net(node.dtype, node)
|
net = Net(node.dtype, node)
|
||||||
|
|
||||||
self.name = 'write_' + transl_type(net.dtype)
|
self.name = 'write_' + transl_type(net.dtype)
|
||||||
self.args = [net]
|
self.args = (net,)
|
||||||
|
self.node_hash = hash(self.name) ^ hash(net.source.node_hash)
|
||||||
|
|
||||||
|
|
||||||
class Op(Node):
|
class Op(Node):
|
||||||
def __init__(self, typed_op_name: str, args: Sequence[Net]):
|
def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False):
|
||||||
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
|
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
|
||||||
self.name: str = typed_op_name
|
self.name: str = typed_op_name
|
||||||
self.args: Sequence[Net] = args
|
self.args: tuple[Net, ...] = tuple(args)
|
||||||
|
self.node_hash = self.get_node_hash(commutative)
|
||||||
|
|
||||||
|
|
||||||
def net_from_value(value: Any) -> variable[Any]:
|
def net_from_value(value: Any) -> variable[Any]:
|
||||||
|
|
@ -378,9 +393,9 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool =
|
||||||
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
||||||
|
|
||||||
if result_type == 'float':
|
if result_type == 'float':
|
||||||
return variable[float](Op(typed_op, arg_nets), result_type)
|
return variable[float](Op(typed_op, arg_nets, commutative), result_type)
|
||||||
else:
|
else:
|
||||||
return variable[int](Op(typed_op, arg_nets), result_type)
|
return variable[int](Op(typed_op, arg_nets, commutative), result_type)
|
||||||
|
|
||||||
|
|
||||||
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue