copapy/stencils/generate_stencils.py

245 lines
6.9 KiB
Python
Raw Normal View History

from typing import Generator, Callable
2025-10-07 21:23:14 +00:00
import argparse
from pathlib import Path
import os
2025-09-11 20:22:17 +00:00
2025-10-18 21:26:35 +00:00
op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/', 'pow': '**',
'gt': '>', 'eq': '==', 'ge': '>=', 'ne': '!=', 'mod': '%'}
2025-09-11 20:22:17 +00:00
2025-10-08 20:47:49 +00:00
entry_func_prefix = ''
2025-10-10 21:22:16 +00:00
stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog
2025-09-20 21:25:07 +00:00
2025-10-14 21:18:58 +00:00
stack_size = 64
includes = ['aux_functions.c']
2025-10-18 21:20:27 +00:00
2025-10-23 15:23:47 +00:00
def read_files(files: list[str]) -> str:
ret = ''
script_dir = Path(__file__).parent
for file_name in files:
file_path = script_dir / file_name
if not os.path.exists(file_path):
file_path = Path(file_name)
with open(file_path) as f:
ret += f.read().strip(' \n') + '\n\n'
return ret
def normalize_indent(text: str) -> str:
text_lines = text.splitlines()
if len(text_lines) > 1 and not text_lines[0].strip():
text_lines = text_lines[1:]
if not text_lines:
return ''
if len(text_lines) > 1 and text_lines[0] and text_lines[0][0] != ' ':
indent_amount = len(text_lines[1]) - len(text_lines[1].lstrip())
else:
indent_amount = len(text_lines[0]) - len(text_lines[0].lstrip())
return '\n' + '\n'.join(
[' ' * max(0, len(line) - len(line.strip()) - indent_amount) + line.strip()
for line in text_lines])
def norm_indent(f: Callable[..., str]) -> Callable[..., str]:
return lambda *x: normalize_indent(f(*x))
@norm_indent
def get_entry_function_shell() -> str:
2025-10-07 21:23:14 +00:00
return f"""
2025-10-12 20:22:30 +00:00
{entry_func_prefix}int entry_function_shell(){{
2025-10-14 21:18:58 +00:00
volatile char stack_place_holder[{stack_size}];
stack_place_holder[0] = 0;
2025-10-10 21:22:16 +00:00
result_int(0);
return 1;
2025-10-07 21:23:14 +00:00
}}
2025-09-30 21:11:14 +00:00
"""
@norm_indent
def get_op_code(op: str, type1: str, type2: str, type_out: str) -> str:
2025-09-11 20:22:17 +00:00
return f"""
2025-10-08 20:47:49 +00:00
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
2025-09-11 20:22:17 +00:00
result_{type_out}_{type2}(arg1 {op_signs[op]} arg2, arg2);
}}
"""
@norm_indent
2025-10-18 21:26:35 +00:00
def get_cast(type1: str, type2: str, type_out: str) -> str:
return f"""
{stencil_func_prefix}void cast_{type_out}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_{type_out}_{type2}(({type1})arg1, arg2);
}}
"""
@norm_indent
def get_func2(func_name: str, type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
2025-10-26 15:09:02 +00:00
result_float_{type2}(aux_{func_name}((float)arg1), arg2);
}}
"""
@norm_indent
2025-10-08 20:21:33 +00:00
def get_conv_code(type1: str, type2: str, type_out: str) -> str:
return f"""
2025-10-08 20:47:49 +00:00
{stencil_func_prefix}void conv_{type1}_{type2}({type1} arg1, {type2} arg2) {{
2025-10-08 20:21:33 +00:00
result_{type_out}_{type2}(({type_out})arg1, arg2);
}}
"""
@norm_indent
2025-10-04 21:20:25 +00:00
def get_op_code_float(op: str, type1: str, type2: str) -> str:
return f"""
2025-10-08 20:47:49 +00:00
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
2025-10-08 20:21:33 +00:00
result_float_{type2}((float)arg1 {op_signs[op]} (float)arg2, arg2);
2025-10-04 21:20:25 +00:00
}}
"""
@norm_indent
2025-10-18 21:26:35 +00:00
def get_pow(type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void pow_{type1}_{type2}({type1} arg1, {type2} arg2) {{
//result_float_{type2}((float)math_pow((double)arg1, (double)arg2), arg2);
result_float_{type2}(fast_pow_float((float)arg1, (float)arg2), arg2);
2025-10-18 21:26:35 +00:00
}}
"""
@norm_indent
2025-10-09 20:54:11 +00:00
def get_floordiv(op: str, type1: str, type2: str) -> str:
if type1 == 'int' and type2 == 'int':
return f"""
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_int_{type2}(floor_div((float)arg1, (float)arg2), arg2);
}}
"""
else:
return f"""
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_float_{type2}((float)floor_div((float)arg1, (float)arg2), arg2);
}}
"""
2025-10-04 21:20:25 +00:00
@norm_indent
def get_result_stubs1(type1: str) -> str:
2025-09-11 20:46:53 +00:00
return f"""
void result_{type1}({type1} arg1);
"""
2025-09-20 21:25:07 +00:00
@norm_indent
def get_result_stubs2(type1: str, type2: str) -> str:
2025-09-11 20:22:17 +00:00
return f"""
void result_{type1}_{type2}({type1} arg1, {type2} arg2);
"""
2025-09-20 21:25:07 +00:00
@norm_indent
def get_read_reg0_code(type1: str, type2: str, type_out: str) -> str:
2025-09-11 20:22:17 +00:00
return f"""
2025-10-08 20:47:49 +00:00
{stencil_func_prefix}void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
2025-09-11 20:22:17 +00:00
result_{type_out}_{type2}(dummy_{type_out}, arg2);
}}
"""
2025-09-20 21:25:07 +00:00
@norm_indent
def get_read_reg1_code(type1: str, type2: str, type_out: str) -> str:
2025-09-11 20:22:17 +00:00
return f"""
2025-10-08 20:47:49 +00:00
{stencil_func_prefix}void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{
2025-09-11 20:22:17 +00:00
result_{type1}_{type_out}(arg1, dummy_{type_out});
}}
"""
2025-09-20 21:25:07 +00:00
@norm_indent
def get_write_code(type1: str) -> str:
2025-09-11 20:22:17 +00:00
return f"""
2025-10-08 20:47:49 +00:00
{stencil_func_prefix}void write_{type1}({type1} arg1) {{
2025-09-11 20:22:17 +00:00
dummy_{type1} = arg1;
2025-09-11 20:46:53 +00:00
result_{type1}(arg1);
2025-09-11 20:22:17 +00:00
}}
"""
2025-09-20 21:25:07 +00:00
2025-09-11 20:22:17 +00:00
def permutate(*lists: list[str]) -> Generator[list[str], None, None]:
if len(lists) == 0:
yield []
return
first, *rest = lists
for item in first:
for items in permutate(*rest):
yield [item, *items]
2025-09-20 21:25:07 +00:00
2025-09-11 20:22:17 +00:00
if __name__ == "__main__":
2025-10-07 21:23:14 +00:00
parser = argparse.ArgumentParser()
parser.add_argument("path", type=str, help="Output file path")
parser.add_argument("--abi", type=str, default="", help="Optionaler String (Standard: '')")
args = parser.parse_args()
if args.abi:
2025-10-08 20:47:49 +00:00
entry_func_prefix = f"__attribute__(({args.abi}_abi)) "
2025-10-07 21:23:14 +00:00
code = "// Auto-generated stencils for copapy - Do not edit manually\n\n"
2025-09-11 20:22:17 +00:00
code += read_files(includes)
2025-09-11 20:46:53 +00:00
2025-10-09 20:53:51 +00:00
# Scalar arithmetic:
types = ['int', 'float']
ops = ['add', 'sub', 'mul', 'div', 'floordiv', 'gt', 'ge', 'eq', 'ne', 'pow']
2025-10-09 20:53:51 +00:00
2025-09-11 20:46:53 +00:00
for t1 in types:
code += get_result_stubs1(t1)
2025-09-11 20:22:17 +00:00
for t1, t2 in permutate(types, types):
2025-09-11 20:46:53 +00:00
code += get_result_stubs2(t1, t2)
2025-09-11 20:22:17 +00:00
code += get_entry_function_shell()
2025-10-12 20:22:30 +00:00
2025-10-19 20:21:54 +00:00
for t1, t2 in permutate(types, types):
t_out = 'int' if t1 == 'float' else 'float'
code += get_cast(t1, t2, t_out)
for t1, t2 in permutate(types, types):
code += get_func2('sqrt', t1, t2)
code += get_func2('sqrt2', t1, t2)
code += get_func2('get_42', t1, t2)
2025-09-11 20:22:17 +00:00
for op, t1, t2 in permutate(ops, types, types):
t_out = t1 if t1 == t2 else 'float'
2025-10-04 21:20:25 +00:00
if op == 'floordiv':
2025-10-09 20:54:11 +00:00
code += get_floordiv('floordiv', t1, t2)
2025-10-08 20:21:33 +00:00
elif op == 'div':
2025-10-04 21:20:25 +00:00
code += get_op_code_float(op, t1, t2)
2025-10-18 21:26:35 +00:00
elif op == 'pow':
code += get_pow(t1, t2)
elif op in {'gt', 'eq', 'ge', 'ne'}:
2025-10-09 20:53:51 +00:00
code += get_op_code(op, t1, t2, 'int')
2025-10-04 21:20:25 +00:00
else:
code += get_op_code(op, t1, t2, t_out)
2025-09-11 20:22:17 +00:00
code += get_op_code('mod', 'int', 'int', 'int')
2025-09-11 20:22:17 +00:00
for t1, t2, t_out in permutate(types, types, types):
code += get_read_reg0_code(t1, t2, t_out)
code += get_read_reg1_code(t1, t2, t_out)
2025-09-11 20:46:53 +00:00
for t1 in types:
code += get_write_code(t1)
2025-09-11 20:22:17 +00:00
print(f"Write file {args.path}...")
2025-10-07 21:23:14 +00:00
with open(args.path, 'w') as f:
2025-09-11 20:22:17 +00:00
f.write(code)