mirror of https://github.com/Nonannet/copapy.git
added bool support
This commit is contained in:
parent
889716642b
commit
7d5990e2b2
|
|
@ -24,6 +24,9 @@ def stencil_db_from_package(arch: str = 'native', optimization: str = 'O3') -> s
|
||||||
|
|
||||||
generic_sdb = stencil_db_from_package()
|
generic_sdb = stencil_db_from_package()
|
||||||
|
|
||||||
|
def transl_type(t: str):
|
||||||
|
return {'bool': 'int'}.get(t, t)
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -84,7 +87,16 @@ class Net:
|
||||||
return _add_op('gt', [other, self])
|
return _add_op('gt', [other, self])
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> 'Net': # type: ignore
|
def __eq__(self, other: Any) -> 'Net': # type: ignore
|
||||||
return _add_op('eq', [self, other])
|
return _add_op('eq', [self, other], True)
|
||||||
|
|
||||||
|
def __req__(self, other: Any) -> 'Net': # type: ignore
|
||||||
|
return _add_op('eq', [self, other], True)
|
||||||
|
|
||||||
|
def __ne__(self, other: Any) -> 'Net': # type: ignore
|
||||||
|
return _add_op('ne', [self, other], True)
|
||||||
|
|
||||||
|
def __rne__(self, other: Any) -> 'Net': # type: ignore
|
||||||
|
return _add_op('ne', [self, other], True)
|
||||||
|
|
||||||
def __mod__(self, other: Any) -> 'Net':
|
def __mod__(self, other: Any) -> 'Net':
|
||||||
return _add_op('mod', [self, other])
|
return _add_op('mod', [self, other])
|
||||||
|
|
@ -109,7 +121,7 @@ class InitVar(Node):
|
||||||
|
|
||||||
class Write(Node):
|
class Write(Node):
|
||||||
def __init__(self, net: Net):
|
def __init__(self, net: Net):
|
||||||
self.name = 'write_' + net.dtype
|
self.name = 'write_' + transl_type(net.dtype)
|
||||||
self.args = [net]
|
self.args = [net]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -126,12 +138,15 @@ def _add_op(op: str, args: list[Any], commutative: bool = False) -> Net:
|
||||||
if commutative:
|
if commutative:
|
||||||
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
|
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
|
||||||
|
|
||||||
typed_op = '_'.join([op] + [a.dtype for a in arg_nets])
|
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])
|
||||||
|
|
||||||
if typed_op not in generic_sdb.stencil_definitions:
|
if typed_op not in generic_sdb.stencil_definitions:
|
||||||
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
raise ValueError(f"Unsupported operand type(s) for {op}: {' and '.join([a.dtype for a in arg_nets])}")
|
||||||
|
|
||||||
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
if op in {'eq', 'ne', 'ge'}:
|
||||||
|
result_type = 'bool'
|
||||||
|
else:
|
||||||
|
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
||||||
|
|
||||||
result_net = Net(result_type, Op(typed_op, arg_nets))
|
result_net = Net(result_type, Op(typed_op, arg_nets))
|
||||||
|
|
||||||
|
|
@ -160,12 +175,12 @@ class CPBool(CPVariable):
|
||||||
|
|
||||||
|
|
||||||
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
||||||
if isinstance(value, int):
|
if isinstance(value, bool):
|
||||||
|
return ('bool', int(value))
|
||||||
|
elif isinstance(value, int):
|
||||||
return ('int', int(value))
|
return ('int', int(value))
|
||||||
elif isinstance(value, float):
|
elif isinstance(value, float):
|
||||||
return ('float', float(value))
|
return ('float', float(value))
|
||||||
elif isinstance(value, bool):
|
|
||||||
return ('bool', int(value))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Non supported data type: {type(value).__name__}')
|
raise ValueError(f'Non supported data type: {type(value).__name__}')
|
||||||
|
|
||||||
|
|
@ -281,8 +296,8 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
|
||||||
if id(net) != id(registers[i]):
|
if id(net) != id(registers[i]):
|
||||||
#if net in registers:
|
#if net in registers:
|
||||||
# print('x swap registers')
|
# print('x swap registers')
|
||||||
type_list = ['int' if r is None else r.dtype for r in registers]
|
type_list = ['int' if r is None else transl_type(r.dtype) for r in registers]
|
||||||
new_node = Op(f"read_{net.dtype}_reg{i}_" + '_'.join(type_list), [])
|
new_node = Op(f"read_{transl_type(net.dtype)}_reg{i}_" + '_'.join(type_list), [])
|
||||||
yield net, new_node
|
yield net, new_node
|
||||||
registers[i] = net
|
registers[i] = net
|
||||||
|
|
||||||
|
|
@ -339,7 +354,7 @@ def get_variable_mem_layout(variable_list: Iterable[Net], sdb: stencil_database)
|
||||||
object_list: list[tuple[Net, int, int]] = []
|
object_list: list[tuple[Net, int, int]] = []
|
||||||
|
|
||||||
for variable in variable_list:
|
for variable in variable_list:
|
||||||
lengths = sdb.get_symbol_size('dummy_' + variable.dtype)
|
lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype))
|
||||||
object_list.append((variable, offset, lengths))
|
object_list.append((variable, offset, lengths))
|
||||||
offset += (lengths + 3) // 4 * 4
|
offset += (lengths + 3) // 4 * 4
|
||||||
|
|
||||||
|
|
@ -421,7 +436,6 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
||||||
|
|
||||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
patch_value = addr + patch.addend - (offset + patch.addr)
|
||||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
|
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
|
||||||
print('++ ', patch.target_symbol_info, patch.target_symbol_name)
|
|
||||||
|
|
||||||
offset += len(data)
|
offset += len(data)
|
||||||
|
|
||||||
|
|
@ -473,7 +487,7 @@ class Target():
|
||||||
nodes.append(Write(s))
|
nodes.append(Write(s))
|
||||||
else:
|
else:
|
||||||
for net in s:
|
for net in s:
|
||||||
assert isinstance(net, Net)
|
assert isinstance(net, Net), f"The folowing element is not a Net: {net}"
|
||||||
nodes.append(Write(net))
|
nodes.append(Write(net))
|
||||||
|
|
||||||
dw, self._variables = compile_to_instruction_list(nodes, self.sdb)
|
dw, self._variables = compile_to_instruction_list(nodes, self.sdb)
|
||||||
|
|
@ -487,10 +501,9 @@ class Target():
|
||||||
dw.write_com(binw.Command.END_COM)
|
dw.write_com(binw.Command.END_COM)
|
||||||
assert coparun(dw.get_data()) > 0
|
assert coparun(dw.get_data()) > 0
|
||||||
|
|
||||||
def read_value(self, net: Net) -> float | int:
|
def read_value(self, net: Net) -> float | int | bool:
|
||||||
assert net in self._variables, f"Variable {net} not found"
|
assert net in self._variables, f"Variable {net} not found"
|
||||||
addr, lengths, var_type = self._variables[net]
|
addr, lengths, var_type = self._variables[net]
|
||||||
print('read_value', addr, lengths)
|
|
||||||
assert lengths > 0
|
assert lengths > 0
|
||||||
data = read_data_mem(addr, lengths)
|
data = read_data_mem(addr, lengths)
|
||||||
assert data is not None and len(data) == lengths, f"Failed to read variable {net}"
|
assert data is not None and len(data) == lengths, f"Failed to read variable {net}"
|
||||||
|
|
@ -505,12 +518,13 @@ class Target():
|
||||||
assert isinstance(value, float)
|
assert isinstance(value, float)
|
||||||
return value
|
return value
|
||||||
elif var_type == 'int':
|
elif var_type == 'int':
|
||||||
if lengths in (1, 2, 4, 8):
|
assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes"
|
||||||
value = int.from_bytes(data, byteorder=self.sdb.byteorder, signed=True)
|
value = int.from_bytes(data, byteorder=self.sdb.byteorder, signed=True)
|
||||||
assert isinstance(value, int)
|
return value
|
||||||
return value
|
elif var_type == 'bool':
|
||||||
else:
|
assert lengths in (1, 2, 4, 8), f"Unsupported int length: {lengths} bytes"
|
||||||
raise ValueError(f"Unsupported int length: {lengths} bytes")
|
value = bool.from_bytes(data, byteorder=self.sdb.byteorder, signed=True)
|
||||||
|
return value
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported variable type: {var_type}")
|
raise ValueError(f"Unsupported variable type: {var_type}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -149,6 +149,7 @@ class stencil_database():
|
||||||
name_set: set[str] = set()
|
name_set: set[str] = set()
|
||||||
for name in names:
|
for name in names:
|
||||||
if name not in name_set:
|
if name not in name_set:
|
||||||
|
# assert name in self.elf.symbols, f"Stencil {name} not found" <-- see: https://github.com/Nonannet/pelfy/issues/1
|
||||||
func = self.elf.symbols[name]
|
func = self.elf.symbols[name]
|
||||||
for r in func.relocations:
|
for r in func.relocations:
|
||||||
if r.symbol.info == 'STT_FUNC':
|
if r.symbol.info == 'STT_FUNC':
|
||||||
|
|
|
||||||
|
|
@ -3,34 +3,54 @@ from pytest import approx
|
||||||
|
|
||||||
|
|
||||||
def function1(c1):
|
def function1(c1):
|
||||||
return [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
|
return [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4,
|
||||||
|
c1 * 4, c1 * -4,
|
||||||
|
c1 + 4, c1 - 4,
|
||||||
|
c1 > 2, c1 > 100, c1 < 4, c1 < 100]
|
||||||
|
|
||||||
def function2(c1):
|
def function2(c1):
|
||||||
return [c1 / 4, c1 / -4, c1 / 4, c1 / -4, (c1 * -1) / 4]
|
return [c1 / 4.44, c1 / -4.44, c1 // 4.44, c1 // -4.44, (c1 * -1) // 4.44,
|
||||||
|
c1 * 4.44, c1 * -4.44,
|
||||||
|
c1 + 4.44, c1 - 4.44,
|
||||||
|
c1 > 2, c1 > 100.11, c1 < 4.44, c1 < 100.11]
|
||||||
|
|
||||||
def function3(c1):
|
def function3(c1):
|
||||||
return [c1 / 4]
|
return [c1 / 4]
|
||||||
|
|
||||||
|
def function4(c1):
|
||||||
|
return [c1 == 9, c1 == 4, c1 != 9, c1 != 4]
|
||||||
|
|
||||||
|
def function5(c1):
|
||||||
|
return [c1 == True, c1 == False, c1 != True, c1 != False]
|
||||||
|
|
||||||
def test_compile():
|
def test_compile():
|
||||||
|
|
||||||
c1 = CPVariable(9)
|
c1 = CPVariable(9)
|
||||||
|
c2 = CPVariable(1.111)
|
||||||
|
c3 = CPVariable(False)
|
||||||
|
|
||||||
ret = function3(c1)
|
#ret_test = function1(c1) + function1(c2) + function2(c1) + function2(c2) + function3(c3) + function4(c1) + function5(c3) + [CPVariable(9) % 2]
|
||||||
|
#ret_ref = function1(9) + function1(1.111) + function2(9) + function2(1.111) + function3(9) + function4(9) + function5(True) + [9 % 2]
|
||||||
|
|
||||||
|
ret_test = [c1 / 4]
|
||||||
|
ret_ref = [9 / 4]
|
||||||
|
|
||||||
|
print(ret_test)
|
||||||
|
|
||||||
tg = Target()
|
tg = Target()
|
||||||
print('* compile and copy ...')
|
print('* compile and copy ...')
|
||||||
tg.compile(ret)
|
tg.compile(ret_test)
|
||||||
#time.sleep(5)
|
#time.sleep(5)
|
||||||
print('* run and copy ...')
|
print('* run and copy ...')
|
||||||
tg.run()
|
tg.run()
|
||||||
#print('* finished')
|
print('* finished')
|
||||||
|
|
||||||
ret_ref = function3(9)
|
for test, ref in zip(ret_test, ret_ref):
|
||||||
|
|
||||||
for test, ref, name in zip(ret, ret_ref, ['r1', 'r2', 'r3', 'r4', 'r5']):
|
|
||||||
val = tg.read_value(test)
|
val = tg.read_value(test)
|
||||||
print('+', name, val, ref)
|
print('+', val, ref)
|
||||||
assert val == approx(ref, 1e-5), name
|
for t in [int, float, bool]:
|
||||||
|
assert isinstance(val, t) == isinstance(ref, t), f"Result type does not match for {val} and {ref}"
|
||||||
|
assert val == approx(ref, 1e-5), f"Result does not match: {val} and reference: {ref}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/',
|
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/',
|
||||||
'gt': '>', 'eq': '==', 'mod': '%'}
|
'gt': '>', 'eq': '==', 'ne': '!=', 'mod': '%'}
|
||||||
|
|
||||||
entry_func_prefix = ''
|
entry_func_prefix = ''
|
||||||
stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog
|
stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog
|
||||||
|
|
@ -117,13 +116,15 @@ if __name__ == "__main__":
|
||||||
// Auto-generated stencils for copapy
|
// Auto-generated stencils for copapy
|
||||||
// Do not edit manually
|
// Do not edit manually
|
||||||
|
|
||||||
|
#define bool int
|
||||||
|
|
||||||
volatile int dummy_int = 1337;
|
volatile int dummy_int = 1337;
|
||||||
volatile float dummy_float = 1337;
|
volatile float dummy_float = 1337;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Scalar arithmetic:
|
# Scalar arithmetic:
|
||||||
types = ['int', 'float']
|
types = ['int', 'float']
|
||||||
ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'eq']
|
ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'eq', 'ne']
|
||||||
|
|
||||||
for t1 in types:
|
for t1 in types:
|
||||||
code += get_result_stubs1(t1)
|
code += get_result_stubs1(t1)
|
||||||
|
|
@ -139,7 +140,7 @@ if __name__ == "__main__":
|
||||||
code += get_floordiv('floordiv', t1, t2)
|
code += get_floordiv('floordiv', t1, t2)
|
||||||
elif op == 'div':
|
elif op == 'div':
|
||||||
code += get_op_code_float(op, t1, t2)
|
code += get_op_code_float(op, t1, t2)
|
||||||
elif op == 'gt' or op == 'eq':
|
elif op == 'gt' or op == 'eq' or op == 'ne':
|
||||||
code += get_op_code(op, t1, t2, 'int')
|
code += get_op_code(op, t1, t2, 'int')
|
||||||
else:
|
else:
|
||||||
code += get_op_code(op, t1, t2, t_out)
|
code += get_op_code(op, t1, t2, t_out)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue