Docstrings added and updated

This commit is contained in:
Nicolas Kruse 2025-10-28 23:16:04 +01:00
parent 98418e5e17
commit 7584b316fc
6 changed files with 145 additions and 18 deletions

View File

@ -32,6 +32,14 @@ def transl_type(t: str) -> str:
class Node: class Node:
"""A Node represents an computational operation like ADD or other operations
like read and write from or to the memory or IOs. In the computation graph
Nodes are connected via Nets.
Attributes:
args (list[Net]): The input Nets to this Node.
name (str): The name of the operation this Node represents.
"""
def __init__(self) -> None: def __init__(self) -> None:
self.args: list[Net] = [] self.args: list[Net] = []
self.name: str = '' self.name: str = ''
@ -40,11 +48,14 @@ class Node:
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})" return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
class Device():
pass
class Net: class Net:
"""A Net represents a variable in the computation graph - or more generally it
connects Nodes together.
Attributes:
dtype (str): The data type of this Net.
source (Node): The Node that produces the value for this Net.
"""
def __init__(self, dtype: str, source: Node): def __init__(self, dtype: str, source: Node):
self.dtype = dtype self.dtype = dtype
self.source = source self.source = source
@ -58,7 +69,19 @@ class Net:
class variable(Generic[TNum], Net): class variable(Generic[TNum], Net):
"""A "variable" represents a typed variable. It supports arithmetic and
comparison operations.
Attributes:
dtype (str): Data type of this variable.
"""
def __init__(self, source: TNum | Node, dtype: str | None = None): def __init__(self, source: TNum | Node, dtype: str | None = None):
"""Instance a variable.
Args:
source: A numeric value or Node object.
dtype: Data type of this variable. Required if source is a Node.
"""
if isinstance(source, Node): if isinstance(source, Node):
self.source = source self.source = source
assert dtype, 'For source type Node a dtype argument is required.' assert dtype, 'For source type Node a dtype argument is required.'

View File

@ -144,6 +144,8 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list
def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]: def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
"""Get all unique nets from the provided inputs
"""
nets: set[Net] = set() nets: set[Net] = set()
for input in inputs: for input in inputs:
@ -156,6 +158,16 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[Net, int, int]], int]: def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[Net, int, int]], int]:
"""Get memory layout for the provided variables
Arguments:
variable_list: Variables to layout
sdb: Stencil database for size lookup
offset: Starting offset for layout
Returns:
Tuple of list of (variable, start_offset, length) and total length"""
object_list: list[tuple[Net, int, int]] = [] object_list: list[tuple[Net, int, int]] = []
for variable in variable_list: for variable in variable_list:
@ -171,17 +183,37 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset:
def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]: def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]:
"""Get memory layout for the provided sections
Arguments:
section_indexes: Sections (by index) to layout
sdb: Stencil database for size lookup
offset: Starting offset for layout
Returns:
Tuple of list of (section_id, start_offset, length) and total length
"""
section_list: list[tuple[int, int, int]] = [] section_list: list[tuple[int, int, int]] = []
for id in section_indexes: for index in section_indexes:
lengths = sdb.get_section_size(id) lengths = sdb.get_section_size(index)
section_list.append((id, offset, lengths)) section_list.append((index, offset, lengths))
offset += (lengths + 3) // 4 * 4 offset += (lengths + 3) // 4 * 4
return section_list, offset return section_list, offset
def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[str, int, int]], int]: def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[str, int, int]], int]:
"""Get memory layout for the provided auxiliary functions
Arguments:
function_names: Function names to layout
sdb: Stencil database for size lookup
offset: Starting offset for layout
Returns:
Tuple of list of (function_name, start_offset, length) and total length
"""
function_list: list[tuple[str, int, int]] = [] function_list: list[tuple[str, int, int]] = []
for name in function_names: for name in function_names:
@ -193,6 +225,15 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
"""Compiles a DAG identified by provided end nodes to binary code
Arguments:
node_list: List of end nodes of the DAG to compile
sdb: Stencil database
Returns:
Tuple of data writer with binary code and variable layout dictionary
"""
variables: dict[Net, tuple[int, int, str]] = {} variables: dict[Net, tuple[int, int, str]] = {}
data_list: list[bytes] = [] data_list: list[bytes] = []
patch_list: list[tuple[int, int, int, binw.Command]] = [] patch_list: list[tuple[int, int, int, binw.Command]] = []

