mirror of https://github.com/Nonannet/copapy.git
Docstrings added and updated
This commit is contained in:
parent
98418e5e17
commit
7584b316fc
|
|
@ -32,6 +32,14 @@ def transl_type(t: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
|
"""A Node represents an computational operation like ADD or other operations
|
||||||
|
like read and write from or to the memory or IOs. In the computation graph
|
||||||
|
Nodes are connected via Nets.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
args (list[Net]): The input Nets to this Node.
|
||||||
|
name (str): The name of the operation this Node represents.
|
||||||
|
"""
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.args: list[Net] = []
|
self.args: list[Net] = []
|
||||||
self.name: str = ''
|
self.name: str = ''
|
||||||
|
|
@ -40,11 +48,14 @@ class Node:
|
||||||
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, CPConstant) else '')})"
|
||||||
|
|
||||||
|
|
||||||
class Device():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Net:
|
class Net:
|
||||||
|
"""A Net represents a variable in the computation graph - or more generally it
|
||||||
|
connects Nodes together.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dtype (str): The data type of this Net.
|
||||||
|
source (Node): The Node that produces the value for this Net.
|
||||||
|
"""
|
||||||
def __init__(self, dtype: str, source: Node):
|
def __init__(self, dtype: str, source: Node):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.source = source
|
self.source = source
|
||||||
|
|
@ -58,7 +69,19 @@ class Net:
|
||||||
|
|
||||||
|
|
||||||
class variable(Generic[TNum], Net):
|
class variable(Generic[TNum], Net):
|
||||||
|
"""A "variable" represents a typed variable. It supports arithmetic and
|
||||||
|
comparison operations.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dtype (str): Data type of this variable.
|
||||||
|
"""
|
||||||
def __init__(self, source: TNum | Node, dtype: str | None = None):
|
def __init__(self, source: TNum | Node, dtype: str | None = None):
|
||||||
|
"""Instance a variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A numeric value or Node object.
|
||||||
|
dtype: Data type of this variable. Required if source is a Node.
|
||||||
|
"""
|
||||||
if isinstance(source, Node):
|
if isinstance(source, Node):
|
||||||
self.source = source
|
self.source = source
|
||||||
assert dtype, 'For source type Node a dtype argument is required.'
|
assert dtype, 'For source type Node a dtype argument is required.'
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,8 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list
|
||||||
|
|
||||||
|
|
||||||
def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
|
def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
|
||||||
|
"""Get all unique nets from the provided inputs
|
||||||
|
"""
|
||||||
nets: set[Net] = set()
|
nets: set[Net] = set()
|
||||||
|
|
||||||
for input in inputs:
|
for input in inputs:
|
||||||
|
|
@ -156,6 +158,16 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
|
||||||
|
|
||||||
|
|
||||||
def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[Net, int, int]], int]:
|
def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[Net, int, int]], int]:
|
||||||
|
"""Get memory layout for the provided variables
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
variable_list: Variables to layout
|
||||||
|
sdb: Stencil database for size lookup
|
||||||
|
offset: Starting offset for layout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of list of (variable, start_offset, length) and total length"""
|
||||||
|
|
||||||
object_list: list[tuple[Net, int, int]] = []
|
object_list: list[tuple[Net, int, int]] = []
|
||||||
|
|
||||||
for variable in variable_list:
|
for variable in variable_list:
|
||||||
|
|
@ -171,17 +183,37 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset:
|
||||||
|
|
||||||
|
|
||||||
def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]:
|
def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]:
|
||||||
|
"""Get memory layout for the provided sections
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
section_indexes: Sections (by index) to layout
|
||||||
|
sdb: Stencil database for size lookup
|
||||||
|
offset: Starting offset for layout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of list of (section_id, start_offset, length) and total length
|
||||||
|
"""
|
||||||
section_list: list[tuple[int, int, int]] = []
|
section_list: list[tuple[int, int, int]] = []
|
||||||
|
|
||||||
for id in section_indexes:
|
for index in section_indexes:
|
||||||
lengths = sdb.get_section_size(id)
|
lengths = sdb.get_section_size(index)
|
||||||
section_list.append((id, offset, lengths))
|
section_list.append((index, offset, lengths))
|
||||||
offset += (lengths + 3) // 4 * 4
|
offset += (lengths + 3) // 4 * 4
|
||||||
|
|
||||||
return section_list, offset
|
return section_list, offset
|
||||||
|
|
||||||
|
|
||||||
def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[str, int, int]], int]:
|
def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[str, int, int]], int]:
|
||||||
|
"""Get memory layout for the provided auxiliary functions
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
function_names: Function names to layout
|
||||||
|
sdb: Stencil database for size lookup
|
||||||
|
offset: Starting offset for layout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of list of (function_name, start_offset, length) and total length
|
||||||
|
"""
|
||||||
function_list: list[tuple[str, int, int]] = []
|
function_list: list[tuple[str, int, int]] = []
|
||||||
|
|
||||||
for name in function_names:
|
for name in function_names:
|
||||||
|
|
@ -193,6 +225,15 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data
|
||||||
|
|
||||||
|
|
||||||
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
|
||||||
|
"""Compiles a DAG identified by provided end nodes to binary code
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
node_list: List of end nodes of the DAG to compile
|
||||||
|
sdb: Stencil database
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of data writer with binary code and variable layout dictionary
|
||||||
|
"""
|
||||||
variables: dict[Net, tuple[int, int, str]] = {}
|
variables: dict[Net, tuple[int, int, str]] = {}
|
||||||
data_list: list[bytes] = []
|
data_list: list[bytes] = []
|
||||||
patch_list: list[tuple[int, int, int, binw.Command]] = []
|
patch_list: list[tuple[int, int, int, binw.Command]] = []
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,14 @@ def sqrt(x: float | int) -> float: ...
|
||||||
@overload
|
@overload
|
||||||
def sqrt(x: variable[Any]) -> variable[float]: ...
|
def sqrt(x: variable[Any]) -> variable[float]: ...
|
||||||
def sqrt(x: NumLike) -> variable[float] | float:
|
def sqrt(x: NumLike) -> variable[float] | float:
|
||||||
"""Square root function"""
|
"""Square root function
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: Input value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Square root of x
|
||||||
|
"""
|
||||||
if isinstance(x, variable):
|
if isinstance(x, variable):
|
||||||
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
|
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
|
||||||
return float(x ** 0.5)
|
return float(x ** 0.5)
|
||||||
|
|
@ -33,7 +40,14 @@ def get_42() -> variable[float]:
|
||||||
|
|
||||||
|
|
||||||
def abs(x: T) -> T:
|
def abs(x: T) -> T:
|
||||||
"""Absolute value function"""
|
"""Absolute value function
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: Input value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute value of x
|
||||||
|
"""
|
||||||
ret = (x < 0) * -x + (x >= 0) * x
|
ret = (x < 0) * -x + (x >= 0) * x
|
||||||
return ret # pyright: ignore[reportReturnType]
|
return ret # pyright: ignore[reportReturnType]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -170,6 +170,13 @@ class stencil_database():
|
||||||
return strip_function(self.elf.symbols[name])
|
return strip_function(self.elf.symbols[name])
|
||||||
|
|
||||||
def get_sub_functions(self, names: Iterable[str]) -> set[str]:
|
def get_sub_functions(self, names: Iterable[str]) -> set[str]:
|
||||||
|
"""Return recursively all functions called by stencils or by other functions
|
||||||
|
Args:
|
||||||
|
names: function or stencil names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
set of all sub function names
|
||||||
|
"""
|
||||||
name_set: set[str] = set()
|
name_set: set[str] = set()
|
||||||
for name in names:
|
for name in names:
|
||||||
if name not in name_set:
|
if name not in name_set:
|
||||||
|
|
@ -182,16 +189,27 @@ class stencil_database():
|
||||||
return name_set
|
return name_set
|
||||||
|
|
||||||
def get_symbol_size(self, name: str) -> int:
|
def get_symbol_size(self, name: str) -> int:
|
||||||
|
"""Returns the size of a specified symbol name."""
|
||||||
return self.elf.symbols[name].fields['st_size']
|
return self.elf.symbols[name].fields['st_size']
|
||||||
|
|
||||||
def get_section_size(self, id: int) -> int:
|
def get_section_size(self, index: int) -> int:
|
||||||
return self.elf.sections[id].fields['sh_size']
|
"""Returns the size of a section specified by index."""
|
||||||
|
return self.elf.sections[index].fields['sh_size']
|
||||||
|
|
||||||
def get_section_data(self, id: int) -> bytes:
|
def get_section_data(self, index: int) -> bytes:
|
||||||
return self.elf.sections[id].data
|
"""Returns the data of a section specified by index."""
|
||||||
|
return self.elf.sections[index].data
|
||||||
|
|
||||||
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
|
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
|
||||||
"""Returns machine code for a specified function name."""
|
"""Returns machine code for a specified function name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: function name
|
||||||
|
part: part of the function to return ('full', 'start', 'end')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Machine code bytes of the specified part of the function
|
||||||
|
"""
|
||||||
func = self.elf.symbols[name]
|
func = self.elf.symbols[name]
|
||||||
assert func.info == 'STT_FUNC', f"{name} is not a function"
|
assert func.info == 'STT_FUNC', f"{name} is not a function"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,24 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
|
||||||
|
|
||||||
|
|
||||||
class Target():
|
class Target():
|
||||||
|
"""Target device for compiling for and running on copapy code.
|
||||||
|
"""
|
||||||
def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None:
|
def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None:
|
||||||
|
"""Initialize Target object
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
arch: Target architecture
|
||||||
|
optimization: Optimization level
|
||||||
|
"""
|
||||||
self.sdb = stencil_db_from_package(arch, optimization)
|
self.sdb = stencil_db_from_package(arch, optimization)
|
||||||
self._variables: dict[Net, tuple[int, int, str]] = {}
|
self._variables: dict[Net, tuple[int, int, str]] = {}
|
||||||
|
|
||||||
def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None:
|
def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None:
|
||||||
|
"""Compiles the code to compute the given variables.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
variables: Variables to compute
|
||||||
|
"""
|
||||||
nodes: list[Node] = []
|
nodes: list[Node] = []
|
||||||
for s in variables:
|
for s in variables:
|
||||||
if isinstance(s, Iterable):
|
if isinstance(s, Iterable):
|
||||||
|
|
@ -35,7 +48,8 @@ class Target():
|
||||||
assert coparun(dw.get_data()) > 0
|
assert coparun(dw.get_data()) > 0
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
# set entry point and run code
|
"""Runs the compiled code on the target device.
|
||||||
|
"""
|
||||||
dw = binw.data_writer(self.sdb.byteorder)
|
dw = binw.data_writer(self.sdb.byteorder)
|
||||||
dw.write_com(binw.Command.RUN_PROG)
|
dw.write_com(binw.Command.RUN_PROG)
|
||||||
dw.write_com(binw.Command.END_COM)
|
dw.write_com(binw.Command.END_COM)
|
||||||
|
|
@ -58,8 +72,16 @@ class Target():
|
||||||
...
|
...
|
||||||
|
|
||||||
def read_value(self, net: NumLike) -> float | int | bool:
|
def read_value(self, net: NumLike) -> float | int | bool:
|
||||||
|
"""Reads the value of a variable.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
net: Variable to read
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Value of the variable
|
||||||
|
"""
|
||||||
assert isinstance(net, Net), "Variable must be a copapy variable object"
|
assert isinstance(net, Net), "Variable must be a copapy variable object"
|
||||||
assert net in self._variables, f"Variable {net} not found"
|
assert net in self._variables, f"Variable {net} not found. It might not have been compiled for the target."
|
||||||
addr, lengths, var_type = self._variables[net]
|
addr, lengths, var_type = self._variables[net]
|
||||||
print('...', self._variables[net], net.dtype)
|
print('...', self._variables[net], net.dtype)
|
||||||
assert lengths > 0
|
assert lengths > 0
|
||||||
|
|
@ -87,6 +109,7 @@ class Target():
|
||||||
raise ValueError(f"Unsupported variable type: {var_type}")
|
raise ValueError(f"Unsupported variable type: {var_type}")
|
||||||
|
|
||||||
def read_value_remote(self, net: Net) -> None:
|
def read_value_remote(self, net: Net) -> None:
|
||||||
|
"""Reads the raw data of a variable by the runner."""
|
||||||
dw = binw.data_writer(self.sdb.byteorder)
|
dw = binw.data_writer(self.sdb.byteorder)
|
||||||
add_read_command(dw, self._variables, net)
|
add_read_command(dw, self._variables, net)
|
||||||
assert coparun(dw.get_data()) > 0
|
assert coparun(dw.get_data()) > 0
|
||||||
|
|
|
||||||
|
|
@ -11,11 +11,16 @@ epsilon = 1e-20
|
||||||
|
|
||||||
|
|
||||||
class vector(Generic[T]):
|
class vector(Generic[T]):
|
||||||
"""Type-safe vector supporting numeric promotion between vector types."""
|
"""Mathematical vector class supporting basic operations and interactions with variables.
|
||||||
|
"""
|
||||||
def __init__(self, values: Iterable[T | variable[T]]):
|
def __init__(self, values: Iterable[T | variable[T]]):
|
||||||
|
"""Create a vector with given values and variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: iterable of constant values and variables
|
||||||
|
"""
|
||||||
self.values: tuple[variable[T] | T, ...] = tuple(values)
|
self.values: tuple[variable[T] | T, ...] = tuple(values)
|
||||||
|
|
||||||
# ---------- Basic dunder methods ----------
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"vector({self.values})"
|
return f"vector({self.values})"
|
||||||
|
|
||||||
|
|
@ -143,14 +148,17 @@ class vector(Generic[T]):
|
||||||
@overload
|
@overload
|
||||||
def sum(self: 'vector[float]') -> float | variable[float]: ...
|
def sum(self: 'vector[float]') -> float | variable[float]: ...
|
||||||
def sum(self) -> Any:
|
def sum(self) -> Any:
|
||||||
|
"""Sum of all vector elements."""
|
||||||
return sum(a for a in self.values if isinstance(a, variable)) +\
|
return sum(a for a in self.values if isinstance(a, variable)) +\
|
||||||
sum(a for a in self.values if not isinstance(a, variable))
|
sum(a for a in self.values if not isinstance(a, variable))
|
||||||
|
|
||||||
def magnitude(self) -> 'float | variable[float]':
|
def magnitude(self) -> 'float | variable[float]':
|
||||||
|
"""Magnitude (length) of the vector."""
|
||||||
s = sum(a * a for a in self.values)
|
s = sum(a * a for a in self.values)
|
||||||
return sqrt(s) if isinstance(s, variable) else sqrt(s)
|
return sqrt(s) if isinstance(s, variable) else sqrt(s)
|
||||||
|
|
||||||
def normalize(self) -> 'vector[float]':
|
def normalize(self) -> 'vector[float]':
|
||||||
|
"""Returns a normalized (unit length) version of the vector."""
|
||||||
mag = self.magnitude() + epsilon
|
mag = self.magnitude() + epsilon
|
||||||
return self / mag
|
return self / mag
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue