diff options
Diffstat (limited to 'nn/common')
-rw-r--r-- | nn/common/OperationsUtils.cpp | 18 | ||||
-rw-r--r-- | nn/common/operations/StridedSlice.cpp | 5 |
2 files changed, 18 insertions, 5 deletions
diff --git a/nn/common/OperationsUtils.cpp b/nn/common/OperationsUtils.cpp index e5031473e..e8dd3e2ba 100644 --- a/nn/common/OperationsUtils.cpp +++ b/nn/common/OperationsUtils.cpp @@ -658,6 +658,10 @@ bool meanPrepare(const Shape& input, const int32_t* axisData, const Shape& axisS outDims[idx - numSkipAxis] = getSizeOfDimension(input, idx); } } + // Handle the case when all dimensions are removed + if (outDims.empty()) { + outDims.push_back(1); + } output->dimensions = outDims; } @@ -675,11 +679,15 @@ bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output) { // Copy the input dimensions, omitting the axis dimension. output->dimensions.clear(); - output->dimensions.reserve(getNumberOfDimensions(input) - 1); - output->dimensions.insert(output->dimensions.end(), input.dimensions.begin(), - input.dimensions.begin() + axis); - output->dimensions.insert(output->dimensions.end(), input.dimensions.begin() + axis + 1, - input.dimensions.end()); + if (getNumberOfDimensions(input) > 1) { + output->dimensions.reserve(getNumberOfDimensions(input) - 1); + output->dimensions.insert(output->dimensions.end(), input.dimensions.begin(), + input.dimensions.begin() + axis); + output->dimensions.insert(output->dimensions.end(), input.dimensions.begin() + axis + 1, + input.dimensions.end()); + } else { + output->dimensions.push_back(1); + } return true; } diff --git a/nn/common/operations/StridedSlice.cpp b/nn/common/operations/StridedSlice.cpp index cd972d3de..5ff5aeca8 100644 --- a/nn/common/operations/StridedSlice.cpp +++ b/nn/common/operations/StridedSlice.cpp @@ -191,6 +191,11 @@ bool prepare(IOperationExecutionContext* context) { } } + // Handle the case when all dimensions are removed + if (outDims.empty()) { + outDims.push_back(1); + } + Shape outputShape = context->getOutputShape(kOutputTensor); NN_RET_CHECK(SetShape(inputShape, &outputShape)); outputShape.dimensions = outDims; |