View File

@ -10,7 +10,14 @@ def sqrt(x: float | int) -> float: ...
@overload @overload
def sqrt(x: variable[Any]) -> variable[float]: ... def sqrt(x: variable[Any]) -> variable[float]: ...
def sqrt(x: NumLike) -> variable[float] | float: def sqrt(x: NumLike) -> variable[float] | float:
"""Square root function""" """Square root function
Arguments:
x: Input value
Returns:
Square root of x
"""
if isinstance(x, variable): if isinstance(x, variable):
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
return float(x ** 0.5) return float(x ** 0.5)
@ -33,7 +40,14 @@ def get_42() -> variable[float]:
def abs(x: T) -> T: def abs(x: T) -> T:
"""Absolute value function""" """Absolute value function
Arguments:
x: Input value
Returns:
Absolute value of x
"""
ret = (x < 0) * -x + (x >= 0) * x ret = (x < 0) * -x + (x >= 0) * x
return ret # pyright: ignore[reportReturnType] return ret # pyright: ignore[reportReturnType]

View File

@ -170,6 +170,13 @@ class stencil_database():
return strip_function(self.elf.symbols[name]) return strip_function(self.elf.symbols[name])
def get_sub_functions(self, names: Iterable[str]) -> set[str]: def get_sub_functions(self, names: Iterable[str]) -> set[str]:
"""Return recursively all functions called by stencils or by other functions
Args:
names: function or stencil names
Returns:
set of all sub function names
"""
name_set: set[str] = set() name_set: set[str] = set()
for name in names: for name in names:
if name not in name_set: if name not in name_set:
@ -182,16 +189,27 @@ class stencil_database():
return name_set return name_set
def get_symbol_size(self, name: str) -> int: def get_symbol_size(self, name: str) -> int:
"""Returns the size of a specified symbol name."""
return self.elf.symbols[name].fields['st_size'] return self.elf.symbols[name].fields['st_size']
def get_section_size(self, id: int) -> int: def get_section_size(self, index: int) -> int:
return self.elf.sections[id].fields['sh_size'] """Returns the size of a section specified by index."""
return self.elf.sections[index].fields['sh_size']
def get_section_data(self, id: int) -> bytes: def get_section_data(self, index: int) -> bytes:
return self.elf.sections[id].data """Returns the data of a section specified by index."""
return self.elf.sections[index].data
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes: def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
"""Returns machine code for a specified function name.""" """Returns machine code for a specified function name.
Args:
name: function name
part: part of the function to return ('full', 'start', 'end')
Returns:
Machine code bytes of the specified part of the function
"""
func = self.elf.symbols[name] func = self.elf.symbols[name]
assert func.info == 'STT_FUNC', f"{name} is not a function" assert func.info == 'STT_FUNC', f"{name} is not a function"

View File

