diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_broadcasting.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_broadcasting.cpp | 141 |
1 files changed, 139 insertions, 2 deletions
diff --git a/unsupported/test/cxx11_tensor_broadcasting.cpp b/unsupported/test/cxx11_tensor_broadcasting.cpp index 5c0ea5889..d3dab891f 100644 --- a/unsupported/test/cxx11_tensor_broadcasting.cpp +++ b/unsupported/test/cxx11_tensor_broadcasting.cpp @@ -91,7 +91,16 @@ static void test_vectorized_broadcasting() } } +#if EIGEN_HAS_VARIADIC_TEMPLATES tensor.resize(11,3,5); +#else + array<Index, 3> new_dims; + new_dims[0] = 11; + new_dims[1] = 3; + new_dims[2] = 5; + tensor.resize(new_dims); +#endif + tensor.setRandom(); broadcast = tensor.broadcast(broadcasts); @@ -115,7 +124,7 @@ static void test_static_broadcasting() Tensor<float, 3, DataLayout> tensor(8,3,5); tensor.setRandom(); -#if EIGEN_HAS_CONSTEXPR +#if defined(EIGEN_HAS_INDEX_LIST) Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts; #else Eigen::array<int, 3> broadcasts; @@ -139,7 +148,16 @@ static void test_static_broadcasting() } } +#if EIGEN_HAS_VARIADIC_TEMPLATES tensor.resize(11,3,5); +#else + array<Index, 3> new_dims; + new_dims[0] = 11; + new_dims[1] = 3; + new_dims[2] = 5; + tensor.resize(new_dims); +#endif + tensor.setRandom(); broadcast = tensor.broadcast(broadcasts); @@ -180,8 +198,119 @@ static void test_fixed_size_broadcasting() #endif } +template <int DataLayout> +static void test_simple_broadcasting_one_by_n() +{ + Tensor<float, 4, DataLayout> tensor(1,13,5,7); + tensor.setRandom(); + array<ptrdiff_t, 4> broadcasts; + broadcasts[0] = 9; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 1; + Tensor<float, 4, DataLayout> broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 9); + VERIFY_IS_EQUAL(broadcast.dimension(1), 13); + VERIFY_IS_EQUAL(broadcast.dimension(2), 5); + VERIFY_IS_EQUAL(broadcast.dimension(3), 7); + + for (int i = 0; i < 9; ++i) { + for (int j = 0; j < 13; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l)); + } + } + } + } +} + +template <int DataLayout> +static void test_simple_broadcasting_n_by_one() +{ + Tensor<float, 4, DataLayout> tensor(7,3,5,1); + tensor.setRandom(); + array<ptrdiff_t, 4> broadcasts; + broadcasts[0] = 1; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 19; + Tensor<float, 4, DataLayout> broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 7); + VERIFY_IS_EQUAL(broadcast.dimension(1), 3); + VERIFY_IS_EQUAL(broadcast.dimension(2), 5); + VERIFY_IS_EQUAL(broadcast.dimension(3), 19); + + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 19; ++l) { + VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l)); + } + } + } + } +} + +template <int DataLayout> +static void test_simple_broadcasting_one_by_n_by_one_1d() +{ + Tensor<float, 3, DataLayout> tensor(1,7,1); + tensor.setRandom(); + array<ptrdiff_t, 3> broadcasts; + broadcasts[0] = 5; + broadcasts[1] = 1; + broadcasts[2] = 13; + Tensor<float, 3, DataLayout> broadcasted; + broadcasted = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcasted.dimension(0), 5); + VERIFY_IS_EQUAL(broadcasted.dimension(1), 7); + VERIFY_IS_EQUAL(broadcasted.dimension(2), 13); + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 13; ++k) { + VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k)); + } + } + } +} + +template <int DataLayout> +static void test_simple_broadcasting_one_by_n_by_one_2d() +{ + Tensor<float, 4, DataLayout> tensor(1,7,13,1); + tensor.setRandom(); + array<ptrdiff_t, 4> broadcasts; + broadcasts[0] = 5; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 19; + Tensor<float, 4, DataLayout> broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 5); + VERIFY_IS_EQUAL(broadcast.dimension(1), 7); + VERIFY_IS_EQUAL(broadcast.dimension(2), 13); + VERIFY_IS_EQUAL(broadcast.dimension(3), 19); + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 13; ++k) { + for (int l = 0; l < 19; ++l) { + VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l)); + } + } + } + } +} -void test_cxx11_tensor_broadcasting() +EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting) { CALL_SUBTEST(test_simple_broadcasting<ColMajor>()); CALL_SUBTEST(test_simple_broadcasting<RowMajor>()); @@ -191,4 +320,12 @@ void test_cxx11_tensor_broadcasting() CALL_SUBTEST(test_static_broadcasting<RowMajor>()); CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>()); CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>()); } |