mirror of https://github.com/Nonannet/copapy.git
Docstrings added and updated
This commit is contained in:
parent
98418e5e17
commit
7584b316fc
|
|
@ -32,6 +32,14 @@ def transl_type(t: str) -> str:
|
|||
|
||||
|
||||
class Node:
|
||||
"""A Node represents an computational operation like ADD or other operations
|
||||
like read and write from or to the memory or IOs. In the computation graph
|
||||
Nodes are connected via Nets.
|
||||
|
||||
Attributes:
|
||||
args (list[Net]): The input Nets to this Node.
|
||||
name (str): The name of the operation this Node represents.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.args: list[Net] = []
|
||||
self.name: str = ''
|
||||
|
|
@ -40,11 +48,14 @@ class Node:
|
|||
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
||||
|
||||
|
||||
class Device():
|
||||
pass
|
||||
|
||||
|
||||
class Net:
|
||||
"""A Net represents a variable in the computation graph - or more generally it
|
||||
connects Nodes together.
|
||||
|
||||
Attributes:
|
||||
dtype (str): The data type of this Net.
|
||||
source (Node): The Node that produces the value for this Net.
|
||||
"""
|
||||
def __init__(self, dtype: str, source: Node):
|
||||
self.dtype = dtype
|
||||
self.source = source
|
||||
|
|
@ -58,7 +69,19 @@ class Net:
|
|||
|
||||
|
||||
class variable(Generic[TNum], Net):
|
||||
"""A "variable" represents a typed variable. It supports arithmetic and
|
||||
comparison operations.
|
||||
|
||||
Attributes:
|
||||
dtype (str): Data type of this variable.
|
||||
"""
|
||||
def __init__(self, source: TNum | Node, dtype: str | None = None):
|
||||
"""Instance a variable.
|
||||
|
||||
Args:
|
||||
source: A numeric value or Node object.
|
||||
dtype: Data type of this variable. Required if source is a Node.
|
||||
"""
|
||||
if isinstance(source, Node):
|
||||
self.source = source
|
||||
assert dtype, 'For source type Node a dtype argument is required.'
|
||||
|
|
|
|||
|
|
@ -144,6 +144,8 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list
|
|||
|
||||
|
||||
def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
|
||||
"""Get all unique nets from the provided inputs
|
||||
"""
|
||||
nets: set[Net] = set()
|
||||
|
||||
for input in inputs:
|
||||
|
|
@ -156,6 +158,16 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
|
|||
|
||||
|
||||
def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[Net, int, int]], int]:
|
||||
"""Get memory layout for the provided variables
|
||||
|
||||
Arguments:
|
||||
variable_list: Variables to layout
|
||||
sdb: Stencil database for size lookup
|
||||
offset: Starting offset for layout
|
||||
|
||||
Returns:
|
||||
Tuple of list of (variable, start_offset, length) and total length"""
|
||||
|
||||
object_list: list[tuple[Net, int, int]] = []
|
||||
|
||||
for variable in variable_list:
|
||||
|
|
@ -171,17 +183,37 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset:
|
|||
|
||||
|
||||
def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]:
|
||||
"""Get memory layout for the provided sections
|
||||
|
||||
Arguments:
|
||||
section_indexes: Sections (by index) to layout
|
||||
sdb: Stencil database for size lookup
|
||||
offset: Starting offset for layout
|
||||
|
||||
Returns:
|
||||
Tuple of list of (section_id, start_offset, length) and total length
|
||||
"""
|
||||
section_list: list[tuple[int, int, int]] = []
|
||||
|
||||
for id in section_indexes:
|
||||
lengths = sdb.get_section_size(id)
|
||||
section_list.append((id, offset, lengths))
|
||||
for index in section_indexes:
|
||||
lengths = sdb.get_section_size(index)
|
||||
section_list.append((index, offset, lengths))
|
||||
offset += (lengths + 3) // 4 * 4
|
||||
|
||||
return section_list, offset
|
||||
|
||||
|
||||
def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[str, int, int]], int]:
|
||||
"""Get memory layout for the provided auxiliary functions
|
||||
|
||||
Arguments:
|
||||
function_names: Function names to layout
|
||||
sdb: Stencil database for size lookup
|
||||
offset: Starting offset for layout
|
||||
|
||||
Returns:
|
||||
Tuple of list of (function_name, start_offset, length) and total length
|
||||
"""
|
||||
function_list: list[tuple[str, int, int]] = []
|
||||
|
||||
for name in function_names:
|
||||
|
|
@ -193,6 +225,15 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data
|
|||
|
||||
|
||||
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||
"""Compiles a DAG identified by provided end nodes to binary code
|
||||
|
||||
Arguments:
|
||||
node_list: List of end nodes of the DAG to compile
|
||||
sdb: Stencil database
|
||||
|
||||
Returns:
|
||||
Tuple of data writer with binary code and variable layout dictionary
|
||||
"""
|
||||
variables: dict[Net, tuple[int, int, str]] = {}
|
||||
data_list: list[bytes] = []
|
||||
patch_list: list[tuple[int, int, int, binw.Command]] = []
|
||||
|
|
|
|||
|
|
@ -10,7 +10,14 @@ def sqrt(x: float | int) -> float: ...
|
|||
@overload
|
||||
def sqrt(x: variable[Any]) -> variable[float]: ...
|
||||
def sqrt(x: NumLike) -> variable[float] | float:
|
||||
"""Square root function"""
|
||||
"""Square root function
|
||||
|
||||
Arguments:
|
||||
x: Input value
|
||||
|
||||
Returns:
|
||||
Square root of x
|
||||
"""
|
||||
if isinstance(x, variable):
|
||||
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
|
||||
return float(x ** 0.5)
|
||||
|
|
@ -33,7 +40,14 @@ def get_42() -> variable[float]:
|
|||
|
||||
|
||||
def abs(x: T) -> T:
|
||||
"""Absolute value function"""
|
||||
"""Absolute value function
|
||||
|
||||
Arguments:
|
||||
x: Input value
|
||||
|
||||
Returns:
|
||||
Absolute value of x
|
||||
"""
|
||||
ret = (x < 0) * -x + (x >= 0) * x
|
||||
return ret # pyright: ignore[reportReturnType]
|
||||
|
||||
|
|
|
|||
|
|
@ -170,6 +170,13 @@ class stencil_database():
|
|||
return strip_function(self.elf.symbols[name])
|
||||
|
||||
def get_sub_functions(self, names: Iterable[str]) -> set[str]:
|
||||
"""Return recursively all functions called by stencils or by other functions
|
||||
Args:
|
||||
names: function or stencil names
|
||||
|
||||
Returns:
|
||||
set of all sub function names
|
||||
"""
|
||||
name_set: set[str] = set()
|
||||
for name in names:
|
||||
if name not in name_set:
|
||||
|
|
@ -182,16 +189,27 @@ class stencil_database():
|
|||
return name_set
|
||||
|
||||
def get_symbol_size(self, name: str) -> int:
|
||||
"""Returns the size of a specified symbol name."""
|
||||
return self.elf.symbols[name].fields['st_size']
|
||||
|
||||
def get_section_size(self, id: int) -> int:
|
||||
return self.elf.sections[id].fields['sh_size']
|
||||
def get_section_size(self, index: int) -> int:
|
||||
"""Returns the size of a section specified by index."""
|
||||
return self.elf.sections[index].fields['sh_size']
|
||||
|
||||
def get_section_data(self, id: int) -> bytes:
|
||||
return self.elf.sections[id].data
|
||||
def get_section_data(self, index: int) -> bytes:
|
||||
"""Returns the data of a section specified by index."""
|
||||
return self.elf.sections[index].data
|
||||
|
||||
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
|
||||
"""Returns machine code for a specified function name."""
|
||||
"""Returns machine code for a specified function name.
|
||||
|
||||
Args:
|
||||
name: function name
|
||||
part: part of the function to return ('full', 'start', 'end')
|
||||
|
||||
Returns:
|
||||
Machine code bytes of the specified part of the function
|
||||
"""
|
||||
func = self.elf.symbols[name]
|
||||
assert func.info == 'STT_FUNC', f"{name} is not a function"
|
||||
|
||||
|
|
|
|||
|
|
@ -16,11 +16,24 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
|
|||
|
||||
|
||||
class Target():
|
||||
"""Target device for compiling for and running on copapy code.
|
||||
"""
|
||||
def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None:
|
||||
"""Initialize Target object
|
||||
|
||||
Arguments:
|
||||
arch: Target architecture
|
||||
optimization: Optimization level
|
||||
"""
|
||||
self.sdb = stencil_db_from_package(arch, optimization)
|
||||
self._variables: dict[Net, tuple[int, int, str]] = {}
|
||||
|
||||
def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None:
|
||||
"""Compiles the code to compute the given variables.
|
||||
|
||||
Arguments:
|
||||
variables: Variables to compute
|
||||
"""
|
||||
nodes: list[Node] = []
|
||||
for s in variables:
|
||||
if isinstance(s, Iterable):
|
||||
|
|
@ -35,7 +48,8 @@ class Target():
|
|||
assert coparun(dw.get_data()) > 0
|
||||
|
||||
def run(self) -> None:
|
||||
# set entry point and run code
|
||||
"""Runs the compiled code on the target device.
|
||||
"""
|
||||
dw = binw.data_writer(self.sdb.byteorder)
|
||||
dw.write_com(binw.Command.RUN_PROG)
|
||||
dw.write_com(binw.Command.END_COM)
|
||||
|
|
@ -58,8 +72,16 @@ class Target():
|
|||
...
|
||||
|
||||
def read_value(self, net: NumLike) -> float | int | bool:
|
||||
"""Reads the value of a variable.
|
||||
|
||||
Arguments:
|
||||
net: Variable to read
|
||||
|
||||
Returns:
|
||||
Value of the variable
|
||||
"""
|
||||
assert isinstance(net, Net), "Variable must be a copapy variable object"
|
||||
assert net in self._variables, f"Variable {net} not found"
|
||||
assert net in self._variables, f"Variable {net} not found. It might not have been compiled for the target."
|
||||
addr, lengths, var_type = self._variables[net]
|
||||
print('...', self._variables[net], net.dtype)
|
||||
assert lengths > 0
|
||||
|
|
@ -87,6 +109,7 @@ class Target():
|
|||
raise ValueError(f"Unsupported variable type: {var_type}")
|
||||
|
||||
def read_value_remote(self, net: Net) -> None:
|
||||
"""Reads the raw data of a variable by the runner."""
|
||||
dw = binw.data_writer(self.sdb.byteorder)
|
||||
add_read_command(dw, self._variables, net)
|
||||
assert coparun(dw.get_data()) > 0
|
||||
|
|
|
|||
|
|
@ -11,11 +11,16 @@ epsilon = 1e-20
|
|||
|
||||
|
||||
class vector(Generic[T]):
|
||||
"""Type-safe vector supporting numeric promotion between vector types."""
|
||||
"""Mathematical vector class supporting basic operations and interactions with variables.
|
||||
"""
|
||||
def __init__(self, values: Iterable[T | variable[T]]):
|
||||
"""Create a vector with given values and variables.
|
||||
|
||||
Args:
|
||||
values: iterable of constant values and variables
|
||||
"""
|
||||
self.values: tuple[variable[T] | T, ...] = tuple(values)
|
||||
|
||||
# ---------- Basic dunder methods ----------
|
||||
def __repr__(self) -> str:
|
||||
return f"vector({self.values})"
|
||||
|
||||
|
|
@ -143,14 +148,17 @@ class vector(Generic[T]):
|
|||
@overload
|
||||
def sum(self: 'vector[float]') -> float | variable[float]: ...
|
||||
def sum(self) -> Any:
|
||||
"""Sum of all vector elements."""
|
||||
return sum(a for a in self.values if isinstance(a, variable)) +\
|
||||
sum(a for a in self.values if not isinstance(a, variable))
|
||||
|
||||
def magnitude(self) -> 'float | variable[float]':
|
||||
"""Magnitude (length) of the vector."""
|
||||
s = sum(a * a for a in self.values)
|
||||
return sqrt(s) if isinstance(s, variable) else sqrt(s)
|
||||
|
||||
def normalize(self) -> 'vector[float]':
|
||||
"""Returns a normalized (unit length) version of the vector."""
|
||||
mag = self.magnitude() + epsilon
|
||||
return self / mag
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue