issue with wrong results on aarch64 fixed, by guarding registers for the write op

This commit is contained in:
Nicolas 2025-11-07 16:01:22 +01:00 committed by Nicolas Kruse
parent e3f40f94c0
commit 7c77c42b80
4 changed files with 35 additions and 153 deletions

View File

@ -129,18 +129,28 @@ def add_write_ops(net_node_list: list[tuple[Net | None, Node]], const_nets: list
read_back_nets = { read_back_nets = {
net for net, node in net_node_list net for net, node in net_node_list
if net and node.name.startswith('read_')} if net and node.name.startswith('read_')}
registers: list[Net | None] = [None, None]
for net, node in net_node_list: for net, node in net_node_list:
if isinstance(node, Write): if isinstance(node, Write):
yield node.args[0], node assert len(registers) == 2
type_list = [transl_type(r.dtype) if r else 'int' for r in registers]
yield node.args[0], Op(f"write_{type_list[0]}_reg0_" + '_'.join(type_list), node.args)
elif node.name.startswith('read_'): elif node.name.startswith('read_'):
yield net, node yield net, node
else: else:
yield None, node yield None, node
if net and net in read_back_nets and net not in stored_nets: if net:
yield net, Write(net) registers[0] = net
stored_nets.add(net) if len(node.args) > 1:
registers[1] = net
if net in read_back_nets and net not in stored_nets:
type_list = [transl_type(r.dtype) if r else 'int' for r in registers]
yield net, Op(f"write_{type_list[0]}_reg0_" + '_'.join(type_list), [])
stored_nets.add(net)
def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]: def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:

View File

@ -180,12 +180,12 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str) -> str:
@norm_indent @norm_indent
def get_write_code(type1: str) -> str: def get_write_code(type1: str, type2: str) -> str:
return f""" return f"""
{stencil_func_prefix}void write_{type1}({type1} arg1) {{ {stencil_func_prefix}void write_{type1}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START(write_{type1}); STENCIL_START(write_{type1});
dummy_{type1} = arg1; dummy_{type1} = arg1;
result_{type1}(arg1); result_{type1}_{type2}(arg1, arg2);
}} }}
""" """
@ -256,8 +256,8 @@ if __name__ == "__main__":
code += get_read_reg0_code(t1, t2, t_out) code += get_read_reg0_code(t1, t2, t_out)
code += get_read_reg1_code(t1, t2, t_out) code += get_read_reg1_code(t1, t2, t_out)
for t1 in types: for t1, t2 in permutate(types, types):
code += get_write_code(t1) code += get_write_code(t1, t2)
print(f"Write file {args.path}...") print(f"Write file {args.path}...")
with open(args.path, 'w') as f: with open(args.path, 'w') as f:

View File

@ -3,12 +3,10 @@ from copapy.backend import Write, compile_to_dag, add_read_command
import subprocess import subprocess
from copapy import _binwrite from copapy import _binwrite
import copapy.backend as backend import copapy.backend as backend
import copapy as cp
import os import os
import warnings import warnings
import re import re
import struct import struct
import pytest
if os.name == 'nt': if os.name == 'nt':
# On Windows wsl and qemu-user is required: # On Windows wsl and qemu-user is required:
@ -83,12 +81,17 @@ def iiftests(c1: NumLike) -> list[NumLike]:
def test_compile(): def test_compile():
c_i = variable(5)
v2 = variable(0.0)
ret_test = [c_i + v2, c_i + v2] a1 = 0.0
ret_ref = [9 * 4.44, 9 * -4.44, 9 * -4.44] a2 = 3
c1 = variable(a1)
c2 = variable(a2)
ret_test = [c1 + c2, c2 + c2]
ret_ref: list[int | float] = [a1 + a2, a2 + a2]
#out = [Write(r) for r in ret_test]
out = [Write(r) for r in ret_test] out = [Write(r) for r in ret_test]
#ret_test += [c_i, v2] #ret_test += [c_i, v2]
@ -97,21 +100,21 @@ def test_compile():
sdb = backend.stencil_db_from_package('aarch64') sdb = backend.stencil_db_from_package('aarch64')
dw, variables = compile_to_dag(out, sdb) dw, variables = compile_to_dag(out, sdb)
dw.write_com(_binwrite.Command.READ_DATA) #dw.write_com(_binwrite.Command.READ_DATA)
dw.write_int(0) #dw.write_int(0)
dw.write_int(28) #dw.write_int(28)
# run program command # run program command
dw.write_com(_binwrite.Command.RUN_PROG) dw.write_com(_binwrite.Command.RUN_PROG)
#il.write_com(_binwrite.Command.DUMP_CODE) #dw.write_com(_binwrite.Command.DUMP_CODE)
for net in ret_test: for net in ret_test:
assert isinstance(net, backend.Net) assert isinstance(net, backend.Net)
add_read_command(dw, variables, net) add_read_command(dw, variables, net)
dw.write_com(_binwrite.Command.READ_DATA) #dw.write_com(_binwrite.Command.READ_DATA)
dw.write_int(0) #dw.write_int(0)
dw.write_int(28) #dw.write_int(28)
dw.write_com(_binwrite.Command.END_COM) dw.write_com(_binwrite.Command.END_COM)

