mirror of https://github.com/Nonannet/copapy.git
replaced list type by Sequence to improve type hinting
This commit is contained in:
parent
d2df1dd3fb
commit
cc5582ae73
|
|
@ -1,5 +1,5 @@
|
|||
import pkgutil
|
||||
from typing import Any, TypeVar, overload, TypeAlias, Generic, cast
|
||||
from typing import Any, Sequence, TypeVar, overload, TypeAlias, Generic, cast
|
||||
from ._stencils import stencil_database, detect_process_arch
|
||||
import copapy as cp
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ class Node:
|
|||
name (str): The name of the operation this Node represents.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.args: list[Net] = []
|
||||
self.args: Sequence[Net] = []
|
||||
self.name: str = ''
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
@ -67,6 +67,7 @@ class Net:
|
|||
def __init__(self, dtype: str, source: Node):
|
||||
self.dtype = dtype
|
||||
self.source = source
|
||||
self.grad: NumLike = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
names = get_var_name(self)
|
||||
|
|
@ -104,6 +105,8 @@ class variable(Generic[TNum], Net):
|
|||
self.source = CPConstant(source)
|
||||
self.dtype = 'int'
|
||||
|
||||
self.grad = 1
|
||||
|
||||
@overload
|
||||
def __add__(self: 'variable[TNum]', other: 'variable[TNum] | TNum') -> 'variable[TNum]': ...
|
||||
@overload
|
||||
|
|
@ -332,15 +335,15 @@ class Write(Node):
|
|||
|
||||
|
||||
class Op(Node):
|
||||
def __init__(self, typed_op_name: str, args: list[Net]):
|
||||
def __init__(self, typed_op_name: str, args: Sequence[Net]):
|
||||
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
|
||||
self.name: str = typed_op_name
|
||||
self.args: list[Net] = args
|
||||
self.args: Sequence[Net] = args
|
||||
|
||||
|
||||
def net_from_value(value: Any) -> Net:
|
||||
def net_from_value(value: Any) -> variable[Any]:
|
||||
vi = CPConstant(value)
|
||||
return Net(vi.dtype, vi)
|
||||
return variable(vi, vi.dtype)
|
||||
|
||||
|
||||
@overload
|
||||
|
|
|
|||
Loading…
Reference in New Issue