Docstrings updated

This commit is contained in:
Nicolas Kruse 2025-12-20 22:59:31 +01:00
parent a8e70cd5d6
commit c75b4788c3
9 changed files with 56 additions and 21 deletions

View File

@ -95,7 +95,7 @@ class value(Generic[TNum], Net):
def __init__(self, source: TNum | Node, dtype: str | None = None): def __init__(self, source: TNum | Node, dtype: str | None = None):
"""Instance a value. """Instance a value.
Args: Arguments:
source: A numeric value or Node object. source: A numeric value or Node object.
dtype: Data type of this value. Required if source is a Node. dtype: Data type of this value. Required if source is a Node.
""" """
@ -376,6 +376,17 @@ def iif(expression: float | int, true_result: value[TNum], false_result: TNum |
@overload @overload
def iif(expression: float | int | value[Any], true_result: TNum | value[TNum], false_result: TNum | value[TNum]) -> value[TNum] | TNum: ... def iif(expression: float | int | value[Any], true_result: TNum | value[TNum], false_result: TNum | value[TNum]) -> value[TNum] | TNum: ...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any: def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
"""Inline if-else operation. Returns true_result if expression is non-zero,
else returns false_result.
Arguments:
expression: The condition to evaluate.
true_result: The result if expression is non-zero.
false_result: The result if expression is zero.
Returns:
The selected result based on the evaluation of expression.
"""
allowed_type = (value, int, float) allowed_type = (value, int, float)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported" assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
return (expression != 0) * true_result + (expression == 0) * false_result return (expression != 0) * true_result + (expression == 0) * false_result

View File

@ -308,6 +308,14 @@ def get_aux_func_layout(function_names: Iterable[str], sdb: stencil_database, of
def get_dag_stats(node_list: Iterable[Node | Net]) -> dict[str, int]: def get_dag_stats(node_list: Iterable[Node | Net]) -> dict[str, int]:
"""Get operation statistics for the DAG identified by provided end nodes
Arguments:
node_list: List of end nodes of the DAG
Returns:
Dictionary of operation name to occurrence count
"""
edges = get_all_dag_edges(n.source if isinstance(n, Net) else n for n in node_list) edges = get_all_dag_edges(n.source if isinstance(n, Net) else n for n in node_list)
ops = {node for node, _ in edges} ops = {node for node, _ in edges}

View File

@ -280,7 +280,6 @@ def get_42(x: NumLike) -> value[float] | float:
return float((int(x) * 3.0 + 42.0) * 5.0 + 21.0) return float((int(x) * 3.0 + 42.0) * 5.0 + 21.0)
#TODO: Add vector support
@overload @overload
def abs(x: U) -> U: ... def abs(x: U) -> U: ...
@overload @overload
@ -296,9 +295,8 @@ def abs(x: U | value[U] | vector[U]) -> Any:
Returns: Returns:
Absolute value of x Absolute value of x
""" """
#tt = -x * (x < 0)
ret = (x < 0) * -x + (x >= 0) * x ret = (x < 0) * -x + (x >= 0) * x
return ret # REMpyright: ignore[reportReturnType] return ret
@overload @overload

View File

@ -16,7 +16,7 @@ class matrix(Generic[TNum]):
def __init__(self, values: Iterable[Iterable[TNum | value[TNum]]] | vector[TNum]): def __init__(self, values: Iterable[Iterable[TNum | value[TNum]]] | vector[TNum]):
"""Create a matrix with given values. """Create a matrix with given values.
Args: Arguments:
values: iterable of iterable of constant values values: iterable of iterable of constant values
""" """
if isinstance(values, vector): if isinstance(values, vector):
@ -44,7 +44,7 @@ class matrix(Generic[TNum]):
def __getitem__(self, key: tuple[int, int]) -> value[TNum] | TNum: ... def __getitem__(self, key: tuple[int, int]) -> value[TNum] | TNum: ...
def __getitem__(self, key: int | tuple[int, int]) -> Any: def __getitem__(self, key: int | tuple[int, int]) -> Any:
"""Get a row as a vector or a specific element. """Get a row as a vector or a specific element.
Args: Arguments:
key: row index or (row, col) tuple key: row index or (row, col) tuple
Returns: Returns:

View File

@ -18,6 +18,17 @@ def mixed_sum(scalars: Iterable[int | float | value[Any]]) -> Any:
def mixed_homogenize(scalars: Iterable[T | value[T]]) -> Iterable[T] | Iterable[value[T]]: def mixed_homogenize(scalars: Iterable[T | value[T]]) -> Iterable[T] | Iterable[value[T]]:
"""Convert all scalars to either python numbers if there are no value types,
or to value types if there is at least one value type.
Arguments:
scalars: Iterable of scalars which can be either
python numbers or value types.
Returns:
Iterable of scalars homogenized to either all plain values
or all value types.
"""
if any(isinstance(val, value) for val in scalars): if any(isinstance(val, value) for val in scalars):
return (value(val) if not isinstance(val, value) else val for val in scalars) return (value(val) if not isinstance(val, value) else val for val in scalars)
else: else:

View File

@ -123,7 +123,7 @@ class stencil_database():
def __init__(self, obj_file: str | bytes): def __init__(self, obj_file: str | bytes):
"""Load the stencil database from an ELF object file """Load the stencil database from an ELF object file
Args: Arguments:
obj_file: path to the ELF object file or bytes of the ELF object file obj_file: path to the ELF object file or bytes of the ELF object file
""" """
if isinstance(obj_file, str): if isinstance(obj_file, str):
@ -201,7 +201,7 @@ class stencil_database():
def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry: def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry:
"""Return patch positions for a provided symbol (function or object) """Return patch positions for a provided symbol (function or object)
Args: Arguments:
relocation: relocation entry relocation: relocation entry
symbol_address: absolute address of the target symbol symbol_address: absolute address of the target symbol
function_offset: absolute address of the first byte of the function_offset: absolute address of the first byte of the
@ -313,7 +313,7 @@ class stencil_database():
def get_stencil_code(self, name: str) -> bytes: def get_stencil_code(self, name: str) -> bytes:
"""Return the striped function code for a provided function name """Return the striped function code for a provided function name
Args: Arguments:
name: function name name: function name
Returns: Returns:
@ -333,7 +333,7 @@ class stencil_database():
def get_sub_functions(self, names: Iterable[str]) -> set[str]: def get_sub_functions(self, names: Iterable[str]) -> set[str]:
"""Return recursively all functions called by stencils or by other functions """Return recursively all functions called by stencils or by other functions
Args: Arguments:
names: function or stencil names names: function or stencil names
Returns: Returns:
@ -384,7 +384,7 @@ class stencil_database():
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes: def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
"""Returns machine code for a specified function name. """Returns machine code for a specified function name.
Args: Arguments:
name: function name name: function name
part: part of the function to return ('full', 'start', 'end') part: part of the function to return ('full', 'start', 'end')

View File

@ -23,6 +23,14 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
def jit(func: Callable[..., TRet]) -> Callable[..., TRet]: def jit(func: Callable[..., TRet]) -> Callable[..., TRet]:
"""Just-in-time compile a function for the copapy target.
Arguments:
func: Function to compile
Returns:
A callable that runs the compiled function.
"""
def call_helper(*args: ArgType) -> TRet: def call_helper(*args: ArgType) -> TRet:
if func in _jit_cache: if func in _jit_cache:
tg, inputs, out = _jit_cache[func] tg, inputs, out = _jit_cache[func]
@ -96,16 +104,15 @@ class Target():
"""Reads the numeric value of a copapy type. """Reads the numeric value of a copapy type.
Arguments: Arguments:
net: Values to read net: Value or multiple Values to read
Returns: Returns:
Numeric value Numeric value or values
""" """
if isinstance(net, Iterable): if isinstance(net, Iterable):
return [self.read_value(ni) if isinstance(ni, value) else ni for ni in net] return [self.read_value(ni) if isinstance(ni, value) else ni for ni in net]
if isinstance(net, float | int): if isinstance(net, float | int):
print("Warning: value is not a copypy value")
return net return net
assert isinstance(net, Net), "Argument must be a copapy value" assert isinstance(net, Net), "Argument must be a copapy value"
@ -136,11 +143,11 @@ class Target():
raise ValueError(f"Unsupported value type: {var_type}") raise ValueError(f"Unsupported value type: {var_type}")
def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None: def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None:
"""Reads the numeric value of a copapy type. """Write to a copapy value on the target.
Arguments: Arguments:
net: Variable to overwrite net: Singe variable or multiple variables to overwrite
value: Value value: Singe value or multiple values to write
""" """
if isinstance(net, Iterable): if isinstance(net, Iterable):
assert isinstance(value, Iterable), "If net is iterable, value must be iterable too" assert isinstance(value, Iterable), "If net is iterable, value must be iterable too"

View File

@ -19,7 +19,7 @@ class vector(Generic[TNum]):
def __init__(self, values: Iterable[TNum | value[TNum]]): def __init__(self, values: Iterable[TNum | value[TNum]]):
"""Create a vector with given values. """Create a vector with given values.
Args: Arguments:
values: iterable of constant values values: iterable of constant values
""" """
self.values: tuple[value[TNum] | TNum, ...] = tuple(values) self.values: tuple[value[TNum] | TNum, ...] = tuple(values)

View File

@ -22,7 +22,7 @@ def argsort(input_vector: vector[TNum]) -> vector[int]:
Perform an indirect sort. It returns an array of indices that index data Perform an indirect sort. It returns an array of indices that index data
in sorted order. in sorted order.
Args: Arguments:
input_vector: The input vector containing numerical values. input_vector: The input vector containing numerical values.
Returns: Returns:
@ -35,7 +35,7 @@ def median(input_vector: vector[TNum]) -> TNum | value[TNum]:
""" """
Applies a median filter to the input vector and returns the median as a unifloat. Applies a median filter to the input vector and returns the median as a unifloat.
Args: Arguments:
input_vector: The input vector containing numerical values. input_vector: The input vector containing numerical values.
Returns: Returns:
@ -56,7 +56,7 @@ def mean(input_vector: vector[Any]) -> unifloat:
""" """
Applies a mean filter to the input vector and returns the mean as a unifloat. Applies a mean filter to the input vector and returns the mean as a unifloat.
Args: Arguments:
input_vector (vector): The input vector containing numerical values. input_vector (vector): The input vector containing numerical values.
Returns: Returns: