mirror of https://github.com/Nonannet/copapy.git
added get_all_dag_edges_between function
This commit is contained in:
parent
959d80b082
commit
d526c5ddc0
|
|
@ -55,6 +55,43 @@ def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
|
|||
return result
|
||||
|
||||
|
||||
def get_all_dag_edges_between(roots: Iterable[Node], leaves: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
|
||||
"""Get all edges in the DAG connecting given roots with given leaves
|
||||
|
||||
Arguments:
|
||||
nodes: Iterable of nodes to start the traversal from
|
||||
|
||||
Yields:
|
||||
Tuples of (source_node, target_node) representing edges in the DAG
|
||||
"""
|
||||
# Walk the full DAG starting from given roots to final leaves
|
||||
parent_lookup: dict[Node, set[Node]] = dict()
|
||||
node_list: list[Node] = [n for n in roots]
|
||||
while(node_list):
|
||||
node = node_list.pop()
|
||||
for net in node.args:
|
||||
if net.source in parent_lookup:
|
||||
parent_lookup[net.source].add(node)
|
||||
else:
|
||||
parent_lookup[net.source] = {node}
|
||||
node_list.append(net.source)
|
||||
|
||||
# Walk the DAG in reverse direction starting from given leaves to given roots
|
||||
emitted_edges: set[tuple[Node, Node]] = set()
|
||||
node_list = [n for n in leaves]
|
||||
while(node_list):
|
||||
child_node = node_list.pop()
|
||||
if child_node in parent_lookup:
|
||||
for node in parent_lookup[child_node]:
|
||||
edge = (child_node, node)
|
||||
if edge not in emitted_edges:
|
||||
yield edge
|
||||
node_list.append(node)
|
||||
emitted_edges.add(edge)
|
||||
|
||||
assert all(r in {e[0] for e in emitted_edges} for r in leaves)
|
||||
|
||||
|
||||
def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
|
||||
"""Get all edges in the DAG by traversing from the given nodes
|
||||
|
||||
|
|
@ -64,15 +101,17 @@ def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], Non
|
|||
Yields:
|
||||
Tuples of (source_node, target_node) representing edges in the DAG
|
||||
"""
|
||||
emitted_nodes: set[tuple[Node, Node]] = set()
|
||||
emitted_edges: set[tuple[Node, Node]] = set()
|
||||
node_list: list[Node] = [n for n in nodes]
|
||||
|
||||
for node in nodes:
|
||||
yield from get_all_dag_edges(net.source for net in node.args)
|
||||
while(node_list):
|
||||
node = node_list.pop()
|
||||
for net in node.args:
|
||||
edge = (net.source, node)
|
||||
if edge not in emitted_nodes:
|
||||
if edge not in emitted_edges:
|
||||
yield edge
|
||||
emitted_nodes.add(edge)
|
||||
node_list.append(net.source)
|
||||
emitted_edges.add(edge)
|
||||
|
||||
|
||||
def get_const_nets(nodes: list[Node]) -> list[Net]:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from ._target import add_read_command
|
||||
from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package
|
||||
from ._compiler import compile_to_dag, \
|
||||
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, \
|
||||
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, get_all_dag_edges_between, \
|
||||
add_write_ops
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -15,6 +15,7 @@ __all__ = [
|
|||
"stable_toposort",
|
||||
"get_const_nets",
|
||||
"get_all_dag_edges",
|
||||
"get_all_dag_edges_between",
|
||||
"add_read_ops",
|
||||
"add_write_ops",
|
||||
"stencil_db_from_package"
|
||||
|
|
|
|||
Loading…
Reference in New Issue