test stencils and aux functions added, including test

This commit is contained in:
Nicolas Kruse 2025-10-26 16:08:33 +01:00
parent fb4df412ce
commit ac6854ff9b
4 changed files with 36 additions and 12 deletions

View File

@ -16,6 +16,22 @@ def sqrt(x: NumLike) -> variable[float] | float:
return float(x ** 0.5)
@overload
def sqrt2(x: float | int) -> float: ...
@overload
def sqrt2(x: variable[Any]) -> variable[float]: ...
def sqrt2(x: NumLike) -> variable[float] | float:
"""Square root function"""
if isinstance(x, variable):
return add_op('sqrt2', [x, x]) # TODO: fix 2. dummy argument
return float(x ** 0.5)
def get_42() -> variable[float]:
"""Returns the variable representing the constant 42"""
return add_op('get_42', [0.0, 0.0])
def abs(x: T) -> T:
"""Absolute value function"""
ret = (x < 0) * -x + (x >= 0) * x

View File

@ -12,7 +12,7 @@ __attribute__((noinline)) int floor_div(float arg1, float arg2) {
return i;
}
__attribute__((noinline)) float fast_sqrt2(float n) {
__attribute__((noinline)) float sqrt(float n) {
if (n < 0) return -1;
float x = n; // initial guess
@ -25,8 +25,12 @@ __attribute__((noinline)) float fast_sqrt2(float n) {
return x;
}
__attribute__((noinline)) float fast_sqrt(float n) {
return n * 3.5 + 4.5;
__attribute__((noinline)) float sqrt2(float n) {
return n * 20.5 + 4.5;
}
__attribute__((noinline)) float get_42(float n) {
return n * + 42.0;
}
float fast_pow_float(float base, float exponent) {

View File

@ -79,10 +79,10 @@ def get_cast(type1: str, type2: str, type_out: str) -> str:
@norm_indent
def get_sqrt(type1: str, type2: str) -> str:
def get_func2(func_name: str, type1: str, type2: str) -> str:
return f"""
{stencil_func_prefix}void sqrt_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_float_{type2}(fast_sqrt((float)arg1), arg2);
{stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
result_float_{type2}({func_name}((float)arg1), arg2);
}}
"""
@ -213,7 +213,9 @@ if __name__ == "__main__":
code += get_cast(t1, t2, t_out)
for t1, t2 in permutate(types, types):
code += get_sqrt(t1, t2)
code += get_func2('sqrt', t1, t2)
code += get_func2('sqrt2', t1, t2)
code += get_func2('get_42', t1, t2)
for op, t1, t2 in permutate(ops, types, types):
t_out = t1 if t1 == t2 else 'float'

View File

@ -15,17 +15,19 @@ def test_compiled_vectors():
t2 = t1.sum()
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
t4 = ((t3 * t1) * 2).magnitude()
#t4 = ((t3 * t1) * 2).sum()
t5 = cp.sqrt(cp.variable(8.0))
#t4 = ((t3 * t1) * 2).magnitude()
t4 = ((t3 * t1) * 2).sum()
t5 = cp._math.sqrt2(cp.variable(8.0))
t6 = cp._math.get_42()
tg = cp.Target()
tg.compile(t2, t4, t5)
tg.compile(t2, t4, t5, t6)
tg.run()
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
#assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 3.5 + 4.5
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 20.5 + 4.5
assert tg.read_value(t6) == 42.0
if __name__ == "__main__":
test_compiled_vectors()