Compiler rewritten for aarch64 support

This commit is contained in:
Nicolas 2025-10-31 16:56:51 +01:00
parent c2d1fb7eea
commit e8a73c088e
3 changed files with 107 additions and 84 deletions

View File

@ -39,7 +39,7 @@ class data_writer():
data = struct.pack(en + 'f', value)
else:
data = struct.pack(en + 'd', value)
assert len(data) == num_bytes
assert len(data) == num_bytes, (len(data), num_bytes)
self.write_bytes(data)
def print(self) -> None:

View File

@ -1,6 +1,6 @@
from typing import Generator, Iterable, Any
from . import _binwrite as binw
from ._stencils import stencil_database
from ._stencils import stencil_database, patch_entry
from collections import defaultdict, deque
from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type
@ -171,7 +171,7 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset:
object_list: list[tuple[Net, int, int]] = []
for variable in variable_list:
lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype))
lengths = sdb.get_type_size(transl_type(variable.dtype))
object_list.append((variable, offset, lengths))
offset += (lengths + 3) // 4 * 4
@ -236,7 +236,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
"""
variables: dict[Net, tuple[int, int, str]] = {}
data_list: list[bytes] = []
patch_list: list[tuple[int, int, int, binw.Command]] = []
patch_list: list[patch_entry] = []
ordered_ops = list(stable_toposort(get_all_dag_edges(node_list)))
const_net_list = get_const_nets(ordered_ops)
@ -297,32 +297,31 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
data_list.append(data)
#print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for patch in sdb.get_patch_positions(node.name, stencil=True):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
if patch.target_symbol_name.startswith('dummy_'):
for reloc in sdb.get_relocations(node.name, stencil=True):
if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
#print('-- ' + reloc.target_symbol_name + ' // ' + node.name)
if reloc.target_symbol_name.startswith('dummy_'):
# Patch for write and read addresses to/from heap variables
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
addr = object_addr_lookup[associated_net]
patch_value = addr + patch.addend - (offset + patch.patch_address)
elif patch.target_symbol_name.startswith('result_'):
obj_addr = object_addr_lookup[associated_net]
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
elif reloc.target_symbol_name.startswith('result_'):
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
else:
# Patch constants addresses on heap
section_addr = section_addr_lookup[patch.target_symbol_section_index]
obj_addr = section_addr + patch.target_symbol_address
patch_value = obj_addr + patch.addend - (offset + patch.patch_address)
obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
#print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
elif patch.target_symbol_info == 'STT_FUNC':
addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = addr + patch.addend - (offset + patch.patch_address)
patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
elif reloc.target_symbol_info == 'STT_FUNC':
func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value)
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
else:
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
raise ValueError(f"Unsupported: {node.name} {reloc.target_symbol_info} {reloc.target_symbol_name}")
patch_list.append(patch)
offset += len(data)
@ -334,6 +333,8 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_com(binw.Command.ALLOCATE_CODE)
dw.write_int(offset)
print('o aux: ', aux_function_mem_layout)
# write aux functions code
for name, start, lengths in aux_function_mem_layout:
dw.write_com(binw.Command.COPY_CODE)
@ -343,22 +344,21 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
# Patch aux functions
for name, start, lengths in aux_function_mem_layout:
for patch in sdb.get_patch_positions(name):
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
for reloc in sdb.get_relocations(name):
if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
# Patch constants/variable addresses on heap
section_addr = section_addr_lookup[patch.target_symbol_section_index]
obj_addr = section_addr + patch.target_symbol_address
patch_value = obj_addr + patch.addend - (start + patch.patch_address)
patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
#print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
elif patch.target_symbol_info == 'STT_FUNC':
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
patch_value = aux_func_addr + patch.addend - (start + patch.patch_address)
patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
elif reloc.target_symbol_info == 'STT_FUNC':
func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value)
else:
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
raise ValueError(f"Unsupported: {name} {reloc.target_symbol_info} {reloc.target_symbol_name}")
patch_list.append(patch)
#assert False, aux_function_mem_layout
@ -369,11 +369,11 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
dw.write_bytes(b''.join(data_list))
# write patch operations
for mask, patch_addr, addr, patch_command in patch_list:
dw.write_com(patch_command)
dw.write_int(patch_addr)
dw.write_int(mask)
dw.write_int(addr, signed=True)
for patch in patch_list:
dw.write_int(patch.patch_type)
dw.write_int(patch.address)
dw.write_int(patch.mask)
dw.write_int(patch.value, signed=True)
dw.write_com(binw.Command.ENTRY_POINT)
dw.write_int(aux_function_lengths)

View File

@ -5,6 +5,22 @@ import pelfy
ByteOrder = Literal['little', 'big']
@dataclass
class relocation_entry:
"""
A dataclass for representing a relocation entry
"""
target_symbol_name: str
target_symbol_info: str
target_symbol_offset: int
target_section_index: int
function_offset: int
start: int
pelfy_reloc: pelfy.elf_relocation
@dataclass
class patch_entry:
"""
@ -15,36 +31,9 @@ class patch_entry:
type (RelocationType): relocation type"""
mask: int
patch_address: int
addend: int
target_symbol_name: str
target_symbol_info: str
target_symbol_section_index: int
target_symbol_address: int
def translate_relocation(reloc: pelfy.elf_relocation, offset: int) -> patch_entry:
if reloc.type.endswith('_PLT32') or reloc.type.endswith('_PC32'):
# S + A - P
mask = 0xFFFFFFFF # 32 bit
imm = offset
elif reloc.type.endswith('_JUMP26') or reloc.type.endswith('_CALL26'):
# S + A - P
assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64"
mask = 0x3ffffff # 26 bit
imm = offset >> 2
assert imm < mask, "Relocation immediate value too large"
else:
raise NotImplementedError(f"Relocation type {reloc.type} not implemented")
return patch_entry(mask, imm,
reloc.fields['r_addend'],
reloc.symbol.name,
reloc.symbol.info,
reloc.symbol.fields['st_shndx'],
reloc.symbol.fields['st_value'])
address: int
value: int
patch_type: int
def get_return_function_type(symbol: elf_symbol) -> str:
@ -72,7 +61,6 @@ def get_last_call_in_function(func: elf_symbol) -> int:
# Find last relocation in function
assert func.relocations, f'No call function in stencil function {func.name}.'
reloc = func.relocations[-1]
print(f"reloc.fields['r_addend'] {reloc.fields['r_addend']}")
instruction_lenghs = 4 if reloc.bits < 32 else 5
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - instruction_lenghs
@ -140,16 +128,7 @@ class stencil_database():
ret.add(sym.section.index)
return list(ret)
def get_patch_positions(self, symbol_name: str, stencil: bool = False) -> Generator[patch_entry, None, None]:
"""Return patch positions for a provided symbol (function or object)
Args:
symbol_name: function or object name
Yields:
patch_entry: every relocation for the symbol
"""
arm_hi_byte_flag: bool = False
def get_relocations(self, symbol_name: str, stencil: bool = False) -> Generator[relocation_entry, None, None]:
symbol = self.elf.symbols[symbol_name]
if stencil:
start_index, end_index = get_stencil_position(symbol)
@ -159,18 +138,58 @@ class stencil_database():
print('->', symbol_name)
for reloc in symbol.relocations:
print(' ', symbol_name, arm_hi_byte_flag, reloc.symbol.info)
print(' ', symbol_name, reloc.symbol.info, reloc.symbol.name, reloc.type)
# address to fist byte to patch relative to the start of the symbol
patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
if patch_offset < end_index - start_index: # Exclude the call to the result_* function
if reloc.symbol.info == 'STT_SECTION':
arm_hi_byte_flag = True
else:
assert not arm_hi_byte_flag, "Page based relocation for ARM not supported"
# address to fist byte to patch relative to the start of the symbol
yield relocation_entry(reloc.symbol.name,
reloc.symbol.info,
reloc.symbol.fields['st_value'],
reloc.symbol.fields['st_shndx'],
symbol.fields['st_value'],
start_index,
reloc)
def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry:
"""Return patch positions for a provided symbol (function or object)
Args:
relocation: relocation entry
symbol_address: absolute address of the target symbol
function_offset: absolute address of the first byte of the
function the patch is applied to
Yields:
patch_entry: every relocation for the symbol
"""
pr = relocation.pelfy_reloc
# calculate absolut address to the first byte to patch
# relative to the start of the (stripped stencil) function:
patch_offset = pr.fields['r_offset'] - relocation.function_offset - relocation.start + function_offset
if pr.type.endswith('_PLT32') or pr.type.endswith('_PC32'):
# S + A - P
mask = 0xFFFFFFFF # 32 bit
patch_value = symbol_address + pr.fields['r_addend'] - patch_offset
print(f'** {patch_offset=} {relocation.target_symbol_name=} {pr.fields['r_offset']=} {relocation.function_offset=} {relocation.start=} {function_offset=}')
print(f' {patch_value=} {symbol_address=} {pr.fields['r_addend']=}, {function_offset=}')
#elif reloc.type.endswith('_JUMP26') or reloc.type.endswith('_CALL26'):
# # S + A - P
# assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64"
# mask = 0x3ffffff # 26 bit
# imm = offset >> 2
# assert imm < mask, "Relocation immediate value too large"
else:
raise NotImplementedError(f"Relocation type {pr.type} not implemented")
return patch_entry(mask, patch_offset, patch_value, symbol_type)
yield translate_relocation(reloc, patch_offset)
def get_stencil_code(self, name: str) -> bytes:
"""Return the striped function code for a provided function name
@ -202,6 +221,10 @@ class stencil_database():
name_set |= self.get_sub_functions([r.symbol.name])
return name_set
def get_type_size(self, type_name: str) -> int:
"""Returns the size of a variable type in bytes."""
return {'int': 4, 'float': 4}[type_name]
def get_symbol_size(self, name: str) -> int:
"""Returns the size of a specified symbol name."""
return self.elf.symbols[name].fields['st_size']