patch command updated for arm support

This commit is contained in:
Nicolas Kruse 2025-10-29 22:29:15 +01:00
parent 7584b316fc
commit f60df09fa7
6 changed files with 97 additions and 62 deletions

View File

@ -313,13 +313,13 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
obj_addr = section_addr + patch.target_symbol_address
patch_value = obj_addr + patch.addend - (offset + patch.patch_address)
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
elif patch.target_symbol_info == 'STT_FUNC':
addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = addr + patch.addend - (offset + patch.patch_address)
patch_list.append((patch.type.value, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
else:
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
@ -349,13 +349,13 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
section_addr = section_addr_lookup[patch.target_symbol_section_index]
obj_addr = section_addr + patch.target_symbol_address
patch_value = obj_addr + patch.addend - (start + patch.patch_address)
patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
elif patch.target_symbol_info == 'STT_FUNC':
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = aux_func_addr + patch.addend - (start + patch.patch_address)
patch_list.append((patch.type.value, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
else:
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
@ -369,10 +369,10 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_bytes(b''.join(data_list))
# write patch operations
for patch_type, patch_addr, addr, patch_command in patch_list:
for mask, patch_addr, addr, patch_command in patch_list:
dw.write_com(patch_command)
dw.write_int(patch_addr)
dw.write_int(patch_type)
dw.write_int(mask)
dw.write_int(addr, signed=True)
dw.write_com(binw.Command.ENTRY_POINT)

View File

@ -1,16 +1,13 @@
from dataclasses import dataclass
from pelfy import open_elf_file, elf_file, elf_symbol
from typing import Generator, Literal, Iterable
from enum import Enum
import pelfy
ByteOrder = Literal['little', 'big']
# on x86_64: call or jmp instruction when tail call optimized
LENGTH_CALL_INSTRUCTION = 5
RelocationType = Enum('RelocationType', [('RELOC_RELATIVE_32', 0)])
@dataclass
class patch_entry:
"""
@ -20,7 +17,7 @@ class patch_entry:
addr (int): address of first byte to patch relative to the start of the symbol
type (RelocationType): relocation type"""
type: RelocationType
mask: int
patch_address: int
addend: int
target_symbol_name: str
@ -29,12 +26,27 @@ class patch_entry:
target_symbol_address: int
def translate_relocation(relocation_addr: int, reloc_type: str, bits: int, r_addend: int) -> RelocationType:
if reloc_type in ('R_AMD64_PLT32', 'R_AMD64_PC32'):
def translate_relocation(reloc: pelfy.elf_relocation, offset: int) -> patch_entry:
if reloc.type in ('R_AMD64_PLT32', 'R_AMD64_PC32'):
# S + A - P
return RelocationType.RELOC_RELATIVE_32
mask = 0xFFFFFFFF # 32 bit
imm = offset
elif reloc.type.endswith('_JUMP26'):
assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64"
mask = 0x3ffffff # 26 bit
imm = offset >> 2
assert imm < mask, "Relocation immediate value too large"
else:
raise Exception(f"Unknown relocation type: {reloc_type}")
raise NotImplementedError(f"Relocation type {reloc.type} not implemented")
return patch_entry(mask, imm,
reloc.fields['r_addend'],
reloc.symbol.name,
reloc.symbol.info,
reloc.symbol.fields['st_shndx'],
reloc.symbol.fields['st_value'])
def get_return_function_type(symbol: elf_symbol) -> str:
@ -129,6 +141,7 @@ class stencil_database():
Yields:
patch_entry: every relocation for the symbol
"""
arm_hi_byte_flag: bool = False
symbol = self.elf.symbols[symbol_name]
if stencil:
start_index, end_index = get_stencil_position(symbol)
@ -136,27 +149,19 @@ class stencil_database():
start_index = 0
end_index = symbol.fields['st_size']
print('->', symbol_name)
for reloc in symbol.relocations:
patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
# address to fist byte to patch relative to the start of the symbol
rtype = translate_relocation(
patch_offset,
reloc.type,
reloc.bits,
reloc.fields['r_addend'])
if patch_offset < end_index - start_index: # Exclude the call to the result_* function
if reloc.symbol.info == 'STT_SECTION':
arm_hi_byte_flag = True
else:
assert not arm_hi_byte_flag, "Page based relocation for ARM not supported"
# address to fist byte to patch relative to the start of the symbol
patch = patch_entry(rtype, patch_offset,
reloc.fields['r_addend'],
reloc.symbol.name,
reloc.symbol.info,
reloc.symbol.fields['st_shndx'],
reloc.symbol.fields['st_value'])
# Exclude the call to the result_* function
if patch.patch_address < end_index - start_index:
yield patch
yield translate_relocation(reloc, patch_offset)
def get_stencil_code(self, name: str) -> bytes:
"""Return the striped function code for a provided function name

View File

@ -24,18 +24,13 @@ uint32_t executable_memory_len = 0;
entry_point_t entr_point = NULL;
int data_offs = 0;
void patch_mem_32(uint8_t *patch_addr, int32_t value) {
int32_t *val_ptr = (int32_t*)patch_addr;
*val_ptr = value;
}
int patch(uint8_t *patch_addr, uint32_t patch_mask, int32_t value) {
uint32_t *val_ptr = (uint32_t*)patch_addr;
uint32_t original = *val_ptr;
int patch(uint8_t *patch_addr, uint32_t reloc_type, int32_t value) {
if (reloc_type == PATCH_RELATIVE_32) {
patch_mem_32(patch_addr, value);
}else{
LOG("Not implemented");
return 0;
}
uint32_t new_value = (original & ~patch_mask) | ((uint32_t)value & patch_mask);
*val_ptr = new_value;
return 1;
}
@ -58,7 +53,7 @@ int update_data_offs() {
int parse_commands(uint8_t *bytes) {
int32_t value;
uint32_t command;
uint32_t reloc_type;
uint32_t patch_mask;
uint32_t offs;
uint32_t size;
int end_flag = 0;
@ -101,20 +96,20 @@ int parse_commands(uint8_t *bytes) {
case PATCH_FUNC:
offs = *(uint32_t*)bytes; bytes += 4;
reloc_type = *(uint32_t*)bytes; bytes += 4;
patch_mask = *(uint32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4;
LOG("PATCH_FUNC patch_offs=%i reloc_type=%i value=%i\n",
offs, reloc_type, value);
patch(executable_memory + offs, reloc_type, value);
LOG("PATCH_FUNC patch_offs=%i patch_mask=%#08x value=%i\n",
offs, patch_mask, value);
patch(executable_memory + offs, patch_mask, value);
break;
case PATCH_OBJECT:
offs = *(uint32_t*)bytes; bytes += 4;
reloc_type = *(uint32_t*)bytes; bytes += 4;
patch_mask = *(uint32_t*)bytes; bytes += 4;
value = *(int32_t*)bytes; bytes += 4;
LOG("PATCH_OBJECT patch_offs=%i reloc_type=%i value=%i\n",
offs, reloc_type, value);
patch(executable_memory + offs, reloc_type, value + data_offs);
LOG("PATCH_OBJECT patch_offs=%i patch_mask=%#08x value=%i\n",
offs, patch_mask, value);
patch(executable_memory + offs, patch_mask, value + data_offs);
break;
case PATCH_MATH_POW:

View File

@ -17,9 +17,6 @@
#define FREE_MEMORY 257
#define PATCH_MATH_POW 512
/* Relocation types */
#define PATCH_RELATIVE_32 0
/* Memory blobs accessible by other translation units */
extern uint8_t *data_memory;
extern uint32_t data_memory_len;

View File

@ -0,0 +1,39 @@
import copapy as cp
import pytest
def test_vectors_init():
tt1 = cp.vector(range(3)) + cp.vector([1.1, 2.2, 3.3])
tt2 = cp.vector([1.1, 2, cp.variable(5)]) + cp.vector(range(3))
tt3 = (cp.vector(range(3)) + 5.6)
tt4 = cp.vector([1.1, 2, 3]) + cp.vector(cp.variable(v) for v in range(3))
tt5 = cp.vector([1, 2, 3]).dot(tt4)
print(tt1, tt2, tt3, tt4, tt5)
@pytest.mark.skip(reason="no way of currently testing this")
def test_compiled_vectors():
t1 = cp.vector([10, 11, 12]) + cp.vector(cp.variable(v) for v in range(3))
t2 = t1.sum()
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
t4 = ((t3 * t1) * 2).sum()
t5 = ((t3 * t1) * 2).magnitude()
tg = cp.Target('aarch64')
tg.compile(t2, t4, t5)
tg.run()
assert isinstance(t2, cp.variable)
assert tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
assert isinstance(t4, cp.variable)
assert tg.read_value(t4) == pytest.approx(((10/1*2) + (12/2*2) + (14/3*2)), 0.001) # pyright: ignore[reportUnknownMemberType]
assert isinstance(t5, cp.variable)
assert tg.read_value(t5) == pytest.approx(((10/1*2)**2 + (12/2*2)**2 + (14/3*2)**2) ** 0.5, 0.001) # pyright: ignore[reportUnknownMemberType]
if __name__ == "__main__":
test_compiled_vectors()

View File

@ -1,5 +1,4 @@
from copapy._binwrite import data_reader, Command, ByteOrder
from copapy._stencils import RelocationType
import argparse
if __name__ == "__main__":
@ -45,18 +44,18 @@ if __name__ == "__main__":
print(f"COPY_CODE offs={offs} size={size} data={' '.join(hex(d) for d in datab[:5])}...")
elif com == Command.PATCH_FUNC:
offs = dr.read_int()
reloc_type = dr.read_int()
mask = dr.read_int()
value = dr.read_int(signed=True)
assert reloc_type == RelocationType.RELOC_RELATIVE_32.value
assert mask == 0xFFFFFFFF
program_data[offs:offs + 4] = value.to_bytes(4, byteorder, signed=True)
print(f"PATCH_FUNC patch_offs={offs} reloc_type={reloc_type} value={value}")
print(f"PATCH_FUNC patch_offs={offs} mask=0x{mask:x} value={value}")
elif com == Command.PATCH_OBJECT:
offs = dr.read_int()
reloc_type = dr.read_int()
mask = dr.read_int()
value = dr.read_int(signed=True)
assert reloc_type == RelocationType.RELOC_RELATIVE_32.value
assert mask == 0xFFFFFFFF
program_data[offs:offs + 4] = (value + data_section_offset).to_bytes(4, byteorder, signed=True)
print(f"PATCH_OBJECT patch_offs={offs} reloc_type={reloc_type} value={value}")
print(f"PATCH_OBJECT patch_offs={offs} mask=ox{mask:x} value={value}")
elif com == Command.ENTRY_POINT:
rel_entr_point = dr.read_int()
print(f"ENTRY_POINT rel_entr_point={rel_entr_point}")