replaced list type by Sequence to improve type hinting

This commit is contained in:
Nicolas Kruse 2025-12-02 16:59:14 +01:00
parent d2df1dd3fb
commit cc5582ae73
1 changed files with 9 additions and 6 deletions

View File

@ -1,5 +1,5 @@
import pkgutil import pkgutil
from typing import Any, TypeVar, overload, TypeAlias, Generic, cast from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast
from ._stencils import stencil_database, detect_process_arch from ._stencils import stencil_database, detect_process_arch
import copapy as cp import copapy as cp
@ -49,7 +49,7 @@ class Node:
name (str): The name of the operation this Node represents. name (str): The name of the operation this Node represents.
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.args: list[Net] = [] self.args: Sequence[Net] = []
self.name: str = '' self.name: str = ''
def __repr__(self) -> str: def __repr__(self) -> str:
@ -67,6 +67,7 @@ class Net:
def __init__(self, dtype: str, source: Node): def __init__(self, dtype: str, source: Node):
self.dtype = dtype self.dtype = dtype
self.source = source self.source = source
self.grad: NumLike = 1
def __repr__(self) -> str: def __repr__(self) -> str:
names = get_var_name(self) names = get_var_name(self)
@ -103,6 +104,8 @@ class variable(Generic[TNum], Net):
else: else:
self.source = CPConstant(source) self.source = CPConstant(source)
self.dtype = 'int' self.dtype = 'int'
self.grad = 1
@overload @overload
def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ... def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
@ -332,15 +335,15 @@ class Write(Node):
class Op(Node): class Op(Node):
def __init__(self, typed_op_name: str, args: list[Net]): def __init__(self, typed_op_name: str, args: Sequence[Net]):
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
self.name: str = typed_op_name self.name: str = typed_op_name
self.args: list[Net] = args self.args: Sequence[Net] = args
def net_from_value(value: Any) -> Net: def net_from_value(value: Any) -> variable[Any]:
vi = CPConstant(value) vi = CPConstant(value)
return Net(vi.dtype, vi) return variable(vi, vi.dtype)
@overload @overload