@ -16,11 +16,24 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
class Target(): class Target():
"""Target device for compiling for and running on copapy code.
"""
def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None: def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None:
"""Initialize Target object
Arguments:
arch: Target architecture
optimization: Optimization level
"""
self.sdb = stencil_db_from_package(arch, optimization) self.sdb = stencil_db_from_package(arch, optimization)
self._variables: dict[Net, tuple[int, int, str]] = {} self._variables: dict[Net, tuple[int, int, str]] = {}
def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None: def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None:
"""Compiles the code to compute the given variables.
Arguments:
variables: Variables to compute
"""
nodes: list[Node] = [] nodes: list[Node] = []
for s in variables: for s in variables:
if isinstance(s, Iterable): if isinstance(s, Iterable):
@ -35,7 +48,8 @@ class Target():
assert coparun(dw.get_data()) > 0 assert coparun(dw.get_data()) > 0
def run(self) -> None: def run(self) -> None:
# set entry point and run code """Runs the compiled code on the target device.
"""
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)
dw.write_com(binw.Command.RUN_PROG) dw.write_com(binw.Command.RUN_PROG)
dw.write_com(binw.Command.END_COM) dw.write_com(binw.Command.END_COM)
@ -58,8 +72,16 @@ class Target():
... ...
def read_value(self, net: NumLike) -> float | int | bool: def read_value(self, net: NumLike) -> float | int | bool:
"""Reads the value of a variable.
Arguments:
net: Variable to read
Returns:
Value of the variable
"""
assert isinstance(net, Net), "Variable must be a copapy variable object" assert isinstance(net, Net), "Variable must be a copapy variable object"
assert net in self._variables, f"Variable {net} not found" assert net in self._variables, f"Variable {net} not found. It might not have been compiled for the target."
addr, lengths, var_type = self._variables[net] addr, lengths, var_type = self._variables[net]
print('...', self._variables[net], net.dtype) print('...', self._variables[net], net.dtype)
assert lengths > 0 assert lengths > 0
@ -87,6 +109,7 @@ class Target():
raise ValueError(f"Unsupported variable type: {var_type}") raise ValueError(f"Unsupported variable type: {var_type}")
def read_value_remote(self, net: Net) -> None: def read_value_remote(self, net: Net) -> None:
"""Reads the raw data of a variable by the runner."""
dw = binw.data_writer(self.sdb.byteorder) dw = binw.data_writer(self.sdb.byteorder)
add_read_command(dw, self._variables, net) add_read_command(dw, self._variables, net)
assert coparun(dw.get_data()) > 0 assert coparun(dw.get_data()) > 0

View File

@ -11,11 +11,16 @@ epsilon = 1e-20
class vector(Generic[T]): class vector(Generic[T]):
"""Type-safe vector supporting numeric promotion between vector types.""" """Mathematical vector class supporting basic operations and interactions with variables.
"""
def __init__(self, values: Iterable[T | variable[T]]): def __init__(self, values: Iterable[T | variable[T]]):
"""Create a vector with given values and variables.
Args:
values: iterable of constant values and variables
"""
self.values: tuple[variable[T] | T, ...] = tuple(values) self.values: tuple[variable[T] | T, ...] = tuple(values)
# ---------- Basic dunder methods ----------
def __repr__(self) -> str: def __repr__(self) -> str:
return f"vector({self.values})" return f"vector({self.values})"
@ -143,14 +148,17 @@ class vector(Generic[T]):
@overload @overload
def sum(self: 'vector[float]') -> float | variable[float]: ... def sum(self: 'vector[float]') -> float | variable[float]: ...
def sum(self) -> Any: def sum(self) -> Any:
"""Sum of all vector elements."""
return sum(a for a in self.values if isinstance(a, variable)) +\ return sum(a for a in self.values if isinstance(a, variable)) +\
sum(a for a in self.values if not isinstance(a, variable)) sum(a for a in self.values if not isinstance(a, variable))
def magnitude(self) -> 'float | variable[float]': def magnitude(self) -> 'float | variable[float]':
"""Magnitude (length) of the vector."""
s = sum(a * a for a in self.values) s = sum(a * a for a in self.values)
return sqrt(s) if isinstance(s, variable) else sqrt(s) return sqrt(s) if isinstance(s, variable) else sqrt(s)
def normalize(self) -> 'vector[float]': def normalize(self) -> 'vector[float]':
"""Returns a normalized (unit length) version of the vector."""
mag = self.magnitude() + epsilon mag = self.magnitude() + epsilon
return self / mag return self / mag