added bool support

This commit is contained in:
Nicolas 2025-10-13 22:58:55 +02:00
parent 889716642b
commit 7d5990e2b2
4 changed files with 70 additions and 34 deletions

View File

@ -24,6 +24,9 @@ def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> s
generic_sdb = stencil_db_from_package()
def transl_type(t: str):
return {'bool': 'int'}.get(t, t)
class Node:
def __init__(self) -> None:
@ -84,7 +87,16 @@ class Net:
return _add_op('gt', [other, self])
def __eq__(self, other: Any) -> 'Net': # type: ignore
return _add_op('eq', [self, other])
return _add_op('eq', [self, other], True)
def __req__(self, other: Any) -> 'Net': # type: ignore
return _add_op('eq', [self, other], True)
def __ne__(self, other: Any) -> 'Net': # type: ignore
return _add_op('ne', [self, other], True)
def __rne__(self, other: Any) -> 'Net': # type: ignore
return _add_op('ne', [self, other], True)
def __mod__(self, other: Any) -> 'Net':
return _add_op('mod', [self, other])
@ -109,7 +121,7 @@ class InitVar(Node):
class Write(Node):
def __init__(self, net: Net):
self.name = 'write_' + net.dtype
self.name = 'write_' + transl_type(net.dtype)
self.args = [net]
@ -126,11 +138,14 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])
if typed_op not in generic_sdb.stencil_definitions:
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
if op in {'eq', 'ne', 'ge'}:
result_type = 'bool'
else:
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
result_net = Net(result_type, Op(typed_op, arg_nets))
@ -160,12 +175,12 @@ class CPBool(CPVariable):
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
if isinstance(value, int):
if isinstance(value, bool):
return ('bool', int(value))
elif isinstance(value, int):
return ('int', int(value))
elif isinstance(value, float):
return ('float', float(value))
elif isinstance(value, bool):
return ('bool', int(value))
else:
raise ValueError(f'Non supported data type: {type(value).__name__}')
@ -281,8 +296,8 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
if id(net) != id(registers[i]):
#if net in registers:
# print('x swap registers')
type_list = ['int' if r is None else r.dtype for r in registers]
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
type_list = ['int' if r is None else transl_type(r.dtype) for r in registers]
new_node = Op(f"read_{transl_type(net.dtype)}_reg{i}_" + '_'.join(type_list), [])
yield net, new_node
registers[i] = net
@ -339,7 +354,7 @@ def get_variable_mem_layout(variable_list: Iterable[Net], sdb: stencil_database)
object_list: list[tuple[Net, int, int]] = []
for variable in variable_list:
lengths = sdb.get_symbol_size('dummy_' + variable.dtype)
lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype))
object_list.append((variable, offset, lengths))
offset += (lengths + 3) // 4 * 4
@ -421,7 +436,6 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
patch_value = addr + patch.addend - (offset + patch.addr)
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
print('++ ', patch.target_symbol_info, patch.target_symbol_name)
offset += len(data)
@ -473,7 +487,7 @@ class Target():
nodes.append(Write(s))
else:
for net in s:
assert isinstance(net, Net)
assert isinstance(net, Net), f"The folowing element is not a Net: {net}"
nodes.append(Write(net))
dw, self._variables = compile_to_instruction_list(nodes, self.sdb)
@ -487,10 +501,9 @@ class Target():
dw.write_com(binw.Command.END_COM)
assert coparun(dw.get_data()) > 0
def read_value(self, net: Net) -> float | int:
def read_value(self, net: Net) -> float | int | bool:
assert net in self._variables, f"Variable {net} not found"
addr, lengths, var_type = self._variables[net]
print('read_value', addr, lengths)
assert lengths > 0
data = read_data_mem(addr, lengths)
assert data is not None and len(data) == lengths, f"Failed to read variable {net}"
@ -505,12 +518,13 @@ class Target():
assert isinstance(value, float)
return value
elif var_type == 'int':
if lengths in (1, 2, 4, 8):
assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes"
value = int.from_bytes(data, byteorder=self.sdb.byteorder, signed=True)
assert isinstance(value, int)
return value
else:
raise ValueError(f"Unsupported int length: {lengths} bytes")
elif var_type == 'bool':
assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes"
value = bool.from_bytes(data, byteorder=self.sdb.byteorder, signed=True)
return value
else:
raise ValueError(f"Unsupported variable type: {var_type}")

