diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_expr.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_expr.cpp | 180 |
1 files changed, 165 insertions, 15 deletions
diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp index 77e24cb67..169fc1898 100644 --- a/unsupported/test/cxx11_tensor_expr.cpp +++ b/unsupported/test/cxx11_tensor_expr.cpp @@ -7,6 +7,8 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include <numeric> + #include "main.h" #include <Eigen/CXX11/Tensor> @@ -193,26 +195,23 @@ static void test_constants() static void test_boolean() { - Tensor<int, 1> vec(6); - std::copy_n(std::begin({0, 1, 2, 3, 4, 5}), 6, vec.data()); + const int kSize = 31; + Tensor<int, 1> vec(kSize); + std::iota(vec.data(), vec.data() + kSize, 0); // Test ||. Tensor<bool, 1> bool1 = vec < vec.constant(1) || vec > vec.constant(4); - VERIFY_IS_EQUAL(bool1[0], true); - VERIFY_IS_EQUAL(bool1[1], false); - VERIFY_IS_EQUAL(bool1[2], false); - VERIFY_IS_EQUAL(bool1[3], false); - VERIFY_IS_EQUAL(bool1[4], false); - VERIFY_IS_EQUAL(bool1[5], true); + for (int i = 0; i < kSize; ++i) { + bool expected = i < 1 || i > 4; + VERIFY_IS_EQUAL(bool1[i], expected); + } // Test &&, including cast of operand vec. Tensor<bool, 1> bool2 = vec.cast<bool>() && vec < vec.constant(4); - VERIFY_IS_EQUAL(bool2[0], false); - VERIFY_IS_EQUAL(bool2[1], true); - VERIFY_IS_EQUAL(bool2[2], true); - VERIFY_IS_EQUAL(bool2[3], true); - VERIFY_IS_EQUAL(bool2[4], false); - VERIFY_IS_EQUAL(bool2[5], false); + for (int i = 0; i < kSize; ++i) { + bool expected = bool(i) && i < 4; + VERIFY_IS_EQUAL(bool2[i], expected); + } // Compilation tests: // Test Tensor<bool> against results of cast or comparison; verifies that @@ -300,8 +299,152 @@ static void test_select() } } +template <typename Scalar> +void test_minmax_nan_propagation_templ() { + for (int size = 1; size < 17; ++size) { + const Scalar kNaN = std::numeric_limits<Scalar>::quiet_NaN(); + const Scalar kInf = std::numeric_limits<Scalar>::infinity(); + const Scalar kZero(0); + Tensor<Scalar, 1> vec_all_nan(size); + Tensor<Scalar, 1> vec_one_nan(size); + Tensor<Scalar, 1> vec_zero(size); + vec_all_nan.setConstant(kNaN); + vec_zero.setZero(); + vec_one_nan.setZero(); + vec_one_nan(size/2) = kNaN; + + auto verify_all_nan = [&](const Tensor<Scalar, 1>& v) { + for (int i = 0; i < size; ++i) { + VERIFY((numext::isnan)(v(i))); + } + }; + + auto verify_all_zero = [&](const Tensor<Scalar, 1>& v) { + for (int i = 0; i < size; ++i) { + VERIFY_IS_EQUAL(v(i), Scalar(0)); + } + }; + + // Test NaN propagating max. + // max(nan, nan) = nan + // max(nan, 0) = nan + // max(0, nan) = nan + // max(0, 0) = 0 + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(vec_all_nan)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(kZero)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(vec_zero)); + verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNaN)); + verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_all_nan)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(kZero)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(vec_zero)); + + // Test number propagating max. + // max(nan, nan) = nan + // max(nan, 0) = 0 + // max(0, nan) = 0 + // max(0, 0) = 0 + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNumbers>(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNumbers>(vec_all_nan)); + verify_all_zero(vec_all_nan.template cwiseMax<PropagateNumbers>(kZero)); + verify_all_zero(vec_all_nan.template cwiseMax<PropagateNumbers>(vec_zero)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNaN)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_all_nan)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kZero)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_zero)); + + // Test NaN propagating min. + // min(nan, nan) = nan + // min(nan, 0) = nan + // min(0, nan) = nan + // min(0, 0) = 0 + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(vec_all_nan)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(kZero)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(vec_zero)); + verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNaN)); + verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_all_nan)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(kZero)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(vec_zero)); + + // Test number propagating min. + // min(nan, nan) = nan + // min(nan, 0) = 0 + // min(0, nan) = 0 + // min(0, 0) = 0 + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNumbers>(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNumbers>(vec_all_nan)); + verify_all_zero(vec_all_nan.template cwiseMin<PropagateNumbers>(kZero)); + verify_all_zero(vec_all_nan.template cwiseMin<PropagateNumbers>(vec_zero)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNaN)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_all_nan)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kZero)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_zero)); + + // Test min and max reduction + Tensor<Scalar, 0> val; + val = vec_zero.minimum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template minimum<PropagateNaN>(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template minimum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.maximum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template maximum<PropagateNaN>(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template maximum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), kZero); + + // Test NaN propagation for tensor of all NaNs. + val = vec_all_nan.template minimum<PropagateNaN>(); + VERIFY((numext::isnan)(val())); + val = vec_all_nan.template minimum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), kInf); + val = vec_all_nan.template maximum<PropagateNaN>(); + VERIFY((numext::isnan)(val())); + val = vec_all_nan.template maximum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), -kInf); + + // Test NaN propagation for tensor with a single NaN. + val = vec_one_nan.template minimum<PropagateNaN>(); + VERIFY((numext::isnan)(val())); + val = vec_one_nan.template minimum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), (size == 1 ? kInf : kZero)); + val = vec_one_nan.template maximum<PropagateNaN>(); + VERIFY((numext::isnan)(val())); + val = vec_one_nan.template maximum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), (size == 1 ? -kInf : kZero)); + } +} + +static void test_clip() +{ + Tensor<float, 1> vec(6); + vec(0) = 4.0; + vec(1) = 8.0; + vec(2) = 15.0; + vec(3) = 16.0; + vec(4) = 23.0; + vec(5) = 42.0; + + float kMin = 20; + float kMax = 30; + + Tensor<float, 1> vec_clipped(6); + vec_clipped = vec.clip(kMin, kMax); + for (int i = 0; i < 6; ++i) { + VERIFY_IS_EQUAL(vec_clipped(i), numext::mini(numext::maxi(vec(i), kMin), kMax)); + } +} + +static void test_minmax_nan_propagation() +{ + test_minmax_nan_propagation_templ<float>(); + test_minmax_nan_propagation_templ<double>(); +} -void test_cxx11_tensor_expr() +EIGEN_DECLARE_TEST(cxx11_tensor_expr) { CALL_SUBTEST(test_1d()); CALL_SUBTEST(test_2d()); @@ -311,4 +454,11 @@ void test_cxx11_tensor_expr() CALL_SUBTEST(test_functors()); CALL_SUBTEST(test_type_casting()); CALL_SUBTEST(test_select()); + CALL_SUBTEST(test_clip()); + +// Nan propagation does currently not work like one would expect from std::max/std::min, +// so we disable it for now +#if !EIGEN_ARCH_ARM_OR_ARM64 + CALL_SUBTEST(test_minmax_nan_propagation()); +#endif } |