mirror of https://github.com/Nonannet/copapy.git
test stencils and aux functions added, including test
This commit is contained in:
parent
fb4df412ce
commit
ac6854ff9b
|
|
@ -16,6 +16,22 @@ def sqrt(x: NumLike) -> variable[float] | float:
|
|||
return float(x ** 0.5)
|
||||
|
||||
|
||||
@overload
|
||||
def sqrt2(x: float | int) -> float: ...
|
||||
@overload
|
||||
def sqrt2(x: variable[Any]) -> variable[float]: ...
|
||||
def sqrt2(x: NumLike) -> variable[float] | float:
|
||||
"""Square root function"""
|
||||
if isinstance(x, variable):
|
||||
return add_op('sqrt2', [x, x]) # TODO: fix 2. dummy argument
|
||||
return float(x ** 0.5)
|
||||
|
||||
|
||||
def get_42() -> variable[float]:
|
||||
"""Returns the variable representing the constant 42"""
|
||||
return add_op('get_42', [0.0, 0.0])
|
||||
|
||||
|
||||
def abs(x: T) -> T:
|
||||
"""Absolute value function"""
|
||||
ret = (x < 0) * -x + (x >= 0) * x
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ __attribute__((noinline)) int floor_div(float arg1, float arg2) {
|
|||
return i;
|
||||
}
|
||||
|
||||
__attribute__((noinline)) float fast_sqrt2(float n) {
|
||||
__attribute__((noinline)) float sqrt(float n) {
|
||||
if (n < 0) return -1;
|
||||
|
||||
float x = n; // initial guess
|
||||
|
|
@ -25,8 +25,12 @@ __attribute__((noinline)) float fast_sqrt2(float n) {
|
|||
return x;
|
||||
}
|
||||
|
||||
__attribute__((noinline)) float fast_sqrt(float n) {
|
||||
return n * 3.5 + 4.5;
|
||||
__attribute__((noinline)) float sqrt2(float n) {
|
||||
return n * 20.5 + 4.5;
|
||||
}
|
||||
|
||||
__attribute__((noinline)) float get_42(float n) {
|
||||
return n * + 42.0;
|
||||
}
|
||||
|
||||
float fast_pow_float(float base, float exponent) {
|
||||
|
|
|
|||
|
|
@ -79,10 +79,10 @@ def get_cast(type1: str, type2: str, type_out: str) -> str:
|
|||
|
||||
|
||||
@norm_indent
|
||||
def get_sqrt(type1: str, type2: str) -> str:
|
||||
def get_func2(func_name: str, type1: str, type2: str) -> str:
|
||||
return f"""
|
||||
{stencil_func_prefix}void sqrt_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
||||
result_float_{type2}(fast_sqrt((float)arg1), arg2);
|
||||
{stencil_func_prefix}void {func_name}_{type1}_{type2}({type1} arg1, {type2} arg2) {{
|
||||
result_float_{type2}({func_name}((float)arg1), arg2);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
|
@ -213,7 +213,9 @@ if __name__ == "__main__":
|
|||
code += get_cast(t1, t2, t_out)
|
||||
|
||||
for t1, t2 in permutate(types, types):
|
||||
code += get_sqrt(t1, t2)
|
||||
code += get_func2('sqrt', t1, t2)
|
||||
code += get_func2('sqrt2', t1, t2)
|
||||
code += get_func2('get_42', t1, t2)
|
||||
|
||||
for op, t1, t2 in permutate(ops, types, types):
|
||||
t_out = t1 if t1 == t2 else 'float'
|
||||
|
|
|
|||
|
|
@ -15,17 +15,19 @@ def test_compiled_vectors():
|
|||
t2 = t1.sum()
|
||||
|
||||
t3 = cp.vector(cp.variable(1 / (v + 1)) for v in range(3))
|
||||
t4 = ((t3 * t1) * 2).magnitude()
|
||||
#t4 = ((t3 * t1) * 2).sum()
|
||||
t5 = cp.sqrt(cp.variable(8.0))
|
||||
#t4 = ((t3 * t1) * 2).magnitude()
|
||||
t4 = ((t3 * t1) * 2).sum()
|
||||
t5 = cp._math.sqrt2(cp.variable(8.0))
|
||||
t6 = cp._math.get_42()
|
||||
|
||||
tg = cp.Target()
|
||||
tg.compile(t2, t4, t5)
|
||||
tg.compile(t2, t4, t5, t6)
|
||||
tg.run()
|
||||
|
||||
assert isinstance(t2, cp.variable) and tg.read_value(t2) == 10 + 11 + 12 + 0 + 1 + 2
|
||||
#assert isinstance(t4, cp.variable) and tg.read_value(t4) == ((1/1*10 + 1/2*11 + 1/3*12) * 2)**0.5
|
||||
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 3.5 + 4.5
|
||||
assert isinstance(t5, cp.variable) and tg.read_value(t5) == 8.0 * 20.5 + 4.5
|
||||
assert tg.read_value(t6) == 42.0
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_compiled_vectors()
|
||||
|
|
|
|||
Loading…
Reference in New Issue