diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index e5f0117..e2dcd01 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -1,802 +1,16 @@ -import pkgutil -from typing import Generator, Iterable, Any, TypeVar, overload, TypeAlias -from typing import cast - -from . import binwrite as binw -from .stencil_db import stencil_database -from collections import defaultdict, deque -from coparun_module import coparun, read_data_mem -import struct -import platform - -NumLike: TypeAlias = 'cpint | cpfloat | cpbool | int | float| bool' -NumLikeAndNet: TypeAlias = 'cpint | cpfloat | cpbool | int | float | bool | Net' -NetAndNum: TypeAlias = 'Net | int | float' - -unifloat: TypeAlias = 'cpfloat | float' -uniint: TypeAlias = 'cpint | int' -unibool: TypeAlias = 'cpbool | bool' - -TNumber = TypeVar("TNumber", bound='CPNumber') -T = TypeVar("T") - - -def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: - return [name for name, value in scope.items() if value is var] - - -def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> stencil_database: - if arch == 'native': - arch = platform.machine() - stencil_data = pkgutil.get_data(__name__, f"obj/stencils_{arch}_{optimization}.o") - assert stencil_data, f"stencils_{arch}_{optimization} not found" - return stencil_database(stencil_data) - - -generic_sdb = stencil_db_from_package() - - -def transl_type(t: str) -> str: - return {'bool': 'int'}.get(t, t) - - -class Node: - def __init__(self) -> None: - self.args: list[Net] = [] - self.name: str = '' - - def __repr__(self) -> str: - return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, InitVar) else '')})" - - -class Device(): - pass - - -class Net: - def __init__(self, dtype: str, source: Node): - self.dtype = dtype - self.source = source - - def __repr__(self) -> str: - names = get_var_name(self) - return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}" - - def __hash__(self) -> int: - return id(self) - - -class CPNumber(Net): - def __init__(self, dtype: str, source: Node): - self.dtype = dtype - self.source = source - - @overload - def __mul__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __mul__(self, other: unifloat) -> 'cpfloat': - ... - - def __mul__(self, other: NumLike) -> 'CPNumber': - return _add_op('mul', [self, other], True) - - @overload - def __rmul__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __rmul__(self, other: unifloat) -> 'cpfloat': - ... - - def __rmul__(self, other: NumLike) -> 'CPNumber': - return _add_op('mul', [self, other], True) - - @overload - def __add__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __add__(self, other: unifloat) -> 'cpfloat': - ... - - def __add__(self, other: NumLike) -> 'CPNumber': - return _add_op('add', [self, other], True) - - @overload - def __radd__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __radd__(self, other: unifloat) -> 'cpfloat': - ... - - def __radd__(self, other: NumLike) -> 'CPNumber': - return _add_op('add', [self, other], True) - - @overload - def __sub__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __sub__(self, other: unifloat) -> 'cpfloat': - ... - - def __sub__(self, other: NumLike) -> 'CPNumber': - return _add_op('sub', [self, other]) - - @overload - def __rsub__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __rsub__(self, other: unifloat) -> 'cpfloat': - ... - - def __rsub__(self, other: NumLike) -> 'CPNumber': - return _add_op('sub', [other, self]) - - def __truediv__(self, other: NumLike) -> 'cpfloat': - ret = _add_op('div', [self, other]) - assert isinstance(ret, cpfloat) - return ret - - def __rtruediv__(self, other: NumLike) -> 'cpfloat': - ret = _add_op('div', [other, self]) - assert isinstance(ret, cpfloat) - return ret - - @overload - def __floordiv__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __floordiv__(self, other: unifloat) -> 'cpfloat': - ... - - def __floordiv__(self, other: NumLike) -> 'CPNumber': - return _add_op('floordiv', [self, other]) - - @overload - def __rfloordiv__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __rfloordiv__(self, other: unifloat) -> 'cpfloat': - ... - - def __rfloordiv__(self, other: NumLike) -> 'CPNumber': - return _add_op('floordiv', [other, self]) - - def __neg__(self: TNumber) -> TNumber: - return cast(TNumber, _add_op('sub', [cpvalue(0), self])) - - def __gt__(self, other: NumLike) -> 'cpbool': - ret = _add_op('gt', [self, other]) - return cpbool(ret.source) - - def __lt__(self, other: NumLike) -> 'cpbool': - ret = _add_op('gt', [other, self]) - return cpbool(ret.source) - - def __eq__(self, other: NumLike) -> 'cpbool': # type: ignore - ret = _add_op('eq', [self, other], True) - return cpbool(ret.source) - - def __ne__(self, other: NumLike) -> 'cpbool': # type: ignore - ret = _add_op('ne', [self, other], True) - return cpbool(ret.source) - - @overload - def __mod__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __mod__(self, other: unifloat) -> 'cpfloat': - ... - - def __mod__(self, other: NumLike) -> 'CPNumber': - return _add_op('mod', [self, other]) - - @overload - def __rmod__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __rmod__(self, other: unifloat) -> 'cpfloat': - ... - - def __rmod__(self, other: NumLike) -> 'CPNumber': - return _add_op('mod', [other, self]) - - @overload - def __pow__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __pow__(self, other: unifloat) -> 'cpfloat': - ... - - def __pow__(self, other: NumLike) -> 'CPNumber': - return _add_op('pow', [other, self]) - - @overload - def __rpow__(self: TNumber, other: uniint) -> TNumber: - ... - - @overload - def __rpow__(self, other: unifloat) -> 'cpfloat': - ... - - def __rpow__(self, other: NumLike) -> 'CPNumber': - return _add_op('rpow', [self, other]) - - def __hash__(self) -> int: - return super().__hash__() - - -class cpint(CPNumber): - def __init__(self, source: int | Node): - if isinstance(source, Node): - self.source = source - else: - self.source = InitVar(int(source)) - self.dtype = 'int' - - def __lshift__(self, other: uniint) -> 'cpint': - ret = _add_op('lshift', [self, other]) - assert isinstance(ret, cpint) - return ret - - def __rlshift__(self, other: uniint) -> 'cpint': - ret = _add_op('lshift', [other, self]) - assert isinstance(ret, cpint) - return ret - - def __rshift__(self, other: uniint) -> 'cpint': - ret = _add_op('rshift', [self, other]) - assert isinstance(ret, cpint) - return ret - - def __rrshift__(self, other: uniint) -> 'cpint': - ret = _add_op('rshift', [other, self]) - assert isinstance(ret, cpint) - return ret - - def __and__(self, other: uniint) -> 'cpint': - ret = _add_op('bwand', [self, other], True) - assert isinstance(ret, cpint) - return ret - - def __rand__(self, other: uniint) -> 'cpint': - ret = _add_op('rwand', [other, self], True) - assert isinstance(ret, cpint) - return ret - - def __or__(self, other: uniint) -> 'cpint': - ret = _add_op('bwor', [self, other], True) - assert isinstance(ret, cpint) - return ret - - def __ror__(self, other: uniint) -> 'cpint': - ret = _add_op('bwor', [other, self], True) - assert isinstance(ret, cpint) - return ret - - def __xor__(self, other: uniint) -> 'cpint': - ret = _add_op('bwxor', [self, other], True) - assert isinstance(ret, cpint) - return ret - - def __rxor__(self, other: uniint) -> 'cpint': - ret = _add_op('bwxor', [other, self], True) - assert isinstance(ret, cpint) - return ret - - -class cpfloat(CPNumber): - def __init__(self, source: float | Node | CPNumber): - if isinstance(source, Node): - self.source = source - elif isinstance(source, CPNumber): - self.source = _add_op('cast_float', [source]).source - else: - self.source = InitVar(float(source)) - self.dtype = 'float' - - -class cpbool(cpint): - def __init__(self, source: bool | Node): - if isinstance(source, Node): - self.source = source - else: - self.source = InitVar(bool(source)) - self.dtype = 'bool' - - -class cpvector: - def __init__(self, *value: NumLike): - self.value = value - - def __add__(self, other: 'cpvector') -> 'cpvector': - assert len(self.value) == len(other.value) - tup = (a + b for a, b in zip(self.value, other.value)) - return cpvector(*(v for v in tup if isinstance(v, CPNumber))) - - -class InitVar(Node): - def __init__(self, value: int | float): - self.dtype, self.value = _get_data_and_dtype(value) - self.name = 'const_' + self.dtype - self.args = [] - - -class Write(Node): - def __init__(self, input: NetAndNum): - if isinstance(input, Net): - net = input - else: - node = InitVar(input) - net = Net(node.dtype, node) - - self.name = 'write_' + transl_type(net.dtype) - self.args = [net] - - -class Op(Node): - def __init__(self, typed_op_name: str, args: list[Net]): - assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' - self.name: str = typed_op_name - self.args: list[Net] = args - - -def net_from_value(value: Any) -> Net: - vi = InitVar(value) - return Net(vi.dtype, vi) - - -@overload -def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> cpbool: # pyright: ignore[reportOverlappingOverload] - ... - - -@overload -def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> cpint: - ... - - -@overload -def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> cpfloat: - ... - - -@overload -def iif(expression: NumLike, true_result: T, false_result: T) -> T: - ... - - -def iif(expression: Any, true_result: Any, false_result: Any) -> Any: - # TODO: check that input types are matching - alowed_type = cpint | cpfloat | cpbool | int | float | bool - assert isinstance(true_result, alowed_type) and isinstance(false_result, alowed_type), "Result type not supported" - if isinstance(expression, CPNumber): - return (expression != 0) * true_result + (expression == 0) * false_result - else: - return true_result if expression else false_result - - -def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber: - arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] - - if commutative: - arg_nets = sorted(arg_nets, key=lambda a: a.dtype) - - typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) - - if typed_op not in generic_sdb.stencil_definitions: - raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}") - - result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] - - if result_type == 'int': - return cpint(Op(typed_op, arg_nets)) - else: - return cpfloat(Op(typed_op, arg_nets)) - - -@overload -def cpvalue(value: bool) -> cpbool: # pyright: ignore[reportOverlappingOverload] - ... - - -@overload -def cpvalue(value: int) -> cpint: - ... - - -@overload -def cpvalue(value: float) -> cpfloat: - ... - - -def cpvalue(value: bool | int | float) -> cpbool | cpint | cpfloat: - vi = InitVar(value) - - if isinstance(value, bool): - return cpbool(vi) - elif isinstance(value, float): - return cpfloat(vi) - else: - return cpint(vi) - - -def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: - if isinstance(value, bool): - return ('bool', int(value)) - elif isinstance(value, int): - return ('int', int(value)) - elif isinstance(value, float): - return ('float', float(value)) - else: - raise ValueError(f'Non supported data type: {type(value).__name__}') - - -def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: - """Perform a stable topological sort on a directed acyclic graph (DAG). - Arguments: - edges: Iterable of (u, v) pairs meaning u -> v - - Returns: - List of nodes in topologically sorted order. - """ - - # Track adjacency and indegrees - adj: defaultdict[Node, list[Node]] = defaultdict(list) - indeg: defaultdict[Node, int] = defaultdict(int) - order: dict[Node, int] = {} # first-appearance order of each node - - # Build graph and order map - pos = 0 - for u, v in edges: - if u not in order: - order[u] = pos - pos += 1 - if v not in order: - order[v] = pos - pos += 1 - adj[u].append(v) - indeg[v] += 1 - indeg.setdefault(u, 0) - - # Initialize queue with nodes of indegree 0, sorted by first appearance - queue = deque(sorted([n for n in indeg if indeg[n] == 0], key=lambda x: order[x])) - result: list[Node] = [] - - while queue: - node = queue.popleft() - result.append(node) - - for nei in adj[node]: - indeg[nei] -= 1 - if indeg[nei] == 0: - queue.append(nei) - - # Maintain stability: sort queue by appearance order - queue = deque(sorted(queue, key=lambda x: order[x])) - - # Check if graph had a cycle (not all nodes output) - if len(result) != len(indeg): - raise ValueError("Graph contains a cycle — topological sort not possible") - - return result - - -def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]: - """Get all edges in the DAG by traversing from the given nodes - - Arguments: - nodes: Iterable of nodes to start the traversal from - - Yields: - Tuples of (source_node, target_node) representing edges in the DAG - """ - for node in nodes: - yield from get_all_dag_edges(net.source for net in node.args) - yield from ((net.source, node) for net in node.args) - - -def get_const_nets(nodes: list[Node]) -> list[Net]: - """Get all nets with a constant nodes value - - Returns: - List of nets whose source node is a Const - """ - net_lookup = {net.source: net for node in nodes for net in node.args} - return [net_lookup[node] for node in nodes if isinstance(node, InitVar)] - - -def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]: - """Add read node before each op where arguments are not already positioned - correctly in the registers - - Arguments: - node_list: List of nodes in the order of execution - - Returns: - Yields tuples of a net and a node. The net is the result net - for the node. If the node has no result net None is returned in the tuple. - """ - - registers: list[None | Net] = [None] * 2 - - # Generate result net lookup table - net_lookup = {net.source: net for node in node_list for net in node.args} - - for node in node_list: - if not isinstance(node, InitVar): - for i, net in enumerate(node.args): - if id(net) != id(registers[i]): - #if net in registers: - # print('x swap registers') - type_list = ['int' if r is None else transl_type(r.dtype) for r in registers] - new_node = Op(f"read_{transl_type(net.dtype)}_reg{i}_" + '_'.join(type_list), []) - yield net, new_node - registers[i] = net - - if node in net_lookup: - yield net_lookup[node], node - registers[0] = net_lookup[node] - else: - yield None, node - - -def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list[Net]) -> Generator[tuple[Net | None, Node], None, None]: - """Add write operation for each new defined net if a read operation is later followed - - Returns: - Yields tuples of a net and a node. The associated net is provided for read and write nodes. - Otherwise None is returned in the tuple. - """ - - # Initialize set of nets with constants - stored_nets = set(const_nets) - - #assert all(node.name.startswith('read_') for net, node in net_node_list if net) - read_back_nets = { - net for net, node in net_node_list - if net and node.name.startswith('read_')} - - for net, node in net_node_list: - if isinstance(node, Write): - yield node.args[0], node - elif node.name.startswith('read_'): - yield net, node - else: - yield None, node - - if net and net in read_back_nets and net not in stored_nets: - yield net, Write(net) - stored_nets.add(net) - - -def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]: - nets: set[Net] = set() - - for input in inputs: - for el in input: - for net in el: - if isinstance(net, Net): - nets.add(net) - - return list(nets) - - -def get_variable_mem_layout(variable_list: Iterable[Net], sdb: stencil_database) -> tuple[list[tuple[Net, int, int]], int]: - offset: int = 0 - object_list: list[tuple[Net, int, int]] = [] - - for variable in variable_list: - lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype)) - object_list.append((variable, offset, lengths)) - offset += (lengths + 3) // 4 * 4 - - return object_list, offset - - -def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]: - offset: int = 0 - function_list: list[tuple[str, int, int]] = [] - - for name in function_names: - lengths = sdb.get_symbol_size(name) - function_list.append((name, offset, lengths)) - offset += (lengths + 3) // 4 * 4 - - return function_list, offset - - -def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: - variables: dict[Net, tuple[int, int, str]] = dict() - - ordered_ops = list(stable_toposort(get_all_dag_edges(node_list))) - const_net_list = get_const_nets(ordered_ops) - output_ops = list(add_read_ops(ordered_ops)) - extended_output_ops = list(add_write_ops(output_ops, const_net_list)) - - dw = binw.data_writer(sdb.byteorder) - - # Deallocate old allocated memory (if existing) - dw.write_com(binw.Command.FREE_MEMORY) - - # Get all nets/variables associated with heap memory - variable_list = get_nets([[const_net_list]], extended_output_ops) - - # Write data - variable_mem_layout, data_section_lengths = get_variable_mem_layout(variable_list, sdb) - dw.write_com(binw.Command.ALLOCATE_DATA) - dw.write_int(data_section_lengths) - - for net, out_offs, lengths in variable_mem_layout: - variables[net] = (out_offs, lengths, net.dtype) - if isinstance(net.source, InitVar): - dw.write_com(binw.Command.COPY_DATA) - dw.write_int(out_offs) - dw.write_int(lengths) - dw.write_value(net.source.value, lengths) - # print(f'+ {net.dtype} {net.source.value}') - - # prep auxiliary_functions - aux_function_names = sdb.get_sub_functions(node.name for _, node in extended_output_ops) - aux_function_mem_layout, aux_function_lengths = get_aux_function_mem_layout(aux_function_names, sdb) - aux_func_addr_lookup = {name: offs for name, offs, _ in aux_function_mem_layout} - - # Prepare program code and relocations - object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout} - data_list: list[bytes] = [] - patch_list: list[tuple[int, int, int, binw.Command]] = [] - offset = aux_function_lengths # offset in generated code chunk - - # assemble stencils to main program - data = sdb.get_function_code('entry_function_shell', 'start') - data_list.append(data) - offset += len(data) - - for associated_net, node in extended_output_ops: - assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found" - data = sdb.get_stencil_code(node.name) - data_list.append(data) - # print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) - - for patch in sdb.get_patch_positions(node.name): - if patch.target_symbol_info == 'STT_OBJECT': - assert associated_net, f"Relocation found but no net defined for operation {node.name}" - addr = object_addr_lookup[associated_net] - patch_value = addr + patch.addend - (offset + patch.addr) - patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) - elif patch.target_symbol_info == 'STT_FUNC': - addr = aux_func_addr_lookup[patch.target_symbol_name] - patch_value = addr + patch.addend - (offset + patch.addr) - patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC)) - else: - raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") - - offset += len(data) - - data = sdb.get_function_code('entry_function_shell', 'end') - data_list.append(data) - offset += len(data) - - # allocate program data - dw.write_com(binw.Command.ALLOCATE_CODE) - dw.write_int(offset) - - # write aux code - for name, out_offs, lengths in aux_function_mem_layout: - dw.write_com(binw.Command.COPY_CODE) - dw.write_int(out_offs) - dw.write_int(lengths) - dw.write_bytes(sdb.get_function_code(name)) - - # write program code - dw.write_com(binw.Command.COPY_CODE) - dw.write_int(aux_function_lengths) - dw.write_int(offset - aux_function_lengths) - dw.write_bytes(b''.join(data_list)) - - # write relocations - for patch_type, patch_addr, addr, patch_command in patch_list: - dw.write_com(patch_command) - dw.write_int(patch_addr) - dw.write_int(patch_type) - dw.write_int(addr, signed=True) - - dw.write_com(binw.Command.ENTRY_POINT) - dw.write_int(aux_function_lengths) - - return dw, variables - - -class Target(): - - def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None: - self.sdb = stencil_db_from_package(arch, optimization) - self._variables: dict[Net, tuple[int, int, str]] = dict() - - def compile(self, *variables: int | float | cpint | cpfloat | cpbool | list[int | float | cpint | cpfloat | cpbool]) -> None: - nodes: list[Node] = [] - for s in variables: - if isinstance(s, list): - for net in s: - assert isinstance(net, Net), f"The folowing element is not a Net: {net}" - nodes.append(Write(net)) - else: - nodes.append(Write(s)) - - dw, self._variables = compile_to_instruction_list(nodes, self.sdb) - dw.write_com(binw.Command.END_COM) - assert coparun(dw.get_data()) > 0 - - def run(self) -> None: - # set entry point and run code - dw = binw.data_writer(self.sdb.byteorder) - dw.write_com(binw.Command.RUN_PROG) - dw.write_com(binw.Command.END_COM) - assert coparun(dw.get_data()) > 0 - - @overload - def read_value(self, net: cpbool) -> bool: - ... - - @overload - def read_value(self, net: cpfloat) -> float: - ... - - @overload - def read_value(self, net: cpint) -> int: - ... - - @overload - def read_value(self, net: NumLike) -> float | int | bool: - ... - - def read_value(self, net: NumLike) -> float | int | bool: - assert isinstance(net, Net), "Variable must be a copapy variable object" - assert net in self._variables, f"Variable {net} not found" - addr, lengths, var_type = self._variables[net] - assert lengths > 0 - data = read_data_mem(addr, lengths) - assert data is not None and len(data) == lengths, f"Failed to read variable {net}" - en = {'little': '<', 'big': '>'}[self.sdb.byteorder] - if var_type == 'float': - if lengths == 4: - value = struct.unpack(en + 'f', data)[0] - elif lengths == 8: - value = struct.unpack(en + 'd', data)[0] - else: - raise ValueError(f"Unsupported float length: {lengths} bytes") - assert isinstance(value, float) - return value - elif var_type == 'int': - assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes" - value = int.from_bytes(data, byteorder=self.sdb.byteorder, signed=True) - return value - elif var_type == 'bool': - assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes" - value = bool.from_bytes(data, byteorder=self.sdb.byteorder, signed=True) - return value - else: - raise ValueError(f"Unsupported variable type: {var_type}") - - def read_value_remote(self, net: Net) -> None: - dw = binw.data_writer(self.sdb.byteorder) - add_read_command(dw, self._variables, net) - assert coparun(dw.get_data()) > 0 - - -def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None: - assert net in variables, f"Variable {net} not found in data writer variables" - addr, lengths, _ = variables[net] - dw.write_com(binw.Command.READ_DATA) - dw.write_int(addr) - dw.write_int(lengths) +from ._target import Target +from ._basic_types import NumLike, cpbool, cpfloat, cpint, \ + CPNumber, cpvalue, cpvector, generic_sdb, iif + +__all__ = [ + "Target", + "NumLike", + "cpbool", + "cpfloat", + "cpint", + "CPNumber", + "cpvalue", + "cpvector", + "generic_sdb", + "iif", +] \ No newline at end of file diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py new file mode 100644 index 0000000..edf630b --- /dev/null +++ b/src/copapy/_basic_types.py @@ -0,0 +1,438 @@ +import pkgutil +from typing import Any, TypeVar, overload, TypeAlias +from ._stencils import stencil_database +import platform + +NumLike: TypeAlias = 'cpint | cpfloat | cpbool | int | float| bool' +NumLikeAndNet: TypeAlias = 'cpint | cpfloat | cpbool | int | float | bool | Net' +NetAndNum: TypeAlias = 'Net | int | float' + +unifloat: TypeAlias = 'cpfloat | float' +uniint: TypeAlias = 'cpint | int' +unibool: TypeAlias = 'cpbool | bool' + +TNumber = TypeVar("TNumber", bound='CPNumber') +T = TypeVar("T") + + +def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: + return [name for name, value in scope.items() if value is var] + + +def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> stencil_database: + if arch == 'native': + arch = platform.machine() + stencil_data = pkgutil.get_data(__name__, f"obj/stencils_{arch}_{optimization}.o") + assert stencil_data, f"stencils_{arch}_{optimization} not found" + return stencil_database(stencil_data) + + +generic_sdb = stencil_db_from_package() + + +def transl_type(t: str) -> str: + return {'bool': 'int'}.get(t, t) + + +class Node: + def __init__(self) -> None: + self.args: list[Net] = [] + self.name: str = '' + + def __repr__(self) -> str: + return f"Node:{self.name}({', '.join(str(a) for a in self.args) if self.args else (self.value if isinstance(self, InitVar) else '')})" + + +class Device(): + pass + + +class Net: + def __init__(self, dtype: str, source: Node): + self.dtype = dtype + self.source = source + + def __repr__(self) -> str: + names = get_var_name(self) + return f"{'name:' + names[0] if names else 'id:' + str(id(self))[-5:]}" + + def __hash__(self) -> int: + return id(self) + + +class CPNumber(Net): + def __init__(self, dtype: str, source: Node): + self.dtype = dtype + self.source = source + + @overload + def __mul__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __mul__(self, other: unifloat) -> 'cpfloat': + ... + + def __mul__(self, other: NumLike) -> 'CPNumber': + return _add_op('mul', [self, other], True) + + @overload + def __rmul__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rmul__(self, other: unifloat) -> 'cpfloat': + ... + + def __rmul__(self, other: NumLike) -> 'CPNumber': + return _add_op('mul', [self, other], True) + + @overload + def __add__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __add__(self, other: unifloat) -> 'cpfloat': + ... + + def __add__(self, other: NumLike) -> 'CPNumber': + return _add_op('add', [self, other], True) + + @overload + def __radd__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __radd__(self, other: unifloat) -> 'cpfloat': + ... + + def __radd__(self, other: NumLike) -> 'CPNumber': + return _add_op('add', [self, other], True) + + @overload + def __sub__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __sub__(self, other: unifloat) -> 'cpfloat': + ... + + def __sub__(self, other: NumLike) -> 'CPNumber': + return _add_op('sub', [self, other]) + + @overload + def __rsub__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rsub__(self, other: unifloat) -> 'cpfloat': + ... + + def __rsub__(self, other: NumLike) -> 'CPNumber': + return _add_op('sub', [other, self]) + + def __truediv__(self, other: NumLike) -> 'cpfloat': + ret = _add_op('div', [self, other]) + assert isinstance(ret, cpfloat) + return ret + + def __rtruediv__(self, other: NumLike) -> 'cpfloat': + ret = _add_op('div', [other, self]) + assert isinstance(ret, cpfloat) + return ret + + @overload + def __floordiv__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __floordiv__(self, other: unifloat) -> 'cpfloat': + ... + + def __floordiv__(self, other: NumLike) -> 'CPNumber': + return _add_op('floordiv', [self, other]) + + @overload + def __rfloordiv__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rfloordiv__(self, other: unifloat) -> 'cpfloat': + ... + + def __rfloordiv__(self, other: NumLike) -> 'CPNumber': + return _add_op('floordiv', [other, self]) + + def __neg__(self: TNumber) -> TNumber: + ret = _add_op('sub', [cpvalue(0), self]) + assert isinstance(ret, type(self)) + return ret + + def __gt__(self, other: NumLike) -> 'cpbool': + ret = _add_op('gt', [self, other]) + return cpbool(ret.source) + + def __lt__(self, other: NumLike) -> 'cpbool': + ret = _add_op('gt', [other, self]) + return cpbool(ret.source) + + def __eq__(self, other: NumLike) -> 'cpbool': # type: ignore + ret = _add_op('eq', [self, other], True) + return cpbool(ret.source) + + def __ne__(self, other: NumLike) -> 'cpbool': # type: ignore + ret = _add_op('ne', [self, other], True) + return cpbool(ret.source) + + @overload + def __mod__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __mod__(self, other: unifloat) -> 'cpfloat': + ... + + def __mod__(self, other: NumLike) -> 'CPNumber': + return _add_op('mod', [self, other]) + + @overload + def __rmod__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rmod__(self, other: unifloat) -> 'cpfloat': + ... + + def __rmod__(self, other: NumLike) -> 'CPNumber': + return _add_op('mod', [other, self]) + + @overload + def __pow__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __pow__(self, other: unifloat) -> 'cpfloat': + ... + + def __pow__(self, other: NumLike) -> 'CPNumber': + return _add_op('pow', [other, self]) + + @overload + def __rpow__(self: TNumber, other: uniint) -> TNumber: + ... + + @overload + def __rpow__(self, other: unifloat) -> 'cpfloat': + ... + + def __rpow__(self, other: NumLike) -> 'CPNumber': + return _add_op('rpow', [self, other]) + + def __hash__(self) -> int: + return super().__hash__() + + +class cpint(CPNumber): + def __init__(self, source: int | Node): + if isinstance(source, Node): + self.source = source + else: + self.source = InitVar(int(source)) + self.dtype = 'int' + + def __lshift__(self, other: uniint) -> 'cpint': + ret = _add_op('lshift', [self, other]) + assert isinstance(ret, cpint) + return ret + + def __rlshift__(self, other: uniint) -> 'cpint': + ret = _add_op('lshift', [other, self]) + assert isinstance(ret, cpint) + return ret + + def __rshift__(self, other: uniint) -> 'cpint': + ret = _add_op('rshift', [self, other]) + assert isinstance(ret, cpint) + return ret + + def __rrshift__(self, other: uniint) -> 'cpint': + ret = _add_op('rshift', [other, self]) + assert isinstance(ret, cpint) + return ret + + def __and__(self, other: uniint) -> 'cpint': + ret = _add_op('bwand', [self, other], True) + assert isinstance(ret, cpint) + return ret + + def __rand__(self, other: uniint) -> 'cpint': + ret = _add_op('rwand', [other, self], True) + assert isinstance(ret, cpint) + return ret + + def __or__(self, other: uniint) -> 'cpint': + ret = _add_op('bwor', [self, other], True) + assert isinstance(ret, cpint) + return ret + + def __ror__(self, other: uniint) -> 'cpint': + ret = _add_op('bwor', [other, self], True) + assert isinstance(ret, cpint) + return ret + + def __xor__(self, other: uniint) -> 'cpint': + ret = _add_op('bwxor', [self, other], True) + assert isinstance(ret, cpint) + return ret + + def __rxor__(self, other: uniint) -> 'cpint': + ret = _add_op('bwxor', [other, self], True) + assert isinstance(ret, cpint) + return ret + + +class cpfloat(CPNumber): + def __init__(self, source: float | Node | CPNumber): + if isinstance(source, Node): + self.source = source + elif isinstance(source, CPNumber): + self.source = _add_op('cast_float', [source]).source + else: + self.source = InitVar(float(source)) + self.dtype = 'float' + + +class cpbool(cpint): + def __init__(self, source: bool | Node): + if isinstance(source, Node): + self.source = source + else: + self.source = InitVar(bool(source)) + self.dtype = 'bool' + + +class cpvector: + def __init__(self, *value: NumLike): + self.value = value + + def __add__(self, other: 'cpvector') -> 'cpvector': + assert len(self.value) == len(other.value) + tup = (a + b for a, b in zip(self.value, other.value)) + return cpvector(*(v for v in tup if isinstance(v, CPNumber))) + + +class InitVar(Node): + def __init__(self, value: int | float): + self.dtype, self.value = _get_data_and_dtype(value) + self.name = 'const_' + self.dtype + self.args = [] + + +class Write(Node): + def __init__(self, input: NetAndNum): + if isinstance(input, Net): + net = input + else: + node = InitVar(input) + net = Net(node.dtype, node) + + self.name = 'write_' + transl_type(net.dtype) + self.args = [net] + + +class Op(Node): + def __init__(self, typed_op_name: str, args: list[Net]): + assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]' + self.name: str = typed_op_name + self.args: list[Net] = args + + +def net_from_value(value: Any) -> Net: + vi = InitVar(value) + return Net(vi.dtype, vi) + + +@overload +def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> cpbool: # pyright: ignore[reportOverlappingOverload] + ... + + +@overload +def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> cpint: + ... + + +@overload +def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> cpfloat: + ... + + +@overload +def iif(expression: NumLike, true_result: T, false_result: T) -> T: + ... + + +def iif(expression: Any, true_result: Any, false_result: Any) -> Any: + # TODO: check that input types are matching + alowed_type = cpint | cpfloat | cpbool | int | float | bool + assert isinstance(true_result, alowed_type) and isinstance(false_result, alowed_type), "Result type not supported" + if isinstance(expression, CPNumber): + return (expression != 0) * true_result + (expression == 0) * false_result + else: + return true_result if expression else false_result + + +def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> CPNumber: + arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] + + if commutative: + arg_nets = sorted(arg_nets, key=lambda a: a.dtype) + + typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) + + if typed_op not in generic_sdb.stencil_definitions: + raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}") + + result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] + + if result_type == 'int': + return cpint(Op(typed_op, arg_nets)) + else: + return cpfloat(Op(typed_op, arg_nets)) + + +@overload +def cpvalue(value: bool) -> cpbool: # pyright: ignore[reportOverlappingOverload] + ... + + +@overload +def cpvalue(value: int) -> cpint: + ... + + +@overload +def cpvalue(value: float) -> cpfloat: + ... + + +def cpvalue(value: bool | int | float) -> cpbool | cpint | cpfloat: + vi = InitVar(value) + + if isinstance(value, bool): + return cpbool(vi) + elif isinstance(value, float): + return cpfloat(vi) + else: + return cpint(vi) + + +def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: + if isinstance(value, bool): + return ('bool', int(value)) + elif isinstance(value, int): + return ('int', int(value)) + elif isinstance(value, float): + return ('float', float(value)) + else: + raise ValueError(f'Non supported data type: {type(value).__name__}') diff --git a/src/copapy/binwrite.py b/src/copapy/_binwrite.py similarity index 100% rename from src/copapy/binwrite.py rename to src/copapy/_binwrite.py diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py new file mode 100644 index 0000000..8019855 --- /dev/null +++ b/src/copapy/_compiler.py @@ -0,0 +1,280 @@ +from typing import Generator, Iterable, Any +from . import _binwrite as binw +from ._stencils import stencil_database +from collections import defaultdict, deque +from ._basic_types import Net, Node, Write, InitVar, Op, transl_type + + +def stable_toposort(edges: Iterable[tuple[Node, Node]]) -> list[Node]: + """Perform a stable topological sort on a directed acyclic graph (DAG). + Arguments: + edges: Iterable of (u, v) pairs meaning u -> v + + Returns: + List of nodes in topologically sorted order. + """ + + # Track adjacency and indegrees + adj: defaultdict[Node, list[Node]] = defaultdict(list) + indeg: defaultdict[Node, int] = defaultdict(int) + order: dict[Node, int] = {} # first-appearance order of each node + + # Build graph and order map + pos = 0 + for u, v in edges: + if u not in order: + order[u] = pos + pos += 1 + if v not in order: + order[v] = pos + pos += 1 + adj[u].append(v) + indeg[v] += 1 + indeg.setdefault(u, 0) + + # Initialize queue with nodes of indegree 0, sorted by first appearance + queue = deque(sorted([n for n in indeg if indeg[n] == 0], key=lambda x: order[x])) + result: list[Node] = [] + + while queue: + node = queue.popleft() + result.append(node) + + for nei in adj[node]: + indeg[nei] -= 1 + if indeg[nei] == 0: + queue.append(nei) + + # Maintain stability: sort queue by appearance order + queue = deque(sorted(queue, key=lambda x: order[x])) + + # Check if graph had a cycle (not all nodes output) + if len(result) != len(indeg): + raise ValueError("Graph contains a cycle — topological sort not possible") + + return result + + +def get_all_dag_edges(nodes: Iterable[Node]) -> Generator[tuple[Node, Node], None, None]: + """Get all edges in the DAG by traversing from the given nodes + + Arguments: + nodes: Iterable of nodes to start the traversal from + + Yields: + Tuples of (source_node, target_node) representing edges in the DAG + """ + for node in nodes: + yield from get_all_dag_edges(net.source for net in node.args) + yield from ((net.source, node) for net in node.args) + + +def get_const_nets(nodes: list[Node]) -> list[Net]: + """Get all nets with a constant nodes value + + Returns: + List of nets whose source node is a Const + """ + net_lookup = {net.source: net for node in nodes for net in node.args} + return [net_lookup[node] for node in nodes if isinstance(node, InitVar)] + + +def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], None, None]: + """Add read node before each op where arguments are not already positioned + correctly in the registers + + Arguments: + node_list: List of nodes in the order of execution + + Returns: + Yields tuples of a net and a node. The net is the result net + for the node. If the node has no result net None is returned in the tuple. + """ + + registers: list[None | Net] = [None] * 2 + + # Generate result net lookup table + net_lookup = {net.source: net for node in node_list for net in node.args} + + for node in node_list: + if not isinstance(node, InitVar): + for i, net in enumerate(node.args): + if id(net) != id(registers[i]): + #if net in registers: + # print('x swap registers') + type_list = ['int' if r is None else transl_type(r.dtype) for r in registers] + new_node = Op(f"read_{transl_type(net.dtype)}_reg{i}_" + '_'.join(type_list), []) + yield net, new_node + registers[i] = net + + if node in net_lookup: + yield net_lookup[node], node + registers[0] = net_lookup[node] + else: + yield None, node + + +def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list[Net]) -> Generator[tuple[Net | None, Node], None, None]: + """Add write operation for each new defined net if a read operation is later followed + + Returns: + Yields tuples of a net and a node. The associated net is provided for read and write nodes. + Otherwise None is returned in the tuple. + """ + + # Initialize set of nets with constants + stored_nets = set(const_nets) + + #assert all(node.name.startswith('read_') for net, node in net_node_list if net) + read_back_nets = { + net for net, node in net_node_list + if net and node.name.startswith('read_')} + + for net, node in net_node_list: + if isinstance(node, Write): + yield node.args[0], node + elif node.name.startswith('read_'): + yield net, node + else: + yield None, node + + if net and net in read_back_nets and net not in stored_nets: + yield net, Write(net) + stored_nets.add(net) + + +def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]: + nets: set[Net] = set() + + for input in inputs: + for el in input: + for net in el: + if isinstance(net, Net): + nets.add(net) + + return list(nets) + + +def get_variable_mem_layout(variable_list: Iterable[Net], sdb: stencil_database) -> tuple[list[tuple[Net, int, int]], int]: + offset: int = 0 + object_list: list[tuple[Net, int, int]] = [] + + for variable in variable_list: + lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype)) + object_list.append((variable, offset, lengths)) + offset += (lengths + 3) // 4 * 4 + + return object_list, offset + + +def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_database) -> tuple[list[tuple[str, int, int]], int]: + offset: int = 0 + function_list: list[tuple[str, int, int]] = [] + + for name in function_names: + lengths = sdb.get_symbol_size(name) + function_list.append((name, offset, lengths)) + offset += (lengths + 3) // 4 * 4 + + return function_list, offset + + +def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: + variables: dict[Net, tuple[int, int, str]] = dict() + + ordered_ops = list(stable_toposort(get_all_dag_edges(node_list))) + const_net_list = get_const_nets(ordered_ops) + output_ops = list(add_read_ops(ordered_ops)) + extended_output_ops = list(add_write_ops(output_ops, const_net_list)) + + dw = binw.data_writer(sdb.byteorder) + + # Deallocate old allocated memory (if existing) + dw.write_com(binw.Command.FREE_MEMORY) + + # Get all nets/variables associated with heap memory + variable_list = get_nets([[const_net_list]], extended_output_ops) + + # Write data + variable_mem_layout, data_section_lengths = get_variable_mem_layout(variable_list, sdb) + dw.write_com(binw.Command.ALLOCATE_DATA) + dw.write_int(data_section_lengths) + + for net, out_offs, lengths in variable_mem_layout: + variables[net] = (out_offs, lengths, net.dtype) + if isinstance(net.source, InitVar): + dw.write_com(binw.Command.COPY_DATA) + dw.write_int(out_offs) + dw.write_int(lengths) + dw.write_value(net.source.value, lengths) + # print(f'+ {net.dtype} {net.source.value}') + + # prep auxiliary_functions + aux_function_names = sdb.get_sub_functions(node.name for _, node in extended_output_ops) + aux_function_mem_layout, aux_function_lengths = get_aux_function_mem_layout(aux_function_names, sdb) + aux_func_addr_lookup = {name: offs for name, offs, _ in aux_function_mem_layout} + + # Prepare program code and relocations + object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout} + data_list: list[bytes] = [] + patch_list: list[tuple[int, int, int, binw.Command]] = [] + offset = aux_function_lengths # offset in generated code chunk + + # assemble stencils to main program + data = sdb.get_function_code('entry_function_shell', 'start') + data_list.append(data) + offset += len(data) + + for associated_net, node in extended_output_ops: + assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found" + data = sdb.get_stencil_code(node.name) + data_list.append(data) + # print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) + + for patch in sdb.get_patch_positions(node.name): + if patch.target_symbol_info == 'STT_OBJECT': + assert associated_net, f"Relocation found but no net defined for operation {node.name}" + addr = object_addr_lookup[associated_net] + patch_value = addr + patch.addend - (offset + patch.addr) + patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) + elif patch.target_symbol_info == 'STT_FUNC': + addr = aux_func_addr_lookup[patch.target_symbol_name] + patch_value = addr + patch.addend - (offset + patch.addr) + patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC)) + else: + raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") + + offset += len(data) + + data = sdb.get_function_code('entry_function_shell', 'end') + data_list.append(data) + offset += len(data) + + # allocate program data + dw.write_com(binw.Command.ALLOCATE_CODE) + dw.write_int(offset) + + # write aux functions + for name, out_offs, lengths in aux_function_mem_layout: + dw.write_com(binw.Command.COPY_CODE) + dw.write_int(out_offs) + dw.write_int(lengths) + dw.write_bytes(sdb.get_function_code(name)) + + # write entry function code + dw.write_com(binw.Command.COPY_CODE) + dw.write_int(aux_function_lengths) + dw.write_int(offset - aux_function_lengths) + dw.write_bytes(b''.join(data_list)) + + # write patch operations + for patch_type, patch_addr, addr, patch_command in patch_list: + dw.write_com(patch_command) + dw.write_int(patch_addr) + dw.write_int(patch_type) + dw.write_int(addr, signed=True) + + dw.write_com(binw.Command.ENTRY_POINT) + dw.write_int(aux_function_lengths) + + return dw, variables \ No newline at end of file diff --git a/src/copapy/stencil_db.py b/src/copapy/_stencils.py similarity index 100% rename from src/copapy/stencil_db.py rename to src/copapy/_stencils.py diff --git a/src/copapy/_target.py b/src/copapy/_target.py new file mode 100644 index 0000000..631cb70 --- /dev/null +++ b/src/copapy/_target.py @@ -0,0 +1,91 @@ +from typing import overload +from . import _binwrite as binw +from coparun_module import coparun, read_data_mem +import struct +from ._basic_types import stencil_db_from_package +from ._basic_types import cpbool, cpint, cpfloat, Net, Node, Write, NumLike +from ._compiler import compile_to_instruction_list + + +def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None: + assert net in variables, f"Variable {net} not found in data writer variables" + addr, lengths, _ = variables[net] + dw.write_com(binw.Command.READ_DATA) + dw.write_int(addr) + dw.write_int(lengths) + + +class Target(): + def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None: + self.sdb = stencil_db_from_package(arch, optimization) + self._variables: dict[Net, tuple[int, int, str]] = dict() + + def compile(self, *variables: int | float | cpint | cpfloat | cpbool | list[int | float | cpint | cpfloat | cpbool]) -> None: + nodes: list[Node] = [] + for s in variables: + if isinstance(s, list): + for net in s: + assert isinstance(net, Net), f"The folowing element is not a Net: {net}" + nodes.append(Write(net)) + else: + nodes.append(Write(s)) + + dw, self._variables = compile_to_instruction_list(nodes, self.sdb) + dw.write_com(binw.Command.END_COM) + assert coparun(dw.get_data()) > 0 + + def run(self) -> None: + # set entry point and run code + dw = binw.data_writer(self.sdb.byteorder) + dw.write_com(binw.Command.RUN_PROG) + dw.write_com(binw.Command.END_COM) + assert coparun(dw.get_data()) > 0 + + @overload + def read_value(self, net: cpbool) -> bool: + ... + + @overload + def read_value(self, net: cpfloat) -> float: + ... + + @overload + def read_value(self, net: cpint) -> int: + ... + + @overload + def read_value(self, net: NumLike) -> float | int | bool: + ... + + def read_value(self, net: NumLike) -> float | int | bool: + assert isinstance(net, Net), "Variable must be a copapy variable object" + assert net in self._variables, f"Variable {net} not found" + addr, lengths, var_type = self._variables[net] + assert lengths > 0 + data = read_data_mem(addr, lengths) + assert data is not None and len(data) == lengths, f"Failed to read variable {net}" + en = {'little': '<', 'big': '>'}[self.sdb.byteorder] + if var_type == 'float': + if lengths == 4: + value = struct.unpack(en + 'f', data)[0] + elif lengths == 8: + value = struct.unpack(en + 'd', data)[0] + else: + raise ValueError(f"Unsupported float length: {lengths} bytes") + assert isinstance(value, float) + return value + elif var_type == 'int': + assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes" + value = int.from_bytes(data, byteorder=self.sdb.byteorder, signed=True) + return value + elif var_type == 'bool': + assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes" + value = bool.from_bytes(data, byteorder=self.sdb.byteorder, signed=True) + return value + else: + raise ValueError(f"Unsupported variable type: {var_type}") + + def read_value_remote(self, net: Net) -> None: + dw = binw.data_writer(self.sdb.byteorder) + add_read_command(dw, self._variables, net) + assert coparun(dw.get_data()) > 0 \ No newline at end of file diff --git a/src/copapy/backend.py b/src/copapy/backend.py new file mode 100644 index 0000000..447a352 --- /dev/null +++ b/src/copapy/backend.py @@ -0,0 +1,20 @@ +from ._target import add_read_command +from ._basic_types import Net, Op, Node, InitVar, Write +from ._compiler import compile_to_instruction_list, \ + stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, \ + add_write_ops + +__all__ = [ + "add_read_command", + "Net", + "Op", + "Node", + "InitVar", + "Write", + "compile_to_instruction_list", + "stable_toposort", + "get_const_nets", + "get_all_dag_edges", + "add_read_ops", + "add_write_ops", +] \ No newline at end of file diff --git a/tests/test_ast_gen.py b/tests/test_ast_gen.py index ff4b7ed..15313b9 100644 --- a/tests/test_ast_gen.py +++ b/tests/test_ast_gen.py @@ -1,5 +1,6 @@ -from copapy import Write, cpvalue -import copapy as rc +from copapy import cpvalue +from copapy.backend import Write +import copapy.backend as cpbe def test_ast_generation(): @@ -32,27 +33,27 @@ def test_ast_generation(): print(out) print('-- get_edges:') - edges = list(rc.get_all_dag_edges(out)) + edges = list(cpbe.get_all_dag_edges(out)) for p in edges: print('#', p) print('-- get_ordered_ops:') - ordered_ops = list(rc.stable_toposort(edges)) + ordered_ops = list(cpbe.stable_toposort(edges)) for p in ordered_ops: print('#', p) print('-- get_consts:') - const_list = rc.get_const_nets(ordered_ops) + const_list = cpbe.get_const_nets(ordered_ops) for p in const_list: print('#', p) print('-- add_read_ops:') - output_ops = list(rc.add_read_ops(ordered_ops)) + output_ops = list(cpbe.add_read_ops(ordered_ops)) for p in output_ops: print('#', p) print('-- add_write_ops:') - extended_output_ops = list(rc.add_write_ops(output_ops, const_list)) + extended_output_ops = list(cpbe.add_write_ops(output_ops, const_list)) for p in extended_output_ops: print('#', p) print('--') diff --git a/tests/test_compile.py b/tests/test_compile.py index c49cd98..86acbfa 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -1,8 +1,10 @@ -from copapy import Write, cpvalue, NumLike +from copapy import cpvalue, NumLike +from copapy.backend import Write, compile_to_instruction_list, add_read_command import copapy import subprocess import struct -from copapy import binwrite +from copapy import _binwrite +import copapy.backend def run_command(command: list[str]) -> str: @@ -47,16 +49,16 @@ def test_compile(): out = [Write(r) for r in ret] - il, variables = copapy.compile_to_instruction_list(out, copapy.generic_sdb) + il, variables = compile_to_instruction_list(out, copapy.generic_sdb) # run program command - il.write_com(binwrite.Command.RUN_PROG) + il.write_com(_binwrite.Command.RUN_PROG) for net in ret: - assert isinstance(net, copapy.Net) - copapy.add_read_command(il, variables, net) + assert isinstance(net, copapy.backend.Net) + add_read_command(il, variables, net) - il.write_com(binwrite.Command.END_COM) + il.write_com(_binwrite.Command.END_COM) print('* Data to runner:') il.print() diff --git a/tests/test_compile_div.py b/tests/test_compile_div.py index 9960423..12513bc 100644 --- a/tests/test_compile_div.py +++ b/tests/test_compile_div.py @@ -1,7 +1,8 @@ -from copapy import Write, cpvalue, NumLike +from copapy import cpvalue, NumLike +from copapy.backend import Write, compile_to_instruction_list import copapy import subprocess -from copapy import binwrite +from copapy import _binwrite def run_command(command: list[str], encoding: str = 'utf8') -> str: @@ -25,16 +26,16 @@ def test_compile(): out = [Write(r) for r in ret] - il, _ = copapy.compile_to_instruction_list(out, copapy.generic_sdb) + il, _ = compile_to_instruction_list(out, copapy.generic_sdb) # run program command - il.write_com(binwrite.Command.RUN_PROG) + il.write_com(_binwrite.Command.RUN_PROG) - il.write_com(binwrite.Command.READ_DATA) + il.write_com(_binwrite.Command.READ_DATA) il.write_int(0) il.write_int(36) - il.write_com(binwrite.Command.END_COM) + il.write_com(_binwrite.Command.END_COM) print('* Data to runner:') il.print() diff --git a/tests/test_coparun_module.py b/tests/test_coparun_module.py index 77e7e0f..f05f5f1 100644 --- a/tests/test_coparun_module.py +++ b/tests/test_coparun_module.py @@ -31,7 +31,7 @@ def test_compile(): for test, ref, name in zip(ret, ret_ref, ['i1', 'i2', 'r1', 'r2']): val = tg.read_value(test) print('+', name, val, ref) - assert val == pytest.approx(ref, 1e-5), name + assert val == pytest.approx(ref, 1e-5), name # pyright: ignore[reportUnknownMemberType] if __name__ == "__main__": diff --git a/tests/test_coparun_module2.py b/tests/test_coparun_module2.py index 1c83f9f..00bd1a6 100644 --- a/tests/test_coparun_module2.py +++ b/tests/test_coparun_module2.py @@ -1,7 +1,8 @@ from coparun_module import coparun -from copapy import Write, cpvalue +from copapy import cpvalue +from copapy.backend import Write, compile_to_instruction_list, add_read_command import copapy -from copapy import binwrite +from copapy import _binwrite def test_compile(): @@ -14,16 +15,16 @@ def test_compile(): r2 = i1 + 9 out = [Write(r1), Write(r2), Write(c2)] - il, variables = copapy.compile_to_instruction_list(out, copapy.generic_sdb) + il, variables = compile_to_instruction_list(out, copapy.generic_sdb) # run program command - il.write_com(binwrite.Command.RUN_PROG) + il.write_com(_binwrite.Command.RUN_PROG) for net in (c1, c2, i1, r1, r2): - copapy.add_read_command(il, variables, net) + add_read_command(il, variables, net) # run program command - il.write_com(binwrite.Command.END_COM) + il.write_com(_binwrite.Command.END_COM) #print('* Data to runner:') #il.print() diff --git a/tests/test_crash_win.py b/tests/test_crash_win.py index 49837a1..3f8a3be 100644 --- a/tests/test_crash_win.py +++ b/tests/test_crash_win.py @@ -1,7 +1,8 @@ -from copapy import NumLike, Write, cpvalue, Net +from copapy import NumLike, cpvalue +from copapy.backend import Write, Net, compile_to_instruction_list, add_read_command import copapy import subprocess -from copapy import binwrite +from copapy import _binwrite def run_command(command: list[str], encoding: str = 'utf8') -> str: @@ -28,21 +29,21 @@ def test_compile(): ret = function(c1, c2) - dw, variable_list = copapy.compile_to_instruction_list([Write(net) for net in ret], copapy.generic_sdb) + dw, variable_list = compile_to_instruction_list([Write(net) for net in ret], copapy.generic_sdb) # run program command - dw.write_com(binwrite.Command.RUN_PROG) + dw.write_com(_binwrite.Command.RUN_PROG) - dw.write_com(binwrite.Command.READ_DATA) + dw.write_com(_binwrite.Command.READ_DATA) dw.write_int(0) dw.write_int(36) for net, name in zip(ret, ['i1', 'i2', 'r1', 'r2']): print('+', name) assert isinstance(net, Net) - copapy.add_read_command(dw, variable_list, net) + add_read_command(dw, variable_list, net) - dw.write_com(binwrite.Command.END_COM) + dw.write_com(_binwrite.Command.END_COM) dw.to_file('bin/test.copapy') result = run_command(['bin/coparun', 'bin/test.copapy']) diff --git a/tests/test_ops.py b/tests/test_ops.py index 8b69f74..eff7f5a 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -1,5 +1,6 @@ -from copapy import cpvalue, Target, NumLike, Net, iif, cpint -from pytest import approx +from copapy import cpvalue, Target, NumLike, iif, cpint +import pytest +import copapy def function1(c1: NumLike) -> list[NumLike]: @@ -56,12 +57,12 @@ def test_compile(): print('* finished') for test, ref in zip(ret_test, ret_ref): - assert isinstance(test, Net) + assert isinstance(test, copapy.CPNumber) val = tg.read_value(test) print('+', val, ref, test.dtype) - for t in [int, float, bool]: + for t in (int, float, bool): assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" - assert val == approx(ref, 1e-5), f"Result does not match: {val} and reference: {ref}" + assert val == pytest.approx(ref, 1e-5), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] if __name__ == "__main__": diff --git a/tests/test_stencil_db.py b/tests/test_stencil_db.py index 090079d..49337d9 100644 --- a/tests/test_stencil_db.py +++ b/tests/test_stencil_db.py @@ -1,4 +1,4 @@ -from copapy import stencil_database, stencil_db +from copapy._stencils import stencil_database, get_stencil_position import platform arch = platform.machine() @@ -19,9 +19,9 @@ def test_start_end_function(): if symbol.relocations and symbol.relocations[-1].symbol.info == 'STT_NOTYPE': - print('-', sym_name, stencil_db.get_stencil_position(symbol), len(symbol.data)) + print('-', sym_name, get_stencil_position(symbol), len(symbol.data)) - start, end = stencil_db.get_stencil_position(symbol) + start, end = get_stencil_position(symbol) assert start >= 0 and end >= start and end <= len(symbol.data) diff --git a/tools/extract_code.py b/tools/extract_code.py index 230f89a..71646f7 100644 --- a/tools/extract_code.py +++ b/tools/extract_code.py @@ -1,5 +1,5 @@ -from copapy.binwrite import data_reader, Command, ByteOrder -from copapy.stencil_db import RelocationType +from copapy._binwrite import data_reader, Command, ByteOrder +from copapy._stencils import RelocationType import argparse if __name__ == "__main__": diff --git a/tools/generate_stencils.py b/tools/generate_stencils.py index a5f0c1a..898e539 100644 --- a/tools/generate_stencils.py +++ b/tools/generate_stencils.py @@ -144,7 +144,7 @@ if __name__ == "__main__": // Auto-generated stencils for copapy // Do not edit manually - double math_pow(double arg1, double arg2); + double (*math_pow)(double, double); volatile int dummy_int = 1337; volatile float dummy_float = 1337; diff --git a/tools/make_example.py b/tools/make_example.py index 39b3df6..bdfb481 100644 --- a/tools/make_example.py +++ b/tools/make_example.py @@ -1,4 +1,5 @@ -from copapy import cpvalue, Write, binwrite +from copapy import _binwrite, cpvalue +from copapy.backend import Write, compile_to_instruction_list import copapy @@ -11,16 +12,16 @@ def test_compile() -> None: out = [Write(r) for r in ret] - il, _ = copapy.compile_to_instruction_list(out, copapy.generic_sdb) + il, _ = compile_to_instruction_list(out, copapy.generic_sdb) # run program command - il.write_com(binwrite.Command.RUN_PROG) + il.write_com(_binwrite.Command.RUN_PROG) - il.write_com(binwrite.Command.READ_DATA) + il.write_com(_binwrite.Command.READ_DATA) il.write_int(0) il.write_int(36) - il.write_com(binwrite.Command.END_COM) + il.write_com(_binwrite.Command.END_COM) print('* Data to runner:') il.print()