mirror of https://github.com/Nonannet/copapy.git
get_dag_stats function added to inspect DAG
This commit is contained in:
parent
247fc1a28f
commit
d436dd9116
|
|
@ -299,6 +299,17 @@ def get_aux_func_layout(function_names: Iterable[str], sdb: stencil_database, of
|
||||||
return section_list, function_lookup, offset
|
return section_list, function_lookup, offset
|
||||||
|
|
||||||
|
|
||||||
|
def get_dag_stats(node_list: Iterable[Node | Net]) -> dict[str, int]:
|
||||||
|
edges = get_all_dag_edges(n.source if isinstance(n, Net) else n for n in node_list)
|
||||||
|
ops = {node for node, _ in edges}
|
||||||
|
|
||||||
|
op_stat: dict[str, int] = {}
|
||||||
|
for op in ops:
|
||||||
|
op_stat[op.name] = op_stat.get(op.name, 0) + 1
|
||||||
|
|
||||||
|
return op_stat
|
||||||
|
|
||||||
|
|
||||||
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||||
"""Compiles a DAG identified by provided end nodes to binary code
|
"""Compiles a DAG identified by provided end nodes to binary code
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from ._target import add_read_command
|
||||||
from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package
|
from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package
|
||||||
from ._compiler import compile_to_dag, \
|
from ._compiler import compile_to_dag, \
|
||||||
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, get_all_dag_edges_between, \
|
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, get_all_dag_edges_between, \
|
||||||
add_write_ops
|
add_write_ops, get_dag_stats
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"add_read_command",
|
"add_read_command",
|
||||||
|
|
@ -18,5 +18,6 @@ __all__ = [
|
||||||
"get_all_dag_edges_between",
|
"get_all_dag_edges_between",
|
||||||
"add_read_ops",
|
"add_read_ops",
|
||||||
"add_write_ops",
|
"add_write_ops",
|
||||||
"stencil_db_from_package"
|
"stencil_db_from_package",
|
||||||
|
"get_dag_stats"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
import copapy as cp
|
||||||
|
from copapy._basic_types import value
|
||||||
|
from copapy.backend import get_dag_stats
|
||||||
|
|
||||||
|
def test_get_dag_stats():
|
||||||
|
|
||||||
|
sum_size = 10
|
||||||
|
v_size = 200
|
||||||
|
|
||||||
|
v1 = cp.vector(cp.value(float(v)) for v in range(v_size))
|
||||||
|
v2 = cp.vector(cp.value(float(v)) for v in [5]*v_size)
|
||||||
|
|
||||||
|
v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size))
|
||||||
|
|
||||||
|
assert isinstance(v3, value)
|
||||||
|
stat = get_dag_stats([v3])
|
||||||
|
print(stat)
|
||||||
|
|
||||||
|
assert stat['const_float'] == 2 * v_size
|
||||||
|
assert stat['add_float_float'] == sum_size * v_size - 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_get_dag_stats()
|
||||||
Loading…
Reference in New Issue