mirror of https://github.com/Nonannet/copapy.git
autograd updated with get_all_dag_edges_between to increase speed
This commit is contained in:
parent
a21970de79
commit
19fc403d37
|
|
@ -6,14 +6,14 @@ from ._basic_types import Net, unifloat
|
|||
|
||||
|
||||
@overload
|
||||
def grad(x: variable[Any], y: variable[Any]) -> unifloat: ...
|
||||
def grad(x: Any, y: variable[Any]) -> unifloat: ...
|
||||
@overload
|
||||
def grad(x: variable[Any], y: Sequence[variable[Any]]) -> list[unifloat]: ...
|
||||
def grad(x: Any, y: vector[Any]) -> vector[float]: ...
|
||||
@overload
|
||||
def grad(x: variable[Any], y: vector[Any]) -> vector[float]: ...
|
||||
def grad(x: Any, y: Sequence[variable[Any]]) -> list[unifloat]: ...
|
||||
@overload
|
||||
def grad(x: variable[Any], y: matrix[Any]) -> matrix[float]: ...
|
||||
def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[float]) -> Any:
|
||||
def grad(x: Any, y: matrix[Any]) -> matrix[float]: ...
|
||||
def grad(x: Any, y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[Any]) -> Any:
|
||||
"""Returns the partial derivative dx/dy where x needs to be a scalar
|
||||
and y might be a scalar, a list of scalars, a vector or matrix.
|
||||
|
||||
|
|
@ -24,7 +24,17 @@ def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[A
|
|||
Returns:
|
||||
Derivative of x with the type and dimensions of y
|
||||
"""
|
||||
edges = cpb.get_all_dag_edges([x.source])
|
||||
assert isinstance(x, variable), f"Argument x for grad function must be a variable but is {type(x)}."
|
||||
|
||||
if isinstance(y, variable):
|
||||
y_set = {y}
|
||||
if isinstance(y, matrix):
|
||||
y_set = {v for row in y for v in row}
|
||||
else:
|
||||
assert isinstance(y, Sequence) or isinstance(y, vector)
|
||||
y_set = {v for v in y}
|
||||
|
||||
edges = cpb.get_all_dag_edges_between([x.source], (net.source for net in y_set if isinstance(net, Net)))
|
||||
ordered_ops = cpb.stable_toposort(edges)
|
||||
|
||||
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
|
||||
|
|
@ -34,7 +44,7 @@ def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[A
|
|||
grad_dict[val] = grad_dict.get(val, 0.0) + gradient_value
|
||||
|
||||
for node in reversed(ordered_ops):
|
||||
print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
|
||||
#print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
|
||||
if node.args:
|
||||
args: Sequence[Any] = list(node.args)
|
||||
g = 1.0 if node is x.source else grad_dict[net_lookup[node]]
|
||||
|
|
|
|||
Loading…
Reference in New Issue