test for dag optimization extended

This commit is contained in:
Nicolas 2025-12-24 14:10:46 +01:00
parent 33fea7e354
commit 662a168d90
1 changed files with 52 additions and 3 deletions

View File

@ -1,6 +1,41 @@
import copapy as cp import copapy as cp
from copapy._basic_types import value from copapy import value
from copapy.backend import get_dag_stats from copapy.backend import get_dag_stats, Write
import copapy.backend as cpb
from typing import Any
def show_dag(val: value[Any]):
out = [Write(val.net)]
print(out)
print('-- get_edges:')
edges = list(cpb.get_all_dag_edges(out))
for p in edges:
print('#', p)
print('-- get_ordered_ops:')
ordered_ops = cpb.stable_toposort(edges)
for p in ordered_ops:
print('#', p)
print('-- get_consts:')
const_list = cpb.get_const_nets(ordered_ops)
for p in const_list:
print('#', p)
print('-- add_read_ops:')
output_ops = list(cpb.add_read_ops(ordered_ops))
for p in output_ops:
print('#', p)
print('-- add_write_ops:')
extended_output_ops = list(cpb.add_write_ops(output_ops, const_list))
for p in extended_output_ops:
print('#', p)
print('--')
def test_get_dag_stats(): def test_get_dag_stats():
@ -20,5 +55,19 @@ def test_get_dag_stats():
assert stat['add_float_float'] == sum_size * v_size - 2 assert stat['add_float_float'] == sum_size * v_size - 2
def test_dag_reduction():
a = value(8)
v3 = (a * 3 + 7 + 2) + (a * 3 + 7 + 2)
show_dag(v3)
assert isinstance(v3, value)
stat = get_dag_stats([v3.net])
print(stat)
if __name__ == "__main__": if __name__ == "__main__":
test_get_dag_stats() test_get_dag_stats()
test_dag_reduction()