diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_concatenation.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_concatenation.cpp | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_concatenation.cpp b/unsupported/test/cxx11_tensor_concatenation.cpp new file mode 100644 index 000000000..03ef12e63 --- /dev/null +++ b/unsupported/test/cxx11_tensor_concatenation.cpp @@ -0,0 +1,137 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include <Eigen/CXX11/Tensor> + +using Eigen::Tensor; + +template<int DataLayout> +static void test_dimension_failures() +{ + Tensor<int, 3, DataLayout> left(2, 3, 1); + Tensor<int, 3, DataLayout> right(3, 3, 1); + left.setRandom(); + right.setRandom(); + + // Okay; other dimensions are equal. + Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); + + // Dimension mismatches. + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1)); + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2)); + + // Axis > NumDims or < 0. + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3)); + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1)); +} + +template<int DataLayout> +static void test_static_dimension_failure() +{ + Tensor<int, 2, DataLayout> left(2, 3); + Tensor<int, 3, DataLayout> right(2, 3, 1); + +#ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE + // Technically compatible, but we static assert that the inputs have same + // NumDims. + Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); +#endif + + // This can be worked around in this case. + Tensor<int, 3, DataLayout> concatenation = left + .reshape(Tensor<int, 3>::Dimensions(2, 3, 1)) + .concatenate(right, 0); + Tensor<int, 2, DataLayout> alternative = left + .concatenate(right.reshape(Tensor<int, 2>::Dimensions{{{2, 3}}}), 0); +} + +template<int DataLayout> +static void test_simple_concatenation() +{ + Tensor<int, 3, DataLayout> left(2, 3, 1); + Tensor<int, 3, DataLayout> right(2, 3, 1); + left.setRandom(); + right.setRandom(); + + Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); + VERIFY_IS_EQUAL(concatenation.dimension(0), 4); + VERIFY_IS_EQUAL(concatenation.dimension(1), 3); + VERIFY_IS_EQUAL(concatenation.dimension(2), 1); + for (int j = 0; j < 3; ++j) { + for (int i = 0; i < 2; ++i) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + } + for (int i = 2; i < 4; ++i) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0)); + } + } + + concatenation = left.concatenate(right, 1); + VERIFY_IS_EQUAL(concatenation.dimension(0), 2); + VERIFY_IS_EQUAL(concatenation.dimension(1), 6); + VERIFY_IS_EQUAL(concatenation.dimension(2), 1); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + } + for (int j = 3; j < 6; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0)); + } + } + + concatenation = left.concatenate(right, 2); + VERIFY_IS_EQUAL(concatenation.dimension(0), 2); + VERIFY_IS_EQUAL(concatenation.dimension(1), 3); + VERIFY_IS_EQUAL(concatenation.dimension(2), 2); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0)); + } + } +} + + +// TODO(phli): Add test once we have a real vectorized implementation. +// static void test_vectorized_concatenation() {} + +static void test_concatenation_as_lvalue() +{ + Tensor<int, 2> t1(2, 3); + Tensor<int, 2> t2(2, 3); + t1.setRandom(); + t2.setRandom(); + + Tensor<int, 2> result(4, 3); + result.setRandom(); + t1.concatenate(t2, 0) = result; + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(t1(i, j), result(i, j)); + VERIFY_IS_EQUAL(t2(i, j), result(i+2, j)); + } + } +} + + +void test_cxx11_tensor_concatenation() +{ + CALL_SUBTEST(test_dimension_failures<ColMajor>()); + CALL_SUBTEST(test_dimension_failures<RowMajor>()); + CALL_SUBTEST(test_static_dimension_failure<ColMajor>()); + CALL_SUBTEST(test_static_dimension_failure<RowMajor>()); + CALL_SUBTEST(test_simple_concatenation<ColMajor>()); + CALL_SUBTEST(test_simple_concatenation<RowMajor>()); + // CALL_SUBTEST(test_vectorized_concatenation()); + CALL_SUBTEST(test_concatenation_as_lvalue()); + +} |