diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_shuffling.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_shuffling.cpp | 67 |
1 files changed, 61 insertions, 6 deletions
diff --git a/unsupported/test/cxx11_tensor_shuffling.cpp b/unsupported/test/cxx11_tensor_shuffling.cpp index d11444a14..89a64c021 100644 --- a/unsupported/test/cxx11_tensor_shuffling.cpp +++ b/unsupported/test/cxx11_tensor_shuffling.cpp @@ -81,12 +81,12 @@ static void test_expr_shuffling() Tensor<float, 4, DataLayout> expected; expected = tensor.shuffle(shuffles); - Tensor<float, 4, DataLayout> result(5,7,3,2); + Tensor<float, 4, DataLayout> result(5, 7, 3, 2); - array<int, 4> src_slice_dim{{2,3,1,7}}; - array<int, 4> src_slice_start{{0,0,0,0}}; - array<int, 4> dst_slice_dim{{1,7,3,2}}; - array<int, 4> dst_slice_start{{0,0,0,0}}; + array<ptrdiff_t, 4> src_slice_dim{{2, 3, 1, 7}}; + array<ptrdiff_t, 4> src_slice_start{{0, 0, 0, 0}}; + array<ptrdiff_t, 4> dst_slice_dim{{1, 7, 3, 2}}; + array<ptrdiff_t, 4> dst_slice_start{{0, 0, 0, 0}}; for (int i = 0; i < 5; ++i) { result.slice(dst_slice_start, dst_slice_dim) = @@ -215,7 +215,60 @@ static void test_shuffle_unshuffle() } -void test_cxx11_tensor_shuffling() +template <int DataLayout> +static void test_empty_shuffling() +{ + Tensor<float, 4, DataLayout> tensor(2,3,0,7); + tensor.setRandom(); + array<ptrdiff_t, 4> shuffles; + shuffles[0] = 0; + shuffles[1] = 1; + shuffles[2] = 2; + shuffles[3] = 3; + + Tensor<float, 4, DataLayout> no_shuffle; + no_shuffle = tensor.shuffle(shuffles); + + VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2); + VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3); + VERIFY_IS_EQUAL(no_shuffle.dimension(2), 0); + VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 0; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l)); + } + } + } + } + + shuffles[0] = 2; + shuffles[1] = 3; + shuffles[2] = 1; + shuffles[3] = 0; + Tensor<float, 4, DataLayout> shuffle; + shuffle = tensor.shuffle(shuffles); + + VERIFY_IS_EQUAL(shuffle.dimension(0), 0); + VERIFY_IS_EQUAL(shuffle.dimension(1), 7); + VERIFY_IS_EQUAL(shuffle.dimension(2), 3); + VERIFY_IS_EQUAL(shuffle.dimension(3), 2); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 0; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i)); + } + } + } + } +} + + +EIGEN_DECLARE_TEST(cxx11_tensor_shuffling) { CALL_SUBTEST(test_simple_shuffling<ColMajor>()); CALL_SUBTEST(test_simple_shuffling<RowMajor>()); @@ -225,4 +278,6 @@ void test_cxx11_tensor_shuffling() CALL_SUBTEST(test_shuffling_as_value<RowMajor>()); CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>()); CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>()); + CALL_SUBTEST(test_empty_shuffling<ColMajor>()); + CALL_SUBTEST(test_empty_shuffling<RowMajor>()); } |