mirror of https://github.com/Nonannet/copapy.git
test stencils and aux functions added, including test
This commit is contained in:
parent
fb4df412ce
commit
ac6854ff9b
|
|
@ -16,6 +16,22 @@ def sqrt(x: NumLike) -> variable[float] | float:
|
||||||
return float(x ** 0.5)
|
return float(x ** 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def sqrt2(x: float | int) -> float: ...
|
||||||
|
@overload
|
||||||
|
def sqrt2(x: variable[Any]) -> variable[float]: ...
|
||||||
|
def sqrt2(x: NumLike) -> variable[float] | float:
|
||||||
|
"""Square root function"""
|
||||||
|
if isinstance(x, variable):
|
||||||
|
return add_op('sqrt2', [x, x]) # TODO: fix 2. dummy argument
|
||||||
|
return float(x ** 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def get_42() -> variable[float]:
|
||||||
|
"""Returns the variable representing the constant 42"""
|
||||||
|
return add_op('get_42', [0.0, 0.0])
|
||||||
|
|
||||||
|
|
||||||
def abs(x: T) -> T:
|
def abs(x: T) -> T:
|
||||||
"""Absolute value function"""
|
"""Absolute value function"""
|
||||||
ret = (x < 0) * -x + (x >= 0) * x
|
ret = (x < 0) * -x + (x >= 0) * x
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ __attribute__((noinline)) int floor_div(float arg1, float arg2) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
__attribute__((noinline)) float fast_sqrt2(float n) {
|
__attribute__((noinline)) float sqrt(float n) {
|
||||||
if (n < 0) return -1;
|
if (n < 0) return -1;
|
||||||
|
|
||||||
float x = n; // initial guess
|
float x = n; // initial guess
|
||||||
|
|
@ -25,8 +25,12 @@ __attribute__((noinline)) float fast_sqrt2(float n) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
__attribute__((noinline)) float fast_sqrt(float n) {
|
__attribute__((noinline)) float sqrt2(float n) {
|
||||||
return n * 3.5 + 4.5;
|
return n * 20.5 + 4.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
__attribute__((noinline)) float get_42(float n) {
|
||||||
|
return n * + 42.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
float fast_pow_float(float base, float exponent) {
|
float fast_pow_float(float base, float exponent) {
|
||||||
|
|
|
||||||
|
|
@ -79,10 +79,10 @@ def get_cast(type1: str, type2: str, type_out: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
@norm_indent
|
@norm_indent
|
||||||
def get_sqrt(type1: str, type2: str) -> str:
|
def get_func2(func_name: str, type1: str, type2: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
{stencil_func_prefix}void sqrt_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
{stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
||||||
result_float_{type2}(fast_sqrt((float)arg1), arg2);
|
result_float_{type2}({func_name}((float)arg1), arg2);
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -213,7 +213,9 @@ if __name__ == "__main__":
|
||||||
code += get_cast(t1, t2, t_out)
|
code += get_cast(t1, t2, t_out)
|
||||||
|
|
||||||
for t1, t2 in permutate(types, types):
|
for t1, t2 in permutate(types, types):
|
||||||
code += get_sqrt(t1, t2)
|
code += get_func2('sqrt', t1, t2)
|
||||||
|
code += get_func2('sqrt2', t1, t2)
|
||||||
|
code += get_func2('get_42', t1, t2)
|
||||||
|
|
||||||
for op, t1, t2 in permutate(ops, types, types):
|
for op, t1, t2 in permutate(ops, types, types):
|
||||||
t_out = t1 if t1 == t2 else 'float'
|
t_out = t1 if t1 == t2 else 'float'
|
||||||
|
|
|
||||||
|
|
@ -15,17 +15,19 @@ def test_compiled_vectors():
|
||||||
t2 = t1.sum()
|
t2 = t1.sum()
|
||||||
|
|
||||||
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
|
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
|
||||||
t4 = ((t3 * t1) * 2).magnitude()
|
#t4 = ((t3 * t1) * 2).magnitude()
|
||||||
#t4 = ((t3 * t1) * 2).sum()
|
t4 = ((t3 * t1) * 2).sum()
|
||||||
t5 = cp.sqrt(cp.variable(8.0))
|
t5 = cp._math.sqrt2(cp.variable(8.0))
|
||||||
|
t6 = cp._math.get_42()
|
||||||
|
|
||||||
tg = cp.Target()
|
tg = cp.Target()
|
||||||
tg.compile(t2, t4, t5)
|
tg.compile(t2, t4, t5, t6)
|
||||||
tg.run()
|
tg.run()
|
||||||
|
|
||||||
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
|
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
|
||||||
#assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
|
#assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
|
||||||
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 3.5 + 4.5
|
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 20.5 + 4.5
|
||||||
|
assert tg.read_value(t6) == 42.0
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_compiled_vectors()
|
test_compiled_vectors()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue