diff options
author | Yi Kong <yikong@google.com> | 2022-02-25 16:32:14 +0800 |
---|---|---|
committer | Yi Kong <yikong@google.com> | 2022-02-25 15:08:55 +0000 |
commit | 2aab794c004027d008d6b0b64165bf1961d5d2bb (patch) | |
tree | 83bb8f19c67bcafdb2ca4a98414af1b17392ec36 /unsupported/test/cxx11_tensor_broadcasting.cpp | |
parent | ca5aa72016f062fd0712bcb86370478de332bca3 (diff) | |
download | eigen-2aab794c004027d008d6b0b64165bf1961d5d2bb.tar.gz |
Upgrade eigen to 3.4.0
Steps:
* Removed common files between Android copy and the matching upstream copy
* Obtained latest upstream tarball (see README.version)
* Extracted over the directory
Bug: 148287349
Test: presubmit
Change-Id: Iee2744719075fdf000b315e973645923da766111
Diffstat (limited to 'unsupported/test/cxx11_tensor_broadcasting.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_broadcasting.cpp | 141 |
1 files changed, 139 insertions, 2 deletions
diff --git a/unsupported/test/cxx11_tensor_broadcasting.cpp b/unsupported/test/cxx11_tensor_broadcasting.cpp index 5c0ea5889..d3dab891f 100644 --- a/unsupported/test/cxx11_tensor_broadcasting.cpp +++ b/unsupported/test/cxx11_tensor_broadcasting.cpp @@ -91,7 +91,16 @@ static void test_vectorized_broadcasting() } } +#if EIGEN_HAS_VARIADIC_TEMPLATES tensor.resize(11,3,5); +#else + array<Index, 3> new_dims; + new_dims[0] = 11; + new_dims[1] = 3; + new_dims[2] = 5; + tensor.resize(new_dims); +#endif + tensor.setRandom(); broadcast = tensor.broadcast(broadcasts); @@ -115,7 +124,7 @@ static void test_static_broadcasting() Tensor<float, 3, DataLayout> tensor(8,3,5); tensor.setRandom(); -#if EIGEN_HAS_CONSTEXPR +#if defined(EIGEN_HAS_INDEX_LIST) Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts; #else Eigen::array<int, 3> broadcasts; @@ -139,7 +148,16 @@ static void test_static_broadcasting() } } +#if EIGEN_HAS_VARIADIC_TEMPLATES tensor.resize(11,3,5); +#else + array<Index, 3> new_dims; + new_dims[0] = 11; + new_dims[1] = 3; + new_dims[2] = 5; + tensor.resize(new_dims); +#endif + tensor.setRandom(); broadcast = tensor.broadcast(broadcasts); @@ -180,8 +198,119 @@ static void test_fixed_size_broadcasting() #endif } +template <int DataLayout> +static void test_simple_broadcasting_one_by_n() +{ + Tensor<float, 4, DataLayout> tensor(1,13,5,7); + tensor.setRandom(); + array<ptrdiff_t, 4> broadcasts; + broadcasts[0] = 9; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 1; + Tensor<float, 4, DataLayout> broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 9); + VERIFY_IS_EQUAL(broadcast.dimension(1), 13); + VERIFY_IS_EQUAL(broadcast.dimension(2), 5); + VERIFY_IS_EQUAL(broadcast.dimension(3), 7); + + for (int i = 0; i < 9; ++i) { + for (int j = 0; j < 13; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l)); + } + } + } + } +} + +template <int DataLayout> +static void test_simple_broadcasting_n_by_one() +{ + Tensor<float, 4, DataLayout> tensor(7,3,5,1); + tensor.setRandom(); + array<ptrdiff_t, 4> broadcasts; + broadcasts[0] = 1; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 19; + Tensor<float, 4, DataLayout> broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 7); + VERIFY_IS_EQUAL(broadcast.dimension(1), 3); + VERIFY_IS_EQUAL(broadcast.dimension(2), 5); + VERIFY_IS_EQUAL(broadcast.dimension(3), 19); + + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 19; ++l) { + VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l)); + } + } + } + } +} + +template <int DataLayout> +static void test_simple_broadcasting_one_by_n_by_one_1d() +{ + Tensor<float, 3, DataLayout> tensor(1,7,1); + tensor.setRandom(); + array<ptrdiff_t, 3> broadcasts; + broadcasts[0] = 5; + broadcasts[1] = 1; + broadcasts[2] = 13; + Tensor<float, 3, DataLayout> broadcasted; + broadcasted = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcasted.dimension(0), 5); + VERIFY_IS_EQUAL(broadcasted.dimension(1), 7); + VERIFY_IS_EQUAL(broadcasted.dimension(2), 13); + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 13; ++k) { + VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k)); + } + } + } +} + +template <int DataLayout> +static void test_simple_broadcasting_one_by_n_by_one_2d() +{ + Tensor<float, 4, DataLayout> tensor(1,7,13,1); + tensor.setRandom(); + array<ptrdiff_t, 4> broadcasts; + broadcasts[0] = 5; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 19; + Tensor<float, 4, DataLayout> broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 5); + VERIFY_IS_EQUAL(broadcast.dimension(1), 7); + VERIFY_IS_EQUAL(broadcast.dimension(2), 13); + VERIFY_IS_EQUAL(broadcast.dimension(3), 19); + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 13; ++k) { + for (int l = 0; l < 19; ++l) { + VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l)); + } + } + } + } +} -void test_cxx11_tensor_broadcasting() +EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting) { CALL_SUBTEST(test_simple_broadcasting<ColMajor>()); CALL_SUBTEST(test_simple_broadcasting<RowMajor>()); @@ -191,4 +320,12 @@ void test_cxx11_tensor_broadcasting() CALL_SUBTEST(test_static_broadcasting<RowMajor>()); CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>()); CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>()); } |