diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index c3c4115..535bde1 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -313,13 +313,13 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi obj_addr = section_addr + patch.target_symbol_address patch_value = obj_addr + patch.addend - (offset + patch.patch_address) #print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name) - patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT)) + patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT)) #print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name) elif patch.target_symbol_info == 'STT_FUNC': addr = aux_func_addr_lookup[patch.target_symbol_name] patch_value = addr + patch.addend - (offset + patch.patch_address) - patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC)) + patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC)) #print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name) else: raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") @@ -349,13 +349,13 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi section_addr = section_addr_lookup[patch.target_symbol_section_index] obj_addr = section_addr + patch.target_symbol_address patch_value = obj_addr + patch.addend - (start + patch.patch_address) - patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT)) + patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT)) #print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name) elif patch.target_symbol_info == 'STT_FUNC': aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name] patch_value = aux_func_addr + patch.addend - (start + patch.patch_address) - patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC)) + patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC)) else: raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}") @@ -369,10 +369,10 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi dw.write_bytes(b''.join(data_list)) # write patch operations - for patch_type, patch_addr, addr, patch_command in patch_list: + for mask, patch_addr, addr, patch_command in patch_list: dw.write_com(patch_command) dw.write_int(patch_addr) - dw.write_int(patch_type) + dw.write_int(mask) dw.write_int(addr, signed=True) dw.write_com(binw.Command.ENTRY_POINT) diff --git a/src/copapy/_stencils.py b/src/copapy/_stencils.py index 2c5e938..42d3100 100644 --- a/src/copapy/_stencils.py +++ b/src/copapy/_stencils.py @@ -1,16 +1,13 @@ from dataclasses import dataclass from pelfy import open_elf_file, elf_file, elf_symbol from typing import Generator, Literal, Iterable -from enum import Enum +import pelfy ByteOrder = Literal['little', 'big'] # on x86_64: call or jmp instruction when tail call optimized LENGTH_CALL_INSTRUCTION = 5 -RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)]) - - @dataclass class patch_entry: """ @@ -20,7 +17,7 @@ class patch_entry: addr (int): address of first byte to patch relative to the start of the symbol type (RelocationType): relocation type""" - type: RelocationType + mask: int patch_address: int addend: int target_symbol_name: str @@ -29,12 +26,27 @@ class patch_entry: target_symbol_address: int -def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType: - if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'): +def translate_relocation(reloc: pelfy.elf_relocation, offset: int) -> patch_entry: + if reloc.type in ('R_AMD64_PLT32', 'R_AMD64_PC32'): # S + A - P - return RelocationType.RELOC_RELATIVE_32 + mask = 0xFFFFFFFF # 32 bit + imm = offset + + elif reloc.type.endswith('_JUMP26'): + assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64" + mask = 0x3ffffff # 26 bit + imm = offset >> 2 + assert imm < mask, "Relocation immediate value too large" + else: - raise Exception(f"Unknown relocation type: {reloc_type}") + raise NotImplementedError(f"Relocation type {reloc.type} not implemented") + + return patch_entry(mask, imm, + reloc.fields['r_addend'], + reloc.symbol.name, + reloc.symbol.info, + reloc.symbol.fields['st_shndx'], + reloc.symbol.fields['st_value']) def get_return_function_type(symbol: elf_symbol) -> str: @@ -129,6 +141,7 @@ class stencil_database(): Yields: patch_entry: every relocation for the symbol """ + arm_hi_byte_flag: bool = False symbol = self.elf.symbols[symbol_name] if stencil: start_index, end_index = get_stencil_position(symbol) @@ -136,27 +149,19 @@ class stencil_database(): start_index = 0 end_index = symbol.fields['st_size'] + print('->', symbol_name) for reloc in symbol.relocations: patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index - # address to fist byte to patch relative to the start of the symbol - rtype = translate_relocation( - patch_offset, - reloc.type, - reloc.bits, - reloc.fields['r_addend']) + if patch_offset < end_index - start_index: # Exclude the call to the result_* function + if reloc.symbol.info == 'STT_SECTION': + arm_hi_byte_flag = True + else: + assert not arm_hi_byte_flag, "Page based relocation for ARM not supported" + # address to fist byte to patch relative to the start of the symbol - patch = patch_entry(rtype, patch_offset, - reloc.fields['r_addend'], - reloc.symbol.name, - reloc.symbol.info, - reloc.symbol.fields['st_shndx'], - reloc.symbol.fields['st_value']) - - # Exclude the call to the result_* function - if patch.patch_address < end_index - start_index: - yield patch + yield translate_relocation(reloc, patch_offset) def get_stencil_code(self, name: str) -> bytes: """Return the striped function code for a provided function name diff --git a/src/coparun/runmem.c b/src/coparun/runmem.c index 74173e1..395fb94 100644 --- a/src/coparun/runmem.c +++ b/src/coparun/runmem.c @@ -23,19 +23,14 @@ uint8_t *executable_memory = NULL; uint32_t executable_memory_len = 0; entry_point_t entr_point = NULL; int data_offs = 0; - -void patch_mem_32(uint8_t *patch_addr, int32_t value) { - int32_t *val_ptr = (int32_t*)patch_addr; - *val_ptr = value; -} -int patch(uint8_t *patch_addr, uint32_t reloc_type, int32_t value) { - if (reloc_type == PATCH_RELATIVE_32) { - patch_mem_32(patch_addr, value); - }else{ - LOG("Not implemented"); - return 0; - } +int patch(uint8_t *patch_addr, uint32_t patch_mask, int32_t value) { + uint32_t *val_ptr = (uint32_t*)patch_addr; + uint32_t original = *val_ptr; + + uint32_t new_value = (original & ~patch_mask) | ((uint32_t)value & patch_mask); + + *val_ptr = new_value; return 1; } @@ -58,7 +53,7 @@ int update_data_offs() { int parse_commands(uint8_t *bytes) { int32_t value; uint32_t command; - uint32_t reloc_type; + uint32_t patch_mask; uint32_t offs; uint32_t size; int end_flag = 0; @@ -101,20 +96,20 @@ int parse_commands(uint8_t *bytes) { case PATCH_FUNC: offs = *(uint32_t*)bytes; bytes += 4; - reloc_type = *(uint32_t*)bytes; bytes += 4; + patch_mask = *(uint32_t*)bytes; bytes += 4; value = *(int32_t*)bytes; bytes += 4; - LOG("PATCH_FUNC patch_offs=%i reloc_type=%i value=%i\n", - offs, reloc_type, value); - patch(executable_memory + offs, reloc_type, value); + LOG("PATCH_FUNC patch_offs=%i patch_mask=%#08x value=%i\n", + offs, patch_mask, value); + patch(executable_memory + offs, patch_mask, value); break; case PATCH_OBJECT: offs = *(uint32_t*)bytes; bytes += 4; - reloc_type = *(uint32_t*)bytes; bytes += 4; + patch_mask = *(uint32_t*)bytes; bytes += 4; value = *(int32_t*)bytes; bytes += 4; - LOG("PATCH_OBJECT patch_offs=%i reloc_type=%i value=%i\n", - offs, reloc_type, value); - patch(executable_memory + offs, reloc_type, value + data_offs); + LOG("PATCH_OBJECT patch_offs=%i patch_mask=%#08x value=%i\n", + offs, patch_mask, value); + patch(executable_memory + offs, patch_mask, value + data_offs); break; case PATCH_MATH_POW: diff --git a/src/coparun/runmem.h b/src/coparun/runmem.h index faed240..87ec751 100644 --- a/src/coparun/runmem.h +++ b/src/coparun/runmem.h @@ -17,9 +17,6 @@ #define FREE_MEMORY 257 #define PATCH_MATH_POW 512 -/* Relocation types */ -#define PATCH_RELATIVE_32 0 - /* Memory blobs accessible by other translation units */ extern uint8_t *data_memory; extern uint32_t data_memory_len; diff --git a/tests/test_vector_aarch64.py b/tests/test_vector_aarch64.py new file mode 100644 index 0000000..a249f87 --- /dev/null +++ b/tests/test_vector_aarch64.py @@ -0,0 +1,39 @@ +import copapy as cp +import pytest + + +def test_vectors_init(): + tt1 = cp.vector(range(3)) + cp.vector([1.1, 2.2, 3.3]) + tt2 = cp.vector([1.1, 2, cp.variable(5)]) + cp.vector(range(3)) + tt3 = (cp.vector(range(3)) + 5.6) + tt4 = cp.vector([1.1, 2, 3]) + cp.vector(cp.variable(v) for v in range(3)) + tt5 = cp.vector([1, 2, 3]).dot(tt4) + + print(tt1, tt2, tt3, tt4, tt5) + + +@pytest.mark.skip(reason="no way of currently testing this") +def test_compiled_vectors(): + t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3)) + t2 = t1.sum() + + t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3)) + t4 = ((t3 * t1) * 2).sum() + t5 = ((t3 * t1) * 2).magnitude() + + tg = cp.Target('aarch64') + tg.compile(t2, t4, t5) + tg.run() + + assert isinstance(t2, cp.variable) + assert tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2 + + assert isinstance(t4, cp.variable) + assert tg.read_value(t4) == pytest.approx(((10/1*2) + (12/2*2) + (14/3*2)), 0.001) # pyright: ignore[reportUnknownMemberType] + + assert isinstance(t5, cp.variable) + assert tg.read_value(t5) == pytest.approx(((10/1*2)**2 + (12/2*2)**2 + (14/3*2)**2) ** 0.5, 0.001) # pyright: ignore[reportUnknownMemberType] + + +if __name__ == "__main__": + test_compiled_vectors() diff --git a/tools/extract_code.py b/tools/extract_code.py index eaca3bf..05496ad 100644 --- a/tools/extract_code.py +++ b/tools/extract_code.py @@ -1,5 +1,4 @@ from copapy._binwrite import data_reader, Command, ByteOrder -from copapy._stencils import RelocationType import argparse if __name__ == "__main__": @@ -45,18 +44,18 @@ if __name__ == "__main__": print(f"COPY_CODE offs={offs} size={size} data={' '.join(hex(d) for d in datab[:5])}...") elif com == Command.PATCH_FUNC: offs = dr.read_int() - reloc_type = dr.read_int() + mask = dr.read_int() value = dr.read_int(signed=True) - assert reloc_type == RelocationType.RELOC_RELATIVE_32.value + assert mask == 0xFFFFFFFF program_data[offs:offs + 4] = value.to_bytes(4, byteorder, signed=True) - print(f"PATCH_FUNC patch_offs={offs} reloc_type={reloc_type} value={value}") + print(f"PATCH_FUNC patch_offs={offs} mask=0x{mask:x} value={value}") elif com == Command.PATCH_OBJECT: offs = dr.read_int() - reloc_type = dr.read_int() + mask = dr.read_int() value = dr.read_int(signed=True) - assert reloc_type == RelocationType.RELOC_RELATIVE_32.value + assert mask == 0xFFFFFFFF program_data[offs:offs + 4] = (value + data_section_offset).to_bytes(4, byteorder, signed=True) - print(f"PATCH_OBJECT patch_offs={offs} reloc_type={reloc_type} value={value}") + print(f"PATCH_OBJECT patch_offs={offs} mask=ox{mask:x} value={value}") elif com == Command.ENTRY_POINT: rel_entr_point = dr.read_int() print(f"ENTRY_POINT rel_entr_point={rel_entr_point}")