Docstrings updated

This commit is contained in:
Nicolas Kruse 2025-12-20 22:59:31 +01:00
parent a8e70cd5d6
commit c75b4788c3
9 changed files with 56 additions and 21 deletions

View File

@ -95,7 +95,7 @@ class value(Generic[TNum], Net):
def __init__(self, source: TNum | Node, dtype: str | None = None):
"""Instance a value.
Args:
Arguments:
source: A numeric value or Node object.
dtype: Data type of this value. Required if source is a Node.
"""
@ -376,6 +376,17 @@ def iif(expression: float | int, true_result: value[TNum], false_result: TNum |
@overload
def iif(expression: float | int | value[Any], true_result: TNum | value[TNum], false_result: TNum | value[TNum]) -> value[TNum] | TNum: ...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
"""Inline if-else operation. Returns true_result if expression is non-zero,
else returns false_result.
Arguments:
expression: The condition to evaluate.
true_result: The result if expression is non-zero.
false_result: The result if expression is zero.
Returns:
The selected result based on the evaluation of expression.
"""
allowed_type = (value, int, float)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
return (expression != 0) * true_result + (expression == 0) * false_result

View File

@ -308,6 +308,14 @@ def get_aux_func_layout(function_names: Iterable[str], sdb: stencil_database, of
def get_dag_stats(node_list: Iterable[Node | Net]) -> dict[str, int]:
"""Get operation statistics for the DAG identified by provided end nodes
Arguments:
node_list: List of end nodes of the DAG
Returns:
Dictionary of operation name to occurrence count
"""
edges = get_all_dag_edges(n.source if isinstance(n, Net) else n for n in node_list)
ops = {node for node, _ in edges}

View File

@ -280,7 +280,6 @@ def get_42(x: NumLike) -> value[float] | float:
return float((int(x) * 3.0 + 42.0) * 5.0 + 21.0)
#TODO: Add vector support
@overload
def abs(x: U) -> U: ...
@overload
@ -296,9 +295,8 @@ def abs(x: U | value[U] | vector[U]) -> Any:
Returns:
Absolute value of x
"""
#tt = -x * (x < 0)
ret = (x < 0) * -x + (x >= 0) * x
return ret # REMpyright: ignore[reportReturnType]
return ret
@overload

View File

@ -16,7 +16,7 @@ class matrix(Generic[TNum]):
def __init__(self, values: Iterable[Iterable[TNum | value[TNum]]] | vector[TNum]):
"""Create a matrix with given values.
Args:
Arguments:
values: iterable of iterable of constant values
"""
if isinstance(values, vector):
@ -44,7 +44,7 @@ class matrix(Generic[TNum]):
def __getitem__(self, key: tuple[int, int]) -> value[TNum] | TNum: ...
def __getitem__(self, key: int | tuple[int, int]) -> Any:
"""Get a row as a vector or a specific element.
Args:
Arguments:
key: row index or (row, col) tuple
Returns:

View File

@ -18,6 +18,17 @@ def mixed_sum(scalars: Iterable[int | float | value[Any]]) -> Any:
def mixed_homogenize(scalars: Iterable[T | value[T]]) -> Iterable[T] | Iterable[value[T]]:
"""Convert all scalars to either python numbers if there are no value types,
or to value types if there is at least one value type.
Arguments:
scalars: Iterable of scalars which can be either
python numbers or value types.
Returns:
Iterable of scalars homogenized to either all plain values
or all value types.
"""
if any(isinstance(val, value) for val in scalars):
return (value(val) if not isinstance(val, value) else val for val in scalars)
else:

View File

@ -123,7 +123,7 @@ class stencil_database():
def __init__(self, obj_file: str | bytes):
"""Load the stencil database from an ELF object file
Args:
Arguments:
obj_file: path to the ELF object file or bytes of the ELF object file
"""
if isinstance(obj_file, str):
@ -201,7 +201,7 @@ class stencil_database():
def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry:
"""Return patch positions for a provided symbol (function or object)
Args:
Arguments:
relocation: relocation entry
symbol_address: absolute address of the target symbol
function_offset: absolute address of the first byte of the
@ -313,7 +313,7 @@ class stencil_database():
def get_stencil_code(self, name: str) -> bytes:
"""Return the striped function code for a provided function name
Args:
Arguments:
name: function name
Returns:
@ -333,7 +333,7 @@ class stencil_database():
def get_sub_functions(self, names: Iterable[str]) -> set[str]:
"""Return recursively all functions called by stencils or by other functions
Args:
Arguments:
names: function or stencil names
Returns:
@ -384,7 +384,7 @@ class stencil_database():
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
"""Returns machine code for a specified function name.
Args:
Arguments:
name: function name
part: part of the function to return ('full', 'start', 'end')

View File

@ -23,6 +23,14 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
def jit(func: Callable[..., TRet]) -> Callable[..., TRet]:
"""Just-in-time compile a function for the copapy target.
Arguments:
func: Function to compile
Returns:
A callable that runs the compiled function.
"""
def call_helper(*args: ArgType) -> TRet:
if func in _jit_cache:
tg, inputs, out = _jit_cache[func]
@ -96,16 +104,15 @@ class Target():
"""Reads the numeric value of a copapy type.
Arguments:
net: Values to read
net: Value or multiple Values to read
Returns:
Numeric value
Numeric value or values
"""
if isinstance(net, Iterable):
return [self.read_value(ni) if isinstance(ni, value) else ni for ni in net]
if isinstance(net, float | int):
print("Warning: value is not a copypy value")
return net
assert isinstance(net, Net), "Argument must be a copapy value"
@ -136,11 +143,11 @@ class Target():
raise ValueError(f"Unsupported value type: {var_type}")
def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None:
"""Reads the numeric value of a copapy type.
"""Write to a copapy value on the target.
Arguments:
net: Variable to overwrite
value: Value
net: Singe variable or multiple variables to overwrite
value: Singe value or multiple values to write
"""
if isinstance(net, Iterable):
assert isinstance(value, Iterable), "If net is iterable, value must be iterable too"

View File

@ -19,7 +19,7 @@ class vector(Generic[TNum]):
def __init__(self, values: Iterable[TNum | value[TNum]]):
"""Create a vector with given values.
Args:
Arguments:
values: iterable of constant values
"""
self.values: tuple[value[TNum] | TNum, ...] = tuple(values)

View File

@ -22,7 +22,7 @@ def argsort(input_vector: vector[TNum]) -> vector[int]:
Perform an indirect sort. It returns an array of indices that index data
in sorted order.
Args:
Arguments:
input_vector: The input vector containing numerical values.
Returns:
@ -35,7 +35,7 @@ def median(input_vector: vector[TNum]) -> TNum | value[TNum]:
"""
Applies a median filter to the input vector and returns the median as a unifloat.
Args:
Arguments:
input_vector: The input vector containing numerical values.
Returns:
@ -56,7 +56,7 @@ def mean(input_vector: vector[Any]) -> unifloat:
"""
Applies a mean filter to the input vector and returns the mean as a unifloat.
Args:
Arguments:
input_vector (vector): The input vector containing numerical values.
Returns: