Compiler rewritten for aarch64 support

This commit is contained in:
Nicolas 2025-10-31 16:56:51 +01:00
parent c2d1fb7eea
commit e8a73c088e
3 changed files with 107 additions and 84 deletions

View File

@ -39,7 +39,7 @@ class data_writer():
data = struct.pack(en + 'f', value) data = struct.pack(en + 'f', value)
else: else:
data = struct.pack(en + 'd', value) data = struct.pack(en + 'd', value)
assert len(data) == num_bytes assert len(data) == num_bytes, (len(data), num_bytes)
self.write_bytes(data) self.write_bytes(data)
def print(self) -> None: def print(self) -> None:

View File

@ -1,6 +1,6 @@
from typing import Generator, Iterable, Any from typing import Generator, Iterable, Any
from . import _binwrite as binw from . import _binwrite as binw
from ._stencils import stencil_database from ._stencils import stencil_database, patch_entry
from collections import defaultdict, deque from collections import defaultdict, deque
from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type
@ -171,7 +171,7 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset:
object_list: list[tuple[Net, int, int]] = [] object_list: list[tuple[Net, int, int]] = []
for variable in variable_list: for variable in variable_list:
lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype)) lengths = sdb.get_type_size(transl_type(variable.dtype))
object_list.append((variable, offset, lengths)) object_list.append((variable, offset, lengths))
offset += (lengths + 3) // 4 * 4 offset += (lengths + 3) // 4 * 4
@ -236,7 +236,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
""" """
variables: dict[Net, tuple[int, int, str]] = {} variables: dict[Net, tuple[int, int, str]] = {}
data_list: list[bytes] = [] data_list: list[bytes] = []
patch_list: list[tuple[int, int, int, binw.Command]] = [] patch_list: list[patch_entry] = []
ordered_ops = list(stable_toposort(get_all_dag_edges(node_list))) ordered_ops = list(stable_toposort(get_all_dag_edges(node_list)))
const_net_list = get_const_nets(ordered_ops) const_net_list = get_const_nets(ordered_ops)
@ -297,32 +297,31 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
data_list.append(data) data_list.append(data)
#print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) #print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for patch in sdb.get_patch_positions(node.name, stencil=True): for reloc in sdb.get_relocations(node.name, stencil=True):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
if patch.target_symbol_name.startswith('dummy_'): #print('-- ' + reloc.target_symbol_name + ' // ' + node.name)
if reloc.target_symbol_name.startswith('dummy_'):
# Patch for write and read addresses to/from heap variables # Patch for write and read addresses to/from heap variables
assert associated_net, f"Relocation found but no net defined for operation {node.name}" assert associated_net, f"Relocation found but no net defined for operation {node.name}"
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") #print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
addr = object_addr_lookup[associated_net] obj_addr = object_addr_lookup[associated_net]
patch_value = addr + patch.addend - (offset + patch.patch_address) patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
elif patch.target_symbol_name.startswith('result_'): elif reloc.target_symbol_name.startswith('result_'):
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.") raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
else: else:
# Patch constants addresses on heap # Patch constants addresses on heap
section_addr = section_addr_lookup[patch.target_symbol_section_index] obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
obj_addr = section_addr + patch.target_symbol_address patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
patch_value = obj_addr + patch.addend - (offset + patch.patch_address)
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name) #print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
elif patch.target_symbol_info == 'STT_FUNC': elif reloc.target_symbol_info == 'STT_FUNC':
addr = aux_func_addr_lookup[patch.target_symbol_name] func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
patch_value = addr + patch.addend - (offset + patch.patch_address) patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value)
patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name) #print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
else: else:
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") raise ValueError(f"Unsupported: {node.name} {reloc.target_symbol_info} {reloc.target_symbol_name}")
patch_list.append(patch)
offset += len(data) offset += len(data)
@ -334,6 +333,8 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_com(binw.Command.ALLOCATE_CODE) dw.write_com(binw.Command.ALLOCATE_CODE)
dw.write_int(offset) dw.write_int(offset)
print('o aux: ', aux_function_mem_layout)
# write aux functions code # write aux functions code
for name, start, lengths in aux_function_mem_layout: for name, start, lengths in aux_function_mem_layout:
dw.write_com(binw.Command.COPY_CODE) dw.write_com(binw.Command.COPY_CODE)
@ -343,22 +344,21 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
# Patch aux functions # Patch aux functions
for name, start, lengths in aux_function_mem_layout: for name, start, lengths in aux_function_mem_layout:
for patch in sdb.get_patch_positions(name): for reloc in sdb.get_relocations(name):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}: if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
# Patch constants/variable addresses on heap # Patch constants/variable addresses on heap
section_addr = section_addr_lookup[patch.target_symbol_section_index] obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
obj_addr = section_addr + patch.target_symbol_address patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
patch_value = obj_addr + patch.addend - (start + patch.patch_address)
patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name) #print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
elif patch.target_symbol_info == 'STT_FUNC': elif reloc.target_symbol_info == 'STT_FUNC':
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name] func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
patch_value = aux_func_addr + patch.addend - (start + patch.patch_address) patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value)
patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
else: else:
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}") raise ValueError(f"Unsupported: {name} {reloc.target_symbol_info} {reloc.target_symbol_name}")
patch_list.append(patch)
#assert False, aux_function_mem_layout #assert False, aux_function_mem_layout
@ -369,11 +369,11 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_bytes(b''.join(data_list)) dw.write_bytes(b''.join(data_list))
# write patch operations # write patch operations
for mask, patch_addr, addr, patch_command in patch_list: for patch in patch_list:
dw.write_com(patch_command) dw.write_int(patch.patch_type)
dw.write_int(patch_addr) dw.write_int(patch.address)
dw.write_int(mask) dw.write_int(patch.mask)
dw.write_int(addr, signed=True) dw.write_int(patch.value, signed=True)
dw.write_com(binw.Command.ENTRY_POINT) dw.write_com(binw.Command.ENTRY_POINT)
dw.write_int(aux_function_lengths) dw.write_int(aux_function_lengths)

View File

@ -5,6 +5,22 @@ import pelfy
ByteOrder = Literal['little', 'big'] ByteOrder = Literal['little', 'big']
@dataclass
class relocation_entry:
"""
A dataclass for representing a relocation entry
"""
target_symbol_name: str
target_symbol_info: str
target_symbol_offset: int
target_section_index: int
function_offset: int
start: int
pelfy_reloc: pelfy.elf_relocation
@dataclass @dataclass
class patch_entry: class patch_entry:
""" """
@ -15,36 +31,9 @@ class patch_entry:
type (RelocationType): relocation type""" type (RelocationType): relocation type"""
mask: int mask: int
patch_address: int address: int
addend: int value: int
target_symbol_name: str patch_type: int
target_symbol_info: str
target_symbol_section_index: int
target_symbol_address: int
def translate_relocation(reloc: pelfy.elf_relocation, offset: int) -> patch_entry:
if reloc.type.endswith('_PLT32') or reloc.type.endswith('_PC32'):
# S + A - P
mask = 0xFFFFFFFF # 32 bit
imm = offset
elif reloc.type.endswith('_JUMP26') or reloc.type.endswith('_CALL26'):
# S + A - P
assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64"
mask = 0x3ffffff # 26 bit
imm = offset >> 2
assert imm < mask, "Relocation immediate value too large"
else:
raise NotImplementedError(f"Relocation type {reloc.type} not implemented")
return patch_entry(mask, imm,
reloc.fields['r_addend'],
reloc.symbol.name,
reloc.symbol.info,
reloc.symbol.fields['st_shndx'],
reloc.symbol.fields['st_value'])
def get_return_function_type(symbol: elf_symbol) -> str: def get_return_function_type(symbol: elf_symbol) -> str:
@ -72,7 +61,6 @@ def get_last_call_in_function(func: elf_symbol) -> int:
# Find last relocation in function # Find last relocation in function
assert func.relocations, f'No call function in stencil function {func.name}.' assert func.relocations, f'No call function in stencil function {func.name}.'
reloc = func.relocations[-1] reloc = func.relocations[-1]
print(f"reloc.fields['r_addend'] {reloc.fields['r_addend']}")
instruction_lenghs = 4 if reloc.bits < 32 else 5 instruction_lenghs = 4 if reloc.bits < 32 else 5
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - instruction_lenghs return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - instruction_lenghs
@ -140,16 +128,7 @@ class stencil_database():
ret.add(sym.section.index) ret.add(sym.section.index)
return list(ret) return list(ret)
def get_patch_positions(self, symbol_name: str, stencil: bool = False) -> Generator[patch_entry, None, None]: def get_relocations(self, symbol_name: str, stencil: bool = False) -> Generator[relocation_entry, None, None]:
"""Return patch positions for a provided symbol (function or object)
Args:
symbol_name: function or object name
Yields:
patch_entry: every relocation for the symbol
"""
arm_hi_byte_flag: bool = False
symbol = self.elf.symbols[symbol_name] symbol = self.elf.symbols[symbol_name]
if stencil: if stencil:
start_index, end_index = get_stencil_position(symbol) start_index, end_index = get_stencil_position(symbol)
@ -159,18 +138,58 @@ class stencil_database():
print('->', symbol_name) print('->', symbol_name)
for reloc in symbol.relocations: for reloc in symbol.relocations:
print(' ', symbol_name, arm_hi_byte_flag, reloc.symbol.info) print(' ', symbol_name, reloc.symbol.info, reloc.symbol.name, reloc.type)
# address to fist byte to patch relative to the start of the symbol
patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
if patch_offset < end_index - start_index: # Exclude the call to the result_* function if patch_offset < end_index - start_index: # Exclude the call to the result_* function
if reloc.symbol.info == 'STT_SECTION': yield relocation_entry(reloc.symbol.name,
arm_hi_byte_flag = True reloc.symbol.info,
else: reloc.symbol.fields['st_value'],
assert not arm_hi_byte_flag, "Page based relocation for ARM not supported" reloc.symbol.fields['st_shndx'],
# address to fist byte to patch relative to the start of the symbol symbol.fields['st_value'],
start_index,
reloc)
def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry:
"""Return patch positions for a provided symbol (function or object)
Args:
relocation: relocation entry
symbol_address: absolute address of the target symbol
function_offset: absolute address of the first byte of the
function the patch is applied to
Yields:
patch_entry: every relocation for the symbol
"""
pr = relocation.pelfy_reloc
# calculate absolut address to the first byte to patch
# relative to the start of the (stripped stencil) function:
patch_offset = pr.fields['r_offset'] - relocation.function_offset - relocation.start + function_offset
if pr.type.endswith('_PLT32') or pr.type.endswith('_PC32'):
# S + A - P
mask = 0xFFFFFFFF # 32 bit
patch_value = symbol_address + pr.fields['r_addend'] - patch_offset
print(f'** {patch_offset=} {relocation.target_symbol_name=} {pr.fields['r_offset']=} {relocation.function_offset=} {relocation.start=} {function_offset=}')
print(f' {patch_value=} {symbol_address=} {pr.fields['r_addend']=}, {function_offset=}')
#elif reloc.type.endswith('_JUMP26') or reloc.type.endswith('_CALL26'):
# # S + A - P
# assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64"
# mask = 0x3ffffff # 26 bit
# imm = offset >> 2
# assert imm < mask, "Relocation immediate value too large"
else:
raise NotImplementedError(f"Relocation type {pr.type} not implemented")
return patch_entry(mask, patch_offset, patch_value, symbol_type)
yield translate_relocation(reloc, patch_offset)
def get_stencil_code(self, name: str) -> bytes: def get_stencil_code(self, name: str) -> bytes:
"""Return the striped function code for a provided function name """Return the striped function code for a provided function name
@ -202,6 +221,10 @@ class stencil_database():
name_set |= self.get_sub_functions([r.symbol.name]) name_set |= self.get_sub_functions([r.symbol.name])
return name_set return name_set
def get_type_size(self, type_name: str) -> int:
"""Returns the size of a variable type in bytes."""
return {'int': 4, 'float': 4}[type_name]
def get_symbol_size(self, name: str) -> int: def get_symbol_size(self, name: str) -> int:
"""Returns the size of a specified symbol name.""" """Returns the size of a specified symbol name."""
return self.elf.symbols[name].fields['st_size'] return self.elf.symbols[name].fields['st_size']