stencil generation improved, relying on tail call optimization (TOC)

This commit is contained in:
Nicolas 2025-11-02 18:36:56 +01:00
parent 58038cef8b
commit f34795cac6
4 changed files with 41 additions and 11 deletions

View File

@ -1,18 +1,19 @@
#include <stdint.h>
#include "stencil_helper.h"
//double (*math_pow)(double, double);
volatile extern int dummy_int;
volatile extern float dummy_float;
__attribute__((noinline)) int floor_div(float arg1, float arg2) {
NOINLINE int floor_div(float arg1, float arg2) {
float x = arg1 / arg2;
int i = (int)x;
if (x < 0 && x != (float)i) i -= 1;
return i;
}
__attribute__((noinline)) float aux_sqrt(float x) {
NOINLINE float aux_sqrt(float x) {
if (x <= 0.0f) return 0.0f;
// --- Improved initial guess using bit-level trick ---
@ -28,7 +29,7 @@ __attribute__((noinline)) float aux_sqrt(float x) {
return y;
}
__attribute__((noinline)) float aux_get_42(float n) {
NOINLINE float aux_get_42(float n) {
return n + 42.0;
}

View File

@ -9,11 +9,11 @@ op_signs = {'add': '+', 'sub': '-', 'mul': '*', 'div': '/', 'pow': '**',
'bwand': '&', 'bwor': '|', 'bwxor': '^'}
entry_func_prefix = ''
stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog
stencil_func_prefix = '' # Remove callee prolog
stack_size = 64
includes = ['aux_functions.c', 'trigonometry.c']
includes = ['stencil_helper.h', 'aux_functions.c', 'trigonometry.c']
def read_files(files: list[str]) -> str:
@ -25,6 +25,10 @@ def read_files(files: list[str]) -> str:
file_path = Path(file_name)
with open(file_path) as f:
ret += f.read().strip(' \n') + '\n\n'
for incl in includes:
ret = ret.replace(f'#include "{incl}"\n', '')
return ret
@ -66,6 +70,7 @@ def get_entry_function_shell() -> str:
def get_op_code(op: str, type1: str, type2: str, type_out: str) -> str:
return f"""
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START({op}_{type1}_{type2});
result_{type_out}_{type2}(arg1 {op_signs[op]} arg2, arg2);
}}
"""
@ -75,6 +80,7 @@ def get_op_code(op: str, type1: str, type2: str, type_out: str) -> str:
def get_cast(type1: str, type2: str, type_out: str) -> str:
return f"""
{stencil_func_prefix}void cast_{type_out}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START(cast_{type_out}_{type1}_{type2});
result_{type_out}_{type2}(({type1})arg1, arg2);
}}
"""
@ -84,6 +90,7 @@ def get_cast(type1: str, type2: str, type_out: str) -> str:
def get_func2(func_name: str, type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START({func_name}_{type1}_{type2});
result_float_{type2}(aux_{func_name}((float)arg1), arg2);
}}
"""
@ -93,6 +100,7 @@ def get_func2(func_name: str, type1: str, type2: str) -> str:
def get_conv_code(type1: str, type2: str, type_out: str) -> str:
return f"""
{stencil_func_prefix}void conv_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START(conv_{type1}_{type2});
result_{type_out}_{type2}(({type_out})arg1, arg2);
}}
"""
@ -102,6 +110,7 @@ def get_conv_code(type1: str, type2: str, type_out: str) -> str:
def get_op_code_float(op: str, type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START({op}_{type1}_{type2});
result_float_{type2}((float)arg1 {op_signs[op]} (float)arg2, arg2);
}}
"""
@ -111,7 +120,7 @@ def get_op_code_float(op: str, type1: str, type2: str) -> str:
def get_pow(type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void pow_{type1}_{type2}({type1} arg1, {type2} arg2) {{
//result_float_{type2}((float)math_pow((double)arg1, (double)arg2), arg2);
STENCIL_START(pow_{type1}_{type2});
result_float_{type2}(fast_pow_float((float)arg1, (float)arg2), arg2);
}}
"""
@ -121,13 +130,16 @@ def get_pow(type1: str, type2: str) -> str:
def get_floordiv(op: str, type1: str, type2: str) -> str:
if type1 == 'int' and type2 == 'int':
return f"""
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_int_{type2}(floor_div((float)arg1, (float)arg2), arg2);
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} a, {type2} b) {{
STENCIL_START({op}_{type1}_{type2});
int result = a / b - ((a % b != 0) && ((a < 0) != (b < 0)));
result_int_{type2}(result, b);
}}
"""
else:
return f"""
{stencil_func_prefix}void {op}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START({op}_{type1}_{type2});
result_float_{type2}((float)floor_div((float)arg1, (float)arg2), arg2);
}}
"""
@ -151,6 +163,7 @@ def get_result_stubs2(type1: str, type2: str) -> str:
def get_read_reg0_code(type1: str, type2: str, type_out: str) -> str:
return f"""
{stencil_func_prefix}void read_{type_out}_reg0_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START(read_{type_out}_reg0_{type1}_{type2});
result_{type_out}_{type2}(dummy_{type_out}, arg2);
}}
"""
@ -160,6 +173,7 @@ def get_read_reg0_code(type1: str, type2: str, type_out: str) -> str:
def get_read_reg1_code(type1: str, type2: str, type_out: str) -> str:
return f"""
{stencil_func_prefix}void read_{type_out}_reg1_{type1}_{type2}({type1} arg1, {type2} arg2) {{
STENCIL_START(read_{type_out}_reg1_{type1}_{type2});
result_{type1}_{type_out}(arg1, dummy_{type_out});
}}
"""
@ -169,6 +183,7 @@ def get_read_reg1_code(type1: str, type2: str, type_out: str) -> str:
def get_write_code(type1: str) -> str:
return f"""
{stencil_func_prefix}void write_{type1}({type1} arg1) {{
STENCIL_START(write_{type1});
dummy_{type1} = arg1;
result_{type1}(arg1);
}}

12
stencils/stencil_helper.h Normal file
View File

@ -0,0 +1,12 @@
#if defined(__GNUC__)
#define NOINLINE __attribute__((noinline))
#define STENCIL_START_EX(funcname) \
__asm__ __volatile__( \
".global stencil_start_" #funcname "\n" \
"stencil_start_" #funcname ":\n" \
)
#define STENCIL_START(funcname)
#else
#define NOINLINE
#define STENCIL_START(funcname)
#endif

View File

@ -1,8 +1,10 @@
#include "stencil_helper.h"
const float PI = 3.14159265358979323846f;
const float PI_2 = 1.57079632679489661923f; // pi/2
const float TWO_OVER_PI = 0.63661977236758134308f; // 2/pi
__attribute__((noinline)) float aux_sin(float x) {
NOINLINE float aux_sin(float x) {
// convert to double for reduction (better precision)
double xd = (double)x;
@ -48,7 +50,7 @@ __attribute__((noinline)) float aux_sin(float x) {
}
}
__attribute__((noinline)) float aux_cos(float x) {
NOINLINE float aux_cos(float x) {
// convert to double for reduction (better precision)
double xd = (double)x;
@ -94,7 +96,7 @@ __attribute__((noinline)) float aux_cos(float x) {
}
}
__attribute__((noinline)) float aux_tan(float x) {
NOINLINE float aux_tan(float x) {
// Promote to double for argument reduction (improves precision)
double xd = (double)x;
double qd = xd * (double)TWO_OVER_PI; // how many half-pi multiples