aboutsummaryrefslogtreecommitdiff
path: root/test/numext.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/numext.cpp')
-rw-r--r--test/numext.cpp262
1 files changed, 242 insertions, 20 deletions
diff --git a/test/numext.cpp b/test/numext.cpp
index 3de33e2f9..8a2fde501 100644
--- a/test/numext.cpp
+++ b/test/numext.cpp
@@ -9,16 +9,44 @@
#include "main.h"
+template<typename T, typename U>
+bool check_if_equal_or_nans(const T& actual, const U& expected) {
+ return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
+}
+
+template<typename T, typename U>
+bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
+ return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
+ && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
+}
+
+template<typename T, typename U>
+bool test_is_equal_or_nans(const T& actual, const U& expected)
+{
+ if (check_if_equal_or_nans(actual, expected)) {
+ return true;
+ }
+
+ // false:
+ std::cerr
+ << "\n actual = " << actual
+ << "\n expected = " << expected << "\n\n";
+ return false;
+}
+
+#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
+
template<typename T>
void check_abs() {
typedef typename NumTraits<T>::Real Real;
+ Real zero(0);
if(NumTraits<T>::IsSigned)
VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
- for(int k=0; k<g_repeat*100; ++k)
+ for(int k=0; k<100; ++k)
{
T x = internal::random<T>();
if(!internal::is_same<T,bool>::value)
@@ -26,28 +54,222 @@ void check_abs() {
if(NumTraits<T>::IsSigned)
{
VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
- VERIFY( numext::abs(-x) >= Real(0));
+ VERIFY( numext::abs(-x) >= zero );
}
- VERIFY( numext::abs(x) >= Real(0));
+ VERIFY( numext::abs(x) >= zero );
VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
}
}
-void test_numext() {
- CALL_SUBTEST( check_abs<bool>() );
- CALL_SUBTEST( check_abs<signed char>() );
- CALL_SUBTEST( check_abs<unsigned char>() );
- CALL_SUBTEST( check_abs<short>() );
- CALL_SUBTEST( check_abs<unsigned short>() );
- CALL_SUBTEST( check_abs<int>() );
- CALL_SUBTEST( check_abs<unsigned int>() );
- CALL_SUBTEST( check_abs<long>() );
- CALL_SUBTEST( check_abs<unsigned long>() );
- CALL_SUBTEST( check_abs<half>() );
- CALL_SUBTEST( check_abs<float>() );
- CALL_SUBTEST( check_abs<double>() );
- CALL_SUBTEST( check_abs<long double>() );
-
- CALL_SUBTEST( check_abs<std::complex<float> >() );
- CALL_SUBTEST( check_abs<std::complex<double> >() );
+template<typename T>
+void check_arg() {
+ typedef typename NumTraits<T>::Real Real;
+ VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
+ VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
+
+ for(int k=0; k<100; ++k)
+ {
+ T x = internal::random<T>();
+ Real y = numext::arg(x);
+ VERIFY_IS_APPROX( y, std::arg(x) );
+ }
+}
+
+template<typename T>
+struct check_sqrt_impl {
+ static void run() {
+ for (int i=0; i<1000; ++i) {
+ const T x = numext::abs(internal::random<T>());
+ const T sqrtx = numext::sqrt(x);
+ VERIFY_IS_APPROX(sqrtx*sqrtx, x);
+ }
+
+ // Corner cases.
+ const T zero = T(0);
+ const T one = T(1);
+ const T inf = std::numeric_limits<T>::infinity();
+ const T nan = std::numeric_limits<T>::quiet_NaN();
+ VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
+ VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
+ VERIFY((numext::isnan)(numext::sqrt(nan)));
+ VERIFY((numext::isnan)(numext::sqrt(-one)));
+ }
+};
+
+template<typename T>
+struct check_sqrt_impl<std::complex<T> > {
+ static void run() {
+ typedef typename std::complex<T> ComplexT;
+
+ for (int i=0; i<1000; ++i) {
+ const ComplexT x = internal::random<ComplexT>();
+ const ComplexT sqrtx = numext::sqrt(x);
+ VERIFY_IS_APPROX(sqrtx*sqrtx, x);
+ }
+
+ // Corner cases.
+ const T zero = T(0);
+ const T one = T(1);
+ const T inf = std::numeric_limits<T>::infinity();
+ const T nan = std::numeric_limits<T>::quiet_NaN();
+
+ // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
+ const int kNumCorners = 20;
+ const ComplexT corners[kNumCorners][2] = {
+ {ComplexT(zero, zero), ComplexT(zero, zero)},
+ {ComplexT(-zero, zero), ComplexT(zero, zero)},
+ {ComplexT(zero, -zero), ComplexT(zero, zero)},
+ {ComplexT(-zero, -zero), ComplexT(zero, zero)},
+ {ComplexT(one, inf), ComplexT(inf, inf)},
+ {ComplexT(nan, inf), ComplexT(inf, inf)},
+ {ComplexT(one, -inf), ComplexT(inf, -inf)},
+ {ComplexT(nan, -inf), ComplexT(inf, -inf)},
+ {ComplexT(-inf, one), ComplexT(zero, inf)},
+ {ComplexT(inf, one), ComplexT(inf, zero)},
+ {ComplexT(-inf, -one), ComplexT(zero, -inf)},
+ {ComplexT(inf, -one), ComplexT(inf, -zero)},
+ {ComplexT(-inf, nan), ComplexT(nan, inf)},
+ {ComplexT(inf, nan), ComplexT(inf, nan)},
+ {ComplexT(zero, nan), ComplexT(nan, nan)},
+ {ComplexT(one, nan), ComplexT(nan, nan)},
+ {ComplexT(nan, zero), ComplexT(nan, nan)},
+ {ComplexT(nan, one), ComplexT(nan, nan)},
+ {ComplexT(nan, -one), ComplexT(nan, nan)},
+ {ComplexT(nan, nan), ComplexT(nan, nan)},
+ };
+
+ for (int i=0; i<kNumCorners; ++i) {
+ const ComplexT& x = corners[i][0];
+ const ComplexT sqrtx = corners[i][1];
+ VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
+ }
+ }
+};
+
+template<typename T>
+void check_sqrt() {
+ check_sqrt_impl<T>::run();
+}
+
+template<typename T>
+struct check_rsqrt_impl {
+ static void run() {
+ const T zero = T(0);
+ const T one = T(1);
+ const T inf = std::numeric_limits<T>::infinity();
+ const T nan = std::numeric_limits<T>::quiet_NaN();
+
+ for (int i=0; i<1000; ++i) {
+ const T x = numext::abs(internal::random<T>());
+ const T rsqrtx = numext::rsqrt(x);
+ const T invx = one / x;
+ VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
+ }
+
+ // Corner cases.
+ VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
+ VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
+ VERIFY((numext::isnan)(numext::rsqrt(nan)));
+ VERIFY((numext::isnan)(numext::rsqrt(-one)));
+ }
+};
+
+template<typename T>
+struct check_rsqrt_impl<std::complex<T> > {
+ static void run() {
+ typedef typename std::complex<T> ComplexT;
+ const T zero = T(0);
+ const T one = T(1);
+ const T inf = std::numeric_limits<T>::infinity();
+ const T nan = std::numeric_limits<T>::quiet_NaN();
+
+ for (int i=0; i<1000; ++i) {
+ const ComplexT x = internal::random<ComplexT>();
+ const ComplexT invx = ComplexT(one, zero) / x;
+ const ComplexT rsqrtx = numext::rsqrt(x);
+ VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
+ }
+
+ // GCC and MSVC differ in their treatment of 1/(0 + 0i)
+ // GCC/clang = (inf, nan)
+ // MSVC = (nan, nan)
+ // and 1 / (x + inf i)
+ // GCC/clang = (0, 0)
+ // MSVC = (nan, nan)
+ #if (EIGEN_COMP_GNUC)
+ {
+ const int kNumCorners = 20;
+ const ComplexT corners[kNumCorners][2] = {
+ // Only consistent across GCC, clang
+ {ComplexT(zero, zero), ComplexT(zero, zero)},
+ {ComplexT(-zero, zero), ComplexT(zero, zero)},
+ {ComplexT(zero, -zero), ComplexT(zero, zero)},
+ {ComplexT(-zero, -zero), ComplexT(zero, zero)},
+ {ComplexT(one, inf), ComplexT(inf, inf)},
+ {ComplexT(nan, inf), ComplexT(inf, inf)},
+ {ComplexT(one, -inf), ComplexT(inf, -inf)},
+ {ComplexT(nan, -inf), ComplexT(inf, -inf)},
+ // Consistent across GCC, clang, MSVC
+ {ComplexT(-inf, one), ComplexT(zero, inf)},
+ {ComplexT(inf, one), ComplexT(inf, zero)},
+ {ComplexT(-inf, -one), ComplexT(zero, -inf)},
+ {ComplexT(inf, -one), ComplexT(inf, -zero)},
+ {ComplexT(-inf, nan), ComplexT(nan, inf)},
+ {ComplexT(inf, nan), ComplexT(inf, nan)},
+ {ComplexT(zero, nan), ComplexT(nan, nan)},
+ {ComplexT(one, nan), ComplexT(nan, nan)},
+ {ComplexT(nan, zero), ComplexT(nan, nan)},
+ {ComplexT(nan, one), ComplexT(nan, nan)},
+ {ComplexT(nan, -one), ComplexT(nan, nan)},
+ {ComplexT(nan, nan), ComplexT(nan, nan)},
+ };
+
+ for (int i=0; i<kNumCorners; ++i) {
+ const ComplexT& x = corners[i][0];
+ const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
+ VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
+ }
+ }
+ #endif
+ }
+};
+
+template<typename T>
+void check_rsqrt() {
+ check_rsqrt_impl<T>::run();
+}
+
+EIGEN_DECLARE_TEST(numext) {
+ for(int k=0; k<g_repeat; ++k)
+ {
+ CALL_SUBTEST( check_abs<bool>() );
+ CALL_SUBTEST( check_abs<signed char>() );
+ CALL_SUBTEST( check_abs<unsigned char>() );
+ CALL_SUBTEST( check_abs<short>() );
+ CALL_SUBTEST( check_abs<unsigned short>() );
+ CALL_SUBTEST( check_abs<int>() );
+ CALL_SUBTEST( check_abs<unsigned int>() );
+ CALL_SUBTEST( check_abs<long>() );
+ CALL_SUBTEST( check_abs<unsigned long>() );
+ CALL_SUBTEST( check_abs<half>() );
+ CALL_SUBTEST( check_abs<bfloat16>() );
+ CALL_SUBTEST( check_abs<float>() );
+ CALL_SUBTEST( check_abs<double>() );
+ CALL_SUBTEST( check_abs<long double>() );
+ CALL_SUBTEST( check_abs<std::complex<float> >() );
+ CALL_SUBTEST( check_abs<std::complex<double> >() );
+
+ CALL_SUBTEST( check_arg<std::complex<float> >() );
+ CALL_SUBTEST( check_arg<std::complex<double> >() );
+
+ CALL_SUBTEST( check_sqrt<float>() );
+ CALL_SUBTEST( check_sqrt<double>() );
+ CALL_SUBTEST( check_sqrt<std::complex<float> >() );
+ CALL_SUBTEST( check_sqrt<std::complex<double> >() );
+
+ CALL_SUBTEST( check_rsqrt<float>() );
+ CALL_SUBTEST( check_rsqrt<double>() );
+ CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
+ CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
+ }
}