aboutsummaryrefslogtreecommitdiff
path: root/unsupported/test/cxx11_tensor_casts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/test/cxx11_tensor_casts.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_casts.cpp83
1 files changed, 77 insertions, 6 deletions
diff --git a/unsupported/test/cxx11_tensor_casts.cpp b/unsupported/test/cxx11_tensor_casts.cpp
index 3c6d0d2ff..45456f3ef 100644
--- a/unsupported/test/cxx11_tensor_casts.cpp
+++ b/unsupported/test/cxx11_tensor_casts.cpp
@@ -8,6 +8,7 @@
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "main.h"
+#include "random_without_cast_overflow.h"
#include <Eigen/CXX11/Tensor>
@@ -104,12 +105,82 @@ static void test_small_to_big_type_cast()
}
}
+template <typename FromType, typename ToType>
+static void test_type_cast() {
+ Tensor<FromType, 2> ftensor(100, 200);
+ // Generate random values for a valid cast.
+ for (int i = 0; i < 100; ++i) {
+ for (int j = 0; j < 200; ++j) {
+ ftensor(i, j) = internal::random_without_cast_overflow<FromType,ToType>::value();
+ }
+ }
+
+ Tensor<ToType, 2> ttensor(100, 200);
+ ttensor = ftensor.template cast<ToType>();
+
+ for (int i = 0; i < 100; ++i) {
+ for (int j = 0; j < 200; ++j) {
+ const ToType ref = internal::cast<FromType,ToType>(ftensor(i, j));
+ VERIFY_IS_APPROX(ttensor(i, j), ref);
+ }
+ }
+}
+
+template<typename Scalar, typename EnableIf = void>
+struct test_cast_runner {
+ static void run() {
+ test_type_cast<Scalar, bool>();
+ test_type_cast<Scalar, int8_t>();
+ test_type_cast<Scalar, int16_t>();
+ test_type_cast<Scalar, int32_t>();
+ test_type_cast<Scalar, int64_t>();
+ test_type_cast<Scalar, uint8_t>();
+ test_type_cast<Scalar, uint16_t>();
+ test_type_cast<Scalar, uint32_t>();
+ test_type_cast<Scalar, uint64_t>();
+ test_type_cast<Scalar, half>();
+ test_type_cast<Scalar, bfloat16>();
+ test_type_cast<Scalar, float>();
+ test_type_cast<Scalar, double>();
+ test_type_cast<Scalar, std::complex<float>>();
+ test_type_cast<Scalar, std::complex<double>>();
+ }
+};
+
+// Only certain types allow cast from std::complex<>.
+template<typename Scalar>
+struct test_cast_runner<Scalar, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> {
+ static void run() {
+ test_type_cast<Scalar, half>();
+ test_type_cast<Scalar, bfloat16>();
+ test_type_cast<Scalar, std::complex<float>>();
+ test_type_cast<Scalar, std::complex<double>>();
+ }
+};
+
-void test_cxx11_tensor_casts()
+EIGEN_DECLARE_TEST(cxx11_tensor_casts)
{
- CALL_SUBTEST(test_simple_cast());
- CALL_SUBTEST(test_vectorized_cast());
- CALL_SUBTEST(test_float_to_int_cast());
- CALL_SUBTEST(test_big_to_small_type_cast());
- CALL_SUBTEST(test_small_to_big_type_cast());
+ CALL_SUBTEST(test_simple_cast());
+ CALL_SUBTEST(test_vectorized_cast());
+ CALL_SUBTEST(test_float_to_int_cast());
+ CALL_SUBTEST(test_big_to_small_type_cast());
+ CALL_SUBTEST(test_small_to_big_type_cast());
+
+ CALL_SUBTEST(test_cast_runner<bool>::run());
+ CALL_SUBTEST(test_cast_runner<int8_t>::run());
+ CALL_SUBTEST(test_cast_runner<int16_t>::run());
+ CALL_SUBTEST(test_cast_runner<int32_t>::run());
+ CALL_SUBTEST(test_cast_runner<int64_t>::run());
+ CALL_SUBTEST(test_cast_runner<uint8_t>::run());
+ CALL_SUBTEST(test_cast_runner<uint16_t>::run());
+ CALL_SUBTEST(test_cast_runner<uint32_t>::run());
+ CALL_SUBTEST(test_cast_runner<uint64_t>::run());
+ CALL_SUBTEST(test_cast_runner<half>::run());
+ CALL_SUBTEST(test_cast_runner<bfloat16>::run());
+ CALL_SUBTEST(test_cast_runner<float>::run());
+ CALL_SUBTEST(test_cast_runner<double>::run());
+ CALL_SUBTEST(test_cast_runner<std::complex<float>>::run());
+ CALL_SUBTEST(test_cast_runner<std::complex<double>>::run());
+
}