diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_map.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_map.cpp | 68 |
1 files changed, 59 insertions, 9 deletions
diff --git a/unsupported/test/cxx11_tensor_map.cpp b/unsupported/test/cxx11_tensor_map.cpp index 3db0ee7c0..4d4f68911 100644 --- a/unsupported/test/cxx11_tensor_map.cpp +++ b/unsupported/test/cxx11_tensor_map.cpp @@ -19,8 +19,8 @@ static void test_0d() Tensor<int, 0> scalar1; Tensor<int, 0, RowMajor> scalar2; - TensorMap<Tensor<const int, 0> > scalar3(scalar1.data()); - TensorMap<Tensor<const int, 0, RowMajor> > scalar4(scalar2.data()); + TensorMap<const Tensor<int, 0> > scalar3(scalar1.data()); + TensorMap<const Tensor<int, 0, RowMajor> > scalar4(scalar2.data()); scalar1() = 7; scalar2() = 13; @@ -37,8 +37,8 @@ static void test_1d() Tensor<int, 1> vec1(6); Tensor<int, 1, RowMajor> vec2(6); - TensorMap<Tensor<const int, 1> > vec3(vec1.data(), 6); - TensorMap<Tensor<const int, 1, RowMajor> > vec4(vec2.data(), 6); + TensorMap<const Tensor<int, 1> > vec3(vec1.data(), 6); + TensorMap<const Tensor<int, 1, RowMajor> > vec4(vec2.data(), 6); vec1(0) = 4; vec2(0) = 0; vec1(1) = 8; vec2(1) = 1; @@ -85,8 +85,8 @@ static void test_2d() mat2(1,1) = 4; mat2(1,2) = 5; - TensorMap<Tensor<const int, 2> > mat3(mat1.data(), 2, 3); - TensorMap<Tensor<const int, 2, RowMajor> > mat4(mat2.data(), 2, 3); + TensorMap<const Tensor<int, 2> > mat3(mat1.data(), 2, 3); + TensorMap<const Tensor<int, 2, RowMajor> > mat4(mat2.data(), 2, 3); VERIFY_IS_EQUAL(mat3.rank(), 2); VERIFY_IS_EQUAL(mat3.size(), 6); @@ -129,8 +129,8 @@ static void test_3d() } } - TensorMap<Tensor<const int, 3> > mat3(mat1.data(), 2, 3, 7); - TensorMap<Tensor<const int, 3, RowMajor> > mat4(mat2.data(), 2, 3, 7); + TensorMap<const Tensor<int, 3> > mat3(mat1.data(), 2, 3, 7); + TensorMap<const Tensor<int, 3, RowMajor> > mat4(mat2.data(), 2, 3, 7); VERIFY_IS_EQUAL(mat3.rank(), 3); VERIFY_IS_EQUAL(mat3.size(), 2*3*7); @@ -265,7 +265,54 @@ static void test_casting() VERIFY_IS_EQUAL(sum1, 861); } -void test_cxx11_tensor_map() +template<typename T> +static const T& add_const(T& value) { + return value; +} + +static void test_0d_const_tensor() +{ + Tensor<int, 0> scalar1; + Tensor<int, 0, RowMajor> scalar2; + + TensorMap<const Tensor<int, 0> > scalar3(add_const(scalar1).data()); + TensorMap<const Tensor<int, 0, RowMajor> > scalar4(add_const(scalar2).data()); + + scalar1() = 7; + scalar2() = 13; + + VERIFY_IS_EQUAL(scalar1.rank(), 0); + VERIFY_IS_EQUAL(scalar1.size(), 1); + + VERIFY_IS_EQUAL(scalar3(), 7); + VERIFY_IS_EQUAL(scalar4(), 13); +} + +static void test_0d_const_tensor_map() +{ + Tensor<int, 0> scalar1; + Tensor<int, 0, RowMajor> scalar2; + + const TensorMap<Tensor<int, 0> > scalar3(scalar1.data()); + const TensorMap<Tensor<int, 0, RowMajor> > scalar4(scalar2.data()); + + // Although TensorMap is constant, we still can write to the underlying + // storage, because we map over non-constant Tensor. + scalar3() = 7; + scalar4() = 13; + + VERIFY_IS_EQUAL(scalar1(), 7); + VERIFY_IS_EQUAL(scalar2(), 13); + + // Pointer to the underlying storage is also non-const. + scalar3.data()[0] = 8; + scalar4.data()[0] = 14; + + VERIFY_IS_EQUAL(scalar1(), 8); + VERIFY_IS_EQUAL(scalar2(), 14); +} + +EIGEN_DECLARE_TEST(cxx11_tensor_map) { CALL_SUBTEST(test_0d()); CALL_SUBTEST(test_1d()); @@ -274,4 +321,7 @@ void test_cxx11_tensor_map() CALL_SUBTEST(test_from_tensor()); CALL_SUBTEST(test_casting()); + + CALL_SUBTEST(test_0d_const_tensor()); + CALL_SUBTEST(test_0d_const_tensor_map()); } |