From 3c5f01db7f170e4252ed5280d91df2ff3d44b8ed Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Fri, 19 Dec 2025 18:20:29 +0100 Subject: [PATCH] auto stripping of equal graph branches added --- src/copapy/_basic_types.py | 10 +++++----- src/copapy/_compiler.py | 13 ++++++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 9b7455c..f03e259 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -105,13 +105,13 @@ class value(Generic[TNum], Net): assert dtype, 'For source type Node a dtype argument is required.' self.dtype = dtype elif isinstance(source, float): - self.source = CPConstant(source) + self.source = CPConstant(source, False) self.dtype = 'float' elif isinstance(source, bool): - self.source = CPConstant(source) + self.source = CPConstant(source, False) self.dtype = 'bool' else: - self.source = CPConstant(source) + self.source = CPConstant(source, False) self.dtype = 'int' self.volatile = volatile @@ -332,11 +332,11 @@ class value(Generic[TNum], Net): class CPConstant(Node): - def __init__(self, value: int | float): + def __init__(self, value: int | float, constant: bool = True): self.dtype, self.value = _get_data_and_dtype(value) self.name = 'const_' + self.dtype self.args = tuple() - self.node_hash = id(self) + self.node_hash = hash(value) if constant else id(self) class Write(Node): diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index fbb9125..124755c 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -102,16 +102,19 @@ def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], Non Tuples of (source_node, target_node) representing edges in the DAG """ emitted_edges: set[tuple[Node, Node]] = set() + used_nets: set[Net] = set() node_list: list[Node] = [n for n in nodes] while(node_list): node = node_list.pop() for net in node.args: - edge = (net.source, node) - if edge not in emitted_edges: - yield edge - node_list.append(net.source) - emitted_edges.add(edge) + if net not in used_nets: + used_nets.add(net) + edge = (net.source, node) + if edge not in emitted_edges: + yield edge + node_list.append(net.source) + emitted_edges.add(edge) def get_const_nets(nodes: list[Node]) -> list[Net]: