summaryrefslogtreecommitdiff
path: root/nn/common
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common')
-rw-r--r--nn/common/OperationsUtils.cpp18
-rw-r--r--nn/common/operations/StridedSlice.cpp5
2 files changed, 18 insertions, 5 deletions
diff --git a/nn/common/OperationsUtils.cpp b/nn/common/OperationsUtils.cpp
index e5031473e..e8dd3e2ba 100644
--- a/nn/common/OperationsUtils.cpp
+++ b/nn/common/OperationsUtils.cpp
@@ -658,6 +658,10 @@ bool meanPrepare(const Shape& input, const int32_t* axisData, const Shape& axisS
outDims[idx - numSkipAxis] = getSizeOfDimension(input, idx);
}
}
+ // Handle the case when all dimensions are removed
+ if (outDims.empty()) {
+ outDims.push_back(1);
+ }
output->dimensions = outDims;
}
@@ -675,11 +679,15 @@ bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output) {
// Copy the input dimensions, omitting the axis dimension.
output->dimensions.clear();
- output->dimensions.reserve(getNumberOfDimensions(input) - 1);
- output->dimensions.insert(output->dimensions.end(), input.dimensions.begin(),
- input.dimensions.begin() + axis);
- output->dimensions.insert(output->dimensions.end(), input.dimensions.begin() + axis + 1,
- input.dimensions.end());
+ if (getNumberOfDimensions(input) > 1) {
+ output->dimensions.reserve(getNumberOfDimensions(input) - 1);
+ output->dimensions.insert(output->dimensions.end(), input.dimensions.begin(),
+ input.dimensions.begin() + axis);
+ output->dimensions.insert(output->dimensions.end(), input.dimensions.begin() + axis + 1,
+ input.dimensions.end());
+ } else {
+ output->dimensions.push_back(1);
+ }
return true;
}
diff --git a/nn/common/operations/StridedSlice.cpp b/nn/common/operations/StridedSlice.cpp
index cd972d3de..5ff5aeca8 100644
--- a/nn/common/operations/StridedSlice.cpp
+++ b/nn/common/operations/StridedSlice.cpp
@@ -191,6 +191,11 @@ bool prepare(IOperationExecutionContext* context) {
}
}
+ // Handle the case when all dimensions are removed
+ if (outDims.empty()) {
+ outDims.push_back(1);
+ }
+
Shape outputShape = context->getOutputShape(kOutputTensor);
NN_RET_CHECK(SetShape(inputShape, &outputShape));
outputShape.dimensions = outDims;