diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 61075ff..fbb9125 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -299,6 +299,17 @@ def get_aux_func_layout(function_names: Iterable[str], sdb: stencil_database, of return section_list, function_lookup, offset +def get_dag_stats(node_list: Iterable[Node | Net]) -> dict[str, int]: + edges = get_all_dag_edges(n.source if isinstance(n, Net) else n for n in node_list) + ops = {node for node, _ in edges} + + op_stat: dict[str, int] = {} + for op in ops: + op_stat[op.name] = op_stat.get(op.name, 0) + 1 + + return op_stat + + def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: """Compiles a DAG identified by provided end nodes to binary code diff --git a/src/copapy/backend.py b/src/copapy/backend.py index c03c76c..f494d1d 100644 --- a/src/copapy/backend.py +++ b/src/copapy/backend.py @@ -2,7 +2,7 @@ from ._target import add_read_command from ._basic_types import Net, Op, Node, CPConstant, Write, stencil_db_from_package from ._compiler import compile_to_dag, \ stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, get_all_dag_edges_between, \ - add_write_ops + add_write_ops, get_dag_stats __all__ = [ "add_read_command", @@ -18,5 +18,6 @@ __all__ = [ "get_all_dag_edges_between", "add_read_ops", "add_write_ops", - "stencil_db_from_package" + "stencil_db_from_package", + "get_dag_stats" ] diff --git a/tests/test_dag_optimization.py b/tests/test_dag_optimization.py new file mode 100644 index 0000000..5d705ee --- /dev/null +++ b/tests/test_dag_optimization.py @@ -0,0 +1,24 @@ +import copapy as cp +from copapy._basic_types import value +from copapy.backend import get_dag_stats + +def test_get_dag_stats(): + + sum_size = 10 + v_size = 200 + + v1 = cp.vector(cp.value(float(v)) for v in range(v_size)) + v2 = cp.vector(cp.value(float(v)) for v in [5]*v_size) + + v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size)) + + assert isinstance(v3, value) + stat = get_dag_stats([v3]) + print(stat) + + assert stat['const_float'] == 2 * v_size + assert stat['add_float_float'] == sum_size * v_size - 2 + + +if __name__ == "__main__": + test_get_dag_stats() \ No newline at end of file