From 19fc403d37cb85c7fcc46b98fc97a4073989d283 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Sat, 6 Dec 2025 15:14:06 +0100 Subject: [PATCH] autograd updated with get_all_dag_edges_between to increase speed --- src/copapy/_autograd.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/copapy/_autograd.py b/src/copapy/_autograd.py index 685d9e4..ea5e9c4 100644 --- a/src/copapy/_autograd.py +++ b/src/copapy/_autograd.py @@ -6,14 +6,14 @@ from ._basic_types import Net, unifloat @overload -def grad(x: variable[Any], y: variable[Any]) -> unifloat: ... +def grad(x: Any, y: variable[Any]) -> unifloat: ... @overload -def grad(x: variable[Any], y: Sequence[variable[Any]]) -> list[unifloat]: ... +def grad(x: Any, y: vector[Any]) -> vector[float]: ... @overload -def grad(x: variable[Any], y: vector[Any]) -> vector[float]: ... +def grad(x: Any, y: Sequence[variable[Any]]) -> list[unifloat]: ... @overload -def grad(x: variable[Any], y: matrix[Any]) -> matrix[float]: ... -def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[float]) -> Any: +def grad(x: Any, y: matrix[Any]) -> matrix[float]: ... +def grad(x: Any, y: variable[Any] | Sequence[variable[Any]] | vector[Any] | matrix[Any]) -> Any: """Returns the partial derivative dx/dy where x needs to be a scalar and y might be a scalar, a list of scalars, a vector or matrix. @@ -24,7 +24,17 @@ def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[A Returns: Derivative of x with the type and dimensions of y """ - edges = cpb.get_all_dag_edges([x.source]) + assert isinstance(x, variable), f"Argument x for grad function must be a variable but is {type(x)}." + + if isinstance(y, variable): + y_set = {y} + if isinstance(y, matrix): + y_set = {v for row in y for v in row} + else: + assert isinstance(y, Sequence) or isinstance(y, vector) + y_set = {v for v in y} + + edges = cpb.get_all_dag_edges_between([x.source], (net.source for net in y_set if isinstance(net, Net))) ordered_ops = cpb.stable_toposort(edges) net_lookup = {net.source: net for node in ordered_ops for net in node.args} @@ -34,7 +44,7 @@ def grad(x: variable[Any], y: variable[Any] | Sequence[variable[Any]] | vector[A grad_dict[val] = grad_dict.get(val, 0.0) + gradient_value for node in reversed(ordered_ops): - print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}") + #print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}") if node.args: args: Sequence[Any] = list(node.args) g = 1.0 if node is x.source else grad_dict[net_lookup[node]]