diff --git a/tests/test_dag_optimization.py b/tests/test_dag_optimization.py index bc4db4e..a340e19 100644 --- a/tests/test_dag_optimization.py +++ b/tests/test_dag_optimization.py @@ -1,6 +1,41 @@ import copapy as cp -from copapy._basic_types import value -from copapy.backend import get_dag_stats +from copapy import value +from copapy.backend import get_dag_stats, Write +import copapy.backend as cpb +from typing import Any + + +def show_dag(val: value[Any]): + out = [Write(val.net)] + + print(out) + print('-- get_edges:') + + edges = list(cpb.get_all_dag_edges(out)) + for p in edges: + print('#', p) + + print('-- get_ordered_ops:') + ordered_ops = cpb.stable_toposort(edges) + for p in ordered_ops: + print('#', p) + + print('-- get_consts:') + const_list = cpb.get_const_nets(ordered_ops) + for p in const_list: + print('#', p) + + print('-- add_read_ops:') + output_ops = list(cpb.add_read_ops(ordered_ops)) + for p in output_ops: + print('#', p) + + print('-- add_write_ops:') + extended_output_ops = list(cpb.add_write_ops(output_ops, const_list)) + for p in extended_output_ops: + print('#', p) + print('--') + def test_get_dag_stats(): @@ -20,5 +55,19 @@ def test_get_dag_stats(): assert stat['add_float_float'] == sum_size * v_size - 2 +def test_dag_reduction(): + + a = value(8) + + v3 = (a * 3 + 7 + 2) + (a * 3 + 7 + 2) + + show_dag(v3) + + assert isinstance(v3, value) + stat = get_dag_stats([v3.net]) + print(stat) + + if __name__ == "__main__": - test_get_dag_stats() \ No newline at end of file + test_get_dag_stats() + test_dag_reduction() \ No newline at end of file