mirror of https://github.com/Nonannet/copapy.git
Compiler rewritten for aarch64 support
This commit is contained in:
parent
c2d1fb7eea
commit
e8a73c088e
|
|
@ -39,7 +39,7 @@ class data_writer():
|
||||||
data = struct.pack(en + 'f', value)
|
data = struct.pack(en + 'f', value)
|
||||||
else:
|
else:
|
||||||
data = struct.pack(en + 'd', value)
|
data = struct.pack(en + 'd', value)
|
||||||
assert len(data) == num_bytes
|
assert len(data) == num_bytes, (len(data), num_bytes)
|
||||||
self.write_bytes(data)
|
self.write_bytes(data)
|
||||||
|
|
||||||
def print(self) -> None:
|
def print(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Generator, Iterable, Any
|
from typing import Generator, Iterable, Any
|
||||||
from . import _binwrite as binw
|
from . import _binwrite as binw
|
||||||
from ._stencils import stencil_database
|
from ._stencils import stencil_database, patch_entry
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type
|
from ._basic_types import Net, Node, Write, CPConstant, Op, transl_type
|
||||||
|
|
||||||
|
|
@ -171,7 +171,7 @@ def get_data_layout(variable_list: Iterable[Net], sdb: stencil_database, offset:
|
||||||
object_list: list[tuple[Net, int, int]] = []
|
object_list: list[tuple[Net, int, int]] = []
|
||||||
|
|
||||||
for variable in variable_list:
|
for variable in variable_list:
|
||||||
lengths = sdb.get_symbol_size('dummy_' + transl_type(variable.dtype))
|
lengths = sdb.get_type_size(transl_type(variable.dtype))
|
||||||
object_list.append((variable, offset, lengths))
|
object_list.append((variable, offset, lengths))
|
||||||
offset += (lengths + 3) // 4 * 4
|
offset += (lengths + 3) // 4 * 4
|
||||||
|
|
||||||
|
|
@ -236,7 +236,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
|
||||||
"""
|
"""
|
||||||
variables: dict[Net, tuple[int, int, str]] = {}
|
variables: dict[Net, tuple[int, int, str]] = {}
|
||||||
data_list: list[bytes] = []
|
data_list: list[bytes] = []
|
||||||
patch_list: list[tuple[int, int, int, binw.Command]] = []
|
patch_list: list[patch_entry] = []
|
||||||
|
|
||||||
ordered_ops = list(stable_toposort(get_all_dag_edges(node_list)))
|
ordered_ops = list(stable_toposort(get_all_dag_edges(node_list)))
|
||||||
const_net_list = get_const_nets(ordered_ops)
|
const_net_list = get_const_nets(ordered_ops)
|
||||||
|
|
@ -297,32 +297,31 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
|
||||||
data_list.append(data)
|
data_list.append(data)
|
||||||
#print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
|
#print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
|
||||||
|
|
||||||
for patch in sdb.get_patch_positions(node.name, stencil=True):
|
for reloc in sdb.get_relocations(node.name, stencil=True):
|
||||||
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
|
if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
|
||||||
if patch.target_symbol_name.startswith('dummy_'):
|
#print('-- ' + reloc.target_symbol_name + ' // ' + node.name)
|
||||||
|
if reloc.target_symbol_name.startswith('dummy_'):
|
||||||
# Patch for write and read addresses to/from heap variables
|
# Patch for write and read addresses to/from heap variables
|
||||||
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
|
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
|
||||||
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
||||||
addr = object_addr_lookup[associated_net]
|
obj_addr = object_addr_lookup[associated_net]
|
||||||
patch_value = addr + patch.addend - (offset + patch.patch_address)
|
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
|
||||||
elif patch.target_symbol_name.startswith('result_'):
|
elif reloc.target_symbol_name.startswith('result_'):
|
||||||
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
|
raise Exception(f"Stencil {node.name} seams to branch to multiple result_* calls.")
|
||||||
else:
|
else:
|
||||||
# Patch constants addresses on heap
|
# Patch constants addresses on heap
|
||||||
section_addr = section_addr_lookup[patch.target_symbol_section_index]
|
obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
|
||||||
obj_addr = section_addr + patch.target_symbol_address
|
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
|
||||||
patch_value = obj_addr + patch.addend - (offset + patch.patch_address)
|
|
||||||
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
|
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
|
||||||
patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
|
|
||||||
#print(patch.type, patch.addr, binw.Command.PATCH_OBJECT, node.name)
|
|
||||||
|
|
||||||
elif patch.target_symbol_info == 'STT_FUNC':
|
elif reloc.target_symbol_info == 'STT_FUNC':
|
||||||
addr = aux_func_addr_lookup[patch.target_symbol_name]
|
func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
|
||||||
patch_value = addr + patch.addend - (offset + patch.patch_address)
|
patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value)
|
||||||
patch_list.append((patch.mask, offset + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
|
|
||||||
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
|
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
raise ValueError(f"Unsupported: {node.name} {reloc.target_symbol_info} {reloc.target_symbol_name}")
|
||||||
|
|
||||||
|
patch_list.append(patch)
|
||||||
|
|
||||||
offset += len(data)
|
offset += len(data)
|
||||||
|
|
||||||
|
|
@ -334,6 +333,8 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
|
||||||
dw.write_com(binw.Command.ALLOCATE_CODE)
|
dw.write_com(binw.Command.ALLOCATE_CODE)
|
||||||
dw.write_int(offset)
|
dw.write_int(offset)
|
||||||
|
|
||||||
|
print('o aux: ', aux_function_mem_layout)
|
||||||
|
|
||||||
# write aux functions code
|
# write aux functions code
|
||||||
for name, start, lengths in aux_function_mem_layout:
|
for name, start, lengths in aux_function_mem_layout:
|
||||||
dw.write_com(binw.Command.COPY_CODE)
|
dw.write_com(binw.Command.COPY_CODE)
|
||||||
|
|
@ -343,22 +344,21 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
|
||||||
|
|
||||||
# Patch aux functions
|
# Patch aux functions
|
||||||
for name, start, lengths in aux_function_mem_layout:
|
for name, start, lengths in aux_function_mem_layout:
|
||||||
for patch in sdb.get_patch_positions(name):
|
for reloc in sdb.get_relocations(name):
|
||||||
if patch.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
|
if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE'}:
|
||||||
# Patch constants/variable addresses on heap
|
# Patch constants/variable addresses on heap
|
||||||
section_addr = section_addr_lookup[patch.target_symbol_section_index]
|
obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
|
||||||
obj_addr = section_addr + patch.target_symbol_address
|
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
|
||||||
patch_value = obj_addr + patch.addend - (start + patch.patch_address)
|
|
||||||
patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_OBJECT))
|
|
||||||
#print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
|
#print('* constants aux', patch.type, patch.patch_address, obj_addr, binw.Command.PATCH_OBJECT, name)
|
||||||
|
|
||||||
elif patch.target_symbol_info == 'STT_FUNC':
|
elif reloc.target_symbol_info == 'STT_FUNC':
|
||||||
aux_func_addr = aux_func_addr_lookup[patch.target_symbol_name]
|
func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
|
||||||
patch_value = aux_func_addr + patch.addend - (start + patch.patch_address)
|
patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value)
|
||||||
patch_list.append((patch.mask, start + patch.patch_address, patch_value, binw.Command.PATCH_FUNC))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported: {name} {patch.target_symbol_info} {patch.target_symbol_name}")
|
raise ValueError(f"Unsupported: {name} {reloc.target_symbol_info} {reloc.target_symbol_name}")
|
||||||
|
|
||||||
|
patch_list.append(patch)
|
||||||
|
|
||||||
#assert False, aux_function_mem_layout
|
#assert False, aux_function_mem_layout
|
||||||
|
|
||||||
|
|
@ -369,11 +369,11 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
|
||||||
dw.write_bytes(b''.join(data_list))
|
dw.write_bytes(b''.join(data_list))
|
||||||
|
|
||||||
# write patch operations
|
# write patch operations
|
||||||
for mask, patch_addr, addr, patch_command in patch_list:
|
for patch in patch_list:
|
||||||
dw.write_com(patch_command)
|
dw.write_int(patch.patch_type)
|
||||||
dw.write_int(patch_addr)
|
dw.write_int(patch.address)
|
||||||
dw.write_int(mask)
|
dw.write_int(patch.mask)
|
||||||
dw.write_int(addr, signed=True)
|
dw.write_int(patch.value, signed=True)
|
||||||
|
|
||||||
dw.write_com(binw.Command.ENTRY_POINT)
|
dw.write_com(binw.Command.ENTRY_POINT)
|
||||||
dw.write_int(aux_function_lengths)
|
dw.write_int(aux_function_lengths)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,22 @@ import pelfy
|
||||||
|
|
||||||
ByteOrder = Literal['little', 'big']
|
ByteOrder = Literal['little', 'big']
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class relocation_entry:
|
||||||
|
"""
|
||||||
|
A dataclass for representing a relocation entry
|
||||||
|
"""
|
||||||
|
|
||||||
|
target_symbol_name: str
|
||||||
|
target_symbol_info: str
|
||||||
|
target_symbol_offset: int
|
||||||
|
target_section_index: int
|
||||||
|
function_offset: int
|
||||||
|
start: int
|
||||||
|
pelfy_reloc: pelfy.elf_relocation
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class patch_entry:
|
class patch_entry:
|
||||||
"""
|
"""
|
||||||
|
|
@ -15,36 +31,9 @@ class patch_entry:
|
||||||
type (RelocationType): relocation type"""
|
type (RelocationType): relocation type"""
|
||||||
|
|
||||||
mask: int
|
mask: int
|
||||||
patch_address: int
|
address: int
|
||||||
addend: int
|
value: int
|
||||||
target_symbol_name: str
|
patch_type: int
|
||||||
target_symbol_info: str
|
|
||||||
target_symbol_section_index: int
|
|
||||||
target_symbol_address: int
|
|
||||||
|
|
||||||
|
|
||||||
def translate_relocation(reloc: pelfy.elf_relocation, offset: int) -> patch_entry:
|
|
||||||
if reloc.type.endswith('_PLT32') or reloc.type.endswith('_PC32'):
|
|
||||||
# S + A - P
|
|
||||||
mask = 0xFFFFFFFF # 32 bit
|
|
||||||
imm = offset
|
|
||||||
|
|
||||||
elif reloc.type.endswith('_JUMP26') or reloc.type.endswith('_CALL26'):
|
|
||||||
# S + A - P
|
|
||||||
assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64"
|
|
||||||
mask = 0x3ffffff # 26 bit
|
|
||||||
imm = offset >> 2
|
|
||||||
assert imm < mask, "Relocation immediate value too large"
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Relocation type {reloc.type} not implemented")
|
|
||||||
|
|
||||||
return patch_entry(mask, imm,
|
|
||||||
reloc.fields['r_addend'],
|
|
||||||
reloc.symbol.name,
|
|
||||||
reloc.symbol.info,
|
|
||||||
reloc.symbol.fields['st_shndx'],
|
|
||||||
reloc.symbol.fields['st_value'])
|
|
||||||
|
|
||||||
|
|
||||||
def get_return_function_type(symbol: elf_symbol) -> str:
|
def get_return_function_type(symbol: elf_symbol) -> str:
|
||||||
|
|
@ -72,7 +61,6 @@ def get_last_call_in_function(func: elf_symbol) -> int:
|
||||||
# Find last relocation in function
|
# Find last relocation in function
|
||||||
assert func.relocations, f'No call function in stencil function {func.name}.'
|
assert func.relocations, f'No call function in stencil function {func.name}.'
|
||||||
reloc = func.relocations[-1]
|
reloc = func.relocations[-1]
|
||||||
print(f"reloc.fields['r_addend'] {reloc.fields['r_addend']}")
|
|
||||||
|
|
||||||
instruction_lenghs = 4 if reloc.bits < 32 else 5
|
instruction_lenghs = 4 if reloc.bits < 32 else 5
|
||||||
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - instruction_lenghs
|
return reloc.fields['r_offset'] - func.fields['st_value'] - reloc.fields['r_addend'] - instruction_lenghs
|
||||||
|
|
@ -140,16 +128,7 @@ class stencil_database():
|
||||||
ret.add(sym.section.index)
|
ret.add(sym.section.index)
|
||||||
return list(ret)
|
return list(ret)
|
||||||
|
|
||||||
def get_patch_positions(self, symbol_name: str, stencil: bool = False) -> Generator[patch_entry, None, None]:
|
def get_relocations(self, symbol_name: str, stencil: bool = False) -> Generator[relocation_entry, None, None]:
|
||||||
"""Return patch positions for a provided symbol (function or object)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol_name: function or object name
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
patch_entry: every relocation for the symbol
|
|
||||||
"""
|
|
||||||
arm_hi_byte_flag: bool = False
|
|
||||||
symbol = self.elf.symbols[symbol_name]
|
symbol = self.elf.symbols[symbol_name]
|
||||||
if stencil:
|
if stencil:
|
||||||
start_index, end_index = get_stencil_position(symbol)
|
start_index, end_index = get_stencil_position(symbol)
|
||||||
|
|
@ -159,18 +138,58 @@ class stencil_database():
|
||||||
|
|
||||||
print('->', symbol_name)
|
print('->', symbol_name)
|
||||||
for reloc in symbol.relocations:
|
for reloc in symbol.relocations:
|
||||||
print(' ', symbol_name, arm_hi_byte_flag, reloc.symbol.info)
|
print(' ', symbol_name, reloc.symbol.info, reloc.symbol.name, reloc.type)
|
||||||
|
|
||||||
|
# address to fist byte to patch relative to the start of the symbol
|
||||||
patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
|
patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
|
||||||
|
|
||||||
if patch_offset < end_index - start_index: # Exclude the call to the result_* function
|
if patch_offset < end_index - start_index: # Exclude the call to the result_* function
|
||||||
if reloc.symbol.info == 'STT_SECTION':
|
yield relocation_entry(reloc.symbol.name,
|
||||||
arm_hi_byte_flag = True
|
reloc.symbol.info,
|
||||||
else:
|
reloc.symbol.fields['st_value'],
|
||||||
assert not arm_hi_byte_flag, "Page based relocation for ARM not supported"
|
reloc.symbol.fields['st_shndx'],
|
||||||
# address to fist byte to patch relative to the start of the symbol
|
symbol.fields['st_value'],
|
||||||
|
start_index,
|
||||||
|
reloc)
|
||||||
|
|
||||||
|
def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry:
|
||||||
|
"""Return patch positions for a provided symbol (function or object)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relocation: relocation entry
|
||||||
|
symbol_address: absolute address of the target symbol
|
||||||
|
function_offset: absolute address of the first byte of the
|
||||||
|
function the patch is applied to
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
patch_entry: every relocation for the symbol
|
||||||
|
"""
|
||||||
|
pr = relocation.pelfy_reloc
|
||||||
|
|
||||||
|
# calculate absolut address to the first byte to patch
|
||||||
|
# relative to the start of the (stripped stencil) function:
|
||||||
|
patch_offset = pr.fields['r_offset'] - relocation.function_offset - relocation.start + function_offset
|
||||||
|
|
||||||
|
if pr.type.endswith('_PLT32') or pr.type.endswith('_PC32'):
|
||||||
|
# S + A - P
|
||||||
|
mask = 0xFFFFFFFF # 32 bit
|
||||||
|
patch_value = symbol_address + pr.fields['r_addend'] - patch_offset
|
||||||
|
|
||||||
|
print(f'** {patch_offset=} {relocation.target_symbol_name=} {pr.fields['r_offset']=} {relocation.function_offset=} {relocation.start=} {function_offset=}')
|
||||||
|
print(f' {patch_value=} {symbol_address=} {pr.fields['r_addend']=}, {function_offset=}')
|
||||||
|
|
||||||
|
#elif reloc.type.endswith('_JUMP26') or reloc.type.endswith('_CALL26'):
|
||||||
|
# # S + A - P
|
||||||
|
# assert reloc.file.byteorder == 'little', "Big endian not supported for ARM64"
|
||||||
|
# mask = 0x3ffffff # 26 bit
|
||||||
|
# imm = offset >> 2
|
||||||
|
# assert imm < mask, "Relocation immediate value too large"
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Relocation type {pr.type} not implemented")
|
||||||
|
|
||||||
|
return patch_entry(mask, patch_offset, patch_value, symbol_type)
|
||||||
|
|
||||||
yield translate_relocation(reloc, patch_offset)
|
|
||||||
|
|
||||||
def get_stencil_code(self, name: str) -> bytes:
|
def get_stencil_code(self, name: str) -> bytes:
|
||||||
"""Return the striped function code for a provided function name
|
"""Return the striped function code for a provided function name
|
||||||
|
|
@ -202,6 +221,10 @@ class stencil_database():
|
||||||
name_set |= self.get_sub_functions([r.symbol.name])
|
name_set |= self.get_sub_functions([r.symbol.name])
|
||||||
return name_set
|
return name_set
|
||||||
|
|
||||||
|
def get_type_size(self, type_name: str) -> int:
|
||||||
|
"""Returns the size of a variable type in bytes."""
|
||||||
|
return {'int': 4, 'float': 4}[type_name]
|
||||||
|
|
||||||
def get_symbol_size(self, name: str) -> int:
|
def get_symbol_size(self, name: str) -> int:
|
||||||
"""Returns the size of a specified symbol name."""
|
"""Returns the size of a specified symbol name."""
|
||||||
return self.elf.symbols[name].fields['st_size']
|
return self.elf.symbols[name].fields['st_size']
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue