diff options
Diffstat (limited to 'test/numext.cpp')
-rw-r--r-- | test/numext.cpp | 262 |
1 files changed, 242 insertions, 20 deletions
diff --git a/test/numext.cpp b/test/numext.cpp index 3de33e2f9..8a2fde501 100644 --- a/test/numext.cpp +++ b/test/numext.cpp @@ -9,16 +9,44 @@ #include "main.h" +template<typename T, typename U> +bool check_if_equal_or_nans(const T& actual, const U& expected) { + return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected))); +} + +template<typename T, typename U> +bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) { + return check_if_equal_or_nans(numext::real(actual), numext::real(expected)) + && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected)); +} + +template<typename T, typename U> +bool test_is_equal_or_nans(const T& actual, const U& expected) +{ + if (check_if_equal_or_nans(actual, expected)) { + return true; + } + + // false: + std::cerr + << "\n actual = " << actual + << "\n expected = " << expected << "\n\n"; + return false; +} + +#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b)) + template<typename T> void check_abs() { typedef typename NumTraits<T>::Real Real; + Real zero(0); if(NumTraits<T>::IsSigned) VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1)); VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); - for(int k=0; k<g_repeat*100; ++k) + for(int k=0; k<100; ++k) { T x = internal::random<T>(); if(!internal::is_same<T,bool>::value) @@ -26,28 +54,222 @@ void check_abs() { if(NumTraits<T>::IsSigned) { VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x)); - VERIFY( numext::abs(-x) >= Real(0)); + VERIFY( numext::abs(-x) >= zero ); } - VERIFY( numext::abs(x) >= Real(0)); + VERIFY( numext::abs(x) >= zero ); VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) ); } } -void test_numext() { - CALL_SUBTEST( check_abs<bool>() ); - CALL_SUBTEST( check_abs<signed char>() ); - CALL_SUBTEST( check_abs<unsigned char>() ); - CALL_SUBTEST( check_abs<short>() ); - CALL_SUBTEST( check_abs<unsigned short>() ); - CALL_SUBTEST( check_abs<int>() ); - CALL_SUBTEST( check_abs<unsigned int>() ); - CALL_SUBTEST( check_abs<long>() ); - CALL_SUBTEST( check_abs<unsigned long>() ); - CALL_SUBTEST( check_abs<half>() ); - CALL_SUBTEST( check_abs<float>() ); - CALL_SUBTEST( check_abs<double>() ); - CALL_SUBTEST( check_abs<long double>() ); - - CALL_SUBTEST( check_abs<std::complex<float> >() ); - CALL_SUBTEST( check_abs<std::complex<double> >() ); +template<typename T> +void check_arg() { + typedef typename NumTraits<T>::Real Real; + VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); + VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); + + for(int k=0; k<100; ++k) + { + T x = internal::random<T>(); + Real y = numext::arg(x); + VERIFY_IS_APPROX( y, std::arg(x) ); + } +} + +template<typename T> +struct check_sqrt_impl { + static void run() { + for (int i=0; i<1000; ++i) { + const T x = numext::abs(internal::random<T>()); + const T sqrtx = numext::sqrt(x); + VERIFY_IS_APPROX(sqrtx*sqrtx, x); + } + + // Corner cases. + const T zero = T(0); + const T one = T(1); + const T inf = std::numeric_limits<T>::infinity(); + const T nan = std::numeric_limits<T>::quiet_NaN(); + VERIFY_IS_EQUAL(numext::sqrt(zero), zero); + VERIFY_IS_EQUAL(numext::sqrt(inf), inf); + VERIFY((numext::isnan)(numext::sqrt(nan))); + VERIFY((numext::isnan)(numext::sqrt(-one))); + } +}; + +template<typename T> +struct check_sqrt_impl<std::complex<T> > { + static void run() { + typedef typename std::complex<T> ComplexT; + + for (int i=0; i<1000; ++i) { + const ComplexT x = internal::random<ComplexT>(); + const ComplexT sqrtx = numext::sqrt(x); + VERIFY_IS_APPROX(sqrtx*sqrtx, x); + } + + // Corner cases. + const T zero = T(0); + const T one = T(1); + const T inf = std::numeric_limits<T>::infinity(); + const T nan = std::numeric_limits<T>::quiet_NaN(); + + // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt + const int kNumCorners = 20; + const ComplexT corners[kNumCorners][2] = { + {ComplexT(zero, zero), ComplexT(zero, zero)}, + {ComplexT(-zero, zero), ComplexT(zero, zero)}, + {ComplexT(zero, -zero), ComplexT(zero, zero)}, + {ComplexT(-zero, -zero), ComplexT(zero, zero)}, + {ComplexT(one, inf), ComplexT(inf, inf)}, + {ComplexT(nan, inf), ComplexT(inf, inf)}, + {ComplexT(one, -inf), ComplexT(inf, -inf)}, + {ComplexT(nan, -inf), ComplexT(inf, -inf)}, + {ComplexT(-inf, one), ComplexT(zero, inf)}, + {ComplexT(inf, one), ComplexT(inf, zero)}, + {ComplexT(-inf, -one), ComplexT(zero, -inf)}, + {ComplexT(inf, -one), ComplexT(inf, -zero)}, + {ComplexT(-inf, nan), ComplexT(nan, inf)}, + {ComplexT(inf, nan), ComplexT(inf, nan)}, + {ComplexT(zero, nan), ComplexT(nan, nan)}, + {ComplexT(one, nan), ComplexT(nan, nan)}, + {ComplexT(nan, zero), ComplexT(nan, nan)}, + {ComplexT(nan, one), ComplexT(nan, nan)}, + {ComplexT(nan, -one), ComplexT(nan, nan)}, + {ComplexT(nan, nan), ComplexT(nan, nan)}, + }; + + for (int i=0; i<kNumCorners; ++i) { + const ComplexT& x = corners[i][0]; + const ComplexT sqrtx = corners[i][1]; + VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx); + } + } +}; + +template<typename T> +void check_sqrt() { + check_sqrt_impl<T>::run(); +} + +template<typename T> +struct check_rsqrt_impl { + static void run() { + const T zero = T(0); + const T one = T(1); + const T inf = std::numeric_limits<T>::infinity(); + const T nan = std::numeric_limits<T>::quiet_NaN(); + + for (int i=0; i<1000; ++i) { + const T x = numext::abs(internal::random<T>()); + const T rsqrtx = numext::rsqrt(x); + const T invx = one / x; + VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx); + } + + // Corner cases. + VERIFY_IS_EQUAL(numext::rsqrt(zero), inf); + VERIFY_IS_EQUAL(numext::rsqrt(inf), zero); + VERIFY((numext::isnan)(numext::rsqrt(nan))); + VERIFY((numext::isnan)(numext::rsqrt(-one))); + } +}; + +template<typename T> +struct check_rsqrt_impl<std::complex<T> > { + static void run() { + typedef typename std::complex<T> ComplexT; + const T zero = T(0); + const T one = T(1); + const T inf = std::numeric_limits<T>::infinity(); + const T nan = std::numeric_limits<T>::quiet_NaN(); + + for (int i=0; i<1000; ++i) { + const ComplexT x = internal::random<ComplexT>(); + const ComplexT invx = ComplexT(one, zero) / x; + const ComplexT rsqrtx = numext::rsqrt(x); + VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx); + } + + // GCC and MSVC differ in their treatment of 1/(0 + 0i) + // GCC/clang = (inf, nan) + // MSVC = (nan, nan) + // and 1 / (x + inf i) + // GCC/clang = (0, 0) + // MSVC = (nan, nan) + #if (EIGEN_COMP_GNUC) + { + const int kNumCorners = 20; + const ComplexT corners[kNumCorners][2] = { + // Only consistent across GCC, clang + {ComplexT(zero, zero), ComplexT(zero, zero)}, + {ComplexT(-zero, zero), ComplexT(zero, zero)}, + {ComplexT(zero, -zero), ComplexT(zero, zero)}, + {ComplexT(-zero, -zero), ComplexT(zero, zero)}, + {ComplexT(one, inf), ComplexT(inf, inf)}, + {ComplexT(nan, inf), ComplexT(inf, inf)}, + {ComplexT(one, -inf), ComplexT(inf, -inf)}, + {ComplexT(nan, -inf), ComplexT(inf, -inf)}, + // Consistent across GCC, clang, MSVC + {ComplexT(-inf, one), ComplexT(zero, inf)}, + {ComplexT(inf, one), ComplexT(inf, zero)}, + {ComplexT(-inf, -one), ComplexT(zero, -inf)}, + {ComplexT(inf, -one), ComplexT(inf, -zero)}, + {ComplexT(-inf, nan), ComplexT(nan, inf)}, + {ComplexT(inf, nan), ComplexT(inf, nan)}, + {ComplexT(zero, nan), ComplexT(nan, nan)}, + {ComplexT(one, nan), ComplexT(nan, nan)}, + {ComplexT(nan, zero), ComplexT(nan, nan)}, + {ComplexT(nan, one), ComplexT(nan, nan)}, + {ComplexT(nan, -one), ComplexT(nan, nan)}, + {ComplexT(nan, nan), ComplexT(nan, nan)}, + }; + + for (int i=0; i<kNumCorners; ++i) { + const ComplexT& x = corners[i][0]; + const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1]; + VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx); + } + } + #endif + } +}; + +template<typename T> +void check_rsqrt() { + check_rsqrt_impl<T>::run(); +} + +EIGEN_DECLARE_TEST(numext) { + for(int k=0; k<g_repeat; ++k) + { + CALL_SUBTEST( check_abs<bool>() ); + CALL_SUBTEST( check_abs<signed char>() ); + CALL_SUBTEST( check_abs<unsigned char>() ); + CALL_SUBTEST( check_abs<short>() ); + CALL_SUBTEST( check_abs<unsigned short>() ); + CALL_SUBTEST( check_abs<int>() ); + CALL_SUBTEST( check_abs<unsigned int>() ); + CALL_SUBTEST( check_abs<long>() ); + CALL_SUBTEST( check_abs<unsigned long>() ); + CALL_SUBTEST( check_abs<half>() ); + CALL_SUBTEST( check_abs<bfloat16>() ); + CALL_SUBTEST( check_abs<float>() ); + CALL_SUBTEST( check_abs<double>() ); + CALL_SUBTEST( check_abs<long double>() ); + CALL_SUBTEST( check_abs<std::complex<float> >() ); + CALL_SUBTEST( check_abs<std::complex<double> >() ); + + CALL_SUBTEST( check_arg<std::complex<float> >() ); + CALL_SUBTEST( check_arg<std::complex<double> >() ); + + CALL_SUBTEST( check_sqrt<float>() ); + CALL_SUBTEST( check_sqrt<double>() ); + CALL_SUBTEST( check_sqrt<std::complex<float> >() ); + CALL_SUBTEST( check_sqrt<std::complex<double> >() ); + + CALL_SUBTEST( check_rsqrt<float>() ); + CALL_SUBTEST( check_rsqrt<double>() ); + CALL_SUBTEST( check_rsqrt<std::complex<float> >() ); + CALL_SUBTEST( check_rsqrt<std::complex<double> >() ); + } } |