summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTreeHugger Robot <treehugger-gerrit@google.com>2017-10-10 23:41:35 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2017-10-10 23:41:35 +0000
commit07e922eea3f97a7154e967c4c2b17189b519fe77 (patch)
treeeff63fe50cb33f48adc4c772e31964aebbbfe1cf
parent56511a7eb83287dfbbcae7fe1d38be8aa43dc166 (diff)
parente38ec877d1f3a47d889567ffc391236fceb1be30 (diff)
downloadml-07e922eea3f97a7154e967c4c2b17189b519fe77.tar.gz
Merge "Fix the helper function converting explict padding to implicit padding" into oc-mr1-dev
-rw-r--r--nn/common/include/OperationsUtils.h44
1 files changed, 24 insertions, 20 deletions
diff --git a/nn/common/include/OperationsUtils.h b/nn/common/include/OperationsUtils.h
index 80efacdfd..aaca0c083 100644
--- a/nn/common/include/OperationsUtils.h
+++ b/nn/common/include/OperationsUtils.h
@@ -45,26 +45,6 @@ enum PaddingScheme {
kPaddingValid = 2,
};
-inline PaddingScheme getPaddingScheme(uint32_t filterWidth, uint32_t filterHeight,
- uint32_t paddingLeft, uint32_t paddingRight,
- uint32_t paddingTop, uint32_t paddingBottom) {
- if (paddingLeft > paddingRight || paddingTop > paddingBottom) {
- return kPaddingUnknown;
- }
-
- uint32_t totolPaddingWidth = paddingLeft + paddingRight;
- uint32_t totolPaddingHeight = paddingTop + paddingBottom;
- if (totolPaddingWidth == filterWidth - 1 &&
- totolPaddingHeight == filterHeight -1) {
- return kPaddingSame;
- } else if (totolPaddingWidth == 0 &&
- totolPaddingHeight == 0) {
- return kPaddingValid;
- } else {
- return kPaddingUnknown;
- }
-}
-
// The type and dimensions of an operand.
struct Shape {
OperandType type;
@@ -132,6 +112,30 @@ inline void calculateExplicitPadding(int32_t in_size, int32_t stride,
}
}
+inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight,
+ int32_t strideWidth, int32_t strideHeight,
+ int32_t filterWidth, int32_t filterHeight,
+ int32_t paddingLeft, int32_t paddingRight,
+ int32_t paddingTop, int32_t paddingBottom) {
+ if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && paddingBottom == 0) {
+ return kPaddingValid;
+ }
+
+ int32_t expectedPaddingLeft, expectedPaddingRight;
+ int32_t expectedPaddingTop, expectedPaddingBottom;
+
+ calculateExplicitPadding(inWidth, strideWidth, filterWidth, kPaddingSame,
+ &expectedPaddingLeft, &expectedPaddingRight);
+ calculateExplicitPadding(inHeight, strideHeight, filterHeight, kPaddingSame,
+ &expectedPaddingTop, &expectedPaddingBottom);
+ if (expectedPaddingLeft == paddingLeft && expectedPaddingRight == paddingRight &&
+ expectedPaddingTop == paddingTop && expectedPaddingBottom == paddingBottom) {
+ return kPaddingSame;
+ } else {
+ return kPaddingUnknown;
+ }
+}
+
// Preparation functions for the corresponding ops
bool addMulPrepare(const Shape& in1, const Shape& in2, Shape* out1);