mirror of https://github.com/Nonannet/copapy.git
get_dag_stats function added to inspect DAG
This commit is contained in:
parent
247fc1a28f
commit
d436dd9116
|
|
@ -299,6 +299,17 @@ def get_aux_func_layout(function_names: Iterable[str], sdb: stencil_database, of
|
|||
return section_list, function_lookup, offset
|
||||
|
||||
|
||||
def get_dag_stats(node_list: Iterable[Node | Net]) -> dict[str, int]:
|
||||
edges = get_all_dag_edges(n.source if isinstance(n, Net) else n for n in node_list)
|
||||
ops = {node for node, _ in edges}
|
||||
|
||||
op_stat: dict[str, int] = {}
|
||||
for op in ops:
|
||||
op_stat[op.name] = op_stat.get(op.name, 0) + 1
|
||||
|
||||
return op_stat
|
||||
|
||||
|
||||
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||
"""Compiles a DAG identified by provided end nodes to binary code
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from ._target import add_read_command
|
|||
from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package
|
||||
from ._compiler import compile_to_dag, \
|
||||
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, get_all_dag_edges_between, \
|
||||
add_write_ops
|
||||
add_write_ops, get_dag_stats
|
||||
|
||||
__all__ = [
|
||||
"add_read_command",
|
||||
|
|
@ -18,5 +18,6 @@ __all__ = [
|
|||
"get_all_dag_edges_between",
|
||||
"add_read_ops",
|
||||
"add_write_ops",
|
||||
"stencil_db_from_package"
|
||||
"stencil_db_from_package",
|
||||
"get_dag_stats"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
import copapy as cp
|
||||
from copapy._basic_types import value
|
||||
from copapy.backend import get_dag_stats
|
||||
|
||||
def test_get_dag_stats():
|
||||
|
||||
sum_size = 10
|
||||
v_size = 200
|
||||
|
||||
v1 = cp.vector(cp.value(float(v)) for v in range(v_size))
|
||||
v2 = cp.vector(cp.value(float(v)) for v in [5]*v_size)
|
||||
|
||||
v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size))
|
||||
|
||||
assert isinstance(v3, value)
|
||||
stat = get_dag_stats([v3])
|
||||
print(stat)
|
||||
|
||||
assert stat['const_float'] == 2 * v_size
|
||||
assert stat['add_float_float'] == sum_size * v_size - 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_get_dag_stats()
|
||||
Loading…
Reference in New Issue