mirror of https://github.com/Nonannet/copapy.git
example generation to track down sqrt issue
This commit is contained in:
parent
e400eff2b0
commit
501bd5bee3
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Generator, Iterable, Any
|
||||
from . import _binwrite as binw
|
||||
from ._stencils import stencil_database, patch_entry
|
||||
from ._stencils import stencil_database
|
||||
from collections import defaultdict, deque
|
||||
from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type
|
||||
|
||||
|
|
@ -244,12 +244,11 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
# Prepare program code and relocations
|
||||
object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout}
|
||||
section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout}
|
||||
offset = aux_function_lengths # offset in generated code chunk
|
||||
|
||||
# assemble stencils to main program and patch stencils
|
||||
data = sdb.get_function_code('entry_function_shell', 'start')
|
||||
data_list.append(data)
|
||||
offset += len(data)
|
||||
offset = aux_function_lengths + len(data)
|
||||
|
||||
for associated_net, node in extended_output_ops:
|
||||
assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found"
|
||||
|
|
@ -273,11 +272,13 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
addr = section_addr_lookup[patch.target_symbol_section_index]
|
||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
|
||||
print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
|
||||
|
||||
elif patch.target_symbol_info == 'STT_FUNC':
|
||||
addr = aux_func_addr_lookup[patch.target_symbol_name]
|
||||
patch_value = addr + patch.addend - (offset + patch.addr)
|
||||
patch_list.append((patch.type.value, offset + patch.addr, patch_value, binw.Command.PATCH_FUNC))
|
||||
print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||
|
||||
|
|
@ -303,17 +304,22 @@ def compile_to_instruction_list(node_list: Iterable[Node], sdb: stencil_database
|
|||
for patch in sdb.get_patch_positions(name):
|
||||
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
|
||||
# Patch constants/variable addresses on heap
|
||||
addr = section_addr_lookup[patch.target_symbol_section_index]
|
||||
patch_value = addr + patch.addend - (start + patch.addr)
|
||||
section_addr = section_addr_lookup[patch.target_symbol_section_index]
|
||||
patch_value = section_addr + patch.addend - (start + patch.addr)
|
||||
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT))
|
||||
print(patch.type, patch.addr, section_addr, binw.Command.PATCH_OBJECT, name)
|
||||
#print(patch.type, start + patch.addr, patch_value, binw.Command.PATCH_OBJECT)
|
||||
|
||||
elif patch.target_symbol_info == 'STT_FUNC':
|
||||
addr = aux_func_addr_lookup[patch.target_symbol_name]
|
||||
patch_value = addr + patch.addend - (start + patch.addr)
|
||||
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
|
||||
patch_value = aux_func_addr + patch.addend - (start + patch.addr)
|
||||
patch_list.append((patch.type.value, start + patch.addr, patch_value, binw.Command.PATCH_FUNC))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||
|
||||
assert False, aux_function_mem_layout
|
||||
|
||||
# write entry function code
|
||||
dw.write_com(binw.Command.COPY_CODE)
|
||||
dw.write_int(aux_function_lengths)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ def test_compiled_vectors():
|
|||
t2 = t1.sum()
|
||||
|
||||
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
|
||||
t4 = ((t3 * t1) * 2).magnitude()
|
||||
#t4 = ((t3 * t1) * 2).magnitude()
|
||||
t4 = ((t3 * t1) * 2).sum()
|
||||
|
||||
|
||||
tg = cp.Target()
|
||||
|
|
@ -23,7 +24,7 @@ def test_compiled_vectors():
|
|||
tg.run()
|
||||
|
||||
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
|
||||
assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
|
||||
#assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_compiled_vectors()
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
from copapy import _binwrite, variable
|
||||
from copapy.backend import Write, compile_to_instruction_list
|
||||
import copapy
|
||||
import copapy as cp
|
||||
|
||||
|
||||
def test_compile() -> None:
|
||||
|
||||
c1 = variable(9)
|
||||
c1 = variable(9.0)
|
||||
|
||||
#ret = [c1 / 4, c1 / -4, c1 // 4, c1 // -4, (c1 * -1) // 4]
|
||||
ret = [c1 // 3.3 + 5]
|
||||
#ret = [c1 // 3.3 + 5]
|
||||
ret = [cp.sqrt(c1)]
|
||||
|
||||
out = [Write(r) for r in ret]
|
||||
|
||||
il, _ = compile_to_instruction_list(out, copapy.generic_sdb)
|
||||
il, _ = compile_to_instruction_list(out, cp.generic_sdb)
|
||||
|
||||
# run program command
|
||||
il.write_com(_binwrite.Command.RUN_PROG)
|
||||
|
|
|
|||
Loading…
Reference in New Issue