diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_image_patch.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_image_patch.cpp | 54 |
1 files changed, 53 insertions, 1 deletions
diff --git a/unsupported/test/cxx11_tensor_image_patch.cpp b/unsupported/test/cxx11_tensor_image_patch.cpp index 475c59651..862f1f7f0 100644 --- a/unsupported/test/cxx11_tensor_image_patch.cpp +++ b/unsupported/test/cxx11_tensor_image_patch.cpp @@ -405,6 +405,57 @@ void test_patch_padding_same() } } +// Verifies that SAME padding, when computed as negative values, will be clipped +// to zero. +void test_patch_padding_same_negative_padding_clip_to_zero() { + int input_depth = 1; + int input_rows = 15; + int input_cols = 1; + int input_batches = 1; + int ksize = 1; // Corresponds to the Rows and Cols for + // tensor.extract_image_patches<>. + int row_stride = 5; + int col_stride = 1; + // ColMajor + Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches); + // Initializes tensor with incrementing numbers. + for (int i = 0; i < tensor.size(); ++i) { + tensor.data()[i] = i + 1; + } + Tensor<float, 5> result = tensor.extract_image_patches( + ksize, ksize, row_stride, col_stride, 1, 1, PADDING_SAME); + // row padding will be computed as -2 originally and then be clipped to 0. + VERIFY_IS_EQUAL(result.coeff(0), 1.0f); + VERIFY_IS_EQUAL(result.coeff(1), 6.0f); + VERIFY_IS_EQUAL(result.coeff(2), 11.0f); + + VERIFY_IS_EQUAL(result.dimension(0), input_depth); // depth + VERIFY_IS_EQUAL(result.dimension(1), ksize); // kernel rows + VERIFY_IS_EQUAL(result.dimension(2), ksize); // kernel cols + VERIFY_IS_EQUAL(result.dimension(3), 3); // number of patches + VERIFY_IS_EQUAL(result.dimension(4), input_batches); // number of batches + + // RowMajor + Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout(); + VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3)); + VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2)); + VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1)); + VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0)); + + Tensor<float, 5, RowMajor> result_row_major = + tensor_row_major.extract_image_patches(ksize, ksize, row_stride, + col_stride, 1, 1, PADDING_SAME); + VERIFY_IS_EQUAL(result_row_major.coeff(0), 1.0f); + VERIFY_IS_EQUAL(result_row_major.coeff(1), 6.0f); + VERIFY_IS_EQUAL(result_row_major.coeff(2), 11.0f); + + VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4)); + VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3)); + VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2)); + VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1)); + VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0)); +} + void test_patch_no_extra_dim() { Tensor<float, 3> tensor(2,3,5); @@ -746,7 +797,7 @@ void test_imagenet_patches() } } -void test_cxx11_tensor_image_patch() +EIGEN_DECLARE_TEST(cxx11_tensor_image_patch) { CALL_SUBTEST_1(test_simple_patch()); CALL_SUBTEST_2(test_patch_no_extra_dim()); @@ -754,4 +805,5 @@ void test_cxx11_tensor_image_patch() CALL_SUBTEST_4(test_patch_padding_valid_same_value()); CALL_SUBTEST_5(test_patch_padding_same()); CALL_SUBTEST_6(test_imagenet_patches()); + CALL_SUBTEST_7(test_patch_padding_same_negative_padding_clip_to_zero()); } |