From 7c77c42b80a857829573ac105e108269ef22d858 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Fri, 7 Nov 2025 16:01:22 +0100 Subject: [PATCH] issue with wrong results on aarch64 fixed, by guarding registers for the write op --- src/copapy/_compiler.py | 18 ++++- stencils/generate_stencils.py | 10 +-- tests/test_issue001_aarch64.py | 29 ++++---- tests/test_issue001_x86_64.py | 131 --------------------------------- 4 files changed, 35 insertions(+), 153 deletions(-) delete mode 100644 tests/test_issue001_x86_64.py diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 5fb90ee..1b95445 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -129,18 +129,28 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list read_back_nets = { net for net, node in net_node_list if net and node.name.startswith('read_')} + + registers: list[Net | None] = [None, None] for net, node in net_node_list: if isinstance(node, Write): - yield node.args[0], node + assert len(registers) == 2 + type_list = [transl_type(r.dtype) if r else 'int' for r in registers] + yield node.args[0], Op(f"write_{type_list[0]}_reg0_" + '_'.join(type_list), node.args) elif node.name.startswith('read_'): yield net, node else: yield None, node - if net and net in read_back_nets and net not in stored_nets: - yield net, Write(net) - stored_nets.add(net) + if net: + registers[0] = net + if len(node.args) > 1: + registers[1] = net + + if net in read_back_nets and net not in stored_nets: + type_list = [transl_type(r.dtype) if r else 'int' for r in registers] + yield net, Op(f"write_{type_list[0]}_reg0_" + '_'.join(type_list), []) + stored_nets.add(net) def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]: diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index 0fb0e71..ce34eb0 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -180,12 +180,12 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str) -> str: @norm_indent -def get_write_code(type1: str) -> str: +def get_write_code(type1: str, type2: str) -> str: return f""" - {stencil_func_prefix}void write_{type1}({type1} arg1) {{ + {stencil_func_prefix}void write_{type1}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{ STENCIL_START(write_{type1}); dummy_{type1} = arg1; - result_{type1}(arg1); + result_{type1}_{type2}(arg1, arg2); }} """ @@ -256,8 +256,8 @@ if __name__ == "__main__": code += get_read_reg0_code(t1, t2, t_out) code += get_read_reg1_code(t1, t2, t_out) - for t1 in types: - code += get_write_code(t1) + for t1, t2 in permutate(types, types): + code += get_write_code(t1, t2) print(f"Write file {args.path}...") with open(args.path, 'w') as f: diff --git a/tests/test_issue001_aarch64.py b/tests/test_issue001_aarch64.py index afa41d2..04f038b 100644 --- a/tests/test_issue001_aarch64.py +++ b/tests/test_issue001_aarch64.py @@ -3,12 +3,10 @@ from copapy.backend import Write, compile_to_dag, add_read_command import subprocess from copapy import _binwrite import copapy.backend as backend -import copapy as cp import os import warnings import re import struct -import pytest if os.name == 'nt': # On Windows wsl and qemu-user is required: @@ -83,12 +81,17 @@ def iiftests(c1: NumLike) -> list[NumLike]: def test_compile(): - c_i = variable(5) - v2 = variable(0.0) - ret_test = [c_i + v2, c_i + v2] - ret_ref = [9 * 4.44, 9 * -4.44, 9 * -4.44] + a1 = 0.0 + a2 = 3 + + c1 = variable(a1) + c2 = variable(a2) + ret_test = [c1 + c2, c2 + c2] + ret_ref: list[int | float] = [a1 + a2, a2 + a2] + + #out = [Write(r) for r in ret_test] out = [Write(r) for r in ret_test] #ret_test += [c_i, v2] @@ -97,21 +100,21 @@ def test_compile(): sdb = backend.stencil_db_from_package('aarch64') dw, variables = compile_to_dag(out, sdb) - dw.write_com(_binwrite.Command.READ_DATA) - dw.write_int(0) - dw.write_int(28) + #dw.write_com(_binwrite.Command.READ_DATA) + #dw.write_int(0) + #dw.write_int(28) # run program command dw.write_com(_binwrite.Command.RUN_PROG) - #il.write_com(_binwrite.Command.DUMP_CODE) + #dw.write_com(_binwrite.Command.DUMP_CODE) for net in ret_test: assert isinstance(net, backend.Net) add_read_command(dw, variables, net) - dw.write_com(_binwrite.Command.READ_DATA) - dw.write_int(0) - dw.write_int(28) + #dw.write_com(_binwrite.Command.READ_DATA) + #dw.write_int(0) + #dw.write_int(28) dw.write_com(_binwrite.Command.END_COM) diff --git a/tests/test_issue001_x86_64.py b/tests/test_issue001_x86_64.py deleted file mode 100644 index 792c46f..0000000 --- a/tests/test_issue001_x86_64.py +++ /dev/null @@ -1,131 +0,0 @@ -from copapy import NumLike, iif, variable -from copapy.backend import Write, compile_to_dag, add_read_command -import subprocess -from copapy import _binwrite -import copapy.backend as backend -import copapy as cp -import os -import warnings -import re -import struct -import pytest - -def parse_results(log_text: str) -> dict[int, bytes]: - regex = r"^READ_DATA offs=(\d*) size=(\d*) data=(.*)$" - matches = re.finditer(regex, log_text, re.MULTILINE) - var_dict: dict[int, bytes] = {} - - for match in matches: - value_str: list[str] = match.group(3).strip().split(' ') - print('--', value_str) - value = bytes(int(v, base=16) for v in value_str) - var_dict[int(match.group(1))] = value - - return var_dict - - -def run_command(command: list[str]) -> str: - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf8', check=False) - assert result.returncode != 11, f"SIGSEGV (segmentation fault)\n -Error occurred: {result.stderr}\n -Output: {result.stdout}" - assert result.returncode == 0, f"\n -Error occurred: {result.stderr}\n -Output: {result.stdout}" - return result.stdout - - -def function1(c1: NumLike) -> list[NumLike]: - return [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4, - c1 * 4, c1 * -4, - c1 + 4, c1 - 4, - c1 > 2, c1 > 100, c1 < 4, c1 < 100] - - -def function2(c1: NumLike) -> list[NumLike]: - return [c1 / 4.44, c1 / -4.44, c1 // 4.44, c1 // -4.44, (c1 * -1) // 4.44, - c1 * 4.44, c1 * -4.44, - c1 + 4.44, c1 - 4.44, - c1 > 2, c1 > 100.11, c1 < 4.44, c1 < 100.11] - - -def function3(c1: NumLike) -> list[NumLike]: - return [c1 / 4] - - -def function4(c1: NumLike) -> list[NumLike]: - return [c1 == 9, c1 == 4, c1 != 9, c1 != 4] - - -def function5(c1: NumLike) -> list[NumLike]: - return [c1 == True, c1 == False, c1 != True, c1 != False, c1 / 2, c1 + 2] - - -def function6(c1: NumLike) -> list[NumLike]: - return [c1 == True] - - -def iiftests(c1: NumLike) -> list[NumLike]: - return [iif(c1 > 5, 8, 9), - iif(c1 < 5, 8.5, 9.5), - iif(1 > 5, 3.3, 8.8) + c1, - iif(1 < 5, c1 * 3.3, 8.8), - iif(c1 < 5, c1 * 3.3, 8.8)] - - -def test_compile(): - c_i = variable(9) - c_f = variable(1.111) - c_b = variable(True) - - ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [variable(9) % 2] + iiftests(c_i) + iiftests(c_f) - ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2] + iiftests(9) + iiftests(1.111) - - out = [Write(r) for r in ret_test] - - sdb = backend.stencil_db_from_package('native') - il, variables = compile_to_dag(out, sdb) - - # run program command - il.write_com(_binwrite.Command.RUN_PROG) - - for net in ret_test: - assert isinstance(net, backend.Net) - add_read_command(il, variables, net) - - il.write_com(_binwrite.Command.END_COM) - - print('* Data to runner:') - il.print() - - il.to_file('bin/test.copapy') - - command = ['bin/coparun', 'bin/test.copapy'] - result = run_command(command) - print('* Output from runner:\n--') - print(result) - print('--') - - assert 'Return value: 1' in result - - result_data = parse_results(result) - - for test, ref in zip(ret_test, ret_ref): - assert isinstance(test, variable) - address = variables[test][0] - data = result_data[address] - if test.dtype == 'int': - val = int.from_bytes(data, sdb.byteorder, signed=True) - elif test.dtype == 'bool': - val = bool.from_bytes(data, sdb.byteorder) - elif test.dtype == 'float': - en = {'little': '<', 'big': '>'}[sdb.byteorder] - val = struct.unpack(en + 'f', data)[0] - assert isinstance(val, float) - else: - raise Exception(f"Unknown type: {test.dtype}") - print('+', val, ref, test.dtype) - for t in (int, float, bool): - assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}" - assert val == pytest.approx(ref, 1e-5), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType] - - -if __name__ == "__main__": - #test_example() - test_compile()