Merge pull request #8 from Nonannet/dev

Compilation speed drastically improved
This commit is contained in:
Nicolas Kruse 2025-11-14 23:23:21 +01:00 committed by GitHub
commit 2e3ececed2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 265 additions and 39 deletions

View File

@ -1,6 +1,6 @@
[project]
name = "copapy"
version = "0.0.0"
version = "0.0.1"
authors = [
{ name="Nicolas Kruse", email="nicolas.kruse@nonan.net" },
]

View File

@ -333,7 +333,7 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool =
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
if commutative:
arg_nets = sorted(arg_nets, key=lambda a: a.dtype)
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])

View File

@ -99,7 +99,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
for node in node_list:
if not isinstance(node, CPConstant):
for i, net in enumerate(node.args):
if id(net) != id(registers[i]):
if id(net) != id(registers[i]): # TODO: consider register swap and commutative ops
#if net in registers:
# print('x swap registers')
type_list = ['int' if r is None else transl_type(r.dtype) for r in registers]
@ -108,8 +108,11 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No
registers[i] = net
if node in net_lookup:
yield net_lookup[node], node
registers[0] = net_lookup[node]
result_net = net_lookup[node]
yield result_net, node
registers[0] = result_net
if len(node.args) < 2: # Reset virtual register for single argument functions
registers[1] = None
else:
yield None, node
@ -267,9 +270,9 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
# Get all nets/variables associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops)
stencil_names = [node.name for _, node in extended_output_ops]
stencil_names = {node.name for _, node in extended_output_ops}
aux_function_names = sdb.get_sub_functions(stencil_names)
used_sections = sdb.const_sections_from_functions(aux_function_names | set(stencil_names))
used_sections = sdb.const_sections_from_functions(aux_function_names | stencil_names)
# Write data
section_mem_layout, sections_length = get_section_layout(used_sections, sdb)

View File

