From 7d5990e2b27e466f33d66a53b66d1cbb8b074142 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Mon, 13 Oct 2025 22:58:55 +0200 Subject: [PATCH] added bool support --- src/copapy/__init__.py | 54 ++++++++++++++++++++++++-------------- src/copapy/stencil_db.py | 1 + tests/test_ops.py | 40 +++++++++++++++++++++------- tools/generate_stencils.py | 9 ++++--- 4 files changed, 70 insertions(+), 34 deletions(-) diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 144dc61..5620b0b 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -24,6 +24,9 @@ def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> s generic_sdb = stencil_db_from_package() +def transl_type(t: str): + return {'bool': 'int'}.get(t, t) + class Node: def __init__(self) -> None: @@ -84,7 +87,16 @@ class Net: return _add_op('gt', [other, self]) def __eq__(self, other: Any) -> 'Net': # type: ignore - return _add_op('eq', [self, other]) + return _add_op('eq', [self, other], True) + + def __req__(self, other: Any) -> 'Net': # type: ignore + return _add_op('eq', [self, other], True) + + def __ne__(self, other: Any) -> 'Net': # type: ignore + return _add_op('ne', [self, other], True) + + def __rne__(self, other: Any) -> 'Net': # type: ignore + return _add_op('ne', [self, other], True) def __mod__(self, other: Any) -> 'Net': return _add_op('mod', [self, other]) @@ -109,7 +121,7 @@ class InitVar(Node): class Write(Node): def __init__(self, net: Net): - self.name = 'write_' + net.dtype + self.name = 'write_' + transl_type(net.dtype) self.args = [net] @@ -126,12 +138,15 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net: if commutative: arg_nets = sorted(arg_nets, key=lambda a: a.dtype) - typed_op = '_'.join([op] + [a.dtype for a in arg_nets]) + typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) if typed_op not in generic_sdb.stencil_definitions: raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}") - result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] + if op in {'eq', 'ne', 'ge'}: + result_type = 'bool' + else: + result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0] result_net = Net(result_type, Op(typed_op, arg_nets)) @@ -160,12 +175,12 @@ class CPBool(CPVariable): def _get_data_and_dtype(value: Any) -> tuple[str, float | int]: - if isinstance(value, int): + if isinstance(value, bool): + return ('bool', int(value)) + elif isinstance(value, int): return ('int', int(value)) elif isinstance(value, float): return ('float', float(value)) - elif isinstance(value, bool): - return ('bool', int(value)) else: raise ValueError(f'Non supported data type: {type(value).__name__}') @@ -281,8 +296,8 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No if id(net) != id(registers[i]): #if net in registers: # print('x swap registers') - type_list = ['int' if r is None else r.dtype for r in registers] - new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), []) + type_list = ['int' if r is None else transl_type(r.dtype) for r in registers] + new_node = Op(f"read_{transl_type(net.dtype)}_reg{i}_" + '_'.join(type_list), []) yield net, new_node registers[i] = net @@ -339,7 +354,7 @@ def get_variable_mem_layout(variable_list: Iterable[Net], sdb: stencil_database) object_list: list[tuple[Net, int, int]] = [] for variable in variable_list: - lengths = sdb.get_symbol_size('dummy_' + variable.dtype) + lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype)) object_list.append((variable, offset, lengths)) offset += (lengths + 3) // 4 * 4 @@ -421,7 +436,6 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database patch_value = addr + patch.addend - (offset + patch.addr) patch_list.append((patch.type.value, offset + patch.addr, patch_value)) - print('++ ', patch.target_symbol_info, patch.target_symbol_name) offset += len(data) @@ -473,7 +487,7 @@ class Target(): nodes.append(Write(s)) else: for net in s: - assert isinstance(net, Net) + assert isinstance(net, Net), f"The folowing element is not a Net: {net}" nodes.append(Write(net)) dw, self._variables = compile_to_instruction_list(nodes, self.sdb) @@ -487,10 +501,9 @@ class Target(): dw.write_com(binw.Command.END_COM) assert coparun(dw.get_data()) > 0 - def read_value(self, net: Net) -> float | int: + def read_value(self, net: Net) -> float | int | bool: assert net in self._variables, f"Variable {net} not found" addr, lengths, var_type = self._variables[net] - print('read_value', addr, lengths) assert lengths > 0 data = read_data_mem(addr, lengths) assert data is not None and len(data) == lengths, f"Failed to read variable {net}" @@ -505,12 +518,13 @@ class Target(): assert isinstance(value, float) return value elif var_type == 'int': - if lengths in (1, 2, 4, 8): - value = int.from_bytes(data, byteorder=self.sdb.byteorder, signed=True) - assert isinstance(value, int) - return value - else: - raise ValueError(f"Unsupported int length: {lengths} bytes") + assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes" + value = int.from_bytes(data, byteorder=self.sdb.byteorder, signed=True) + return value + elif var_type == 'bool': + assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes" + value = bool.from_bytes(data, byteorder=self.sdb.byteorder, signed=True) + return value else: raise ValueError(f"Unsupported variable type: {var_type}") diff --git a/src/copapy/stencil_db.py b/src/copapy/stencil_db.py index e93fd31..7593b0b 100644 --- a/src/copapy/stencil_db.py +++ b/src/copapy/stencil_db.py @@ -149,6 +149,7 @@ class stencil_database(): name_set: set[str] = set() for name in names: if name not in name_set: + # assert name in self.elf.symbols, f"Stencil {name} not found" <-- see: https://github.com/Nonannet/pelfy/issues/1 func = self.elf.symbols[name] for r in func.relocations: if r.symbol.info == 'STT_FUNC': diff --git a/tests/test_ops.py b/tests/test_ops.py index 5bd7181..5da0319 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -3,34 +3,54 @@ from pytest import approx def function1(c1): - return [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4] + return [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4, + c1 * 4, c1 * -4, + c1 + 4, c1 - 4, + c1 > 2, c1 > 100, c1 < 4, c1 < 100] def function2(c1): - return [c1 / 4, c1 / -4, c1 / 4, c1 / -4, (c1 * -1) / 4] + return [c1 / 4.44, c1 / -4.44, c1 // 4.44, c1 // -4.44, (c1 * -1) // 4.44, + c1 * 4.44, c1 * -4.44, + c1 + 4.44, c1 - 4.44, + c1 > 2, c1 > 100.11, c1 < 4.44, c1 < 100.11] def function3(c1): return [c1 / 4] +def function4(c1): + return [c1 == 9, c1 == 4, c1 != 9, c1 != 4] + +def function5(c1): + return [c1 == True, c1 == False, c1 != True, c1 != False] + def test_compile(): c1 = CPVariable(9) + c2 = CPVariable(1.111) + c3 = CPVariable(False) - ret = function3(c1) + #ret_test = function1(c1) + function1(c2) + function2(c1) + function2(c2) + function3(c3) + function4(c1) + function5(c3) + [CPVariable(9) % 2] + #ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2] + + ret_test = [c1 / 4] + ret_ref = [9 / 4] + + print(ret_test) tg = Target() print('* compile and copy ...') - tg.compile(ret) + tg.compile(ret_test) #time.sleep(5) print('* run and copy ...') tg.run() - #print('* finished') + print('* finished') - ret_ref = function3(9) - - for test, ref, name in zip(ret, ret_ref, ['r1', 'r2', 'r3', 'r4', 'r5']): + for test, ref in zip(ret_test, ret_ref): val = tg.read_value(test) - print('+', name, val, ref) - assert val == approx(ref, 1e-5), name + print('+', val, ref) + for t in [int, float, bool]: + assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" + assert val == approx(ref, 1e-5), f"Result does not match: {val} and reference: {ref}" if __name__ == "__main__": diff --git a/tools/generate_stencils.py b/tools/generate_stencils.py index 750af4e..94520e1 100644 --- a/tools/generate_stencils.py +++ b/tools/generate_stencils.py @@ -1,9 +1,8 @@ from typing import Generator import argparse - op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/', - 'gt': '>', 'eq': '==', 'mod': '%'} + 'gt': '>', 'eq': '==', 'ne': '!=', 'mod': '%'} entry_func_prefix = '' stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog @@ -117,13 +116,15 @@ if __name__ == "__main__": // Auto-generated stencils for copapy // Do not edit manually + #define bool int + volatile int dummy_int = 1337; volatile float dummy_float = 1337; """ # Scalar arithmetic: types = ['int', 'float'] - ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'eq'] + ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'eq', 'ne'] for t1 in types: code += get_result_stubs1(t1) @@ -139,7 +140,7 @@ if __name__ == "__main__": code += get_floordiv('floordiv', t1, t2) elif op == 'div': code += get_op_code_float(op, t1, t2) - elif op == 'gt' or op == 'eq': + elif op == 'gt' or op == 'eq' or op == 'ne': code += get_op_code(op, t1, t2, 'int') else: code += get_op_code(op, t1, t2, t_out)