diff --git a/pyproject.toml b/pyproject.toml index aba2e87..e2332e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "copapy" -version = "0.0.0" +version = "0.0.1" authors = [ { name="Nicolas Kruse", email="nicolas.kruse@nonan.net" }, ] diff --git a/src/copapy/_basic_types.py b/src/copapy/_basic_types.py index 044203e..c77819e 100644 --- a/src/copapy/_basic_types.py +++ b/src/copapy/_basic_types.py @@ -333,7 +333,7 @@ def add_op(op: str, args: list[variable[Any] | int | float], commutative: bool = arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args] if commutative: - arg_nets = sorted(arg_nets, key=lambda a: a.dtype) + arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets]) diff --git a/src/copapy/_compiler.py b/src/copapy/_compiler.py index 7143a19..64585e1 100644 --- a/src/copapy/_compiler.py +++ b/src/copapy/_compiler.py @@ -99,7 +99,7 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No for node in node_list: if not isinstance(node, CPConstant): for i, net in enumerate(node.args): - if id(net) != id(registers[i]): + if id(net) != id(registers[i]): # TODO: consider register swap and commutative ops #if net in registers: # print('x swap registers') type_list = ['int' if r is None else transl_type(r.dtype) for r in registers] @@ -108,8 +108,11 @@ def add_read_ops(node_list: list[Node]) -> Generator[tuple[Net | None, Node], No registers[i] = net if node in net_lookup: - yield net_lookup[node], node - registers[0] = net_lookup[node] + result_net = net_lookup[node] + yield result_net, node + registers[0] = result_net + if len(node.args) < 2: # Reset virtual register for single argument functions + registers[1] = None else: yield None, node @@ -267,9 +270,9 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi # Get all nets/variables associated with heap memory variable_list = get_nets([[const_net_list]], extended_output_ops) - stencil_names = [node.name for _, node in extended_output_ops] + stencil_names = {node.name for _, node in extended_output_ops} aux_function_names = sdb.get_sub_functions(stencil_names) - used_sections = sdb.const_sections_from_functions(aux_function_names | set(stencil_names)) + used_sections = sdb.const_sections_from_functions(aux_function_names | stencil_names) # Write data section_mem_layout, sections_length = get_section_layout(used_sections, sdb) diff --git a/src/copapy/_math.py b/src/copapy/_math.py index 02ac4ae..4fed4fd 100644 --- a/src/copapy/_math.py +++ b/src/copapy/_math.py @@ -20,7 +20,7 @@ def exp(x: NumLike) -> variable[float] | float: result of e**x """ if isinstance(x, variable): - return add_op('exp', [x, x]) # TODO: fix 2. dummy argument + return add_op('exp', [x]) return float(math.exp(x)) @@ -38,7 +38,7 @@ def log(x: NumLike) -> variable[float] | float: result of ln(x) """ if isinstance(x, variable): - return add_op('log', [x, x]) # TODO: fix 2. dummy argument + return add_op('log', [x]) return float(math.log(x)) @@ -86,7 +86,7 @@ def sqrt(x: NumLike) -> variable[float] | float: Square root of x """ if isinstance(x, variable): - return add_op('sqrt', [x, x]) # TODO: fix 2. dummy argument + return add_op('sqrt', [x]) return float(math.sqrt(x)) @@ -104,7 +104,7 @@ def sin(x: NumLike) -> variable[float] | float: Square root of x """ if isinstance(x, variable): - return add_op('sin', [x, x]) # TODO: fix 2. dummy argument + return add_op('sin', [x]) return math.sin(x) @@ -122,7 +122,7 @@ def cos(x: NumLike) -> variable[float] | float: Cosine of x """ if isinstance(x, variable): - return add_op('cos', [x, x]) # TODO: fix 2. dummy argument + return add_op('cos', [x]) return math.cos(x) @@ -140,7 +140,7 @@ def tan(x: NumLike) -> variable[float] | float: Tangent of x """ if isinstance(x, variable): - return add_op('tan', [x, x]) # TODO: fix 2. dummy argument + return add_op('tan', [x]) return math.tan(x) @@ -158,7 +158,7 @@ def atan(x: NumLike) -> variable[float] | float: Inverse tangent of x """ if isinstance(x, variable): - return add_op('atan', [x, x]) # TODO: fix 2. dummy argument + return add_op('atan', [x]) return math.atan(x) @@ -177,7 +177,7 @@ def atan2(x: NumLike, y: NumLike) -> variable[float] | float: Result in radian """ if isinstance(x, variable) or isinstance(y, variable): - return add_op('atan2', [x, y]) # TODO: fix 2. dummy argument + return add_op('atan2', [x, y]) return math.atan2(x, y) @@ -195,7 +195,7 @@ def asin(x: NumLike) -> variable[float] | float: Inverse sine of x """ if isinstance(x, variable): - return add_op('asin', [x, x]) # TODO: fix 2. dummy argument + return add_op('asin', [x]) return math.asin(x) diff --git a/src/copapy/_stencils.py b/src/copapy/_stencils.py index 6890af2..35e7483 100644 --- a/src/copapy/_stencils.py +++ b/src/copapy/_stencils.py @@ -13,7 +13,6 @@ class relocation_entry: """ A dataclass for representing a relocation entry """ - target_symbol_name: str target_symbol_info: str target_symbol_offset: int @@ -30,8 +29,8 @@ class patch_entry: Attributes: addr (int): address of first byte to patch relative to the start of the symbol - type (RelocationType): relocation type""" - + type (RelocationType): relocation type + """ mask: int address: int value: int @@ -70,13 +69,6 @@ def get_return_function_type(symbol: elf_symbol) -> str: return 'void' -def strip_function(func: elf_symbol) -> bytes: - """Return stencil code by striped stancil function""" - assert func.relocations and any(reloc.symbol.name.startswith('result_') for reloc in func.relocations), f"{func.name} is not a stencil function" - start_index, end_index = get_stencil_position(func) - return func.data[start_index:end_index] - - def get_stencil_position(func: elf_symbol) -> tuple[int, int]: start_index = 0 # There must be no prolog # Find last relocation in function @@ -109,11 +101,6 @@ def get_op_after_last_call_in_function(func: elf_symbol) -> int: return reloc.fields['r_offset'] - func.fields['st_value'] + 4 -def symbol_is_stencil(sym: elf_symbol) -> bool: - return (sym.info == 'STT_FUNC' and len(sym.relocations) > 0 and - sym.relocations[-1].symbol.info == 'STT_NOTYPE') - - class stencil_database(): """A class for loading and querying a stencil database from an ELF object file @@ -153,6 +140,9 @@ class stencil_database(): # sym.relocations # self.elf.symbols[name].data + self._relocation_cache: dict[tuple[str, bool], list[relocation_entry]] = {} + self._stencil_cache: dict[str, tuple[int, int]] = {} + def const_sections_from_functions(self, symbol_names: Iterable[str]) -> list[int]: ret: set[int] = set() @@ -165,6 +155,17 @@ class stencil_database(): return list(ret) def get_relocations(self, symbol_name: str, stencil: bool = False) -> Generator[relocation_entry, None, None]: + cache_key = (symbol_name, stencil) + if cache_key in self._relocation_cache: + # cache hit: + for reloc_entry in self._relocation_cache[cache_key]: + yield reloc_entry + return + + # cache miss: + cache: list[relocation_entry] = [] + self._relocation_cache[cache_key] = cache + symbol = self.elf.symbols[symbol_name] if stencil: start_index, end_index = get_stencil_position(symbol) @@ -178,13 +179,15 @@ class stencil_database(): patch_offset = reloc.fields['r_offset'] - symbol.fields['st_value'] - start_index if patch_offset < end_index - start_index: # Exclude the call to the result_* function - yield relocation_entry(reloc.symbol.name, + reloc_entry = relocation_entry(reloc.symbol.name, reloc.symbol.info, reloc.symbol.fields['st_value'], reloc.symbol.fields['st_shndx'], symbol.fields['st_value'], start_index, reloc) + cache.append(reloc_entry) + yield reloc_entry def get_patch(self, relocation: relocation_entry, symbol_address: int, function_offset: int, symbol_type: int) -> patch_entry: """Return patch positions for a provided symbol (function or object) @@ -275,7 +278,6 @@ class stencil_database(): return patch_entry(mask, patch_offset, patch_value, scale, symbol_type) - def get_stencil_code(self, name: str) -> bytes: """Return the striped function code for a provided function name @@ -285,7 +287,17 @@ class stencil_database(): Returns: Striped function code """ - return strip_function(self.elf.symbols[name]) + if name in self._stencil_cache: + start_index, lengths = self._stencil_cache[name] + else: + func = self.elf.symbols[name] + start_stencil, end_stencil = get_stencil_position(func) + assert func.section + start_index = func.section['sh_offset'] + func['st_value'] + start_stencil + lengths = end_stencil - start_stencil + self._stencil_cache[name] = (start_index, lengths) + + return self.elf.read_bytes(start_index, lengths) def get_sub_functions(self, names: Iterable[str]) -> set[str]: """Return recursively all functions called by stencils or by other functions diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index 33a74c9..78051c3 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -102,10 +102,10 @@ def get_func2(func_name: str, type1: str, type2: str) -> str: @norm_indent -def get_math_func1(func_name: str, type1: str, type2: str) -> str: +def get_math_func1(func_name: str, type1: str) -> str: return f""" - STENCIL void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{ - result_float_{type2}({func_name}f((float)arg1), arg2); + STENCIL void {func_name}_{type1}({type1} arg1) {{ + result_float({func_name}f((float)arg1)); }} """ @@ -242,7 +242,7 @@ if __name__ == "__main__": fnames = ['sqrt', 'exp', 'log', 'sin', 'cos', 'tan', 'asin', 'atan'] for fn, t1 in permutate(fnames, types): - code += get_math_func1(fn, t1, t1) + code += get_math_func1(fn, t1) fnames = ['atan2', 'pow'] for fn, t1, t2 in permutate(fnames, types, types): diff --git a/tests/test_comp_timing.py b/tests/test_comp_timing.py new file mode 100644 index 0000000..d13dd3b --- /dev/null +++ b/tests/test_comp_timing.py @@ -0,0 +1,209 @@ +import time +from copapy import variable +from copapy import backend +from copapy.backend import Write, stencil_db_from_package +import copapy.backend as cpbe +import copapy as cp +import copapy._binwrite as binw +from copapy._compiler import get_nets, get_section_layout, get_data_layout +from copapy._compiler import patch_entry, CPConstant, get_aux_function_mem_layout + +def test_timing_compiler(): + t1 = cp.vector([10, 11]*128) + cp.vector(cp.variable(v) for v in range(256)) + t2 = t1.sum() + t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(256)) + t5 = ((t3 * t1) * 2).magnitude() + out = [Write(t5)] + + print(out) + + print('-- get_edges:') + t0 = time.time() + edges = list(cpbe.get_all_dag_edges(out)) + t1 = time.time() + print(f' found {len(edges)} edges') + #for p in edges: + # print('#', p) + print(f"get_edges time: {t1-t0:.6f}s") + + print('-- get_ordered_ops:') + t0 = time.time() + ordered_ops = list(cpbe.stable_toposort(edges)) + t1 = time.time() + print(f' found {len(ordered_ops)} ops') + #for p in ordered_ops: + # print('#', p) + print(f"get_ordered_ops time: {t1-t0:.6f}s") + + print('-- get_consts:') + t0 = time.time() + const_net_list = cpbe.get_const_nets(ordered_ops) + t1 = time.time() + #for p in const_list: + # print('#', p) + print(f"get_consts time: {t1-t0:.6f}s") + + print('-- add_read_ops:') + t0 = time.time() + output_ops = list(cpbe.add_read_ops(ordered_ops)) + t1 = time.time() + #for p in output_ops: + # print('#', p) + print(f"add_read_ops time: {t1-t0:.6f}s") + + print('-- add_write_ops:') + t0 = time.time() + extended_output_ops = list(cpbe.add_write_ops(output_ops, const_net_list)) + t1 = time.time() + #for p in extended_output_ops: + # print('#', p) + print(f"add_write_ops time: {t1-t0:.6f}s") + print('--') + + print('-- load_stencil_db:') + t0 = time.time() + sdb = stencil_db_from_package() + dw = binw.data_writer(sdb.byteorder) + t1 = time.time() + print(f"load_stencil_db time: {t1-t0:.6f}s") + + + + # Get all nets/variables associated with heap memory + variable_list = get_nets([[const_net_list]], extended_output_ops) + stencil_names = {node.name for _, node in extended_output_ops} + + print(f'-- get_sub_functions: {len(stencil_names)}') + t0 = time.time() + aux_function_names = sdb.get_sub_functions(stencil_names) + t1 = time.time() + print(f"time: {t1-t0:.6f}s") + + print('-- const_sections_from_functions:') + t0 = time.time() + used_sections = sdb.const_sections_from_functions(aux_function_names | stencil_names) + t1 = time.time() + print(f"time: {t1-t0:.6f}s") + + print('-- get_section_layout:') + t0 = time.time() + section_mem_layout, sections_length = get_section_layout(used_sections, sdb) + variable_mem_layout, variables_data_lengths = get_data_layout(variable_list, sdb, sections_length) + t1 = time.time() + print(f"time: {t1-t0:.6f}s") + + + variables: dict[backend.Net, tuple[int, int, str]] = {} + data_list: list[bytes] = [] + patch_list: list[patch_entry] = [] + + + print('-- write_data:') + t0 = time.time() + # Heap constants + for section_id, start, lengths in section_mem_layout: + dw.write_com(binw.Command.COPY_DATA) + dw.write_int(start) + dw.write_int(lengths) + dw.write_bytes(sdb.get_section_data(section_id)) + + # Heap variables + for net, start, lengths in variable_mem_layout: + variables[net] = (start, lengths, net.dtype) + if isinstance(net.source, CPConstant): + dw.write_com(binw.Command.COPY_DATA) + dw.write_int(start) + dw.write_int(lengths) + dw.write_value(net.source.value, lengths) + #print(f'+ {net.dtype} {net.source.value}') + + t1 = time.time() + print(f"time: {t1-t0:.6f}s") + + + + # prep auxiliary_functions + aux_function_mem_layout, aux_function_lengths = get_aux_function_mem_layout(aux_function_names, sdb) + aux_func_addr_lookup = {name: offs for name, offs, _ in aux_function_mem_layout} + + # Prepare program code and relocations + object_addr_lookup = {net: offs for net, offs, _ in variable_mem_layout} + section_addr_lookup = {id: offs for id, offs, _ in section_mem_layout} + + # assemble stencils to main program and patch stencils + data = sdb.get_function_code('entry_function_shell', 'start') + data_list.append(data) + offset = aux_function_lengths + len(data) + + + print('-- relocate stencils:') + t0 = time.time() + for associated_net, node in extended_output_ops: + assert node.name in sdb.stencil_definitions, f"- Warning: {node.name} stencil not found" + data = sdb.get_stencil_code(node.name) + data_list.append(data) + #print(f"* {node.name} ({offset}) " + ' '.join(f'{d:02X}' for d in data)) + + for reloc in sdb.get_relocations(node.name, stencil=True): + if reloc.target_symbol_info in ('STT_OBJECT', 'STT_NOTYPE', 'STT_SECTION'): + #print('-- ' + reloc.target_symbol_name + ' // ' + node.name) + if reloc.target_symbol_name.startswith('dummy_'): + # Patch for write and read addresses to/from heap variables + assert associated_net, f"Relocation found but no net defined for operation {node.name}" + #print(f"Patch for write and read addresses to/from heap variables: {node.name} {patch.target_symbol_info} {patch.target_symbol_name}") + obj_addr = object_addr_lookup[associated_net] + patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value) + elif reloc.target_symbol_name.startswith('result_'): + # Set return jump address to address of following stencil + patch = sdb.get_patch(reloc, offset + len(data), offset, binw.Command.PATCH_FUNC.value) + else: + # Patch constants addresses on heap + assert reloc.target_section_index in section_addr_lookup, f"- Function or object in {node.name} missing: {reloc.pelfy_reloc.symbol.name}" + obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index] + patch = sdb.get_patch(reloc, obj_addr, offset, binw.Command.PATCH_OBJECT.value) + #print('* constants stancils', patch.type, patch.patch_address, binw.Command.PATCH_OBJECT, node.name) + + elif reloc.target_symbol_info == 'STT_FUNC': + func_addr = aux_func_addr_lookup[reloc.target_symbol_name] + patch = sdb.get_patch(reloc, func_addr, offset, binw.Command.PATCH_FUNC.value) + #print(patch.type, patch.addr, binw.Command.PATCH_FUNC, node.name, '->', patch.target_symbol_name) + else: + raise ValueError(f"Unsupported: {node.name} {reloc.target_symbol_info} {reloc.target_symbol_name}") + + patch_list.append(patch) + + offset += len(data) + t1 = time.time() + print(f"time: {t1-t0:.6f}s") + + print('-- relocate aux functions:') + t0 = time.time() + # Patch aux functions + for name, start, _ in aux_function_mem_layout: + for reloc in sdb.get_relocations(name): + + #assert reloc.target_symbol_info != 'STT_FUNC', "Not tested yet!" + + if reloc.target_symbol_info in {'STT_OBJECT', 'STT_NOTYPE', 'STT_SECTION'}: + # Patch constants/variable addresses on heap + #print('--> DATA ', name, reloc.pelfy_reloc.symbol.name, reloc.pelfy_reloc.symbol.info, reloc.pelfy_reloc.symbol.section.name) + assert reloc.target_section_index in section_addr_lookup, f"- Function or object in {name} missing: {reloc.pelfy_reloc.symbol.name}" + obj_addr = reloc.target_symbol_offset + section_addr_lookup[reloc.target_section_index] + patch = sdb.get_patch(reloc, obj_addr, start, binw.Command.PATCH_OBJECT.value) + + elif reloc.target_symbol_info == 'STT_FUNC': + #print('--> FUNC', name, reloc.pelfy_reloc.symbol.name, reloc.pelfy_reloc.symbol.info, reloc.pelfy_reloc.symbol.section.name) + func_addr = aux_func_addr_lookup[reloc.target_symbol_name] + patch = sdb.get_patch(reloc, func_addr, start, binw.Command.PATCH_FUNC.value) + #print(f' FUNC {func_addr=} {start=} {patch.address=}') + + else: + raise ValueError(f"Unsupported: {name=} {reloc.target_symbol_info=} {reloc.target_symbol_name=} {reloc.target_section_index}") + + patch_list.append(patch) + t1 = time.time() + print(f"time: {t1-t0:.6f}s") + + +if __name__ == "__main__": + test_timing_compiler() diff --git a/tools/get_binaries.py b/tools/get_binaries.py index 8aa2e94..8de612c 100644 --- a/tools/get_binaries.py +++ b/tools/get_binaries.py @@ -35,8 +35,10 @@ def main() -> None: url = asset["browser_download_url"] name: str = asset["name"] - if name.endswith('.o'): + if name.startswith('stencils_'): dest = 'src/copapy/obj' + elif name.startswith('musl_'): + dest = 'build/musl' elif name == 'coparun.exe' and os.name == 'nt': dest = 'build/runner' elif name == 'coparun' and os.name == 'posix':