View File

@ -149,6 +149,7 @@ class stencil_database():
name_set: set[str] = set()
for name in names:
if name not in name_set:
# assert name in self.elf.symbols, f"Stencil {name} not found" <-- see: https://github.com/Nonannet/pelfy/issues/1
func = self.elf.symbols[name]
for r in func.relocations:
if r.symbol.info == 'STT_FUNC':

View File

@ -3,34 +3,54 @@ from pytest import approx
def function1(c1):
return [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
return [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4,
c1 * 4, c1 * -4,
c1 + 4, c1 - 4,
c1 > 2, c1 > 100, c1 < 4, c1 < 100]
def function2(c1):
return [c1 / 4, c1 / -4, c1 / 4, c1 / -4, (c1 * -1) / 4]
return [c1 / 4.44, c1 / -4.44, c1 // 4.44, c1 // -4.44, (c1 * -1) // 4.44,
c1 * 4.44, c1 * -4.44,
c1 + 4.44, c1 - 4.44,
c1 > 2, c1 > 100.11, c1 < 4.44, c1 < 100.11]
def function3(c1):
return [c1 / 4]
def function4(c1):
return [c1 == 9, c1 == 4, c1 != 9, c1 != 4]
def function5(c1):
return [c1 == True, c1 == False, c1 != True, c1 != False]
def test_compile():
c1 = CPVariable(9)
c2 = CPVariable(1.111)
c3 = CPVariable(False)
ret = function3(c1)
#ret_test = function1(c1) + function1(c2) + function2(c1) + function2(c2) + function3(c3) + function4(c1) + function5(c3) + [CPVariable(9) % 2]
#ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2]
ret_test = [c1 / 4]
ret_ref = [9 / 4]
print(ret_test)
tg = Target()
print('* compile and copy ...')
tg.compile(ret)
tg.compile(ret_test)
#time.sleep(5)
print('* run and copy ...')
tg.run()
#print('* finished')
print('* finished')
ret_ref = function3(9)
for test, ref, name in zip(ret, ret_ref, ['r1', 'r2', 'r3', 'r4', 'r5']):
for test, ref in zip(ret_test, ret_ref):
val = tg.read_value(test)
print('+', name, val, ref)
assert val == approx(ref, 1e-5), name
print('+', val, ref)
for t in [int, float, bool]:
assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
assert val == approx(ref, 1e-5), f"Result does not match: {val} and reference: {ref}"
if __name__ == "__main__":

View File

@ -1,9 +1,8 @@
from typing import Generator
import argparse
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/',
'gt': '>', 'eq': '==', 'mod': '%'}
'gt': '>', 'eq': '==', 'ne': '!=', 'mod': '%'}
entry_func_prefix = ''
stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog
@ -117,13 +116,15 @@ if __name__ == "__main__":
// Auto-generated stencils for copapy
// Do not edit manually
#define bool int
volatile int dummy_int = 1337;
volatile float dummy_float = 1337;
"""
# Scalar arithmetic:
types = ['int', 'float']
ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'eq']
ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'eq', 'ne']
for t1 in types:
code += get_result_stubs1(t1)
@ -139,7 +140,7 @@ if __name__ == "__main__":
code += get_floordiv('floordiv', t1, t2)
elif op == 'div':
code += get_op_code_float(op, t1, t2)
elif op == 'gt' or op == 'eq':
elif op == 'gt' or op == 'eq' or op == 'ne':
code += get_op_code(op, t1, t2, 'int')
else:
code += get_op_code(op, t1, t2, t_out)