@ -20,7 +20,7 @@ def exp(x: NumLike) -> variable[float] | float:
result of e**x
"""
if isinstance(x, variable):
return add_op('exp', [x, x]) # TODO: fix 2. dummy argument
return add_op('exp', [x])
return float(math.exp(x))
@ -38,7 +38,7 @@ def log(x: NumLike) -> variable[float] | float:
result of ln(x)
"""
if isinstance(x, variable):
return add_op('log', [x, x]) # TODO: fix 2. dummy argument
return add_op('log', [x])
return float(math.log(x))
@ -86,7 +86,7 @@ def sqrt(x: NumLike) -> variable[float] | float:
Square root of x
"""
if isinstance(x, variable):
return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument
return add_op('sqrt', [x])
return float(math.sqrt(x))
@ -104,7 +104,7 @@ def sin(x: NumLike) -> variable[float] | float:
Square root of x
"""
if isinstance(x, variable):
return add_op('sin', [x, x]) # TODO: fix 2. dummy argument
return add_op('sin', [x])
return math.sin(x)
@ -122,7 +122,7 @@ def cos(x: NumLike) -> variable[float] | float:
Cosine of x
"""
if isinstance(x, variable):
return add_op('cos', [x, x]) # TODO: fix 2. dummy argument
return add_op('cos', [x])
return math.cos(x)
@ -140,7 +140,7 @@ def tan(x: NumLike) -> variable[float] | float:
Tangent of x
"""
if isinstance(x, variable):
return add_op('tan', [x, x]) # TODO: fix 2. dummy argument
return add_op('tan', [x])
return math.tan(x)
@ -158,7 +158,7 @@ def atan(x: NumLike) -> variable[float] | float:
Inverse tangent of x
"""
if isinstance(x, variable):
return add_op('atan', [x, x]) # TODO: fix 2. dummy argument
return add_op('atan', [x])
return math.atan(x)
@ -177,7 +177,7 @@ def atan2(x: NumLike, y: NumLike) -> variable[float] | float:
Result in radian
"""
if isinstance(x, variable) or isinstance(y, variable):
return add_op('atan2', [x, y]) # TODO: fix 2. dummy argument
return add_op('atan2', [x, y])
return math.atan2(x, y)
@ -195,7 +195,7 @@ def asin(x: NumLike) -> variable[float] | float:
Inverse sine of x
"""
if isinstance(x, variable):
return add_op('asin', [x, x]) # TODO: fix 2. dummy argument
return add_op('asin', [x])
return math.asin(x)

View File

@ -13,7 +13,6 @@ class relocation_entry:
"""
A dataclass for representing a relocation entry
"""
target_symbol_name: str
target_symbol_info: str
target_symbol_offset: int
@ -30,8 +29,8 @@ class patch_entry:
Attributes:
addr (int): address of first byte to patch relative to the start of the symbol
type (RelocationType): relocation type"""
type (RelocationType): relocation type
"""
mask: int
address: int
value: int
@ -70,13 +69,6 @@ def get_return_function_type(symbol: elf_symbol) -> str:
return 'void'
def strip_function(func: elf_symbol) -> bytes:
"""Return stencil code by striped stancil function"""
assert func.relocations and any(reloc.symbol.name.startswith('result_') for reloc in func.relocations), f"{func.name} is not a stencil function"
start_index, end_index = get_stencil_position(func)
return func.data[start_index:end_index]
def get_stencil_position(func: elf_symbol) -> tuple[int, int]:
start_index = 0 # There must be no prolog
# Find last relocation in function
@ -109,11 +101,6 @@ def get_op_after_last_call_in_function(func: elf_symbol) -> int:
return reloc.fields['r_offset'] - func.fields['st_value'] + 4
def symbol_is_stencil(sym: elf_symbol) -> bool:
return (sym.info == 'STT_FUNC' and len(sym.relocations) > 0 and
sym.relocations[-1].symbol.info == 'STT_NOTYPE')
class stencil_database():
"""A class for loading and querying a stencil database from an ELF object file
@ -153,6 +140,9 @@ class stencil_database():
# sym.relocations
# self.elf.symbols[name].data
self._relocation_cache: dict[tuple[str, bool], list[relocation_entry]] = {}
self._stencil_cache: dict[str, tuple[int, int]] = {}
def const_sections_from_functions(self, symbol_names: Iterable[str]) -> list[int]:
ret: set[int] = set()
@ -165,6 +155,17 @@ class stencil_database():
return list(ret)
def get_relocations(self, symbol_name: str, stencil: bool = False) -> Generator[relocation_entry, None, None]:
cache_key = (symbol_name, stencil)
if cache_key in self._relocation_cache:
# cache hit:
for reloc_entry in self._relocation_cache[cache_key]:
yield reloc_entry
return
# cache miss:
cache: list[relocation_entry] = []
self._relocation_cache[cache_key] = cache
symbol = self.elf.symbols[symbol_name]
if stencil:
start_index, end_index = get_stencil_position(symbol)
@ -178,13 +179,15 @@ class stencil_database():
patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index
if patch_offset < end_index - start_index: # Exclude the call to the result_* function
yield relocation_entry(reloc.symbol.name,
reloc_entry = relocation_entry(reloc.symbol.name,
reloc.symbol.info,
reloc.symbol.fields['st_value'],
reloc.symbol.fields['st_shndx'],
symbol.fields['st_value'],
start_index,
reloc)
cache.append(reloc_entry)
yield reloc_entry
def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry:
"""Return patch positions for a provided symbol (function or object)
@ -275,7 +278,6 @@ class stencil_database():
return patch_entry(mask, patch_offset, patch_value, scale, symbol_type)
def get_stencil_code(self, name: str) -> bytes:
"""Return the striped function code for a provided function name
@ -285,7 +287,17 @@ class stencil_database():
Returns:
Striped function code
"""
return strip_function(self.elf.symbols[name])
if name in self._stencil_cache:
start_index, lengths = self._stencil_cache[name]
else:
func = self.elf.symbols[name]
start_stencil, end_stencil = get_stencil_position(func)
assert func.section
start_index = func.section['sh_offset'] + func['st_value'] + start_stencil
lengths = end_stencil - start_stencil
self._stencil_cache[name] = (start_index, lengths)
return self.elf.read_bytes(start_index, lengths)
def get_sub_functions(self, names: Iterable[str]) -> set[str]:
"""Return recursively all functions called by stencils or by other functions

View File

@ -102,10 +102,10 @@ def get_func2(func_name: str, type1: str, type2: str) -> str:
@norm_indent
def get_math_func1(func_name: str, type1: str, type2: str) -> str:
def get_math_func1(func_name: str, type1: str) -> str:
return f"""
STENCIL void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_float_{type2}({func_name}f((float)arg1), arg2);
STENCIL void {func_name}_{type1}({type1} arg1) {{
result_float({func_name}f((float)arg1));
}}
"""
@ -242,7 +242,7 @@ if __name__ == "__main__":
fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan']
for fn, t1 in permutate(fnames, types):
code += get_math_func1(fn, t1, t1)
code += get_math_func1(fn, t1)
fnames = ['atan2', 'pow']
for fn, t1, t2 in permutate(fnames, types, types):

209
tests/test_comp_timing.py Normal file
View File

@ -0,0 +1,209 @@
import time
from copapy import variable
from copapy import backend
from copapy.backend import Write, stencil_db_from_package
import copapy.backend as cpbe
import copapy as cp
import copapy._binwrite as binw
from copapy._compiler import get_nets, get_section_layout, get_data_layout
from copapy._compiler import patch_entry, CPConstant, get_aux_function_mem_layout
def test_timing_compiler():
t1 = cp.vector([10, 11]*128) + cp.vector(cp.variable(v) for v in range(256))
t2 = t1.sum()
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(256))
t5 = ((t3 * t1) * 2).magnitude()
out = [Write(t5)]
print(out)
print('-- get_edges:')
t0 = time.time()
edges = list(cpbe.get_all_dag_edges(out))
t1 = time.time()
print(f' found {len(edges)} edges')
#for p in edges:
# print('#', p)
print(f"get_edges time: {t1-t0:.6f}s")
print('-- get_ordered_ops:')
t0 = time.time()
ordered_ops = list(cpbe.stable_toposort(edges))
t1 = time.time()
print(f' found {len(ordered_ops)} ops')
#for p in ordered_ops:
# print('#', p)
print(f"get_ordered_ops time: {t1-t0:.6f}s")
print('-- get_consts:')
t0 = time.time()
const_net_list = cpbe.get_const_nets(ordered_ops)
t1 = time.time()
#for p in const_list:
# print('#', p)
print(f"get_consts time: {t1-t0:.6f}s")
print('-- add_read_ops:')
t0 = time.time()
output_ops = list(cpbe.add_read_ops(ordered_ops))
t1 = time.time()
#for p in output_ops:
# print('#', p)
print(f"add_read_ops time: {t1-t0:.6f}s")
print('-- add_write_ops:')
t0 = time.time()
extended_output_ops = list(cpbe.add_write_ops(output_ops, const_net_list))
t1 = time.time()
#for p in extended_output_ops:
# print('#', p)
print(f"add_write_ops time: {t1-t0:.6f}s")
print('--')
print('-- load_stencil_db:')
t0 = time.time()
sdb = stencil_db_from_package()
dw = binw.data_writer(sdb.byteorder)
t1 = time.time()
print(f"load_stencil_db time: {t1-t0:.6f}s")
# Get all nets/variables associated with heap memory
variable_list = get_nets([[const_net_list]], extended_output_ops)
stencil_names = {node.name for _, node in extended_output_ops}
print(f'-- get_sub_functions: {len(stencil_names)}')
t0 = time.time()
aux_function_names = sdb.get_sub_functions(stencil_names)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
print('-- const_sections_from_functions:')
t0 = time.time()
used_sections = sdb.const_sections_from_functions(aux_function_names | stencil_names)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
print('-- get_section_layout:')
t0 = time.time()
section_mem_layout, sections_length = get_section_layout(used_sections, sdb)
variable_mem_layout, variables_data_lengths = get_data_layout(variable_list, sdb, sections_length)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
variables: dict[backend.Net, tuple[int, int, str]] = {}
data_list: list[bytes] = []
patch_list: list[patch_entry] = []
print('-- write_data:')
t0 = time.time()
# Heap constants
for section_id, start, lengths in section_mem_layout:
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(start)
dw.write_int(lengths)
dw.write_bytes(sdb.get_section_data(section_id))
# Heap variables
for net, start, lengths in variable_mem_layout:
variables[net] = (start, lengths, net.dtype)
if isinstance(net.source, CPConstant):
dw.write_com(binw.Command.COPY_DATA)
dw.write_int(start)
dw.write_int(lengths)
dw.write_value(net.source.value, lengths)
#print(f'+ {net.dtype} {net.source.value}')
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
# prep auxiliary_functions
aux_function_mem_layout, aux_function_lengths = get_aux_function_mem_layout(aux_function_names, sdb)
aux_func_addr_lookup = {name: offs for name, offs, _ in aux_function_mem_layout}
# Prepare program code and relocations
object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout}
section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout}
# assemble stencils to main program and patch stencils
data = sdb.get_function_code('entry_function_shell', 'start')
data_list.append(data)
offset = aux_function_lengths + len(data)
print('-- relocate stencils:')
t0 = time.time()
for associated_net, node in extended_output_ops:
assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found"
data = sdb.get_stencil_code(node.name)
data_list.append(data)
#print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data))
for reloc in sdb.get_relocations(node.name, stencil=True):
if reloc.target_symbol_info in ('STT_OBJECT', 'STT_NOTYPE', 'STT_SECTION'):
#print('-- ' + reloc.target_symbol_name + ' // ' + node.name)
if reloc.target_symbol_name.startswith('dummy_'):
# Patch for write and read addresses to/from heap variables
assert associated_net, f"Relocation found but no net defined for operation {node.name}"
#print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}")
obj_addr = object_addr_lookup[associated_net]
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
elif reloc.target_symbol_name.startswith('result_'):
# Set return jump address to address of following stencil
patch = sdb.get_patch(reloc, offset + len(data), offset, binw.Command.PATCH_FUNC.value)
else:
# Patch constants addresses on heap
assert reloc.target_section_index in section_addr_lookup, f"- Function or object in {node.name} missing: {reloc.pelfy_reloc.symbol.name}"
obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value)
#print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name)
elif reloc.target_symbol_info == 'STT_FUNC':
func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value)
#print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name)
else:
raise ValueError(f"Unsupported: {node.name} {reloc.target_symbol_info} {reloc.target_symbol_name}")
patch_list.append(patch)
offset += len(data)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
print('-- relocate aux functions:')
t0 = time.time()
# Patch aux functions
for name, start, _ in aux_function_mem_layout:
for reloc in sdb.get_relocations(name):
#assert reloc.target_symbol_info != 'STT_FUNC', "Not tested yet!"
if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE', 'STT_SECTION'}:
# Patch constants/variable addresses on heap
#print('--> DATA ', name, reloc.pelfy_reloc.symbol.name, reloc.pelfy_reloc.symbol.info, reloc.pelfy_reloc.symbol.section.name)
assert reloc.target_section_index in section_addr_lookup, f"- Function or object in {name} missing: {reloc.pelfy_reloc.symbol.name}"
obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index]
patch = sdb.get_patch(reloc, obj_addr, start, binw.Command.PATCH_OBJECT.value)
elif reloc.target_symbol_info == 'STT_FUNC':
#print('--> FUNC', name, reloc.pelfy_reloc.symbol.name, reloc.pelfy_reloc.symbol.info, reloc.pelfy_reloc.symbol.section.name)
func_addr = aux_func_addr_lookup[reloc.target_symbol_name]
patch = sdb.get_patch(reloc, func_addr, start, binw.Command.PATCH_FUNC.value)
#print(f' FUNC {func_addr=} {start=} {patch.address=}')
else:
raise ValueError(f"Unsupported: {name=} {reloc.target_symbol_info=} {reloc.target_symbol_name=} {reloc.target_section_index}")
patch_list.append(patch)
t1 = time.time()
print(f"time: {t1-t0:.6f}s")
if __name__ == "__main__":
test_timing_compiler()

View File

@ -35,8 +35,10 @@ def main() -> None:
url = asset["browser_download_url"]
name: str = asset["name"]
if name.endswith('.o'):
if name.startswith('stencils_'):
dest = 'src/copapy/obj'
elif name.startswith('musl_'):
dest = 'build/musl'
elif name == 'coparun.exe' and os.name == 'nt':
dest = 'build/runner'
elif name == 'coparun' and os.name == 'posix':