diff options
author | Xusong Wang <xusongw@google.com> | 2020-05-26 17:43:13 +0000 |
---|---|---|
committer | Android (Google) Code Review <android-gerrit@google.com> | 2020-05-26 17:43:13 +0000 |
commit | 66e5923200afc965bf19b880737e9180e9f5c909 (patch) | |
tree | 744bfbff7a4b6062d8011b0721f4577e4c5a4a73 | |
parent | 0824d7c6d6821941bde2d1b82efb7982ff7cc8a4 (diff) | |
parent | f0af901e251b46938ceca80658b5cefc67fc7b6d (diff) | |
download | ml-66e5923200afc965bf19b880737e9180e9f5c909.tar.gz |
Merge changes Ib3b191cc,I9afea607 into rvc-dev
* changes:
Fix FULLY_CONNECTED issue with unknown num_units.
Fix CAST issue with outputs of unknown rank.
-rw-r--r-- | nn/common/operations/Cast.cpp | 8 | ||||
-rw-r--r-- | nn/common/operations/FullyConnected.cpp | 5 |
2 files changed, 7 insertions, 6 deletions
diff --git a/nn/common/operations/Cast.cpp b/nn/common/operations/Cast.cpp index f8ca4022e..77e35afb0 100644 --- a/nn/common/operations/Cast.cpp +++ b/nn/common/operations/Cast.cpp @@ -17,12 +17,13 @@ #define LOG_TAG "Operations" #include "Cast.h" + +#include <algorithm> + #include "HalInterfaces.h" #include "Operations.h" #include "Tracing.h" -#include <algorithm> - namespace android { namespace nn { namespace cast { @@ -67,9 +68,6 @@ bool copyToTensor(const FromT* inputData, int numElements, uint8_t* outputData, } // namespace bool prepare(const Shape& input, Shape* output) { - if (input.dimensions.size() != output->dimensions.size()) { - return false; - } output->dimensions = input.dimensions; return true; } diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp index 71808c0b7..9bdd0bab2 100644 --- a/nn/common/operations/FullyConnected.cpp +++ b/nn/common/operations/FullyConnected.cpp @@ -200,11 +200,14 @@ bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias, uint32_t input_n_elements = getNumberOfElements(input); uint32_t num_units = getSizeOfDimension(weights, 0); uint32_t input_size = getSizeOfDimension(weights, 1); + uint32_t bias_len = getSizeOfDimension(bias, 0); uint32_t batch_size = input_size == 0 ? 0 : input_n_elements / input_size; if (batch_size != 0) { NN_RET_CHECK_EQ(input_size * batch_size, input_n_elements); } - NN_RET_CHECK_EQ(getSizeOfDimension(bias, 0), num_units); + if (num_units != 0 && bias_len != 0) { + NN_RET_CHECK_EQ(bias_len, num_units); + } if (output != nullptr) { // Only batch_size can be 0. NN_RET_CHECK_GT(num_units, 0); |