mirror of https://github.com/Nonannet/copapy.git
added get_all_dag_edges_between function
This commit is contained in:
parent
959d80b082
commit
d526c5ddc0
|
|
@ -55,6 +55,43 @@ def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_dag_edges_between(roots: Iterable[Node], leaves: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
|
||||||
|
"""Get all edges in the DAG connecting given roots with given leaves
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
nodes: Iterable of nodes to start the traversal from
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuples of (source_node, target_node) representing edges in the DAG
|
||||||
|
"""
|
||||||
|
# Walk the full DAG starting from given roots to final leaves
|
||||||
|
parent_lookup: dict[Node, set[Node]] = dict()
|
||||||
|
node_list: list[Node] = [n for n in roots]
|
||||||
|
while(node_list):
|
||||||
|
node = node_list.pop()
|
||||||
|
for net in node.args:
|
||||||
|
if net.source in parent_lookup:
|
||||||
|
parent_lookup[net.source].add(node)
|
||||||
|
else:
|
||||||
|
parent_lookup[net.source] = {node}
|
||||||
|
node_list.append(net.source)
|
||||||
|
|
||||||
|
# Walk the DAG in reverse direction starting from given leaves to given roots
|
||||||
|
emitted_edges: set[tuple[Node, Node]] = set()
|
||||||
|
node_list = [n for n in leaves]
|
||||||
|
while(node_list):
|
||||||
|
child_node = node_list.pop()
|
||||||
|
if child_node in parent_lookup:
|
||||||
|
for node in parent_lookup[child_node]:
|
||||||
|
edge = (child_node, node)
|
||||||
|
if edge not in emitted_edges:
|
||||||
|
yield edge
|
||||||
|
node_list.append(node)
|
||||||
|
emitted_edges.add(edge)
|
||||||
|
|
||||||
|
assert all(r in {e[0] for e in emitted_edges} for r in leaves)
|
||||||
|
|
||||||
|
|
||||||
def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
|
def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]:
|
||||||
"""Get all edges in the DAG by traversing from the given nodes
|
"""Get all edges in the DAG by traversing from the given nodes
|
||||||
|
|
||||||
|
|
@ -64,15 +101,17 @@ def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], Non
|
||||||
Yields:
|
Yields:
|
||||||
Tuples of (source_node, target_node) representing edges in the DAG
|
Tuples of (source_node, target_node) representing edges in the DAG
|
||||||
"""
|
"""
|
||||||
emitted_nodes: set[tuple[Node, Node]] = set()
|
emitted_edges: set[tuple[Node, Node]] = set()
|
||||||
|
node_list: list[Node] = [n for n in nodes]
|
||||||
|
|
||||||
for node in nodes:
|
while(node_list):
|
||||||
yield from get_all_dag_edges(net.source for net in node.args)
|
node = node_list.pop()
|
||||||
for net in node.args:
|
for net in node.args:
|
||||||
edge = (net.source, node)
|
edge = (net.source, node)
|
||||||
if edge not in emitted_nodes:
|
if edge not in emitted_edges:
|
||||||
yield edge
|
yield edge
|
||||||
emitted_nodes.add(edge)
|
node_list.append(net.source)
|
||||||
|
emitted_edges.add(edge)
|
||||||
|
|
||||||
|
|
||||||
def get_const_nets(nodes: list[Node]) -> list[Net]:
|
def get_const_nets(nodes: list[Node]) -> list[Net]:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from ._target import add_read_command
|
from ._target import add_read_command
|
||||||
from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package
|
from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package
|
||||||
from ._compiler import compile_to_dag, \
|
from ._compiler import compile_to_dag, \
|
||||||
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, \
|
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, get_all_dag_edges_between, \
|
||||||
add_write_ops
|
add_write_ops
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -15,6 +15,7 @@ __all__ = [
|
||||||
"stable_toposort",
|
"stable_toposort",
|
||||||
"get_const_nets",
|
"get_const_nets",
|
||||||
"get_all_dag_edges",
|
"get_all_dag_edges",
|
||||||
|
"get_all_dag_edges_between",
|
||||||
"add_read_ops",
|
"add_read_ops",
|
||||||
"add_write_ops",
|
"add_write_ops",
|
||||||
"stencil_db_from_package"
|
"stencil_db_from_package"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue