This commit is contained in:
Nicolas 2025-10-12 22:22:30 +02:00
parent 72e46e06fb
commit 228daf5c9e
8 changed files with 88 additions and 43 deletions

View File

@ -38,7 +38,7 @@ dev = [
] ]
[tool.mypy] [tool.mypy]
files = ["src"] files = ["src", "tools"]
strict = true strict = true
warn_return_any = true warn_return_any = true
warn_unused_configs = true warn_unused_configs = true

View File

@ -387,26 +387,30 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
offset = 0 # offset in generated code chunk offset = 0 # offset in generated code chunk
# assemble stencils to main program # assemble stencils to main program
data = sdb.get_function_body('function_start', 'start') data = sdb.get_function_body('entry_function_shell', 'start')
data_list.append(data) data_list.append(data)
offset += len(data) offset += len(data)
for associated_net, node in extended_output_ops: for associated_net, node in extended_output_ops:
assert node.name in sdb.function_definitions, f"- Warning: {node.name} prototype not found" assert node.name in sdb.function_definitions, f"- Warning: {node.name} stencil not found"
data = sdb.get_stencil_code(node.name) data = sdb.get_stencil_code(node.name)
data_list.append(data) data_list.append(data)
# print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) # print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for patch in sdb.get_patch_positions(node.name): for patch in sdb.get_patch_positions(node.name):
assert associated_net, f"Relocation found but no net defined for operation {node.name}" if patch.target_symbol_info == 'STT_OBJECT':
object_addr = object_addr_lookp[associated_net] assert associated_net, f"Relocation found but no net defined for operation {node.name}"
patch_value = object_addr + patch.addend - (offset + patch.addr) object_addr = object_addr_lookp[associated_net]
# print('patch: ', patch, object_addr, patch_value) patch_value = object_addr + patch.addend - (offset + patch.addr)
patch_list.append((patch.type.value, offset + patch.addr, patch_value)) # print('patch: ', patch, object_addr, patch_value)
patch_list.append((patch.type.value, offset + patch.addr, patch_value))
print('++ ', patch.target_symbol_info, patch.target_symbol_name)
else:
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
offset += len(data) offset += len(data)
data = sdb.get_function_body('function_end', 'end') data = sdb.get_function_body('entry_function_shell', 'end')
data_list.append(data) data_list.append(data)
offset += len(data) offset += len(data)
# print('function_end', offset, data) # print('function_end', offset, data)

View File

@ -23,18 +23,15 @@ class patch_entry:
type: RelocationType type: RelocationType
addr: int addr: int
addend: int addend: int
target_symbol_name: str
target_symbol_info: str
def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> patch_entry: def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType:
if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'): if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'):
# S + A - P # S + A - P
patch_offset = relocation_addr return RelocationType.RELOC_RELATIVE_32
return patch_entry(RelocationType.RELOC_RELATIVE_32, patch_offset, r_addend)
else: else:
print('relocation_addr: ', relocation_addr)
print('reloc_type: ', reloc_type)
print('bits: ', bits)
print('r_addend: ', r_addend)
raise Exception(f"Unknown relocation type: {reloc_type}") raise Exception(f"Unknown relocation type: {reloc_type}")
@ -48,8 +45,11 @@ def get_return_function_type(symbol: elf_symbol) -> str:
def strip_function(func: elf_symbol) -> bytes: def strip_function(func: elf_symbol) -> bytes:
"""Return stencil code by striped stancil function""" """Return stencil code by striped stancil function"""
start_index, end_index = get_stencil_position(func) if func.relocations and func.relocations[-1].symbol.info == 'STT_NOTYPE':
return func.data[start_index:end_index] start_index, end_index = get_stencil_position(func)
return func.data[start_index:end_index]
else:
return func.data
def get_stencil_position(func: elf_symbol) -> tuple[int, int]: def get_stencil_position(func: elf_symbol) -> tuple[int, int]:
@ -117,13 +117,17 @@ class stencil_database():
for reloc in symbol.relocations: for reloc in symbol.relocations:
patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
# address to fist byte to patch relative to the start of the symbol # address to fist byte to patch relative to the start of the symbol
patch = translate_relocation( rtype = translate_relocation(
reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index, patch_offset,
reloc.type, reloc.type,
reloc.bits, reloc.bits,
reloc.fields['r_addend']) reloc.fields['r_addend'])
patch = patch_entry(rtype, patch_offset, reloc.fields['r_addend'], reloc.symbol.name, reloc.symbol.info)
# Exclude the call to the result_* function # Exclude the call to the result_* function
if patch.addr < end_index - start_index: if patch.addr < end_index - start_index:
yield patch yield patch

View File

@ -8,11 +8,14 @@ def function1(c1):
def function2(c1): def function2(c1):
return [c1 / 4, c1 / -4, c1 / 4, c1 / -4, (c1 * -1) / 4] return [c1 / 4, c1 / -4, c1 / 4, c1 / -4, (c1 * -1) / 4]
def function3(c1):
return [c1 / 4]
def test_compile(): def test_compile():
c1 = CPVariable(9) c1 = CPVariable(9)
ret = function2(c1) ret = function3(c1)
tg = Target() tg = Target()
print('* compile and copy ...') print('* compile and copy ...')
@ -22,7 +25,7 @@ def test_compile():
tg.run() tg.run()
#print('* finished') #print('* finished')
ret_ref = function2(9) ret_ref = function3(9)
for test, ref, name in zip(ret, ret_ref, ['r1', 'r2', 'r3', 'r4', 'r5']): for test, ref, name in zip(ret, ret_ref, ['r1', 'r2', 'r3', 'r4', 'r5']):
val = tg.read_value(test) val = tg.read_value(test)

