Merge pull request #1 from Nonannet/vector_feature

Vector feature
This commit is contained in:
Nicolas Kruse 2025-10-26 22:38:19 +01:00 committed by GitHub
commit eb82afface
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 682 additions and 349 deletions

24
.flake8
View File

@ -1,24 +0,0 @@
[flake8]
# Specify the maximum allowed line length
max-line-length = 88
# Ignore specific rules
# For example, E501: Line too long, W503: Line break before binary operator
ignore = E501, W503, W504, E226, E265
# Exclude specific files or directories
exclude =
.git,
__pycache__,
build,
dist,
.conda,
.venv
# Enable specific plugins or options
# Example: Enabling flake8-docstrings
select = C,E,F,W,D
# Specify custom error codes to ignore or enable
per-file-ignores =
tests/*: D, E712

View File

@ -14,6 +14,9 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Build & test aux functions
run: bash tools/test_stencil_aux.sh
- name: Build object files - name: Build object files
run: bash tools/crosscompile.sh run: bash tools/crosscompile.sh
@ -24,7 +27,7 @@ jobs:
build-ubuntu: build-ubuntu:
needs: [build_stencils] needs: [build_stencils]
runs-on: ubuntu-latest runs-on: ubuntu-24.04
strategy: strategy:
matrix: matrix:
@ -53,7 +56,10 @@ jobs:
gcc -DENABLE_BASIC_LOGGING -O3 -Wall -Wextra -Wconversion -Wsign-conversion -Wshadow -Wstrict-overflow -Werror -g src/coparun/runmem.c src/coparun/coparun.c src/coparun/mem_man.c -o bin/coparun gcc -DENABLE_BASIC_LOGGING -O3 -Wall -Wextra -Wconversion -Wsign-conversion -Wshadow -Wstrict-overflow -Werror -g src/coparun/runmem.c src/coparun/coparun.c src/coparun/mem_man.c -o bin/coparun
- name: Install ARM binutils - name: Install ARM binutils
run: sudo apt-get update && sudo apt-get install -y binutils-aarch64-linux-gnu run: |
which aarch64-linux-gnu-objdump || echo "Not found: aarch64-linux-gnu-objdump"
sudo apt-get install --no-install-recommends -y binutils-aarch64-linux-gnu
- name: Generate debug asm files - name: Generate debug asm files
if: strategy.job-index == 0 if: strategy.job-index == 0

View File

@ -32,7 +32,7 @@ copapy = ["obj/*.o", "py.typed"]
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"flake8", "ruff",
"mypy", "mypy",
"pytest" "pytest"
] ]
@ -51,3 +51,22 @@ minversion = "6.0"
addopts = "-ra -q" addopts = "-ra -q"
testpaths = ["tests"] testpaths = ["tests"]
pythonpath = ["src"] pythonpath = ["src"]
[tool.ruff]
lint.ignore = ["E501", "E226", "E265"]
# Equivalent to Flake8's "exclude"
exclude = [
".git",
"__pycache__",
"build",
"dist",
".conda",
".venv",
]
# "D" for dockstrings
lint.select = ["C", "E", "F", "W"]
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D", "E712"]

View File

@ -1,13 +1,16 @@
from ._target import Target from ._target import Target
from ._basic_types import NumLike, variable, \ from ._basic_types import NumLike, variable, \
CPNumber, cpvector, generic_sdb, iif generic_sdb, iif
from ._vectors import vector
from ._math import sqrt, abs
__all__ = [ __all__ = [
"Target", "Target",
"NumLike", "NumLike",
"variable", "variable",
"CPNumber",
"cpvector",
"generic_sdb", "generic_sdb",
"iif", "iif",
"vector",
"sqrt",
"abs",
] ]

View File

@ -4,15 +4,12 @@ from ._stencils import stencil_database
import platform import platform
NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool' NumLike: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool'
NumLikeAndNet: TypeAlias = 'variable[int] | variable[float] | variable[bool] | int | float | bool | Net'
NetAndNum: TypeAlias = 'Net | int | float'
unifloat: TypeAlias = 'variable[float] | float' unifloat: TypeAlias = 'variable[float] | float'
uniint: TypeAlias = 'variable[int] | int' uniint: TypeAlias = 'variable[int] | int'
unibool: TypeAlias = 'variable[bool] | bool' unibool: TypeAlias = 'variable[bool] | bool'
TNumber = TypeVar("TNumber", bound='CPNumber') TCPNum = TypeVar("TCPNum", bound='variable[Any]')
T = TypeVar("T") TNum = TypeVar("TNum", int, bool, float)
def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]: def get_var_name(var: Any, scope: dict[str, Any] = globals()) -> list[str]:
@ -60,231 +57,201 @@ class Net:
return id(self) return id(self)
class CPNumber(Net): class variable(Generic[TNum], Net):
def __init__(self, dtype: str, source: Node): def __init__(self, source: TNum | Node, dtype: str | None = None):
self.dtype = dtype
self.source = source
@overload
def __mul__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __mul__(self, other: unifloat) -> 'variable[float]':
...
def __mul__(self, other: NumLike) -> 'CPNumber':
return _add_op('mul', [self, other], True)
@overload
def __rmul__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rmul__(self, other: unifloat) -> 'variable[float]':
...
def __rmul__(self, other: NumLike) -> 'CPNumber':
return _add_op('mul', [self, other], True)
@overload
def __add__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __add__(self, other: unifloat) -> 'variable[float]':
...
def __add__(self, other: NumLike) -> 'CPNumber':
return _add_op('add', [self, other], True)
@overload
def __radd__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __radd__(self, other: unifloat) -> 'variable[float]':
...
def __radd__(self, other: NumLike) -> 'CPNumber':
return _add_op('add', [self, other], True)
@overload
def __sub__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __sub__(self, other: unifloat) -> 'variable[float]':
...
def __sub__(self, other: NumLike) -> 'CPNumber':
return _add_op('sub', [self, other])
@overload
def __rsub__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rsub__(self, other: unifloat) -> 'variable[float]':
...
def __rsub__(self, other: NumLike) -> 'CPNumber':
return _add_op('sub', [other, self])
def __truediv__(self, other: NumLike) -> 'variable[float]':
return _add_op('div', [self, other])
def __rtruediv__(self, other: NumLike) -> 'variable[float]':
return _add_op('div', [other, self])
@overload
def __floordiv__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __floordiv__(self, other: unifloat) -> 'variable[float]':
...
def __floordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [self, other])
@overload
def __rfloordiv__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rfloordiv__(self, other: unifloat) -> 'variable[float]':
...
def __rfloordiv__(self, other: NumLike) -> 'CPNumber':
return _add_op('floordiv', [other, self])
def __neg__(self: TNumber) -> TNumber:
assert isinstance(T, variable)
return cast(TNumber, _add_op('sub', [variable(0), self]))
def __gt__(self, other: NumLike) -> 'variable[bool]':
ret = _add_op('gt', [self, other])
return variable(ret.source, dtype='bool')
def __lt__(self, other: NumLike) -> 'variable[bool]':
ret = _add_op('gt', [other, self])
return variable(ret.source, dtype='bool')
def __eq__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = _add_op('eq', [self, other], True)
return variable(ret.source, dtype='bool')
def __ne__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = _add_op('ne', [self, other], True)
return variable(ret.source, dtype='bool')
@overload
def __mod__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __mod__(self, other: unifloat) -> 'variable[float]':
...
def __mod__(self, other: NumLike) -> 'CPNumber':
return _add_op('mod', [self, other])
@overload
def __rmod__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rmod__(self, other: unifloat) -> 'variable[float]':
...
def __rmod__(self, other: NumLike) -> 'CPNumber':
return _add_op('mod', [other, self])
@overload
def __pow__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __pow__(self, other: unifloat) -> 'variable[float]':
...
def __pow__(self, other: NumLike) -> 'CPNumber':
return _add_op('pow', [other, self])
@overload
def __rpow__(self: TNumber, other: uniint) -> TNumber:
...
@overload
def __rpow__(self, other: unifloat) -> 'variable[float]':
...
def __rpow__(self, other: NumLike) -> 'CPNumber':
return _add_op('rpow', [self, other])
def __hash__(self) -> int:
return super().__hash__()
class variable(Generic[T], CPNumber):
def __init__(self, source: T | Node, dtype: str | None = None):
if isinstance(source, Node): if isinstance(source, Node):
self.source = source self.source = source
assert dtype, 'For source type Node a dtype argument is required.' assert dtype, 'For source type Node a dtype argument is required.'
self.dtype = dtype self.dtype = dtype
elif isinstance(source, bool):
self.source = CPConstant(source)
self.dtype = 'bool'
elif isinstance(source, int):
self.source = CPConstant(source)
self.dtype = 'int'
elif isinstance(source, float): elif isinstance(source, float):
self.source = CPConstant(source) self.source = CPConstant(source)
self.dtype = 'float' self.dtype = 'float'
elif isinstance(source, bool):
self.source = CPConstant(source)
self.dtype = 'bool'
else: else:
raise ValueError(f'Non supported data type: {type(source).__name__}') self.source = CPConstant(source)
self.dtype = 'int'
@overload
def __add__(self, other: TCPNum) -> TCPNum: ...
@overload
def __add__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __add__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __add__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __add__(self, other: NumLike) -> Any:
return add_op('add', [self, other], True)
@overload
def __radd__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __radd__(self, other: float) -> 'variable[float]': ...
def __radd__(self, other: NumLike) -> Any:
return add_op('add', [self, other], True)
@overload
def __sub__(self, other: TCPNum) -> TCPNum: ...
@overload
def __sub__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __sub__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __sub__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __sub__(self, other: NumLike) -> Any:
return add_op('sub', [self, other])
@overload
def __rsub__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rsub__(self, other: float) -> 'variable[float]': ...
def __rsub__(self, other: NumLike) -> Any:
return add_op('sub', [other, self])
@overload
def __mul__(self, other: TCPNum) -> TCPNum: ...
@overload
def __mul__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __mul__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __mul__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __mul__(self, other: NumLike) -> Any:
return add_op('mul', [self, other], True)
@overload
def __rmul__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rmul__(self, other: float) -> 'variable[float]': ...
def __rmul__(self, other: NumLike) -> Any:
return add_op('mul', [self, other], True)
def __truediv__(self, other: NumLike) -> 'variable[float]':
return add_op('div', [self, other])
def __rtruediv__(self, other: NumLike) -> 'variable[float]':
return add_op('div', [other, self])
@overload
def __floordiv__(self, other: TCPNum) -> TCPNum: ...
@overload
def __floordiv__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __floordiv__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __floordiv__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __floordiv__(self, other: NumLike) -> Any:
return add_op('floordiv', [self, other])
@overload
def __rfloordiv__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rfloordiv__(self, other: float) -> 'variable[float]': ...
def __rfloordiv__(self, other: NumLike) -> Any:
return add_op('floordiv', [other, self])
def __neg__(self: TCPNum) -> TCPNum:
return cast(TCPNum, add_op('sub', [variable(0), self]))
def __gt__(self, other: NumLike) -> 'variable[bool]':
ret = add_op('gt', [self, other])
return variable(ret.source, dtype='bool')
def __lt__(self, other: NumLike) -> 'variable[bool]':
ret = add_op('gt', [other, self])
return variable(ret.source, dtype='bool')
def __ge__(self, other: NumLike) -> 'variable[bool]':
ret = add_op('ge', [self, other])
return variable(ret.source, dtype='bool')
def __le__(self, other: NumLike) -> 'variable[bool]':
ret = add_op('ge', [other, self])
return variable(ret.source, dtype='bool')
def __eq__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = add_op('eq', [self, other], True)
return variable(ret.source, dtype='bool')
def __ne__(self, other: NumLike) -> 'variable[bool]': # type: ignore
ret = add_op('ne', [self, other], True)
return variable(ret.source, dtype='bool')
@overload
def __mod__(self, other: TCPNum) -> TCPNum: ...
@overload
def __mod__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __mod__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __mod__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __mod__(self, other: NumLike) -> Any:
return add_op('mod', [self, other])
@overload
def __rmod__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rmod__(self, other: float) -> 'variable[float]': ...
def __rmod__(self, other: NumLike) -> Any:
return add_op('mod', [other, self])
@overload
def __pow__(self, other: TCPNum) -> TCPNum: ...
@overload
def __pow__(self: TCPNum, other: uniint) -> TCPNum: ...
@overload
def __pow__(self, other: unifloat) -> 'variable[float]': ...
@overload
def __pow__(self, other: NumLike) -> 'variable[float] | variable[int]': ...
def __pow__(self, other: NumLike) -> Any:
if not isinstance(other, variable):
if other == 2:
return self * self
if other == -1:
return 1 / self
return add_op('pow', [self, other])
@overload
def __rpow__(self: TCPNum, other: int) -> TCPNum: ...
@overload
def __rpow__(self, other: float) -> 'variable[float]': ...
def __rpow__(self, other: NumLike) -> Any:
return add_op('rpow', [other, self])
def __hash__(self) -> int:
return super().__hash__()
# Bitwise and shift operations for cp[int] # Bitwise and shift operations for cp[int]
def __lshift__(self, other: uniint) -> 'variable[int]': def __lshift__(self, other: uniint) -> 'variable[int]':
return _add_op('lshift', [self, other]) return add_op('lshift', [self, other])
def __rlshift__(self, other: uniint) -> 'variable[int]': def __rlshift__(self, other: uniint) -> 'variable[int]':
return _add_op('lshift', [other, self]) return add_op('lshift', [other, self])
def __rshift__(self, other: uniint) -> 'variable[int]': def __rshift__(self, other: uniint) -> 'variable[int]':
return _add_op('rshift', [self, other]) return add_op('rshift', [self, other])
def __rrshift__(self, other: uniint) -> 'variable[int]': def __rrshift__(self, other: uniint) -> 'variable[int]':
return _add_op('rshift', [other, self]) return add_op('rshift', [other, self])
def __and__(self, other: uniint) -> 'variable[int]': def __and__(self, other: uniint) -> 'variable[int]':
return _add_op('bwand', [self, other], True) return add_op('bwand', [self, other], True)
def __rand__(self, other: uniint) -> 'variable[int]': def __rand__(self, other: uniint) -> 'variable[int]':
return _add_op('rwand', [other, self], True) return add_op('rwand', [other, self], True)
def __or__(self, other: uniint) -> 'variable[int]': def __or__(self, other: uniint) -> 'variable[int]':
return _add_op('bwor', [self, other], True) return add_op('bwor', [self, other], True)
def __ror__(self, other: uniint) -> 'variable[int]': def __ror__(self, other: uniint) -> 'variable[int]':
return _add_op('bwor', [other, self], True) return add_op('bwor', [other, self], True)
def __xor__(self, other: uniint) -> 'variable[int]': def __xor__(self, other: uniint) -> 'variable[int]':
return _add_op('bwxor', [self, other], True) return add_op('bwxor', [self, other], True)
def __rxor__(self, other: uniint) -> 'variable[int]': def __rxor__(self, other: uniint) -> 'variable[int]':
return _add_op('bwxor', [other, self], True) return add_op('bwxor', [other, self], True)
class cpvector:
def __init__(self, *value: NumLike):
self.value = value
def __add__(self, other: 'cpvector') -> 'cpvector':
assert len(self.value) == len(other.value)
tup = (a + b for a, b in zip(self.value, other.value))
return cpvector(*(v for v in tup if isinstance(v, CPNumber)))
class CPConstant(Node): class CPConstant(Node):
@ -295,7 +262,7 @@ class CPConstant(Node):
class Write(Node): class Write(Node):
def __init__(self, input: NetAndNum): def __init__(self, input: Net | int | float):
if isinstance(input, Net): if isinstance(input, Net):
net = input net = input
else: else:
@ -319,35 +286,24 @@ def net_from_value(value: Any) -> Net:
@overload @overload
def iif(expression: CPNumber, true_result: unibool, false_result: unibool) -> variable[bool]: # pyright: ignore[reportOverlappingOverload] def iif(expression: variable[Any], true_result: unibool, false_result: unibool) -> variable[bool]: ... # pyright: ignore[reportOverlappingOverload]
...
@overload @overload
def iif(expression: CPNumber, true_result: uniint, false_result: uniint) -> variable[int]: def iif(expression: variable[Any], true_result: uniint, false_result: uniint) -> variable[int]: ...
...
@overload @overload
def iif(expression: CPNumber, true_result: unifloat, false_result: unifloat) -> variable[float]: def iif(expression: variable[Any], true_result: unifloat, false_result: unifloat) -> variable[float]: ...
...
@overload @overload
def iif(expression: NumLike, true_result: T, false_result: T) -> T: def iif(expression: float | int, true_result: TNum, false_result: TNum) -> TNum: ...
... @overload
def iif(expression: float | int, true_result: TNum, false_result: variable[TNum]) -> variable[TNum]: ...
@overload
def iif(expression: float | int, true_result: variable[TNum], false_result: TNum | variable[TNum]) -> variable[TNum]: ...
def iif(expression: Any, true_result: Any, false_result: Any) -> Any: def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
allowed_type = (variable, int, float, bool) allowed_type = (variable, int, float, bool)
assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported" assert isinstance(true_result, allowed_type) and isinstance(false_result, allowed_type), "Result type not supported"
if isinstance(expression, CPNumber): return (expression != 0) * true_result + (expression == 0) * false_result
return (expression != 0) * true_result + (expression == 0) * false_result
else:
return true_result if expression else false_result
def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = False) -> variable[Any]: def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool = False) -> variable[Any]:
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
if commutative: if commutative:
@ -360,10 +316,10 @@ def _add_op(op: str, args: list[CPNumber | int | float], commutative: bool = Fal
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
if result_type == 'int': if result_type == 'float':
return variable[int](Op(typed_op, arg_nets), result_type)
else:
return variable[float](Op(typed_op, arg_nets), result_type) return variable[float](Op(typed_op, arg_nets), result_type)
else:
return variable[int](Op(typed_op, arg_nets), result_type)
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:

View File

@ -14,7 +14,7 @@ COMMAND_SIZE = 4
class data_writer(): class data_writer():
def __init__(self, byteorder: ByteOrder): def __init__(self, byteorder: ByteOrder):
self._data: list[tuple[str, bytes, int]] = list() self._data: list[tuple[str, bytes, int]] = []
self.byteorder: ByteOrder = byteorder self.byteorder: ByteOrder = byteorder
def write_int(self, value: int, num_bytes: int = 4, signed: bool = False) -> None: def write_int(self, value: int, num_bytes: int = 4, signed: bool = False) -> None:

View File

@ -1,6 +1,6 @@
from typing import Generator, Iterable, Any from typing import Generator, Iterable, Any
from . import _binwrite as binw from . import _binwrite as binw
from ._stencils import stencil_database, patch_entry from ._stencils import stencil_database
from collections import defaultdict, deque from collections import defaultdict, deque
from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type
@ -166,8 +166,8 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset:
return object_list, offset return object_list, offset
def get_target_sym_lookup(function_names: Iterable[str], sdb: stencil_database) -> dict[str, patch_entry]: #def get_target_sym_lookup(function_names: Iterable[str], sdb: stencil_database) -> dict[str, patch_entry]:
return {patch.target_symbol_name: patch for name in set(function_names) for patch in sdb.get_patch_positions(name)} # return {patch.target_symbol_name: patch for name in set(function_names) for patch in sdb.get_patch_positions(name)}
def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]: def get_section_layout(section_indexes: Iterable[int], sdb: stencil_database, offset: int = 0) -> tuple[list[tuple[int, int, int]], int]:
@ -192,8 +192,8 @@ def get_aux_function_mem_layout(function_names: Iterable[str], sdb: stencil_data
return function_list, offset return function_list, offset
def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]: def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[binw.data_writer, dict[Net, tuple[int, int, str]]]:
variables: dict[Net, tuple[int, int, str]] = dict() variables: dict[Net, tuple[int, int, str]] = {}
data_list: list[bytes] = [] data_list: list[bytes] = []
patch_list: list[tuple[int, int, int, binw.Command]] = [] patch_list: list[tuple[int, int, int, binw.Command]] = []
@ -221,18 +221,18 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
dw.write_int(variables_data_lengths) dw.write_int(variables_data_lengths)
# Heap constants # Heap constants
for section_id, out_offs, lengths in section_mem_layout: for section_id, start, lengths in section_mem_layout:
dw.write_com(binw.Command.COPY_DATA) dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs) dw.write_int(start)
dw.write_int(lengths) dw.write_int(lengths)
dw.write_bytes(sdb.get_section_data(section_id)) dw.write_bytes(sdb.get_section_data(section_id))
# Heap variables # Heap variables
for net, out_offs, lengths in variable_mem_layout: for net, start, lengths in variable_mem_layout:
variables[net] = (out_offs, lengths, net.dtype) variables[net] = (start, lengths, net.dtype)
if isinstance(net.source, CPConstant): if isinstance(net.source, CPConstant):
dw.write_com(binw.Command.COPY_DATA) dw.write_com(binw.Command.COPY_DATA)
dw.write_int(out_offs) dw.write_int(start)
dw.write_int(lengths) dw.write_int(lengths)
dw.write_value(net.source.value, lengths) dw.write_value(net.source.value, lengths)
# print(f'+ {net.dtype} {net.source.value}') # print(f'+ {net.dtype} {net.source.value}')
@ -244,12 +244,11 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
# Prepare program code and relocations # Prepare program code and relocations
object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout} object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout}
section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout} section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout}
offset = aux_function_lengths # offset in generated code chunk
# assemble stencils to main program # assemble stencils to main program and patch stencils
data = sdb.get_function_code('entry_function_shell', 'start') data = sdb.get_function_code('entry_function_shell', 'start')
data_list.append(data) data_list.append(data)
offset += len(data) offset = aux_function_lengths + len(data)
for associated_net, node in extended_output_ops: for associated_net, node in extended_output_ops:
assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found" assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found"
@ -257,24 +256,30 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
data_list.append(data) data_list.append(data)
#print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) #print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for patch in sdb.get_patch_positions(node.name): for patch in sdb.get_patch_positions(node.name, stencil=True):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
if patch.target_symbol_name.startswith('dummy_'): if patch.target_symbol_name.startswith('dummy_'):
# Patch for write and read addresses to/from heap variables # Patch for write and read addresses to/from heap variables
assert associated_net, f"Relocation found but no net defined for operation {node.name}" assert associated_net, f"Relocation found but no net defined for operation {node.name}"
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") #print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
addr = object_addr_lookup[associated_net] addr = object_addr_lookup[associated_net]
patch_value = addr + patch.addend - (offset + patch.addr) patch_value = addr + patch.addend - (offset + patch.patch_address)
elif patch.target_symbol_name.startswith('result_'):
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
else: else:
# Patch constants addresses on heap # Patch constants addresses on heap
addr = section_addr_lookup[patch.target_symbol_section_index] section_addr = section_addr_lookup[patch.target_symbol_section_index]
patch_value = addr + patch.addend - (offset + patch.addr) obj_addr = section_addr + patch.target_symbol_address
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) patch_value = obj_addr + patch.addend - (offset + patch.patch_address)
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
elif patch.target_symbol_info == 'STT_FUNC': elif patch.target_symbol_info == 'STT_FUNC':
addr = aux_func_addr_lookup[patch.target_symbol_name] addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = addr + patch.addend - (offset + patch.addr) patch_value = addr + patch.addend - (offset + patch.patch_address)
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC)) patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
else: else:
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
@ -288,13 +293,34 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
dw.write_com(binw.Command.ALLOCATE_CODE) dw.write_com(binw.Command.ALLOCATE_CODE)
dw.write_int(offset) dw.write_int(offset)
# write aux functions # write aux functions code
for name, out_offs, lengths in aux_function_mem_layout: for name, start, lengths in aux_function_mem_layout:
dw.write_com(binw.Command.COPY_CODE) dw.write_com(binw.Command.COPY_CODE)
dw.write_int(out_offs) dw.write_int(start)
dw.write_int(lengths) dw.write_int(lengths)
dw.write_bytes(sdb.get_function_code(name)) dw.write_bytes(sdb.get_function_code(name))
# Patch aux functions
for name, start, lengths in aux_function_mem_layout:
for patch in sdb.get_patch_positions(name):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
# Patch constants/variable addresses on heap
section_addr = section_addr_lookup[patch.target_symbol_section_index]
obj_addr = section_addr + patch.target_symbol_address
patch_value = obj_addr + patch.addend - (start + patch.patch_address)
patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
elif patch.target_symbol_info == 'STT_FUNC':
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = aux_func_addr + patch.addend - (start + patch.patch_address)
patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
else:
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
#assert False, aux_function_mem_layout
# write entry function code # write entry function code
dw.write_com(binw.Command.COPY_CODE) dw.write_com(binw.Command.COPY_CODE)
dw.write_int(aux_function_lengths) dw.write_int(aux_function_lengths)

39
src/copapy/_math.py Normal file
View File

@ -0,0 +1,39 @@
from . import variable, NumLike
from typing import TypeVar, Any, overload
from ._basic_types import add_op
T = TypeVar("T", int, float, variable[int], variable[float])
@overload
def sqrt(x: float | int) -> float: ...
@overload
def sqrt(x: variable[Any]) -> variable[float]: ...
def sqrt(x: NumLike) -> variable[float] | float:
"""Square root function"""
if isinstance(x, variable):
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
return float(x ** 0.5)
@overload
def sqrt2(x: float | int) -> float: ...
@overload
def sqrt2(x: variable[Any]) -> variable[float]: ...
def sqrt2(x: NumLike) -> variable[float] | float:
"""Square root function"""
if isinstance(x, variable):
return add_op('sqrt2', [x, x]) # TODO: fix 2. dummy argument
return float(x ** 0.5)
def get_42() -> variable[float]:
"""Returns the variable representing the constant 42"""
return add_op('get_42', [0.0, 0.0])
def abs(x: T) -> T:
"""Absolute value function"""
ret = (x < 0) * -x + (x >= 0) * x
return ret # pyright: ignore[reportReturnType]

View File

@ -21,11 +21,12 @@ class patch_entry:
type (RelocationType): relocation type""" type (RelocationType): relocation type"""
type: RelocationType type: RelocationType
addr: int patch_address: int
addend: int addend: int
target_symbol_name: str target_symbol_name: str
target_symbol_info: str target_symbol_info: str
target_symbol_section_index: int target_symbol_section_index: int
target_symbol_address: int
def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType: def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType:
@ -119,7 +120,7 @@ class stencil_database():
ret.add(sym.section.index) ret.add(sym.section.index)
return list(ret) return list(ret)
def get_patch_positions(self, symbol_name: str) -> Generator[patch_entry, None, None]: def get_patch_positions(self, symbol_name: str, stencil: bool = False) -> Generator[patch_entry, None, None]:
"""Return patch positions for a provided symbol (function or object) """Return patch positions for a provided symbol (function or object)
Args: Args:
@ -129,7 +130,11 @@ class stencil_database():
patch_entry: every relocation for the symbol patch_entry: every relocation for the symbol
""" """
symbol = self.elf.symbols[symbol_name] symbol = self.elf.symbols[symbol_name]
start_index, end_index = get_stencil_position(symbol) if stencil:
start_index, end_index = get_stencil_position(symbol)
else:
start_index = 0
end_index = symbol.fields['st_size']
for reloc in symbol.relocations: for reloc in symbol.relocations:
@ -146,10 +151,11 @@ class stencil_database():
reloc.fields['r_addend'], reloc.fields['r_addend'],
reloc.symbol.name, reloc.symbol.name,
reloc.symbol.info, reloc.symbol.info,
reloc.symbol.fields['st_shndx']) reloc.symbol.fields['st_shndx'],
reloc.symbol.fields['st_value'])
# Exclude the call to the result_* function # Exclude the call to the result_* function
if patch.addr < end_index - start_index: if patch.patch_address < end_index - start_index:
yield patch yield patch
def get_stencil_code(self, name: str) -> bytes: def get_stencil_code(self, name: str) -> bytes:
@ -185,7 +191,7 @@ class stencil_database():
return self.elf.sections[id].data return self.elf.sections[id].data
def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes: def get_function_code(self, name: str, part: Literal['full', 'start', 'end'] = 'full') -> bytes:
"""Returns machine code for a specified function name""" """Returns machine code for a specified function name."""
func = self.elf.symbols[name] func = self.elf.symbols[name]
assert func.info == 'STT_FUNC', f"{name} is not a function" assert func.info == 'STT_FUNC', f"{name} is not a function"

View File

@ -4,7 +4,7 @@ from coparun_module import coparun, read_data_mem
import struct import struct
from ._basic_types import stencil_db_from_package from ._basic_types import stencil_db_from_package
from ._basic_types import variable, Net, Node, Write, NumLike from ._basic_types import variable, Net, Node, Write, NumLike
from ._compiler import compile_to_instruction_list from ._compiler import compile_to_dag
def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None: def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int, str]], net: Net) -> None:
@ -18,7 +18,7 @@ def add_read_command(dw: binw.data_writer, variables: dict[Net, tuple[int, int,
class Target(): class Target():
def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None: def __init__(self, arch: str = 'native', optimization: str = 'O3') -> None:
self.sdb = stencil_db_from_package(arch, optimization) self.sdb = stencil_db_from_package(arch, optimization)
self._variables: dict[Net, tuple[int, int, str]] = dict() self._variables: dict[Net, tuple[int, int, str]] = {}
def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None: def compile(self, *variables: int | float | variable[int] | variable[float] | variable[bool] | Iterable[int | float | variable[int] | variable[float] | variable[bool]]) -> None:
nodes: list[Node] = [] nodes: list[Node] = []
@ -30,7 +30,7 @@ class Target():
else: else:
nodes.append(Write(s)) nodes.append(Write(s))
dw, self._variables = compile_to_instruction_list(nodes, self.sdb) dw, self._variables = compile_to_dag(nodes, self.sdb)
dw.write_com(binw.Command.END_COM) dw.write_com(binw.Command.END_COM)
assert coparun(dw.get_data()) > 0 assert coparun(dw.get_data()) > 0

158
src/copapy/_vectors.py Normal file
View File

@ -0,0 +1,158 @@
from . import variable
from typing import Generic, TypeVar, Iterable, Any, overload, TypeAlias
from ._math import sqrt
VecNumLike: TypeAlias = 'vector[int] | vector[float] | variable[int] | variable[float] | int | float'
VecIntLike: TypeAlias = 'vector[int] | variable[int] | int'
VecFloatLike: TypeAlias = 'vector[float] | variable[float] | float'
T = TypeVar("T", int, float)
epsilon = 1e-20
class vector(Generic[T]):
"""Type-safe vector supporting numeric promotion between vector types."""
def __init__(self, values: Iterable[T | variable[T]]):
self.values: tuple[variable[T] | T, ...] = tuple(values)
# ---------- Basic dunder methods ----------
def __repr__(self) -> str:
return f"vector({self.values})"
def __len__(self) -> int:
return len(self.values)
def __getitem__(self, index: int) -> variable[T] | T:
return self.values[index]
@overload
def __add__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ...
@overload
def __add__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
@overload
def __add__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __add__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
def __add__(self, other: VecNumLike) -> Any:
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a + b for a, b in zip(self.values, other.values))
return vector(a + other for a in self.values)
@overload
def __radd__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __radd__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
def __radd__(self, other: Any) -> Any:
return self + other
@overload
def __sub__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ...
@overload
def __sub__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
@overload
def __sub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __sub__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
def __sub__(self, other: VecNumLike) -> Any:
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a - b for a, b in zip(self.values, other.values))
return vector(a - other for a in self.values)
@overload
def __rsub__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __rsub__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
def __rsub__(self, other: VecNumLike) -> Any:
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(b - a for a, b in zip(self.values, other.values))
return vector(other - a for a in self.values)
@overload
def __mul__(self: 'vector[int]', other: VecFloatLike) -> 'vector[float]': ...
@overload
def __mul__(self: 'vector[int]', other: VecIntLike) -> 'vector[int]': ...
@overload
def __mul__(self: 'vector[float]', other: 'vector[int] | float | int | variable[int]') -> 'vector[float]': ...
@overload
def __mul__(self, other: VecNumLike) -> 'vector[int] | vector[float]': ...
def __mul__(self, other: VecNumLike) -> Any:
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a * b for a, b in zip(self.values, other.values))
return vector(a * other for a in self.values)
@overload
def __rmul__(self: 'vector[float]', other: VecNumLike) -> 'vector[float]': ...
@overload
def __rmul__(self: 'vector[int]', other: variable[int] | int) -> 'vector[int]': ...
def __rmul__(self, other: VecNumLike) -> Any:
return self * other
def __truediv__(self, other: VecNumLike) -> 'vector[float]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(a / b for a, b in zip(self.values, other.values))
return vector(a / other for a in self.values)
def __rtruediv__(self, other: VecNumLike) -> 'vector[float]':
if isinstance(other, vector):
assert len(self.values) == len(other.values)
return vector(b / a for a, b in zip(self.values, other.values))
return vector(other / a for a in self.values)
@overload
def dot(self: 'vector[int]', other: 'vector[int]') -> int | variable[int]: ...
@overload
def dot(self, other: 'vector[float]') -> float | variable[float]: ...
@overload
def dot(self: 'vector[float]', other: 'vector[int] | vector[float]') -> float | variable[float]: ...
@overload
def dot(self, other: 'vector[int] | vector[float]') -> float | int | variable[float] | variable[int]: ...
def dot(self, other: 'vector[int] | vector[float]') -> Any:
assert len(self.values) == len(other.values)
return sum(a * b for a, b in zip(self.values, other.values))
# @ operator
@overload
def __matmul__(self: 'vector[int]', other: 'vector[int]') -> int | variable[int]: ...
@overload
def __matmul__(self, other: 'vector[float]') -> float | variable[float]: ...
@overload
def __matmul__(self: 'vector[float]', other: 'vector[int] | vector[float]') -> float | variable[float]: ...
@overload
def __matmul__(self, other: 'vector[int] | vector[float]') -> float | int | variable[float] | variable[int]: ...
def __matmul__(self, other: 'vector[int] | vector[float]') -> Any:
return self.dot(other)
def cross(self: 'vector[float]', other: 'vector[float]') -> 'vector[float]':
"""3D cross product"""
assert len(self.values) == 3 and len(other.values) == 3
a1, a2, a3 = self.values
b1, b2, b3 = other.values
return vector([
a2 * b3 - a3 * b2,
a3 * b1 - a1 * b3,
a1 * b2 - a2 * b1
])
@overload
def sum(self: 'vector[int]') -> int | variable[int]: ...
@overload
def sum(self: 'vector[float]') -> float | variable[float]: ...
def sum(self) -> Any:
return sum(a for a in self.values if isinstance(a, variable)) +\
sum(a for a in self.values if not isinstance(a, variable))
def magnitude(self) -> 'float | variable[float]':
s = sum(a * a for a in self.values)
return sqrt(s) if isinstance(s, variable) else sqrt(s)
def normalize(self) -> 'vector[float]':
mag = self.magnitude() + epsilon
return self / mag
def __iter__(self) -> Iterable[variable[T] | T]:
return iter(self.values)

View File

@ -1,6 +1,6 @@
from ._target import add_read_command from ._target import add_read_command
from ._basic_types import Net, Op, Node, CPConstant, Write from ._basic_types import Net, Op, Node, CPConstant, Write
from ._compiler import compile_to_instruction_list, \ from ._compiler import compile_to_dag, \
stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, \ stable_toposort, get_const_nets, get_all_dag_edges, add_read_ops, \
add_write_ops add_write_ops
@ -11,7 +11,7 @@ __all__ = [
"Node", "Node",
"CPConstant", "CPConstant",
"Write", "Write",
"compile_to_instruction_list", "compile_to_dag",
"stable_toposort", "stable_toposort",
"get_const_nets", "get_const_nets",
"get_all_dag_edges", "get_all_dag_edges",

View File

@ -12,6 +12,27 @@ __attribute__((noinline)) int floor_div(float arg1, float arg2) {
return i; return i;
} }
__attribute__((noinline)) float aux_sqrt(float n) {
if (n < 0) return -1;
float x = n; // initial guess
float epsilon = 0.00001; // desired accuracy
while ((x - n / x) > epsilon || (x - n / x) < -epsilon) {
x = 0.5 * (x + n / x);
}
return x;
}
__attribute__((noinline)) float aux_sqrt2(float n) {
return n * 20.5 + 4.5;
}
__attribute__((noinline)) float aux_get_42(float n) {
return n + 42.0;
}
float fast_pow_float(float base, float exponent) { float fast_pow_float(float base, float exponent) {
union { union {
float f; float f;
@ -24,3 +45,13 @@ float fast_pow_float(float base, float exponent) {
u.i = (uint32_t)y; u.i = (uint32_t)y;
return u.f; return u.f;
} }
int main() {
// Test aux functions
float a = 16.0f;
float sqrt_a = aux_sqrt(a);
float pow_a = fast_pow_float(a, 0.5f);
float sqrt2_a = aux_sqrt2(a);
float g42 = aux_get_42(0.0f);
return 0;
}

View File

@ -4,7 +4,7 @@ from pathlib import Path
import os import os
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/', 'pow': '**', op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/', 'pow': '**',
'gt': '>', 'eq': '==', 'ne': '!=', 'mod': '%'} 'gt': '>', 'eq': '==', 'ge': '>=', 'ne': '!=', 'mod': '%'}
entry_func_prefix = '' entry_func_prefix = ''
stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog
@ -78,6 +78,15 @@ def get_cast(type1: str, type2: str, type_out: str) -> str:
""" """
@norm_indent
def get_func2(func_name: str, type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_float_{type2}(aux_{func_name}((float)arg1), arg2);
}}
"""
@norm_indent @norm_indent
def get_conv_code(type1: str, type2: str, type_out: str) -> str: def get_conv_code(type1: str, type2: str, type_out: str) -> str:
return f""" return f"""
@ -189,7 +198,7 @@ if __name__ == "__main__":
# Scalar arithmetic: # Scalar arithmetic:
types = ['int', 'float'] types = ['int', 'float']
ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'eq', 'ne', 'pow'] ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'pow']
for t1 in types: for t1 in types:
code += get_result_stubs1(t1) code += get_result_stubs1(t1)
@ -203,6 +212,11 @@ if __name__ == "__main__":
t_out = 'int' if t1 == 'float' else 'float' t_out = 'int' if t1 == 'float' else 'float'
code += get_cast(t1, t2, t_out) code += get_cast(t1, t2, t_out)
for t1, t2 in permutate(types, types):
code += get_func2('sqrt', t1, t2)
code += get_func2('sqrt2', t1, t2)
code += get_func2('get_42', t1, t2)
for op, t1, t2 in permutate(ops, types, types): for op, t1, t2 in permutate(ops, types, types):
t_out = t1 if t1 == t2 else 'float' t_out = t1 if t1 == t2 else 'float'
if op == 'floordiv': if op == 'floordiv':
@ -211,7 +225,7 @@ if __name__ == "__main__":
code += get_op_code_float(op, t1, t2) code += get_op_code_float(op, t1, t2)
elif op == 'pow': elif op == 'pow':
code += get_pow(t1, t2) code += get_pow(t1, t2)
elif op == 'gt' or op == 'eq' or op == 'ne': elif op in {'gt', 'eq', 'ge', 'ne'}:
code += get_op_code(op, t1, t2, 'int') code += get_op_code(op, t1, t2, 'int')
else: else:
code += get_op_code(op, t1, t2, t_out) code += get_op_code(op, t1, t2, t_out)

View File

@ -1,5 +1,5 @@
from copapy import variable, NumLike from copapy import variable, NumLike
from copapy.backend import Write, compile_to_instruction_list, add_read_command from copapy.backend import Write, compile_to_dag, add_read_command
import copapy import copapy
import subprocess import subprocess
import struct import struct
@ -49,7 +49,7 @@ def test_compile():
out = [Write(r) for r in ret] out = [Write(r) for r in ret]
il, variables = compile_to_instruction_list(out, copapy.generic_sdb) il, variables = compile_to_dag(out, copapy.generic_sdb)
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)

View File

@ -1,5 +1,5 @@
from copapy import variable, NumLike from copapy import variable, NumLike
from copapy.backend import Write, compile_to_instruction_list from copapy.backend import Write, compile_to_dag
import copapy import copapy
import subprocess import subprocess
from copapy import _binwrite from copapy import _binwrite
@ -26,7 +26,7 @@ def test_compile():
out = [Write(r) for r in ret] out = [Write(r) for r in ret]
il, _ = compile_to_instruction_list(out, copapy.generic_sdb) il, _ = compile_to_dag(out, copapy.generic_sdb)
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)

View File

@ -1,6 +1,6 @@
from coparun_module import coparun from coparun_module import coparun
from copapy import variable from copapy import variable
from copapy.backend import Write, compile_to_instruction_list, add_read_command from copapy.backend import Write, compile_to_dag, add_read_command
import copapy import copapy
from copapy import _binwrite from copapy import _binwrite
@ -15,7 +15,7 @@ def test_compile():
r2 = i1 + 9 r2 = i1 + 9
out = [Write(r1), Write(r2), Write(c2)] out = [Write(r1), Write(r2), Write(c2)]
il, variables = compile_to_instruction_list(out, copapy.generic_sdb) il, variables = compile_to_dag(out, copapy.generic_sdb)
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)

View File

@ -1,5 +1,5 @@
from copapy import NumLike, variable from copapy import NumLike, variable
from copapy.backend import Write, Net, compile_to_instruction_list, add_read_command from copapy.backend import Write, Net, compile_to_dag, add_read_command
import copapy import copapy
import subprocess import subprocess
from copapy import _binwrite from copapy import _binwrite
@ -29,7 +29,7 @@ def test_compile():
ret = function(c1, c2) ret = function(c1, c2)
dw, variable_list = compile_to_instruction_list([Write(net) for net in ret], copapy.generic_sdb) dw, variable_list = compile_to_dag([Write(net) for net in ret], copapy.generic_sdb)
# run program command # run program command
dw.write_com(_binwrite.Command.RUN_PROG) dw.write_com(_binwrite.Command.RUN_PROG)

View File

@ -1,31 +0,0 @@
from copapy import variable, Target
import pytest
import copapy
def test_compile():
c_i = variable(9)
c_f = variable(2.5)
# c_b = variable(True)
ret_test = (c_f ** c_f, c_i ** c_i)
ret_ref = (2.5 ** 2.5, 9 ** 9)
tg = Target()
print('* compile and copy ...')
tg.compile(ret_test)
print('* run and copy ...')
tg.run()
print('* finished')
for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, copapy.CPNumber)
val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
test_compile()

60
tests/test_math.py Normal file
View File

@ -0,0 +1,60 @@
from copapy import variable, Target
import pytest
import copapy as cp
def test_corse():
a_i = 9
a_f = 2.5
c_i = variable(a_i)
c_f = variable(a_f)
# c_b = variable(True)
ret_test = (c_f ** c_f, c_i ** c_i) # , c_i & 3)
ret_refe = (a_f ** a_f, a_i ** a_i) # , a_i & 3)
tg = Target()
print('* compile and copy ...')
tg.compile(ret_test)
print('* run and copy ...')
tg.run()
print('* finished')
for test, ref in zip(ret_test, ret_refe):
assert isinstance(test, cp.variable)
val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, 2), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
def test_fine():
a_i = 9
a_f = 2.5
c_i = variable(a_i)
c_f = variable(a_f)
# c_b = variable(True)
ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f)) # , c_i & 3)
ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f)) # , a_i & 3)
tg = Target()
print('* compile and copy ...')
tg.compile(ret_test)
print('* run and copy ...')
tg.run()
print('* finished')
for test, ref in zip(ret_test, ret_refe):
assert isinstance(test, cp.variable)
val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool):
# assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, 0.001), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
test_corse()
test_fine()

View File

@ -57,7 +57,7 @@ def test_compile():
print('* finished') print('* finished')
for test, ref in zip(ret_test, ret_ref): for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, copapy.CPNumber) assert isinstance(test, copapy.variable)
val = tg.read_value(test) val = tg.read_value(test)
print('+', val, ref, test.dtype) print('+', val, ref, test.dtype)
for t in (int, float, bool): for t in (int, float, bool):

View File

@ -19,7 +19,7 @@ def test_compile():
print('* finished') print('* finished')
for test, ref in zip(ret_test, ret_ref): for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, copapy.CPNumber) assert isinstance(test, copapy.variable)
val = tg.read_value(test) val = tg.read_value(test)
print('+', val, ref, type(val), test.dtype) print('+', val, ref, type(val), test.dtype)
#for t in (int, float, bool): #for t in (int, float, bool):

38
tests/test_vector.py Normal file
View File

@ -0,0 +1,38 @@
import copapy as cp
import pytest
def test_vectors_init():
tt1 = cp.vector(range(3)) + cp.vector([1.1, 2.2, 3.3])
tt2 = cp.vector([1.1, 2, cp.variable(5)]) + cp.vector(range(3))
tt3 = (cp.vector(range(3)) + 5.6)
tt4 = cp.vector([1.1, 2, 3]) + cp.vector(cp.variable(v) for v in range(3))
tt5 = cp.vector([1, 2, 3]).dot(tt4)
print(tt1, tt2, tt3, tt4, tt5)
def test_compiled_vectors():
t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3))
t2 = t1.sum()
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
t4 = ((t3 * t1) * 2).sum()
t5 = ((t3 * t1) * 2).magnitude()
tg = cp.Target()
tg.compile(t2, t4, t5)
tg.run()
assert isinstance(t2, cp.variable)
assert tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
assert isinstance(t4, cp.variable)
assert tg.read_value(t4) == pytest.approx(((10/1*2) + (12/2*2) + (14/3*2)), 0.001) # pyright: ignore[reportUnknownMemberType]
assert isinstance(t5, cp.variable)
assert tg.read_value(t5) == pytest.approx(((10/1*2)**2 + (12/2*2)**2 + (14/3*2)**2) ** 0.5, 0.001) # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
test_compiled_vectors()

View File

@ -47,6 +47,8 @@ if __name__ == "__main__":
offs = dr.read_int() offs = dr.read_int()
reloc_type = dr.read_int() reloc_type = dr.read_int()
value = dr.read_int(signed=True) value = dr.read_int(signed=True)
assert reloc_type == RelocationType.RELOC_RELATIVE_32.value
program_data[offs:offs + 4] = value.to_bytes(4, byteorder, signed=True)
print(f"PATCH_FUNC patch_offs={offs} reloc_type={reloc_type} value={value}") print(f"PATCH_FUNC patch_offs={offs} reloc_type={reloc_type} value={value}")
elif com == Command.PATCH_OBJECT: elif com == Command.PATCH_OBJECT:
offs = dr.read_int() offs = dr.read_int()

View File

@ -1,32 +1,43 @@
from copapy import _binwrite, variable from copapy import variable
from copapy.backend import Write, compile_to_instruction_list from copapy.backend import Write, compile_to_dag
import copapy import copapy as cp
from copapy._binwrite import Command
def test_compile() -> None: def test_compile() -> None:
"""Test compilation of a simple program."""
c1 = variable(9) c1 = variable(9.0)
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4] #ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
ret = [c1 // 3.3 + 5] ret = [c1 // 3.3 + 5]
#ret = [cp.sqrt(c1)]
#c2 = cp._math.get_42()
#ret = [c2]
out = [Write(r) for r in ret] out = [Write(r) for r in ret]
il, _ = compile_to_instruction_list(out, copapy.generic_sdb) dw, vars = compile_to_dag(out, cp.generic_sdb)
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) dw.write_com(Command.RUN_PROG)
il.write_com(_binwrite.Command.READ_DATA) # read first 32 byte
il.write_int(0) dw.write_com(Command.READ_DATA)
il.write_int(36) dw.write_int(0)
dw.write_int(32)
il.write_com(_binwrite.Command.END_COM) # read variables
for addr, lengths, _ in vars.values():
dw.write_com(Command.READ_DATA)
dw.write_int(addr)
dw.write_int(lengths)
dw.write_com(Command.END_COM)
print('* Data to runner:') print('* Data to runner:')
il.print() dw.print()
il.to_file('bin/test.copapy') dw.to_file('bin/test.copapy')
if __name__ == "__main__": if __name__ == "__main__":

19
tools/test_stencil_aux.sh Normal file
View File

@ -0,0 +1,19 @@
#!/bin/bash
set -e
set -v
mkdir -p bin
FILE=aux_functions
SRC=stencils/$FILE.c
DEST=bin
OPT=O3
mkdir -p $DEST
# Compile native x86_64
gcc -g -$OPT $SRC -o $DEST/$FILE
chmod +x $DEST/$FILE
# Run
$DEST/$FILE