View File

@ -1,131 +0,0 @@
from copapy import NumLike, iif, variable
from copapy.backend import Write, compile_to_dag, add_read_command
import subprocess
from copapy import _binwrite
import copapy.backend as backend
import copapy as cp
import os
import warnings
import re
import struct
import pytest
def parse_results(log_text: str) -> dict[int, bytes]:
regex = r"^READ_DATA offs=(\d*) size=(\d*) data=(.*)$"
matches = re.finditer(regex, log_text, re.MULTILINE)
var_dict: dict[int, bytes] = {}
for match in matches:
value_str: list[str] = match.group(3).strip().split(' ')
print('--', value_str)
value = bytes(int(v, base=16) for v in value_str)
var_dict[int(match.group(1))] = value
return var_dict
def run_command(command: list[str]) -> str:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf8', check=False)
assert result.returncode != 11, f"SIGSEGV (segmentation fault)\n -Error occurred: {result.stderr}\n -Output: {result.stdout}"
assert result.returncode == 0, f"\n -Error occurred: {result.stderr}\n -Output: {result.stdout}"
return result.stdout
def function1(c1: NumLike) -> list[NumLike]:
return [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4,
c1 * 4, c1 * -4,
c1 + 4, c1 - 4,
c1 > 2, c1 > 100, c1 < 4, c1 < 100]
def function2(c1: NumLike) -> list[NumLike]:
return [c1 / 4.44, c1 / -4.44, c1 // 4.44, c1 // -4.44, (c1 * -1) // 4.44,
c1 * 4.44, c1 * -4.44,
c1 + 4.44, c1 - 4.44,
c1 > 2, c1 > 100.11, c1 < 4.44, c1 < 100.11]
def function3(c1: NumLike) -> list[NumLike]:
return [c1 / 4]
def function4(c1: NumLike) -> list[NumLike]:
return [c1 == 9, c1 == 4, c1 != 9, c1 != 4]
def function5(c1: NumLike) -> list[NumLike]:
return [c1 == True, c1 == False, c1 != True, c1 != False, c1 / 2, c1 + 2]
def function6(c1: NumLike) -> list[NumLike]:
return [c1 == True]
def iiftests(c1: NumLike) -> list[NumLike]:
return [iif(c1 > 5, 8, 9),
iif(c1 < 5, 8.5, 9.5),
iif(1 > 5, 3.3, 8.8) + c1,
iif(1 < 5, c1 * 3.3, 8.8),
iif(c1 < 5, c1 * 3.3, 8.8)]
def test_compile():
c_i = variable(9)
c_f = variable(1.111)
c_b = variable(True)
ret_test = function1(c_i) + function1(c_f) + function2(c_i) + function2(c_f) + function3(c_i) + function4(c_i) + function5(c_b) + [variable(9) % 2] + iiftests(c_i) + iiftests(c_f)
ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2] + iiftests(9) + iiftests(1.111)
out = [Write(r) for r in ret_test]
sdb = backend.stencil_db_from_package('native')
il, variables = compile_to_dag(out, sdb)
# run program command
il.write_com(_binwrite.Command.RUN_PROG)
for net in ret_test:
assert isinstance(net, backend.Net)
add_read_command(il, variables, net)
il.write_com(_binwrite.Command.END_COM)
print('* Data to runner:')
il.print()
il.to_file('bin/test.copapy')
command = ['bin/coparun', 'bin/test.copapy']
result = run_command(command)
print('* Output from runner:\n--')
print(result)
print('--')
assert 'Return value: 1' in result
result_data = parse_results(result)
for test, ref in zip(ret_test, ret_ref):
assert isinstance(test, variable)
address = variables[test][0]
data = result_data[address]
if test.dtype == 'int':
val = int.from_bytes(data, sdb.byteorder, signed=True)
elif test.dtype == 'bool':
val = bool.from_bytes(data, sdb.byteorder)
elif test.dtype == 'float':
en = {'little': '<', 'big': '>'}[sdb.byteorder]
val = struct.unpack(en + 'f', data)[0]
assert isinstance(val, float)
else:
raise Exception(f"Unknown type: {test.dtype}")
print('+', val, ref, test.dtype)
for t in (int, float, bool):
assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == pytest.approx(ref, 1e-5), f"Result does not match: {val} and reference: {ref}" # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
#test_example()
test_compile()