diff --git a/src/copapy/__init__.py b/src/copapy/__init__.py index 5dee792..e6478f7 100644 --- a/src/copapy/__init__.py +++ b/src/copapy/__init__.py @@ -1,8 +1,7 @@ from ._target import Target -from ._basic_types import NumLike, variable, \ - generic_sdb, iif +from ._basic_types import NumLike, variable, generic_sdb, iif from ._vectors import vector -from ._math import sqrt, abs +from ._math import sqrt, abs, sin, cos, tan __all__ = [ "Target", @@ -13,4 +12,7 @@ __all__ = [ "vector", "sqrt", "abs", + "sin", + "cos", + "tan" ] diff --git a/src/copapy/_math.py b/src/copapy/_math.py index f4bcd2f..6dbbc57 100644 --- a/src/copapy/_math.py +++ b/src/copapy/_math.py @@ -1,6 +1,7 @@ from . import variable, NumLike from typing import TypeVar, Any, overload from ._basic_types import add_op +import math T = TypeVar("T", int, float, variable[int], variable[float]) @@ -24,14 +25,55 @@ def sqrt(x: NumLike) -> variable[float] | float: @overload -def sqrt2(x: float | int) -> float: ... +def sin(x: float | int) -> float: ... @overload -def sqrt2(x: variable[Any]) -> variable[float]: ... -def sqrt2(x: NumLike) -> variable[float] | float: - """Square root function""" +def sin(x: variable[Any]) -> variable[float]: ... +def sin(x: NumLike) -> variable[float] | float: + """Sine function + + Arguments: + x: Input value + + Returns: + Square root of x + """ if isinstance(x, variable): - return add_op('sqrt2', [x, x]) # TODO: fix 2. dummy argument - return float(x ** 0.5) + return add_op('sin', [x, x]) # TODO: fix 2. dummy argument + return math.sin(x) + +@overload +def cos(x: float | int) -> float: ... +@overload +def cos(x: variable[Any]) -> variable[float]: ... +def cos(x: NumLike) -> variable[float] | float: + """Cosine function + + Arguments: + x: Input value + + Returns: + Cosine of x + """ + if isinstance(x, variable): + return add_op('cos', [x, x]) # TODO: fix 2. dummy argument + return math.cos(x) + +@overload +def tan(x: float | int) -> float: ... +@overload +def tan(x: variable[Any]) -> variable[float]: ... +def tan(x: NumLike) -> variable[float] | float: + """Tangent function + + Arguments: + x: Input value + + Returns: + Tangent of x + """ + if isinstance(x, variable): + return add_op('tan', [x, x]) # TODO: fix 2. dummy argument + return math.tan(x) def get_42() -> variable[float]: diff --git a/stencils/aux_functions.c b/stencils/aux_functions.c index 2654b8f..89a61f9 100644 --- a/stencils/aux_functions.c +++ b/stencils/aux_functions.c @@ -25,10 +25,6 @@ __attribute__((noinline)) float aux_sqrt(float n) { return x; } -__attribute__((noinline)) float aux_sqrt2(float n) { - return n * 20.5 + 4.5; -} - __attribute__((noinline)) float aux_get_42(float n) { return n + 42.0; } diff --git a/stencils/generate_stencils.py b/stencils/generate_stencils.py index cf597bf..1db9238 100644 --- a/stencils/generate_stencils.py +++ b/stencils/generate_stencils.py @@ -11,7 +11,7 @@ stencil_func_prefix = '__attribute__((naked)) ' # Remove callee prolog stack_size = 64 -includes = ['aux_functions.c'] +includes = ['aux_functions.c', 'trigonometry.c'] def read_files(files: list[str]) -> str: @@ -212,10 +212,9 @@ if __name__ == "__main__": t_out = 'int' if t1 == 'float' else 'float' code += get_cast(t1, t2, t_out) - for t1, t2 in permutate(types, types): - code += get_func2('sqrt', t1, t2) - code += get_func2('sqrt2', t1, t2) - code += get_func2('get_42', t1, t2) + fnames = ['sqrt', 'sin', 'cos', 'tan', 'get_42'] + for fn, t1, t2 in permutate(fnames, types, types): + code += get_func2(fn, t1, t2) for op, t1, t2 in permutate(ops, types, types): t_out = t1 if t1 == t2 else 'float' diff --git a/stencils/trigonometry.c b/stencils/trigonometry.c new file mode 100644 index 0000000..da995bd --- /dev/null +++ b/stencils/trigonometry.c @@ -0,0 +1,146 @@ +const float PI = 3.14159265358979323846f; +const float PI_2 = 1.57079632679489661923f; // pi/2 +const float TWO_OVER_PI = 0.63661977236758134308f; // 2/pi + +__attribute__((noinline)) float aux_sin(float x) { + // convert to double for reduction (better precision) + double xd = (double)x; + + // quadrant index q = nearest integer to x * 2/pi + double qd = xd * (double)TWO_OVER_PI; + // round to nearest integer (tie to even rounding not guaranteed) + int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); + + // range-reduced remainder r = x − q*(pi/2) + // use hi/lo parts for pi/2 to reduce error + const double PIO2_HI = 1.57079625129699707031; // ≈ first 24 bits + const double PIO2_LO = 7.54978941586159635335e-08; // remainder + double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO; + float r = (float)r_d; + + // Select function and sign based on quadrant + int qm = q & 3; + int use_cos = (qm == 1 || qm == 3); + int sign = (qm == 0 || qm == 1) ? +1 : -1; + + float r2 = r * r; + + if (!use_cos) { + // sin(r) polynomial: r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9 + const float s3 = -1.6666667163e-1f; + const float s5 = 8.3333337680e-3f; + const float s7 = -1.9841270114e-4f; + const float s9 = 2.7557314297e-6f; + + float p = ((s9 * r2 + s7) * r2 + s5) * r2 + s3; + float result = r + r * r2 * p; + return sign * result; + } else { + // cos(r) polynomial: 1 + c2*r2 + c4*r4 + c6*r6 + c8*r8 + const float c2 = -0.5f; + const float c4 = 4.1666667908e-2f; + const float c6 = -1.3888889225e-3f; + const float c8 = 2.4801587642e-5f; + + float p = ((c8 * r2 + c6) * r2 + c4) * r2 + c2; + float result = 1.0f + r2 * p; + return sign * result; + } +} + +__attribute__((noinline)) float aux_cos(float x) { + // convert to double for reduction (better precision) + double xd = (double)x; + + // quadrant index q = nearest integer to x * 2/pi + double qd = xd * (double)TWO_OVER_PI; + // round to nearest integer (tie to even rounding not guaranteed) + int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); + + // range-reduced remainder r = x − q*(pi/2) + // use hi/lo parts for pi/2 to reduce error + const double PIO2_HI = 1.57079625129699707031; // ≈ first 24 bits + const double PIO2_LO = 7.54978941586159635335e-08; // remainder + double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO; + float r = (float)r_d; + + // Select function and sign based on quadrant + int qm = q & 3; + int use_sin = (qm == 1 || qm == 3); + int sign = (qm == 0 || qm == 1) ? +1 : -1; + + float r2 = r * r; + + if (use_sin) { + // sin(r) polynomial: r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9 + const float s3 = -1.6666667163e-1f; + const float s5 = 8.3333337680e-3f; + const float s7 = -1.9841270114e-4f; + const float s9 = 2.7557314297e-6f; + + float p = ((s9 * r2 + s7) * r2 + s5) * r2 + s3; + float result = r + r * r2 * p; + return sign * result; + } else { + // cos(r) polynomial: 1 + c2*r2 + c4*r4 + c6*r6 + c8*r8 + const float c2 = -0.5f; + const float c4 = 4.1666667908e-2f; + const float c6 = -1.3888889225e-3f; + const float c8 = 2.4801587642e-5f; + + float p = ((c8 * r2 + c6) * r2 + c4) * r2 + c2; + float result = 1.0f + r2 * p; + return sign * result; + } +} + +__attribute__((noinline)) float aux_tan(float x) { + // Promote to double for argument reduction (improves precision) + double xd = (double)x; + double qd = xd * (double)TWO_OVER_PI; // how many half-pi multiples + int q = (int)(qd + (qd >= 0.0 ? 0.5 : -0.5)); // nearest integer + + // Range reduce: r = x - q*(pi/2) + const double PIO2_HI = 1.57079625129699707031; // π/2 high part + const double PIO2_LO = 7.54978941586159635335e-08; // π/2 low part + double r_d = xd - (double)q * PIO2_HI - (double)q * PIO2_LO; + float r = (float)r_d; + + // For tan: period is π, so q mod 2 determines sign + int qm = q & 3; + int use_cot = (qm == 1 || qm == 3); // tan(x) = ±cot(r) in odd quadrants + int sign = (qm == 1 || qm == 2) ? -1 : +1; + + // Polynomial approximations + // sin(r) ≈ r + s3*r^3 + s5*r^5 + s7*r^7 + s9*r^9 + const float s3 = -1.6666667163e-1f; + const float s5 = 8.3333337680e-3f; + const float s7 = -1.9841270114e-4f; + const float s9 = 2.7557314297e-6f; + + // cos(r) ≈ 1 + c2*r^2 + c4*r^4 + c6*r^6 + c8*r^8 + const float c2 = -0.5f; + const float c4 = 4.1666667908e-2f; + const float c6 = -1.3888889225e-3f; + const float c8 = 2.4801587642e-5f; + + float r2 = r * r; + float sin_r = r + r * r2 * (((s9 * r2 + s7) * r2 + s5) * r2 + s3); + float cos_r = 1.0f + r2 * (((c8 * r2 + c6) * r2 + c4) * r2 + c2); + + float t; + if (!use_cot) { + // tan(r) = sin(r)/cos(r) + t = sin_r / cos_r; + } else { + // cot(r) = cos(r)/sin(r) + t = cos_r / sin_r; + } + + // Avoid catastrophic explosion near vertical asymptotes + // Clip to a large finite value (~1e8) + if (t > 1e8f) t = 1e8f; + if (t < -1e8f) t = -1e8f; + + return sign * t; +} \ No newline at end of file diff --git a/tests/test_math.py b/tests/test_math.py index f1c30a7..acc0d1d 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -37,8 +37,8 @@ def test_fine(): c_f = variable(a_f) # c_b = variable(True) - ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f)) # , c_i & 3) - ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f)) # , a_i & 3) + ret_test = (c_f ** 2, c_i ** -1, cp.sqrt(c_i), cp.sqrt(c_f), cp.sin(c_f), cp.cos(c_f), cp.tan(c_f)) # , c_i & 3) + ret_refe = (a_f ** 2, a_i ** -1, cp.sqrt(a_i), cp.sqrt(a_f), cp.sin(a_f), cp.cos(a_f), cp.tan(a_f)) # , a_i & 3) tg = Target() print('* compile and copy ...')