autograd updated with get_all_dag_edges_between to increase speed

This commit is contained in:
Nicolas 2025-12-06 15:14:06 +01:00
parent a21970de79
commit 19fc403d37
1 changed files with 17 additions and 7 deletions

View File

@ -6,14 +6,14 @@ from ._basic_types import Net, unifloat
@overload
def grad(x: variable[Any], y: variable[Any]) -> unifloat: ...
def grad(x: Any, y: variable[Any]) -> unifloat: ...
@overload
def grad(x: variable[Any], y: Sequence[variable[Any]]) -> list[unifloat]: ...
def grad(x: Any, y: vector[Any]) -> vector[float]: ...
@overload
def grad(x: variable[Any], y: vector[Any]) -> vector[float]: ...
def grad(x: Any, y: Sequence[variable[Any]]) -> list[unifloat]: ...
@overload
def grad(x: variable[Any], y: matrix[Any]) -> matrix[float]: ...
def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[float]) -> Any:
def grad(x: Any, y: matrix[Any]) -> matrix[float]: ...
def grad(x: Any, y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[Any]) -> Any:
"""Returns the partial derivative dx/dy where x needs to be a scalar
and y might be a scalar, a list of scalars, a vector or matrix.
@ -24,7 +24,17 @@ def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[A
Returns:
Derivative of x with the type and dimensions of y
"""
edges = cpb.get_all_dag_edges([x.source])
assert isinstance(x, variable), f"Argument x for grad function must be a variable but is {type(x)}."
if isinstance(y, variable):
y_set = {y}
if isinstance(y, matrix):
y_set = {v for row in y for v in row}
else:
assert isinstance(y, Sequence) or isinstance(y, vector)
y_set = {v for v in y}
edges = cpb.get_all_dag_edges_between([x.source], (net.source for net in y_set if isinstance(net, Net)))
ordered_ops = cpb.stable_toposort(edges)
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
@ -34,7 +44,7 @@ def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[A
grad_dict[val] = grad_dict.get(val, 0.0) + gradient_value
for node in reversed(ordered_ops):
print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
#print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
if node.args:
args: Sequence[Any] = list(node.args)
g = 1.0 if node is x.source else grad_dict[net_lookup[node]]