diff options
author | Michael Butler <butlermichael@google.com> | 2019-07-22 18:59:46 -0700 |
---|---|---|
committer | Slava Shklyaev <slavash@google.com> | 2019-08-23 11:42:41 +0100 |
commit | 43953b8f3976fe83c4b04322d4e855cba0688b1e (patch) | |
tree | 0a6719d328cfe7adeed49f814412e03dde303ad9 | |
parent | a1846f57b824acda3616a0053bda3912b3f591ac (diff) | |
download | ml-43953b8f3976fe83c4b04322d4e855cba0688b1e.tar.gz |
clang-format for frameworks/ml/nn
This CL formats all of frameworks/ml/nn/* with the following commands:
$ $CLANG_DIR/clang-format --style=file -i `find $NNAPI_DIR -name "*.cpp"`
$ $CLANG_DIR/clang-format --style=file -i `find $NNAPI_DIR -name "*.h"`
where:
* "NNAPI_DIR" is "$ANDROID_BUILD_TOP/frameworks/ml/nn"
* "CLANG_DIR" is "$ANDROID_BUILD_TOP/prebuilts/clang/host/linux-x86/clang-stable/bin"
Bug: N/A
Test: mma
Change-Id: Idddbc7ecaeab76fb0bbee4250830333752a1f29b
Merged-In: Idddbc7ecaeab76fb0bbee4250830333752a1f29b
(cherry picked from commit 67e41a5467d7879b34f613069ade6cf61d5bd633)
57 files changed, 2650 insertions, 3073 deletions
diff --git a/nn/common/GraphDump.cpp b/nn/common/GraphDump.cpp index 9fe6bf31e..b057692df 100644 --- a/nn/common/GraphDump.cpp +++ b/nn/common/GraphDump.cpp @@ -46,8 +46,8 @@ using namespace hal; // namespace { class Dumper { -public: - Dumper(std::ostream* outStream) : mStream(outStream) { } + public: + Dumper(std::ostream* outStream) : mStream(outStream) {} Dumper(const Dumper&) = delete; void operator=(const Dumper&) = delete; @@ -58,7 +58,7 @@ public: return *this; } - class EndlType { }; + class EndlType {}; Dumper& operator<<(EndlType) { if (mStream) { @@ -81,27 +81,36 @@ public: } static const EndlType endl; -private: + + private: std::ostream* mStream; std::ostringstream mStringStream; }; const Dumper::EndlType Dumper::endl; -} - +} // namespace // Provide short name for OperandType value. static std::string translate(OperandType type) { switch (type) { - case OperandType::FLOAT32: return "F32"; - case OperandType::INT32: return "I32"; - case OperandType::UINT32: return "U32"; - case OperandType::TENSOR_FLOAT32: return "TF32"; - case OperandType::TENSOR_INT32: return "TI32"; - case OperandType::TENSOR_QUANT8_ASYMM: return "TQ8A"; - case OperandType::OEM: return "OEM"; - case OperandType::TENSOR_OEM_BYTE: return "TOEMB"; - default: return toString(type); + case OperandType::FLOAT32: + return "F32"; + case OperandType::INT32: + return "I32"; + case OperandType::UINT32: + return "U32"; + case OperandType::TENSOR_FLOAT32: + return "TF32"; + case OperandType::TENSOR_INT32: + return "TI32"; + case OperandType::TENSOR_QUANT8_ASYMM: + return "TQ8A"; + case OperandType::OEM: + return "OEM"; + case OperandType::TENSOR_OEM_BYTE: + return "TOEMB"; + default: + return toString(type); } } @@ -110,10 +119,9 @@ static std::string translate(OperandType type) { // OperandLifeTime::CONSTANT_COPY, then write the Operand's value to // the Dumper. namespace { -template<OperandType nnType, typename cppType> +template <OperandType nnType, typename cppType> void tryValueDump(Dumper& dump, const Model& model, const Operand& opnd) { - if (opnd.type != nnType || - opnd.lifetime != OperandLifeTime::CONSTANT_COPY || + if (opnd.type != nnType || opnd.lifetime != OperandLifeTime::CONSTANT_COPY || opnd.location.length != sizeof(cppType)) { return; } @@ -122,7 +130,7 @@ void tryValueDump(Dumper& dump, const Model& model, const Operand& opnd) { memcpy(&val, &model.operandValues[opnd.location.offset], sizeof(cppType)); dump << " = " << val; } -} +} // namespace void graphDump(const char* name, const Model& model, std::ostream* outStream) { // Operand nodes are named "d" (operanD) followed by operand index. @@ -182,8 +190,8 @@ void graphDump(const char* name, const Model& model, std::ostream* outStream) { dump << ": " << kind; } dump << "\\n" << translate(opnd.type); - tryValueDump<OperandType::FLOAT32, float>(dump, model, opnd); - tryValueDump<OperandType::INT32, int>(dump, model, opnd); + tryValueDump<OperandType::FLOAT32, float>(dump, model, opnd); + tryValueDump<OperandType::INT32, int>(dump, model, opnd); tryValueDump<OperandType::UINT32, unsigned>(dump, model, opnd); if (opnd.dimensions.size()) { dump << "("; @@ -210,8 +218,7 @@ void graphDump(const char* name, const Model& model, std::ostream* outStream) { dump << " ordering=out"; } } - dump << " label=\"" << i << ": " - << toString(operation.type) << "\"]" << Dumper::endl; + dump << " label=\"" << i << ": " << toString(operation.type) << "\"]" << Dumper::endl; { // operation inputs for (unsigned in = 0, inE = operation.inputs.size(); in < inE; in++) { diff --git a/nn/common/OperationsUtils.cpp b/nn/common/OperationsUtils.cpp index 9219dddb7..cec15fc94 100644 --- a/nn/common/OperationsUtils.cpp +++ b/nn/common/OperationsUtils.cpp @@ -122,8 +122,7 @@ uint32_t getNumberOfElements(const Shape& shape) { return count; } -uint32_t getNumberOfElements(const Shape& shape, - size_t firstAxisInclusive, +uint32_t getNumberOfElements(const Shape& shape, size_t firstAxisInclusive, size_t lastAxisExclusive) { nnAssert(0 <= firstAxisInclusive); nnAssert(firstAxisInclusive <= lastAxisExclusive); @@ -170,8 +169,7 @@ bool QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, return true; } -bool QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, +bool QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t* quantized_multiplier, int32_t* right_shift) { NN_OPS_CHECK(double_multiplier >= 0.); NN_OPS_CHECK(double_multiplier < 1.); @@ -195,8 +193,7 @@ bool QuantizeMultiplierSmallerThanOne(double double_multiplier, return true; } -bool QuantizeMultiplierGreaterThanOne(double double_multiplier, - int32_t* quantized_multiplier, +bool QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier, int* left_shift) { NN_OPS_CHECK(double_multiplier > 1.); const double q = std::frexp(double_multiplier, left_shift); @@ -221,15 +218,13 @@ bool GetQuantizedConvolutionMultipler(const Shape& inputShape, const Shape& filt // The following conditions must be guaranteed by the training pipeline. NN_OPS_CHECK(std::abs(input_product_scale - bias_scale) <= - 1e-6 * std::min(input_product_scale, bias_scale)); + 1e-6 * std::min(input_product_scale, bias_scale)); NN_OPS_CHECK(input_product_scale >= 0); *multiplier = input_product_scale / outputShape.scale; return true; } -void CalculateActivationRangeUint8(int32_t activation, - const Shape& outputShape, - int32_t* act_min, +void CalculateActivationRangeUint8(int32_t activation, const Shape& outputShape, int32_t* act_min, int32_t* act_max) { const int32_t qmin = std::numeric_limits<uint8_t>::min(); const int32_t qmax = std::numeric_limits<uint8_t>::max(); @@ -250,7 +245,7 @@ void CalculateActivationRangeUint8(int32_t activation, } else if (activation == kActivationRelu1) { *act_min = std::max(qmin, quantize(-1.0)); *act_max = std::min(qmax, quantize(1.0)); - } else if (activation == kActivationNone){ + } else if (activation == kActivationNone) { *act_min = qmin; *act_max = qmax; } else { @@ -258,8 +253,7 @@ void CalculateActivationRangeUint8(int32_t activation, } } -void CalculateActivationRangeFloat(int32_t activation, - float* activation_min, +void CalculateActivationRangeFloat(int32_t activation, float* activation_min, float* activation_max) { if (activation == kActivationRelu) { *activation_min = 0.f; @@ -270,7 +264,7 @@ void CalculateActivationRangeFloat(int32_t activation, } else if (activation == kActivationRelu1) { *activation_min = -1.f; *activation_max = 1.f; - } else if (activation == kActivationNone){ + } else if (activation == kActivationNone) { *activation_min = std::numeric_limits<float>::lowest(); *activation_max = std::numeric_limits<float>::max(); } else { @@ -372,11 +366,11 @@ bool depthwiseConvPrepare(const Shape& input, const Shape& filter, const Shape& uint32_t channels_out = getSizeOfDimension(filter, 3); uint32_t channels_in = getSizeOfDimension(input, 3); - uint32_t width = getSizeOfDimension(input, 2); - uint32_t height = getSizeOfDimension(input, 1); - uint32_t filterWidth = getSizeOfDimension(filter, 2); + uint32_t width = getSizeOfDimension(input, 2); + uint32_t height = getSizeOfDimension(input, 1); + uint32_t filterWidth = getSizeOfDimension(filter, 2); uint32_t filterHeight = getSizeOfDimension(filter, 1); - uint32_t batches = getSizeOfDimension(input, 0); + uint32_t batches = getSizeOfDimension(input, 0); NN_OPS_CHECK(depth_multiplier * channels_in == channels_out); int32_t effectiveFilterWidth = (filterWidth - 1) * dilation_width_factor + 1; @@ -396,8 +390,7 @@ bool depthwiseConvPrepare(const Shape& input, const Shape& filter, const Shape& return true; } -bool genericActivationPrepare(const Shape& input, - Shape* output) { +bool genericActivationPrepare(const Shape& input, Shape* output) { NN_OPS_CHECK(getNumberOfDimensions(input) <= 4); return SetShape(input, output); } @@ -406,15 +399,13 @@ bool genericNormalizationPrepare(const Shape& input, Shape* output) { return SetShape(input, output); } -bool reshapePrepare(const Shape& input, - const int32_t* targetDims, - const int32_t targetDimsSize, +bool reshapePrepare(const Shape& input, const int32_t* targetDims, const int32_t targetDimsSize, Shape* output) { // Reshape allows one of the targetDims components to have the // special -1 value, meaning it will be calculated automatically based on the // input. Here we calculate what that dimension should be so that the number // of output elements in the same as the number of input elements. - int32_t numInputElements = (int32_t) getNumberOfElements(input); + int32_t numInputElements = (int32_t)getNumberOfElements(input); std::vector<uint32_t> outDims(targetDimsSize); int32_t numOutputElements = 1; @@ -431,7 +422,7 @@ bool reshapePrepare(const Shape& input, } if (strechDim != -1) { int32_t strechValue = numInputElements / numOutputElements; - outDims[strechDim] = (uint32_t) strechValue; + outDims[strechDim] = (uint32_t)strechValue; numOutputElements *= strechValue; } @@ -445,22 +436,18 @@ bool reshapePrepare(const Shape& input, return true; } -bool depthToSpacePrepare(const Shape& input, - int32_t blockSize, - Shape* output) { +bool depthToSpacePrepare(const Shape& input, int32_t blockSize, Shape* output) { NN_OPS_CHECK(getNumberOfDimensions(input) == 4); NN_OPS_CHECK(blockSize > 0); - uint32_t batches = getSizeOfDimension(input, 0); - uint32_t height = getSizeOfDimension(input, 1); - uint32_t width = getSizeOfDimension(input, 2); + uint32_t batches = getSizeOfDimension(input, 0); + uint32_t height = getSizeOfDimension(input, 1); + uint32_t width = getSizeOfDimension(input, 2); uint32_t channels = getSizeOfDimension(input, 3); NN_OPS_CHECK(channels % (blockSize * blockSize) == 0); output->type = input.type; - output->dimensions = {batches, - height * blockSize, - width * blockSize, + output->dimensions = {batches, height * blockSize, width * blockSize, channels / (blockSize * blockSize)}; output->offset = input.offset; output->scale = input.scale; @@ -468,24 +455,20 @@ bool depthToSpacePrepare(const Shape& input, return true; } -bool spaceToDepthPrepare(const Shape& input, - int32_t blockSize, - Shape* output) { +bool spaceToDepthPrepare(const Shape& input, int32_t blockSize, Shape* output) { NN_OPS_CHECK(getNumberOfDimensions(input) == 4); NN_OPS_CHECK(blockSize > 0); - uint32_t batches = getSizeOfDimension(input, 0); - uint32_t height = getSizeOfDimension(input, 1); - uint32_t width = getSizeOfDimension(input, 2); + uint32_t batches = getSizeOfDimension(input, 0); + uint32_t height = getSizeOfDimension(input, 1); + uint32_t width = getSizeOfDimension(input, 2); uint32_t channels = getSizeOfDimension(input, 3); NN_OPS_CHECK(height % blockSize == 0); NN_OPS_CHECK(width % blockSize == 0); output->type = input.type; - output->dimensions = {batches, - height / blockSize, - width / blockSize, + output->dimensions = {batches, height / blockSize, width / blockSize, channels * (blockSize * blockSize)}; output->offset = input.offset; output->scale = input.scale; @@ -493,19 +476,17 @@ bool spaceToDepthPrepare(const Shape& input, return true; } -bool embeddingLookupPrepare(const Shape &valueShape, - const Shape &lookupShape, - Shape *outputShape) { +bool embeddingLookupPrepare(const Shape& valueShape, const Shape& lookupShape, Shape* outputShape) { NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 2); NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1); - const uint32_t rows = getSizeOfDimension(valueShape, 0); - const uint32_t columns = getSizeOfDimension(valueShape, 1); + const uint32_t rows = getSizeOfDimension(valueShape, 0); + const uint32_t columns = getSizeOfDimension(valueShape, 1); - const uint32_t lookups = getSizeOfDimension(lookupShape, 0); + const uint32_t lookups = getSizeOfDimension(lookupShape, 0); outputShape->type = valueShape.type; - outputShape->dimensions = { lookups, columns }; + outputShape->dimensions = {lookups, columns}; for (uint32_t i = 2; i < getNumberOfDimensions(valueShape); i++) { outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i)); } @@ -515,20 +496,17 @@ bool embeddingLookupPrepare(const Shape &valueShape, return true; } -bool hashtableLookupPrepare(const Shape &lookupShape, - const Shape &keyShape, - const Shape &valueShape, - Shape *outputShape, - Shape *hitShape) { +bool hashtableLookupPrepare(const Shape& lookupShape, const Shape& keyShape, + const Shape& valueShape, Shape* outputShape, Shape* hitShape) { NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1); NN_OPS_CHECK(getNumberOfDimensions(keyShape) == 1); NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 1); - const uint32_t lookups = getSizeOfDimension(lookupShape, 0); - const uint32_t keys = getSizeOfDimension(keyShape, 0); - const uint32_t rows = getSizeOfDimension(valueShape, 0); + const uint32_t lookups = getSizeOfDimension(lookupShape, 0); + const uint32_t keys = getSizeOfDimension(keyShape, 0); + const uint32_t rows = getSizeOfDimension(valueShape, 0); outputShape->type = valueShape.type; - outputShape->dimensions = { lookups }; + outputShape->dimensions = {lookups}; for (uint32_t i = 1; i < getNumberOfDimensions(valueShape); i++) { outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i)); } @@ -536,16 +514,14 @@ bool hashtableLookupPrepare(const Shape &lookupShape, outputShape->scale = valueShape.scale; hitShape->type = OperandType::TENSOR_QUANT8_ASYMM; - hitShape->dimensions = { lookups }; + hitShape->dimensions = {lookups}; hitShape->offset = 0; hitShape->scale = 1.f; return true; } -bool padPrepare(const Shape& input, - const int32_t* paddingsData, - const Shape& paddingsShape, +bool padPrepare(const Shape& input, const int32_t* paddingsData, const Shape& paddingsShape, Shape* output) { uint32_t numInputDims = getNumberOfDimensions(input); @@ -571,10 +547,8 @@ bool padPrepare(const Shape& input, return true; } -bool batchToSpacePrepare(const Shape& input, - const int32_t* blockSizeData, - const Shape& blockSizeShape, - Shape* output) { +bool batchToSpacePrepare(const Shape& input, const int32_t* blockSizeData, + const Shape& blockSizeShape, Shape* output) { // Only 4D NHWC tensors are supported. NN_OPS_CHECK(getNumberOfDimensions(input) == 4); @@ -584,29 +558,24 @@ bool batchToSpacePrepare(const Shape& input, // Only applies to spatial dimensions. NN_OPS_CHECK(getSizeOfDimension(blockSizeShape, 0) == 2); - uint32_t batches = getSizeOfDimension(input, 0); - uint32_t height = getSizeOfDimension(input, 1); - uint32_t width = getSizeOfDimension(input, 2); + uint32_t batches = getSizeOfDimension(input, 0); + uint32_t height = getSizeOfDimension(input, 1); + uint32_t width = getSizeOfDimension(input, 2); uint32_t channels = getSizeOfDimension(input, 3); NN_OPS_CHECK(batches % (blockSizeData[0] * blockSizeData[1]) == 0); output->type = input.type; output->dimensions = {batches / (blockSizeData[0] * blockSizeData[1]), - height * blockSizeData[0], - width * blockSizeData[1], - channels}; + height * blockSizeData[0], width * blockSizeData[1], channels}; output->offset = input.offset; output->scale = input.scale; return true; } -bool spaceToBatchPrepare(const Shape& input, - const int32_t* blockSizeData, - const Shape& blockSizeShape, - const int32_t* paddingsData, - const Shape& paddingsShape, - Shape* output) { +bool spaceToBatchPrepare(const Shape& input, const int32_t* blockSizeData, + const Shape& blockSizeShape, const int32_t* paddingsData, + const Shape& paddingsShape, Shape* output) { // Only 4D NHWC tensors are supported. NN_OPS_CHECK(getNumberOfDimensions(input) == 4); @@ -622,9 +591,9 @@ bool spaceToBatchPrepare(const Shape& input, NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 0) == 2); NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 1) == 2); - uint32_t batches = getSizeOfDimension(input, 0); - uint32_t height = getSizeOfDimension(input, 1); - uint32_t width = getSizeOfDimension(input, 2); + uint32_t batches = getSizeOfDimension(input, 0); + uint32_t height = getSizeOfDimension(input, 1); + uint32_t width = getSizeOfDimension(input, 2); uint32_t channels = getSizeOfDimension(input, 3); uint32_t paddedHeight = paddingsData[0] + height + paddingsData[1]; @@ -635,8 +604,7 @@ bool spaceToBatchPrepare(const Shape& input, output->type = input.type; output->dimensions = {batches * (blockSizeData[0] * blockSizeData[1]), - paddedHeight / blockSizeData[0], - paddedWidth / blockSizeData[1], + paddedHeight / blockSizeData[0], paddedWidth / blockSizeData[1], channels}; output->offset = input.offset; output->scale = input.scale; @@ -644,9 +612,7 @@ bool spaceToBatchPrepare(const Shape& input, return true; } -bool squeezePrepare(const Shape& input, - const int32_t* squeezeDims, - const Shape& squeezeDimsShape, +bool squeezePrepare(const Shape& input, const int32_t* squeezeDims, const Shape& squeezeDimsShape, Shape* output) { int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(input)); @@ -668,13 +634,13 @@ bool squeezePrepare(const Shape& input, } } else { for (int32_t idx = 0; idx < squeezeDimsSize; ++idx) { - int32_t current = squeezeDims[idx] < 0 ? squeezeDims[idx] + numInputDims - : squeezeDims[idx]; + int32_t current = + squeezeDims[idx] < 0 ? squeezeDims[idx] + numInputDims : squeezeDims[idx]; NN_OPS_CHECK(current >= 0 && current < numInputDims && getSizeOfDimension(input, current) == 1); if (!shouldSqueeze[current]) ++numDimsSqueezed; shouldSqueeze[current] = true; - } + } } // Sets output dimensions. @@ -693,12 +659,8 @@ bool squeezePrepare(const Shape& input, return true; } -bool meanPrepare(const Shape& input, - const int32_t* axisData, - const Shape& axisShape, - bool keepDims, +bool meanPrepare(const Shape& input, const int32_t* axisData, const Shape& axisShape, bool keepDims, Shape* output) { - // perm need to be provided as a 1-D int32 tensor. NN_OPS_CHECK(axisShape.type == OperandType::TENSOR_INT32); NN_OPS_CHECK(getNumberOfDimensions(axisShape) == 1); @@ -770,12 +732,10 @@ bool meanPrepare(const Shape& input, return true; } -bool stridedSlicePrepare(const Shape& input, - const int32_t* beginData, const Shape& beginShape, - const int32_t* endData, const Shape& endShape, - const int32_t* stridesData, const Shape& stridesShape, - int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask, - Shape* output) { +bool stridedSlicePrepare(const Shape& input, const int32_t* beginData, const Shape& beginShape, + const int32_t* endData, const Shape& endShape, const int32_t* stridesData, + const Shape& stridesShape, int32_t beginMask, int32_t endMask, + int32_t shrinkAxisMask, Shape* output) { uint32_t numInputDims = getNumberOfDimensions(input); // StridedSlice op only supports 1D-4D input arrays. NN_OPS_CHECK(numInputDims <= 4); @@ -795,30 +755,28 @@ bool stridedSlicePrepare(const Shape& input, // Determine size of output tensor and map indices std::vector<uint32_t> outDims; for (int32_t idx = 0; idx < static_cast<int32_t>(numInputDims); idx++) { - int32_t dim = static_cast<int32_t>(getSizeOfDimension(input, idx)); - int32_t stride = stridesData[idx]; - // stride value has to be non-zero - NN_OPS_CHECK(stride != 0); - bool positiveStride = stride > 0; - - int32_t begin = beginMask & (1 << idx) - ? positiveStride ? 0 : dim - 1 - : ClampedIndex(beginData[idx], dim, positiveStride); - int32_t end = endMask & (1 << idx) - ? positiveStride ? dim : -1 - : ClampedIndex(endData[idx], dim, positiveStride); - - // This is valid for both positive and negative strides - int32_t outDim = ceil((end - begin) / static_cast<float>(stride)); - outDim = outDim < 0 ? 0 : static_cast<uint32_t>(outDim); - if (!(shrinkAxisMask & (1 << idx))) { - outDims.push_back(outDim); - } else { - if (outDim != 1) { - LOG(ERROR) << "Outdim " << idx << " is " << outDim << ", expected 1"; - NN_OPS_CHECK(outDim == 1); - } - } + int32_t dim = static_cast<int32_t>(getSizeOfDimension(input, idx)); + int32_t stride = stridesData[idx]; + // stride value has to be non-zero + NN_OPS_CHECK(stride != 0); + bool positiveStride = stride > 0; + + int32_t begin = beginMask & (1 << idx) ? positiveStride ? 0 : dim - 1 + : ClampedIndex(beginData[idx], dim, positiveStride); + int32_t end = endMask & (1 << idx) ? positiveStride ? dim : -1 + : ClampedIndex(endData[idx], dim, positiveStride); + + // This is valid for both positive and negative strides + int32_t outDim = ceil((end - begin) / static_cast<float>(stride)); + outDim = outDim < 0 ? 0 : static_cast<uint32_t>(outDim); + if (!(shrinkAxisMask & (1 << idx))) { + outDims.push_back(outDim); + } else { + if (outDim != 1) { + LOG(ERROR) << "Outdim " << idx << " is " << outDim << ", expected 1"; + NN_OPS_CHECK(outDim == 1); + } + } } output->type = input.type; @@ -837,11 +795,9 @@ bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output) { // Copy the input dimensions, omitting the axis dimension. output->dimensions.clear(); output->dimensions.reserve(getNumberOfDimensions(input) - 1); - output->dimensions.insert(output->dimensions.end(), - input.dimensions.begin(), + output->dimensions.insert(output->dimensions.end(), input.dimensions.begin(), input.dimensions.begin() + axis); - output->dimensions.insert(output->dimensions.end(), - input.dimensions.begin() + axis + 1, + output->dimensions.insert(output->dimensions.end(), input.dimensions.begin() + axis + 1, input.dimensions.end()); return true; @@ -910,5 +866,5 @@ bool groupedConvPrepare(const Shape& input, const Shape& filter, const Shape& bi return true; } -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index fc5493270..11f3c2abd 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -54,15 +54,14 @@ void initVLogMask() { return; } - std::unordered_map<std::string, int> vLogFlags = { - {"1", -1}, - {"all", -1}, - {"model", MODEL}, - {"compilation", COMPILATION}, - {"execution", EXECUTION}, - {"cpuexe", CPUEXE}, - {"manager", MANAGER}, - {"driver", DRIVER}}; + std::unordered_map<std::string, int> vLogFlags = {{"1", -1}, + {"all", -1}, + {"model", MODEL}, + {"compilation", COMPILATION}, + {"execution", EXECUTION}, + {"cpuexe", CPUEXE}, + {"manager", MANAGER}, + {"driver", DRIVER}}; std::vector<std::string> elements = android::base::Split(vLogSetting, " ,:"); for (const auto& elem : elements) { @@ -103,8 +102,7 @@ namespace { template <typename EntryType, uint32_t entryCount, uint32_t entryCountOEM> EntryType tableLookup(const EntryType (&table)[entryCount], - const EntryType (&tableOEM)[entryCountOEM], - uint32_t code) { + const EntryType (&tableOEM)[entryCountOEM], uint32_t code) { if (code < entryCount) { return table[code]; } else if (code >= kOEMCodeBase && (code - kOEMCodeBase) < entryCountOEM) { @@ -253,16 +251,16 @@ const bool kScalarDataType[]{ static_assert(COUNT(kScalarDataType) == kNumberOfDataTypes, "kScalarDataType is incorrect"); const uint32_t kSizeOfDataTypeOEM[]{ - 0, // ANEURALNETWORKS_OEM - 1, // ANEURALNETWORKS_TENSOR_OEM_BYTE + 0, // ANEURALNETWORKS_OEM + 1, // ANEURALNETWORKS_TENSOR_OEM_BYTE }; static_assert(COUNT(kSizeOfDataTypeOEM) == kNumberOfDataTypesOEM, "kSizeOfDataTypeOEM is incorrect"); const bool kScalarDataTypeOEM[]{ - true, // ANEURALNETWORKS_OEM - false, // ANEURALNETWORKS_TENSOR_OEM_BYTE + true, // ANEURALNETWORKS_OEM + false, // ANEURALNETWORKS_TENSOR_OEM_BYTE }; static_assert(COUNT(kScalarDataTypeOEM) == kNumberOfDataTypesOEM, @@ -313,11 +311,11 @@ bool tensorHasUnspecifiedDimensions(const Operand& operand) { uint32_t alignBytesNeeded(uint32_t index, size_t length) { uint32_t pattern; if (length < 2) { - pattern = 0; // No alignment necessary + pattern = 0; // No alignment necessary } else if (length < 4) { - pattern = 1; // Align on 2-byte boundary + pattern = 1; // Align on 2-byte boundary } else { - pattern = 3; // Align on 4-byte boundary + pattern = 3; // Align on 4-byte boundary } uint32_t extra = (~(index - 1)) & pattern; return extra; @@ -479,8 +477,8 @@ int validateOperandList(uint32_t count, const uint32_t* list, uint32_t operandCo return ANEURALNETWORKS_NO_ERROR; } -int validateOperationOperandTypes(const std::vector<Operand>& operands, - uint32_t inOperandCount, const uint32_t* inOperandIndexes, +int validateOperationOperandTypes(const std::vector<Operand>& operands, uint32_t inOperandCount, + const uint32_t* inOperandIndexes, const std::vector<OperandType>& inExpectedTypes, uint32_t outOperandCount, const uint32_t* outOperandIndexes, const std::vector<OperandType>& outExpectedInTypes) { @@ -494,16 +492,16 @@ int validateOperationOperandTypes(const std::vector<Operand>& operands, for (uint32_t i = 0; i < inOperandCount; i++) { if (operands[inOperandIndexes[i]].type != inExpectedTypes[i]) { LOG(ERROR) << "Invalid input tensor type " - << toString(operands[inOperandIndexes[i]].type) - << " for input " << i << ", expected " << toString(inExpectedTypes[i]); + << toString(operands[inOperandIndexes[i]].type) << " for input " << i + << ", expected " << toString(inExpectedTypes[i]); return ANEURALNETWORKS_BAD_DATA; } } for (uint32_t i = 0; i < outOperandCount; i++) { if (operands[outOperandIndexes[i]].type != outExpectedInTypes[i]) { LOG(ERROR) << "Invalid output tensor type " - << toString(operands[outOperandIndexes[i]].type) - << " for input " << i << ", expected " << toString(outExpectedInTypes[i]); + << toString(operands[outOperandIndexes[i]].type) << " for input " << i + << ", expected " << toString(outExpectedInTypes[i]); return ANEURALNETWORKS_BAD_DATA; } } @@ -575,10 +573,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, << getOperationName(opType); return ANEURALNETWORKS_BAD_DATA; } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_DEPTHWISE_CONV_2D: { @@ -683,10 +679,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } else { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION: { @@ -724,10 +718,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } else if (operands[inputIndexes[0]].dimensions.size() != 4) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_RESHAPE: { @@ -740,8 +732,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, std::vector<OperandType> outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - inExpectedTypes = {OperandType::TENSOR_FLOAT32, - OperandType::TENSOR_INT32}; + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); @@ -749,18 +740,15 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, - OperandType::TENSOR_INT32}; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else { LOG(ERROR) << "Unsupported input tensor type for operation " << getOperationName(opType); return ANEURALNETWORKS_BAD_DATA; } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_DEPTH_TO_SPACE: { @@ -775,8 +763,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, std::vector<OperandType> outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - inExpectedTypes = {OperandType::TENSOR_FLOAT32, - OperandType::INT32}; + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); @@ -784,8 +771,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, - OperandType::INT32}; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else { LOG(ERROR) << "Unsupported input tensor type for operation " @@ -798,10 +784,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } else { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_SPACE_TO_DEPTH: { @@ -816,8 +800,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, std::vector<OperandType> outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - inExpectedTypes = {OperandType::TENSOR_FLOAT32, - OperandType::INT32}; + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); @@ -825,8 +808,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, - OperandType::INT32}; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else { LOG(ERROR) << "Unsupported input tensor type for operation " @@ -839,10 +821,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } else { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_EMBEDDING_LOOKUP: { @@ -858,14 +838,11 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, << getOperationName(opType); return ANEURALNETWORKS_BAD_DATA; } - std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, - inputType}; + std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType}; std::vector<OperandType> outExpectedTypes = {inputType}; NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_HASHTABLE_LOOKUP: { @@ -882,15 +859,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, return ANEURALNETWORKS_BAD_DATA; } std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, - OperandType::TENSOR_INT32, - inputType}; + OperandType::TENSOR_INT32, inputType}; std::vector<OperandType> outExpectedTypes = {inputType, OperandType::TENSOR_QUANT8_ASYMM}; NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_LSH_PROJECTION: { @@ -1174,10 +1148,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } else { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1)); } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_SPACE_TO_BATCH_ND: { @@ -1226,10 +1198,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } else { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1)); } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_PAD: { @@ -1359,8 +1329,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, std::vector<OperandType> outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1)); - inExpectedTypes = {OperandType::TENSOR_FLOAT32, - OperandType::TENSOR_INT32}; + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); @@ -1368,18 +1337,15 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1)); - inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, - OperandType::TENSOR_INT32}; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else { LOG(ERROR) << "Unsupported input tensor type for operation " << getOperationName(opType); return ANEURALNETWORKS_BAD_DATA; } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_STRIDED_SLICE: { @@ -1439,8 +1405,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, std::vector<OperandType> outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1)); - inExpectedTypes = {OperandType::TENSOR_FLOAT32, - OperandType::TENSOR_INT32, + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { @@ -1450,8 +1415,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1)); - inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, - OperandType::TENSOR_INT32, + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else { @@ -1459,10 +1423,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, << getOperationName(opType); return ANEURALNETWORKS_BAD_DATA; } - return validateOperationOperandTypes(operands, - inputCount, inputIndexes, - inExpectedTypes, - outputCount, outputIndexes, + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); } case ANEURALNETWORKS_ARGMAX: @@ -1951,8 +1913,8 @@ V1_0::Capabilities convertToV1_0(const V1_1::Capabilities& capabilities) { LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities) << " from V1_1::Capabilities to V1_0::Capabilities"; } - return { .float32Performance = capabilities.float32Performance, - .quantized8Performance = capabilities.quantized8Performance }; + return {.float32Performance = capabilities.float32Performance, + .quantized8Performance = capabilities.quantized8Performance}; } V1_0::Capabilities convertToV1_0(const V1_2::Capabilities& capabilities) { @@ -1967,9 +1929,9 @@ V1_0::Capabilities convertToV1_0(const V1_2::Capabilities& capabilities) { } V1_1::Capabilities convertToV1_1(const V1_0::Capabilities& capabilities) { - return { .float32Performance = capabilities.float32Performance, - .quantized8Performance = capabilities.quantized8Performance, - .relaxedFloat32toFloat16Performance = capabilities.float32Performance }; + return {.float32Performance = capabilities.float32Performance, + .quantized8Performance = capabilities.quantized8Performance, + .relaxedFloat32toFloat16Performance = capabilities.float32Performance}; } V1_1::Capabilities convertToV1_1(const V1_1::Capabilities& capabilities) { @@ -2361,5 +2323,5 @@ uint32_t getProp(const char* str, uint32_t defaultValue) { } #endif // NN_DEBUGGABLE -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android diff --git a/nn/common/ValidateHal.cpp b/nn/common/ValidateHal.cpp index da440d840..a74b6565b 100644 --- a/nn/common/ValidateHal.cpp +++ b/nn/common/ValidateHal.cpp @@ -45,7 +45,7 @@ struct ModelToHalVersion<V1_2::Model> { }; class MemoryAccessVerifier { -public: + public: MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools) : mPoolCount(pools.size()), mPoolSizes(mPoolCount) { for (size_t i = 0; i < mPoolCount; i++) { @@ -68,7 +68,7 @@ public: return true; } -private: + private: size_t mPoolCount; std::vector<size_t> mPoolSizes; }; @@ -567,10 +567,10 @@ static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArg for (size_t i = 0; i < rank; i++) { if (requestArgument.dimensions[i] != operand.dimensions[i] && operand.dimensions[i] != 0) { - LOG(ERROR) << "Request " << type << " " << requestArgumentIndex - << " has dimension " << i << " of " - << requestArgument.dimensions[i] - << " different than the model's " << operand.dimensions[i]; + LOG(ERROR) + << "Request " << type << " " << requestArgumentIndex + << " has dimension " << i << " of " << requestArgument.dimensions[i] + << " different than the model's " << operand.dimensions[i]; return false; } if (requestArgument.dimensions[i] == 0 && !allowUnspecified) { diff --git a/nn/common/include/ActivationFunctor.h b/nn/common/include/ActivationFunctor.h index d667ae1c0..d4d4d3aae 100644 --- a/nn/common/include/ActivationFunctor.h +++ b/nn/common/include/ActivationFunctor.h @@ -33,31 +33,30 @@ enum ActivationFn { }; class ActivationFunctor { - public: - explicit ActivationFunctor(ActivationFn act) : act_(act) {} - - float operator()(float a) const { - switch (act_) { - case kActivationNone: - return a; - case kActivationRelu: - return a < 0.f ? 0.f : a; - case kActivationRelu6: - return std::max(0.f, std::min(a, 6.f)); - case kActivationTanh: - return std::tanh(a); - case kActivationSigmoid: - return 1.0f / (1.0f + std::exp(-a)); - default: - __android_log_print(ANDROID_LOG_ERROR, "NN API", - "Invalid enum value for activation function: 0x%0X", - act_); - abort(); + public: + explicit ActivationFunctor(ActivationFn act) : act_(act) {} + + float operator()(float a) const { + switch (act_) { + case kActivationNone: + return a; + case kActivationRelu: + return a < 0.f ? 0.f : a; + case kActivationRelu6: + return std::max(0.f, std::min(a, 6.f)); + case kActivationTanh: + return std::tanh(a); + case kActivationSigmoid: + return 1.0f / (1.0f + std::exp(-a)); + default: + __android_log_print(ANDROID_LOG_ERROR, "NN API", + "Invalid enum value for activation function: 0x%0X", act_); + abort(); + } } - } - private: - ActivationFn act_; + private: + ActivationFn act_; }; #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_ACTIVATION_FUNCTOR_H diff --git a/nn/common/include/CpuExecutor.h b/nn/common/include/CpuExecutor.h index e589a654f..c0cb1e984 100644 --- a/nn/common/include/CpuExecutor.h +++ b/nn/common/include/CpuExecutor.h @@ -213,11 +213,12 @@ class CpuExecutor { // b/109953668, disable OpenMP #ifdef NNAPI_OPENMP class ScopedOpenmpSettings { -public: + public: ScopedOpenmpSettings(); ~ScopedOpenmpSettings(); DISALLOW_COPY_AND_ASSIGN(ScopedOpenmpSettings); -private: + + private: int mBlocktimeInitial; #if NNAPI_LIMIT_CPU_THREADS int mMaxThreadsInitial; @@ -225,7 +226,6 @@ private: }; #endif // NNAPI_OPENMP - namespace { template <typename T> @@ -235,7 +235,7 @@ T getScalarData(const RunTimeOperandInfo& info) { return data[0]; } -inline bool IsNullInput(const RunTimeOperandInfo *input) { +inline bool IsNullInput(const RunTimeOperandInfo* input) { return input->lifetime == hal::OperandLifeTime::NO_VALUE; } @@ -250,12 +250,12 @@ inline int NumOutputs(const hal::Operation& operation) { return operation.outputs.size(); } -inline size_t NumDimensions(const RunTimeOperandInfo *operand) { - return operand->shape().dimensions.size(); +inline size_t NumDimensions(const RunTimeOperandInfo* operand) { + return operand->shape().dimensions.size(); } -inline uint32_t SizeOfDimension(const RunTimeOperandInfo *operand, int i) { - return operand->shape().dimensions[i]; +inline uint32_t SizeOfDimension(const RunTimeOperandInfo* operand, int i) { + return operand->shape().dimensions[i]; } inline RunTimeOperandInfo* GetInput(const hal::Operation& operation, @@ -270,7 +270,7 @@ inline RunTimeOperandInfo* GetOutput(const hal::Operation& operation, } // anonymous namespace -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_CPU_EXECUTOR_H diff --git a/nn/common/include/GraphDump.h b/nn/common/include/GraphDump.h index 1bf02b90e..bee994b39 100644 --- a/nn/common/include/GraphDump.h +++ b/nn/common/include/GraphDump.h @@ -45,8 +45,7 @@ namespace nn { // A model input or output (operand) is shown in "reverse colors" -- // white text on a black background. // -void graphDump(const char* name, - const ::android::hardware::neuralnetworks::V1_2::Model& model, +void graphDump(const char* name, const ::android::hardware::neuralnetworks::V1_2::Model& model, std::ostream* outStream = nullptr); } // namespace nn diff --git a/nn/common/include/OperationsUtils.h b/nn/common/include/OperationsUtils.h index 381bc8e16..9ae3aca8b 100644 --- a/nn/common/include/OperationsUtils.h +++ b/nn/common/include/OperationsUtils.h @@ -146,8 +146,7 @@ bool combineDimensions(const std::vector<uint32_t>& lhs, const std::vector<uint3 // Return the total number of elements, i.e. all the dimensions multiplied // together. For a scalar, returns one. uint32_t getNumberOfElements(const Shape& shape); -uint32_t getNumberOfElements(const Shape& shape, - size_t firstAxisInclusive, +uint32_t getNumberOfElements(const Shape& shape, size_t firstAxisInclusive, size_t lastAxisExclusive); uint32_t getNumberOfDimensions(const Shape& shape); @@ -179,27 +178,20 @@ inline int32_t computeOutSizeTransposeConv(int32_t imageSize, int32_t filterSize __wur bool QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, int* shift); -__wur -bool QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int32_t* right_shift); +__wur bool QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t* quantized_multiplier, + int32_t* right_shift); -__wur -bool QuantizeMultiplierGreaterThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* left_shift); +__wur bool QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier, + int* left_shift); __wur bool GetQuantizedConvolutionMultipler(const Shape& inputShape, const Shape& filterShape, const Shape& biasShape, const Shape& outputShape, double* multiplier); -void CalculateActivationRangeUint8(int32_t activation, - const Shape& outputShape, - int32_t* act_min, +void CalculateActivationRangeUint8(int32_t activation, const Shape& outputShape, int32_t* act_min, int32_t* act_max); -void CalculateActivationRangeFloat(int32_t activation, - float* activation_min, +void CalculateActivationRangeFloat(int32_t activation, float* activation_min, float* activation_max); int32_t CalculateInputRadius(int input_integer_bits, int input_left_shift); @@ -231,11 +223,11 @@ inline void calculateExplicitPaddingTransposeConv(int32_t in_size, int32_t strid padding_tail); } -inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight, - int32_t strideWidth, int32_t strideHeight, - int32_t filterWidth, int32_t filterHeight, - int32_t paddingLeft, int32_t paddingRight, - int32_t paddingTop, int32_t paddingBottom) { +inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight, int32_t strideWidth, + int32_t strideHeight, int32_t filterWidth, + int32_t filterHeight, int32_t paddingLeft, + int32_t paddingRight, int32_t paddingTop, + int32_t paddingBottom) { if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && paddingBottom == 0) { return kPaddingValid; } @@ -243,8 +235,8 @@ inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight, int32_t expectedPaddingLeft, expectedPaddingRight; int32_t expectedPaddingTop, expectedPaddingBottom; - calculateExplicitPadding(inWidth, strideWidth, filterWidth, kPaddingSame, - &expectedPaddingLeft, &expectedPaddingRight); + calculateExplicitPadding(inWidth, strideWidth, filterWidth, kPaddingSame, &expectedPaddingLeft, + &expectedPaddingRight); calculateExplicitPadding(inHeight, strideHeight, filterHeight, kPaddingSame, &expectedPaddingTop, &expectedPaddingBottom); if (expectedPaddingLeft == paddingLeft && expectedPaddingRight == paddingRight && @@ -257,30 +249,28 @@ inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight, // Reverse order of bits in the mask to match the expected order in kernel inline int ReverseMaskBits(int mask, int num_dimensions) { - int out = 0; - for (int dim = 0; dim < num_dimensions; dim++) { - out <<= 1; - out += (mask & 1); - mask >>= 1; - } - return out; + int out = 0; + for (int dim = 0; dim < num_dimensions; dim++) { + out <<= 1; + out += (mask & 1); + mask >>= 1; + } + return out; } // Compute the positive remainder. inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) { - return (divisor + (dividend % divisor)) % divisor; + return (divisor + (dividend % divisor)) % divisor; } // Compute clamped index. inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { - return pos_stride - ? (index >= dim ? dim - : PositiveRemainder( - std::min(std::max(index, -dim), dim), dim)) - : (index < -dim - ? -1 - : PositiveRemainder( - std::min(std::max(index, -dim), dim - 1), dim)); + return pos_stride + ? (index >= dim ? dim + : PositiveRemainder(std::min(std::max(index, -dim), dim), dim)) + : (index < -dim + ? -1 + : PositiveRemainder(std::min(std::max(index, -dim), dim - 1), dim)); } // Broadcasts input shape against one another and puts the result into output @@ -303,63 +293,38 @@ bool genericActivationPrepare(const Shape& input, Shape* output); bool genericNormalizationPrepare(const Shape& input, Shape* output); -bool reshapePrepare(const Shape& input, - const int32_t* targetDims, - const int32_t targetDimsSize, +bool reshapePrepare(const Shape& input, const int32_t* targetDims, const int32_t targetDimsSize, Shape* output); -bool depthToSpacePrepare(const Shape& input, - int32_t blockSize, - Shape* output); +bool depthToSpacePrepare(const Shape& input, int32_t blockSize, Shape* output); -bool spaceToDepthPrepare(const Shape& input, - int32_t blockSize, - Shape* output); +bool spaceToDepthPrepare(const Shape& input, int32_t blockSize, Shape* output); -bool embeddingLookupPrepare(const Shape &valueShape, - const Shape &lookupShape, - Shape *outputShape); +bool embeddingLookupPrepare(const Shape& valueShape, const Shape& lookupShape, Shape* outputShape); -bool hashtableLookupPrepare(const Shape &lookupShape, - const Shape &keyShape, - const Shape &valueShape, - Shape *outputShape, - Shape *hitShape); +bool hashtableLookupPrepare(const Shape& lookupShape, const Shape& keyShape, + const Shape& valueShape, Shape* outputShape, Shape* hitShape); -bool padPrepare(const Shape& input, - const int32_t* paddingsData, - const Shape& paddingsShape, +bool padPrepare(const Shape& input, const int32_t* paddingsData, const Shape& paddingsShape, Shape* output); -bool batchToSpacePrepare(const Shape& input, - const int32_t* blockSizeData, - const Shape& blockSizeShape, - Shape* output); - -bool spaceToBatchPrepare(const Shape& input, - const int32_t* blockSizeData, - const Shape& blockSizeShape, - const int32_t* paddingsData, - const Shape& paddingsShape, - Shape* output); - -bool squeezePrepare(const Shape& input, - const int32_t* squeezeDims, - const Shape& squeezeDimsShape, +bool batchToSpacePrepare(const Shape& input, const int32_t* blockSizeData, + const Shape& blockSizeShape, Shape* output); + +bool spaceToBatchPrepare(const Shape& input, const int32_t* blockSizeData, + const Shape& blockSizeShape, const int32_t* paddingsData, + const Shape& paddingsShape, Shape* output); + +bool squeezePrepare(const Shape& input, const int32_t* squeezeDims, const Shape& squeezeDimsShape, Shape* output); -bool meanPrepare(const Shape& input, - const int32_t* axisData, - const Shape& axisShape, - bool keepDims, +bool meanPrepare(const Shape& input, const int32_t* axisData, const Shape& axisShape, bool keepDims, Shape* output); -bool stridedSlicePrepare(const Shape& input, - const int32_t* beginData, const Shape& beginShape, - const int32_t* endData, const Shape& endShape, - const int32_t* stridesData, const Shape& stridesShape, - int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask, - Shape* output); +bool stridedSlicePrepare(const Shape& input, const int32_t* beginData, const Shape& beginShape, + const int32_t* endData, const Shape& endShape, const int32_t* stridesData, + const Shape& stridesShape, int32_t beginMask, int32_t endMask, + int32_t shrinkAxisMask, Shape* output); bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output); @@ -428,7 +393,7 @@ inline bool mergeThirdDimension(const T* bufferA, const std::vector<uint32_t>& d return true; } -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_UTILS_H diff --git a/nn/common/include/Tracing.h b/nn/common/include/Tracing.h index 01535d831..e461b2bdd 100644 --- a/nn/common/include/Tracing.h +++ b/nn/common/include/Tracing.h @@ -100,43 +100,41 @@ // Layer Application - For native applications (e.g., unit tests) #define NNTRACE_APP(phase, detail) NNTRACE_FULL(NNTRACE_LAYER_APPLICATION, phase, detail) #define NNTRACE_APP_SWITCH(phase, detail) \ - NNTRACE_FULL_SWITCH(NNTRACE_LAYER_APPLICATION, phase, detail) + NNTRACE_FULL_SWITCH(NNTRACE_LAYER_APPLICATION, phase, detail) // Layer Runtime - For the NNAPI runtime #define NNTRACE_RT(phase, detail) NNTRACE_FULL(NNTRACE_LAYER_RUNTIME, phase, detail) #define NNTRACE_RT_SWITCH(phase, detail) NNTRACE_FULL_SWITCH(NNTRACE_LAYER_RUNTIME, phase, detail) // Layer CPU - CPU executor #define NNTRACE_CPU(phase, detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, phase, detail) -#define NNTRACE_COMP(detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, \ - NNTRACE_PHASE_COMPUTATION, detail) -#define NNTRACE_COMP_SWITCH(detail) NNTRACE_FULL_SWITCH(NNTRACE_LAYER_CPU, \ - NNTRACE_PHASE_COMPUTATION, detail) -#define NNTRACE_TRANS(detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, \ - NNTRACE_PHASE_TRANSFORMATION, detail) +#define NNTRACE_COMP(detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, NNTRACE_PHASE_COMPUTATION, detail) +#define NNTRACE_COMP_SWITCH(detail) \ + NNTRACE_FULL_SWITCH(NNTRACE_LAYER_CPU, NNTRACE_PHASE_COMPUTATION, detail) +#define NNTRACE_TRANS(detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, NNTRACE_PHASE_TRANSFORMATION, detail) // Fully specified macros to be used when no convenience wrapper exists for your // need. #define NNTRACE_FULL(layer, phase, detail) NNTRACE_NAME_1(("[NN_" layer "_" phase "]" detail)) #define NNTRACE_FULL_SWITCH(layer, phase, detail) \ - NNTRACE_NAME_SWITCH(("[SW][NN_" layer "_" phase "]" detail)) + NNTRACE_NAME_SWITCH(("[SW][NN_" layer "_" phase "]" detail)) #define NNTRACE_FULL_SUBTRACT(layer, phase, detail) \ - NNTRACE_NAME_1(("[SUB][NN_" layer "_" phase "]" detail)) + NNTRACE_NAME_1(("[SUB][NN_" layer "_" phase "]" detail)) // Raw macro without scoping requirements, for special cases -#define NNTRACE_FULL_RAW(layer, phase, detail) android::ScopedTrace PASTE(___tracer, __LINE__) \ - (ATRACE_TAG, ("[NN_" layer "_" phase "]" detail)) +#define NNTRACE_FULL_RAW(layer, phase, detail) \ + android::ScopedTrace PASTE(___tracer, __LINE__)(ATRACE_TAG, ("[NN_" layer "_" phase "]" detail)) // Tracing buckets - for calculating timing summaries over. // // Application-only phases -#define NNTRACE_PHASE_OVERALL "PO" // Overall program, e.g., one benchmark case -#define NNTRACE_PHASE_WARMUP "PWU" // Warmup (nesting multiple executions) -#define NNTRACE_PHASE_BENCHMARK "PBM" // Benchmark (nesting multiple executions) +#define NNTRACE_PHASE_OVERALL "PO" // Overall program, e.g., one benchmark case +#define NNTRACE_PHASE_WARMUP "PWU" // Warmup (nesting multiple executions) +#define NNTRACE_PHASE_BENCHMARK "PBM" // Benchmark (nesting multiple executions) // Main phases, usable by all layers -#define NNTRACE_PHASE_INITIALIZATION "PI" // Initialization - not related to a model -#define NNTRACE_PHASE_PREPARATION "PP" // Model construction -#define NNTRACE_PHASE_COMPILATION "PC" // Model compilation -#define NNTRACE_PHASE_EXECUTION "PE" // Executing the model -#define NNTRACE_PHASE_TERMINATION "PT" // Tearing down -#define NNTRACE_PHASE_UNSPECIFIED "PU" // Helper code called from multiple phases +#define NNTRACE_PHASE_INITIALIZATION "PI" // Initialization - not related to a model +#define NNTRACE_PHASE_PREPARATION "PP" // Model construction +#define NNTRACE_PHASE_COMPILATION "PC" // Model compilation +#define NNTRACE_PHASE_EXECUTION "PE" // Executing the model +#define NNTRACE_PHASE_TERMINATION "PT" // Tearing down +#define NNTRACE_PHASE_UNSPECIFIED "PU" // Helper code called from multiple phases // Subphases of execution #define NNTRACE_PHASE_INPUTS_AND_OUTPUTS "PIO" // Setting inputs/outputs and allocating buffers #define NNTRACE_PHASE_TRANSFORMATION "PTR" // Transforming data for computation @@ -149,8 +147,7 @@ #define NNTRACE_LAYER_DRIVER "LD" #define NNTRACE_LAYER_CPU "LC" #define NNTRACE_LAYER_OTHER "LO" -#define NNTRACE_LAYER_UTILITY "LU" // Code used from multiple layers - +#define NNTRACE_LAYER_UTILITY "LU" // Code used from multiple layers // Implementation // @@ -162,10 +159,9 @@ // Switching trace, more than one per scope allowed, translated by // systrace_parser.py. This is mainly useful for tracing multiple phases through // one function / scope. -#define NNTRACE_NAME_SWITCH(name) android::ScopedTrace PASTE(___tracer, __LINE__) \ - (ATRACE_TAG, name); \ - (void)___tracer_1 // ensure switch is only used after a basic trace - +#define NNTRACE_NAME_SWITCH(name) \ + android::ScopedTrace PASTE(___tracer, __LINE__)(ATRACE_TAG, name); \ + (void)___tracer_1 // ensure switch is only used after a basic trace // Disallow use of raw ATRACE macros #undef ATRACE_NAME diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h index 6c53cc940..595c7098a 100644 --- a/nn/common/include/Utils.h +++ b/nn/common/include/Utils.h @@ -50,22 +50,14 @@ const int kOEMCodeBase = 10000; * forget to update the corresponding 'tags' table in * the initVlogMask() function implemented in Utils.cpp. */ -enum VLogFlags { - MODEL = 0, - COMPILATION, - EXECUTION, - CPUEXE, - MANAGER, - DRIVER -}; +enum VLogFlags { MODEL = 0, COMPILATION, EXECUTION, CPUEXE, MANAGER, DRIVER }; -#define VLOG_IS_ON(TAG) \ - ((vLogMask & (1 << (TAG))) != 0) +#define VLOG_IS_ON(TAG) ((vLogMask & (1 << (TAG))) != 0) -#define VLOG(TAG) \ +#define VLOG(TAG) \ if (LIKELY(!VLOG_IS_ON(TAG))) \ - ; \ - else \ + ; \ + else \ LOG(INFO) extern int vLogMask; diff --git a/nn/common/operations/ArgMinMax.cpp b/nn/common/operations/ArgMinMax.cpp index 64d4606d1..95b69bdd2 100644 --- a/nn/common/operations/ArgMinMax.cpp +++ b/nn/common/operations/ArgMinMax.cpp @@ -30,22 +30,19 @@ namespace nn { using namespace hal; template <typename In, typename Out> -static void argMinMaxImpl(const In* inputData, const Shape& inputShape, - int32_t axis, bool isArgMin, +static void argMinMaxImpl(const In* inputData, const Shape& inputShape, int32_t axis, bool isArgMin, Out* outputData, const Shape& outputShape) { const int outerSize = getNumberOfElements(inputShape, 0, axis); const int axisSize = getSizeOfDimension(inputShape, axis); - const int innerSize = getNumberOfElements( - inputShape, axis + 1, getNumberOfDimensions(inputShape)); + const int innerSize = + getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape)); for (int outer = 0; outer < outerSize; ++outer) { for (int inner = 0; inner < innerSize; ++inner) { auto minMaxValue = inputData[outer * axisSize * innerSize + inner]; int minMaxIndex = 0; for (int i = 1; i < axisSize; ++i) { - const auto& value = - inputData[(outer * axisSize + i) * innerSize + inner]; - if ((isArgMin && value < minMaxValue) || - (!isArgMin && value > minMaxValue)) { + const auto& value = inputData[(outer * axisSize + i) * innerSize + inner]; + if ((isArgMin && value < minMaxValue) || (!isArgMin && value > minMaxValue)) { minMaxValue = value; minMaxIndex = i; } @@ -55,23 +52,17 @@ static void argMinMaxImpl(const In* inputData, const Shape& inputShape, } } -bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, - int32 axis, bool isArgMin, +bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32 axis, bool isArgMin, uint8_t* outputData, const Shape& outputShape) { NNTRACE_TRANS("argMinMaxGeneric"); NN_CHECK(handleNegativeAxis(inputShape, &axis)); -#define NNAPI_IMPL_ARG_MIN_MAX(operandType, dataType) \ - if (inputShape.type == operandType) { \ - NNTRACE_COMP_SWITCH("argMinMaxImpl::" #dataType); \ - argMinMaxImpl( \ - reinterpret_cast<const dataType*>(inputData), \ - inputShape, \ - axis, \ - isArgMin, \ - reinterpret_cast<int32_t*>(outputData), \ - outputShape); \ - return true; \ +#define NNAPI_IMPL_ARG_MIN_MAX(operandType, dataType) \ + if (inputShape.type == operandType) { \ + NNTRACE_COMP_SWITCH("argMinMaxImpl::" #dataType); \ + argMinMaxImpl(reinterpret_cast<const dataType*>(inputData), inputShape, axis, isArgMin, \ + reinterpret_cast<int32_t*>(outputData), outputShape); \ + return true; \ } NNAPI_IMPL_ARG_MIN_MAX(OperandType::TENSOR_FLOAT16, _Float16); @@ -84,5 +75,5 @@ bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, return false; } -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android diff --git a/nn/common/operations/Conv2D.cpp b/nn/common/operations/Conv2D.cpp index 678e2d698..152d3d63c 100644 --- a/nn/common/operations/Conv2D.cpp +++ b/nn/common/operations/Conv2D.cpp @@ -120,49 +120,49 @@ struct Conv2dParam { } }; -#define ANDROID_NN_CONV_PARAMETERS(Type) \ - uint32_t height = getSizeOfDimension(inputShape, 1); \ - uint32_t width = getSizeOfDimension(inputShape, 2); \ - uint32_t filterHeight = getSizeOfDimension(filterShape, 1); \ - uint32_t filterWidth = getSizeOfDimension(filterShape, 2); \ - uint32_t outHeight = getSizeOfDimension(outputShape, 1); \ - uint32_t outWidth = getSizeOfDimension(outputShape, 2); \ - uint32_t inDepth = getSizeOfDimension(inputShape, 3); \ - \ - uint32_t paddingHeight = (uint32_t)padding_top; \ - uint32_t paddingWidth = (uint32_t)padding_left; \ - \ - tflite::Dims<4> im2colDim; \ - im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0); \ - im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1); \ - im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2); \ - im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth; \ - \ - im2colDim.strides[0] = 1; \ - for (int i=1; i<4; i++) { \ - im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1]; \ - } \ - \ - Type* im2colData = nullptr; \ - uint64_t im2colByteSize = sizeof(Type); \ - std::unique_ptr<Type[]> im2colGuard; \ - for (int i=0; i<4; i++) { \ - im2colByteSize *= im2colDim.sizes[i]; \ - } \ - /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */ \ - if (im2colByteSize >= 0x7fffffff) { \ - LOG(ERROR) << "Conv size is too large, not enough memory"; \ - return false; \ - } \ - if (im2colByteSize <= kStaticBufferSize) { \ - im2colData = reinterpret_cast<Type *>(static_scratch_buffer); \ - } else { \ - im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)]; \ - if (im2colData == nullptr) { \ - LOG(ERROR) << "Conv size is too large, not enough memory"; \ - return false; \ - } \ - im2colGuard.reset(im2colData); \ +#define ANDROID_NN_CONV_PARAMETERS(Type) \ + uint32_t height = getSizeOfDimension(inputShape, 1); \ + uint32_t width = getSizeOfDimension(inputShape, 2); \ + uint32_t filterHeight = getSizeOfDimension(filterShape, 1); \ + uint32_t filterWidth = getSizeOfDimension(filterShape, 2); \ + uint32_t outHeight = getSizeOfDimension(outputShape, 1); \ + uint32_t outWidth = getSizeOfDimension(outputShape, 2); \ + uint32_t inDepth = getSizeOfDimension(inputShape, 3); \ + \ + uint32_t paddingHeight = (uint32_t)padding_top; \ + uint32_t paddingWidth = (uint32_t)padding_left; \ + \ + tflite::Dims<4> im2colDim; \ + im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0); \ + im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1); \ + im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2); \ + im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth; \ + \ + im2colDim.strides[0] = 1; \ + for (int i = 1; i < 4; i++) { \ + im2colDim.strides[i] = im2colDim.strides[i - 1] * im2colDim.sizes[i - 1]; \ + } \ + \ + Type* im2colData = nullptr; \ + uint64_t im2colByteSize = sizeof(Type); \ + std::unique_ptr<Type[]> im2colGuard; \ + for (int i = 0; i < 4; i++) { \ + im2colByteSize *= im2colDim.sizes[i]; \ + } \ + /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */ \ + if (im2colByteSize >= 0x7fffffff) { \ + LOG(ERROR) << "Conv size is too large, not enough memory"; \ + return false; \ + } \ + if (im2colByteSize <= kStaticBufferSize) { \ + im2colData = reinterpret_cast<Type*>(static_scratch_buffer); \ + } else { \ + im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)]; \ + if (im2colData == nullptr) { \ + LOG(ERROR) << "Conv size is too large, not enough memory"; \ + return false; \ + } \ + im2colGuard.reset(im2colData); \ } bool convNhwc(const float* inputData, const Shape& inputShape, const float* filterData, diff --git a/nn/common/operations/EmbeddingLookup.cpp b/nn/common/operations/EmbeddingLookup.cpp index f3b2911e7..bf3be6317 100644 --- a/nn/common/operations/EmbeddingLookup.cpp +++ b/nn/common/operations/EmbeddingLookup.cpp @@ -31,30 +31,29 @@ using namespace hal; EmbeddingLookup::EmbeddingLookup(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) { - value_ = GetInput(operation, operands, kValueTensor); - lookup_ = GetInput(operation, operands, kLookupTensor); + value_ = GetInput(operation, operands, kValueTensor); + lookup_ = GetInput(operation, operands, kLookupTensor); - output_ = GetOutput(operation, operands, kOutputTensor); + output_ = GetOutput(operation, operands, kOutputTensor); } bool EmbeddingLookup::Eval() { - NNTRACE_COMP("EmbeddingLookup::Eval"); - const int row_size = value_->shape().dimensions[0]; - const int total_bytes = nonExtensionOperandSizeOfData(value_->type, value_->dimensions); - const int row_bytes = total_bytes/row_size; - - for (uint32_t i = 0; i < lookup_->shape().dimensions[0]; i++) { - int idx = (reinterpret_cast<int*>(lookup_->buffer))[i]; - if (idx >= row_size || idx < 0) { - LOG(ERROR) << "Embedding Lookup: index out of bounds."; - return false; - } else { - memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes, - row_bytes); + NNTRACE_COMP("EmbeddingLookup::Eval"); + const int row_size = value_->shape().dimensions[0]; + const int total_bytes = nonExtensionOperandSizeOfData(value_->type, value_->dimensions); + const int row_bytes = total_bytes / row_size; + + for (uint32_t i = 0; i < lookup_->shape().dimensions[0]; i++) { + int idx = (reinterpret_cast<int*>(lookup_->buffer))[i]; + if (idx >= row_size || idx < 0) { + LOG(ERROR) << "Embedding Lookup: index out of bounds."; + return false; + } else { + memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes, row_bytes); + } } - } - return true; + return true; } } // namespace nn diff --git a/nn/common/operations/EmbeddingLookup.h b/nn/common/operations/EmbeddingLookup.h index bb89e24b2..9109ddfdf 100644 --- a/nn/common/operations/EmbeddingLookup.h +++ b/nn/common/operations/EmbeddingLookup.h @@ -27,22 +27,22 @@ namespace nn { struct RunTimeOperandInfo; class EmbeddingLookup { - public: - EmbeddingLookup(const hardware::neuralnetworks::V1_2::Operation& operation, - std::vector<RunTimeOperandInfo>& operands); + public: + EmbeddingLookup(const hardware::neuralnetworks::V1_2::Operation& operation, + std::vector<RunTimeOperandInfo>& operands); - bool Eval(); + bool Eval(); - static constexpr int kLookupTensor = 0; - static constexpr int kValueTensor = 1; + static constexpr int kLookupTensor = 0; + static constexpr int kValueTensor = 1; - static constexpr int kOutputTensor = 0; + static constexpr int kOutputTensor = 0; - private: - const RunTimeOperandInfo *value_; - const RunTimeOperandInfo *lookup_; + private: + const RunTimeOperandInfo* value_; + const RunTimeOperandInfo* lookup_; - RunTimeOperandInfo *output_; + RunTimeOperandInfo* output_; }; } // namespace nn diff --git a/nn/common/operations/EmbeddingLookupTest.cpp b/nn/common/operations/EmbeddingLookupTest.cpp index d864ab73a..10e2e339a 100644 --- a/nn/common/operations/EmbeddingLookupTest.cpp +++ b/nn/common/operations/EmbeddingLookupTest.cpp @@ -31,13 +31,13 @@ namespace wrapper { namespace { std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, - float max_abs_error=1.e-6) { - std::vector<Matcher<float>> matchers; - matchers.reserve(values.size()); - for (const float& v : values) { - matchers.emplace_back(FloatNear(v, max_abs_error)); - } - return matchers; + float max_abs_error = 1.e-6) { + std::vector<Matcher<float>> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; } } // namespace @@ -45,109 +45,108 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, using ::testing::ElementsAreArray; #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ - ACTION(Value, float) \ - ACTION(Lookup, int) + ACTION(Value, float) \ + ACTION(Lookup, int) // For all output and intermediate states -#define FOR_ALL_OUTPUT_TENSORS(ACTION) \ - ACTION(Output, float) +#define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, float) class EmbeddingLookupOpModel { - public: - EmbeddingLookupOpModel(std::initializer_list<uint32_t> index_shape, - std::initializer_list<uint32_t> weight_shape) { - auto it = weight_shape.begin(); - rows_ = *it++; - columns_ = *it++; - features_ = *it; + public: + EmbeddingLookupOpModel(std::initializer_list<uint32_t> index_shape, + std::initializer_list<uint32_t> weight_shape) { + auto it = weight_shape.begin(); + rows_ = *it++; + columns_ = *it++; + features_ = *it; - std::vector<uint32_t> inputs; + std::vector<uint32_t> inputs; - OperandType LookupTy(Type::TENSOR_INT32, index_shape); - inputs.push_back(model_.addOperand(&LookupTy)); + OperandType LookupTy(Type::TENSOR_INT32, index_shape); + inputs.push_back(model_.addOperand(&LookupTy)); - OperandType ValueTy(Type::TENSOR_FLOAT32, weight_shape); - inputs.push_back(model_.addOperand(&ValueTy)); + OperandType ValueTy(Type::TENSOR_FLOAT32, weight_shape); + inputs.push_back(model_.addOperand(&ValueTy)); - std::vector<uint32_t> outputs; + std::vector<uint32_t> outputs; - OperandType OutputOpndTy(Type::TENSOR_FLOAT32, weight_shape); - outputs.push_back(model_.addOperand(&OutputOpndTy)); + OperandType OutputOpndTy(Type::TENSOR_FLOAT32, weight_shape); + outputs.push_back(model_.addOperand(&OutputOpndTy)); - auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t { - uint32_t sz = 1; - for (uint32_t d : dims) { sz *= d; } - return sz; - }; + auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t { + uint32_t sz = 1; + for (uint32_t d : dims) { + sz *= d; + } + return sz; + }; - Value_.insert(Value_.end(), multiAll(weight_shape), 0.f); - Output_.insert(Output_.end(), multiAll(weight_shape), 0.f); + Value_.insert(Value_.end(), multiAll(weight_shape), 0.f); + Output_.insert(Output_.end(), multiAll(weight_shape), 0.f); - model_.addOperation(ANEURALNETWORKS_EMBEDDING_LOOKUP, inputs, outputs); - model_.identifyInputsAndOutputs(inputs, outputs); + model_.addOperation(ANEURALNETWORKS_EMBEDDING_LOOKUP, inputs, outputs); + model_.identifyInputsAndOutputs(inputs, outputs); - model_.finish(); - } + model_.finish(); + } - void Invoke() { - ASSERT_TRUE(model_.isValid()); + void Invoke() { + ASSERT_TRUE(model_.isValid()); - Compilation compilation(&model_); - compilation.finish(); - Execution execution(&compilation); + Compilation compilation(&model_); + compilation.finish(); + Execution execution(&compilation); #define SetInputOrWeight(X, T) \ - ASSERT_EQ(execution.setInput(EmbeddingLookup::k##X##Tensor, X##_.data(), \ - sizeof(T) * X##_.size()), \ - Result::NO_ERROR); + ASSERT_EQ(execution.setInput(EmbeddingLookup::k##X##Tensor, X##_.data(), \ + sizeof(T) * X##_.size()), \ + Result::NO_ERROR); - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); #undef SetInputOrWeight #define SetOutput(X, T) \ - ASSERT_EQ(execution.setOutput(EmbeddingLookup::k##X##Tensor, X##_.data(), \ - sizeof(T) * X##_.size()), \ - Result::NO_ERROR); + ASSERT_EQ(execution.setOutput(EmbeddingLookup::k##X##Tensor, X##_.data(), \ + sizeof(T) * X##_.size()), \ + Result::NO_ERROR); - FOR_ALL_OUTPUT_TENSORS(SetOutput); + FOR_ALL_OUTPUT_TENSORS(SetOutput); #undef SetOutput - ASSERT_EQ(execution.compute(), Result::NO_ERROR); - } + ASSERT_EQ(execution.compute(), Result::NO_ERROR); + } -#define DefineSetter(X, T) \ - void Set##X(const std::vector<T>& f) { \ - X##_.insert(X##_.end(), f.begin(), f.end()); \ - } +#define DefineSetter(X, T) \ + void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); } - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); #undef DefineSetter - void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) { - for (uint32_t i = 0; i < rows_; i++) { - for (uint32_t j = 0; j < columns_; j++) { - for (uint32_t k = 0; k < features_; k++) { - Value_[(i * columns_ + j) * features_ + k] = function(i, j, k); + void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) { + for (uint32_t i = 0; i < rows_; i++) { + for (uint32_t j = 0; j < columns_; j++) { + for (uint32_t k = 0; k < features_; k++) { + Value_[(i * columns_ + j) * features_ + k] = function(i, j, k); + } + } } - } } - } - const std::vector<float> &GetOutput() const { return Output_; } + const std::vector<float>& GetOutput() const { return Output_; } - private: - Model model_; - uint32_t rows_; - uint32_t columns_; - uint32_t features_; + private: + Model model_; + uint32_t rows_; + uint32_t columns_; + uint32_t features_; #define DefineTensor(X, T) std::vector<T> X##_; - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); - FOR_ALL_OUTPUT_TENSORS(DefineTensor); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); + FOR_ALL_OUTPUT_TENSORS(DefineTensor); #undef DefineTensor }; @@ -155,19 +154,17 @@ class EmbeddingLookupOpModel { // TODO: write more tests that exercise the details of the op, such as // lookup errors and variable input shapes. TEST(EmbeddingLookupOpTest, SimpleTest) { - EmbeddingLookupOpModel m({3}, {3, 2, 4}); - m.SetLookup({1, 0, 2}); - m.Set3DWeightMatrix( - [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); - - m.Invoke(); - - EXPECT_THAT(m.GetOutput(), - ElementsAreArray(ArrayFloatNear({ - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 - }))); + EmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetLookup({1, 0, 2}); + m.Set3DWeightMatrix([](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }))); } } // namespace wrapper diff --git a/nn/common/operations/GenerateProposals.cpp b/nn/common/operations/GenerateProposals.cpp index 67d614f6b..d70271504 100644 --- a/nn/common/operations/GenerateProposals.cpp +++ b/nn/common/operations/GenerateProposals.cpp @@ -480,10 +480,10 @@ bool boxWithNmsLimitFloat32Compute(float* scoresData, const Shape& scoresShape, NN_RET_CHECK_LE(roi[1], roi[3]); } std::vector<uint32_t> result; - softNmsMultiClass(scoresBase, numClasses, batchSplitIn->at(b), scoreThreshold, - nmsScoreThreshold, maxNumDetections, maxNumDetections, - [&roiBase](uint32_t ind) { return roiBase + ind * kRoiDim; }, kernel, - &result); + softNmsMultiClass( + scoresBase, numClasses, batchSplitIn->at(b), scoreThreshold, nmsScoreThreshold, + maxNumDetections, maxNumDetections, + [&roiBase](uint32_t ind) { return roiBase + ind * kRoiDim; }, kernel, &result); // Sort again by class. std::sort(result.begin(), result.end(), [&scoresBase, numClasses](const uint32_t& lhs, const uint32_t& rhs) { diff --git a/nn/common/operations/HashtableLookup.cpp b/nn/common/operations/HashtableLookup.cpp index 67cdffd2d..ed4de8e64 100644 --- a/nn/common/operations/HashtableLookup.cpp +++ b/nn/common/operations/HashtableLookup.cpp @@ -32,47 +32,46 @@ namespace { using namespace hal; int greater(const void* a, const void* b) { - return *static_cast<const int*>(a) - *static_cast<const int*>(b); + return *static_cast<const int*>(a) - *static_cast<const int*>(b); } } // anonymous namespace HashtableLookup::HashtableLookup(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) { - lookup_ = GetInput(operation, operands, kLookupTensor); - key_ = GetInput(operation, operands, kKeyTensor); - value_ = GetInput(operation, operands, kValueTensor); + lookup_ = GetInput(operation, operands, kLookupTensor); + key_ = GetInput(operation, operands, kKeyTensor); + value_ = GetInput(operation, operands, kValueTensor); - output_ = GetOutput(operation, operands, kOutputTensor); - hits_ = GetOutput(operation, operands, kHitsTensor); + output_ = GetOutput(operation, operands, kOutputTensor); + hits_ = GetOutput(operation, operands, kHitsTensor); } bool HashtableLookup::Eval() { - NNTRACE_COMP("HashtableLookup::Eval"); - const int num_rows = value_->shape().dimensions[0]; - const int row_bytes = nonExtensionOperandSizeOfData(value_->type, value_->dimensions) / num_rows; - void* pointer = nullptr; - - for (int i = 0; i < static_cast<int>(lookup_->shape().dimensions[0]); i++) { - int idx = -1; - pointer = bsearch(lookup_->buffer + sizeof(int) * i, key_->buffer, - num_rows, sizeof(int), greater); - if (pointer != nullptr) { - idx = - (reinterpret_cast<uint8_t*>(pointer) - key_->buffer) / sizeof(float); + NNTRACE_COMP("HashtableLookup::Eval"); + const int num_rows = value_->shape().dimensions[0]; + const int row_bytes = + nonExtensionOperandSizeOfData(value_->type, value_->dimensions) / num_rows; + void* pointer = nullptr; + + for (int i = 0; i < static_cast<int>(lookup_->shape().dimensions[0]); i++) { + int idx = -1; + pointer = bsearch(lookup_->buffer + sizeof(int) * i, key_->buffer, num_rows, sizeof(int), + greater); + if (pointer != nullptr) { + idx = (reinterpret_cast<uint8_t*>(pointer) - key_->buffer) / sizeof(float); + } + + if (idx >= num_rows || idx < 0) { + memset(output_->buffer + i * row_bytes, 0, row_bytes); + hits_->buffer[i] = 0; + } else { + memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes, row_bytes); + hits_->buffer[i] = 1; + } } - if (idx >= num_rows || idx < 0) { - memset(output_->buffer + i * row_bytes, 0, row_bytes); - hits_->buffer[i] = 0; - } else { - memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes, - row_bytes); - hits_->buffer[i] = 1; - } - } - - return true; + return true; } } // namespace nn diff --git a/nn/common/operations/HashtableLookup.h b/nn/common/operations/HashtableLookup.h index 52d9cc5bc..854e7dff7 100644 --- a/nn/common/operations/HashtableLookup.h +++ b/nn/common/operations/HashtableLookup.h @@ -27,26 +27,26 @@ namespace nn { struct RunTimeOperandInfo; class HashtableLookup { - public: - HashtableLookup(const hardware::neuralnetworks::V1_2::Operation& operation, - std::vector<RunTimeOperandInfo>& operands); + public: + HashtableLookup(const hardware::neuralnetworks::V1_2::Operation& operation, + std::vector<RunTimeOperandInfo>& operands); - bool Eval(); + bool Eval(); - static constexpr int kLookupTensor = 0; - static constexpr int kKeyTensor = 1; - static constexpr int kValueTensor = 2; + static constexpr int kLookupTensor = 0; + static constexpr int kKeyTensor = 1; + static constexpr int kValueTensor = 2; - static constexpr int kOutputTensor = 0; - static constexpr int kHitsTensor = 1; + static constexpr int kOutputTensor = 0; + static constexpr int kHitsTensor = 1; - private: - const RunTimeOperandInfo *lookup_; - const RunTimeOperandInfo *key_; - const RunTimeOperandInfo *value_; + private: + const RunTimeOperandInfo* lookup_; + const RunTimeOperandInfo* key_; + const RunTimeOperandInfo* value_; - RunTimeOperandInfo *output_; - RunTimeOperandInfo *hits_; + RunTimeOperandInfo* output_; + RunTimeOperandInfo* hits_; }; } // namespace nn diff --git a/nn/common/operations/HashtableLookupTest.cpp b/nn/common/operations/HashtableLookupTest.cpp index 7fbab58ce..ff62006b8 100644 --- a/nn/common/operations/HashtableLookupTest.cpp +++ b/nn/common/operations/HashtableLookupTest.cpp @@ -31,158 +31,160 @@ namespace wrapper { namespace { std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, - float max_abs_error=1.e-6) { - std::vector<Matcher<float>> matchers; - matchers.reserve(values.size()); - for (const float& v : values) { - matchers.emplace_back(FloatNear(v, max_abs_error)); - } - return matchers; + float max_abs_error = 1.e-6) { + std::vector<Matcher<float>> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; } } // namespace using ::testing::ElementsAreArray; -#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ - ACTION(Lookup, int) \ - ACTION(Key, int) \ - ACTION(Value, float) +#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ + ACTION(Lookup, int) \ + ACTION(Key, int) \ + ACTION(Value, float) // For all output and intermediate states #define FOR_ALL_OUTPUT_TENSORS(ACTION) \ - ACTION(Output, float) \ - ACTION(Hits, uint8_t) + ACTION(Output, float) \ + ACTION(Hits, uint8_t) class HashtableLookupOpModel { - public: + public: HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape, std::initializer_list<uint32_t> key_shape, std::initializer_list<uint32_t> value_shape) { - auto it_vs = value_shape.begin(); - rows_ = *it_vs++; - features_ = *it_vs; + auto it_vs = value_shape.begin(); + rows_ = *it_vs++; + features_ = *it_vs; - std::vector<uint32_t> inputs; + std::vector<uint32_t> inputs; - // Input and weights - OperandType LookupTy(Type::TENSOR_INT32, lookup_shape); - inputs.push_back(model_.addOperand(&LookupTy)); + // Input and weights + OperandType LookupTy(Type::TENSOR_INT32, lookup_shape); + inputs.push_back(model_.addOperand(&LookupTy)); - OperandType KeyTy(Type::TENSOR_INT32, key_shape); - inputs.push_back(model_.addOperand(&KeyTy)); + OperandType KeyTy(Type::TENSOR_INT32, key_shape); + inputs.push_back(model_.addOperand(&KeyTy)); - OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape); - inputs.push_back(model_.addOperand(&ValueTy)); + OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape); + inputs.push_back(model_.addOperand(&ValueTy)); - // Output and other intermediate state - std::vector<uint32_t> outputs; + // Output and other intermediate state + std::vector<uint32_t> outputs; - std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end()); - out_dim.push_back(features_); + std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end()); + out_dim.push_back(features_); - OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim); - outputs.push_back(model_.addOperand(&OutputOpndTy)); + OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim); + outputs.push_back(model_.addOperand(&OutputOpndTy)); - OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0); - outputs.push_back(model_.addOperand(&HitsOpndTy)); + OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0); + outputs.push_back(model_.addOperand(&HitsOpndTy)); - auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t { - uint32_t sz = 1; - for (uint32_t d : dims) { sz *= d; } - return sz; - }; + auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t { + uint32_t sz = 1; + for (uint32_t d : dims) { + sz *= d; + } + return sz; + }; - Value_.insert(Value_.end(), multiAll(value_shape), 0.f); - Output_.insert(Output_.end(), multiAll(out_dim), 0.f); - Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0); + Value_.insert(Value_.end(), multiAll(value_shape), 0.f); + Output_.insert(Output_.end(), multiAll(out_dim), 0.f); + Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0); - model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs); - model_.identifyInputsAndOutputs(inputs, outputs); + model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs); + model_.identifyInputsAndOutputs(inputs, outputs); - model_.finish(); - } + model_.finish(); + } - void Invoke() { - ASSERT_TRUE(model_.isValid()); + void Invoke() { + ASSERT_TRUE(model_.isValid()); - Compilation compilation(&model_); - compilation.finish(); - Execution execution(&compilation); + Compilation compilation(&model_); + compilation.finish(); + Execution execution(&compilation); -#define SetInputOrWeight(X, T) \ - ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \ - sizeof(T) * X##_.size()), \ - Result::NO_ERROR); +#define SetInputOrWeight(X, T) \ + ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \ + sizeof(T) * X##_.size()), \ + Result::NO_ERROR); - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); #undef SetInputOrWeight -#define SetOutput(X, T) \ - ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \ - sizeof(T) * X##_.size()), \ - Result::NO_ERROR); +#define SetOutput(X, T) \ + ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \ + sizeof(T) * X##_.size()), \ + Result::NO_ERROR); - FOR_ALL_OUTPUT_TENSORS(SetOutput); + FOR_ALL_OUTPUT_TENSORS(SetOutput); #undef SetOutput - ASSERT_EQ(execution.compute(), Result::NO_ERROR); - } + ASSERT_EQ(execution.compute(), Result::NO_ERROR); + } -#define DefineSetter(X, T) \ - void Set##X(const std::vector<T>& f) { \ - X##_.insert(X##_.end(), f.begin(), f.end()); \ - } +#define DefineSetter(X, T) \ + void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); } - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); #undef DefineSetter - void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) { - for (uint32_t i = 0; i < rows_; i++) { - for (uint32_t j = 0; j < features_; j++) { - Value_[i * features_ + j] = function(i, j); - } + void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) { + for (uint32_t i = 0; i < rows_; i++) { + for (uint32_t j = 0; j < features_; j++) { + Value_[i * features_ + j] = function(i, j); + } + } } - } - const std::vector<float>& GetOutput() const { return Output_; } - const std::vector<uint8_t>& GetHits() const { return Hits_; } + const std::vector<float>& GetOutput() const { return Output_; } + const std::vector<uint8_t>& GetHits() const { return Hits_; } - private: - Model model_; - uint32_t rows_; - uint32_t features_; + private: + Model model_; + uint32_t rows_; + uint32_t features_; #define DefineTensor(X, T) std::vector<T> X##_; - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); - FOR_ALL_OUTPUT_TENSORS(DefineTensor); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); + FOR_ALL_OUTPUT_TENSORS(DefineTensor); #undef DefineTensor }; TEST(HashtableLookupOpTest, BlackBoxTest) { - HashtableLookupOpModel m({4}, {3}, {3, 2}); - - m.SetLookup({1234, -292, -11, 0}); - m.SetKey({-11, 0, 1234}); - m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); - - m.Invoke(); - - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 2.0, 2.1, // 2-rd item - 0, 0, // Not found - 0.0, 0.1, // 0-th item - 1.0, 1.1, // 1-st item - }))); - EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({ - 1, 0, 1, 1, + HashtableLookupOpModel m({4}, {3}, {3, 2}); + + m.SetLookup({1234, -292, -11, 0}); + m.SetKey({-11, 0, 1234}); + m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 2.0, 2.1, // 2-rd item + 0, 0, // Not found + 0.0, 0.1, // 0-th item + 1.0, 1.1, // 1-st item + }))); + EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({ + 1, + 0, + 1, + 1, })); - } } // namespace wrapper diff --git a/nn/common/operations/LSTMTest.cpp b/nn/common/operations/LSTMTest.cpp index edbf12841..35e5eded9 100644 --- a/nn/common/operations/LSTMTest.cpp +++ b/nn/common/operations/LSTMTest.cpp @@ -32,60 +32,61 @@ using ::testing::Matcher; namespace { std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, - float max_abs_error=1.e-6) { - std::vector<Matcher<float>> matchers; - matchers.reserve(values.size()); - for (const float& v : values) { - matchers.emplace_back(FloatNear(v, max_abs_error)); - } - return matchers; + float max_abs_error = 1.e-6) { + std::vector<Matcher<float>> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; } } // anonymous namespace -#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ - ACTION(Input) \ - ACTION(InputToInputWeights) \ - ACTION(InputToCellWeights) \ - ACTION(InputToForgetWeights) \ - ACTION(InputToOutputWeights) \ - ACTION(RecurrentToInputWeights) \ - ACTION(RecurrentToCellWeights) \ - ACTION(RecurrentToForgetWeights) \ - ACTION(RecurrentToOutputWeights) \ - ACTION(CellToInputWeights) \ - ACTION(CellToForgetWeights) \ - ACTION(CellToOutputWeights) \ - ACTION(InputGateBias) \ - ACTION(CellGateBias) \ - ACTION(ForgetGateBias) \ - ACTION(OutputGateBias) \ - ACTION(ProjectionWeights) \ - ACTION(ProjectionBias) \ - ACTION(OutputStateIn) \ +#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ + ACTION(Input) \ + ACTION(InputToInputWeights) \ + ACTION(InputToCellWeights) \ + ACTION(InputToForgetWeights) \ + ACTION(InputToOutputWeights) \ + ACTION(RecurrentToInputWeights) \ + ACTION(RecurrentToCellWeights) \ + ACTION(RecurrentToForgetWeights) \ + ACTION(RecurrentToOutputWeights) \ + ACTION(CellToInputWeights) \ + ACTION(CellToForgetWeights) \ + ACTION(CellToOutputWeights) \ + ACTION(InputGateBias) \ + ACTION(CellGateBias) \ + ACTION(ForgetGateBias) \ + ACTION(OutputGateBias) \ + ACTION(ProjectionWeights) \ + ACTION(ProjectionBias) \ + ACTION(OutputStateIn) \ ACTION(CellStateIn) // For all output and intermediate states -#define FOR_ALL_OUTPUT_TENSORS(ACTION) \ - ACTION(ScratchBuffer) \ - ACTION(OutputStateOut) \ - ACTION(CellStateOut) \ - ACTION(Output) \ +#define FOR_ALL_OUTPUT_TENSORS(ACTION) \ + ACTION(ScratchBuffer) \ + ACTION(OutputStateOut) \ + ACTION(CellStateOut) \ + ACTION(Output) class LSTMOpModel { -public: - LSTMOpModel(uint32_t n_batch, uint32_t n_input, - uint32_t n_cell, uint32_t n_output, bool use_cifg, - bool use_peephole, bool use_projection_weights, + public: + LSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output, + bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector<std::vector<uint32_t>>& input_shapes0) : n_input_(n_input), n_output_(n_output), - use_cifg_(use_cifg), use_peephole_(use_peephole), + use_cifg_(use_cifg), + use_peephole_(use_peephole), use_projection_weights_(use_projection_weights), use_projection_bias_(use_projection_bias), activation_(ActivationFn::kActivationTanh), - cell_clip_(cell_clip), proj_clip_(proj_clip) { + cell_clip_(cell_clip), + proj_clip_(proj_clip) { std::vector<uint32_t> inputs; std::vector<std::vector<uint32_t>> input_shapes(input_shapes0); @@ -94,9 +95,9 @@ public: auto it = input_shapes.begin(); // Input and weights -#define AddInput(X) \ - OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \ - inputs.push_back(model_.addOperand(&X##OpndTy)); +#define AddInput(X) \ + OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \ + inputs.push_back(model_.addOperand(&X##OpndTy)); FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput); @@ -112,18 +113,18 @@ public: // Output and other intermediate state std::vector<std::vector<uint32_t>> output_shapes{ - {n_batch, n_cell * (use_cifg ? 3 : 4)}, - {n_batch, n_output}, - {n_batch, n_cell}, - {n_batch, n_output}, + {n_batch, n_cell * (use_cifg ? 3 : 4)}, + {n_batch, n_output}, + {n_batch, n_cell}, + {n_batch, n_output}, }; std::vector<uint32_t> outputs; auto it2 = output_shapes.begin(); -#define AddOutput(X)\ - OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \ - outputs.push_back(model_.addOperand(&X##OpndTy)); +#define AddOutput(X) \ + OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \ + outputs.push_back(model_.addOperand(&X##OpndTy)); FOR_ALL_OUTPUT_TENSORS(AddOutput); @@ -136,9 +137,11 @@ public: OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f); CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f); - auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t { + auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t { uint32_t sz = 1; - for(uint32_t d:dims) { sz *= d; } + for (uint32_t d : dims) { + sz *= d; + } return sz; }; @@ -153,10 +156,8 @@ public: model_.finish(); } -#define DefineSetter(X) \ - void Set##X(const std::vector<float> &f) { \ - X##_.insert(X##_.end(), f.begin(), f.end()); \ - } +#define DefineSetter(X) \ + void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); } FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); @@ -172,8 +173,8 @@ public: std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f); } - void SetInput(int offset, float *begin, float *end) { - for (;begin != end; begin++, offset++) { + void SetInput(int offset, float* begin, float* end) { + for (; begin != end; begin++, offset++) { Input_[offset] = *begin; } } @@ -181,7 +182,7 @@ public: uint32_t num_inputs() const { return n_input_; } uint32_t num_outputs() const { return n_output_; } - const std::vector<float> &GetOutput() const { return Output_; } + const std::vector<float>& GetOutput() const { return Output_; } void Invoke() { ASSERT_TRUE(model_.isValid()); @@ -192,19 +193,19 @@ public: Compilation compilation(&model_); compilation.finish(); Execution execution(&compilation); -#define SetInputOrWeight(X) \ - ASSERT_EQ(execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), \ - sizeof(float)*X##_.size()), \ - Result::NO_ERROR); +#define SetInputOrWeight(X) \ + ASSERT_EQ( \ + execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \ + Result::NO_ERROR); FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); #undef SetInputOrWeight -#define SetOutput(X) \ - ASSERT_EQ(execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), \ - sizeof(float)*X##_.size()), \ - Result::NO_ERROR); +#define SetOutput(X) \ + ASSERT_EQ( \ + execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \ + Result::NO_ERROR); FOR_ALL_OUTPUT_TENSORS(SetOutput); @@ -234,20 +235,17 @@ public: execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0); } - ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, - &activation_, sizeof(activation_)), + ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)), Result::NO_ERROR); - ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, - &cell_clip_, sizeof(cell_clip_)), + ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)), Result::NO_ERROR); - ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, - &proj_clip_, sizeof(proj_clip_)), + ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)), Result::NO_ERROR); ASSERT_EQ(execution.compute(), Result::NO_ERROR); } -private: + private: Model model_; // Execution execution_; const uint32_t n_input_; @@ -262,8 +260,7 @@ private: const float cell_clip_; const float proj_clip_; -#define DefineTensor(X) \ - std::vector<float> X##_; +#define DefineTensor(X) std::vector<float> X##_; FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); FOR_ALL_OUTPUT_TENSORS(DefineTensor); @@ -272,843 +269,741 @@ private: }; TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { - const int n_batch = 1; - const int n_input = 2; - // n_cell and n_output have the same size when there is no projection. - const int n_cell = 4; - const int n_output = 4; - - LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, - /*use_cifg=*/false, /*use_peephole=*/false, - /*use_projection_weights=*/false, - /*use_projection_bias=*/false, - /*cell_clip=*/0.0, /*proj_clip=*/0.0, - { - {n_batch, n_input}, // input tensor - - {n_cell, n_input}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {0}, // cell_to_input_weight tensor - {0}, // cell_to_forget_weight tensor - {0}, // cell_to_output_weight tensor - - {n_cell}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {0, 0}, // projection_weight tensor - {0}, // projection_bias tensor - }); - - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); - - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); - - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); - - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); - - lstm.SetInputGateBias({0., 0., 0., 0.}); - - lstm.SetCellGateBias({0., 0., 0., 0.}); - - lstm.SetForgetGateBias({1., 1., 1., 1.}); - - lstm.SetOutputGateBias({0., 0., 0., 0.}); - - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); - - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); - - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); - - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); - - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; - - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); - - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector<float> expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, + -0.15680569, -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, -0.20583314, + 0.44344562, 0.22077113, -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, -0.31343272, -0.40032279, + 0.44781327, 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, 0.40525138, 0.44272184, + 0.03897077, -0.1556896, 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellGateBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights({-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, + 0.08183324, -0.16555229, 0.02286911, -0.13566875, 0.03034258, + 0.48091322, -0.12528998, 0.24077177, -0.51332325, -0.33502164, + 0.10629296}); + + lstm.SetRecurrentToCellWeights({-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, + -0.00123841, -0.4744786, -0.35869038, -0.06418842, -0.13502428, + -0.501764, 0.22830659, -0.46367589, 0.26016325, -0.03894562, + -0.16368064}); + + lstm.SetRecurrentToForgetWeights({-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, + 0.20864892, -0.07646349, 0.45877004, 0.00141793, -0.14609534, + 0.36447752, 0.09196436, 0.28053468, 0.01560611, -0.20127171, + -0.01140004}); + + lstm.SetRecurrentToOutputWeights({0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, + -0.39835793, 0.18212086, 0.01301402, 0.48572797, -0.50656658, + 0.20047462, -0.20607421, -0.51818722, -0.15390486, 0.0468148, + 0.39922136}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector<float> expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } } - TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { - const int n_batch = 1; - const int n_input = 2; - // n_cell and n_output have the same size when there is no projection. - const int n_cell = 4; - const int n_output = 4; - - LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, - /*use_cifg=*/true, /*use_peephole=*/true, - /*use_projection_weights=*/false, - /*use_projection_bias=*/false, - /*cell_clip=*/0.0, /*proj_clip=*/0.0, - { - {n_batch, n_input}, // input tensor - - {0, 0}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {0, 0}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {0}, // cell_to_input_weight tensor - {n_cell}, // cell_to_forget_weight tensor - {n_cell}, // cell_to_output_weight tensor - - {n_cell}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {0, 0}, // projection_weight tensor - {0}, // projection_bias tensor - }); - - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); - - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); - - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); - - lstm.SetCellGateBias({0., 0., 0., 0.}); - - lstm.SetForgetGateBias({1., 1., 1., 1.}); - - lstm.SetOutputGateBias({0., 0., 0., 0.}); - - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); - - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); - - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); - - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); - - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; - - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); - - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector<float> expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, 0.04717243, + 0.48944736, -0.38535351, -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, -0.3633365, -0.22755712, + 0.28253698, 0.24407166, 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, -0.09426838, -0.44257352, + 0.54939759, 0.01533556, 0.42751634}); + + lstm.SetCellGateBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights({0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights({-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, + 0.17751795, -0.34484994, -0.35874045, -0.11352962, 0.27268326, + 0.54058349}); + + lstm.SetRecurrentToOutputWeights({0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, + -0.05115908, -0.33941799, 0.23364776, 0.11178309, 0.09481031, + -0.26424935, 0.46261835, 0.50248802, 0.26114327, -0.43736315, + 0.33149987}); + + lstm.SetCellToForgetWeights({0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights({-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector<float> expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } } TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { - const int n_batch = 2; - const int n_input = 5; - const int n_cell = 20; - const int n_output = 16; - - LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, - /*use_cifg=*/false, /*use_peephole=*/true, - /*use_projection_weights=*/true, - /*use_projection_bias=*/false, - /*cell_clip=*/0.0, /*proj_clip=*/0.0, - { - {n_batch, n_input}, // input tensor - - {n_cell, n_input}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {n_cell}, // cell_to_input_weight tensor - {n_cell}, // cell_to_forget_weight tensor - {n_cell}, // cell_to_output_weight tensor - - {n_cell}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {n_output, n_cell}, // projection_weight tensor - {0}, // projection_bias tensor - }); - - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellGateBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; - - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - - const int input_sequence_size = - sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); - - lstm.SetInput(0, batch0_start, batch0_end); - - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); - - lstm.Invoke(); - - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - std::vector<float> expected; - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights( + {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, 0.09171803, + 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, -0.2726754, 0.10154029, + -0.018539885, 0.080349885, -0.10262385, -0.022599787, -0.09121155, -0.008675967, + -0.045206103, -0.0821282, -0.008045952, 0.015478081, 0.055217247, 0.038719587, + 0.044153627, -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, 0.25005487, + -0.22790983, 0.009855087, -0.028140958, -0.11200698, 0.11295408, -0.0035217577, + 0.054485075, 0.05184695, 0.064711206, 0.10989193, 0.11674786, 0.03490607, + 0.07727357, 0.11390585, -0.1863375, -0.1034451, -0.13945189, -0.049401227, + -0.18767063, 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, -0.042484224, + -0.11827596, -0.09171104, -0.10808628, -0.16327988, -0.2273378, -0.0993647, + -0.017155107, 0.0023917493, 0.049272764, 0.0038534778, 0.054764505, 0.089753784, + 0.06947234, 0.08014476, -0.04544234, -0.0497073, -0.07135631, -0.048929106, + -0.004042012, -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, -0.39292613, + -0.18519334, -0.11651281, -0.06809892, 0.011373677}); + + lstm.SetInputToForgetWeights( + {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, -0.016726194, + -0.05249759, -0.10204261, 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, -0.01909731, -0.017883534, + -0.0047269356, -0.045103323, 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, 0.06958324, 0.034257296, + 0.0482646, 0.06267997, 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, -0.06277783, -0.08833636, + -0.040120605, -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, -0.08402166, + -0.01901462, -0.044678304, -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, 0.052960336, 0.035755895, + 0.035839386, -0.012540553, 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); + + lstm.SetInputToCellWeights( + {-0.04580283, -0.09549462, -0.032418985, -0.06454633, -0.043528453, 0.043018587, + -0.049152344, -0.12418144, -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, -0.025034338, -0.0028890965, + 0.048929527, 0.06235075, 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, 0.0047691227, -0.0025825808, + 0.066017866, 0.029991534, -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, -0.025174323, 0.0396852, + 0.081777506, 0.06157468, 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, 0.053568836, 0.06408714, + 0.12835667, -0.008714329, -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, -0.036999565, -0.028842626, + -0.0033637602, -0.017012902, -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, -0.021742892, -0.023377212, + -0.07221364, -0.06430552, 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, 0.07463894, 0.0075130584, + 0.012850982, 0.04555431, 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); + + lstm.SetInputToOutputWeights( + {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, -0.07650751, + 0.02359855, -0.075155355, -0.08037709, -0.15093534, 0.029517552, -0.04751393, + 0.010350531, -0.02664851, -0.016839722, -0.023121163, 0.0077019283, 0.012851257, + -0.05040649, -0.0129761, -0.021737747, -0.038305793, -0.06870586, -0.01481247, + -0.001285394, 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, -0.086665764, + -0.037162706, -0.038880914, -0.035832845, -0.014481564, -0.09825003, -0.12048569, + -0.097665586, -0.05287633, -0.0964047, -0.11366429, 0.035777505, 0.13568819, + 0.052451383, 0.050649304, 0.05798951, -0.021852335, -0.099848844, 0.014740475, + -0.078897946, 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, -0.078907564, + -0.06707616, -0.11844508, -0.09986688, -0.07509403, 0.06263226, 0.14925587, + 0.20188436, 0.12098451, 0.14639415, 0.0015017595, -0.014267382, -0.03417257, + 0.012711468, 0.0028300495, -0.024758482, -0.05098548, -0.0821182, 0.014225672, + 0.021544158, 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, -0.049097303, + -0.017121866, -0.083368234, -0.02332002, -0.0840956}); + + lstm.SetInputGateBias({0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, + -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, + -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, + 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); + + lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, 0.11098921, + 0.15378423, 0.09263801, 0.09790885, 0.09508917, 0.061199076, + 0.07665568, -0.015443159, -0.03499149, 0.046190713, 0.08895977, + 0.10899629, 0.40694186, 0.06030037, 0.012413437, -0.06108739}); + + lstm.SetCellGateBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, -0.1483596, + -0.10639995, -0.091433935, 0.058573797, -0.06809782, -0.07889636, + -0.043246906, -0.09829136, -0.4279842, 0.034901652, 0.18797937, + 0.0075234566, 0.016178843, 0.1749513, 0.13975595, 0.92058027}); + + lstm.SetOutputGateBias({0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, + 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, + 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, + -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); + + lstm.SetRecurrentToInputWeights( + {-0.001374326, -0.078856036, 0.10672688, 0.029162422, -0.11585556, + 0.02557986, -0.13446963, -0.035785314, -0.01244275, 0.025961924, + -0.02337298, -0.044228926, -0.055839065, -0.046598054, -0.010546039, + -0.06900766, 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, -0.14390695, + -0.02916037, 0.000996957, 0.091420636, 0.14283475, -0.07390571, + -0.06402044, 0.062524505, -0.093129106, 0.04860203, -0.08364217, + -0.08119002, 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, 0.12162708, + -0.031923793, -0.014335606, 0.01790974, -0.10650317, -0.0724401, + 0.08554849, -0.05727212, 0.06556731, -0.042729504, -0.043227166, + 0.011683251, -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, -0.07787477, + -0.11576462, 0.017356863, 0.048673786, -0.017577527, -0.05527947, + -0.082487635, -0.040137455, -0.10820036, -0.04666372, 0.022746278, + -0.07851417, 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, 0.02032331, + -0.059686817, -0.0005566496, -0.086984694, 0.040414046, -0.1380399, + 0.094208956, -0.05722982, 0.012092817, -0.04989123, -0.086576, + -0.003399834, -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, 0.09504992, + 0.041799378, -0.049185462, -0.031518843, -0.10516937, 0.026374253, + 0.10058866, -0.0033195973, -0.041975245, 0.0073591834, 0.0033782164, + -0.004325073, -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, -0.011337001, + 0.035530265, -0.010912711, 0.0706555, -0.005894094, 0.051841937, + -0.1401738, -0.02351249, 0.0365468, 0.07590991, 0.08838724, + 0.021681072, -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, -0.048691444, + -0.009579111, 0.07595467, 0.11480546, -0.09801813, 0.019894179, + 0.08502348, 0.004032281, 0.037211012, 0.068537936, -0.048005626, + -0.091520436, -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, 0.010889619, + 0.0047078193, 0.038385306, 0.08540671, -0.017140968, -0.0035865551, + 0.016678626, 0.005633034, 0.015963363, 0.00871737, 0.060130805, + 0.028611384, 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, 0.030737216, + -0.0046374933, 0.14215417, -0.11823516, 0.019899689, 0.006106124, + -0.027092824, 0.0786356, 0.05052217, -0.058925, -0.011402121, + -0.024987547, -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, -0.033664223, + -0.07978348, -0.025200296, -0.017207067, -0.058403496, -0.055697463, + 0.005798788, 0.12965427, -0.062582195, 0.0013350133, -0.10482091, + 0.0379771, 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, -0.017081132, + 0.019358726, 0.0027079724, 0.004635139, 0.062634714, -0.02338735, + -0.039547626, -0.02050681, 0.03385117, -0.083611414, 0.002862572, + -0.09421313, 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, -0.07314807, + -0.056307215, -0.10433547, -0.06440842, 0.04328182, 0.04389765, + -0.020006588, -0.09076438, -0.11652589, -0.021705797, 0.03345259, + -0.010329105, -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, 0.1305557, + 0.058638252, -0.03393652, 0.09622831, -0.16253184, -2.4580743e-06, + 0.079869635, -0.070196845, -0.005644518, 0.06857898, -0.12598175, + -0.035084512, 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, 0.040170413, + -0.062104587, -0.0037324072, 0.0554317, 0.08184801, -0.019164372, + 0.06791302, 0.034257166, -0.10307039, 0.021943003, 0.046745934, + 0.0790918, -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, -0.014512694, + -0.08251313, 0.08861942, 0.13589665, 0.026351685, 0.012641483, + 0.07466548, 0.044301085, -0.045414884, -0.051112458, 0.03444247, + -0.08502782, -0.04106223, -0.028126027, 0.028473156, 0.10467447}); + + lstm.SetRecurrentToForgetWeights( + {-0.057784554, -0.026057621, -0.068447545, -0.022581743, 0.14811787, + 0.10826372, 0.09471067, 0.03987225, -0.0039523416, 0.00030638507, + 0.053185795, 0.10572994, 0.08414449, -0.022036452, -0.00066928595, + -0.09203576, 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, -0.06193199, + 0.055729095, 0.03736828, 0.020123724, 0.061878487, -0.04729229, + 0.034919553, -0.07585433, -0.04421272, -0.044019096, 0.085488975, + 0.04058006, -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, -0.04899627, + 0.0241671, 0.015736353, -0.095442444, -0.029564252, 0.016493602, + -0.035026584, 0.022337519, -0.026871363, 0.004780428, 0.0077918363, + -0.03601621, 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, -0.046250615, + -0.01847454, 0.047608484, 0.07339695, 0.034546845, -0.04881143, + 0.009128804, -0.08802852, 0.03761666, 0.008096139, -0.014454086, + 0.014361001, -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, 0.060212336, + 0.055259194, 0.06974018, 0.049454916, -0.027794661, -0.08077226, + -0.016179763, 0.1169753, 0.17213494, -0.0056326236, -0.053934924, + -0.0124349, -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, -0.05695512, + 0.047233116, 0.038937137, -0.06542224, 0.014429736, -0.09719407, + 0.13908425, -0.05379757, 0.012321099, 0.082840554, -0.029899208, + 0.044217527, 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, -0.13873616, + 0.040668588, 0.034832682, -0.015319203, -0.018715994, 0.046002675, + 0.0599172, -0.043107376, 0.0294216, -0.002314414, -0.022424703, + 0.0030315618, 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, 0.052958444, + 0.07558703, 0.04817258, 0.044462286, -0.015213451, -0.08783778, + -0.0561384, -0.003008196, 0.047060397, -0.002058388, 0.03429439, + -0.018839769, 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, -0.02558259, + -0.022822596, -0.023273505, -0.02464396, -0.10991725, -0.006240552, + 0.0074488563, 0.024044557, 0.04383914, -0.046476185, 0.028658995, + 0.060410924, 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, 0.015898481, + 0.021362653, -0.030262267, 0.016587038, -0.011442813, 0.041154444, + -0.007631438, -0.03423484, -0.010977775, 0.036152758, 0.0066366293, + 0.11915515, 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, -0.03672871, + 0.024019798, 0.014255957, -0.05221243, -0.00661567, -0.04630967, + 0.033188973, 0.10107534, -0.014027541, 0.030796422, -0.10270911, + -0.035999842, 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, -0.03858649, + 0.01849943, 0.13872518, 0.01503974, 0.069941424, -0.06948533, + -0.0088794185, 0.061282158, -0.047401894, 0.03100163, -0.041533746, + -0.10430945, 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, 0.02247216, + -0.0042998926, 0.061146557, -0.10250651, 0.020881841, -0.06747029, + 0.10062043, -0.0023941975, 0.03532124, -0.016341697, 0.09685456, + -0.016764693, 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, 0.010987891, + 0.02357273, -0.00010856845, 0.0099760275, -0.001845119, -0.03551521, + 0.0018358806, 0.05763657, -0.01769146, 0.040995963, 0.02235177, + -0.060430344, 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, -0.030519066, + 0.0060542435, 0.014653856, -0.038836084, 0.04096551, 0.032249358, + -0.08355519, -0.026823482, 0.056386515, -0.010401743, -0.028396193, + 0.08507674, 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, -0.081302024, + 0.017264642, -0.009585969, 0.09491168, -0.051313367, 0.054532815, + -0.014298593, 0.10657464, 0.007076659, 0.10964551, 0.0409152, + 0.008275321, -0.07283536, 0.07937492, 0.04192024, -0.1075027}); + + lstm.SetRecurrentToCellWeights( + {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, 0.055647098, + -0.05713207, -0.05626563, 0.005559383, 0.03375411, -0.025757805, + -0.088049285, 0.06017052, -0.06570978, 0.007384076, 0.035123326, + -0.07920549, 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, -0.027673481, + 0.044746667, 0.028349208, 0.020090483, -0.019443132, -0.030755889, + -0.0040000007, 0.04465846, -0.021585021, 0.0031670958, 0.0053199246, + -0.056117613, -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, -0.046195522, + 0.0422062, -0.005683705, -0.1253618, -0.012925729, -0.04890792, + 0.06985068, 0.037654128, 0.03398274, -0.004781977, 0.007032333, + -0.031787455, 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, -0.048885044, + -0.12722108, 0.035304096, 0.06554885, 0.00972396, -0.039238118, + -0.05159735, -0.11329045, 0.1613692, -0.03750952, 0.06529313, + -0.071974665, -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, 0.015983658, + 0.03406988, -0.06939408, 0.040699873, 0.02111075, 0.09669095, + 0.041345075, -0.08316494, -0.07684199, -0.045768797, 0.032298047, + -0.041805092, 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, 0.06760663, + -0.027437469, 0.07216407, 0.06977076, -0.05438599, 0.034033038, + -0.028602652, 0.05346137, 0.043184172, -0.037189785, 0.10420091, + 0.00882477, -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, 0.04361412, + -0.007001822, 0.09631092, -0.06702025, -0.042049985, -0.035070654, + -0.04103342, -0.10273396, 0.0544271, 0.037184782, -0.13150354, + -0.0058036847, -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, -0.00093483756, + 0.048938446, -0.004952862, -0.007730018, -0.04043371, -0.017094059, + 0.07229206, -0.023670016, -0.052195564, -0.025616996, -0.01520939, + 0.045104615, -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, 0.11868599, + 0.07663395, -0.07323171, 0.03463402, -0.050708205, -0.04458982, + -0.11590894, 0.021273347, 0.1251325, -0.15313013, -0.12224372, + 0.17228661, 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, -0.006041347, + -0.09292539, -0.05678812, 0.03897832, 0.09465633, 0.008115513, + -0.02171956, 0.08304309, 0.071401566, 0.019622514, 0.032163795, + -0.004167056, 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, -0.1335546, + -0.030136576, 0.11584653, -0.014678886, 0.0020118146, -0.09688814, + -0.0790206, 0.039770417, -0.0329582, 0.07922767, 0.029322514, + 0.026405897, 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, -0.09606514, + -0.006669445, -0.043381985, 0.04240257, -0.06955775, -0.06769346, + 0.043903265, -0.026784198, -0.017840602, 0.024307009, -0.040079936, + -0.019946516, 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, -0.039390966, + -0.072174534, 0.0739445, -0.1211869, -0.0347889, -0.07943156, + 0.014809798, -0.12412325, -0.0030663363, 0.039695457, 0.0647603, + -0.08291318, -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, 0.012962922, + -0.031234352, 0.07029052, 0.016418684, 0.03618972, 0.055686004, + -0.08663945, -0.017404709, -0.054761406, 0.029065743, 0.052404847, + 0.020238016, 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, -0.0270329, + -0.03299152, -0.060088247, -0.015162964, -0.001828936, 0.12642565, + -0.056757294, 0.013586685, 0.09232601, -0.035886683, 0.06000002, + 0.05229691, -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, 0.031502828, + 0.036232427, -0.031581745, 0.023051167, -0.05325106, -0.03421577, + 0.028793324, -0.034633752, -0.009881397, -0.043551125, -0.018609839, + 0.0019097115, -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); + + lstm.SetRecurrentToOutputWeights({ + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }); + + lstm.SetCellToInputWeights({0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); + + lstm.SetCellToForgetWeights({-0.01998659, -0.15568835, -0.24248174, -0.012770197, + 0.041331276, -0.072311886, -0.052123554, -0.0066330447, + -0.043891653, 0.036225766, -0.047248036, 0.021479502, + 0.033189066, 0.11952997, -0.020432774, 0.64658105, + -0.06650122, -0.03467612, 0.095340036, 0.23647355}); + + lstm.SetCellToOutputWeights({0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); + + lstm.SetProjectionWeights( + {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, 0.060420845, + 0.08539281, 0.054285463, 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, -0.10959692, -0.013875541, + 0.18600968, -0.061274476, 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, -0.188994, -0.027265169, + 0.043389652, -0.022061434, 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, 0.13221376, 0.009441198, + -0.16715929, 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, 0.050794356, + 0.10770313, -0.20790008, -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, -0.007487823, 0.11000021, + -0.09542012, 0.22619456, -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, -0.004618631, -0.067624845, + -0.007576253, 0.10727444, 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, 0.13708383, 0.09134148, + -0.12617786, -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, -0.042050496, + 0.16842307, -0.060597885, 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, 0.05970546, 0.05197996, + 0.02839995, 0.10434969, -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, 0.07830115, -0.056184657, + 0.03277091, 0.025466874, 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, 0.1776665, -0.07409566, + -0.0477268, 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, 0.045594677, + 0.0635285, -0.0715442, -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, 0.09648481, 0.026996298, + -0.06419476, -0.04752702, -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, 0.10539724, -0.04383912, + -0.042349473, 0.08438151, -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, 0.12633002, 0.11335965, + -0.0088127935, -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, -0.0855457, + 0.099339016, -0.07580735, -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, 0.008753825, 0.03498998, + 0.030215185, 0.03907079, 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, -0.07051762, 0.19488361, + 0.02664691, -0.18104725, -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, 0.11156489, -0.07483202, + 0.06693364, -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, 0.010969227, + 0.11109743, 0.010919218, 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, -0.07000971, 0.026893595, + -0.038844477, 0.14543656}); + + static float lstm_input[][20] = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, + 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, + 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, + 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, + 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; + + static float lstm_golden_output[][64] = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, -0.0211779, + 0.0283512, -0.0114597, 0.00907307, -0.0244004, -0.0152191, -0.0259063, + 0.00914318, 0.00415118, 0.017147, 0.0134203, -0.0166936, 0.0381209, + 0.000889694, 0.0143363, -0.0328911, -0.0234288, 0.0333051, -0.012229, + 0.0110322, -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, 0.0276012, + -0.0263374, -0.0371449, 0.0446149, -0.0205474, 0.0103729, -0.0576349, + -0.0150052, -0.0292043, 0.0376827, 0.0136115, 0.0243435, 0.0354492, + -0.0189322, 0.0464512, -0.00251373, 0.0225745, -0.0308346, -0.0317124, + 0.0460407, -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, -0.0186926, + 0.0193662, -0.0115437, 0.00422612, -0.0345232, 0.00223253, -0.00957321, + 0.0210624, 0.013331, 0.0150954, 0.02168, -0.0141913, 0.0322082, + 0.00227024, 0.0260507, -0.0188721, -0.0296489, 0.0399134, -0.0160509, + 0.0116039, -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, 0.0167673, + -0.0375007, -0.0238314, 0.038784, -0.0174034, 0.0131743, -0.0506589, + -0.0048447, -0.0240239, 0.0325789, 0.00790065, 0.0220157, 0.0333314, + -0.0264787, 0.0387855, -0.000764675, 0.0217599, -0.037537, -0.0335206, + 0.0431679, -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); + float* batch1_end = batch1_start + lstm.num_inputs(); + lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + + lstm.Invoke(); + + float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); + float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); + std::vector<float> expected; + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } } - } // namespace wrapper } // namespace nn } // namespace android diff --git a/nn/common/operations/RNN.cpp b/nn/common/operations/RNN.cpp index dcb5928fa..36473042e 100644 --- a/nn/common/operations/RNN.cpp +++ b/nn/common/operations/RNN.cpp @@ -29,61 +29,55 @@ namespace nn { using namespace hal; -RNN::RNN(const Operation& operation, - std::vector<RunTimeOperandInfo>& operands) { - NNTRACE_TRANS("RNN::RNN"); - input_ = GetInput(operation, operands, kInputTensor); - weights_ = GetInput(operation, operands, kWeightsTensor); - recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor); - hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor); - bias_ = GetInput(operation, operands, kBiasTensor); - - activation_ = static_cast<ActivationFn>( - getScalarData<int32_t>(operands[operation.inputs[kActivationParam]])); - - hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor); - output_ = GetOutput(operation, operands, kOutputTensor); +RNN::RNN(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) { + NNTRACE_TRANS("RNN::RNN"); + input_ = GetInput(operation, operands, kInputTensor); + weights_ = GetInput(operation, operands, kWeightsTensor); + recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor); + hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor); + bias_ = GetInput(operation, operands, kBiasTensor); + + activation_ = static_cast<ActivationFn>( + getScalarData<int32_t>(operands[operation.inputs[kActivationParam]])); + + hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor); + output_ = GetOutput(operation, operands, kOutputTensor); } -bool RNN::Prepare(const Operation &operation, - std::vector<RunTimeOperandInfo> &operands, - Shape *hiddenStateShape, - Shape *outputShape) { - NNTRACE_TRANS("RNN::Prepare"); - // Check we have all the inputs and outputs we need. - const int num_inputs = NumInputsWithValues(operation, operands); - NN_CHECK(num_inputs == 5 || num_inputs == 6); - NN_CHECK_EQ(NumOutputs(operation), 2); - - const RunTimeOperandInfo *input = - GetInput(operation, operands, kInputTensor); - const RunTimeOperandInfo *input_weights = - GetInput(operation, operands, kWeightsTensor); - const RunTimeOperandInfo *recurrent_weights = - GetInput(operation, operands, kRecurrentWeightsTensor); - const RunTimeOperandInfo *bias = - GetInput(operation, operands, kBiasTensor); - - // Check all the parameters of tensor match within themselves and match the - // input configuration. - const uint32_t batch_size = SizeOfDimension(input, 0); - const uint32_t num_units = SizeOfDimension(input_weights, 0); - NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1)); - NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0)); - NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0)); - NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0)); - - const Shape &inputShape = input->shape(); - - // Resize state. - hiddenStateShape->type = inputShape.type; - hiddenStateShape->dimensions = { batch_size, num_units }; - - // Resize output. - outputShape->type = inputShape.type; - outputShape->dimensions = { batch_size, num_units }; - - return true; +bool RNN::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands, + Shape* hiddenStateShape, Shape* outputShape) { + NNTRACE_TRANS("RNN::Prepare"); + // Check we have all the inputs and outputs we need. + const int num_inputs = NumInputsWithValues(operation, operands); + NN_CHECK(num_inputs == 5 || num_inputs == 6); + NN_CHECK_EQ(NumOutputs(operation), 2); + + const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor); + const RunTimeOperandInfo* input_weights = GetInput(operation, operands, kWeightsTensor); + const RunTimeOperandInfo* recurrent_weights = + GetInput(operation, operands, kRecurrentWeightsTensor); + const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor); + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const uint32_t batch_size = SizeOfDimension(input, 0); + const uint32_t num_units = SizeOfDimension(input_weights, 0); + NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1)); + NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0)); + NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0)); + NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0)); + + const Shape& inputShape = input->shape(); + + // Resize state. + hiddenStateShape->type = inputShape.type; + hiddenStateShape->dimensions = {batch_size, num_units}; + + // Resize output. + outputShape->type = inputShape.type; + outputShape->dimensions = {batch_size, num_units}; + + return true; } bool RNN::Eval() { diff --git a/nn/common/operations/RNNTest.cpp b/nn/common/operations/RNNTest.cpp index 332885f95..66acac7cb 100644 --- a/nn/common/operations/RNNTest.cpp +++ b/nn/common/operations/RNNTest.cpp @@ -33,304 +33,271 @@ namespace { std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, float max_abs_error = 1.e-5) { - std::vector<Matcher<float>> matchers; - matchers.reserve(values.size()); - for (const float& v : values) { - matchers.emplace_back(FloatNear(v, max_abs_error)); - } - return matchers; + std::vector<Matcher<float>> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; } static float rnn_input[] = { - 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, - 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, - -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, - 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, - 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, - 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, - -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, - -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, - 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, - 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, - 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, - -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, - 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, - -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, - -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, - -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, - 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, - -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, - -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, - 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, - -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, - 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, - 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, - 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, - -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, - 0.93455386, -0.6324693, -0.083922029}; + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, 0.43773448, + 0.60379338, 0.35562468, -0.69424844, -0.93421471, -0.87287879, 0.37144363, + -0.62476718, 0.23791671, 0.40060222, 0.1356622, -0.99774903, -0.98858172, + -0.38952237, -0.47685933, 0.31073618, 0.71511042, -0.63767755, -0.31729108, + 0.33468103, 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, -0.61777675, + -0.21095741, 0.41213346, 0.73784804, 0.094794154, 0.47791874, 0.86496925, + -0.53376222, 0.85315156, 0.10288584, 0.86684, -0.011186242, 0.10513687, + 0.87825835, 0.59929144, 0.62827742, 0.18899453, 0.31440187, 0.99059987, + 0.87170351, -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, -0.66609079, + 0.59098077, 0.73017097, 0.74604273, 0.32882881, -0.17503482, 0.22396147, + 0.19379807, 0.29120302, 0.077113032, -0.70331609, 0.15804303, -0.93407321, + 0.40182066, 0.036301374, 0.66521823, 0.0300982, -0.7747041, -0.02038002, + 0.020698071, -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, 0.43519354, + 0.14744234, 0.62589407, 0.1653645, -0.10651493, -0.045277178, 0.99032974, + -0.88255352, -0.85147917, 0.28153265, 0.19455957, -0.55479527, -0.56042433, + 0.26048636, 0.84702539, 0.47587705, -0.074295521, -0.12287641, 0.70117295, + 0.90532446, 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, 0.93455386, + -0.6324693, -0.083922029}; static float rnn_golden_output[] = { - 0.496726, 0, 0.965996, 0, 0.0584254, 0, - 0, 0.12315, 0, 0, 0.612266, 0.456601, - 0, 0.52286, 1.16099, 0.0291232, + 0.496726, 0, 0.965996, 0, 0.0584254, 0, 0, 0.12315, + 0, 0, 0.612266, 0.456601, 0, 0.52286, 1.16099, 0.0291232, - 0, 0, 0.524901, 0, 0, 0, - 0, 1.02116, 0, 1.35762, 0, 0.356909, - 0.436415, 0.0355727, 0, 0, + 0, 0, 0.524901, 0, 0, 0, 0, 1.02116, + 0, 1.35762, 0, 0.356909, 0.436415, 0.0355727, 0, 0, - 0, 0, 0, 0.262335, 0, 0, - 0, 1.33992, 0, 2.9739, 0, 0, - 1.31914, 2.66147, 0, 0, + 0, 0, 0, 0.262335, 0, 0, 0, 1.33992, + 0, 2.9739, 0, 0, 1.31914, 2.66147, 0, 0, - 0.942568, 0, 0, 0, 0.025507, 0, - 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, - 0.8158, 1.21805, 0.586239, 0.25427, + 0.942568, 0, 0, 0, 0.025507, 0, 0, 0, + 0.321429, 0.569141, 1.25274, 1.57719, 0.8158, 1.21805, 0.586239, 0.25427, - 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, - 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, - 0, 1.22031, 1.30117, 0.495867, + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, 0.363026, 0, + 0.533426, 0, 1.25926, 0.722707, 0, 1.22031, 1.30117, 0.495867, - 0.222187, 0, 0.72725, 0, 0.767003, 0, - 0, 0.147835, 0, 0, 0, 0.608758, - 0.469394, 0.00720298, 0.927537, 0, + 0.222187, 0, 0.72725, 0, 0.767003, 0, 0, 0.147835, + 0, 0, 0, 0.608758, 0.469394, 0.00720298, 0.927537, 0, - 0.856974, 0.424257, 0, 0, 0.937329, 0, - 0, 0, 0.476425, 0, 0.566017, 0.418462, - 0.141911, 0.996214, 1.13063, 0, + 0.856974, 0.424257, 0, 0, 0.937329, 0, 0, 0, + 0.476425, 0, 0.566017, 0.418462, 0.141911, 0.996214, 1.13063, 0, - 0.967899, 0, 0, 0, 0.0831304, 0, - 0, 1.00378, 0, 0, 0, 1.44818, - 1.01768, 0.943891, 0.502745, 0, + 0.967899, 0, 0, 0, 0.0831304, 0, 0, 1.00378, + 0, 0, 0, 1.44818, 1.01768, 0.943891, 0.502745, 0, - 0.940135, 0, 0, 0, 0, 0, - 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, - 1.30225, 1.59644, 0.70222, 0, + 0.940135, 0, 0, 0, 0, 0, 0, 2.13243, + 0, 0.71208, 0.123918, 1.53907, 1.30225, 1.59644, 0.70222, 0, - 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, - 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, - 0.0454298, 0.300267, 0.562784, 0.395095, + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, 0.343448, 0, + 0.107756, 0.614544, 1.44549, 1.52311, 0.0454298, 0.300267, 0.562784, 0.395095, - 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, - 0, 0, 0, 0.735363, 0.0759267, 1.91017, - 0.941888, 0, 0, 0, + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, 0, 0, + 0, 0.735363, 0.0759267, 1.91017, 0.941888, 0, 0, 0, - 0, 0, 1.5909, 0, 0, 0, - 0, 0.5755, 0, 0.184687, 0, 1.56296, - 0.625285, 0, 0, 0, + 0, 0, 1.5909, 0, 0, 0, 0, 0.5755, + 0, 0.184687, 0, 1.56296, 0.625285, 0, 0, 0, - 0, 0, 0.0857888, 0, 0, 0, - 0, 0.488383, 0.252786, 0, 0, 0, - 1.02817, 1.85665, 0, 0, + 0, 0, 0.0857888, 0, 0, 0, 0, 0.488383, + 0.252786, 0, 0, 0, 1.02817, 1.85665, 0, 0, - 0.00981836, 0, 1.06371, 0, 0, 0, - 0, 0, 0, 0.290445, 0.316406, 0, - 0.304161, 1.25079, 0.0707152, 0, + 0.00981836, 0, 1.06371, 0, 0, 0, 0, 0, + 0, 0.290445, 0.316406, 0, 0.304161, 1.25079, 0.0707152, 0, - 0.986264, 0.309201, 0, 0, 0, 0, - 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, - 0.524981, 1.92076, 2.07013, 0.333244, + 0.986264, 0.309201, 0, 0, 0, 0, 0, 1.64896, + 0.346248, 0, 0.918175, 0.78884, 0.524981, 1.92076, 2.07013, 0.333244, - 0.415153, 0.210318, 0, 0, 0, 0, - 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, - 0.628881, 3.58099, 1.49974, 0}; + 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616, + 0, 0.728256, 0.84183, 0.0907453, 0.628881, 3.58099, 1.49974, 0}; } // anonymous namespace #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ - ACTION(Input) \ - ACTION(Weights) \ - ACTION(RecurrentWeights) \ - ACTION(Bias) \ - ACTION(HiddenStateIn) + ACTION(Input) \ + ACTION(Weights) \ + ACTION(RecurrentWeights) \ + ACTION(Bias) \ + ACTION(HiddenStateIn) // For all output and intermediate states #define FOR_ALL_OUTPUT_TENSORS(ACTION) \ - ACTION(HiddenStateOut) \ - ACTION(Output) + ACTION(HiddenStateOut) \ + ACTION(Output) class BasicRNNOpModel { - public: - BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size) - : batches_(batches), - units_(units), - input_size_(size), - activation_(kActivationRelu) { - std::vector<uint32_t> inputs; - - OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_}); - inputs.push_back(model_.addOperand(&InputTy)); - OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_}); - inputs.push_back(model_.addOperand(&WeightTy)); - OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_}); - inputs.push_back(model_.addOperand(&RecurrentWeightTy)); - OperandType BiasTy(Type::TENSOR_FLOAT32, {units_}); - inputs.push_back(model_.addOperand(&BiasTy)); - OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_}); - inputs.push_back(model_.addOperand(&HiddenStateTy)); - OperandType ActionParamTy(Type::INT32, {}); - inputs.push_back(model_.addOperand(&ActionParamTy)); - - std::vector<uint32_t> outputs; - - outputs.push_back(model_.addOperand(&HiddenStateTy)); - OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_}); - outputs.push_back(model_.addOperand(&OutputTy)); - - Input_.insert(Input_.end(), batches_ * input_size_, 0.f); - HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f); - HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f); - Output_.insert(Output_.end(), batches_ * units_, 0.f); - - model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs); - model_.identifyInputsAndOutputs(inputs, outputs); - - model_.finish(); - } - -#define DefineSetter(X) \ - void Set##X(const std::vector<float>& f) { \ - X##_.insert(X##_.end(), f.begin(), f.end()); \ - } - - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); + public: + BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size) + : batches_(batches), units_(units), input_size_(size), activation_(kActivationRelu) { + std::vector<uint32_t> inputs; + + OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_}); + inputs.push_back(model_.addOperand(&InputTy)); + OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_}); + inputs.push_back(model_.addOperand(&WeightTy)); + OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_}); + inputs.push_back(model_.addOperand(&RecurrentWeightTy)); + OperandType BiasTy(Type::TENSOR_FLOAT32, {units_}); + inputs.push_back(model_.addOperand(&BiasTy)); + OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_}); + inputs.push_back(model_.addOperand(&HiddenStateTy)); + OperandType ActionParamTy(Type::INT32, {}); + inputs.push_back(model_.addOperand(&ActionParamTy)); + + std::vector<uint32_t> outputs; + + outputs.push_back(model_.addOperand(&HiddenStateTy)); + OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_}); + outputs.push_back(model_.addOperand(&OutputTy)); + + Input_.insert(Input_.end(), batches_ * input_size_, 0.f); + HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f); + HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f); + Output_.insert(Output_.end(), batches_ * units_, 0.f); + + model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs); + model_.identifyInputsAndOutputs(inputs, outputs); + + model_.finish(); + } + +#define DefineSetter(X) \ + void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); } + + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); #undef DefineSetter - void SetInput(int offset, float* begin, float* end) { - for (; begin != end; begin++, offset++) { - Input_[offset] = *begin; + void SetInput(int offset, float* begin, float* end) { + for (; begin != end; begin++, offset++) { + Input_[offset] = *begin; + } } - } - void ResetHiddenState() { - std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f); - std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f); - } + void ResetHiddenState() { + std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f); + std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f); + } - const std::vector<float>& GetOutput() const { return Output_; } + const std::vector<float>& GetOutput() const { return Output_; } - uint32_t input_size() const { return input_size_; } - uint32_t num_units() const { return units_; } - uint32_t num_batches() const { return batches_; } + uint32_t input_size() const { return input_size_; } + uint32_t num_units() const { return units_; } + uint32_t num_batches() const { return batches_; } - void Invoke() { - ASSERT_TRUE(model_.isValid()); + void Invoke() { + ASSERT_TRUE(model_.isValid()); - HiddenStateIn_.swap(HiddenStateOut_); + HiddenStateIn_.swap(HiddenStateOut_); - Compilation compilation(&model_); - compilation.finish(); - Execution execution(&compilation); -#define SetInputOrWeight(X) \ - ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), \ - sizeof(float) * X##_.size()), \ - Result::NO_ERROR); + Compilation compilation(&model_); + compilation.finish(); + Execution execution(&compilation); +#define SetInputOrWeight(X) \ + ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \ + Result::NO_ERROR); - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); #undef SetInputOrWeight -#define SetOutput(X) \ - ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), \ - sizeof(float) * X##_.size()), \ - Result::NO_ERROR); +#define SetOutput(X) \ + ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \ + Result::NO_ERROR); - FOR_ALL_OUTPUT_TENSORS(SetOutput); + FOR_ALL_OUTPUT_TENSORS(SetOutput); #undef SetOutput - ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, - sizeof(activation_)), - Result::NO_ERROR); + ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, sizeof(activation_)), + Result::NO_ERROR); - ASSERT_EQ(execution.compute(), Result::NO_ERROR); - } + ASSERT_EQ(execution.compute(), Result::NO_ERROR); + } - private: - Model model_; + private: + Model model_; - const uint32_t batches_; - const uint32_t units_; - const uint32_t input_size_; + const uint32_t batches_; + const uint32_t units_; + const uint32_t input_size_; - const int activation_; + const int activation_; #define DefineTensor(X) std::vector<float> X##_; - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); - FOR_ALL_OUTPUT_TENSORS(DefineTensor); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); + FOR_ALL_OUTPUT_TENSORS(DefineTensor); #undef DefineTensor }; TEST(RNNOpTest, BlackBoxTest) { - BasicRNNOpModel rnn(2, 16, 8); - rnn.SetWeights( - {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, - 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, - 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, - -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, - -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, - -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, - -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, - 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, - 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, - 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, - -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, - 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, - -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, - -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, - 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, - 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, - 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, - -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, - 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, - 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, - -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, - 0.277308, 0.415818}); - - rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, - -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, - 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, - -0.37609905}); - - rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1}); - - rnn.ResetHiddenState(); - const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / - (rnn.input_size() * rnn.num_batches()); - - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = rnn_input + i * rnn.input_size(); - float* batch_end = batch_start + rnn.input_size(); - rnn.SetInput(0, batch_start, batch_end); - rnn.SetInput(rnn.input_size(), batch_start, batch_end); - - rnn.Invoke(); - - float* golden_start = rnn_golden_output + i * rnn.num_units(); - float* golden_end = golden_start + rnn.num_units(); - std::vector<float> expected; - expected.insert(expected.end(), golden_start, golden_end); - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + BasicRNNOpModel rnn(2, 16, 8); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 0.317493, + 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 0.448504, 0.317662, + 0.523556, -0.323514, 0.480877, 0.333113, -0.757714, -0.674487, -0.643585, + 0.217766, -0.0251462, 0.79512, -0.595574, -0.422444, 0.371572, -0.452178, + -0.556069, -0.482188, -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, + 0.729158, -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 0.306261, + -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 0.0354295, 0.566564, + -0.485469, -0.620498, 0.832546, 0.697884, -0.279115, 0.294415, -0.584313, + 0.548772, 0.0648819, 0.968726, 0.723834, -0.0080452, -0.350386, -0.272803, + 0.115121, -0.412644, -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, + -0.423461, -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 0.0960841, + 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 0.37225, -0.623598, + -0.405423, 0.455101, 0.673656, -0.145345, -0.511346, -0.901675, -0.81252, + -0.127006, 0.809865, -0.721884, 0.636255, 0.868989, -0.347973, -0.10179, + -0.777449, 0.917274, 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, + 0.972934, -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, 0.37197268, + 0.61957061, 0.3956964, -0.37609905}); + + rnn.SetRecurrentWeights( + {0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0.1}); + + rnn.ResetHiddenState(); + const int input_sequence_size = + sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector<float> expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } } } // namespace wrapper diff --git a/nn/common/operations/SVDF.cpp b/nn/common/operations/SVDF.cpp index 844361da2..4009b7735 100644 --- a/nn/common/operations/SVDF.cpp +++ b/nn/common/operations/SVDF.cpp @@ -29,8 +29,7 @@ namespace nn { using namespace hal; -SVDF::SVDF(const Operation& operation, - std::vector<RunTimeOperandInfo>& operands) { +SVDF::SVDF(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) { NNTRACE_TRANS("SVDF::SVDF"); input_ = GetInput(operation, operands, kInputTensor); weights_feature_ = GetInput(operation, operands, kWeightsFeatureTensor); @@ -39,62 +38,58 @@ SVDF::SVDF(const Operation& operation, state_in_ = GetInput(operation, operands, kStateInTensor); params_.rank_ = getScalarData<int>(*GetInput(operation, operands, kRankParam)); - params_.activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int>( - *GetInput(operation, operands, kActivationParam))); + params_.activation_ = static_cast<TfLiteFusedActivation>( + getScalarData<int>(*GetInput(operation, operands, kActivationParam))); state_out_ = GetOutput(operation, operands, kStateOutTensor); output_ = GetOutput(operation, operands, kOutputTensor); } -bool SVDF::Prepare(const Operation &operation, - std::vector<RunTimeOperandInfo> &operands, - Shape *stateShape, - Shape *outputShape) { - NNTRACE_TRANS("SVDF::Prepare"); - // Check we have all the inputs and outputs we need. - const int num_inputs = NumInputsWithValues(operation, operands); - - NN_CHECK(num_inputs == 6 || num_inputs == 7); - NN_CHECK_EQ(NumOutputs(operation), 2); - - const RunTimeOperandInfo *input = - GetInput(operation, operands, SVDF::kInputTensor); - const RunTimeOperandInfo *weights_feature = - GetInput(operation, operands, SVDF::kWeightsFeatureTensor); - const RunTimeOperandInfo *weights_time = - GetInput(operation, operands, SVDF::kWeightsTimeTensor); - - // Check all the parameters of tensor match within themselves and match the - // input configuration. - const int rank = getScalarData<int>(*GetInput(operation, operands, kRankParam)); - const uint32_t batch_size = SizeOfDimension(input, 0); - const uint32_t num_filters = SizeOfDimension(weights_feature, 0); - NN_CHECK_EQ(num_filters % rank, 0); - const uint32_t num_units = num_filters / rank; - const uint32_t memory_size = SizeOfDimension(weights_time, 1); - NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(weights_feature, 1)); - NN_CHECK_EQ(SizeOfDimension(weights_time, 0), num_filters); - - const RunTimeOperandInfo *bias = - GetInput(operation, operands, kBiasTensor); - if (!IsNullInput(bias)) { - NN_CHECK_EQ(SizeOfDimension(bias, 0), num_units); - } - - // Resize state. - const Shape &inputShape = input->shape(); - stateShape->type = inputShape.type; - stateShape->dimensions = { batch_size, memory_size * num_filters }; - stateShape->offset = inputShape.offset; - stateShape->scale = inputShape.scale; - - // Resize output. - outputShape->type = inputShape.type; - outputShape->dimensions = { batch_size, num_units }; - outputShape->offset = inputShape.offset; - outputShape->scale = inputShape.scale; - - return true; +bool SVDF::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands, + Shape* stateShape, Shape* outputShape) { + NNTRACE_TRANS("SVDF::Prepare"); + // Check we have all the inputs and outputs we need. + const int num_inputs = NumInputsWithValues(operation, operands); + + NN_CHECK(num_inputs == 6 || num_inputs == 7); + NN_CHECK_EQ(NumOutputs(operation), 2); + + const RunTimeOperandInfo* input = GetInput(operation, operands, SVDF::kInputTensor); + const RunTimeOperandInfo* weights_feature = + GetInput(operation, operands, SVDF::kWeightsFeatureTensor); + const RunTimeOperandInfo* weights_time = + GetInput(operation, operands, SVDF::kWeightsTimeTensor); + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int rank = getScalarData<int>(*GetInput(operation, operands, kRankParam)); + const uint32_t batch_size = SizeOfDimension(input, 0); + const uint32_t num_filters = SizeOfDimension(weights_feature, 0); + NN_CHECK_EQ(num_filters % rank, 0); + const uint32_t num_units = num_filters / rank; + const uint32_t memory_size = SizeOfDimension(weights_time, 1); + NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(weights_feature, 1)); + NN_CHECK_EQ(SizeOfDimension(weights_time, 0), num_filters); + + const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor); + if (!IsNullInput(bias)) { + NN_CHECK_EQ(SizeOfDimension(bias, 0), num_units); + } + + // Resize state. + const Shape& inputShape = input->shape(); + stateShape->type = inputShape.type; + stateShape->dimensions = {batch_size, memory_size * num_filters}; + stateShape->offset = inputShape.offset; + stateShape->scale = inputShape.scale; + + // Resize output. + outputShape->type = inputShape.type; + outputShape->dimensions = {batch_size, num_units}; + outputShape->offset = inputShape.offset; + outputShape->scale = inputShape.scale; + + return true; } bool SVDF::Eval() { diff --git a/nn/common/operations/SVDFTest.cpp b/nn/common/operations/SVDFTest.cpp index 864c2eb8a..21f769fc2 100644 --- a/nn/common/operations/SVDFTest.cpp +++ b/nn/common/operations/SVDFTest.cpp @@ -30,418 +30,389 @@ namespace wrapper { namespace { std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, - float max_abs_error=1.e-6) { - std::vector<Matcher<float>> matchers; - matchers.reserve(values.size()); - for (const float& v : values) { - matchers.emplace_back(FloatNear(v, max_abs_error)); - } - return matchers; + float max_abs_error = 1.e-6) { + std::vector<Matcher<float>> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; } } // namespace using ::testing::ElementsAreArray; -static float svdf_input[] = {0.12609188, -0.46347019, -0.89598465, - 0.12609188, -0.46347019, -0.89598465, +static float svdf_input[] = { + 0.12609188, -0.46347019, -0.89598465, 0.12609188, -0.46347019, -0.89598465, - 0.14278367, -1.64410412, -0.75222826, - 0.14278367, -1.64410412, -0.75222826, + 0.14278367, -1.64410412, -0.75222826, 0.14278367, -1.64410412, -0.75222826, - 0.49837467, 0.19278903, 0.26584083, - 0.49837467, 0.19278903, 0.26584083, + 0.49837467, 0.19278903, 0.26584083, 0.49837467, 0.19278903, 0.26584083, - -0.11186574, 0.13164264, -0.05349274, - -0.11186574, 0.13164264, -0.05349274, + -0.11186574, 0.13164264, -0.05349274, -0.11186574, 0.13164264, -0.05349274, - -0.68892461, 0.37783599, 0.18263303, - -0.68892461, 0.37783599, 0.18263303, + -0.68892461, 0.37783599, 0.18263303, -0.68892461, 0.37783599, 0.18263303, - -0.81299269, -0.86831826, 1.43940818, - -0.81299269, -0.86831826, 1.43940818, + -0.81299269, -0.86831826, 1.43940818, -0.81299269, -0.86831826, 1.43940818, - -1.45006323, -0.82251364, -1.69082689, - -1.45006323, -0.82251364, -1.69082689, + -1.45006323, -0.82251364, -1.69082689, -1.45006323, -0.82251364, -1.69082689, - 0.03966608, -0.24936394, -0.77526885, - 0.03966608, -0.24936394, -0.77526885, + 0.03966608, -0.24936394, -0.77526885, 0.03966608, -0.24936394, -0.77526885, - 0.11771342, -0.23761693, -0.65898693, - 0.11771342, -0.23761693, -0.65898693, + 0.11771342, -0.23761693, -0.65898693, 0.11771342, -0.23761693, -0.65898693, - -0.89477462, 1.67204106, -0.53235275, - -0.89477462, 1.67204106, -0.53235275}; + -0.89477462, 1.67204106, -0.53235275, -0.89477462, 1.67204106, -0.53235275}; static float svdf_input_rank2[] = { - 0.12609188, -0.46347019, -0.89598465, - 0.35867718, 0.36897406, 0.73463392, + 0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406, 0.73463392, - 0.14278367, -1.64410412, -0.75222826, - -0.57290924, 0.12729003, 0.7567004, + 0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003, 0.7567004, - 0.49837467, 0.19278903, 0.26584083, - 0.17660543, 0.52949083, -0.77931279, + 0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279, - -0.11186574, 0.13164264, -0.05349274, - -0.72674477, -0.5683046, 0.55900657, + -0.11186574, 0.13164264, -0.05349274, -0.72674477, -0.5683046, 0.55900657, - -0.68892461, 0.37783599, 0.18263303, - -0.63690937, 0.44483393, -0.71817774, + -0.68892461, 0.37783599, 0.18263303, -0.63690937, 0.44483393, -0.71817774, - -0.81299269, -0.86831826, 1.43940818, - -0.95760226, 1.82078898, 0.71135032, + -0.81299269, -0.86831826, 1.43940818, -0.95760226, 1.82078898, 0.71135032, - -1.45006323, -0.82251364, -1.69082689, - -1.65087092, -1.89238167, 1.54172635, + -1.45006323, -0.82251364, -1.69082689, -1.65087092, -1.89238167, 1.54172635, - 0.03966608, -0.24936394, -0.77526885, - 2.06740379, -1.51439476, 1.43768692, + 0.03966608, -0.24936394, -0.77526885, 2.06740379, -1.51439476, 1.43768692, - 0.11771342, -0.23761693, -0.65898693, - 0.31088525, -1.55601168, -0.87661445, + 0.11771342, -0.23761693, -0.65898693, 0.31088525, -1.55601168, -0.87661445, - -0.89477462, 1.67204106, -0.53235275, - -0.6230064, 0.29819036, 1.06939757, + -0.89477462, 1.67204106, -0.53235275, -0.6230064, 0.29819036, 1.06939757, }; -static float svdf_golden_output[] = { - 0.014899, -0.0517661, -0.143725, -0.00271883, - 0.014899, -0.0517661, -0.143725, -0.00271883, +static float svdf_golden_output[] = {0.014899, -0.0517661, -0.143725, -0.00271883, + 0.014899, -0.0517661, -0.143725, -0.00271883, - 0.068281, -0.162217, -0.152268, 0.00323521, - 0.068281, -0.162217, -0.152268, 0.00323521, + 0.068281, -0.162217, -0.152268, 0.00323521, + 0.068281, -0.162217, -0.152268, 0.00323521, - -0.0317821, -0.0333089, 0.0609602, 0.0333759, - -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.0317821, -0.0333089, 0.0609602, 0.0333759, - -0.00623099, -0.077701, -0.391193, -0.0136691, - -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.00623099, -0.077701, -0.391193, -0.0136691, - 0.201551, -0.164607, -0.179462, -0.0592739, - 0.201551, -0.164607, -0.179462, -0.0592739, + 0.201551, -0.164607, -0.179462, -0.0592739, + 0.201551, -0.164607, -0.179462, -0.0592739, - 0.0886511, -0.0875401, -0.269283, 0.0281379, - 0.0886511, -0.0875401, -0.269283, 0.0281379, + 0.0886511, -0.0875401, -0.269283, 0.0281379, + 0.0886511, -0.0875401, -0.269283, 0.0281379, - -0.201174, -0.586145, -0.628624, -0.0330412, - -0.201174, -0.586145, -0.628624, -0.0330412, + -0.201174, -0.586145, -0.628624, -0.0330412, + -0.201174, -0.586145, -0.628624, -0.0330412, - -0.0839096, -0.299329, 0.108746, 0.109808, - -0.0839096, -0.299329, 0.108746, 0.109808, + -0.0839096, -0.299329, 0.108746, 0.109808, + -0.0839096, -0.299329, 0.108746, 0.109808, - 0.419114, -0.237824, -0.422627, 0.175115, - 0.419114, -0.237824, -0.422627, 0.175115, + 0.419114, -0.237824, -0.422627, 0.175115, + 0.419114, -0.237824, -0.422627, 0.175115, - 0.36726, -0.522303, -0.456502, -0.175475, - 0.36726, -0.522303, -0.456502, -0.175475}; + 0.36726, -0.522303, -0.456502, -0.175475, + 0.36726, -0.522303, -0.456502, -0.175475}; static float svdf_golden_output_rank_2[] = { - -0.09623547, -0.10193135, 0.11083051, -0.0347917, - 0.1141196, 0.12965347, -0.12652366, 0.01007236, + -0.09623547, -0.10193135, 0.11083051, -0.0347917, + 0.1141196, 0.12965347, -0.12652366, 0.01007236, - -0.16396809, -0.21247184, 0.11259045, -0.04156673, - 0.10132131, -0.06143532, -0.00924693, 0.10084561, + -0.16396809, -0.21247184, 0.11259045, -0.04156673, + 0.10132131, -0.06143532, -0.00924693, 0.10084561, - 0.01257364, 0.0506071, -0.19287863, -0.07162561, - -0.02033747, 0.22673416, 0.15487903, 0.02525555, + 0.01257364, 0.0506071, -0.19287863, -0.07162561, + -0.02033747, 0.22673416, 0.15487903, 0.02525555, - -0.1411963, -0.37054959, 0.01774767, 0.05867489, - 0.09607603, -0.0141301, -0.08995658, 0.12867066, + -0.1411963, -0.37054959, 0.01774767, 0.05867489, + 0.09607603, -0.0141301, -0.08995658, 0.12867066, - -0.27142537, -0.16955489, 0.18521598, -0.12528358, - 0.00331409, 0.11167502, 0.02218599, -0.07309391, + -0.27142537, -0.16955489, 0.18521598, -0.12528358, + 0.00331409, 0.11167502, 0.02218599, -0.07309391, - 0.09593632, -0.28361851, -0.0773851, 0.17199151, - -0.00075242, 0.33691186, -0.1536046, 0.16572715, + 0.09593632, -0.28361851, -0.0773851, 0.17199151, + -0.00075242, 0.33691186, -0.1536046, 0.16572715, - -0.27916506, -0.27626723, 0.42615682, 0.3225764, - -0.37472126, -0.55655634, -0.05013514, 0.289112, + -0.27916506, -0.27626723, 0.42615682, 0.3225764, + -0.37472126, -0.55655634, -0.05013514, 0.289112, - -0.24418658, 0.07540751, -0.1940318, -0.08911639, - 0.00732617, 0.46737891, 0.26449674, 0.24888524, + -0.24418658, 0.07540751, -0.1940318, -0.08911639, + 0.00732617, 0.46737891, 0.26449674, 0.24888524, - -0.17225097, -0.54660404, -0.38795233, 0.08389944, - 0.07736043, -0.28260678, 0.15666828, 1.14949894, + -0.17225097, -0.54660404, -0.38795233, 0.08389944, + 0.07736043, -0.28260678, 0.15666828, 1.14949894, - -0.57454878, -0.64704704, 0.73235172, -0.34616736, - 0.21120001, -0.22927976, 0.02455296, -0.35906726, + -0.57454878, -0.64704704, 0.73235172, -0.34616736, + 0.21120001, -0.22927976, 0.02455296, -0.35906726, }; #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ - ACTION(Input) \ - ACTION(WeightsFeature) \ - ACTION(WeightsTime) \ - ACTION(Bias) \ - ACTION(StateIn) + ACTION(Input) \ + ACTION(WeightsFeature) \ + ACTION(WeightsTime) \ + ACTION(Bias) \ + ACTION(StateIn) // For all output and intermediate states #define FOR_ALL_OUTPUT_TENSORS(ACTION) \ - ACTION(StateOut) \ - ACTION(Output) + ACTION(StateOut) \ + ACTION(Output) // Derived class of SingleOpModel, which is used to test SVDF TFLite op. class SVDFOpModel { - public: - SVDFOpModel(uint32_t batches, uint32_t units, uint32_t input_size, - uint32_t memory_size, uint32_t rank) - : batches_(batches), - units_(units), - input_size_(input_size), - memory_size_(memory_size), - rank_(rank) { - std::vector<std::vector<uint32_t>> input_shapes{ - {batches_, input_size_}, // Input tensor - {units_ * rank_, input_size_}, // weights_feature tensor - {units_ * rank_, memory_size_}, // weights_time tensor - {units_}, // bias tensor - {batches_, memory_size * units_ * rank_}, // state in tensor - }; - std::vector<uint32_t> inputs; - auto it = input_shapes.begin(); - - // Input and weights -#define AddInput(X) \ - OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \ - inputs.push_back(model_.addOperand(&X##OpndTy)); - - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput); + public: + SVDFOpModel(uint32_t batches, uint32_t units, uint32_t input_size, uint32_t memory_size, + uint32_t rank) + : batches_(batches), + units_(units), + input_size_(input_size), + memory_size_(memory_size), + rank_(rank) { + std::vector<std::vector<uint32_t>> input_shapes{ + {batches_, input_size_}, // Input tensor + {units_ * rank_, input_size_}, // weights_feature tensor + {units_ * rank_, memory_size_}, // weights_time tensor + {units_}, // bias tensor + {batches_, memory_size * units_ * rank_}, // state in tensor + }; + std::vector<uint32_t> inputs; + auto it = input_shapes.begin(); + + // Input and weights +#define AddInput(X) \ + OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \ + inputs.push_back(model_.addOperand(&X##OpndTy)); + + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput); #undef AddInput - // Parameters - OperandType RankParamTy(Type::INT32, {}); - inputs.push_back(model_.addOperand(&RankParamTy)); - OperandType ActivationParamTy(Type::INT32, {}); - inputs.push_back(model_.addOperand(&ActivationParamTy)); + // Parameters + OperandType RankParamTy(Type::INT32, {}); + inputs.push_back(model_.addOperand(&RankParamTy)); + OperandType ActivationParamTy(Type::INT32, {}); + inputs.push_back(model_.addOperand(&ActivationParamTy)); - // Output and other intermediate state - std::vector<std::vector<uint32_t>> output_shapes{{batches_, memory_size_ * units_ * rank_}, - {batches_, units_}}; - std::vector<uint32_t> outputs; + // Output and other intermediate state + std::vector<std::vector<uint32_t>> output_shapes{{batches_, memory_size_ * units_ * rank_}, + {batches_, units_}}; + std::vector<uint32_t> outputs; - auto it2 = output_shapes.begin(); + auto it2 = output_shapes.begin(); -#define AddOutput(X) \ - OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \ - outputs.push_back(model_.addOperand(&X##OpndTy)); +#define AddOutput(X) \ + OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \ + outputs.push_back(model_.addOperand(&X##OpndTy)); - FOR_ALL_OUTPUT_TENSORS(AddOutput); + FOR_ALL_OUTPUT_TENSORS(AddOutput); #undef AddOutput - Input_.insert(Input_.end(), batches_ * input_size_, 0.f); - StateIn_.insert(StateIn_.end(), batches_ * units_ * rank_ * memory_size_, 0.f); + Input_.insert(Input_.end(), batches_ * input_size_, 0.f); + StateIn_.insert(StateIn_.end(), batches_ * units_ * rank_ * memory_size_, 0.f); - auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t { - uint32_t sz = 1; - for(uint32_t d:dims) { sz *= d; } - return sz; - }; + auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t { + uint32_t sz = 1; + for (uint32_t d : dims) { + sz *= d; + } + return sz; + }; - it2 = output_shapes.begin(); + it2 = output_shapes.begin(); #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f); - FOR_ALL_OUTPUT_TENSORS(ReserveOutput); + FOR_ALL_OUTPUT_TENSORS(ReserveOutput); - model_.addOperation(ANEURALNETWORKS_SVDF, inputs, outputs); - model_.identifyInputsAndOutputs(inputs, outputs); + model_.addOperation(ANEURALNETWORKS_SVDF, inputs, outputs); + model_.identifyInputsAndOutputs(inputs, outputs); - model_.finish(); - } + model_.finish(); + } - void Invoke() { - ASSERT_TRUE(model_.isValid()); + void Invoke() { + ASSERT_TRUE(model_.isValid()); - Compilation compilation(&model_); - compilation.finish(); - Execution execution(&compilation); + Compilation compilation(&model_); + compilation.finish(); + Execution execution(&compilation); - StateIn_.swap(StateOut_); + StateIn_.swap(StateOut_); -#define SetInputOrWeight(X) \ - ASSERT_EQ(execution.setInput(SVDF::k##X##Tensor, X##_.data(), \ - sizeof(float) * X##_.size()), \ - Result::NO_ERROR); +#define SetInputOrWeight(X) \ + ASSERT_EQ(execution.setInput(SVDF::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \ + Result::NO_ERROR); - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); #undef SetInputOrWeight -#define SetOutput(X) \ - EXPECT_TRUE(X##_.data() != nullptr); \ - ASSERT_EQ(execution.setOutput(SVDF::k##X##Tensor, X##_.data(), \ - sizeof(float) * X##_.size()), \ - Result::NO_ERROR); +#define SetOutput(X) \ + EXPECT_TRUE(X##_.data() != nullptr); \ + ASSERT_EQ(execution.setOutput(SVDF::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \ + Result::NO_ERROR); - FOR_ALL_OUTPUT_TENSORS(SetOutput); + FOR_ALL_OUTPUT_TENSORS(SetOutput); #undef SetOutput - ASSERT_EQ(execution.setInput(SVDF::kRankParam, &rank_, sizeof(rank_)), - Result::NO_ERROR); + ASSERT_EQ(execution.setInput(SVDF::kRankParam, &rank_, sizeof(rank_)), Result::NO_ERROR); - int activation = TfLiteFusedActivation::kTfLiteActNone; - ASSERT_EQ(execution.setInput(SVDF::kActivationParam, &activation, - sizeof(activation)), - Result::NO_ERROR); + int activation = TfLiteFusedActivation::kTfLiteActNone; + ASSERT_EQ(execution.setInput(SVDF::kActivationParam, &activation, sizeof(activation)), + Result::NO_ERROR); - ASSERT_EQ(execution.compute(), Result::NO_ERROR); - } + ASSERT_EQ(execution.compute(), Result::NO_ERROR); + } -#define DefineSetter(X) \ - void Set##X(const std::vector<float>& f) { \ - X##_.insert(X##_.end(), f.begin(), f.end()); \ - } +#define DefineSetter(X) \ + void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); } - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); #undef DefineSetter - void SetInput(int offset, float* begin, float* end) { - for (; begin != end; begin++, offset++) { - Input_[offset] = *begin; + void SetInput(int offset, float* begin, float* end) { + for (; begin != end; begin++, offset++) { + Input_[offset] = *begin; + } } - } - // Resets the state of SVDF op by filling it with 0's. - void ResetState() { - std::fill(StateIn_.begin(), StateIn_.end(), 0.f); - std::fill(StateOut_.begin(), StateOut_.end(), 0.f); - } + // Resets the state of SVDF op by filling it with 0's. + void ResetState() { + std::fill(StateIn_.begin(), StateIn_.end(), 0.f); + std::fill(StateOut_.begin(), StateOut_.end(), 0.f); + } - // Extracts the output tensor from the SVDF op. - const std::vector<float>& GetOutput() const { return Output_; } + // Extracts the output tensor from the SVDF op. + const std::vector<float>& GetOutput() const { return Output_; } - int input_size() const { return input_size_; } - int num_units() const { return units_; } - int num_batches() const { return batches_; } + int input_size() const { return input_size_; } + int num_units() const { return units_; } + int num_batches() const { return batches_; } - private: - Model model_; + private: + Model model_; - const uint32_t batches_; - const uint32_t units_; - const uint32_t input_size_; - const uint32_t memory_size_; - const uint32_t rank_; + const uint32_t batches_; + const uint32_t units_; + const uint32_t input_size_; + const uint32_t memory_size_; + const uint32_t rank_; #define DefineTensor(X) std::vector<float> X##_; - FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); - FOR_ALL_OUTPUT_TENSORS(DefineTensor); + FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); + FOR_ALL_OUTPUT_TENSORS(DefineTensor); #undef DefineTensor }; TEST(SVDFOpTest, BlackBoxTest) { - SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/1); - svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, - 0.22197971, 0.12416199, 0.27901134, 0.27557442, - 0.3905206, -0.36137494, -0.06634006, -0.10640851}); - - svdf.SetWeightsTime( - {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, - 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, - - 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, - -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, - - -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, - 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, - - -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, - -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); - - svdf.SetBias({}); - - svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector<float> expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 0.22197971, 0.12416199, + 0.27901134, 0.27557442, 0.3905206, -0.36137494, -0.06634006, + -0.10640851}); + + svdf.SetWeightsTime({-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.SetBias({}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = svdf_golden_output + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector<float> expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } } TEST(SVDFOpTest, BlackBoxTestRank2) { - SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/2); - svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, - 0.12416199, 0.15785322, 0.27901134, 0.3905206, - 0.21931258, -0.36137494, -0.10640851, 0.31053296, - -0.36118156, -0.0976817, -0.36916667, 0.22197971, - 0.15294972, 0.38031587, 0.27557442, 0.39635518, - -0.21580373, -0.06634006, -0.02702999, 0.27072677}); - - svdf.SetWeightsTime( - {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, - 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, - - 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, - -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, - - -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, - 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, - - -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, - -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, - - -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, - 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, - - -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, - 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, - - -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, - -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, - - 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, - 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); - - svdf.SetBias({}); - - svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input_rank2) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input_rank2 + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector<float> expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, + 0.15785322, 0.27901134, 0.3905206, 0.21931258, -0.36137494, + -0.10640851, 0.31053296, -0.36118156, -0.0976817, -0.36916667, + 0.22197971, 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime({-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.SetBias({}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input_rank2) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input_rank2 + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector<float> expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } } } // namespace wrapper diff --git a/nn/driver/cache/BlobCache/BlobCache.cpp b/nn/driver/cache/BlobCache/BlobCache.cpp index 61130b9ad..e3274da7a 100644 --- a/nn/driver/cache/BlobCache/BlobCache.cpp +++ b/nn/driver/cache/BlobCache/BlobCache.cpp @@ -28,7 +28,7 @@ #include <algorithm> static const char property_value[] = "[HOST]"; #define PROPERTY_VALUE_MAX (sizeof(property_value) - 1) -static int property_get(const char *key, char *value, const char *default_value) { +static int property_get(const char* key, char* value, const char* default_value) { if (!strcmp(key, "ro.build.id")) { memcpy(value, property_value, PROPERTY_VALUE_MAX); return PROPERTY_VALUE_MAX; @@ -57,14 +57,14 @@ static const uint32_t blobCacheVersion = 3; // BlobCache::Header::mDeviceVersion value static const uint32_t blobCacheDeviceVersion = 1; -BlobCache::BlobCache(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize, Policy policy): - mMaxKeySize(maxKeySize), - mMaxValueSize(maxValueSize), - mMaxTotalSize(maxTotalSize), - mPolicySelect(policy.first), - mPolicyCapacity(policy.second), - mTotalSize(0), - mAccessCount(0) { +BlobCache::BlobCache(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize, Policy policy) + : mMaxKeySize(maxKeySize), + mMaxValueSize(maxValueSize), + mMaxTotalSize(maxTotalSize), + mPolicySelect(policy.first), + mPolicyCapacity(policy.second), + mTotalSize(0), + mAccessCount(0) { int64_t now = std::chrono::steady_clock::now().time_since_epoch().count(); #ifdef _WIN32 srand(now); @@ -76,21 +76,21 @@ BlobCache::BlobCache(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize ALOGV("initializing random seed using %lld", (unsigned long long)now); } -void BlobCache::set(const void* key, size_t keySize, const void* value, - size_t valueSize) { +void BlobCache::set(const void* key, size_t keySize, const void* value, size_t valueSize) { if (mMaxKeySize < keySize) { - ALOGV("set: not caching because the key is too large: %zu (limit: %zu)", - keySize, mMaxKeySize); + ALOGV("set: not caching because the key is too large: %zu (limit: %zu)", keySize, + mMaxKeySize); return; } if (mMaxValueSize < valueSize) { - ALOGV("set: not caching because the value is too large: %zu (limit: %zu)", - valueSize, mMaxValueSize); + ALOGV("set: not caching because the value is too large: %zu (limit: %zu)", valueSize, + mMaxValueSize); return; } if (mMaxTotalSize < keySize + valueSize) { ALOGV("set: not caching because the combined key/value size is too " - "large: %zu (limit: %zu)", keySize + valueSize, mMaxTotalSize); + "large: %zu (limit: %zu)", + keySize + valueSize, mMaxTotalSize); return; } if (keySize == 0) { @@ -127,16 +127,16 @@ void BlobCache::set(const void* key, size_t keySize, const void* value, continue; } else { ALOGV("set: not caching new key/value pair because the " - "total cache size limit would be exceeded: %zu " - "(limit: %zu)", - keySize + valueSize, mMaxTotalSize); + "total cache size limit would be exceeded: %zu " + "(limit: %zu)", + keySize + valueSize, mMaxTotalSize); break; } } mCacheEntries.insert(index, CacheEntry(keyBlob, valueBlob, ++mAccessCount)); mTotalSize = newTotalSize; - ALOGV("set: created new cache entry with %zu byte key and %zu byte value", - keySize, valueSize); + ALOGV("set: created new cache entry with %zu byte key and %zu byte value", keySize, + valueSize); } else { // Update the existing cache entry. std::shared_ptr<Blob> valueBlob(new Blob(value, valueSize, true)); @@ -157,8 +157,8 @@ void BlobCache::set(const void* key, size_t keySize, const void* value, continue; } else { ALOGV("set: not caching new value because the total cache " - "size limit would be exceeded: %zu (limit: %zu)", - keySize + valueSize, mMaxTotalSize); + "size limit would be exceeded: %zu (limit: %zu)", + keySize + valueSize, mMaxTotalSize); break; } } @@ -166,26 +166,25 @@ void BlobCache::set(const void* key, size_t keySize, const void* value, index->setRecency(++mAccessCount); mTotalSize = newTotalSize; ALOGV("set: updated existing cache entry with %zu byte key and %zu byte " - "value", keySize, valueSize); + "value", + keySize, valueSize); } break; } } -size_t BlobCache::get(const void* key, size_t keySize, void* value, - size_t valueSize) { - void *dummy; - return get(key, keySize, &dummy, - [value, valueSize](size_t allocSize) { - return (allocSize <= valueSize ? value : nullptr); - }); +size_t BlobCache::get(const void* key, size_t keySize, void* value, size_t valueSize) { + void* dummy; + return get(key, keySize, &dummy, [value, valueSize](size_t allocSize) { + return (allocSize <= valueSize ? value : nullptr); + }); } size_t BlobCache::get(const void* key, size_t keySize, void** value, - std::function<void*(size_t)> alloc) { + std::function<void*(size_t)> alloc) { if (mMaxKeySize < keySize) { - ALOGV("get: not searching because the key is too large: %zu (limit %zu)", - keySize, mMaxKeySize); + ALOGV("get: not searching because the key is too large: %zu (limit %zu)", keySize, + mMaxKeySize); *value = nullptr; return 0; } @@ -201,7 +200,7 @@ size_t BlobCache::get(const void* key, size_t keySize, void** value, // The key was found. Return the value if we can allocate a buffer. std::shared_ptr<Blob> valueBlob(index->getValue()); size_t valueBlobSize = valueBlob->getSize(); - void *buf = alloc(valueBlobSize); + void* buf = alloc(valueBlobSize); if (buf != nullptr) { ALOGV("get: copying %zu bytes to caller's buffer", valueBlobSize); memcpy(buf, valueBlob->getData(), valueBlobSize); @@ -220,7 +219,7 @@ static inline size_t align4(size_t size) { size_t BlobCache::getFlattenedSize() const { size_t size = align4(sizeof(Header) + PROPERTY_VALUE_MAX); - for (const CacheEntry& e : mCacheEntries) { + for (const CacheEntry& e : mCacheEntries) { std::shared_ptr<Blob> const& keyBlob = e.getKey(); std::shared_ptr<Blob> const& valueBlob = e.getValue(); size += align4(sizeof(EntryHeader) + keyBlob->getSize() + valueBlob->getSize()); @@ -246,7 +245,7 @@ int BlobCache::flatten(void* buffer, size_t size) const { // Write cache entries uint8_t* byteBuffer = reinterpret_cast<uint8_t*>(buffer); off_t byteOffset = align4(sizeof(Header) + header->mBuildIdLength); - for (const CacheEntry& e : mCacheEntries) { + for (const CacheEntry& e : mCacheEntries) { std::shared_ptr<Blob> const& keyBlob = e.getKey(); std::shared_ptr<Blob> const& valueBlob = e.getValue(); size_t keySize = keyBlob->getSize(); @@ -295,9 +294,8 @@ int BlobCache::unflatten(void const* buffer, size_t size) { char buildId[PROPERTY_VALUE_MAX]; int len = property_get("ro.build.id", buildId, ""); if (header->mBlobCacheVersion != blobCacheVersion || - header->mDeviceVersion != blobCacheDeviceVersion || - len != header->mBuildIdLength || - strncmp(buildId, header->mBuildId, len)) { + header->mDeviceVersion != blobCacheDeviceVersion || len != header->mBuildIdLength || + strncmp(buildId, header->mBuildId, len)) { // We treat version mismatches as an empty cache. return 0; } @@ -313,8 +311,7 @@ int BlobCache::unflatten(void const* buffer, size_t size) { return -EINVAL; } - const EntryHeader* eheader = reinterpret_cast<const EntryHeader*>( - &byteBuffer[byteOffset]); + const EntryHeader* eheader = reinterpret_cast<const EntryHeader*>(&byteBuffer[byteOffset]); size_t keySize = eheader->mKeySize; size_t valueSize = eheader->mValueSize; size_t entrySize = sizeof(EntryHeader) + keySize + valueSize; @@ -349,9 +346,10 @@ size_t BlobCache::findVictim() { return size_t(blob_random() % (mCacheEntries.size())); case Select::LRU: return std::min_element(mCacheEntries.begin(), mCacheEntries.end(), - [](const CacheEntry &a, const CacheEntry &b) { + [](const CacheEntry& a, const CacheEntry& b) { return a.getRecency() < b.getRecency(); - }) - mCacheEntries.begin(); + }) - + mCacheEntries.begin(); default: ALOGE("findVictim: unknown mPolicySelect: %d", mPolicySelect); return 0; @@ -360,9 +358,8 @@ size_t BlobCache::findVictim() { size_t BlobCache::findDownTo(size_t newEntrySize, size_t onBehalfOf) { auto oldEntrySize = [this, onBehalfOf]() -> size_t { - if (onBehalfOf == NoEntry) - return 0; - const auto &entry = mCacheEntries[onBehalfOf]; + if (onBehalfOf == NoEntry) return 0; + const auto& entry = mCacheEntries[onBehalfOf]; return entry.getKey()->getSize() + entry.getValue()->getSize(); }; switch (mPolicyCapacity) { @@ -421,10 +418,8 @@ bool BlobCache::isCleanable() const { } } -BlobCache::Blob::Blob(const void* data, size_t size, bool copyData) : - mData(copyData ? malloc(size) : data), - mSize(size), - mOwnsData(copyData) { +BlobCache::Blob::Blob(const void* data, size_t size, bool copyData) + : mData(copyData ? malloc(size) : data), mSize(size), mOwnsData(copyData) { if (data != NULL && copyData) { memcpy(const_cast<void*>(mData), data, size); } @@ -452,21 +447,14 @@ size_t BlobCache::Blob::getSize() const { return mSize; } -BlobCache::CacheEntry::CacheEntry(): mRecency(0) { -} +BlobCache::CacheEntry::CacheEntry() : mRecency(0) {} -BlobCache::CacheEntry::CacheEntry( - const std::shared_ptr<Blob>& key, const std::shared_ptr<Blob>& value, uint32_t recency): - mKey(key), - mValue(value), - mRecency(recency) { -} +BlobCache::CacheEntry::CacheEntry(const std::shared_ptr<Blob>& key, + const std::shared_ptr<Blob>& value, uint32_t recency) + : mKey(key), mValue(value), mRecency(recency) {} -BlobCache::CacheEntry::CacheEntry(const CacheEntry& ce): - mKey(ce.mKey), - mValue(ce.mValue), - mRecency(ce.mRecency) { -} +BlobCache::CacheEntry::CacheEntry(const CacheEntry& ce) + : mKey(ce.mKey), mValue(ce.mValue), mRecency(ce.mRecency) {} bool BlobCache::CacheEntry::operator<(const CacheEntry& rhs) const { return *mKey < *rhs.mKey; @@ -499,4 +487,4 @@ void BlobCache::CacheEntry::setRecency(uint32_t recency) { mRecency = recency; } -} // namespace android +} // namespace android diff --git a/nn/driver/cache/BlobCache/BlobCache.h b/nn/driver/cache/BlobCache/BlobCache.h index d79a23de4..9caa6beec 100644 --- a/nn/driver/cache/BlobCache/BlobCache.h +++ b/nn/driver/cache/BlobCache/BlobCache.h @@ -34,7 +34,7 @@ namespace android { // serialization is non-portable and the data should only be used by the device // that generated it. class BlobCache { -public: + public: enum class Select { RANDOM, // evict random entries LRU, // evict least-recently-used entries @@ -86,8 +86,7 @@ public: // 0 < keySize // value != NULL // 0 < valueSize - void set(const void* key, size_t keySize, const void* value, - size_t valueSize); + void set(const void* key, size_t keySize, const void* value, size_t valueSize); // get retrieves from the cache the binary value associated with a given // binary key. If the key is present in the cache then the length of the @@ -131,7 +130,7 @@ public: size_t get(const void* key, size_t keySize, void** value, std::function<void*(size_t)> alloc); template <typename T> size_t get(const void* key, size_t keySize, T** value, std::function<void*(size_t)> alloc) { - void *valueVoid; + void* valueVoid; const size_t size = get(key, keySize, &valueVoid, alloc); *value = static_cast<T*>(valueVoid); return size; @@ -158,7 +157,7 @@ public: // int unflatten(void const* buffer, size_t size); -private: + private: // Copying is disallowed. BlobCache(const BlobCache&); void operator=(const BlobCache&); @@ -204,7 +203,7 @@ private: // A Blob is an immutable sized unstructured data blob. class Blob { - public: + public: Blob(const void* data, size_t size, bool copyData); ~Blob(); @@ -213,7 +212,7 @@ private: const void* getData() const; size_t getSize() const; - private: + private: // Copying is not allowed. Blob(const Blob&); void operator=(const Blob&); @@ -231,9 +230,10 @@ private: // A CacheEntry is a single key/value pair in the cache. class CacheEntry { - public: + public: CacheEntry(); - CacheEntry(const std::shared_ptr<Blob>& key, const std::shared_ptr<Blob>& value, uint32_t recency); + CacheEntry(const std::shared_ptr<Blob>& key, const std::shared_ptr<Blob>& value, + uint32_t recency); CacheEntry(const CacheEntry& ce); bool operator<(const CacheEntry& rhs) const; @@ -247,8 +247,7 @@ private: uint32_t getRecency() const; void setRecency(uint32_t recency); - private: - + private: // mKey is the key that identifies the cache entry. std::shared_ptr<Blob> mKey; @@ -349,6 +348,6 @@ private: std::vector<CacheEntry> mCacheEntries; }; -} +} // namespace android #endif // ANDROID_FRAMEWORKS_ML_NN_DRIVER_CACHE_BLOB_CACHE_BLOB_CACHE_H diff --git a/nn/driver/cache/BlobCache/BlobCache_test.cpp b/nn/driver/cache/BlobCache/BlobCache_test.cpp index 26b915cbb..2635fcc06 100644 --- a/nn/driver/cache/BlobCache/BlobCache_test.cpp +++ b/nn/driver/cache/BlobCache/BlobCache_test.cpp @@ -29,15 +29,12 @@ namespace android { -template<typename T> using sp = std::shared_ptr<T>; +template <typename T> +using sp = std::shared_ptr<T>; class BlobCacheTest : public ::testing::TestWithParam<BlobCache::Policy> { -protected: - - enum { - OK = 0, - BAD_VALUE = -EINVAL - }; + protected: + enum { OK = 0, BAD_VALUE = -EINVAL }; enum { MAX_KEY_SIZE = 6, @@ -49,25 +46,25 @@ protected: mBC.reset(new BlobCache(MAX_KEY_SIZE, MAX_VALUE_SIZE, MAX_TOTAL_SIZE, GetParam())); } - virtual void TearDown() { - mBC.reset(); - } + virtual void TearDown() { mBC.reset(); } std::unique_ptr<BlobCache> mBC; }; -INSTANTIATE_TEST_CASE_P(Policy, BlobCacheTest, - ::testing::Values(BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::HALVE), - BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::HALVE), +INSTANTIATE_TEST_CASE_P( + Policy, BlobCacheTest, + ::testing::Values( + BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::HALVE), + BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::HALVE), - BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT), - BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT), + BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT), + BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT), - BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT_HALVE), - BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT_HALVE))); + BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT_HALVE), + BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT_HALVE))); TEST_P(BlobCacheTest, CacheSingleValueSucceeds) { - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf, 4)); ASSERT_EQ('e', buf[0]); @@ -77,7 +74,7 @@ TEST_P(BlobCacheTest, CacheSingleValueSucceeds) { } TEST_P(BlobCacheTest, CacheTwoValuesSucceeds) { - unsigned char buf[2] = { 0xee, 0xee }; + unsigned char buf[2] = {0xee, 0xee}; mBC->set("ab", 2, "cd", 2); mBC->set("ef", 2, "gh", 2); ASSERT_EQ(size_t(2), mBC->get("ab", 2, buf, 2)); @@ -89,7 +86,7 @@ TEST_P(BlobCacheTest, CacheTwoValuesSucceeds) { } TEST_P(BlobCacheTest, CacheTwoValuesMallocSucceeds) { - unsigned char *bufPtr; + unsigned char* bufPtr; mBC->set("ab", 2, "cd", 2); mBC->set("ef", 2, "gh", 2); @@ -109,9 +106,9 @@ TEST_P(BlobCacheTest, CacheTwoValuesMallocSucceeds) { } TEST_P(BlobCacheTest, GetOnlyWritesInsideBounds) { - unsigned char buf[6] = { 0xee, 0xee, 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[6] = {0xee, 0xee, 0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); - ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf+1, 4)); + ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf + 1, 4)); ASSERT_EQ(0xee, buf[0]); ASSERT_EQ('e', buf[1]); ASSERT_EQ('f', buf[2]); @@ -121,7 +118,7 @@ TEST_P(BlobCacheTest, GetOnlyWritesInsideBounds) { } TEST_P(BlobCacheTest, GetOnlyWritesIfBufferIsLargeEnough) { - unsigned char buf[3] = { 0xee, 0xee, 0xee }; + unsigned char buf[3] = {0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf, 3)); ASSERT_EQ(0xee, buf[0]); @@ -130,13 +127,13 @@ TEST_P(BlobCacheTest, GetOnlyWritesIfBufferIsLargeEnough) { } TEST_P(BlobCacheTest, GetWithFailedAllocator) { - unsigned char buf[3] = { 0xee, 0xee, 0xee }; + unsigned char buf[3] = {0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); // If allocator fails, verify that we set the value pointer to // nullptr, and that we do not modify the buffer that the value // pointer originally pointed to. - unsigned char *bufPtr = &buf[0]; + unsigned char* bufPtr = &buf[0]; ASSERT_EQ(size_t(4), mBC->get("abcd", 4, &bufPtr, [](size_t) -> void* { return nullptr; })); ASSERT_EQ(nullptr, bufPtr); ASSERT_EQ(0xee, buf[0]); @@ -150,7 +147,7 @@ TEST_P(BlobCacheTest, GetDoesntAccessNullBuffer) { } TEST_P(BlobCacheTest, MultipleSetsCacheLatestValue) { - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); mBC->set("abcd", 4, "ijkl", 4); ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf, 4)); @@ -161,9 +158,9 @@ TEST_P(BlobCacheTest, MultipleSetsCacheLatestValue) { } TEST_P(BlobCacheTest, SecondSetKeepsFirstValueIfTooLarge) { - unsigned char buf[MAX_VALUE_SIZE+1] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[MAX_VALUE_SIZE + 1] = {0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); - mBC->set("abcd", 4, buf, MAX_VALUE_SIZE+1); + mBC->set("abcd", 4, buf, MAX_VALUE_SIZE + 1); ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf, 4)); ASSERT_EQ('e', buf[0]); ASSERT_EQ('f', buf[1]); @@ -172,14 +169,14 @@ TEST_P(BlobCacheTest, SecondSetKeepsFirstValueIfTooLarge) { } TEST_P(BlobCacheTest, DoesntCacheIfKeyIsTooBig) { - char key[MAX_KEY_SIZE+1]; - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; - for (int i = 0; i < MAX_KEY_SIZE+1; i++) { + char key[MAX_KEY_SIZE + 1]; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; + for (int i = 0; i < MAX_KEY_SIZE + 1; i++) { key[i] = 'a'; } - mBC->set(key, MAX_KEY_SIZE+1, "bbbb", 4); + mBC->set(key, MAX_KEY_SIZE + 1, "bbbb", 4); - ASSERT_EQ(size_t(0), mBC->get(key, MAX_KEY_SIZE+1, buf, 4)); + ASSERT_EQ(size_t(0), mBC->get(key, MAX_KEY_SIZE + 1, buf, 4)); ASSERT_EQ(0xee, buf[0]); ASSERT_EQ(0xee, buf[1]); ASSERT_EQ(0xee, buf[2]); @@ -188,12 +185,12 @@ TEST_P(BlobCacheTest, DoesntCacheIfKeyIsTooBig) { // If key is too large, verify that we do not call the allocator, // that we set the value pointer to nullptr, and that we do not // modify the buffer that the value pointer originally pointed to. - unsigned char *bufPtr = &buf[0]; + unsigned char* bufPtr = &buf[0]; bool calledAlloc = false; - ASSERT_EQ(size_t(0), mBC->get(key, MAX_KEY_SIZE+1, &bufPtr, - [&calledAlloc](size_t) -> void* { - calledAlloc = true; - return nullptr; })); + ASSERT_EQ(size_t(0), mBC->get(key, MAX_KEY_SIZE + 1, &bufPtr, [&calledAlloc](size_t) -> void* { + calledAlloc = true; + return nullptr; + })); ASSERT_EQ(false, calledAlloc); ASSERT_EQ(nullptr, bufPtr); ASSERT_EQ(0xee, buf[0]); @@ -203,16 +200,16 @@ TEST_P(BlobCacheTest, DoesntCacheIfKeyIsTooBig) { } TEST_P(BlobCacheTest, DoesntCacheIfValueIsTooBig) { - unsigned char buf[MAX_VALUE_SIZE+1]; - for (int i = 0; i < MAX_VALUE_SIZE+1; i++) { + unsigned char buf[MAX_VALUE_SIZE + 1]; + for (int i = 0; i < MAX_VALUE_SIZE + 1; i++) { buf[i] = 'b'; } - mBC->set("abcd", 4, buf, MAX_VALUE_SIZE+1); - for (int i = 0; i < MAX_VALUE_SIZE+1; i++) { + mBC->set("abcd", 4, buf, MAX_VALUE_SIZE + 1); + for (int i = 0; i < MAX_VALUE_SIZE + 1; i++) { buf[i] = 0xee; } - ASSERT_EQ(size_t(0), mBC->get("abcd", 4, buf, MAX_VALUE_SIZE+1)); - for (int i = 0; i < MAX_VALUE_SIZE+1; i++) { + ASSERT_EQ(size_t(0), mBC->get("abcd", 4, buf, MAX_VALUE_SIZE + 1)); + for (int i = 0; i < MAX_VALUE_SIZE + 1; i++) { SCOPED_TRACE(i); ASSERT_EQ(0xee, buf[i]); } @@ -240,7 +237,7 @@ TEST_P(BlobCacheTest, DoesntCacheIfKeyValuePairIsTooBig) { TEST_P(BlobCacheTest, CacheMaxKeySizeSucceeds) { char key[MAX_KEY_SIZE]; - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; for (int i = 0; i < MAX_KEY_SIZE; i++) { key[i] = 'a'; } @@ -261,8 +258,7 @@ TEST_P(BlobCacheTest, CacheMaxValueSizeSucceeds) { for (int i = 0; i < MAX_VALUE_SIZE; i++) { buf[i] = 0xee; } - ASSERT_EQ(size_t(MAX_VALUE_SIZE), mBC->get("abcd", 4, buf, - MAX_VALUE_SIZE)); + ASSERT_EQ(size_t(MAX_VALUE_SIZE), mBC->get("abcd", 4, buf, MAX_VALUE_SIZE)); for (int i = 0; i < MAX_VALUE_SIZE; i++) { SCOPED_TRACE(i); ASSERT_EQ('b', buf[i]); @@ -289,7 +285,7 @@ TEST_P(BlobCacheTest, CacheMaxKeyValuePairSizeSucceeds) { } TEST_P(BlobCacheTest, CacheMinKeyAndValueSizeSucceeds) { - unsigned char buf[1] = { 0xee }; + unsigned char buf[1] = {0xee}; mBC->set("x", 1, "y", 1); ASSERT_EQ(size_t(1), mBC->get("x", 1, buf, 1)); ASSERT_EQ('y', buf[0]); @@ -328,17 +324,16 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitHalvesCacheSize) { // Count the number of entries in the cache; and check which // entries they are. int numCached = 0; - for (int i = 0; i < maxEntries+1; i++) { + for (int i = 0; i < maxEntries + 1; i++) { uint8_t k = i; bool found = (mBC->get(&k, 1, NULL, 0) == 1); - if (found) - numCached++; + if (found) numCached++; if (GetParam().first == BlobCache::Select::LRU) { SCOPED_TRACE(i); - ASSERT_EQ(found, i >= maxEntries/2); + ASSERT_EQ(found, i >= maxEntries / 2); } } - ASSERT_EQ(maxEntries/2 + 1, numCached); + ASSERT_EQ(maxEntries / 2 + 1, numCached); } TEST_P(BlobCacheTest, ExceedingTotalLimitJustFitsSmallEntry) { @@ -358,10 +353,9 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitJustFitsSmallEntry) { } // Count the number of entries in the cache. int numCached = 0; - for (int i = 0; i < maxEntries+1; i++) { + for (int i = 0; i < maxEntries + 1; i++) { uint8_t k = i; - if (mBC->get(&k, 1, NULL, 0) == 1) - numCached++; + if (mBC->get(&k, 1, NULL, 0) == 1) numCached++; } ASSERT_EQ(maxEntries, numCached); } @@ -376,18 +370,17 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitFitsBigEntry) { } // Insert one more entry, causing a cache overflow. const int bigValueSize = std::min((MAX_TOTAL_SIZE * 3) / 4 - 1, int(MAX_VALUE_SIZE)); - ASSERT_GT(bigValueSize+1, MAX_TOTAL_SIZE / 2); // Check testing assumption + ASSERT_GT(bigValueSize + 1, MAX_TOTAL_SIZE / 2); // Check testing assumption { unsigned char buf[MAX_VALUE_SIZE]; - for (int i = 0; i < bigValueSize; i++) - buf[i] = 0xee; + for (int i = 0; i < bigValueSize; i++) buf[i] = 0xee; uint8_t k = maxEntries; mBC->set(&k, 1, buf, bigValueSize); } // Count the number and size of entries in the cache. int numCached = 0; size_t sizeCached = 0; - for (int i = 0; i < maxEntries+1; i++) { + for (int i = 0; i < maxEntries + 1; i++) { uint8_t k = i; size_t size = mBC->get(&k, 1, NULL, 0); if (size) { @@ -399,8 +392,8 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitFitsBigEntry) { case BlobCache::Capacity::HALVE: // New value is too big for this cleaning algorithm. So // we cleaned the cache, but did not insert the new value. - ASSERT_EQ(maxEntries/2, numCached); - ASSERT_EQ(size_t((maxEntries/2)*2), sizeCached); + ASSERT_EQ(maxEntries / 2, numCached); + ASSERT_EQ(size_t((maxEntries / 2) * 2), sizeCached); break; case BlobCache::Capacity::FIT: case BlobCache::Capacity::FIT_HALVE: { @@ -433,21 +426,20 @@ TEST_P(BlobCacheTest, FailedGetWithAllocator) { // allocator, that we set the value pointer to nullptr, and that // we do not modify the buffer that the value pointer originally // pointed to. - unsigned char buf[1] = { 0xee }; - unsigned char *bufPtr = &buf[0]; + unsigned char buf[1] = {0xee}; + unsigned char* bufPtr = &buf[0]; bool calledAlloc = false; - ASSERT_EQ(size_t(0), mBC->get("a", 1, &bufPtr, - [&calledAlloc](size_t) -> void* { - calledAlloc = true; - return nullptr; })); + ASSERT_EQ(size_t(0), mBC->get("a", 1, &bufPtr, [&calledAlloc](size_t) -> void* { + calledAlloc = true; + return nullptr; + })); ASSERT_EQ(false, calledAlloc); ASSERT_EQ(nullptr, bufPtr); ASSERT_EQ(0xee, buf[0]); } TEST_P(BlobCacheTest, ExceedingTotalLimitRemovesLRUEntries) { - if (GetParam().first != BlobCache::Select::LRU) - return; // test doesn't apply for this policy + if (GetParam().first != BlobCache::Select::LRU) return; // test doesn't apply for this policy // Fill up the entire cache with 1 char key/value pairs. static const int maxEntries = MAX_TOTAL_SIZE / 2; @@ -487,20 +479,19 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitRemovesLRUEntries) { for (int i = 0; i < maxEntries; i++) { uint8_t k = accessSequence[i]; bool found = (mBC->get(&k, 1, NULL, 0) == 1); - if (foundAny == found) - continue; + if (foundAny == found) continue; if (!foundAny) { // found == true, so we just discovered j == i foundAny = true; } else { // foundAny == true, found == false -- oops - FAIL() << "found [" << i-1 << "]th entry but not [" << i << "]th entry"; + FAIL() << "found [" << i - 1 << "]th entry but not [" << i << "]th entry"; } } } class BlobCacheFlattenTest : public BlobCacheTest { -protected: + protected: virtual void SetUp() { BlobCacheTest::SetUp(); mBC2.reset(new BlobCache(MAX_KEY_SIZE, MAX_VALUE_SIZE, MAX_TOTAL_SIZE, GetParam())); @@ -522,18 +513,20 @@ protected: sp<BlobCache> mBC2; }; -INSTANTIATE_TEST_CASE_P(Policy, BlobCacheFlattenTest, - ::testing::Values(BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::HALVE), - BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::HALVE), +INSTANTIATE_TEST_CASE_P( + Policy, BlobCacheFlattenTest, + ::testing::Values( + BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::HALVE), + BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::HALVE), - BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT), - BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT), + BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT), + BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT), - BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT_HALVE), - BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT_HALVE))); + BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT_HALVE), + BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT_HALVE))); TEST_P(BlobCacheFlattenTest, FlattenOneValue) { - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); roundTrip(); ASSERT_EQ(size_t(4), mBC2->get("abcd", 4, buf, 4)); @@ -601,7 +594,7 @@ TEST_P(BlobCacheFlattenTest, FlattenCatchesBufferTooSmall) { } TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadMagic) { - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); size_t size = mBC->getFlattenedSize(); @@ -618,7 +611,7 @@ TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadMagic) { } TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadBlobCacheVersion) { - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); size_t size = mBC->getFlattenedSize(); @@ -637,7 +630,7 @@ TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadBlobCacheVersion) { } TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadBlobCacheDeviceVersion) { - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); size_t size = mBC->getFlattenedSize(); @@ -656,7 +649,7 @@ TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadBlobCacheDeviceVersion) { } TEST_P(BlobCacheFlattenTest, UnflattenCatchesBufferTooSmall) { - unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee }; + unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee}; mBC->set("abcd", 4, "efgh", 4); size_t size = mBC->getFlattenedSize(); @@ -673,4 +666,4 @@ TEST_P(BlobCacheFlattenTest, UnflattenCatchesBufferTooSmall) { ASSERT_EQ(size_t(0), mBC2->get("abcd", 4, buf, 4)); } -} // namespace android +} // namespace android diff --git a/nn/driver/cache/nnCache/nnCache.cpp b/nn/driver/cache/nnCache/nnCache.cpp index 702688bf4..9f9e9be8b 100644 --- a/nn/driver/cache/nnCache/nnCache.cpp +++ b/nn/driver/cache/nnCache/nnCache.cpp @@ -39,15 +39,15 @@ namespace android { // // NNCache definition // -NNCache::NNCache() : - mInitialized(false), - mMaxKeySize(0), mMaxValueSize(0), mMaxTotalSize(0), - mPolicy(defaultPolicy()), - mSavePending(false) { -} +NNCache::NNCache() + : mInitialized(false), + mMaxKeySize(0), + mMaxValueSize(0), + mMaxTotalSize(0), + mPolicy(defaultPolicy()), + mSavePending(false) {} -NNCache::~NNCache() { -} +NNCache::~NNCache() {} NNCache NNCache::sCache; @@ -72,8 +72,7 @@ void NNCache::terminate() { mInitialized = false; } -void NNCache::setBlob(const void* key, ssize_t keySize, - const void* value, ssize_t valueSize) { +void NNCache::setBlob(const void* key, ssize_t keySize, const void* value, ssize_t valueSize) { std::lock_guard<std::mutex> lock(mMutex); if (keySize < 0 || valueSize < 0) { @@ -100,8 +99,7 @@ void NNCache::setBlob(const void* key, ssize_t keySize, } } -ssize_t NNCache::getBlob(const void* key, ssize_t keySize, - void* value, ssize_t valueSize) { +ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void* value, ssize_t valueSize) { std::lock_guard<std::mutex> lock(mMutex); if (keySize < 0 || valueSize < 0) { @@ -116,8 +114,8 @@ ssize_t NNCache::getBlob(const void* key, ssize_t keySize, return 0; } -ssize_t NNCache::getBlob(const void* key, ssize_t keySize, - void** value, std::function<void*(size_t)> alloc) { +ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void** value, + std::function<void*(size_t)> alloc) { std::lock_guard<std::mutex> lock(mMutex); if (keySize < 0) { @@ -175,26 +173,23 @@ void NNCache::saveBlobCacheLocked() { // The file exists, delete it and try again. if (unlink(fname) == -1) { // No point in retrying if the unlink failed. - ALOGE("error unlinking cache file %s: %s (%d)", fname, - strerror(errno), errno); + ALOGE("error unlinking cache file %s: %s (%d)", fname, strerror(errno), errno); return; } // Retry now that we've unlinked the file. fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0); } if (fd == -1) { - ALOGE("error creating cache file %s: %s (%d)", fname, - strerror(errno), errno); + ALOGE("error creating cache file %s: %s (%d)", fname, strerror(errno), errno); return; } } size_t fileSize = headerSize + cacheSize; - uint8_t* buf = new uint8_t [fileSize]; + uint8_t* buf = new uint8_t[fileSize]; if (!buf) { - ALOGE("error allocating buffer for cache contents: %s (%d)", - strerror(errno), errno); + ALOGE("error allocating buffer for cache contents: %s (%d)", strerror(errno), errno); close(fd); unlink(fname); return; @@ -202,9 +197,8 @@ void NNCache::saveBlobCacheLocked() { int err = mBlobCache->flatten(buf + headerSize, cacheSize); if (err < 0) { - ALOGE("error writing cache contents: %s (%d)", strerror(-err), - -err); - delete [] buf; + ALOGE("error writing cache contents: %s (%d)", strerror(-err), -err); + delete[] buf; close(fd); unlink(fname); return; @@ -216,15 +210,14 @@ void NNCache::saveBlobCacheLocked() { *crc = crc32c(buf + headerSize, cacheSize); if (write(fd, buf, fileSize) == -1) { - ALOGE("error writing cache file: %s (%d)", strerror(errno), - errno); - delete [] buf; + ALOGE("error writing cache file: %s (%d)", strerror(errno), errno); + delete[] buf; close(fd); unlink(fname); return; } - delete [] buf; + delete[] buf; fchmod(fd, S_IRUSR); close(fd); } @@ -237,8 +230,8 @@ void NNCache::loadBlobCacheLocked() { int fd = open(mFilename.c_str(), O_RDONLY, 0); if (fd == -1) { if (errno != ENOENT) { - ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(), - strerror(errno), errno); + ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(), strerror(errno), + errno); } return; } @@ -253,17 +246,15 @@ void NNCache::loadBlobCacheLocked() { // Sanity check the size before trying to mmap it. size_t fileSize = statBuf.st_size; if (fileSize > mMaxTotalSize * 2) { - ALOGE("cache file is too large: %#" PRIx64, - static_cast<off64_t>(statBuf.st_size)); + ALOGE("cache file is too large: %#" PRIx64, static_cast<off64_t>(statBuf.st_size)); close(fd); return; } - uint8_t* buf = reinterpret_cast<uint8_t*>(mmap(NULL, fileSize, - PROT_READ, MAP_PRIVATE, fd, 0)); + uint8_t* buf = + reinterpret_cast<uint8_t*>(mmap(NULL, fileSize, PROT_READ, MAP_PRIVATE, fd, 0)); if (buf == MAP_FAILED) { - ALOGE("error mmaping cache file: %s (%d)", strerror(errno), - errno); + ALOGE("error mmaping cache file: %s (%d)", strerror(errno), errno); close(fd); return; } @@ -284,8 +275,7 @@ void NNCache::loadBlobCacheLocked() { int err = mBlobCache->unflatten(buf + headerSize, cacheSize); if (err < 0) { - ALOGE("error reading cache contents: %s (%d)", strerror(-err), - -err); + ALOGE("error reading cache contents: %s (%d)", strerror(-err), -err); munmap(buf, fileSize); close(fd); return; @@ -297,5 +287,5 @@ void NNCache::loadBlobCacheLocked() { } // ---------------------------------------------------------------------------- -}; // namespace android +}; // namespace android // ---------------------------------------------------------------------------- diff --git a/nn/driver/cache/nnCache/nnCache.h b/nn/driver/cache/nnCache/nnCache.h index 43f4ec350..a0ec6ee20 100644 --- a/nn/driver/cache/nnCache/nnCache.h +++ b/nn/driver/cache/nnCache/nnCache.h @@ -29,8 +29,7 @@ namespace android { // ---------------------------------------------------------------------------- class NNCache { -public: - + public: typedef BlobCache::Select Select; typedef BlobCache::Capacity Capacity; typedef BlobCache::Policy Policy; @@ -59,19 +58,17 @@ public: void terminate(); // setBlob attempts to insert a new key/value blob pair into the cache. - void setBlob(const void* key, ssize_t keySize, const void* value, - ssize_t valueSize); + void setBlob(const void* key, ssize_t keySize, const void* value, ssize_t valueSize); // getBlob attempts to retrieve the value blob associated with a given key // blob from cache. - ssize_t getBlob(const void* key, ssize_t keySize, - void* value, ssize_t valueSize); - ssize_t getBlob(const void* key, ssize_t keySize, - void** value, std::function<void*(size_t)> alloc); + ssize_t getBlob(const void* key, ssize_t keySize, void* value, ssize_t valueSize); + ssize_t getBlob(const void* key, ssize_t keySize, void** value, + std::function<void*(size_t)> alloc); template <typename T> - ssize_t getBlob(const void* key, size_t keySize, - T** value, std::function<void*(size_t)> alloc) { - void *valueVoid; + ssize_t getBlob(const void* key, size_t keySize, T** value, + std::function<void*(size_t)> alloc) { + void* valueVoid; const ssize_t size = getBlob(key, keySize, &valueVoid, alloc); *value = static_cast<T*>(valueVoid); return size; @@ -81,7 +78,7 @@ public: // cache contents from one program invocation to another. void setCacheFilename(const char* filename); -private: + private: // Creation and (the lack of) destruction is handled internally. NNCache(); ~NNCache(); @@ -153,7 +150,7 @@ private: }; // ---------------------------------------------------------------------------- -}; // namespace android +}; // namespace android // ---------------------------------------------------------------------------- #endif // ANDROID_FRAMEWORKS_ML_NN_DRIVER_CACHE_NN_CACHE_NN_CACHE_H diff --git a/nn/driver/cache/nnCache/nnCache_test.cpp b/nn/driver/cache/nnCache/nnCache_test.cpp index 059160361..7ef2ccc8a 100644 --- a/nn/driver/cache/nnCache/nnCache_test.cpp +++ b/nn/driver/cache/nnCache/nnCache_test.cpp @@ -36,10 +36,8 @@ static const size_t maxTotalSize = 2 * 1024 * 1024; namespace android { class NNCacheTest : public ::testing::TestWithParam<NNCache::Policy> { -protected: - virtual void SetUp() { - mCache = NNCache::get(); - } + protected: + virtual void SetUp() { mCache = NNCache::get(); } virtual void TearDown() { mCache->setCacheFilename(""); @@ -49,18 +47,19 @@ protected: NNCache* mCache; }; -INSTANTIATE_TEST_CASE_P(Policy, NNCacheTest, - ::testing::Values(NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::HALVE), - NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::HALVE), +INSTANTIATE_TEST_CASE_P( + Policy, NNCacheTest, + ::testing::Values(NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::HALVE), + NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::HALVE), - NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::FIT), - NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::FIT), + NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::FIT), + NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::FIT), - NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::FIT_HALVE), - NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::FIT_HALVE))); + NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::FIT_HALVE), + NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::FIT_HALVE))); TEST_P(NNCacheTest, UninitializedCacheAlwaysMisses) { - uint8_t buf[4] = { 0xee, 0xee, 0xee, 0xee }; + uint8_t buf[4] = {0xee, 0xee, 0xee, 0xee}; mCache->setBlob("abcd", 4, "efgh", 4); ASSERT_EQ(0, mCache->getBlob("abcd", 4, buf, 4)); ASSERT_EQ(0xee, buf[0]); @@ -70,7 +69,7 @@ TEST_P(NNCacheTest, UninitializedCacheAlwaysMisses) { } TEST_P(NNCacheTest, InitializedCacheAlwaysHits) { - uint8_t buf[4] = { 0xee, 0xee, 0xee, 0xee }; + uint8_t buf[4] = {0xee, 0xee, 0xee, 0xee}; mCache->initialize(maxKeySize, maxValueSize, maxTotalSize, GetParam()); mCache->setBlob("abcd", 4, "efgh", 4); ASSERT_EQ(4, mCache->getBlob("abcd", 4, buf, 4)); @@ -81,7 +80,7 @@ TEST_P(NNCacheTest, InitializedCacheAlwaysHits) { } TEST_P(NNCacheTest, TerminatedCacheAlwaysMisses) { - uint8_t buf[4] = { 0xee, 0xee, 0xee, 0xee }; + uint8_t buf[4] = {0xee, 0xee, 0xee, 0xee}; mCache->initialize(maxKeySize, maxValueSize, maxTotalSize, GetParam()); mCache->setBlob("abcd", 4, "efgh", 4); @@ -122,18 +121,17 @@ TEST_P(NNCacheTest, ExceedingTotalLimitFitsBigEntry) { } // Insert one more entry, causing a cache overflow. const int bigValueSize = std::min((MAX_TOTAL_SIZE * 3) / 4 - 1, int(MAX_VALUE_SIZE)); - ASSERT_GT(bigValueSize+1, MAX_TOTAL_SIZE / 2); // Check testing assumption + ASSERT_GT(bigValueSize + 1, MAX_TOTAL_SIZE / 2); // Check testing assumption { unsigned char buf[MAX_VALUE_SIZE]; - for (int i = 0; i < bigValueSize; i++) - buf[i] = 0xee; + for (int i = 0; i < bigValueSize; i++) buf[i] = 0xee; uint8_t k = maxEntries; mCache->setBlob(&k, 1, buf, bigValueSize); } // Count the number and size of entries in the cache. int numCached = 0; size_t sizeCached = 0; - for (int i = 0; i < maxEntries+1; i++) { + for (int i = 0; i < maxEntries + 1; i++) { uint8_t k = i; size_t size = mCache->getBlob(&k, 1, NULL, 0); if (size) { @@ -145,8 +143,8 @@ TEST_P(NNCacheTest, ExceedingTotalLimitFitsBigEntry) { case NNCache::Capacity::HALVE: // New value is too big for this cleaning algorithm. So // we cleaned the cache, but did not insert the new value. - ASSERT_EQ(maxEntries/2, numCached); - ASSERT_EQ(size_t((maxEntries/2)*2), sizeCached); + ASSERT_EQ(maxEntries / 2, numCached); + ASSERT_EQ(size_t((maxEntries / 2) * 2), sizeCached); break; case NNCache::Capacity::FIT: case NNCache::Capacity::FIT_HALVE: { @@ -175,9 +173,7 @@ TEST_P(NNCacheTest, ExceedingTotalLimitFitsBigEntry) { } class NNCacheSerializationTest : public NNCacheTest { - -protected: - + protected: virtual void SetUp() { NNCacheTest::SetUp(); mTempFile.reset(new TemporaryFile()); @@ -190,7 +186,7 @@ protected: std::unique_ptr<TemporaryFile> mTempFile; - void yesStringBlob(const char *key, const char *value) { + void yesStringBlob(const char* key, const char* value) { SCOPED_TRACE(key); uint8_t buf[10]; @@ -206,7 +202,7 @@ protected: } } - void noStringBlob(const char *key) { + void noStringBlob(const char* key) { SCOPED_TRACE(key); uint8_t buf[10]; @@ -219,11 +215,10 @@ protected: ASSERT_EQ(0xee, buf[i]); } } - }; TEST_P(NNCacheSerializationTest, ReinitializedCacheContainsValues) { - uint8_t buf[4] = { 0xee, 0xee, 0xee, 0xee }; + uint8_t buf[4] = {0xee, 0xee, 0xee, 0xee}; mCache->setCacheFilename(&mTempFile->path[0]); mCache->initialize(maxKeySize, maxValueSize, maxTotalSize, GetParam()); mCache->setBlob("abcd", 4, "efgh", 4); @@ -235,7 +230,7 @@ TEST_P(NNCacheSerializationTest, ReinitializedCacheContainsValues) { // - we do not modify the buffer that value pointer originally points to // - the value pointer gets set to something other than nullptr // - the newly-allocated buffer is set properly - uint8_t *bufPtr = &buf[0]; + uint8_t* bufPtr = &buf[0]; ASSERT_EQ(4, mCache->getBlob("abcd", 4, &bufPtr, malloc)); ASSERT_EQ(0xee, buf[0]); ASSERT_EQ(0xee, buf[1]); @@ -273,8 +268,8 @@ TEST_P(NNCacheSerializationTest, ReinitializedCacheContainsValuesSizeConstrained SCOPED_TRACE("after second initialize()"); yesStringBlob("abcd", "efgh"); noStringBlob("abcdef"); // key too large - noStringBlob("ab"); // value too large + noStringBlob("ab"); // value too large } } -} +} // namespace android diff --git a/nn/driver/sample/SampleDriver.cpp b/nn/driver/sample/SampleDriver.cpp index ab2c9db32..0cc2d256e 100644 --- a/nn/driver/sample/SampleDriver.cpp +++ b/nn/driver/sample/SampleDriver.cpp @@ -203,8 +203,7 @@ Return<ErrorStatus> SampleDriver::prepareModelFromCache( } Return<DeviceStatus> SampleDriver::getStatus() { - NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, - "SampleDriver::getStatus"); + NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus"); VLOG(DRIVER) << "getStatus()"; return DeviceStatus::AVAILABLE; } @@ -298,8 +297,7 @@ Return<ErrorStatus> executeBase(const Request& request, MeasureTiming measure, c // is expected to live forever. std::thread([&model, &driver, &poolInfos, request, measure, driverStart, callback] { asyncExecute(request, measure, driverStart, model, driver, poolInfos, callback); - }) - .detach(); + }).detach(); return ErrorStatus::NONE; } @@ -469,6 +467,6 @@ Return<void> SamplePreparedModel::configureExecutionBurst( return Void(); } -} // namespace sample_driver -} // namespace nn -} // namespace android +} // namespace sample_driver +} // namespace nn +} // namespace android diff --git a/nn/driver/sample/SampleDriverFloatFast.cpp b/nn/driver/sample/SampleDriverFloatFast.cpp index 329408fa6..3611bbaeb 100644 --- a/nn/driver/sample/SampleDriverFloatFast.cpp +++ b/nn/driver/sample/SampleDriverFloatFast.cpp @@ -33,7 +33,7 @@ namespace sample_driver { using namespace hal; class SampleDriverFloatFast : public SampleDriver { -public: + public: SampleDriverFloatFast() : SampleDriver("sample-float-fast") {} Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override; Return<void> getSupportedOperations_1_2(const V1_2::Model& model, @@ -78,12 +78,12 @@ Return<void> SampleDriverFloatFast::getSupportedOperations_1_2(const V1_2::Model return Void(); } -} // namespace sample_driver -} // namespace nn -} // namespace android +} // namespace sample_driver +} // namespace nn +} // namespace android -using android::nn::sample_driver::SampleDriverFloatFast; using android::sp; +using android::nn::sample_driver::SampleDriverFloatFast; int main() { sp<SampleDriverFloatFast> driver(new SampleDriverFloatFast()); diff --git a/nn/driver/sample/SampleDriverFloatSlow.cpp b/nn/driver/sample/SampleDriverFloatSlow.cpp index c5571eb55..af498379f 100644 --- a/nn/driver/sample/SampleDriverFloatSlow.cpp +++ b/nn/driver/sample/SampleDriverFloatSlow.cpp @@ -33,7 +33,7 @@ namespace sample_driver { using namespace hal; class SampleDriverFloatSlow : public SampleDriver { -public: + public: SampleDriverFloatSlow() : SampleDriver("sample-float-slow") {} Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override; Return<void> getSupportedOperations_1_2(const V1_2::Model& model, @@ -78,12 +78,12 @@ Return<void> SampleDriverFloatSlow::getSupportedOperations_1_2(const V1_2::Model return Void(); } -} // namespace sample_driver -} // namespace nn -} // namespace android +} // namespace sample_driver +} // namespace nn +} // namespace android -using android::nn::sample_driver::SampleDriverFloatSlow; using android::sp; +using android::nn::sample_driver::SampleDriverFloatSlow; int main() { sp<SampleDriverFloatSlow> driver(new SampleDriverFloatSlow()); diff --git a/nn/driver/sample/SampleDriverMinimal.cpp b/nn/driver/sample/SampleDriverMinimal.cpp index 6ff596458..e0420348e 100644 --- a/nn/driver/sample/SampleDriverMinimal.cpp +++ b/nn/driver/sample/SampleDriverMinimal.cpp @@ -34,7 +34,7 @@ namespace sample_driver { using namespace hal; class SampleDriverMinimal : public SampleDriver { -public: + public: SampleDriverMinimal() : SampleDriver("sample-minimal") {} Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override; Return<void> getSupportedOperations_1_2(const V1_2::Model& model, @@ -90,12 +90,12 @@ Return<void> SampleDriverMinimal::getSupportedOperations_1_2(const V1_2::Model& return Void(); } -} // namespace sample_driver -} // namespace nn -} // namespace android +} // namespace sample_driver +} // namespace nn +} // namespace android -using android::nn::sample_driver::SampleDriverMinimal; using android::sp; +using android::nn::sample_driver::SampleDriverMinimal; int main() { sp<SampleDriverMinimal> driver(new SampleDriverMinimal()); diff --git a/nn/driver/sample/SampleDriverQuant.cpp b/nn/driver/sample/SampleDriverQuant.cpp index 8e4c1854e..83fc550fe 100644 --- a/nn/driver/sample/SampleDriverQuant.cpp +++ b/nn/driver/sample/SampleDriverQuant.cpp @@ -33,7 +33,7 @@ namespace sample_driver { using namespace hal; class SampleDriverQuant : public SampleDriver { -public: + public: SampleDriverQuant() : SampleDriver("sample-quant") {} Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override; Return<void> getSupportedOperations_1_2(const V1_2::Model& model, @@ -74,12 +74,12 @@ Return<void> SampleDriverQuant::getSupportedOperations_1_2(const V1_2::Model& mo return Void(); } -} // namespace sample_driver -} // namespace nn -} // namespace android +} // namespace sample_driver +} // namespace nn +} // namespace android -using android::nn::sample_driver::SampleDriverQuant; using android::sp; +using android::nn::sample_driver::SampleDriverQuant; int main() { sp<SampleDriverQuant> driver(new SampleDriverQuant()); diff --git a/nn/runtime/CompilationBuilder.cpp b/nn/runtime/CompilationBuilder.cpp index 6de44d137..912f0087b 100644 --- a/nn/runtime/CompilationBuilder.cpp +++ b/nn/runtime/CompilationBuilder.cpp @@ -90,8 +90,8 @@ int CompilationBuilder::finish() { int CompilationBuilder::setPreference(int32_t preference) { if (mFinished) { - LOG(ERROR) << - "ANeuralNetworksCompilation_setPreference can't modify after compilation finished"; + LOG(ERROR) << "ANeuralNetworksCompilation_setPreference can't modify after compilation " + "finished"; return ANEURALNETWORKS_BAD_STATE; } if (preference >= kNumberOfPreferences) { @@ -121,8 +121,8 @@ int CompilationBuilder::setCaching(const std::string& cacheDir, const uint8_t* t int CompilationBuilder::setPartitioning(uint32_t partitioning) { if (mFinished) { - LOG(ERROR) << - "ANeuralNetworksCompilation_setPartitioning can't modify after compilation finished"; + LOG(ERROR) << "ANeuralNetworksCompilation_setPartitioning can't modify after compilation " + "finished"; return ANEURALNETWORKS_BAD_STATE; } @@ -130,7 +130,7 @@ int CompilationBuilder::setPartitioning(uint32_t partitioning) { return ANEURALNETWORKS_NO_ERROR; } -int CompilationBuilder::createExecution(ExecutionBuilder **execution) { +int CompilationBuilder::createExecution(ExecutionBuilder** execution) { if (!mFinished) { LOG(ERROR) << "ANeuralNetworksExecution_create passed an unfinished compilation"; *execution = nullptr; diff --git a/nn/runtime/CompilationBuilder.h b/nn/runtime/CompilationBuilder.h index 7c0b786f8..a47f99903 100644 --- a/nn/runtime/CompilationBuilder.h +++ b/nn/runtime/CompilationBuilder.h @@ -32,7 +32,7 @@ class ExecutionBuilder; class ModelBuilder; class CompilationBuilder { -public: + public: friend class ExecutionBuilder; // TODO remove this // explicitDeviceList is true if the list of devices was provided explicitly @@ -56,7 +56,7 @@ public: const ExecutionPlan& forTest_getExecutionPlan() const { return mPlan; } -private: + private: const ModelBuilder* mModel; ExecutionPlan mPlan; @@ -88,7 +88,7 @@ private: bool mIsCacheInfoProvided = false; }; -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android #endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_COMPILATION_BUILDER_H diff --git a/nn/runtime/ExecutionBuilder.h b/nn/runtime/ExecutionBuilder.h index 1c8b1d68c..a837238a1 100644 --- a/nn/runtime/ExecutionBuilder.h +++ b/nn/runtime/ExecutionBuilder.h @@ -44,7 +44,8 @@ class StepExecutor; class ExecutionBuilder { friend class StepExecutor; -public: + + public: ExecutionBuilder(const CompilationBuilder* compilation); int setInput(uint32_t index, const ANeuralNetworksOperandType* type, const void* buffer, @@ -171,21 +172,18 @@ class StepExecutor { mapInputOrOutput(mExecutionBuilder->mOutputs[builderIndex], &mOutputs[executorIndex]); } void mapOutputToInput(uint32_t builderIndex, uint32_t executorIndex) { - mapInputOrOutput(mExecutionBuilder->mOutputs[builderIndex], - &mInputs[executorIndex]); + mapInputOrOutput(mExecutionBuilder->mOutputs[builderIndex], &mInputs[executorIndex]); } // The input or output is assumed to have the size of the // corresponding operand. int setInputFromTemporaryMemory(uint32_t inputIndex, const Memory* memory, uint32_t offset) { - return setInputOrOutputFromTemporaryMemory(mModel->getInputOperand(inputIndex), - memory, offset, - &mInputs.at(inputIndex)); + return setInputOrOutputFromTemporaryMemory(mModel->getInputOperand(inputIndex), memory, + offset, &mInputs.at(inputIndex)); } int setOutputFromTemporaryMemory(uint32_t outputIndex, const Memory* memory, uint32_t offset) { - return setInputOrOutputFromTemporaryMemory(mModel->getOutputOperand(outputIndex), - memory, offset, - &mOutputs.at(outputIndex)); + return setInputOrOutputFromTemporaryMemory(mModel->getOutputOperand(outputIndex), memory, + offset, &mOutputs.at(outputIndex)); } // Executes using the (driver, preparedModel) specified at construction time. @@ -238,7 +236,7 @@ class StepExecutor { MemoryTracker mMemories; }; -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android #endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_EXECUTION_BUILDER_H diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp index 4e5409da5..866bfd738 100644 --- a/nn/runtime/ExecutionPlan.cpp +++ b/nn/runtime/ExecutionPlan.cpp @@ -211,7 +211,7 @@ int copyOperandExtraParams(ModelBuilder& model, uint32_t toOperandIndex, // This class tracks whether we know the value of an operand as operations // are processed. class OperandTracker { -public: + public: // Creates the tracker for this model. Figure out which operations can be // executed right away and cb for each one of them. OperandTracker(const ModelBuilder* model, OperationReadyCallback cb); @@ -220,14 +220,14 @@ public: // able to run. Call cb for each one of them. void markProcessed(uint32_t operationIndex, OperationReadyCallback cb); -private: + private: const ModelBuilder* mModel; std::multimap<uint32_t, uint32_t> mOperandToOperations; std::vector<uint32_t> mUnknownInputCount; // For each operation }; -OperandTracker::OperandTracker(const ModelBuilder* model, OperationReadyCallback cb) : - mModel(model) { +OperandTracker::OperandTracker(const ModelBuilder* model, OperationReadyCallback cb) + : mModel(model) { const auto& operations = mModel->getOperations(); mUnknownInputCount.resize(operations.size()); for (uint32_t operationIndex = 0; operationIndex < operations.size(); operationIndex++) { @@ -319,9 +319,8 @@ int ExecutionStep::addOperand(uint32_t fromOperandIndex, uint32_t* toOperandInde } break; case OperandLifeTime::CONSTANT_REFERENCE: { const Memory* memory = fromModel.getMemories()[operand.location.poolIndex]; - n = mSubModel.setOperandValueFromMemory(*toOperandIndex, memory, - operand.location.offset, - operand.location.length); + n = mSubModel.setOperandValueFromMemory( + *toOperandIndex, memory, operand.location.offset, operand.location.length); if (n != ANEURALNETWORKS_NO_ERROR) { LOG(ERROR) << "Previous error occurred when partitioning the graph"; return n; @@ -355,7 +354,8 @@ int ExecutionStep::addOperand(uint32_t fromOperandIndex, uint32_t* toOperandInde // The first time we've seen this operand is as an // input. That means it must be defined by a // different partition, and is an input to this one. - mOutputsAsSubModelInputs.push_back(std::make_pair(fromOperandIndex, *toOperandIndex)); + mOutputsAsSubModelInputs.push_back( + std::make_pair(fromOperandIndex, *toOperandIndex)); } else { // The first time we've seen this operand is as an // output. @@ -396,8 +396,7 @@ int ExecutionStep::addOperation(int operationIndex, const ModelBuilder& fromMode for (uint32_t i = 0; i < operandCount; i++) { uint32_t localOperand = ~0U; int n = addOperand(globalOperands[i], &localOperand, fromModel, kind); - if (n != ANEURALNETWORKS_NO_ERROR) - return n; + if (n != ANEURALNETWORKS_NO_ERROR) return n; localOperands[i] = localOperand; } return ANEURALNETWORKS_NO_ERROR; @@ -410,7 +409,7 @@ int ExecutionStep::addOperation(int operationIndex, const ModelBuilder& fromMode } return mSubModel.addOperation(static_cast<uint32_t>(operation.type), inputCount, inputs.data(), - outputCount, outputs.data()); + outputCount, outputs.data()); } void ExecutionStep::mapInputsAndOutputs(std::shared_ptr<StepExecutor> stepExecutor) const { @@ -438,7 +437,7 @@ void ExecutionPlan::CompoundBody::findTempsAsSubModelOutputs() { void ExecutionStep::logSubModel() const { VLOG(COMPILATION) << "ExecutionStep::finishSubModel, step " << mIndex; - auto logRemapEntry = [](std::string &toLog, const std::pair<uint32_t, uint32_t>& e) { + auto logRemapEntry = [](std::string& toLog, const std::pair<uint32_t, uint32_t>& e) { if (!toLog.empty()) { toLog += ", "; } @@ -475,13 +474,13 @@ static void convertModelInputsOrOutputs( // IN: mModel{Inputs|Outputs} const ExecutionStep::RemapVectorType& myModelInputsOrOutputs, // IN: fromModel->{input|output}Count() - uint32_t fromModelInputOrOutputCount, + uint32_t fromModelInputOrOutputCount, // IN: fromModel->get{Input|Output}OperandIndex - std::function<uint32_t(uint32_t)> fromModelGetInputOrOutputOperandIndex, + std::function<uint32_t(uint32_t)> fromModelGetInputOrOutputOperandIndex, // OUT: for v : mModel{Inputs|Outputs} : v.second - std::vector<uint32_t>* inputsOrOutputs, + std::vector<uint32_t>* inputsOrOutputs, // OUT: submodel input-or-output index to original model input-or-output index - std::vector<uint32_t>* inputOrOutputIndexSubModelToFromModel) { + std::vector<uint32_t>* inputOrOutputIndexSubModelToFromModel) { std::map<uint32_t, uint32_t> fromModelIndexMap; // operand index to input-or-output index for (uint32_t i = 0; i < fromModelInputOrOutputCount; i++) { fromModelIndexMap[fromModelGetInputOrOutputOperandIndex(i)] = i; @@ -508,11 +507,10 @@ int ExecutionStep::finishSubModel(const ModelBuilder* fromModel, bool* hasOutput // ExecutionPlan::next() depends on these orderings. std::vector<uint32_t> inputs; - convertModelInputsOrOutputs(mModelInputs, - fromModel->inputCount(), - [=](uint32_t i) { return fromModel->getInputOperandIndex(i); }, - &inputs, - &mInputIndexSubModelToFromModel); + convertModelInputsOrOutputs( + mModelInputs, fromModel->inputCount(), + [=](uint32_t i) { return fromModel->getInputOperandIndex(i); }, &inputs, + &mInputIndexSubModelToFromModel); for (const auto& subModelInput : mTempsAsSubModelInputs) { inputs.push_back(subModelInput.second); } @@ -521,11 +519,10 @@ int ExecutionStep::finishSubModel(const ModelBuilder* fromModel, bool* hasOutput } std::vector<uint32_t> outputs; - convertModelInputsOrOutputs(mModelOutputs, - fromModel->outputCount(), - [=](uint32_t i) { return fromModel->getOutputOperandIndex(i); }, - &outputs, - &mOutputIndexSubModelToFromModel); + convertModelInputsOrOutputs( + mModelOutputs, fromModel->outputCount(), + [=](uint32_t i) { return fromModel->getOutputOperandIndex(i); }, &outputs, + &mOutputIndexSubModelToFromModel); for (const auto& subModelOutput : mTempsAsSubModelOutputs) { outputs.push_back(subModelOutput.second); const Operand& operand = mSubModel.getOperand(subModelOutput.second); @@ -546,7 +543,8 @@ int ExecutionStep::finishSubModel(const ModelBuilder* fromModel, bool* hasOutput } { - int n = mSubModel.identifyInputsAndOutputs(inputs.size(), &inputs[0], outputs.size(), &outputs[0]); + int n = mSubModel.identifyInputsAndOutputs(inputs.size(), &inputs[0], outputs.size(), + &outputs[0]); if (n != ANEURALNETWORKS_NO_ERROR) { return n; } @@ -601,7 +599,8 @@ int ExecutionPlan::CompoundBody::finish(const ModelBuilder* fromModel, } } if (mHasSubModelOutputOfUnknownSize) { - VLOG(COMPILATION) << "ExecutionPlan::CompoundBody::finish -- mHasSubModelOutputOfUnknownSize"; + VLOG(COMPILATION) + << "ExecutionPlan::CompoundBody::finish -- mHasSubModelOutputOfUnknownSize"; return ANEURALNETWORKS_OP_FAILED; } @@ -709,7 +708,7 @@ std::shared_ptr<ExecutionPlan::Controller> ExecutionPlan::makeController( if (mState == COMPOUND) { const ModelBuilder* fromModel = executionBuilder->getModel(); for (const auto& step : compound()->mSteps) { - for (const auto& output: step->getTempsAsSubModelOutputs()) { + for (const auto& output : step->getTempsAsSubModelOutputs()) { const uint32_t fromModelOperandIndex = output.first; const Operand& fromModelOperand = fromModel->getOperand(fromModelOperandIndex); if (subModelInputsAndOutputs == nullptr) { @@ -718,14 +717,14 @@ std::shared_ptr<ExecutionPlan::Controller> ExecutionPlan::makeController( } const uint32_t size = TypeManager::get()->getSizeOfData(fromModelOperand); totalSizeOfTemporaries += alignBytesNeeded(totalSizeOfTemporaries, size); - subModelInputsAndOutputs->insert(std::make_pair(fromModelOperandIndex, totalSizeOfTemporaries)); + subModelInputsAndOutputs->insert( + std::make_pair(fromModelOperandIndex, totalSizeOfTemporaries)); totalSizeOfTemporaries += size; } } if (VLOG_IS_ON(EXECUTION) && (subModelInputsAndOutputs != nullptr)) { for (const auto& io : *subModelInputsAndOutputs) { - VLOG(EXECUTION) << "temp: origOpndIdx = " << io.first - << ", offset = " << io.second; + VLOG(EXECUTION) << "temp: origOpndIdx = " << io.first << ", offset = " << io.second; } } } @@ -735,7 +734,6 @@ std::shared_ptr<ExecutionPlan::Controller> ExecutionPlan::makeController( totalSizeOfTemporaries)); } - // TODO: Find a better way to provide this functionality. int ExecutionPlan::fallback(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor) const { @@ -766,8 +764,7 @@ int ExecutionPlan::next(std::shared_ptr<Controller> controller, *burstController = nullptr; } - VLOG(EXECUTION) << "ExecutionPlan::next(" - << SHOW_IF_DEBUG(controller << ", " << executor) + VLOG(EXECUTION) << "ExecutionPlan::next(" << SHOW_IF_DEBUG(controller << ", " << executor) << "): mNextStepIndex = " << controller->mNextStepIndex; if (controller->mNextStepIndex == Controller::kBadStepIndex) { @@ -832,11 +829,10 @@ int ExecutionPlan::next(std::shared_ptr<Controller> controller, for (auto I = subModelOutputs.begin(), E = subModelOutputs.end(); I != E; I++, idx++) { const uint32_t fromModelOperandIndex = I->first; const uint32_t offsetOfTemporary = - controller->mSubModelInputsAndOutputs->at(fromModelOperandIndex); - int n = (*executor)->setOutputFromTemporaryMemory( - firstSubModelOutputIndex + idx, - &controller->mTemporaries, - offsetOfTemporary); + controller->mSubModelInputsAndOutputs->at(fromModelOperandIndex); + int n = (*executor)->setOutputFromTemporaryMemory(firstSubModelOutputIndex + idx, + &controller->mTemporaries, + offsetOfTemporary); if (n != ANEURALNETWORKS_NO_ERROR) { controller->mNextStepIndex = Controller::kBadStepIndex; return n; @@ -853,11 +849,10 @@ int ExecutionPlan::next(std::shared_ptr<Controller> controller, for (auto I = subModelInputs.begin(), E = subModelInputs.end(); I != E; I++, idx++) { const uint32_t fromModelOperandIndex = I->first; const uint32_t offsetOfTemporary = - controller->mSubModelInputsAndOutputs->at(fromModelOperandIndex); - int n = (*executor)->setInputFromTemporaryMemory( - firstSubModelInputIndex + idx, - &controller->mTemporaries, - offsetOfTemporary); + controller->mSubModelInputsAndOutputs->at(fromModelOperandIndex); + int n = (*executor)->setInputFromTemporaryMemory(firstSubModelInputIndex + idx, + &controller->mTemporaries, + offsetOfTemporary); if (n != ANEURALNETWORKS_NO_ERROR) { controller->mNextStepIndex = Controller::kBadStepIndex; return n; @@ -1057,7 +1052,7 @@ PerformanceInfo ModelBuilder::getPerformanceInfo(const std::shared_ptr<Device> d // currently the case but is not a safe assumption to make in the long term. const uint32_t operandIndex = operation.inputs[0]; const OperandType operandType = mOperands[operandIndex].type; - switch(operandType) { + switch (operandType) { case OperandType::FLOAT32: if (mRelaxComputationFloat32toFloat16) { return device->getRelaxedFloat32toFloat16PerformanceScalar(); @@ -1079,7 +1074,7 @@ namespace { // This class determines whether a given device can execute a given operation class CanDo { -public: + public: CanDo() {} void initialize(const MetaModel& metaModel, std::shared_ptr<Device> device) { @@ -1088,7 +1083,7 @@ public: bool check(size_t operationIndex) const { return mSupportsOperationByIndex[operationIndex]; } -private: + private: hidl_vec<bool> mSupportsOperationByIndex; }; @@ -1116,8 +1111,8 @@ int ModelBuilder::findBestDeviceForEachOperation( if (canDo[deviceIndex].check(operationIndex)) { const PerformanceInfo perf = getPerformanceInfo(device, operationIndex); const float perfVal = - (preference == ANEURALNETWORKS_PREFER_LOW_POWER ? perf.powerUsage - : perf.execTime); + (preference == ANEURALNETWORKS_PREFER_LOW_POWER ? perf.powerUsage + : perf.execTime); if (bestChoice < 0 || perfVal < bestPerfVal || (perfVal == bestPerfVal && device == DeviceManager::getCpuDevice())) { bestChoice = deviceIndex; @@ -1129,8 +1124,7 @@ int ModelBuilder::findBestDeviceForEachOperation( // specific device. // Logs O(operationCount * deviceCount) times, but // typically deviceCount is very small. - VLOG(COMPILATION) << "Device " << device->getName() - << " can't do operation " + VLOG(COMPILATION) << "Device " << device->getName() << " can't do operation " << toString(getOperation(operationIndex).type); } } @@ -1147,5 +1141,5 @@ int ModelBuilder::findBestDeviceForEachOperation( return ANEURALNETWORKS_NO_ERROR; } -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android diff --git a/nn/runtime/ExecutionPlan.h b/nn/runtime/ExecutionPlan.h index be9b9aae0..6697c8293 100644 --- a/nn/runtime/ExecutionPlan.h +++ b/nn/runtime/ExecutionPlan.h @@ -51,7 +51,7 @@ class PreparedModel; class StepExecutor; class ExecutionStep { -public: + public: typedef std::vector<std::pair<uint32_t, uint32_t>> RemapVectorType; typedef std::set<std::pair<uint32_t, uint32_t>> SubModelOutputSetType; @@ -63,21 +63,13 @@ public: const ModelBuilder& fromModel, OperandKind kind); // Each container entry is of the form (fromModel index, subModel index) - const RemapVectorType& getModelInputs() const { - return mModelInputs; - } - const RemapVectorType& getModelOutputs() const { - return mModelOutputs; - } - const RemapVectorType& getTempsAsSubModelInputs() const { - return mTempsAsSubModelInputs; - } + const RemapVectorType& getModelInputs() const { return mModelInputs; } + const RemapVectorType& getModelOutputs() const { return mModelOutputs; } + const RemapVectorType& getTempsAsSubModelInputs() const { return mTempsAsSubModelInputs; } const SubModelOutputSetType& getTempsAsSubModelOutputs() const { return mTempsAsSubModelOutputs; } - const RemapVectorType& getOutputsAsSubModelInputs() const { - return mOutputsAsSubModelInputs; - } + const RemapVectorType& getOutputsAsSubModelInputs() const { return mOutputsAsSubModelInputs; } const std::vector<uint32_t>& getOutputIndexSubModelToFromModel() const { return mOutputIndexSubModelToFromModel; } @@ -170,11 +162,11 @@ public: }; class ExecutionPlan { -public: + public: ExecutionPlan(const ExecutionPlan&) = delete; ExecutionPlan& operator=(const ExecutionPlan&) = delete; - ExecutionPlan() { } + ExecutionPlan() {} ~ExecutionPlan() { delete mBody; } // Controller is part of the interface to a mechanism for @@ -190,7 +182,8 @@ public: // a problem has occurred. class Controller { friend class ExecutionPlan; - private: + + private: Controller(const Controller&) = delete; Controller& operator=(const Controller&) = delete; @@ -209,7 +202,8 @@ public: const ExecutionPlan* mPlan; ExecutionBuilder* mExecutionBuilder; const BurstBuilder* mBurstBuilder; - std::shared_ptr<const SubModelInputsAndOutputsType> mSubModelInputsAndOutputs; // may be nullptr + std::shared_ptr<const SubModelInputsAndOutputsType> + mSubModelInputsAndOutputs; // may be nullptr Memory mTemporaries; size_t mNextStepIndex; }; @@ -223,7 +217,8 @@ public: std::shared_ptr<ExecutionBurstController>* burstController = nullptr) const; // Create the same executor as the last one created by next(). - int fallback(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor) const; + int fallback(std::shared_ptr<Controller> controller, + std::shared_ptr<StepExecutor>* executor) const; std::shared_ptr<ExecutionStep> createNewStep(const std::shared_ptr<Device> device); @@ -310,7 +305,8 @@ public: std::unordered_map<uint32_t, uint32_t> mTemporaryToDefiningStep; bool mHasSubModelOutputOfUnknownSize = false; - private: + + private: void findTempsAsSubModelOutputs(); }; diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index 2511cd343..7f91b2059 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -228,7 +228,7 @@ void DriverDevice::getSupportedOperations(const MetaModel& metaModel, } uint32_t accumulator = baseAccumulator; - const Operation &operation = hidlModel.operations[operationIndex]; + const Operation& operation = hidlModel.operations[operationIndex]; accumulator ^= static_cast<uint32_t>(operation.type); auto accumulateOperands = [&hidlModel, &accumulator](const hidl_vec<uint32_t>& operands) { for (uint32_t operandIndex : operands) { diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h index 01acdf391..2c2d6cf99 100644 --- a/nn/runtime/Manager.h +++ b/nn/runtime/Manager.h @@ -127,11 +127,7 @@ class DeviceManager { // 1 - Do graph partitioning; but fall back to non-partitioned // execution if there is a partitioning failure. // 2 - Do graph partitioning, and rely on it; there is no fallback. - enum { - kPartitioningNo = 0, - kPartitioningWithFallback = 1, - kPartitioningWithoutFallback = 2 - }; + enum { kPartitioningNo = 0, kPartitioningWithFallback = 1, kPartitioningWithoutFallback = 2 }; uint32_t getPartitioning() const { return mPartitioning; } static bool partitioningAllowsFallback(uint32_t partitioning) { return partitioning == kPartitioningWithFallback; @@ -209,7 +205,7 @@ class DeviceManager { bool mStrictSlicing = false; }; -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android #endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MANAGER_H diff --git a/nn/runtime/Memory.cpp b/nn/runtime/Memory.cpp index a2e3de56c..20fabfa08 100644 --- a/nn/runtime/Memory.cpp +++ b/nn/runtime/Memory.cpp @@ -161,5 +161,5 @@ uint32_t MemoryTracker::add(const Memory* memory) { return idx; } -} // namespace nn -} // namespace android +} // namespace nn +} // namespace android diff --git a/nn/runtime/ModelBuilder.h b/nn/runtime/ModelBuilder.h index 6cf93b914..4bfcd04f9 100644 --- a/nn/runtime/ModelBuilder.h +++ b/nn/runtime/ModelBuilder.h @@ -170,7 +170,6 @@ class ModelBuilder { // No further modifications are allowed to the model. bool mInvalidModel = false; - // 'true' indicates TENSOR_FLOAT32 may be calculated with range and/or // precision as low as that of the IEEE 754 16-bit floating-point format. // 'false' indicates TENSOR_FLOAT32 must be calculated using at least the diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index 038068ee8..beaaa261e 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -76,7 +76,8 @@ class DeathHandler : public hidl_death_recipient { [this, callback] { unregisterCallback(callback); }); } - private : void registerCallback(const sp<ICallback>& callback) { + private: + void registerCallback(const sp<ICallback>& callback) { std::lock_guard<std::mutex> hold(mMutex); mCallbacks.push_back(callback); } diff --git a/nn/runtime/test/TestExecution.cpp b/nn/runtime/test/TestExecution.cpp index 1a753c73f..af62ab46c 100644 --- a/nn/runtime/test/TestExecution.cpp +++ b/nn/runtime/test/TestExecution.cpp @@ -308,9 +308,9 @@ class TestDriver10 : public V1_0::IDevice { // This class adds some simple utilities on top of WrapperCompilation in order // to provide access to certain features from CompilationBuilder that are not // exposed by the base class. -template<typename DriverClass> +template <typename DriverClass> class TestCompilation : public WrapperCompilation { -public: + public: // Allow dummying up the error status for all executions from this // compilation. If errorStatus is NONE, then execute behaves // normally (and sends back the actual execution status). @@ -433,20 +433,21 @@ class ExecutionTestTemplate private: static WrapperModel makeModel() { - static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, { 1 }); + static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, {1}); WrapperModel model; uint32_t input = model.addOperand(&tensorType); uint32_t output = model.addOperand(&tensorType); - model.addOperation(ANEURALNETWORKS_FLOOR, { input }, { output }); - model.identifyInputsAndOutputs({ input }, { output } ); + model.addOperation(ANEURALNETWORKS_FLOOR, {input}, {output}); + model.identifyInputsAndOutputs({input}, {output}); assert(model.finish() == Result::NO_ERROR); return model; } }; -template<class DriverClass> void ExecutionTestTemplate<DriverClass>::TestWait() { +template <class DriverClass> +void ExecutionTestTemplate<DriverClass>::TestWait() { SCOPED_TRACE(kName); // Skip Introspection API tests when CPU only flag is forced on. if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) { diff --git a/nn/runtime/test/TestMemory.cpp b/nn/runtime/test/TestMemory.cpp index fd342dd7a..122bde2b6 100644 --- a/nn/runtime/test/TestMemory.cpp +++ b/nn/runtime/test/TestMemory.cpp @@ -35,9 +35,8 @@ namespace { // Tests the various ways to pass weights and input/output data. class MemoryTest : public ::testing::Test { -protected: + protected: void SetUp() override {} - }; TEST_F(MemoryTest, TestFd) { diff --git a/nn/runtime/test/TestMemory.h b/nn/runtime/test/TestMemory.h index 57df0fe68..3ce179026 100644 --- a/nn/runtime/test/TestMemory.h +++ b/nn/runtime/test/TestMemory.h @@ -43,9 +43,8 @@ const Matrix3x4 matrix1 = {{1.f, 2.f, 3.f, 4.f}, {5.f, 6.f, 7.f, 8.f}, {9.f, 10. const Matrix3x4 matrix2 = {{100.f, 200.f, 300.f, 400.f}, {500.f, 600.f, 700.f, 800.f}, {900.f, 1000.f, 1100.f, 1200.f}}; -const Matrix3x4 matrix3 = {{20.f, 30.f, 40.f, 50.f}, - {21.f, 22.f, 23.f, 24.f}, - {31.f, 32.f, 33.f, 34.f}}; +const Matrix3x4 matrix3 = { + {20.f, 30.f, 40.f, 50.f}, {21.f, 22.f, 23.f, 24.f}, {31.f, 32.f, 33.f, 34.f}}; const Matrix3x4 expected3 = {{121.f, 232.f, 343.f, 454.f}, {526.f, 628.f, 730.f, 832.f}, {940.f, 1042.f, 1144.f, 1246.f}}; diff --git a/nn/runtime/test/TestMemoryInternal.cpp b/nn/runtime/test/TestMemoryInternal.cpp index 755af09fd..8bba8253d 100644 --- a/nn/runtime/test/TestMemoryInternal.cpp +++ b/nn/runtime/test/TestMemoryInternal.cpp @@ -53,11 +53,11 @@ namespace { // (We can also get very unlucky and mask a memory leak by unrelated unmapping // somewhere else. This seems unlikely enough to not deal with.) class MemoryLeakTest : public ::testing::Test { -protected: + protected: void SetUp() override; void TearDown() override; -private: + private: size_t GetAshmemMappingsCount(); size_t mStartingMapCount = 0; @@ -77,7 +77,7 @@ void MemoryLeakTest::TearDown() { size_t MemoryLeakTest::GetAshmemMappingsCount() { std::ifstream mappingsStream("/proc/self/maps"); - if (! mappingsStream.good()) { + if (!mappingsStream.good()) { // errno is set by std::ifstream on Linux ADD_FAILURE() << "Failed to open /proc/self/maps: " << std::strerror(errno); return 0; @@ -85,9 +85,9 @@ size_t MemoryLeakTest::GetAshmemMappingsCount() { std::string line; int mapCount = 0; while (std::getline(mappingsStream, line)) { - if (line.find("/dev/ashmem") != std::string::npos) { - ++mapCount; - } + if (line.find("/dev/ashmem") != std::string::npos) { + ++mapCount; + } } return mapCount; } @@ -106,8 +106,8 @@ TEST_F(MemoryLeakTest, TestASharedMemory) { int weightsFd = ASharedMemory_create("weights", weightsSize); ASSERT_GT(weightsFd, -1); - uint8_t* weightsData = (uint8_t*)mmap(nullptr, weightsSize, PROT_READ | PROT_WRITE, - MAP_SHARED, weightsFd, 0); + uint8_t* weightsData = + (uint8_t*)mmap(nullptr, weightsSize, PROT_READ | PROT_WRITE, MAP_SHARED, weightsFd, 0); ASSERT_NE(weightsData, nullptr); memcpy(weightsData + offsetForMatrix2, matrix2, sizeof(matrix2)); memcpy(weightsData + offsetForMatrix3, matrix3, sizeof(matrix3)); @@ -139,8 +139,8 @@ TEST_F(MemoryLeakTest, TestASharedMemory) { constexpr size_t inputSize = offsetForMatrix1 + sizeof(Matrix3x4); int inputFd = ASharedMemory_create("input", inputSize); ASSERT_GT(inputFd, -1); - uint8_t* inputData = (uint8_t*)mmap(nullptr, inputSize, - PROT_READ | PROT_WRITE, MAP_SHARED, inputFd, 0); + uint8_t* inputData = + (uint8_t*)mmap(nullptr, inputSize, PROT_READ | PROT_WRITE, MAP_SHARED, inputFd, 0); ASSERT_NE(inputData, nullptr); memcpy(inputData + offsetForMatrix1, matrix1, sizeof(Matrix3x4)); WrapperMemory input(inputSize, PROT_READ, inputFd, 0); @@ -150,8 +150,8 @@ TEST_F(MemoryLeakTest, TestASharedMemory) { constexpr size_t outputSize = offsetForActual + sizeof(Matrix3x4); int outputFd = ASharedMemory_create("output", outputSize); ASSERT_GT(outputFd, -1); - uint8_t* outputData = (uint8_t*)mmap(nullptr, outputSize, - PROT_READ | PROT_WRITE, MAP_SHARED, outputFd, 0); + uint8_t* outputData = + (uint8_t*)mmap(nullptr, outputSize, PROT_READ | PROT_WRITE, MAP_SHARED, outputFd, 0); ASSERT_NE(outputData, nullptr); memset(outputData, 0, outputSize); WrapperMemory actual(outputSize, PROT_READ | PROT_WRITE, outputFd, 0); @@ -166,8 +166,9 @@ TEST_F(MemoryLeakTest, TestASharedMemory) { ASSERT_EQ(execution2.setOutputFromMemory(0, &actual, offsetForActual, sizeof(Matrix3x4)), WrapperResult::NO_ERROR); ASSERT_EQ(execution2.compute(), WrapperResult::NO_ERROR); - ASSERT_EQ(CompareMatrices(expected3, - *reinterpret_cast<Matrix3x4*>(outputData + offsetForActual)), 0); + ASSERT_EQ( + CompareMatrices(expected3, *reinterpret_cast<Matrix3x4*>(outputData + offsetForActual)), + 0); munmap(weightsData, weightsSize); munmap(inputData, inputSize); @@ -199,7 +200,7 @@ TEST_F(MemoryLeakTest, GetPointer) { ASSERT_TRUE(mem.isValid()); auto internalMem = reinterpret_cast<::android::nn::Memory*>(mem.get()); - uint8_t *dummy; + uint8_t* dummy; ASSERT_EQ(internalMem->getPointer(&dummy), ANEURALNETWORKS_NO_ERROR); (*dummy)++; } @@ -219,7 +220,7 @@ TEST_F(MemoryLeakTest, Instantiate) { ASSERT_TRUE(mem.isValid()); auto internalMem = reinterpret_cast<::android::nn::Memory*>(mem.get()); - uint8_t *dummy; + uint8_t* dummy; ASSERT_EQ(internalMem->getPointer(&dummy), ANEURALNETWORKS_NO_ERROR); close(fd); @@ -260,7 +261,8 @@ TEST_F(MemoryLeakTest, convTooLarge) { model.setOperandValue(act, act_init, sizeof(act_init)); int32_t stride_init[] = {1}; model.setOperandValue(stride, stride_init, sizeof(stride_init)); - model.addOperation(ANEURALNETWORKS_CONV_2D, {op1, op2, op3, pad0, pad0, pad0, pad0, stride, stride, act}, {op4}); + model.addOperation(ANEURALNETWORKS_CONV_2D, + {op1, op2, op3, pad0, pad0, pad0, pad0, stride, stride, act}, {op4}); // Inputs and outputs model.identifyInputsAndOutputs({op1}, {op4}); @@ -269,7 +271,7 @@ TEST_F(MemoryLeakTest, convTooLarge) { // Compilation WrapperCompilation compilation(&model); - ASSERT_EQ(WrapperResult::NO_ERROR,compilation.finish()); + ASSERT_EQ(WrapperResult::NO_ERROR, compilation.finish()); WrapperExecution execution(&compilation); // Set input and outputs @@ -283,6 +285,6 @@ TEST_F(MemoryLeakTest, convTooLarge) { ASSERT_EQ(WrapperResult::OP_FAILED, r); } -#endif // NNTEST_ONLY_PUBLIC_API +#endif // NNTEST_ONLY_PUBLIC_API } // end namespace diff --git a/nn/runtime/test/TestOpenmpSettings.cpp b/nn/runtime/test/TestOpenmpSettings.cpp index a021f8db1..d1dd2f526 100644 --- a/nn/runtime/test/TestOpenmpSettings.cpp +++ b/nn/runtime/test/TestOpenmpSettings.cpp @@ -16,19 +16,19 @@ #include "CpuExecutor.h" -#include <algorithm> #include <gtest/gtest.h> -#include <memory> #include <omp.h> +#include <unistd.h> +#include <algorithm> +#include <memory> #include <random> #include <thread> -#include <unistd.h> #include <vector> namespace { class OpenmpSettingsTest : public ::testing::Test { -protected: + protected: virtual void SetUp() override { const int blocktimeInitial = kmp_get_blocktime(); ASSERT_EQ(blocktimeInitial, kOpenmpDefaultBlockTime); @@ -84,9 +84,7 @@ TEST_F(OpenmpSettingsTest, TestThreaded) { ASSERT_EQ(blocktimeSet2, 1); })); } - std::for_each(threads.begin(), threads.end(), [](std::thread& t) { - t.join(); - }); + std::for_each(threads.begin(), threads.end(), [](std::thread& t) { t.join(); }); } } // end namespace diff --git a/nn/runtime/test/TestPartitioning.cpp b/nn/runtime/test/TestPartitioning.cpp index d2bea0e1f..c4f792eb6 100644 --- a/nn/runtime/test/TestPartitioning.cpp +++ b/nn/runtime/test/TestPartitioning.cpp @@ -237,9 +237,7 @@ uint32_t lookupOperation(std::function<const Operation&(uint32_t)> getOperation, (input2.lifetime == OperandLifeTime::CONSTANT_COPY)) { int32_t value; CHECK_EQ(sizeof(value), input2.location.length); - memcpy(&value, - getValue(input2.location.offset), - input2.location.length); + memcpy(&value, getValue(input2.location.offset), input2.location.length); return value + operationToFirstEncoding.at(operation.type); } break; @@ -257,14 +255,9 @@ uint32_t lookupOperation(std::function<const Operation&(uint32_t)> getOperation, uint32_t lookupOperation(const HidlModel& model, uint32_t operationIndex) { return lookupOperation( - [&model](uint32_t index) -> const Operation& { - return model.operations[index]; - }, - [&model](uint32_t index) -> const Operand& { - return model.operands[index]; - }, - [&model](uint32_t offset) {return &model.operandValues[offset];}, - operationIndex); + [&model](uint32_t index) -> const Operation& { return model.operations[index]; }, + [&model](uint32_t index) -> const Operand& { return model.operands[index]; }, + [&model](uint32_t offset) { return &model.operandValues[offset]; }, operationIndex); } #ifdef VERBOSE @@ -287,32 +280,33 @@ void dump(const char* name, const ModelBuilder* model) { // operation. The subset is represented with a bitmask, in which // operation kind K corresponds to the bit (1 << K). class PartitioningDriver : public SampleDriver { -private: + private: // Dummy class -- a prepared model must not be nullptr. class PartitioningPreparedModel : public IPreparedModel { - public: - Return<ErrorStatus> execute(const Request&, const sp<V1_0::IExecutionCallback>&) override { - return ErrorStatus::DEVICE_UNAVAILABLE; - } - Return<ErrorStatus> execute_1_2(const Request&, MeasureTiming, - const sp<V1_2::IExecutionCallback>&) override { - return ErrorStatus::DEVICE_UNAVAILABLE; - } - Return<void> executeSynchronously(const Request&, MeasureTiming, - executeSynchronously_cb cb) override { - cb(ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming); - return Void(); - } - Return<void> configureExecutionBurst( - const sp<V1_2::IBurstCallback>& /*callback*/, - const MQDescriptorSync<V1_2::FmqRequestDatum>& /*requestChannel*/, - const MQDescriptorSync<V1_2::FmqResultDatum>& /*resultChannel*/, - configureExecutionBurst_cb cb) override { - cb(ErrorStatus::DEVICE_UNAVAILABLE, nullptr); - return Void(); - } + public: + Return<ErrorStatus> execute(const Request&, const sp<V1_0::IExecutionCallback>&) override { + return ErrorStatus::DEVICE_UNAVAILABLE; + } + Return<ErrorStatus> execute_1_2(const Request&, MeasureTiming, + const sp<V1_2::IExecutionCallback>&) override { + return ErrorStatus::DEVICE_UNAVAILABLE; + } + Return<void> executeSynchronously(const Request&, MeasureTiming, + executeSynchronously_cb cb) override { + cb(ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming); + return Void(); + } + Return<void> configureExecutionBurst( + const sp<V1_2::IBurstCallback>& /*callback*/, + const MQDescriptorSync<V1_2::FmqRequestDatum>& /*requestChannel*/, + const MQDescriptorSync<V1_2::FmqResultDatum>& /*resultChannel*/, + configureExecutionBurst_cb cb) override { + cb(ErrorStatus::DEVICE_UNAVAILABLE, nullptr); + return Void(); + } }; -public: + + public: enum OEM { OEMNo, // rejected by getSupportedOperations and prepareModel OEMIndecisive, // accepted by getSupportedOperations but not prepareModel @@ -350,9 +344,7 @@ public: return status; } - Return<DeviceStatus> getStatus() override { - return DeviceStatus::AVAILABLE; - } + Return<DeviceStatus> getStatus() override { return DeviceStatus::AVAILABLE; } Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override { cb(ErrorStatus::NONE, mCapabilities); @@ -567,15 +559,15 @@ class PartitioningModel : private WrapperModel { uint32_t addOperationOEM1To1(const uint32_t input, Dimensioned dimensionedOutput = Dimensioned::YES) { uint32_t output = addOperandOfSameType(input, dimensionedOutput); - addOperation(ANEURALNETWORKS_OEM_OPERATION, { input }, { output }); + addOperation(ANEURALNETWORKS_OEM_OPERATION, {input}, {output}); return output; } // Run the partitioning algorithm to create an ExecutionPlan. int partitionTheWork(const std::vector<std::shared_ptr<Device>>& devices, ExecutePreference preference, ExecutionPlan* plan) { - return reinterpret_cast<ModelBuilder*>(getHandle())->partitionTheWork( - devices, static_cast<uint32_t>(preference), plan); + return reinterpret_cast<ModelBuilder*>(getHandle()) + ->partitionTheWork(devices, static_cast<uint32_t>(preference), plan); } #ifdef VERBOSE @@ -586,34 +578,34 @@ class PartitioningModel : private WrapperModel { } #endif -private: - // Create an operation with two inputs and one output, specifying - // the operation kind and the input operand indexes. - // Returns the output operand index. - uint32_t addOperation2To1(uint32_t operation, const uint32_t input0, const uint32_t input1, - Dimensioned dimensionedOutput = Dimensioned::YES) { - auto it = firstEncodingToOperation.lower_bound(operation); - CHECK(it != firstEncodingToOperation.end()); - ANeuralNetworksOperationType type = it->second.first; - if (it->second.second) { - int32_t fuseCode = operation - it->first; - uint32_t input2 = addIntOperand(fuseCode); - uint32_t output = addOperandOfSameType(input0, dimensionedOutput); - addOperation(type, {input0, input1, input2}, {output}); - return output; - } else { - uint32_t output = addOperandOfSameType(input0, dimensionedOutput); - addOperation(type, {input0, input1}, {output}); - return output; - } - } - - // Create a scalar integer operand of the specified value, and - // return the corresponding operand index. - uint32_t addIntOperand(int32_t value) { - uint32_t operand = addOperand(WrapperType::INT32); - setOperandValue(operand, &value, sizeof(value)); - return operand; + private: + // Create an operation with two inputs and one output, specifying + // the operation kind and the input operand indexes. + // Returns the output operand index. + uint32_t addOperation2To1(uint32_t operation, const uint32_t input0, const uint32_t input1, + Dimensioned dimensionedOutput = Dimensioned::YES) { + auto it = firstEncodingToOperation.lower_bound(operation); + CHECK(it != firstEncodingToOperation.end()); + ANeuralNetworksOperationType type = it->second.first; + if (it->second.second) { + int32_t fuseCode = operation - it->first; + uint32_t input2 = addIntOperand(fuseCode); + uint32_t output = addOperandOfSameType(input0, dimensionedOutput); + addOperation(type, {input0, input1, input2}, {output}); + return output; + } else { + uint32_t output = addOperandOfSameType(input0, dimensionedOutput); + addOperation(type, {input0, input1}, {output}); + return output; + } + } + + // Create a scalar integer operand of the specified value, and + // return the corresponding operand index. + uint32_t addIntOperand(int32_t value) { + uint32_t operand = addOperand(WrapperType::INT32); + setOperandValue(operand, &value, sizeof(value)); + return operand; } // Create an operand of the same type as the specified operand, @@ -633,30 +625,26 @@ private: // This class adds some utilities on top of WrapperCompilation. class PartitioningCompilation : public WrapperCompilation { -public: - PartitioningCompilation(const PartitioningModel* model, - const std::vector<std::shared_ptr<Device>>& devices) { - ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model->getHandle()); - CompilationBuilder* c = nullptr; - int result = m->createCompilation(&c, devices); - EXPECT_EQ(result, 0); - mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c); - } - - Result setPartitioning(uint32_t partitioning) { - return static_cast<Result>(builder()->setPartitioning(partitioning)); + public: + PartitioningCompilation(const PartitioningModel* model, + const std::vector<std::shared_ptr<Device>>& devices) { + ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model->getHandle()); + CompilationBuilder* c = nullptr; + int result = m->createCompilation(&c, devices); + EXPECT_EQ(result, 0); + mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c); + } + + Result setPartitioning(uint32_t partitioning) { + return static_cast<Result>(builder()->setPartitioning(partitioning)); } using WrapperCompilation::finish; - const ExecutionPlan& getExecutionPlan() const { - return builder()->forTest_getExecutionPlan(); - } + const ExecutionPlan& getExecutionPlan() const { return builder()->forTest_getExecutionPlan(); } -private: - CompilationBuilder* builder() { - return reinterpret_cast<CompilationBuilder*>(getHandle()); - } + private: + CompilationBuilder* builder() { return reinterpret_cast<CompilationBuilder*>(getHandle()); } const CompilationBuilder* builder() const { return reinterpret_cast<const CompilationBuilder*>(getHandle()); @@ -664,16 +652,14 @@ private: }; #ifdef VERBOSE -#define RETURN_TRUE() \ - { \ - std::cerr << "returning true from " << __LINE__ << std::endl; \ - return true; \ +#define RETURN_TRUE() \ + { \ + std::cerr << "returning true from " << __LINE__ << std::endl; \ + return true; \ } #else -#define RETURN_TRUE() \ - { \ - return true; \ - } +#define RETURN_TRUE() \ + { return true; } #endif #ifdef VERBOSE #define RETURN_FALSE(MESSAGE) \ @@ -682,19 +668,16 @@ private: return false; \ } #else -#define RETURN_FALSE(MESSAGE) \ - { \ - return false; \ - } +#define RETURN_FALSE(MESSAGE) \ + { return false; } #endif class PartitioningTest : public ::testing::Test { -protected: + protected: using RemapVectorType = ExecutionStep::RemapVectorType; using SubModelOutputSetType = ExecutionStep::SubModelOutputSetType; - virtual void SetUp() { - } + virtual void SetUp() {} // From a vector of DeviceSpecification, create a vector of // Devices. @@ -841,21 +824,20 @@ protected: // within its scope (actual operations, inputs, constants). enum PseudoDefiningOperationEncodings : uint32_t { - kPseudoDefiningOperationModelInput0 = 0x80000000U, + kPseudoDefiningOperationModelInput0 = 0x80000000U, kPseudoDefiningOperationConstantCopy0 = 0x90000000U, - kPseudoDefiningOperationNoValue = 0xeeeeeeeeU, + kPseudoDefiningOperationNoValue = 0xeeeeeeeeU, // lowest value for special encoding - kPseudoDefiningOperationBase = 0x80000000U, + kPseudoDefiningOperationBase = 0x80000000U, // range of encoded input or constant - kPseudoDefiningOperationRange = 0x10000000U, + kPseudoDefiningOperationRange = 0x10000000U, }; // Build a map from operand to defining operation. // TODO: Replace map with vector? - void buildDefinitionMap(const ModelBuilder* model, - std::map<uint32_t, uint32_t>* defMap) { + void buildDefinitionMap(const ModelBuilder* model, std::map<uint32_t, uint32_t>* defMap) { // actual definitions ASSERT_LT(model->operationCount(), kPseudoDefiningOperationBase); for (uint32_t i = 0, e = model->operationCount(); i < e; i++) { @@ -879,7 +861,8 @@ protected: case OperandLifeTime::CONSTANT_COPY: { ASSERT_EQ(operand.location.length, sizeof(uint32_t)); uint32_t value; - memcpy(&value, model->getPointerToOperandValue(operand.location.offset), sizeof(uint32_t)); + memcpy(&value, model->getPointerToOperandValue(operand.location.offset), + sizeof(uint32_t)); ASSERT_LT(value, kPseudoDefiningOperationNoValue); (*defMap)[i] = kPseudoDefiningOperationConstantCopy0 + value; break; @@ -927,11 +910,9 @@ protected: #endif bool compare(const Operand& operandA, const Operand& operandB) { - if (operandA.type != operandB.type || - operandA.dimensions != operandB.dimensions || + if (operandA.type != operandB.type || operandA.dimensions != operandB.dimensions || operandA.numberOfConsumers != operandB.numberOfConsumers || - operandA.scale != operandB.scale || - operandA.zeroPoint != operandB.zeroPoint) { + operandA.scale != operandB.scale || operandA.zeroPoint != operandB.zeroPoint) { return false; } return true; @@ -980,10 +961,10 @@ protected: ::dump("compare(B)", modelB); #endif - if (modelA->operandCount() != modelB->operandCount() || + if (modelA->operandCount() != modelB->operandCount() || modelA->operationCount() != modelB->operationCount() || - modelA->inputCount() != modelB->inputCount() || - modelA->outputCount() != modelB->outputCount()) { + modelA->inputCount() != modelB->inputCount() || + modelA->outputCount() != modelB->outputCount()) { RETURN_FALSE(); } @@ -1093,8 +1074,7 @@ protected: } // Sanity check - if (modelA->operandCount() != defsA.size() || - modelA->operandCount() != defsB.size() || + if (modelA->operandCount() != defsA.size() || modelA->operandCount() != defsB.size() || modelA->operandCount() != equivalentOperandsAToB.size() || modelA->operationCount() + pseudoDefinitionCount != equivalentOperationsAToB.size()) { RETURN_FALSE(); @@ -1180,7 +1160,7 @@ TEST_F(PartitioningTest, SimpleModel) { uint32_t opnd2 = model.addOperation2To1V1_0(0, opnd0, opnd1); uint32_t opnd3 = model.addFloatOperand(); uint32_t opnd4 = model.addOperation2To1V1_0(1, opnd2, opnd3); - model.identifyInputsAndOutputs({ opnd0, opnd1, opnd3 }, { opnd4 }); + model.identifyInputsAndOutputs({opnd0, opnd1, opnd3}, {opnd4}); model.finish(); ASSERT_TRUE(model.isValid()); @@ -1222,7 +1202,7 @@ TEST_F(PartitioningTest, SimpleModel) { uint32_t b0Opnd0 = modelB0.addFloatOperand(); uint32_t b0Opnd1 = modelB0.addFloatOperand(); uint32_t b0Opnd2 = modelB0.addOperation2To1V1_0(0, b0Opnd0, b0Opnd1); - modelB0.identifyInputsAndOutputs({ b0Opnd0, b0Opnd1 }, { b0Opnd2 }); + modelB0.identifyInputsAndOutputs({b0Opnd0, b0Opnd1}, {b0Opnd2}); modelB0.finish(); ASSERT_TRUE(modelB0.isValid()); @@ -1245,7 +1225,7 @@ TEST_F(PartitioningTest, SimpleModel) { // an input; so in the submodel "modelB1", the corresponding // input b1Opnd2 is a submodel input, and must follow the // model input b1Opnd3. - modelB1.identifyInputsAndOutputs({ b1Opnd3, b1Opnd2 }, { b1Opnd4 }); + modelB1.identifyInputsAndOutputs({b1Opnd3, b1Opnd2}, {b1Opnd4}); modelB1.finish(); ASSERT_TRUE(modelB1.isValid()); @@ -1410,7 +1390,7 @@ TEST_F(PartitioningTest, Cpu) { uint32_t opnd7 = model.addOperation2To1V1_0(kDevOp, opnd3, opnd5); uint32_t opnd8 = model.addOperation2To1V1_0(kDevOp, opnd6, opnd7); - model.identifyInputsAndOutputs({ opnd0, opnd1, opnd6 }, { opnd4, opnd8 }); + model.identifyInputsAndOutputs({opnd0, opnd1, opnd6}, {opnd4, opnd8}); model.finish(); ASSERT_TRUE(model.isValid()); @@ -1429,7 +1409,7 @@ TEST_F(PartitioningTest, Cpu) { uint32_t m0Opnd1 = model0.addFloatOperand(); uint32_t m0Opnd2 = model0.addOperation2To1V1_0(kDevOp, m0Opnd0, m0Opnd1); uint32_t m0Opnd3 = model0.addOperation2To1V1_0(kDevOp, m0Opnd0, m0Opnd2); - model0.identifyInputsAndOutputs({ m0Opnd0, m0Opnd1 }, { m0Opnd2, m0Opnd3 }); + model0.identifyInputsAndOutputs({m0Opnd0, m0Opnd1}, {m0Opnd2, m0Opnd3}); model0.finish(); ASSERT_TRUE(model0.isValid()); @@ -1452,7 +1432,7 @@ TEST_F(PartitioningTest, Cpu) { uint32_t m1Opnd4 = model1.addOperation2To1V1_0(kCpuOp, m1Opnd0, m1Opnd3); uint32_t m1Opnd2 = model1.addFloatOperand(); uint32_t m1Opnd5 = model1.addOperation2To1V1_0(kCpuOp, m1Opnd2, m1Opnd4); - model1.identifyInputsAndOutputs({ m1Opnd0, m1Opnd3, m1Opnd2 }, { m1Opnd4, m1Opnd5 }); + model1.identifyInputsAndOutputs({m1Opnd0, m1Opnd3, m1Opnd2}, {m1Opnd4, m1Opnd5}); model1.finish(); ASSERT_TRUE(model1.isValid()); @@ -1474,7 +1454,7 @@ TEST_F(PartitioningTest, Cpu) { uint32_t m2Opnd7 = model2.addOperation2To1V1_0(kDevOp, m2Opnd3, m2Opnd5); uint32_t m2Opnd6 = model2.addFloatOperand(); uint32_t m2Opnd8 = model2.addOperation2To1V1_0(kDevOp, m2Opnd6, m2Opnd7); - model2.identifyInputsAndOutputs({ m2Opnd6, m2Opnd3, m2Opnd5 }, { m2Opnd8 }); + model2.identifyInputsAndOutputs({m2Opnd6, m2Opnd3, m2Opnd5}, {m2Opnd8}); model2.finish(); ASSERT_TRUE(model2.isValid()); @@ -1495,7 +1475,7 @@ TEST_F(PartitioningTest, SetPartitioning) { model.addOperation2To1V1_0(0, opnd0, opnd1, PartitioningModel::Dimensioned::NO); uint32_t opnd3 = model.addFloatOperand(); uint32_t opnd4 = model.addOperation2To1V1_0(1, opnd2, opnd3); - model.identifyInputsAndOutputs({ opnd0, opnd1, opnd3 }, { opnd4 }); + model.identifyInputsAndOutputs({opnd0, opnd1, opnd3}, {opnd4}); model.finish(); ASSERT_TRUE(model.isValid()); @@ -1524,7 +1504,8 @@ TEST_F(PartitioningTest, SetPartitioning) { // No need to compare the original model to the model from the plan -- we // didn't actually do any partitioning. PartitioningCompilation cPWithFallback(&model, devices); - ASSERT_EQ(cPWithFallback.setPartitioning(DeviceManager::kPartitioningWithFallback), Result::NO_ERROR); + ASSERT_EQ(cPWithFallback.setPartitioning(DeviceManager::kPartitioningWithFallback), + Result::NO_ERROR); ASSERT_EQ(cPWithFallback.finish(), Result::NO_ERROR); ASSERT_EQ(cPWithFallback.getExecutionPlan().forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_EQ(cPWithFallback.getExecutionPlan().forTest_simpleGetDevice(), @@ -1533,21 +1514,23 @@ TEST_F(PartitioningTest, SetPartitioning) { // Test kPartitioningWithoutFallback. We should attempt // partitioning, and fail. PartitioningCompilation cPWithoutFallback(&model, devices); - ASSERT_EQ(cPWithoutFallback.setPartitioning(DeviceManager::kPartitioningWithoutFallback), Result::NO_ERROR); + ASSERT_EQ(cPWithoutFallback.setPartitioning(DeviceManager::kPartitioningWithoutFallback), + Result::NO_ERROR); ASSERT_EQ(cPWithoutFallback.finish(), Result::OP_FAILED); ASSERT_TRUE(cPWithoutFallback.getExecutionPlan().forTest_hasSubModelOutputsOfUnknownSize()); ASSERT_EQ(cPWithoutFallback.getExecutionPlan().forTest_getKind(), ExecutionPlan::Kind::ERROR); } // Regression test for http://b/69166603: -// "partitioned compilation and execution yields wrong results when model output is submodel input" +// "partitioned compilation and execution yields wrong results when model output is submodel +// input" TEST_F(PartitioningTest, ModelOutputAsSubmodelInput) { PartitioningModel model; uint32_t opnd0 = model.addFloatOperand(); uint32_t opnd1 = model.addFloatOperand(); uint32_t opnd2 = model.addOperation2To1V1_0(0, opnd0, opnd1); uint32_t opnd3 = model.addOperation2To1V1_0(1, opnd2, opnd2); - model.identifyInputsAndOutputs({ opnd0, opnd1 }, { opnd2, opnd3 }); + model.identifyInputsAndOutputs({opnd0, opnd1}, {opnd2, opnd3}); model.finish(); ASSERT_TRUE(model.isValid()); @@ -1568,7 +1551,7 @@ TEST_F(PartitioningTest, ModelOutputAsSubmodelInput) { uint32_t m0Opnd0 = model0.addFloatOperand(); uint32_t m0Opnd1 = model0.addFloatOperand(); uint32_t m0Opnd2 = model0.addOperation2To1V1_0(0, m0Opnd0, m0Opnd1); - model0.identifyInputsAndOutputs({ m0Opnd0, m0Opnd1 }, { m0Opnd2 }); + model0.identifyInputsAndOutputs({m0Opnd0, m0Opnd1}, {m0Opnd2}); model0.finish(); ASSERT_TRUE(model0.isValid()); ASSERT_NO_FATAL_FAILURE( @@ -1584,7 +1567,7 @@ TEST_F(PartitioningTest, ModelOutputAsSubmodelInput) { PartitioningModel model1; uint32_t m1Opnd2 = model1.addFloatOperand(); uint32_t m1Opnd3 = model1.addOperation2To1V1_0(1, m1Opnd2, m1Opnd2); - model1.identifyInputsAndOutputs({ m1Opnd2 }, { m1Opnd3 }); + model1.identifyInputsAndOutputs({m1Opnd2}, {m1Opnd3}); model1.finish(); ASSERT_TRUE(model1.isValid()); @@ -1602,7 +1585,7 @@ TEST_F(PartitioningTest, OemOperations) { PartitioningModel model; uint32_t opndIn = model.addFloatOperand(); uint32_t opndOut = model.addOperationOEM1To1(opndIn); - model.identifyInputsAndOutputs({ opndIn }, { opndOut }); + model.identifyInputsAndOutputs({opndIn}, {opndOut}); model.finish(); ASSERT_TRUE(model.isValid()); @@ -1649,7 +1632,7 @@ TEST_F(PartitioningTest, RelaxedFP) { uint32_t opnd0 = model.addFloatOperand(); uint32_t opnd1 = model.addFloatOperand(); uint32_t opnd2 = model.addOperation2To1V1_0(0, opnd0, opnd1); - model.identifyInputsAndOutputs({ opnd0, opnd1 }, { opnd2 }); + model.identifyInputsAndOutputs({opnd0, opnd1}, {opnd2}); model.relaxComputationFloat32toFloat16(doRelax); model.finish(); ASSERT_TRUE(model.isValid()); diff --git a/nn/runtime/test/TestPartitioningRandom.cpp b/nn/runtime/test/TestPartitioningRandom.cpp index d62a6ad97..b7326e562 100644 --- a/nn/runtime/test/TestPartitioningRandom.cpp +++ b/nn/runtime/test/TestPartitioningRandom.cpp @@ -147,8 +147,7 @@ typedef std::pair<ANeuralNetworksOperationType, int> Signature; // it provides access to certain features from ModelBuilder that are not exposed // by the base class (such as inputCount() and operation index). class TestModel : public WrapperModel { -public: - + public: uint32_t addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs, const std::vector<uint32_t>& outputs) { const uint32_t operationIndex = operationCount(); @@ -157,16 +156,10 @@ public: return operationIndex; } - uint32_t operationCount() const { - return mOperations.size(); - } + uint32_t operationCount() const { return mOperations.size(); } - uint32_t inputCount() const { - return builder()->inputCount(); - } - uint32_t outputCount() const { - return builder()->outputCount(); - } + uint32_t inputCount() const { return builder()->inputCount(); } + uint32_t outputCount() const { return builder()->outputCount(); } const std::vector<uint32_t>& getOperationOutputs(uint32_t index) const { CHECK(index < mOperations.size()); @@ -198,8 +191,7 @@ public: WrapperModel::setOperandValue(index, &value, sizeof(value)); } -private: - + private: const ModelBuilder* builder() const { return reinterpret_cast<const ModelBuilder*>(getHandle()); } @@ -217,7 +209,7 @@ private: // to provide access to certain features from CompilationBuilder that are not // exposed by the base class. class TestCompilation : public WrapperCompilation { -public: + public: TestCompilation(const WrapperModel* model) : WrapperCompilation(model) {} TestCompilation(const WrapperModel* model, std::vector<std::shared_ptr<Device>> devices) { @@ -234,17 +226,13 @@ public: return static_cast<Result>(builder()->setPartitioning(partitioning)); } - const ExecutionPlan& getExecutionPlan() const { - return builder()->forTest_getExecutionPlan(); - } + const ExecutionPlan& getExecutionPlan() const { return builder()->forTest_getExecutionPlan(); } -private: + private: const CompilationBuilder* builder() const { return reinterpret_cast<const CompilationBuilder*>(getHandle()); } - CompilationBuilder* builder() { - return reinterpret_cast<CompilationBuilder*>(getHandle()); - } + CompilationBuilder* builder() { return reinterpret_cast<CompilationBuilder*>(getHandle()); } }; // This class is used to manage a collection of memory regions, @@ -262,7 +250,7 @@ private: // TestMemories instance, and are destroyed when the TestMemories // instance is destroyed. class TestMemories { -public: + public: TestMemories() = default; ~TestMemories(); @@ -274,9 +262,7 @@ public: mMemorySizes.push_back(0); return memoryCount() - 1; } - unsigned memoryCount() const { - return mMemorySizes.size(); - } + unsigned memoryCount() const { return mMemorySizes.size(); } unsigned addRegion(unsigned memoryIndex, uint32_t length) { CHECK(!mLayoutDone); @@ -287,14 +273,12 @@ public: memorySize += length; return regionCount() - 1; } - unsigned regionCount() const { - return mRegions.size(); - } + unsigned regionCount() const { return mRegions.size(); } void layout(); - void* getRegion(unsigned regionIndex, - const WrapperMemory** pMemory, uint32_t* pOffset, uint32_t* pLength) { + void* getRegion(unsigned regionIndex, const WrapperMemory** pMemory, uint32_t* pOffset, + uint32_t* pLength) { CHECK(mLayoutDone); CHECK(regionIndex < regionCount()); const auto& regionDescriptor = mRegions[regionIndex]; @@ -319,7 +303,7 @@ public: return getRegion(regionIndex, nullptr, nullptr, nullptr); } -private: + private: // Index is the memory index; value is the size of the memory // (aggregate size of all regions in the memory). std::vector<uint32_t> mMemorySizes; @@ -354,23 +338,23 @@ TestMemories::~TestMemories() { } class RandomPartitioningTest : public ::testing::TestWithParam<unsigned> { -public: + public: RandomPartitioningTest() : mRandNumEng(GetParam() /* seed */), mRandNumUnitDist(0.0, 1.0) {} static Signature getSignature(const HidlModel& model, const Operation& operation); -protected: - static V1_0::IDevice* makeTestDriver(HalVersion version, const char* name, - std::set<Signature> signatures); + protected: + static V1_0::IDevice* makeTestDriver(HalVersion version, const char* name, + std::set<Signature> signatures); - static HalVersion getMinHalVersion(ANeuralNetworksOperationType type); + static HalVersion getMinHalVersion(ANeuralNetworksOperationType type); - static std::string to_string(HalVersion version); + static std::string to_string(HalVersion version); - bool randBool() { return randUInt(2) == 1; } + bool randBool() { return randUInt(2) == 1; } - double randFrac() { // [0.0, 1.0) - return mRandNumUnitDist(mRandNumEng); + double randFrac() { // [0.0, 1.0) + return mRandNumUnitDist(mRandNumEng); } unsigned randUInt(unsigned limit) { // [0, limit) @@ -412,11 +396,10 @@ protected: } // input operand 3 is bias, a 1-D tensor - const WrapperOperandType biasType(WrapperType::TENSOR_FLOAT32, { problemSize }); + const WrapperOperandType biasType(WrapperType::TENSOR_FLOAT32, {problemSize}); const uint32_t operandIndex = model->addOperand(&biasType); std::vector<float> biasValue(problemSize); - std::generate(biasValue.begin(), biasValue.end(), - [this]{ return randFrac(); }); + std::generate(biasValue.begin(), biasValue.end(), [this] { return randFrac(); }); model->setOperandValue(operandIndex, biasValue); return int(operandIndex); } @@ -440,24 +423,23 @@ protected: #ifdef VERBOSE class ModelStats { - public: - ModelStats(const ModelBuilder* model) : - mBuilder(model) { } - ModelStats(const WrapperModel* model) : - mBuilder(reinterpret_cast<const ModelBuilder*>(model->getHandle())) { } + public: + ModelStats(const ModelBuilder* model) : mBuilder(model) {} + ModelStats(const WrapperModel* model) + : mBuilder(reinterpret_cast<const ModelBuilder*>(model->getHandle())) {} friend std::ostream& operator<<(std::ostream& out, const ModelStats& stats) { const uint32_t operandCount = stats.mBuilder->operandCount(); const uint32_t inputCount = stats.mBuilder->inputCount(); const uint32_t outputCount = stats.mBuilder->outputCount(); out << "operationCount = " << stats.mBuilder->operationCount() - << ", operandCount = " << operandCount - << ", inputCount = " << inputCount - << " (" << (double(inputCount) / operandCount) << ")" - << ", outputCount = " << outputCount - << " (" << (double(outputCount) / operandCount) << ")"; + << ", operandCount = " << operandCount << ", inputCount = " << inputCount << " (" + << (double(inputCount) / operandCount) << ")" + << ", outputCount = " << outputCount << " (" << (double(outputCount) / operandCount) + << ")"; return out; } - private: + + private: const ModelBuilder* mBuilder; }; @@ -526,9 +508,7 @@ Signature RandomPartitioningTest::getSignature(const HidlModel& model, const Ope CHECK(operand.lifetime == OperandLifeTime::CONSTANT_COPY); CHECK(operand.type == OperandType::INT32); int32_t value; - memcpy(&value, - &model.operandValues[operand.location.offset], - operand.location.length); + memcpy(&value, &model.operandValues[operand.location.offset], operand.location.length); return Signature(operationType, value); } @@ -546,11 +526,11 @@ std::string RandomPartitioningTest::to_string(HalVersion version) { }; class TestDriver : public SampleDriver { -public: + public: // Behaves like SampleDriver, except that it only supports // operations with the specified signatures. - TestDriver(const char* name, std::set<Signature> signatures) : - SampleDriver(name), mSignatures(std::move(signatures)) { } + TestDriver(const char* name, std::set<Signature> signatures) + : SampleDriver(name), mSignatures(std::move(signatures)) {} Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override { android::nn::initVLogMask(); @@ -569,11 +549,8 @@ public: const size_t count = model.operations.size(); std::vector<bool> supported(count); for (size_t i = 0; i < count; i++) { - supported[i] = - (mSignatures.count( - RandomPartitioningTest::getSignature( - model, - model.operations[i])) != 0); + supported[i] = (mSignatures.count(RandomPartitioningTest::getSignature( + model, model.operations[i])) != 0); } cb(ErrorStatus::NONE, supported); } else { @@ -591,15 +568,15 @@ public: // NOTE: We verify that all operations in the model are supported. ErrorStatus outStatus = ErrorStatus::INVALID_ARGUMENT; auto ret = getSupportedOperations_1_2( - model, - [&outStatus](ErrorStatus inStatus, const hidl_vec<bool>& supportedOperations) { - if (inStatus == ErrorStatus::NONE) { - if (std::all_of(supportedOperations.begin(), supportedOperations.end(), - [](bool v){ return v; })) { - outStatus = ErrorStatus::NONE; + model, + [&outStatus](ErrorStatus inStatus, const hidl_vec<bool>& supportedOperations) { + if (inStatus == ErrorStatus::NONE) { + if (std::all_of(supportedOperations.begin(), supportedOperations.end(), + [](bool v) { return v; })) { + outStatus = ErrorStatus::NONE; + } } - } - }); + }); if (ret.isOk() && (outStatus == ErrorStatus::NONE)) { return SampleDriver::prepareModel_1_2(model, preference, modelCache, dataCache, token, callback); @@ -609,7 +586,7 @@ public: } } -private: + private: const std::set<Signature> mSignatures; }; @@ -696,13 +673,13 @@ TEST_P(RandomPartitioningTest, Test) { std::cout << std::setprecision(2) << std::fixed << std::setw(4); #endif - const unsigned problemSize = 1+randUInt(kMaxProblemSize); - const WrapperOperandType problemType(WrapperType::TENSOR_FLOAT32, { problemSize, problemSize }); - const WrapperOperandType unknownDimensionsType(WrapperType::TENSOR_FLOAT32, { 0, 0 }); + const unsigned problemSize = 1 + randUInt(kMaxProblemSize); + const WrapperOperandType problemType(WrapperType::TENSOR_FLOAT32, {problemSize, problemSize}); + const WrapperOperandType unknownDimensionsType(WrapperType::TENSOR_FLOAT32, {0, 0}); - static const WrapperOperandType activationFunctionType(WrapperType::INT32, { }); + static const WrapperOperandType activationFunctionType(WrapperType::INT32, {}); - const unsigned numOperations = 2+randUInt(kMaxNumOperations-1); + const unsigned numOperations = 2 + randUInt(kMaxNumOperations - 1); const bool allowDeadOperations = (randFrac() < 0.2); const bool allowUnknownDimensions = (randFrac() < 0.25); @@ -783,7 +760,7 @@ TEST_P(RandomPartitioningTest, Test) { } if (operationPattern.mMakeSpecialInput != nullptr) { const int operandIndex = (this->*(operationPattern.mMakeSpecialInput))( - problemSize, &model, operationInputIndex); + problemSize, &model, operationInputIndex); if (operandIndex >= 0) { operationInputs[operationInputIndex] = operandIndex; continue; @@ -811,48 +788,46 @@ TEST_P(RandomPartitioningTest, Test) { // decision later. enum InputKind { IK_MODEL_INPUT, IK_OPERATION_OUTPUT, IK_VALUE }; std::vector<InputKind> normalOperationInputKinds(normalOperationInputCount); - std::generate(normalOperationInputKinds.begin(), normalOperationInputKinds.end(), - [this, &model, - numOperations, - normalOperationInputCount, - &normalOperationInputConstantCount, - &normalOperationInputModelInputCount]() -> InputKind { - // Constant? Becomes less likely the more - // constants we already have as inputs to - // this operation. - if (randFrac() < 0.3 * (1 - double(normalOperationInputConstantCount) / - normalOperationInputCount)) { - normalOperationInputConstantCount++; - return IK_VALUE; - } + std::generate( + normalOperationInputKinds.begin(), normalOperationInputKinds.end(), + [this, &model, numOperations, normalOperationInputCount, + &normalOperationInputConstantCount, + &normalOperationInputModelInputCount]() -> InputKind { + // Constant? Becomes less likely the more + // constants we already have as inputs to + // this operation. + if (randFrac() < 0.3 * (1 - double(normalOperationInputConstantCount) / + normalOperationInputCount)) { + normalOperationInputConstantCount++; + return IK_VALUE; + } - // Model input? Becomes less likely the - // more model inputs we already have as - // inputs to this operation, and the further - // along we are in generating this model - // (i.e., the more operations we have - // generated). - if ((model.operationCount() == 0) || - (randFrac() < 0.5 * - (1 - double(normalOperationInputModelInputCount) / - normalOperationInputCount) * - std::min(0.3, (1 - double(model.operationCount()) / - numOperations)))) { - normalOperationInputModelInputCount++; - return IK_MODEL_INPUT; - } + // Model input? Becomes less likely the + // more model inputs we already have as + // inputs to this operation, and the further + // along we are in generating this model + // (i.e., the more operations we have + // generated). + if ((model.operationCount() == 0) || + (randFrac() < 0.5 * + (1 - double(normalOperationInputModelInputCount) / + normalOperationInputCount) * + std::min(0.3, (1 - double(model.operationCount()) / + numOperations)))) { + normalOperationInputModelInputCount++; + return IK_MODEL_INPUT; + } - // Else output of an existing operation. - return IK_OPERATION_OUTPUT; - }); + // Else output of an existing operation. + return IK_OPERATION_OUTPUT; + }); // Now force common root or model input, if necessary. (A // model must have at least one input.) - auto force = - [this, &normalOperationInputKinds, normalOperationInputCount](InputKind forceKind){ - if (std::none_of(normalOperationInputKinds.begin(), - normalOperationInputKinds.end(), - [forceKind](InputKind kind){ return kind == forceKind; })) { + auto force = [this, &normalOperationInputKinds, + normalOperationInputCount](InputKind forceKind) { + if (std::none_of(normalOperationInputKinds.begin(), normalOperationInputKinds.end(), + [forceKind](InputKind kind) { return kind == forceKind; })) { normalOperationInputKinds[randUInt(normalOperationInputCount)] = forceKind; } }; @@ -889,7 +864,7 @@ TEST_P(RandomPartitioningTest, Test) { const auto& existingOperationOutputs = model.getOperationOutputs(existingOperationIndex); operandIndex = - existingOperationOutputs[randUInt(existingOperationOutputs.size())]; + existingOperationOutputs[randUInt(existingOperationOutputs.size())]; deadOperandI = deadOperands.find(operandIndex); CHECK(deadOperandI == deadOperands.end() || deadOperandI->second == existingOperationIndex); @@ -913,7 +888,8 @@ TEST_P(RandomPartitioningTest, Test) { operandIndex = model.addOperand(&problemType); if (randFrac() < 0.5) { std::vector<float> value(problemSize * problemSize); - std::generate(value.begin(), value.end(), [this]{ return randFrac(); }); + std::generate(value.begin(), value.end(), + [this] { return randFrac(); }); model.setOperandValue(operandIndex, value); valueOperands.push_back(std::make_pair(operandIndex, ~0U)); } else { @@ -945,7 +921,7 @@ TEST_P(RandomPartitioningTest, Test) { std::vector<uint32_t> operationOutputs(operationPattern.mNumOutputs); std::generate(operationOutputs.begin(), operationOutputs.end(), [&model, &problemType, &unknownDimensionsType, &hasUnknownDimensions, - allowUnknownDimensions, this]{ + allowUnknownDimensions, this] { // 3% unknowns causes ~35% of partitionings to fail // (determined by commenting out the fallback code, // running tests and noting number of failures). @@ -959,9 +935,8 @@ TEST_P(RandomPartitioningTest, Test) { // OPERATION /////////////////////////////////////////////////////////////////////////////// - const uint32_t operationIndex = - model.addOperation(operationPattern.mOperationType, - operationInputs, operationOutputs); + const uint32_t operationIndex = model.addOperation(operationPattern.mOperationType, + operationInputs, operationOutputs); deadOperations.insert(operationIndex); std::for_each(operationOutputs.begin(), operationOutputs.end(), [&deadOperands, operationIndex](uint32_t operandIndex) { @@ -984,7 +959,7 @@ TEST_P(RandomPartitioningTest, Test) { float* region = static_cast<float*>(weights.getRegion(regionIndex, &memory, &offset, &length)); CHECK(length == problemSize * problemSize * sizeof(float)); - std::generate(region, region + problemSize * problemSize, [this]{ return randFrac(); }); + std::generate(region, region + problemSize * problemSize, [this] { return randFrac(); }); model.setOperandValueFromMemory(operandIndex, memory, offset, length); } @@ -1005,7 +980,7 @@ TEST_P(RandomPartitioningTest, Test) { // more likely we are to classify this operation // output as a model output. const double probabilityOfModelOutput = - 0.50 * [](double x){ return x*x; }((operationIdx + 1) / operationCount); + 0.50 * [](double x) { return x * x; }((operationIdx + 1) / operationCount); modelOutput = (randFrac() < probabilityOfModelOutput); } else { // This is consumed within the model, so we'll rarely @@ -1044,8 +1019,7 @@ TEST_P(RandomPartitioningTest, Test) { #ifdef VERBOSE { std::cout << "Original model: " << ModelStats(&model) << std::endl; - std::cout << "rootOperationCount = " << rootOperationCount - << ", deadOperations = "; + std::cout << "rootOperationCount = " << rootOperationCount << ", deadOperations = "; if (allowDeadOperations) { std::cout << deadOperations.size(); } else { @@ -1072,8 +1046,8 @@ TEST_P(RandomPartitioningTest, Test) { } // Now remove each entry that has no signatures. auto firstExtra = - std::remove_if(signaturesForDriver.begin(), signaturesForDriver.end(), - [](const std::set<Signature>& sigSet) { return sigSet.empty(); }); + std::remove_if(signaturesForDriver.begin(), signaturesForDriver.end(), + [](const std::set<Signature>& sigSet) { return sigSet.empty(); }); if (firstExtra != signaturesForDriver.end()) { signaturesForDriver.erase(firstExtra, signaturesForDriver.end()); } @@ -1114,7 +1088,7 @@ TEST_P(RandomPartitioningTest, Test) { // the fallback to succeed. TestCompilation cNoFallback(&model, devices); TestCompilation cWithFallback(&model, devices); - TestCompilation *c2 = nullptr; + TestCompilation* c2 = nullptr; ASSERT_EQ(cNoFallback.setPartitioning(DeviceManager::kPartitioningWithoutFallback), Result::NO_ERROR); auto compilationResult = cNoFallback.finish(); @@ -1134,8 +1108,8 @@ TEST_P(RandomPartitioningTest, Test) { #ifdef VERBOSE { - std::cout << "signatures = " << signatures.size() - << ", devices = " << devices.size() << std::endl; + std::cout << "signatures = " << signatures.size() << ", devices = " << devices.size() + << std::endl; const ExecutionPlan& plan = c2->getExecutionPlan(); switch (plan.forTest_getKind()) { case ExecutionPlan::Kind::SIMPLE: @@ -1157,7 +1131,7 @@ TEST_P(RandomPartitioningTest, Test) { } default: std::cout << "Unexpected plan kind: " - << static_cast<unsigned>(plan.forTest_getKind()); + << static_cast<unsigned>(plan.forTest_getKind()); break; } } @@ -1187,7 +1161,7 @@ TEST_P(RandomPartitioningTest, Test) { // should not be dependent on the outputs; but we'll initialize the // outputs anyway. std::vector<float> masterInputs(problemSize * problemSize * model.inputCount()); - std::generate(masterInputs.begin(), masterInputs.end(), [this]{ return randFrac(); }); + std::generate(masterInputs.begin(), masterInputs.end(), [this] { return randFrac(); }); #ifdef VERBOSE { std::cout << "flat inputs = "; @@ -1213,9 +1187,8 @@ TEST_P(RandomPartitioningTest, Test) { }; std::vector<InputOutputDescriptor> ioDescriptors(model.inputCount() + model.outputCount()); for (unsigned i = 0; i < ioDescriptors.size(); i++) { - ioDescriptors[i].mKind = (i < model.inputCount() - ? InputOutputDescriptor::INPUT - : InputOutputDescriptor::OUTPUT); + ioDescriptors[i].mKind = (i < model.inputCount() ? InputOutputDescriptor::INPUT + : InputOutputDescriptor::OUTPUT); } // We randomly interleave inputs and outputs in creation // order, because when we we create memory regions in a @@ -1226,7 +1199,7 @@ TEST_P(RandomPartitioningTest, Test) { // they'll be interleaved. std::shuffle(ioDescriptors.begin(), ioDescriptors.end(), mRandNumEng); TestMemories ioMemories; - for (auto &desc : ioDescriptors) { + for (auto& desc : ioDescriptors) { if (randFrac() < 0.5) { desc.mVector.resize(problemSize * problemSize); } else { @@ -1245,11 +1218,10 @@ TEST_P(RandomPartitioningTest, Test) { // Function to set up actual inputs and outputs (initializing them // and telling the WrapperExecution about them). - auto prepareForExecution = - [&model, &ioDescriptors, &ioMemories, - &masterInputs, &masterOutput, problemSize, &problemType](WrapperExecution *e) { + auto prepareForExecution = [&model, &ioDescriptors, &ioMemories, &masterInputs, &masterOutput, + problemSize, &problemType](WrapperExecution* e) { uint32_t inputIndex = 0, outputIndex = 0; - for (auto &desc : ioDescriptors) { + for (auto& desc : ioDescriptors) { if (desc.getLocation() == InputOutputDescriptor::VECTOR) { if (desc.mKind == InputOutputDescriptor::INPUT) { const size_t inputOffset = inputIndex * problemSize * problemSize; @@ -1260,18 +1232,15 @@ TEST_P(RandomPartitioningTest, Test) { desc.mVector.size() * sizeof(float)); } else { std::fill(desc.mVector.begin(), - desc.mVector.begin() + problemSize * problemSize, - masterOutput); + desc.mVector.begin() + problemSize * problemSize, masterOutput); e->setOutput(outputIndex++, desc.mVector.data(), - desc.mVector.size() * sizeof(float), - &problemType.operandType); + desc.mVector.size() * sizeof(float), &problemType.operandType); } } else { const WrapperMemory* memory; uint32_t offset, length; - float* region = - static_cast<float*>(ioMemories.getRegion(desc.mMemoryRegion, - &memory, &offset, &length)); + float* region = static_cast<float*>( + ioMemories.getRegion(desc.mMemoryRegion, &memory, &offset, &length)); CHECK(length == problemSize * problemSize * sizeof(float)); if (desc.mKind == InputOutputDescriptor::INPUT) { const size_t inputOffset = inputIndex * problemSize * problemSize; @@ -1280,9 +1249,7 @@ TEST_P(RandomPartitioningTest, Test) { region); e->setInputFromMemory(inputIndex++, memory, offset, length); } else { - std::fill(region, - region + problemSize * problemSize, - masterOutput); + std::fill(region, region + problemSize * problemSize, masterOutput); e->setOutputFromMemory(outputIndex++, memory, offset, length, &problemType.operandType); } @@ -1307,13 +1274,11 @@ TEST_P(RandomPartitioningTest, Test) { } const size_t outputOffset = outputIndex * problemSize * problemSize; if (desc.getLocation() == InputOutputDescriptor::VECTOR) { - std::copy(desc.mVector.begin(), - desc.mVector.end(), + std::copy(desc.mVector.begin(), desc.mVector.end(), nonPartitionedOutputs.begin() + outputOffset); } else { float* region = static_cast<float*>(ioMemories.getRegion(desc.mMemoryRegion)); - std::copy(region, - region + problemSize * problemSize, + std::copy(region, region + problemSize * problemSize, nonPartitionedOutputs.begin() + outputOffset); } #ifdef VERBOSE @@ -1347,8 +1312,7 @@ TEST_P(RandomPartitioningTest, Test) { std::cout << " partitioned output[" << outputIndex << "] = "; dump(desc.mVector.begin(), desc.mVector.end()); #endif - ASSERT_TRUE(std::equal(desc.mVector.begin(), - desc.mVector.end(), + ASSERT_TRUE(std::equal(desc.mVector.begin(), desc.mVector.end(), nonPartitionedOutputs.begin() + outputOffset)); } else { float* region = static_cast<float*>(ioMemories.getRegion(desc.mMemoryRegion)); @@ -1356,8 +1320,7 @@ TEST_P(RandomPartitioningTest, Test) { std::cout << "part output[" << outputIndex << "] = "; dump(region, region + problemSize * problemSize); #endif - ASSERT_TRUE(std::equal(region, - region + problemSize * problemSize, + ASSERT_TRUE(std::equal(region, region + problemSize * problemSize, nonPartitionedOutputs.begin() + outputOffset)); } outputIndex++; diff --git a/nn/runtime/test/TestTrivialModel.cpp b/nn/runtime/test/TestTrivialModel.cpp index c5c065439..7280e6ae0 100644 --- a/nn/runtime/test/TestTrivialModel.cpp +++ b/nn/runtime/test/TestTrivialModel.cpp @@ -27,7 +27,7 @@ typedef float Matrix3x4[3][4]; typedef float Matrix4[4]; class TrivialTest : public ::testing::Test { -protected: + protected: virtual void SetUp() {} const Matrix3x4 matrix1 = {{1.f, 2.f, 3.f, 4.f}, {5.f, 6.f, 7.f, 8.f}, {9.f, 10.f, 11.f, 12.f}}; @@ -35,9 +35,8 @@ protected: {500.f, 600.f, 700.f, 800.f}, {900.f, 1000.f, 1100.f, 1200.f}}; const Matrix4 matrix2b = {100.f, 200.f, 300.f, 400.f}; - const Matrix3x4 matrix3 = {{20.f, 30.f, 40.f, 50.f}, - {21.f, 22.f, 23.f, 24.f}, - {31.f, 32.f, 33.f, 34.f}}; + const Matrix3x4 matrix3 = { + {20.f, 30.f, 40.f, 50.f}, {21.f, 22.f, 23.f, 24.f}, {31.f, 32.f, 33.f, 34.f}}; const Matrix3x4 expected2 = {{101.f, 202.f, 303.f, 404.f}, {505.f, 606.f, 707.f, 808.f}, {909.f, 1010.f, 1111.f, 1212.f}}; @@ -51,9 +50,8 @@ protected: const Matrix3x4 expected3 = {{121.f, 232.f, 343.f, 454.f}, {526.f, 628.f, 730.f, 832.f}, {940.f, 1042.f, 1144.f, 1246.f}}; - const Matrix3x4 expected3b = {{22.f, 34.f, 46.f, 58.f}, - {31.f, 34.f, 37.f, 40.f}, - {49.f, 52.f, 55.f, 58.f}}; + const Matrix3x4 expected3b = { + {22.f, 34.f, 46.f, 58.f}, {31.f, 34.f, 37.f, 40.f}, {49.f, 52.f, 55.f, 58.f}}; }; // Create a model that can add two tensors using a one node graph. diff --git a/nn/runtime/test/TestUnknownDimensions.cpp b/nn/runtime/test/TestUnknownDimensions.cpp index ada52c2b6..75465659a 100644 --- a/nn/runtime/test/TestUnknownDimensions.cpp +++ b/nn/runtime/test/TestUnknownDimensions.cpp @@ -28,8 +28,8 @@ using namespace test_helper; namespace { const uint32_t INTENDED_SIZE = 3; -const uint32_t OTHER_SIZE = 2; -const uint32_t UNKNOWN_SIZE = 0; +const uint32_t OTHER_SIZE = 2; +const uint32_t UNKNOWN_SIZE = 0; // We test three basic scenarios for each tensor dimension: // INTENDED_AT_COMPILE_AND_EXECUTE: set the dimension at compile @@ -58,19 +58,21 @@ const uint32_t UNKNOWN_SIZE = 0; // infrastructure to handle correctly. However, running all 16k in one test // makes the ASAN version take so long that the automatic test runner things the // command has become unresponsinve, so we split on the first level. -enum class DimensionKind { INTENDED_AT_COMPILE_AND_EXECUTE, - INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE, - UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE, - UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE }; +enum class DimensionKind { + INTENDED_AT_COMPILE_AND_EXECUTE, + INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE, + UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE, + UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE +}; typedef std::tuple<DimensionKind, DimensionKind> OperandParams; std::vector<DimensionKind> ioDimensionValues = { - DimensionKind::INTENDED_AT_COMPILE_AND_EXECUTE, - DimensionKind::INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE, - DimensionKind::UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE, - DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE }; + DimensionKind::INTENDED_AT_COMPILE_AND_EXECUTE, + DimensionKind::INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE, + DimensionKind::UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE, + DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE}; std::vector<DimensionKind> constantDimensionValues = { DimensionKind::INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE, - DimensionKind::UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE }; + DimensionKind::UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE}; std::vector<OperandParams> Combine(const std::vector<DimensionKind>& firsts, const std::vector<DimensionKind>& seconds); auto ioValues = Combine(ioDimensionValues, ioDimensionValues); @@ -143,18 +145,20 @@ void UnknownDimensionsTest::CompareResults<_Float16>(std::map<int, std::vector<_ EXPECT_EQ(size_t{0}, totalNumberOfErrors); } -template<class T, Type TensorType> void UnknownDimensionsTest::TestOne( - const OperandParams& paramsForInput0, const OperandParams& paramsForInput1, - const OperandParams& paramsForConst, const OperandParams& paramsForOutput) { +template <class T, Type TensorType> +void UnknownDimensionsTest::TestOne(const OperandParams& paramsForInput0, + const OperandParams& paramsForInput1, + const OperandParams& paramsForConst, + const OperandParams& paramsForOutput) { typedef T IntendedMatrix[INTENDED_SIZE][INTENDED_SIZE]; - static const IntendedMatrix ones = { { 1, 1, 1 }, { 1, 1, 1 }, { 1, 1, 1 } }; - static const IntendedMatrix twos = { { 2, 2, 2 }, { 2, 2, 2 }, { 2, 2, 2 } }; - static const IntendedMatrix fives = { { 5, 5, 5 }, { 5, 5, 5 }, { 5, 5, 5 } }; + static const IntendedMatrix ones = {{1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + static const IntendedMatrix twos = {{2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + static const IntendedMatrix fives = {{5, 5, 5}, {5, 5, 5}, {5, 5, 5}}; const float scale = TensorType == Type::TENSOR_QUANT8_ASYMM ? 1.f : 0.f; Model model; - std::string input0Scope("Input 0:"), input1Scope("Input 1:"), - constantScope("Constant:"), outputScope("Output:"); + std::string input0Scope("Input 0:"), input1Scope("Input 1:"), constantScope("Constant:"), + outputScope("Output:"); auto getDimForCompile = [](DimensionKind kind, std::string* scope) { switch (kind) { @@ -176,8 +180,8 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne( std::string* scope = nullptr) { OperandType matrixTypeWithPotentiallyUnknownDims( TensorType, - { getDimForCompile(std::get<0>(params), scope), - getDimForCompile(std::get<1>(params), scope) }, + {getDimForCompile(std::get<0>(params), scope), + getDimForCompile(std::get<1>(params), scope)}, scale); return model.addOperand(&matrixTypeWithPotentiallyUnknownDims); }; @@ -202,11 +206,9 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne( model.setOperandValue(activationOpd0, &activation, sizeof(activation)); model.setOperandValue(constantOpd0, twos, sizeof(twos)); - model.addOperation(ANEURALNETWORKS_ADD, - {inputOpd0, inputOpd1, activationOpd0}, + model.addOperation(ANEURALNETWORKS_ADD, {inputOpd0, inputOpd1, activationOpd0}, {intermediateOpd0}); - model.addOperation(ANEURALNETWORKS_ADD, - {intermediateOpd0, constantOpd0, activationOpd0}, + model.addOperation(ANEURALNETWORKS_ADD, {intermediateOpd0, constantOpd0, activationOpd0}, {outputOpd0}); model.identifyInputsAndOutputs({inputOpd0, inputOpd1}, {outputOpd0}); if (std::get<0>(paramsForConst) == DimensionKind::INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE && @@ -224,7 +226,7 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne( Compilation compilation(&model); ASSERT_EQ(compilation.finish(), Result::NO_ERROR); - IntendedMatrix actual = { { 10, 10, 10 }, { 10, 10, 10 }, { 10, 10, 10 } }; + IntendedMatrix actual = {{10, 10, 10}, {10, 10, 10}, {10, 10, 10}}; Execution execution(&compilation); OperandType matrixTypeIntended(TensorType, {INTENDED_SIZE, INTENDED_SIZE}, scale); @@ -261,19 +263,21 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne( // on OperandParams auto sizeAtSet = [](OperandParams params) { auto first = std::get<0>(params), second = std::get<1>(params); - size_t firstDim = (first == DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE) ? - OTHER_SIZE : INTENDED_SIZE; - size_t secondDim = (second == DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE) ? - OTHER_SIZE : INTENDED_SIZE; + size_t firstDim = (first == DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE) + ? OTHER_SIZE + : INTENDED_SIZE; + size_t secondDim = (second == DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE) + ? OTHER_SIZE + : INTENDED_SIZE; return firstDim * secondDim * sizeof(fives[0][0]); }; ASSERT_EQ(execution.setInput(0, ones, sizeAtSet(paramsForInput0), typeAtSet(paramsForInput0)), Result::NO_ERROR); ASSERT_EQ(execution.setInput(1, twos, sizeAtSet(paramsForInput1), typeAtSet(paramsForInput1)), Result::NO_ERROR); - ASSERT_EQ(execution.setOutput(0, actual, sizeAtSet(paramsForOutput), - typeAtSet(paramsForOutput)), - Result::NO_ERROR); + ASSERT_EQ( + execution.setOutput(0, actual, sizeAtSet(paramsForOutput), typeAtSet(paramsForOutput)), + Result::NO_ERROR); if (allAreIntendedSizeAtExecution) { ASSERT_EQ(execution.compute(), Result::NO_ERROR); @@ -295,21 +299,22 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne( std::vector<OperandParams> Combine(const std::vector<DimensionKind>& firsts, const std::vector<DimensionKind>& seconds) { std::vector<OperandParams> ret; - for (auto first: firsts) { - for (auto second: seconds) { + for (auto first : firsts) { + for (auto second : seconds) { ret.push_back({first, second}); } } return ret; } -template<class T, Type TensorType> void UnknownDimensionsTest::TestAll() { +template <class T, Type TensorType> +void UnknownDimensionsTest::TestAll() { const OperandParams paramsForInput0 = GetParam(); - for (auto paramsForInput1: ioValues) { - for (auto paramsForConst: constantValues) { - for (auto paramsForOutput: ioValues) { - TestOne<T, TensorType>(paramsForInput0, paramsForInput1, - paramsForConst, paramsForOutput); + for (auto paramsForInput1 : ioValues) { + for (auto paramsForConst : constantValues) { + for (auto paramsForOutput : ioValues) { + TestOne<T, TensorType>(paramsForInput0, paramsForInput1, paramsForConst, + paramsForOutput); } } } diff --git a/nn/runtime/test/TestWrapper.cpp b/nn/runtime/test/TestWrapper.cpp index d80d46950..1ab8f9540 100644 --- a/nn/runtime/test/TestWrapper.cpp +++ b/nn/runtime/test/TestWrapper.cpp @@ -23,7 +23,7 @@ using namespace ::android::nn::wrapper; // This file tests certain aspects of the interfaces from NeuralNetworksWrapper.h. class WrapperTestModelFinish : public ::testing::Test { -protected: + protected: void SetUp() override { OperandType type(Type::TENSOR_FLOAT32, {1}); mIn = mModel.addOperand(&type); diff --git a/nn/tools/ion_watcher/ion_watcher.cpp b/nn/tools/ion_watcher/ion_watcher.cpp index 0061086dc..1a79b38b3 100644 --- a/nn/tools/ion_watcher/ion_watcher.cpp +++ b/nn/tools/ion_watcher/ion_watcher.cpp @@ -16,12 +16,12 @@ #define LOG_TAG "IonWatcher" +#include <stdio.h> +#include <unistd.h> #include <fstream> #include <iostream> #include <sstream> -#include <stdio.h> #include <string> -#include <unistd.h> #include <android/log.h> #define ATRACE_TAG ATRACE_TAG_NNAPI @@ -54,16 +54,16 @@ int main(void) { } int size = 0; while (true) { - const int newSize = parseMemInfo("ION_heap"); - if (newSize < 0) { - return newSize; - } - if (newSize != size) { - size = newSize; - std::cout << size << "\n"; - ATRACE_INT("ION_heap", size); - __android_log_print(ANDROID_LOG_INFO, "ion", "ION_heap %d", size); - } - usleep(10); + const int newSize = parseMemInfo("ION_heap"); + if (newSize < 0) { + return newSize; + } + if (newSize != size) { + size = newSize; + std::cout << size << "\n"; + ATRACE_INT("ION_heap", size); + __android_log_print(ANDROID_LOG_INFO, "ion", "ION_heap %d", size); + } + usleep(10); } } diff --git a/nn/tools/test_generator/include/TestHarness.h b/nn/tools/test_generator/include/TestHarness.h index e43658d03..3b4b26b16 100644 --- a/nn/tools/test_generator/include/TestHarness.h +++ b/nn/tools/test_generator/include/TestHarness.h @@ -224,8 +224,7 @@ void filter_internal(const std::map<int, std::vector<T>>& golden, }); } -inline MixedTyped filter(const MixedTyped& golden, - std::function<bool(int)> is_ignored) { +inline MixedTyped filter(const MixedTyped& golden, std::function<bool(int)> is_ignored) { MixedTyped filtered; filter_internal(golden.operandDimensions, &filtered.operandDimensions, is_ignored); filter_internal(golden.float32Operands, &filtered.float32Operands, is_ignored); @@ -275,8 +274,8 @@ inline int getQuant8AllowedError() { } } -inline void compare(const MixedTyped& golden, const MixedTyped& test, - float fpAtol = 1e-5f, float fpRtol = 1e-5f) { +inline void compare(const MixedTyped& golden, const MixedTyped& test, float fpAtol = 1e-5f, + float fpRtol = 1e-5f) { int quant8AllowedError = getQuant8AllowedError(); for_each<uint32_t>( golden.operandDimensions, test.operandDimensions, |