diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index a385874..61075ff 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -55,6 +55,43 @@ def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: return result +def get_all_dag_edges_between(roots: Iterable[Node], leaves: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]: + """Get all edges in the DAG connecting given roots with given leaves + + Arguments: + nodes: Iterable of nodes to start the traversal from + + Yields: + Tuples of (source_node, target_node) representing edges in the DAG + """ + # Walk the full DAG starting from given roots to final leaves + parent_lookup: dict[Node, set[Node]] = dict() + node_list: list[Node] = [n for n in roots] + while(node_list): + node = node_list.pop() + for net in node.args: + if net.source in parent_lookup: + parent_lookup[net.source].add(node) + else: + parent_lookup[net.source] = {node} + node_list.append(net.source) + + # Walk the DAG in reverse direction starting from given leaves to given roots + emitted_edges: set[tuple[Node, Node]] = set() + node_list = [n for n in leaves] + while(node_list): + child_node = node_list.pop() + if child_node in parent_lookup: + for node in parent_lookup[child_node]: + edge = (child_node, node) + if edge not in emitted_edges: + yield edge + node_list.append(node) + emitted_edges.add(edge) + + assert all(r in {e[0] for e in emitted_edges} for r in leaves) + + def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]: """Get all edges in the DAG by traversing from the given nodes @@ -64,15 +101,17 @@ def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], Non Yields: Tuples of (source_node, target_node) representing edges in the DAG """ - emitted_nodes: set[tuple[Node, Node]] = set() + emitted_edges: set[tuple[Node, Node]] = set() + node_list: list[Node] = [n for n in nodes] - for node in nodes: - yield from get_all_dag_edges(net.source for net in node.args) + while(node_list): + node = node_list.pop() for net in node.args: edge = (net.source, node) - if edge not in emitted_nodes: + if edge not in emitted_edges: yield edge - emitted_nodes.add(edge) + node_list.append(net.source) + emitted_edges.add(edge) def get_const_nets(nodes: list[Node]) -> list[Net]: diff --git a/src/copapy/backend.py b/src/copapy/backend.py index 839bef9..c03c76c 100644 --- a/src/copapy/backend.py +++ b/src/copapy/backend.py @@ -1,7 +1,7 @@ from ._target import add_read_command from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package from ._compiler import compile_to_dag, \ - stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, \ + stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, get_all_dag_edges_between, \ add_write_ops __all__ = [ @@ -15,6 +15,7 @@ __all__ = [ "stable_toposort", "get_const_nets", "get_all_dag_edges", + "get_all_dag_edges_between", "add_read_ops", "add_write_ops", "stencil_db_from_package"