added get_all_dag_edges_between function

This commit is contained in:
Nicolas 2025-12-06 15:11:42 +01:00
parent 959d80b082
commit d526c5ddc0
2 changed files with 46 additions and 6 deletions

View File

@ -55,6 +55,43 @@ def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
return result
def get_all_dag_edges_between(roots: Iterable[Node], leaves: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
"""Get all edges in the DAG connecting given roots with given leaves
Arguments:
nodes: Iterable of nodes to start the traversal from
Yields:
Tuples of (source_node, target_node) representing edges in the DAG
"""
# Walk the full DAG starting from given roots to final leaves
parent_lookup: dict[Node, set[Node]] = dict()
node_list: list[Node] = [n for n in roots]
while(node_list):
node = node_list.pop()
for net in node.args:
if net.source in parent_lookup:
parent_lookup[net.source].add(node)
else:
parent_lookup[net.source] = {node}
node_list.append(net.source)
# Walk the DAG in reverse direction starting from given leaves to given roots
emitted_edges: set[tuple[Node, Node]] = set()
node_list = [n for n in leaves]
while(node_list):
child_node = node_list.pop()
if child_node in parent_lookup:
for node in parent_lookup[child_node]:
edge = (child_node, node)
if edge not in emitted_edges:
yield edge
node_list.append(node)
emitted_edges.add(edge)
assert all(r in {e[0] for e in emitted_edges} for r in leaves)
def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
"""Get all edges in the DAG by traversing from the given nodes
@ -64,15 +101,17 @@ def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], Non
Yields:
Tuples of (source_node, target_node) representing edges in the DAG
"""
emitted_nodes: set[tuple[Node, Node]] = set()
emitted_edges: set[tuple[Node, Node]] = set()
node_list: list[Node] = [n for n in nodes]
for node in nodes:
yield from get_all_dag_edges(net.source for net in node.args)
while(node_list):
node = node_list.pop()
for net in node.args:
edge = (net.source, node)
if edge not in emitted_nodes:
if edge not in emitted_edges:
yield edge
emitted_nodes.add(edge)
node_list.append(net.source)
emitted_edges.add(edge)
def get_const_nets(nodes: list[Node]) -> list[Net]:

View File

@ -1,7 +1,7 @@
from ._target import add_read_command
from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package
from ._compiler import compile_to_dag, \
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, \
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, get_all_dag_edges_between, \
add_write_ops
__all__ = [
@ -15,6 +15,7 @@ __all__ = [
"stable_toposort",
"get_const_nets",
"get_all_dag_edges",
"get_all_dag_edges_between",
"add_read_ops",
"add_write_ops",
"stencil_db_from_package"