From cc5582ae737d414743195ca3fe46cb7439a141a0 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Tue, 2 Dec 2025 16:59:14 +0100 Subject: [PATCH] replaced list type by Sequence to improve type hinting --- src/copapy/_basic_types.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 99ee92c..e3ef3e5 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -1,5 +1,5 @@ import pkgutil -from typing import Any, TypeVar, overload, TypeAlias, Generic, cast +from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast from ._stencils import stencil_database, detect_process_arch import copapy as cp @@ -49,7 +49,7 @@ class Node: name (str): The name of the operation this Node represents. """ def __init__(self) -> None: - self.args: list[Net] = [] + self.args: Sequence[Net] = [] self.name: str = '' def __repr__(self) -> str: @@ -67,6 +67,7 @@ class Net: def __init__(self, dtype: str, source: Node): self.dtype = dtype self.source = source + self.grad: NumLike = 1 def __repr__(self) -> str: names = get_var_name(self) @@ -103,6 +104,8 @@ class variable(Generic[TNum], Net): else: self.source = CPConstant(source) self.dtype = 'int' + + self.grad = 1 @overload def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ... @@ -332,15 +335,15 @@ class Write(Node): class Op(Node): - def __init__(self, typed_op_name: str, args: list[Net]): + def __init__(self, typed_op_name: str, args: Sequence[Net]): assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' self.name: str = typed_op_name - self.args: list[Net] = args + self.args: Sequence[Net] = args -def net_from_value(value: Any) -> Net: +def net_from_value(value: Any) -> variable[Any]: vi = CPConstant(value) - return Net(vi.dtype, vi) + return variable(vi, vi.dtype) @overload