View File

@ -9,17 +9,20 @@ def test_list_symbols():
#print(sdb.function_definitions) #print(sdb.function_definitions)
for sym_name in sdb.function_definitions.keys(): for sym_name in sdb.function_definitions.keys():
print('\n-', sym_name) print('\n-', sym_name)
print(list(sdb.get_patch_positions(sym_name))) #print(list(sdb.get_patch_positions(sym_name)))
def test_start_end_function(): def test_start_end_function():
for sym_name in sdb.function_definitions.keys(): for sym_name in sdb.function_definitions.keys():
symbol = sdb.elf.symbols[sym_name] symbol = sdb.elf.symbols[sym_name]
print('-', sym_name, stencil_db.get_stencil_position(symbol), len(symbol.data))
start, end = stencil_db.get_stencil_position(symbol) if symbol.relocations and symbol.relocations[-1].symbol.info == 'STT_NOTYPE':
print('-', sym_name, stencil_db.get_stencil_position(symbol), len(symbol.data))
assert start >= 0 and end >= start and end <= len(symbol.data) start, end = stencil_db.get_stencil_position(symbol)
assert start >= 0 and end >= start and end <= len(symbol.data)
def test_aux_functions(): def test_aux_functions():

View File

@ -8,21 +8,20 @@ op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/',
entry_func_prefix = '' entry_func_prefix = ''
stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog
def get_function_start() -> str: def get_aux_funcs() -> str:
return f""" return f"""
{entry_func_prefix}int function_start(){{ {entry_func_prefix}int entry_function_shell(){{
result_int(0); result_int(0);
return 1; return 1;
}} }}
""" + \
""" """
__attribute__((noinline)) int floor_div(float arg1, float arg2) {
float x = arg1 / arg2;
def get_function_end() -> str: int i = (int)x;
return f""" if (x < 0 && x != (float)i) i -= 1;
{entry_func_prefix}int function_end(){{ return i;
result_int(0); }
return 1;
}}
""" """
@ -53,10 +52,7 @@ def get_op_code_float(op: str, type1: str, type2: str) -> str:
def get_floordiv(op: str, type1: str, type2: str) -> str: def get_floordiv(op: str, type1: str, type2: str) -> str:
return f""" return f"""
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ {stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
float x = (float)arg1 / (float)arg2; result_int_{type2}(floor_div((float)arg1, (float)arg2), arg2);
int i = (int)x;
if (x < 0 && x != (float)i) i -= 1;
result_int_{type2}(i, arg2);
}} }}
""" """
@ -135,6 +131,8 @@ if __name__ == "__main__":
for t1, t2 in permutate(types, types): for t1, t2 in permutate(types, types):
code += get_result_stubs2(t1, t2) code += get_result_stubs2(t1, t2)
code += get_aux_funcs()
for op, t1, t2 in permutate(ops, types, types): for op, t1, t2 in permutate(ops, types, types):
t_out = t1 if t1 == t2 else 'float' t_out = t1 if t1 == t2 else 'float'
if op == 'floordiv': if op == 'floordiv':
@ -155,7 +153,5 @@ if __name__ == "__main__":
for t1 in types: for t1 in types:
code += get_write_code(t1) code += get_write_code(t1)
code += get_function_start() + get_function_end()
with open(args.path, 'w') as f: with open(args.path, 'w') as f:
f.write(code) f.write(code)

View File

@ -1,6 +1,7 @@
#!/bin/bash #!/bin/bash
source tools/build.sh source tools/build.sh
python tests/test_compile_div.py objdump -d -x src/copapy/obj/stencils_x86_64_O3.o > bin/stencils_x86_64_O3.asm
python tools/make_example.py
python tools/extract_code.py "bin/test.copapy" "bin/test.copapy.bin" python tools/extract_code.py "bin/test.copapy" "bin/test.copapy.bin"
objdump -D -b binary -m i386:x86-64 --adjust-vma=0x1000 bin/test.copapy.bin > bin/test.copapy.asm objdump -D -b binary -m i386:x86-64 --adjust-vma=0x1000 bin/test.copapy.bin > bin/test.copapy.asm

34
tools/make_example.py Normal file
View File

@ -0,0 +1,34 @@
from copapy import CPVariable, Target, Write, binwrite
import copapy
from pytest import approx
def test_compile() -> None:
c1 = CPVariable(9)
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
ret = [c1 / 4]
out = [Write(r) for r in ret]
il, _ = copapy.compile_to_instruction_list(out, copapy.generic_sdb)
# run program command
il.write_com(binwrite.Command.RUN_PROG)
il.write_int(0)
il.write_com(binwrite.Command.READ_DATA)
il.write_int(0)
il.write_int(36)
il.write_com(binwrite.Command.END_PROG)
print('* Data to runner:')
il.print()
il.to_file('bin/test.copapy')
if __name__ == "__main__":
test_compile()