example generation to track down sqrt issue

This commit is contained in:
Nicolas Kruse 2025-10-26 13:21:35 +01:00
parent e400eff2b0
commit 501bd5bee3
3 changed files with 21 additions and 13 deletions

View File

@ -1,6 +1,6 @@
from typing import Generator, Iterable, Any from typing import Generator, Iterable, Any
from . import _binwrite as binw from . import _binwrite as binw
from ._stencils import stencil_database, patch_entry from ._stencils import stencil_database
from collections import defaultdict, deque from collections import defaultdict, deque
from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type
@ -244,12 +244,11 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
# Prepare program code and relocations # Prepare program code and relocations
object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout} object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout}
section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout} section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout}
offset = aux_function_lengths # offset in generated code chunk
# assemble stencils to main program and patch stencils # assemble stencils to main program and patch stencils
data = sdb.get_function_code('entry_function_shell', 'start') data = sdb.get_function_code('entry_function_shell', 'start')
data_list.append(data) data_list.append(data)
offset += len(data) offset = aux_function_lengths + len(data)
for associated_net, node in extended_output_ops: for associated_net, node in extended_output_ops:
assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found" assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found"
@ -273,11 +272,13 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
addr = section_addr_lookup[patch.target_symbol_section_index] addr = section_addr_lookup[patch.target_symbol_section_index]
patch_value = addr + patch.addend - (offset + patch.addr) patch_value = addr + patch.addend - (offset + patch.addr)
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
elif patch.target_symbol_info == 'STT_FUNC': elif patch.target_symbol_info == 'STT_FUNC':
addr = aux_func_addr_lookup[patch.target_symbol_name] addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = addr + patch.addend - (offset + patch.addr) patch_value = addr + patch.addend - (offset + patch.addr)
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC)) patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC))
print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
else: else:
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
@ -303,17 +304,22 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
for patch in sdb.get_patch_positions(name): for patch in sdb.get_patch_positions(name):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
# Patch constants/variable addresses on heap # Patch constants/variable addresses on heap
addr = section_addr_lookup[patch.target_symbol_section_index] section_addr = section_addr_lookup[patch.target_symbol_section_index]
patch_value = addr + patch.addend - (start + patch.addr) patch_value = section_addr + patch.addend - (start + patch.addr)
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT)) patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
print(patch.type, patch.addr, section_addr, binw.Command.PATCH_OBJECT, name)
#print(patch.type, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT)
elif patch.target_symbol_info == 'STT_FUNC': elif patch.target_symbol_info == 'STT_FUNC':
addr = aux_func_addr_lookup[patch.target_symbol_name] aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = addr + patch.addend - (start + patch.addr) patch_value = aux_func_addr + patch.addend - (start + patch.addr)
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_FUNC)) patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_FUNC))
else: else:
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}") raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
assert False, aux_function_mem_layout
# write entry function code # write entry function code
dw.write_com(binw.Command.COPY_CODE) dw.write_com(binw.Command.COPY_CODE)
dw.write_int(aux_function_lengths) dw.write_int(aux_function_lengths)

View File

@ -15,7 +15,8 @@ def test_compiled_vectors():
t2 = t1.sum() t2 = t1.sum()
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3)) t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
t4 = ((t3 * t1) * 2).magnitude() #t4 = ((t3 * t1) * 2).magnitude()
t4 = ((t3 * t1) * 2).sum()
tg = cp.Target() tg = cp.Target()
@ -23,7 +24,7 @@ def test_compiled_vectors():
tg.run() tg.run()
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2 assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5 #assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
if __name__ == "__main__": if __name__ == "__main__":
test_compiled_vectors() test_compiled_vectors()

View File

@ -1,18 +1,19 @@
from copapy import _binwrite, variable from copapy import _binwrite, variable
from copapy.backend import Write, compile_to_instruction_list from copapy.backend import Write, compile_to_instruction_list
import copapy import copapy as cp
def test_compile() -> None: def test_compile() -> None:
c1 = variable(9) c1 = variable(9.0)
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4] #ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
ret = [c1 // 3.3 + 5] #ret = [c1 // 3.3 + 5]
ret = [cp.sqrt(c1)]
out = [Write(r) for r in ret] out = [Write(r) for r in ret]
il, _ = compile_to_instruction_list(out, copapy.generic_sdb) il, _ = compile_to_instruction_list(out, cp.generic_sdb)
# run program command # run program command
il.write_com(_binwrite.Command.RUN_PROG) il.write_com(_binwrite.Command.RUN_PROG)