summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2019-07-22 18:59:46 -0700
committerSlava Shklyaev <slavash@google.com>2019-08-23 11:42:41 +0100
commit43953b8f3976fe83c4b04322d4e855cba0688b1e (patch)
tree0a6719d328cfe7adeed49f814412e03dde303ad9
parenta1846f57b824acda3616a0053bda3912b3f591ac (diff)
downloadml-43953b8f3976fe83c4b04322d4e855cba0688b1e.tar.gz
clang-format for frameworks/ml/nn
This CL formats all of frameworks/ml/nn/* with the following commands: $ $CLANG_DIR/clang-format --style=file -i `find $NNAPI_DIR -name "*.cpp"` $ $CLANG_DIR/clang-format --style=file -i `find $NNAPI_DIR -name "*.h"` where: * "NNAPI_DIR" is "$ANDROID_BUILD_TOP/frameworks/ml/nn" * "CLANG_DIR" is "$ANDROID_BUILD_TOP/prebuilts/clang/host/linux-x86/clang-stable/bin" Bug: N/A Test: mma Change-Id: Idddbc7ecaeab76fb0bbee4250830333752a1f29b Merged-In: Idddbc7ecaeab76fb0bbee4250830333752a1f29b (cherry picked from commit 67e41a5467d7879b34f613069ade6cf61d5bd633)
-rw-r--r--nn/common/GraphDump.cpp53
-rw-r--r--nn/common/OperationsUtils.cpp218
-rw-r--r--nn/common/Utils.cpp168
-rw-r--r--nn/common/ValidateHal.cpp12
-rw-r--r--nn/common/include/ActivationFunctor.h45
-rw-r--r--nn/common/include/CpuExecutor.h20
-rw-r--r--nn/common/include/GraphDump.h3
-rw-r--r--nn/common/include/OperationsUtils.h135
-rw-r--r--nn/common/include/Tracing.h48
-rw-r--r--nn/common/include/Utils.h18
-rw-r--r--nn/common/operations/ArgMinMax.cpp37
-rw-r--r--nn/common/operations/Conv2D.cpp86
-rw-r--r--nn/common/operations/EmbeddingLookup.cpp35
-rw-r--r--nn/common/operations/EmbeddingLookup.h22
-rw-r--r--nn/common/operations/EmbeddingLookupTest.cpp167
-rw-r--r--nn/common/operations/GenerateProposals.cpp8
-rw-r--r--nn/common/operations/HashtableLookup.cpp57
-rw-r--r--nn/common/operations/HashtableLookup.h30
-rw-r--r--nn/common/operations/HashtableLookupTest.cpp198
-rw-r--r--nn/common/operations/LSTMTest.cpp1701
-rw-r--r--nn/common/operations/RNN.cpp100
-rw-r--r--nn/common/operations/RNNTest.cpp445
-rw-r--r--nn/common/operations/SVDF.cpp101
-rw-r--r--nn/common/operations/SVDFTest.cpp575
-rw-r--r--nn/driver/cache/BlobCache/BlobCache.cpp118
-rw-r--r--nn/driver/cache/BlobCache/BlobCache.h23
-rw-r--r--nn/driver/cache/BlobCache/BlobCache_test.cpp167
-rw-r--r--nn/driver/cache/nnCache/nnCache.cpp68
-rw-r--r--nn/driver/cache/nnCache/nnCache.h23
-rw-r--r--nn/driver/cache/nnCache/nnCache_test.cpp55
-rw-r--r--nn/driver/sample/SampleDriver.cpp12
-rw-r--r--nn/driver/sample/SampleDriverFloatFast.cpp10
-rw-r--r--nn/driver/sample/SampleDriverFloatSlow.cpp10
-rw-r--r--nn/driver/sample/SampleDriverMinimal.cpp10
-rw-r--r--nn/driver/sample/SampleDriverQuant.cpp10
-rw-r--r--nn/runtime/CompilationBuilder.cpp10
-rw-r--r--nn/runtime/CompilationBuilder.h8
-rw-r--r--nn/runtime/ExecutionBuilder.h20
-rw-r--r--nn/runtime/ExecutionPlan.cpp102
-rw-r--r--nn/runtime/ExecutionPlan.h34
-rw-r--r--nn/runtime/Manager.cpp2
-rw-r--r--nn/runtime/Manager.h10
-rw-r--r--nn/runtime/Memory.cpp4
-rw-r--r--nn/runtime/ModelBuilder.h1
-rw-r--r--nn/runtime/VersionedInterfaces.cpp3
-rw-r--r--nn/runtime/test/TestExecution.cpp13
-rw-r--r--nn/runtime/test/TestMemory.cpp3
-rw-r--r--nn/runtime/test/TestMemory.h5
-rw-r--r--nn/runtime/test/TestMemoryInternal.cpp40
-rw-r--r--nn/runtime/test/TestOpenmpSettings.cpp12
-rw-r--r--nn/runtime/test/TestPartitioning.cpp251
-rw-r--r--nn/runtime/test/TestPartitioningRandom.cpp283
-rw-r--r--nn/runtime/test/TestTrivialModel.cpp12
-rw-r--r--nn/runtime/test/TestUnknownDimensions.cpp87
-rw-r--r--nn/runtime/test/TestWrapper.cpp2
-rw-r--r--nn/tools/ion_watcher/ion_watcher.cpp26
-rw-r--r--nn/tools/test_generator/include/TestHarness.h7
57 files changed, 2650 insertions, 3073 deletions
diff --git a/nn/common/GraphDump.cpp b/nn/common/GraphDump.cpp
index 9fe6bf31e..b057692df 100644
--- a/nn/common/GraphDump.cpp
+++ b/nn/common/GraphDump.cpp
@@ -46,8 +46,8 @@ using namespace hal;
//
namespace {
class Dumper {
-public:
- Dumper(std::ostream* outStream) : mStream(outStream) { }
+ public:
+ Dumper(std::ostream* outStream) : mStream(outStream) {}
Dumper(const Dumper&) = delete;
void operator=(const Dumper&) = delete;
@@ -58,7 +58,7 @@ public:
return *this;
}
- class EndlType { };
+ class EndlType {};
Dumper& operator<<(EndlType) {
if (mStream) {
@@ -81,27 +81,36 @@ public:
}
static const EndlType endl;
-private:
+
+ private:
std::ostream* mStream;
std::ostringstream mStringStream;
};
const Dumper::EndlType Dumper::endl;
-}
-
+} // namespace
// Provide short name for OperandType value.
static std::string translate(OperandType type) {
switch (type) {
- case OperandType::FLOAT32: return "F32";
- case OperandType::INT32: return "I32";
- case OperandType::UINT32: return "U32";
- case OperandType::TENSOR_FLOAT32: return "TF32";
- case OperandType::TENSOR_INT32: return "TI32";
- case OperandType::TENSOR_QUANT8_ASYMM: return "TQ8A";
- case OperandType::OEM: return "OEM";
- case OperandType::TENSOR_OEM_BYTE: return "TOEMB";
- default: return toString(type);
+ case OperandType::FLOAT32:
+ return "F32";
+ case OperandType::INT32:
+ return "I32";
+ case OperandType::UINT32:
+ return "U32";
+ case OperandType::TENSOR_FLOAT32:
+ return "TF32";
+ case OperandType::TENSOR_INT32:
+ return "TI32";
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ return "TQ8A";
+ case OperandType::OEM:
+ return "OEM";
+ case OperandType::TENSOR_OEM_BYTE:
+ return "TOEMB";
+ default:
+ return toString(type);
}
}
@@ -110,10 +119,9 @@ static std::string translate(OperandType type) {
// OperandLifeTime::CONSTANT_COPY, then write the Operand's value to
// the Dumper.
namespace {
-template<OperandType nnType, typename cppType>
+template <OperandType nnType, typename cppType>
void tryValueDump(Dumper& dump, const Model& model, const Operand& opnd) {
- if (opnd.type != nnType ||
- opnd.lifetime != OperandLifeTime::CONSTANT_COPY ||
+ if (opnd.type != nnType || opnd.lifetime != OperandLifeTime::CONSTANT_COPY ||
opnd.location.length != sizeof(cppType)) {
return;
}
@@ -122,7 +130,7 @@ void tryValueDump(Dumper& dump, const Model& model, const Operand& opnd) {
memcpy(&val, &model.operandValues[opnd.location.offset], sizeof(cppType));
dump << " = " << val;
}
-}
+} // namespace
void graphDump(const char* name, const Model& model, std::ostream* outStream) {
// Operand nodes are named "d" (operanD) followed by operand index.
@@ -182,8 +190,8 @@ void graphDump(const char* name, const Model& model, std::ostream* outStream) {
dump << ": " << kind;
}
dump << "\\n" << translate(opnd.type);
- tryValueDump<OperandType::FLOAT32, float>(dump, model, opnd);
- tryValueDump<OperandType::INT32, int>(dump, model, opnd);
+ tryValueDump<OperandType::FLOAT32, float>(dump, model, opnd);
+ tryValueDump<OperandType::INT32, int>(dump, model, opnd);
tryValueDump<OperandType::UINT32, unsigned>(dump, model, opnd);
if (opnd.dimensions.size()) {
dump << "(";
@@ -210,8 +218,7 @@ void graphDump(const char* name, const Model& model, std::ostream* outStream) {
dump << " ordering=out";
}
}
- dump << " label=\"" << i << ": "
- << toString(operation.type) << "\"]" << Dumper::endl;
+ dump << " label=\"" << i << ": " << toString(operation.type) << "\"]" << Dumper::endl;
{
// operation inputs
for (unsigned in = 0, inE = operation.inputs.size(); in < inE; in++) {
diff --git a/nn/common/OperationsUtils.cpp b/nn/common/OperationsUtils.cpp
index 9219dddb7..cec15fc94 100644
--- a/nn/common/OperationsUtils.cpp
+++ b/nn/common/OperationsUtils.cpp
@@ -122,8 +122,7 @@ uint32_t getNumberOfElements(const Shape& shape) {
return count;
}
-uint32_t getNumberOfElements(const Shape& shape,
- size_t firstAxisInclusive,
+uint32_t getNumberOfElements(const Shape& shape, size_t firstAxisInclusive,
size_t lastAxisExclusive) {
nnAssert(0 <= firstAxisInclusive);
nnAssert(firstAxisInclusive <= lastAxisExclusive);
@@ -170,8 +169,7 @@ bool QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
return true;
}
-bool QuantizeMultiplierSmallerThanOne(double double_multiplier,
- int32_t* quantized_multiplier,
+bool QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t* quantized_multiplier,
int32_t* right_shift) {
NN_OPS_CHECK(double_multiplier >= 0.);
NN_OPS_CHECK(double_multiplier < 1.);
@@ -195,8 +193,7 @@ bool QuantizeMultiplierSmallerThanOne(double double_multiplier,
return true;
}
-bool QuantizeMultiplierGreaterThanOne(double double_multiplier,
- int32_t* quantized_multiplier,
+bool QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier,
int* left_shift) {
NN_OPS_CHECK(double_multiplier > 1.);
const double q = std::frexp(double_multiplier, left_shift);
@@ -221,15 +218,13 @@ bool GetQuantizedConvolutionMultipler(const Shape& inputShape, const Shape& filt
// The following conditions must be guaranteed by the training pipeline.
NN_OPS_CHECK(std::abs(input_product_scale - bias_scale) <=
- 1e-6 * std::min(input_product_scale, bias_scale));
+ 1e-6 * std::min(input_product_scale, bias_scale));
NN_OPS_CHECK(input_product_scale >= 0);
*multiplier = input_product_scale / outputShape.scale;
return true;
}
-void CalculateActivationRangeUint8(int32_t activation,
- const Shape& outputShape,
- int32_t* act_min,
+void CalculateActivationRangeUint8(int32_t activation, const Shape& outputShape, int32_t* act_min,
int32_t* act_max) {
const int32_t qmin = std::numeric_limits<uint8_t>::min();
const int32_t qmax = std::numeric_limits<uint8_t>::max();
@@ -250,7 +245,7 @@ void CalculateActivationRangeUint8(int32_t activation,
} else if (activation == kActivationRelu1) {
*act_min = std::max(qmin, quantize(-1.0));
*act_max = std::min(qmax, quantize(1.0));
- } else if (activation == kActivationNone){
+ } else if (activation == kActivationNone) {
*act_min = qmin;
*act_max = qmax;
} else {
@@ -258,8 +253,7 @@ void CalculateActivationRangeUint8(int32_t activation,
}
}
-void CalculateActivationRangeFloat(int32_t activation,
- float* activation_min,
+void CalculateActivationRangeFloat(int32_t activation, float* activation_min,
float* activation_max) {
if (activation == kActivationRelu) {
*activation_min = 0.f;
@@ -270,7 +264,7 @@ void CalculateActivationRangeFloat(int32_t activation,
} else if (activation == kActivationRelu1) {
*activation_min = -1.f;
*activation_max = 1.f;
- } else if (activation == kActivationNone){
+ } else if (activation == kActivationNone) {
*activation_min = std::numeric_limits<float>::lowest();
*activation_max = std::numeric_limits<float>::max();
} else {
@@ -372,11 +366,11 @@ bool depthwiseConvPrepare(const Shape& input, const Shape& filter, const Shape&
uint32_t channels_out = getSizeOfDimension(filter, 3);
uint32_t channels_in = getSizeOfDimension(input, 3);
- uint32_t width = getSizeOfDimension(input, 2);
- uint32_t height = getSizeOfDimension(input, 1);
- uint32_t filterWidth = getSizeOfDimension(filter, 2);
+ uint32_t width = getSizeOfDimension(input, 2);
+ uint32_t height = getSizeOfDimension(input, 1);
+ uint32_t filterWidth = getSizeOfDimension(filter, 2);
uint32_t filterHeight = getSizeOfDimension(filter, 1);
- uint32_t batches = getSizeOfDimension(input, 0);
+ uint32_t batches = getSizeOfDimension(input, 0);
NN_OPS_CHECK(depth_multiplier * channels_in == channels_out);
int32_t effectiveFilterWidth = (filterWidth - 1) * dilation_width_factor + 1;
@@ -396,8 +390,7 @@ bool depthwiseConvPrepare(const Shape& input, const Shape& filter, const Shape&
return true;
}
-bool genericActivationPrepare(const Shape& input,
- Shape* output) {
+bool genericActivationPrepare(const Shape& input, Shape* output) {
NN_OPS_CHECK(getNumberOfDimensions(input) <= 4);
return SetShape(input, output);
}
@@ -406,15 +399,13 @@ bool genericNormalizationPrepare(const Shape& input, Shape* output) {
return SetShape(input, output);
}
-bool reshapePrepare(const Shape& input,
- const int32_t* targetDims,
- const int32_t targetDimsSize,
+bool reshapePrepare(const Shape& input, const int32_t* targetDims, const int32_t targetDimsSize,
Shape* output) {
// Reshape allows one of the targetDims components to have the
// special -1 value, meaning it will be calculated automatically based on the
// input. Here we calculate what that dimension should be so that the number
// of output elements in the same as the number of input elements.
- int32_t numInputElements = (int32_t) getNumberOfElements(input);
+ int32_t numInputElements = (int32_t)getNumberOfElements(input);
std::vector<uint32_t> outDims(targetDimsSize);
int32_t numOutputElements = 1;
@@ -431,7 +422,7 @@ bool reshapePrepare(const Shape& input,
}
if (strechDim != -1) {
int32_t strechValue = numInputElements / numOutputElements;
- outDims[strechDim] = (uint32_t) strechValue;
+ outDims[strechDim] = (uint32_t)strechValue;
numOutputElements *= strechValue;
}
@@ -445,22 +436,18 @@ bool reshapePrepare(const Shape& input,
return true;
}
-bool depthToSpacePrepare(const Shape& input,
- int32_t blockSize,
- Shape* output) {
+bool depthToSpacePrepare(const Shape& input, int32_t blockSize, Shape* output) {
NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
NN_OPS_CHECK(blockSize > 0);
- uint32_t batches = getSizeOfDimension(input, 0);
- uint32_t height = getSizeOfDimension(input, 1);
- uint32_t width = getSizeOfDimension(input, 2);
+ uint32_t batches = getSizeOfDimension(input, 0);
+ uint32_t height = getSizeOfDimension(input, 1);
+ uint32_t width = getSizeOfDimension(input, 2);
uint32_t channels = getSizeOfDimension(input, 3);
NN_OPS_CHECK(channels % (blockSize * blockSize) == 0);
output->type = input.type;
- output->dimensions = {batches,
- height * blockSize,
- width * blockSize,
+ output->dimensions = {batches, height * blockSize, width * blockSize,
channels / (blockSize * blockSize)};
output->offset = input.offset;
output->scale = input.scale;
@@ -468,24 +455,20 @@ bool depthToSpacePrepare(const Shape& input,
return true;
}
-bool spaceToDepthPrepare(const Shape& input,
- int32_t blockSize,
- Shape* output) {
+bool spaceToDepthPrepare(const Shape& input, int32_t blockSize, Shape* output) {
NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
NN_OPS_CHECK(blockSize > 0);
- uint32_t batches = getSizeOfDimension(input, 0);
- uint32_t height = getSizeOfDimension(input, 1);
- uint32_t width = getSizeOfDimension(input, 2);
+ uint32_t batches = getSizeOfDimension(input, 0);
+ uint32_t height = getSizeOfDimension(input, 1);
+ uint32_t width = getSizeOfDimension(input, 2);
uint32_t channels = getSizeOfDimension(input, 3);
NN_OPS_CHECK(height % blockSize == 0);
NN_OPS_CHECK(width % blockSize == 0);
output->type = input.type;
- output->dimensions = {batches,
- height / blockSize,
- width / blockSize,
+ output->dimensions = {batches, height / blockSize, width / blockSize,
channels * (blockSize * blockSize)};
output->offset = input.offset;
output->scale = input.scale;
@@ -493,19 +476,17 @@ bool spaceToDepthPrepare(const Shape& input,
return true;
}
-bool embeddingLookupPrepare(const Shape &valueShape,
- const Shape &lookupShape,
- Shape *outputShape) {
+bool embeddingLookupPrepare(const Shape& valueShape, const Shape& lookupShape, Shape* outputShape) {
NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 2);
NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1);
- const uint32_t rows = getSizeOfDimension(valueShape, 0);
- const uint32_t columns = getSizeOfDimension(valueShape, 1);
+ const uint32_t rows = getSizeOfDimension(valueShape, 0);
+ const uint32_t columns = getSizeOfDimension(valueShape, 1);
- const uint32_t lookups = getSizeOfDimension(lookupShape, 0);
+ const uint32_t lookups = getSizeOfDimension(lookupShape, 0);
outputShape->type = valueShape.type;
- outputShape->dimensions = { lookups, columns };
+ outputShape->dimensions = {lookups, columns};
for (uint32_t i = 2; i < getNumberOfDimensions(valueShape); i++) {
outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i));
}
@@ -515,20 +496,17 @@ bool embeddingLookupPrepare(const Shape &valueShape,
return true;
}
-bool hashtableLookupPrepare(const Shape &lookupShape,
- const Shape &keyShape,
- const Shape &valueShape,
- Shape *outputShape,
- Shape *hitShape) {
+bool hashtableLookupPrepare(const Shape& lookupShape, const Shape& keyShape,
+ const Shape& valueShape, Shape* outputShape, Shape* hitShape) {
NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1);
NN_OPS_CHECK(getNumberOfDimensions(keyShape) == 1);
NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 1);
- const uint32_t lookups = getSizeOfDimension(lookupShape, 0);
- const uint32_t keys = getSizeOfDimension(keyShape, 0);
- const uint32_t rows = getSizeOfDimension(valueShape, 0);
+ const uint32_t lookups = getSizeOfDimension(lookupShape, 0);
+ const uint32_t keys = getSizeOfDimension(keyShape, 0);
+ const uint32_t rows = getSizeOfDimension(valueShape, 0);
outputShape->type = valueShape.type;
- outputShape->dimensions = { lookups };
+ outputShape->dimensions = {lookups};
for (uint32_t i = 1; i < getNumberOfDimensions(valueShape); i++) {
outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i));
}
@@ -536,16 +514,14 @@ bool hashtableLookupPrepare(const Shape &lookupShape,
outputShape->scale = valueShape.scale;
hitShape->type = OperandType::TENSOR_QUANT8_ASYMM;
- hitShape->dimensions = { lookups };
+ hitShape->dimensions = {lookups};
hitShape->offset = 0;
hitShape->scale = 1.f;
return true;
}
-bool padPrepare(const Shape& input,
- const int32_t* paddingsData,
- const Shape& paddingsShape,
+bool padPrepare(const Shape& input, const int32_t* paddingsData, const Shape& paddingsShape,
Shape* output) {
uint32_t numInputDims = getNumberOfDimensions(input);
@@ -571,10 +547,8 @@ bool padPrepare(const Shape& input,
return true;
}
-bool batchToSpacePrepare(const Shape& input,
- const int32_t* blockSizeData,
- const Shape& blockSizeShape,
- Shape* output) {
+bool batchToSpacePrepare(const Shape& input, const int32_t* blockSizeData,
+ const Shape& blockSizeShape, Shape* output) {
// Only 4D NHWC tensors are supported.
NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
@@ -584,29 +558,24 @@ bool batchToSpacePrepare(const Shape& input,
// Only applies to spatial dimensions.
NN_OPS_CHECK(getSizeOfDimension(blockSizeShape, 0) == 2);
- uint32_t batches = getSizeOfDimension(input, 0);
- uint32_t height = getSizeOfDimension(input, 1);
- uint32_t width = getSizeOfDimension(input, 2);
+ uint32_t batches = getSizeOfDimension(input, 0);
+ uint32_t height = getSizeOfDimension(input, 1);
+ uint32_t width = getSizeOfDimension(input, 2);
uint32_t channels = getSizeOfDimension(input, 3);
NN_OPS_CHECK(batches % (blockSizeData[0] * blockSizeData[1]) == 0);
output->type = input.type;
output->dimensions = {batches / (blockSizeData[0] * blockSizeData[1]),
- height * blockSizeData[0],
- width * blockSizeData[1],
- channels};
+ height * blockSizeData[0], width * blockSizeData[1], channels};
output->offset = input.offset;
output->scale = input.scale;
return true;
}
-bool spaceToBatchPrepare(const Shape& input,
- const int32_t* blockSizeData,
- const Shape& blockSizeShape,
- const int32_t* paddingsData,
- const Shape& paddingsShape,
- Shape* output) {
+bool spaceToBatchPrepare(const Shape& input, const int32_t* blockSizeData,
+ const Shape& blockSizeShape, const int32_t* paddingsData,
+ const Shape& paddingsShape, Shape* output) {
// Only 4D NHWC tensors are supported.
NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
@@ -622,9 +591,9 @@ bool spaceToBatchPrepare(const Shape& input,
NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 0) == 2);
NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 1) == 2);
- uint32_t batches = getSizeOfDimension(input, 0);
- uint32_t height = getSizeOfDimension(input, 1);
- uint32_t width = getSizeOfDimension(input, 2);
+ uint32_t batches = getSizeOfDimension(input, 0);
+ uint32_t height = getSizeOfDimension(input, 1);
+ uint32_t width = getSizeOfDimension(input, 2);
uint32_t channels = getSizeOfDimension(input, 3);
uint32_t paddedHeight = paddingsData[0] + height + paddingsData[1];
@@ -635,8 +604,7 @@ bool spaceToBatchPrepare(const Shape& input,
output->type = input.type;
output->dimensions = {batches * (blockSizeData[0] * blockSizeData[1]),
- paddedHeight / blockSizeData[0],
- paddedWidth / blockSizeData[1],
+ paddedHeight / blockSizeData[0], paddedWidth / blockSizeData[1],
channels};
output->offset = input.offset;
output->scale = input.scale;
@@ -644,9 +612,7 @@ bool spaceToBatchPrepare(const Shape& input,
return true;
}
-bool squeezePrepare(const Shape& input,
- const int32_t* squeezeDims,
- const Shape& squeezeDimsShape,
+bool squeezePrepare(const Shape& input, const int32_t* squeezeDims, const Shape& squeezeDimsShape,
Shape* output) {
int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(input));
@@ -668,13 +634,13 @@ bool squeezePrepare(const Shape& input,
}
} else {
for (int32_t idx = 0; idx < squeezeDimsSize; ++idx) {
- int32_t current = squeezeDims[idx] < 0 ? squeezeDims[idx] + numInputDims
- : squeezeDims[idx];
+ int32_t current =
+ squeezeDims[idx] < 0 ? squeezeDims[idx] + numInputDims : squeezeDims[idx];
NN_OPS_CHECK(current >= 0 && current < numInputDims &&
getSizeOfDimension(input, current) == 1);
if (!shouldSqueeze[current]) ++numDimsSqueezed;
shouldSqueeze[current] = true;
- }
+ }
}
// Sets output dimensions.
@@ -693,12 +659,8 @@ bool squeezePrepare(const Shape& input,
return true;
}
-bool meanPrepare(const Shape& input,
- const int32_t* axisData,
- const Shape& axisShape,
- bool keepDims,
+bool meanPrepare(const Shape& input, const int32_t* axisData, const Shape& axisShape, bool keepDims,
Shape* output) {
-
// perm need to be provided as a 1-D int32 tensor.
NN_OPS_CHECK(axisShape.type == OperandType::TENSOR_INT32);
NN_OPS_CHECK(getNumberOfDimensions(axisShape) == 1);
@@ -770,12 +732,10 @@ bool meanPrepare(const Shape& input,
return true;
}
-bool stridedSlicePrepare(const Shape& input,
- const int32_t* beginData, const Shape& beginShape,
- const int32_t* endData, const Shape& endShape,
- const int32_t* stridesData, const Shape& stridesShape,
- int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask,
- Shape* output) {
+bool stridedSlicePrepare(const Shape& input, const int32_t* beginData, const Shape& beginShape,
+ const int32_t* endData, const Shape& endShape, const int32_t* stridesData,
+ const Shape& stridesShape, int32_t beginMask, int32_t endMask,
+ int32_t shrinkAxisMask, Shape* output) {
uint32_t numInputDims = getNumberOfDimensions(input);
// StridedSlice op only supports 1D-4D input arrays.
NN_OPS_CHECK(numInputDims <= 4);
@@ -795,30 +755,28 @@ bool stridedSlicePrepare(const Shape& input,
// Determine size of output tensor and map indices
std::vector<uint32_t> outDims;
for (int32_t idx = 0; idx < static_cast<int32_t>(numInputDims); idx++) {
- int32_t dim = static_cast<int32_t>(getSizeOfDimension(input, idx));
- int32_t stride = stridesData[idx];
- // stride value has to be non-zero
- NN_OPS_CHECK(stride != 0);
- bool positiveStride = stride > 0;
-
- int32_t begin = beginMask & (1 << idx)
- ? positiveStride ? 0 : dim - 1
- : ClampedIndex(beginData[idx], dim, positiveStride);
- int32_t end = endMask & (1 << idx)
- ? positiveStride ? dim : -1
- : ClampedIndex(endData[idx], dim, positiveStride);
-
- // This is valid for both positive and negative strides
- int32_t outDim = ceil((end - begin) / static_cast<float>(stride));
- outDim = outDim < 0 ? 0 : static_cast<uint32_t>(outDim);
- if (!(shrinkAxisMask & (1 << idx))) {
- outDims.push_back(outDim);
- } else {
- if (outDim != 1) {
- LOG(ERROR) << "Outdim " << idx << " is " << outDim << ", expected 1";
- NN_OPS_CHECK(outDim == 1);
- }
- }
+ int32_t dim = static_cast<int32_t>(getSizeOfDimension(input, idx));
+ int32_t stride = stridesData[idx];
+ // stride value has to be non-zero
+ NN_OPS_CHECK(stride != 0);
+ bool positiveStride = stride > 0;
+
+ int32_t begin = beginMask & (1 << idx) ? positiveStride ? 0 : dim - 1
+ : ClampedIndex(beginData[idx], dim, positiveStride);
+ int32_t end = endMask & (1 << idx) ? positiveStride ? dim : -1
+ : ClampedIndex(endData[idx], dim, positiveStride);
+
+ // This is valid for both positive and negative strides
+ int32_t outDim = ceil((end - begin) / static_cast<float>(stride));
+ outDim = outDim < 0 ? 0 : static_cast<uint32_t>(outDim);
+ if (!(shrinkAxisMask & (1 << idx))) {
+ outDims.push_back(outDim);
+ } else {
+ if (outDim != 1) {
+ LOG(ERROR) << "Outdim " << idx << " is " << outDim << ", expected 1";
+ NN_OPS_CHECK(outDim == 1);
+ }
+ }
}
output->type = input.type;
@@ -837,11 +795,9 @@ bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output) {
// Copy the input dimensions, omitting the axis dimension.
output->dimensions.clear();
output->dimensions.reserve(getNumberOfDimensions(input) - 1);
- output->dimensions.insert(output->dimensions.end(),
- input.dimensions.begin(),
+ output->dimensions.insert(output->dimensions.end(), input.dimensions.begin(),
input.dimensions.begin() + axis);
- output->dimensions.insert(output->dimensions.end(),
- input.dimensions.begin() + axis + 1,
+ output->dimensions.insert(output->dimensions.end(), input.dimensions.begin() + axis + 1,
input.dimensions.end());
return true;
@@ -910,5 +866,5 @@ bool groupedConvPrepare(const Shape& input, const Shape& filter, const Shape& bi
return true;
}
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index fc5493270..11f3c2abd 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -54,15 +54,14 @@ void initVLogMask() {
return;
}
- std::unordered_map<std::string, int> vLogFlags = {
- {"1", -1},
- {"all", -1},
- {"model", MODEL},
- {"compilation", COMPILATION},
- {"execution", EXECUTION},
- {"cpuexe", CPUEXE},
- {"manager", MANAGER},
- {"driver", DRIVER}};
+ std::unordered_map<std::string, int> vLogFlags = {{"1", -1},
+ {"all", -1},
+ {"model", MODEL},
+ {"compilation", COMPILATION},
+ {"execution", EXECUTION},
+ {"cpuexe", CPUEXE},
+ {"manager", MANAGER},
+ {"driver", DRIVER}};
std::vector<std::string> elements = android::base::Split(vLogSetting, " ,:");
for (const auto& elem : elements) {
@@ -103,8 +102,7 @@ namespace {
template <typename EntryType, uint32_t entryCount, uint32_t entryCountOEM>
EntryType tableLookup(const EntryType (&table)[entryCount],
- const EntryType (&tableOEM)[entryCountOEM],
- uint32_t code) {
+ const EntryType (&tableOEM)[entryCountOEM], uint32_t code) {
if (code < entryCount) {
return table[code];
} else if (code >= kOEMCodeBase && (code - kOEMCodeBase) < entryCountOEM) {
@@ -253,16 +251,16 @@ const bool kScalarDataType[]{
static_assert(COUNT(kScalarDataType) == kNumberOfDataTypes, "kScalarDataType is incorrect");
const uint32_t kSizeOfDataTypeOEM[]{
- 0, // ANEURALNETWORKS_OEM
- 1, // ANEURALNETWORKS_TENSOR_OEM_BYTE
+ 0, // ANEURALNETWORKS_OEM
+ 1, // ANEURALNETWORKS_TENSOR_OEM_BYTE
};
static_assert(COUNT(kSizeOfDataTypeOEM) == kNumberOfDataTypesOEM,
"kSizeOfDataTypeOEM is incorrect");
const bool kScalarDataTypeOEM[]{
- true, // ANEURALNETWORKS_OEM
- false, // ANEURALNETWORKS_TENSOR_OEM_BYTE
+ true, // ANEURALNETWORKS_OEM
+ false, // ANEURALNETWORKS_TENSOR_OEM_BYTE
};
static_assert(COUNT(kScalarDataTypeOEM) == kNumberOfDataTypesOEM,
@@ -313,11 +311,11 @@ bool tensorHasUnspecifiedDimensions(const Operand& operand) {
uint32_t alignBytesNeeded(uint32_t index, size_t length) {
uint32_t pattern;
if (length < 2) {
- pattern = 0; // No alignment necessary
+ pattern = 0; // No alignment necessary
} else if (length < 4) {
- pattern = 1; // Align on 2-byte boundary
+ pattern = 1; // Align on 2-byte boundary
} else {
- pattern = 3; // Align on 4-byte boundary
+ pattern = 3; // Align on 4-byte boundary
}
uint32_t extra = (~(index - 1)) & pattern;
return extra;
@@ -479,8 +477,8 @@ int validateOperandList(uint32_t count, const uint32_t* list, uint32_t operandCo
return ANEURALNETWORKS_NO_ERROR;
}
-int validateOperationOperandTypes(const std::vector<Operand>& operands,
- uint32_t inOperandCount, const uint32_t* inOperandIndexes,
+int validateOperationOperandTypes(const std::vector<Operand>& operands, uint32_t inOperandCount,
+ const uint32_t* inOperandIndexes,
const std::vector<OperandType>& inExpectedTypes,
uint32_t outOperandCount, const uint32_t* outOperandIndexes,
const std::vector<OperandType>& outExpectedInTypes) {
@@ -494,16 +492,16 @@ int validateOperationOperandTypes(const std::vector<Operand>& operands,
for (uint32_t i = 0; i < inOperandCount; i++) {
if (operands[inOperandIndexes[i]].type != inExpectedTypes[i]) {
LOG(ERROR) << "Invalid input tensor type "
- << toString(operands[inOperandIndexes[i]].type)
- << " for input " << i << ", expected " << toString(inExpectedTypes[i]);
+ << toString(operands[inOperandIndexes[i]].type) << " for input " << i
+ << ", expected " << toString(inExpectedTypes[i]);
return ANEURALNETWORKS_BAD_DATA;
}
}
for (uint32_t i = 0; i < outOperandCount; i++) {
if (operands[outOperandIndexes[i]].type != outExpectedInTypes[i]) {
LOG(ERROR) << "Invalid output tensor type "
- << toString(operands[outOperandIndexes[i]].type)
- << " for input " << i << ", expected " << toString(outExpectedInTypes[i]);
+ << toString(operands[outOperandIndexes[i]].type) << " for input " << i
+ << ", expected " << toString(outExpectedInTypes[i]);
return ANEURALNETWORKS_BAD_DATA;
}
}
@@ -575,10 +573,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_DEPTHWISE_CONV_2D: {
@@ -683,10 +679,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
} else {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION: {
@@ -724,10 +718,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
} else if (operands[inputIndexes[0]].dimensions.size() != 4) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_RESHAPE: {
@@ -740,8 +732,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
std::vector<OperandType> outExpectedTypes;
if (inputType == OperandType::TENSOR_FLOAT32) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- inExpectedTypes = {OperandType::TENSOR_FLOAT32,
- OperandType::TENSOR_INT32};
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32};
outExpectedTypes = {OperandType::TENSOR_FLOAT32};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
@@ -749,18 +740,15 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
outExpectedTypes = {OperandType::TENSOR_FLOAT16};
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
- OperandType::TENSOR_INT32};
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32};
outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
} else {
LOG(ERROR) << "Unsupported input tensor type for operation "
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_DEPTH_TO_SPACE: {
@@ -775,8 +763,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
std::vector<OperandType> outExpectedTypes;
if (inputType == OperandType::TENSOR_FLOAT32) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- inExpectedTypes = {OperandType::TENSOR_FLOAT32,
- OperandType::INT32};
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
outExpectedTypes = {OperandType::TENSOR_FLOAT32};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
@@ -784,8 +771,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
outExpectedTypes = {OperandType::TENSOR_FLOAT16};
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
- OperandType::INT32};
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
} else {
LOG(ERROR) << "Unsupported input tensor type for operation "
@@ -798,10 +784,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
} else {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_SPACE_TO_DEPTH: {
@@ -816,8 +800,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
std::vector<OperandType> outExpectedTypes;
if (inputType == OperandType::TENSOR_FLOAT32) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- inExpectedTypes = {OperandType::TENSOR_FLOAT32,
- OperandType::INT32};
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
outExpectedTypes = {OperandType::TENSOR_FLOAT32};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
@@ -825,8 +808,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
outExpectedTypes = {OperandType::TENSOR_FLOAT16};
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
- OperandType::INT32};
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
} else {
LOG(ERROR) << "Unsupported input tensor type for operation "
@@ -839,10 +821,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
} else {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_EMBEDDING_LOOKUP: {
@@ -858,14 +838,11 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
- std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32,
- inputType};
+ std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType};
std::vector<OperandType> outExpectedTypes = {inputType};
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_HASHTABLE_LOOKUP: {
@@ -882,15 +859,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
return ANEURALNETWORKS_BAD_DATA;
}
std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32,
- OperandType::TENSOR_INT32,
- inputType};
+ OperandType::TENSOR_INT32, inputType};
std::vector<OperandType> outExpectedTypes = {inputType,
OperandType::TENSOR_QUANT8_ASYMM};
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_LSH_PROJECTION: {
@@ -1174,10 +1148,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
} else {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_SPACE_TO_BATCH_ND: {
@@ -1226,10 +1198,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
} else {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_PAD: {
@@ -1359,8 +1329,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
std::vector<OperandType> outExpectedTypes;
if (inputType == OperandType::TENSOR_FLOAT32) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
- inExpectedTypes = {OperandType::TENSOR_FLOAT32,
- OperandType::TENSOR_INT32};
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32};
outExpectedTypes = {OperandType::TENSOR_FLOAT32};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
@@ -1368,18 +1337,15 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
outExpectedTypes = {OperandType::TENSOR_FLOAT16};
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
- inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
- OperandType::TENSOR_INT32};
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32};
outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
} else {
LOG(ERROR) << "Unsupported input tensor type for operation "
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_STRIDED_SLICE: {
@@ -1439,8 +1405,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
std::vector<OperandType> outExpectedTypes;
if (inputType == OperandType::TENSOR_FLOAT32) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
- inExpectedTypes = {OperandType::TENSOR_FLOAT32,
- OperandType::TENSOR_INT32,
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32,
OperandType::INT32};
outExpectedTypes = {OperandType::TENSOR_FLOAT32};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
@@ -1450,8 +1415,7 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
outExpectedTypes = {OperandType::TENSOR_FLOAT16};
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
- inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
- OperandType::TENSOR_INT32,
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32,
OperandType::INT32};
outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
} else {
@@ -1459,10 +1423,8 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_ARGMAX:
@@ -1951,8 +1913,8 @@ V1_0::Capabilities convertToV1_0(const V1_1::Capabilities& capabilities) {
LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
<< " from V1_1::Capabilities to V1_0::Capabilities";
}
- return { .float32Performance = capabilities.float32Performance,
- .quantized8Performance = capabilities.quantized8Performance };
+ return {.float32Performance = capabilities.float32Performance,
+ .quantized8Performance = capabilities.quantized8Performance};
}
V1_0::Capabilities convertToV1_0(const V1_2::Capabilities& capabilities) {
@@ -1967,9 +1929,9 @@ V1_0::Capabilities convertToV1_0(const V1_2::Capabilities& capabilities) {
}
V1_1::Capabilities convertToV1_1(const V1_0::Capabilities& capabilities) {
- return { .float32Performance = capabilities.float32Performance,
- .quantized8Performance = capabilities.quantized8Performance,
- .relaxedFloat32toFloat16Performance = capabilities.float32Performance };
+ return {.float32Performance = capabilities.float32Performance,
+ .quantized8Performance = capabilities.quantized8Performance,
+ .relaxedFloat32toFloat16Performance = capabilities.float32Performance};
}
V1_1::Capabilities convertToV1_1(const V1_1::Capabilities& capabilities) {
@@ -2361,5 +2323,5 @@ uint32_t getProp(const char* str, uint32_t defaultValue) {
}
#endif // NN_DEBUGGABLE
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
diff --git a/nn/common/ValidateHal.cpp b/nn/common/ValidateHal.cpp
index da440d840..a74b6565b 100644
--- a/nn/common/ValidateHal.cpp
+++ b/nn/common/ValidateHal.cpp
@@ -45,7 +45,7 @@ struct ModelToHalVersion<V1_2::Model> {
};
class MemoryAccessVerifier {
-public:
+ public:
MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
: mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
for (size_t i = 0; i < mPoolCount; i++) {
@@ -68,7 +68,7 @@ public:
return true;
}
-private:
+ private:
size_t mPoolCount;
std::vector<size_t> mPoolSizes;
};
@@ -567,10 +567,10 @@ static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArg
for (size_t i = 0; i < rank; i++) {
if (requestArgument.dimensions[i] != operand.dimensions[i] &&
operand.dimensions[i] != 0) {
- LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
- << " has dimension " << i << " of "
- << requestArgument.dimensions[i]
- << " different than the model's " << operand.dimensions[i];
+ LOG(ERROR)
+ << "Request " << type << " " << requestArgumentIndex
+ << " has dimension " << i << " of " << requestArgument.dimensions[i]
+ << " different than the model's " << operand.dimensions[i];
return false;
}
if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
diff --git a/nn/common/include/ActivationFunctor.h b/nn/common/include/ActivationFunctor.h
index d667ae1c0..d4d4d3aae 100644
--- a/nn/common/include/ActivationFunctor.h
+++ b/nn/common/include/ActivationFunctor.h
@@ -33,31 +33,30 @@ enum ActivationFn {
};
class ActivationFunctor {
- public:
- explicit ActivationFunctor(ActivationFn act) : act_(act) {}
-
- float operator()(float a) const {
- switch (act_) {
- case kActivationNone:
- return a;
- case kActivationRelu:
- return a < 0.f ? 0.f : a;
- case kActivationRelu6:
- return std::max(0.f, std::min(a, 6.f));
- case kActivationTanh:
- return std::tanh(a);
- case kActivationSigmoid:
- return 1.0f / (1.0f + std::exp(-a));
- default:
- __android_log_print(ANDROID_LOG_ERROR, "NN API",
- "Invalid enum value for activation function: 0x%0X",
- act_);
- abort();
+ public:
+ explicit ActivationFunctor(ActivationFn act) : act_(act) {}
+
+ float operator()(float a) const {
+ switch (act_) {
+ case kActivationNone:
+ return a;
+ case kActivationRelu:
+ return a < 0.f ? 0.f : a;
+ case kActivationRelu6:
+ return std::max(0.f, std::min(a, 6.f));
+ case kActivationTanh:
+ return std::tanh(a);
+ case kActivationSigmoid:
+ return 1.0f / (1.0f + std::exp(-a));
+ default:
+ __android_log_print(ANDROID_LOG_ERROR, "NN API",
+ "Invalid enum value for activation function: 0x%0X", act_);
+ abort();
+ }
}
- }
- private:
- ActivationFn act_;
+ private:
+ ActivationFn act_;
};
#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_ACTIVATION_FUNCTOR_H
diff --git a/nn/common/include/CpuExecutor.h b/nn/common/include/CpuExecutor.h
index e589a654f..c0cb1e984 100644
--- a/nn/common/include/CpuExecutor.h
+++ b/nn/common/include/CpuExecutor.h
@@ -213,11 +213,12 @@ class CpuExecutor {
// b/109953668, disable OpenMP
#ifdef NNAPI_OPENMP
class ScopedOpenmpSettings {
-public:
+ public:
ScopedOpenmpSettings();
~ScopedOpenmpSettings();
DISALLOW_COPY_AND_ASSIGN(ScopedOpenmpSettings);
-private:
+
+ private:
int mBlocktimeInitial;
#if NNAPI_LIMIT_CPU_THREADS
int mMaxThreadsInitial;
@@ -225,7 +226,6 @@ private:
};
#endif // NNAPI_OPENMP
-
namespace {
template <typename T>
@@ -235,7 +235,7 @@ T getScalarData(const RunTimeOperandInfo& info) {
return data[0];
}
-inline bool IsNullInput(const RunTimeOperandInfo *input) {
+inline bool IsNullInput(const RunTimeOperandInfo* input) {
return input->lifetime == hal::OperandLifeTime::NO_VALUE;
}
@@ -250,12 +250,12 @@ inline int NumOutputs(const hal::Operation& operation) {
return operation.outputs.size();
}
-inline size_t NumDimensions(const RunTimeOperandInfo *operand) {
- return operand->shape().dimensions.size();
+inline size_t NumDimensions(const RunTimeOperandInfo* operand) {
+ return operand->shape().dimensions.size();
}
-inline uint32_t SizeOfDimension(const RunTimeOperandInfo *operand, int i) {
- return operand->shape().dimensions[i];
+inline uint32_t SizeOfDimension(const RunTimeOperandInfo* operand, int i) {
+ return operand->shape().dimensions[i];
}
inline RunTimeOperandInfo* GetInput(const hal::Operation& operation,
@@ -270,7 +270,7 @@ inline RunTimeOperandInfo* GetOutput(const hal::Operation& operation,
} // anonymous namespace
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_CPU_EXECUTOR_H
diff --git a/nn/common/include/GraphDump.h b/nn/common/include/GraphDump.h
index 1bf02b90e..bee994b39 100644
--- a/nn/common/include/GraphDump.h
+++ b/nn/common/include/GraphDump.h
@@ -45,8 +45,7 @@ namespace nn {
// A model input or output (operand) is shown in "reverse colors" --
// white text on a black background.
//
-void graphDump(const char* name,
- const ::android::hardware::neuralnetworks::V1_2::Model& model,
+void graphDump(const char* name, const ::android::hardware::neuralnetworks::V1_2::Model& model,
std::ostream* outStream = nullptr);
} // namespace nn
diff --git a/nn/common/include/OperationsUtils.h b/nn/common/include/OperationsUtils.h
index 381bc8e16..9ae3aca8b 100644
--- a/nn/common/include/OperationsUtils.h
+++ b/nn/common/include/OperationsUtils.h
@@ -146,8 +146,7 @@ bool combineDimensions(const std::vector<uint32_t>& lhs, const std::vector<uint3
// Return the total number of elements, i.e. all the dimensions multiplied
// together. For a scalar, returns one.
uint32_t getNumberOfElements(const Shape& shape);
-uint32_t getNumberOfElements(const Shape& shape,
- size_t firstAxisInclusive,
+uint32_t getNumberOfElements(const Shape& shape, size_t firstAxisInclusive,
size_t lastAxisExclusive);
uint32_t getNumberOfDimensions(const Shape& shape);
@@ -179,27 +178,20 @@ inline int32_t computeOutSizeTransposeConv(int32_t imageSize, int32_t filterSize
__wur bool QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, int* shift);
-__wur
-bool QuantizeMultiplierSmallerThanOne(double double_multiplier,
- int32_t* quantized_multiplier,
- int32_t* right_shift);
+__wur bool QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t* quantized_multiplier,
+ int32_t* right_shift);
-__wur
-bool QuantizeMultiplierGreaterThanOne(double double_multiplier,
- int32_t* quantized_multiplier,
- int* left_shift);
+__wur bool QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier,
+ int* left_shift);
__wur bool GetQuantizedConvolutionMultipler(const Shape& inputShape, const Shape& filterShape,
const Shape& biasShape, const Shape& outputShape,
double* multiplier);
-void CalculateActivationRangeUint8(int32_t activation,
- const Shape& outputShape,
- int32_t* act_min,
+void CalculateActivationRangeUint8(int32_t activation, const Shape& outputShape, int32_t* act_min,
int32_t* act_max);
-void CalculateActivationRangeFloat(int32_t activation,
- float* activation_min,
+void CalculateActivationRangeFloat(int32_t activation, float* activation_min,
float* activation_max);
int32_t CalculateInputRadius(int input_integer_bits, int input_left_shift);
@@ -231,11 +223,11 @@ inline void calculateExplicitPaddingTransposeConv(int32_t in_size, int32_t strid
padding_tail);
}
-inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight,
- int32_t strideWidth, int32_t strideHeight,
- int32_t filterWidth, int32_t filterHeight,
- int32_t paddingLeft, int32_t paddingRight,
- int32_t paddingTop, int32_t paddingBottom) {
+inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight, int32_t strideWidth,
+ int32_t strideHeight, int32_t filterWidth,
+ int32_t filterHeight, int32_t paddingLeft,
+ int32_t paddingRight, int32_t paddingTop,
+ int32_t paddingBottom) {
if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && paddingBottom == 0) {
return kPaddingValid;
}
@@ -243,8 +235,8 @@ inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight,
int32_t expectedPaddingLeft, expectedPaddingRight;
int32_t expectedPaddingTop, expectedPaddingBottom;
- calculateExplicitPadding(inWidth, strideWidth, filterWidth, kPaddingSame,
- &expectedPaddingLeft, &expectedPaddingRight);
+ calculateExplicitPadding(inWidth, strideWidth, filterWidth, kPaddingSame, &expectedPaddingLeft,
+ &expectedPaddingRight);
calculateExplicitPadding(inHeight, strideHeight, filterHeight, kPaddingSame,
&expectedPaddingTop, &expectedPaddingBottom);
if (expectedPaddingLeft == paddingLeft && expectedPaddingRight == paddingRight &&
@@ -257,30 +249,28 @@ inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight,
// Reverse order of bits in the mask to match the expected order in kernel
inline int ReverseMaskBits(int mask, int num_dimensions) {
- int out = 0;
- for (int dim = 0; dim < num_dimensions; dim++) {
- out <<= 1;
- out += (mask & 1);
- mask >>= 1;
- }
- return out;
+ int out = 0;
+ for (int dim = 0; dim < num_dimensions; dim++) {
+ out <<= 1;
+ out += (mask & 1);
+ mask >>= 1;
+ }
+ return out;
}
// Compute the positive remainder.
inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) {
- return (divisor + (dividend % divisor)) % divisor;
+ return (divisor + (dividend % divisor)) % divisor;
}
// Compute clamped index.
inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
- return pos_stride
- ? (index >= dim ? dim
- : PositiveRemainder(
- std::min(std::max(index, -dim), dim), dim))
- : (index < -dim
- ? -1
- : PositiveRemainder(
- std::min(std::max(index, -dim), dim - 1), dim));
+ return pos_stride
+ ? (index >= dim ? dim
+ : PositiveRemainder(std::min(std::max(index, -dim), dim), dim))
+ : (index < -dim
+ ? -1
+ : PositiveRemainder(std::min(std::max(index, -dim), dim - 1), dim));
}
// Broadcasts input shape against one another and puts the result into output
@@ -303,63 +293,38 @@ bool genericActivationPrepare(const Shape& input, Shape* output);
bool genericNormalizationPrepare(const Shape& input, Shape* output);
-bool reshapePrepare(const Shape& input,
- const int32_t* targetDims,
- const int32_t targetDimsSize,
+bool reshapePrepare(const Shape& input, const int32_t* targetDims, const int32_t targetDimsSize,
Shape* output);
-bool depthToSpacePrepare(const Shape& input,
- int32_t blockSize,
- Shape* output);
+bool depthToSpacePrepare(const Shape& input, int32_t blockSize, Shape* output);
-bool spaceToDepthPrepare(const Shape& input,
- int32_t blockSize,
- Shape* output);
+bool spaceToDepthPrepare(const Shape& input, int32_t blockSize, Shape* output);
-bool embeddingLookupPrepare(const Shape &valueShape,
- const Shape &lookupShape,
- Shape *outputShape);
+bool embeddingLookupPrepare(const Shape& valueShape, const Shape& lookupShape, Shape* outputShape);
-bool hashtableLookupPrepare(const Shape &lookupShape,
- const Shape &keyShape,
- const Shape &valueShape,
- Shape *outputShape,
- Shape *hitShape);
+bool hashtableLookupPrepare(const Shape& lookupShape, const Shape& keyShape,
+ const Shape& valueShape, Shape* outputShape, Shape* hitShape);
-bool padPrepare(const Shape& input,
- const int32_t* paddingsData,
- const Shape& paddingsShape,
+bool padPrepare(const Shape& input, const int32_t* paddingsData, const Shape& paddingsShape,
Shape* output);
-bool batchToSpacePrepare(const Shape& input,
- const int32_t* blockSizeData,
- const Shape& blockSizeShape,
- Shape* output);
-
-bool spaceToBatchPrepare(const Shape& input,
- const int32_t* blockSizeData,
- const Shape& blockSizeShape,
- const int32_t* paddingsData,
- const Shape& paddingsShape,
- Shape* output);
-
-bool squeezePrepare(const Shape& input,
- const int32_t* squeezeDims,
- const Shape& squeezeDimsShape,
+bool batchToSpacePrepare(const Shape& input, const int32_t* blockSizeData,
+ const Shape& blockSizeShape, Shape* output);
+
+bool spaceToBatchPrepare(const Shape& input, const int32_t* blockSizeData,
+ const Shape& blockSizeShape, const int32_t* paddingsData,
+ const Shape& paddingsShape, Shape* output);
+
+bool squeezePrepare(const Shape& input, const int32_t* squeezeDims, const Shape& squeezeDimsShape,
Shape* output);
-bool meanPrepare(const Shape& input,
- const int32_t* axisData,
- const Shape& axisShape,
- bool keepDims,
+bool meanPrepare(const Shape& input, const int32_t* axisData, const Shape& axisShape, bool keepDims,
Shape* output);
-bool stridedSlicePrepare(const Shape& input,
- const int32_t* beginData, const Shape& beginShape,
- const int32_t* endData, const Shape& endShape,
- const int32_t* stridesData, const Shape& stridesShape,
- int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask,
- Shape* output);
+bool stridedSlicePrepare(const Shape& input, const int32_t* beginData, const Shape& beginShape,
+ const int32_t* endData, const Shape& endShape, const int32_t* stridesData,
+ const Shape& stridesShape, int32_t beginMask, int32_t endMask,
+ int32_t shrinkAxisMask, Shape* output);
bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output);
@@ -428,7 +393,7 @@ inline bool mergeThirdDimension(const T* bufferA, const std::vector<uint32_t>& d
return true;
}
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_UTILS_H
diff --git a/nn/common/include/Tracing.h b/nn/common/include/Tracing.h
index 01535d831..e461b2bdd 100644
--- a/nn/common/include/Tracing.h
+++ b/nn/common/include/Tracing.h
@@ -100,43 +100,41 @@
// Layer Application - For native applications (e.g., unit tests)
#define NNTRACE_APP(phase, detail) NNTRACE_FULL(NNTRACE_LAYER_APPLICATION, phase, detail)
#define NNTRACE_APP_SWITCH(phase, detail) \
- NNTRACE_FULL_SWITCH(NNTRACE_LAYER_APPLICATION, phase, detail)
+ NNTRACE_FULL_SWITCH(NNTRACE_LAYER_APPLICATION, phase, detail)
// Layer Runtime - For the NNAPI runtime
#define NNTRACE_RT(phase, detail) NNTRACE_FULL(NNTRACE_LAYER_RUNTIME, phase, detail)
#define NNTRACE_RT_SWITCH(phase, detail) NNTRACE_FULL_SWITCH(NNTRACE_LAYER_RUNTIME, phase, detail)
// Layer CPU - CPU executor
#define NNTRACE_CPU(phase, detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, phase, detail)
-#define NNTRACE_COMP(detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, \
- NNTRACE_PHASE_COMPUTATION, detail)
-#define NNTRACE_COMP_SWITCH(detail) NNTRACE_FULL_SWITCH(NNTRACE_LAYER_CPU, \
- NNTRACE_PHASE_COMPUTATION, detail)
-#define NNTRACE_TRANS(detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, \
- NNTRACE_PHASE_TRANSFORMATION, detail)
+#define NNTRACE_COMP(detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, NNTRACE_PHASE_COMPUTATION, detail)
+#define NNTRACE_COMP_SWITCH(detail) \
+ NNTRACE_FULL_SWITCH(NNTRACE_LAYER_CPU, NNTRACE_PHASE_COMPUTATION, detail)
+#define NNTRACE_TRANS(detail) NNTRACE_FULL(NNTRACE_LAYER_CPU, NNTRACE_PHASE_TRANSFORMATION, detail)
// Fully specified macros to be used when no convenience wrapper exists for your
// need.
#define NNTRACE_FULL(layer, phase, detail) NNTRACE_NAME_1(("[NN_" layer "_" phase "]" detail))
#define NNTRACE_FULL_SWITCH(layer, phase, detail) \
- NNTRACE_NAME_SWITCH(("[SW][NN_" layer "_" phase "]" detail))
+ NNTRACE_NAME_SWITCH(("[SW][NN_" layer "_" phase "]" detail))
#define NNTRACE_FULL_SUBTRACT(layer, phase, detail) \
- NNTRACE_NAME_1(("[SUB][NN_" layer "_" phase "]" detail))
+ NNTRACE_NAME_1(("[SUB][NN_" layer "_" phase "]" detail))
// Raw macro without scoping requirements, for special cases
-#define NNTRACE_FULL_RAW(layer, phase, detail) android::ScopedTrace PASTE(___tracer, __LINE__) \
- (ATRACE_TAG, ("[NN_" layer "_" phase "]" detail))
+#define NNTRACE_FULL_RAW(layer, phase, detail) \
+ android::ScopedTrace PASTE(___tracer, __LINE__)(ATRACE_TAG, ("[NN_" layer "_" phase "]" detail))
// Tracing buckets - for calculating timing summaries over.
//
// Application-only phases
-#define NNTRACE_PHASE_OVERALL "PO" // Overall program, e.g., one benchmark case
-#define NNTRACE_PHASE_WARMUP "PWU" // Warmup (nesting multiple executions)
-#define NNTRACE_PHASE_BENCHMARK "PBM" // Benchmark (nesting multiple executions)
+#define NNTRACE_PHASE_OVERALL "PO" // Overall program, e.g., one benchmark case
+#define NNTRACE_PHASE_WARMUP "PWU" // Warmup (nesting multiple executions)
+#define NNTRACE_PHASE_BENCHMARK "PBM" // Benchmark (nesting multiple executions)
// Main phases, usable by all layers
-#define NNTRACE_PHASE_INITIALIZATION "PI" // Initialization - not related to a model
-#define NNTRACE_PHASE_PREPARATION "PP" // Model construction
-#define NNTRACE_PHASE_COMPILATION "PC" // Model compilation
-#define NNTRACE_PHASE_EXECUTION "PE" // Executing the model
-#define NNTRACE_PHASE_TERMINATION "PT" // Tearing down
-#define NNTRACE_PHASE_UNSPECIFIED "PU" // Helper code called from multiple phases
+#define NNTRACE_PHASE_INITIALIZATION "PI" // Initialization - not related to a model
+#define NNTRACE_PHASE_PREPARATION "PP" // Model construction
+#define NNTRACE_PHASE_COMPILATION "PC" // Model compilation
+#define NNTRACE_PHASE_EXECUTION "PE" // Executing the model
+#define NNTRACE_PHASE_TERMINATION "PT" // Tearing down
+#define NNTRACE_PHASE_UNSPECIFIED "PU" // Helper code called from multiple phases
// Subphases of execution
#define NNTRACE_PHASE_INPUTS_AND_OUTPUTS "PIO" // Setting inputs/outputs and allocating buffers
#define NNTRACE_PHASE_TRANSFORMATION "PTR" // Transforming data for computation
@@ -149,8 +147,7 @@
#define NNTRACE_LAYER_DRIVER "LD"
#define NNTRACE_LAYER_CPU "LC"
#define NNTRACE_LAYER_OTHER "LO"
-#define NNTRACE_LAYER_UTILITY "LU" // Code used from multiple layers
-
+#define NNTRACE_LAYER_UTILITY "LU" // Code used from multiple layers
// Implementation
//
@@ -162,10 +159,9 @@
// Switching trace, more than one per scope allowed, translated by
// systrace_parser.py. This is mainly useful for tracing multiple phases through
// one function / scope.
-#define NNTRACE_NAME_SWITCH(name) android::ScopedTrace PASTE(___tracer, __LINE__) \
- (ATRACE_TAG, name); \
- (void)___tracer_1 // ensure switch is only used after a basic trace
-
+#define NNTRACE_NAME_SWITCH(name) \
+ android::ScopedTrace PASTE(___tracer, __LINE__)(ATRACE_TAG, name); \
+ (void)___tracer_1 // ensure switch is only used after a basic trace
// Disallow use of raw ATRACE macros
#undef ATRACE_NAME
diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h
index 6c53cc940..595c7098a 100644
--- a/nn/common/include/Utils.h
+++ b/nn/common/include/Utils.h
@@ -50,22 +50,14 @@ const int kOEMCodeBase = 10000;
* forget to update the corresponding 'tags' table in
* the initVlogMask() function implemented in Utils.cpp.
*/
-enum VLogFlags {
- MODEL = 0,
- COMPILATION,
- EXECUTION,
- CPUEXE,
- MANAGER,
- DRIVER
-};
+enum VLogFlags { MODEL = 0, COMPILATION, EXECUTION, CPUEXE, MANAGER, DRIVER };
-#define VLOG_IS_ON(TAG) \
- ((vLogMask & (1 << (TAG))) != 0)
+#define VLOG_IS_ON(TAG) ((vLogMask & (1 << (TAG))) != 0)
-#define VLOG(TAG) \
+#define VLOG(TAG) \
if (LIKELY(!VLOG_IS_ON(TAG))) \
- ; \
- else \
+ ; \
+ else \
LOG(INFO)
extern int vLogMask;
diff --git a/nn/common/operations/ArgMinMax.cpp b/nn/common/operations/ArgMinMax.cpp
index 64d4606d1..95b69bdd2 100644
--- a/nn/common/operations/ArgMinMax.cpp
+++ b/nn/common/operations/ArgMinMax.cpp
@@ -30,22 +30,19 @@ namespace nn {
using namespace hal;
template <typename In, typename Out>
-static void argMinMaxImpl(const In* inputData, const Shape& inputShape,
- int32_t axis, bool isArgMin,
+static void argMinMaxImpl(const In* inputData, const Shape& inputShape, int32_t axis, bool isArgMin,
Out* outputData, const Shape& outputShape) {
const int outerSize = getNumberOfElements(inputShape, 0, axis);
const int axisSize = getSizeOfDimension(inputShape, axis);
- const int innerSize = getNumberOfElements(
- inputShape, axis + 1, getNumberOfDimensions(inputShape));
+ const int innerSize =
+ getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape));
for (int outer = 0; outer < outerSize; ++outer) {
for (int inner = 0; inner < innerSize; ++inner) {
auto minMaxValue = inputData[outer * axisSize * innerSize + inner];
int minMaxIndex = 0;
for (int i = 1; i < axisSize; ++i) {
- const auto& value =
- inputData[(outer * axisSize + i) * innerSize + inner];
- if ((isArgMin && value < minMaxValue) ||
- (!isArgMin && value > minMaxValue)) {
+ const auto& value = inputData[(outer * axisSize + i) * innerSize + inner];
+ if ((isArgMin && value < minMaxValue) || (!isArgMin && value > minMaxValue)) {
minMaxValue = value;
minMaxIndex = i;
}
@@ -55,23 +52,17 @@ static void argMinMaxImpl(const In* inputData, const Shape& inputShape,
}
}
-bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape,
- int32 axis, bool isArgMin,
+bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32 axis, bool isArgMin,
uint8_t* outputData, const Shape& outputShape) {
NNTRACE_TRANS("argMinMaxGeneric");
NN_CHECK(handleNegativeAxis(inputShape, &axis));
-#define NNAPI_IMPL_ARG_MIN_MAX(operandType, dataType) \
- if (inputShape.type == operandType) { \
- NNTRACE_COMP_SWITCH("argMinMaxImpl::" #dataType); \
- argMinMaxImpl( \
- reinterpret_cast<const dataType*>(inputData), \
- inputShape, \
- axis, \
- isArgMin, \
- reinterpret_cast<int32_t*>(outputData), \
- outputShape); \
- return true; \
+#define NNAPI_IMPL_ARG_MIN_MAX(operandType, dataType) \
+ if (inputShape.type == operandType) { \
+ NNTRACE_COMP_SWITCH("argMinMaxImpl::" #dataType); \
+ argMinMaxImpl(reinterpret_cast<const dataType*>(inputData), inputShape, axis, isArgMin, \
+ reinterpret_cast<int32_t*>(outputData), outputShape); \
+ return true; \
}
NNAPI_IMPL_ARG_MIN_MAX(OperandType::TENSOR_FLOAT16, _Float16);
@@ -84,5 +75,5 @@ bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape,
return false;
}
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
diff --git a/nn/common/operations/Conv2D.cpp b/nn/common/operations/Conv2D.cpp
index 678e2d698..152d3d63c 100644
--- a/nn/common/operations/Conv2D.cpp
+++ b/nn/common/operations/Conv2D.cpp
@@ -120,49 +120,49 @@ struct Conv2dParam {
}
};
-#define ANDROID_NN_CONV_PARAMETERS(Type) \
- uint32_t height = getSizeOfDimension(inputShape, 1); \
- uint32_t width = getSizeOfDimension(inputShape, 2); \
- uint32_t filterHeight = getSizeOfDimension(filterShape, 1); \
- uint32_t filterWidth = getSizeOfDimension(filterShape, 2); \
- uint32_t outHeight = getSizeOfDimension(outputShape, 1); \
- uint32_t outWidth = getSizeOfDimension(outputShape, 2); \
- uint32_t inDepth = getSizeOfDimension(inputShape, 3); \
- \
- uint32_t paddingHeight = (uint32_t)padding_top; \
- uint32_t paddingWidth = (uint32_t)padding_left; \
- \
- tflite::Dims<4> im2colDim; \
- im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0); \
- im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1); \
- im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2); \
- im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth; \
- \
- im2colDim.strides[0] = 1; \
- for (int i=1; i<4; i++) { \
- im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1]; \
- } \
- \
- Type* im2colData = nullptr; \
- uint64_t im2colByteSize = sizeof(Type); \
- std::unique_ptr<Type[]> im2colGuard; \
- for (int i=0; i<4; i++) { \
- im2colByteSize *= im2colDim.sizes[i]; \
- } \
- /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */ \
- if (im2colByteSize >= 0x7fffffff) { \
- LOG(ERROR) << "Conv size is too large, not enough memory"; \
- return false; \
- } \
- if (im2colByteSize <= kStaticBufferSize) { \
- im2colData = reinterpret_cast<Type *>(static_scratch_buffer); \
- } else { \
- im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)]; \
- if (im2colData == nullptr) { \
- LOG(ERROR) << "Conv size is too large, not enough memory"; \
- return false; \
- } \
- im2colGuard.reset(im2colData); \
+#define ANDROID_NN_CONV_PARAMETERS(Type) \
+ uint32_t height = getSizeOfDimension(inputShape, 1); \
+ uint32_t width = getSizeOfDimension(inputShape, 2); \
+ uint32_t filterHeight = getSizeOfDimension(filterShape, 1); \
+ uint32_t filterWidth = getSizeOfDimension(filterShape, 2); \
+ uint32_t outHeight = getSizeOfDimension(outputShape, 1); \
+ uint32_t outWidth = getSizeOfDimension(outputShape, 2); \
+ uint32_t inDepth = getSizeOfDimension(inputShape, 3); \
+ \
+ uint32_t paddingHeight = (uint32_t)padding_top; \
+ uint32_t paddingWidth = (uint32_t)padding_left; \
+ \
+ tflite::Dims<4> im2colDim; \
+ im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0); \
+ im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1); \
+ im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2); \
+ im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth; \
+ \
+ im2colDim.strides[0] = 1; \
+ for (int i = 1; i < 4; i++) { \
+ im2colDim.strides[i] = im2colDim.strides[i - 1] * im2colDim.sizes[i - 1]; \
+ } \
+ \
+ Type* im2colData = nullptr; \
+ uint64_t im2colByteSize = sizeof(Type); \
+ std::unique_ptr<Type[]> im2colGuard; \
+ for (int i = 0; i < 4; i++) { \
+ im2colByteSize *= im2colDim.sizes[i]; \
+ } \
+ /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */ \
+ if (im2colByteSize >= 0x7fffffff) { \
+ LOG(ERROR) << "Conv size is too large, not enough memory"; \
+ return false; \
+ } \
+ if (im2colByteSize <= kStaticBufferSize) { \
+ im2colData = reinterpret_cast<Type*>(static_scratch_buffer); \
+ } else { \
+ im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)]; \
+ if (im2colData == nullptr) { \
+ LOG(ERROR) << "Conv size is too large, not enough memory"; \
+ return false; \
+ } \
+ im2colGuard.reset(im2colData); \
}
bool convNhwc(const float* inputData, const Shape& inputShape, const float* filterData,
diff --git a/nn/common/operations/EmbeddingLookup.cpp b/nn/common/operations/EmbeddingLookup.cpp
index f3b2911e7..bf3be6317 100644
--- a/nn/common/operations/EmbeddingLookup.cpp
+++ b/nn/common/operations/EmbeddingLookup.cpp
@@ -31,30 +31,29 @@ using namespace hal;
EmbeddingLookup::EmbeddingLookup(const Operation& operation,
std::vector<RunTimeOperandInfo>& operands) {
- value_ = GetInput(operation, operands, kValueTensor);
- lookup_ = GetInput(operation, operands, kLookupTensor);
+ value_ = GetInput(operation, operands, kValueTensor);
+ lookup_ = GetInput(operation, operands, kLookupTensor);
- output_ = GetOutput(operation, operands, kOutputTensor);
+ output_ = GetOutput(operation, operands, kOutputTensor);
}
bool EmbeddingLookup::Eval() {
- NNTRACE_COMP("EmbeddingLookup::Eval");
- const int row_size = value_->shape().dimensions[0];
- const int total_bytes = nonExtensionOperandSizeOfData(value_->type, value_->dimensions);
- const int row_bytes = total_bytes/row_size;
-
- for (uint32_t i = 0; i < lookup_->shape().dimensions[0]; i++) {
- int idx = (reinterpret_cast<int*>(lookup_->buffer))[i];
- if (idx >= row_size || idx < 0) {
- LOG(ERROR) << "Embedding Lookup: index out of bounds.";
- return false;
- } else {
- memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes,
- row_bytes);
+ NNTRACE_COMP("EmbeddingLookup::Eval");
+ const int row_size = value_->shape().dimensions[0];
+ const int total_bytes = nonExtensionOperandSizeOfData(value_->type, value_->dimensions);
+ const int row_bytes = total_bytes / row_size;
+
+ for (uint32_t i = 0; i < lookup_->shape().dimensions[0]; i++) {
+ int idx = (reinterpret_cast<int*>(lookup_->buffer))[i];
+ if (idx >= row_size || idx < 0) {
+ LOG(ERROR) << "Embedding Lookup: index out of bounds.";
+ return false;
+ } else {
+ memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes, row_bytes);
+ }
}
- }
- return true;
+ return true;
}
} // namespace nn
diff --git a/nn/common/operations/EmbeddingLookup.h b/nn/common/operations/EmbeddingLookup.h
index bb89e24b2..9109ddfdf 100644
--- a/nn/common/operations/EmbeddingLookup.h
+++ b/nn/common/operations/EmbeddingLookup.h
@@ -27,22 +27,22 @@ namespace nn {
struct RunTimeOperandInfo;
class EmbeddingLookup {
- public:
- EmbeddingLookup(const hardware::neuralnetworks::V1_2::Operation& operation,
- std::vector<RunTimeOperandInfo>& operands);
+ public:
+ EmbeddingLookup(const hardware::neuralnetworks::V1_2::Operation& operation,
+ std::vector<RunTimeOperandInfo>& operands);
- bool Eval();
+ bool Eval();
- static constexpr int kLookupTensor = 0;
- static constexpr int kValueTensor = 1;
+ static constexpr int kLookupTensor = 0;
+ static constexpr int kValueTensor = 1;
- static constexpr int kOutputTensor = 0;
+ static constexpr int kOutputTensor = 0;
- private:
- const RunTimeOperandInfo *value_;
- const RunTimeOperandInfo *lookup_;
+ private:
+ const RunTimeOperandInfo* value_;
+ const RunTimeOperandInfo* lookup_;
- RunTimeOperandInfo *output_;
+ RunTimeOperandInfo* output_;
};
} // namespace nn
diff --git a/nn/common/operations/EmbeddingLookupTest.cpp b/nn/common/operations/EmbeddingLookupTest.cpp
index d864ab73a..10e2e339a 100644
--- a/nn/common/operations/EmbeddingLookupTest.cpp
+++ b/nn/common/operations/EmbeddingLookupTest.cpp
@@ -31,13 +31,13 @@ namespace wrapper {
namespace {
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
- float max_abs_error=1.e-6) {
- std::vector<Matcher<float>> matchers;
- matchers.reserve(values.size());
- for (const float& v : values) {
- matchers.emplace_back(FloatNear(v, max_abs_error));
- }
- return matchers;
+ float max_abs_error = 1.e-6) {
+ std::vector<Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(FloatNear(v, max_abs_error));
+ }
+ return matchers;
}
} // namespace
@@ -45,109 +45,108 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
using ::testing::ElementsAreArray;
#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
- ACTION(Value, float) \
- ACTION(Lookup, int)
+ ACTION(Value, float) \
+ ACTION(Lookup, int)
// For all output and intermediate states
-#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
- ACTION(Output, float)
+#define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, float)
class EmbeddingLookupOpModel {
- public:
- EmbeddingLookupOpModel(std::initializer_list<uint32_t> index_shape,
- std::initializer_list<uint32_t> weight_shape) {
- auto it = weight_shape.begin();
- rows_ = *it++;
- columns_ = *it++;
- features_ = *it;
+ public:
+ EmbeddingLookupOpModel(std::initializer_list<uint32_t> index_shape,
+ std::initializer_list<uint32_t> weight_shape) {
+ auto it = weight_shape.begin();
+ rows_ = *it++;
+ columns_ = *it++;
+ features_ = *it;
- std::vector<uint32_t> inputs;
+ std::vector<uint32_t> inputs;
- OperandType LookupTy(Type::TENSOR_INT32, index_shape);
- inputs.push_back(model_.addOperand(&LookupTy));
+ OperandType LookupTy(Type::TENSOR_INT32, index_shape);
+ inputs.push_back(model_.addOperand(&LookupTy));
- OperandType ValueTy(Type::TENSOR_FLOAT32, weight_shape);
- inputs.push_back(model_.addOperand(&ValueTy));
+ OperandType ValueTy(Type::TENSOR_FLOAT32, weight_shape);
+ inputs.push_back(model_.addOperand(&ValueTy));
- std::vector<uint32_t> outputs;
+ std::vector<uint32_t> outputs;
- OperandType OutputOpndTy(Type::TENSOR_FLOAT32, weight_shape);
- outputs.push_back(model_.addOperand(&OutputOpndTy));
+ OperandType OutputOpndTy(Type::TENSOR_FLOAT32, weight_shape);
+ outputs.push_back(model_.addOperand(&OutputOpndTy));
- auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
- uint32_t sz = 1;
- for (uint32_t d : dims) { sz *= d; }
- return sz;
- };
+ auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
+ uint32_t sz = 1;
+ for (uint32_t d : dims) {
+ sz *= d;
+ }
+ return sz;
+ };
- Value_.insert(Value_.end(), multiAll(weight_shape), 0.f);
- Output_.insert(Output_.end(), multiAll(weight_shape), 0.f);
+ Value_.insert(Value_.end(), multiAll(weight_shape), 0.f);
+ Output_.insert(Output_.end(), multiAll(weight_shape), 0.f);
- model_.addOperation(ANEURALNETWORKS_EMBEDDING_LOOKUP, inputs, outputs);
- model_.identifyInputsAndOutputs(inputs, outputs);
+ model_.addOperation(ANEURALNETWORKS_EMBEDDING_LOOKUP, inputs, outputs);
+ model_.identifyInputsAndOutputs(inputs, outputs);
- model_.finish();
- }
+ model_.finish();
+ }
- void Invoke() {
- ASSERT_TRUE(model_.isValid());
+ void Invoke() {
+ ASSERT_TRUE(model_.isValid());
- Compilation compilation(&model_);
- compilation.finish();
- Execution execution(&compilation);
+ Compilation compilation(&model_);
+ compilation.finish();
+ Execution execution(&compilation);
#define SetInputOrWeight(X, T) \
- ASSERT_EQ(execution.setInput(EmbeddingLookup::k##X##Tensor, X##_.data(), \
- sizeof(T) * X##_.size()), \
- Result::NO_ERROR);
+ ASSERT_EQ(execution.setInput(EmbeddingLookup::k##X##Tensor, X##_.data(), \
+ sizeof(T) * X##_.size()), \
+ Result::NO_ERROR);
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
#undef SetInputOrWeight
#define SetOutput(X, T) \
- ASSERT_EQ(execution.setOutput(EmbeddingLookup::k##X##Tensor, X##_.data(), \
- sizeof(T) * X##_.size()), \
- Result::NO_ERROR);
+ ASSERT_EQ(execution.setOutput(EmbeddingLookup::k##X##Tensor, X##_.data(), \
+ sizeof(T) * X##_.size()), \
+ Result::NO_ERROR);
- FOR_ALL_OUTPUT_TENSORS(SetOutput);
+ FOR_ALL_OUTPUT_TENSORS(SetOutput);
#undef SetOutput
- ASSERT_EQ(execution.compute(), Result::NO_ERROR);
- }
+ ASSERT_EQ(execution.compute(), Result::NO_ERROR);
+ }
-#define DefineSetter(X, T) \
- void Set##X(const std::vector<T>& f) { \
- X##_.insert(X##_.end(), f.begin(), f.end()); \
- }
+#define DefineSetter(X, T) \
+ void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
#undef DefineSetter
- void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
- for (uint32_t i = 0; i < rows_; i++) {
- for (uint32_t j = 0; j < columns_; j++) {
- for (uint32_t k = 0; k < features_; k++) {
- Value_[(i * columns_ + j) * features_ + k] = function(i, j, k);
+ void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
+ for (uint32_t i = 0; i < rows_; i++) {
+ for (uint32_t j = 0; j < columns_; j++) {
+ for (uint32_t k = 0; k < features_; k++) {
+ Value_[(i * columns_ + j) * features_ + k] = function(i, j, k);
+ }
+ }
}
- }
}
- }
- const std::vector<float> &GetOutput() const { return Output_; }
+ const std::vector<float>& GetOutput() const { return Output_; }
- private:
- Model model_;
- uint32_t rows_;
- uint32_t columns_;
- uint32_t features_;
+ private:
+ Model model_;
+ uint32_t rows_;
+ uint32_t columns_;
+ uint32_t features_;
#define DefineTensor(X, T) std::vector<T> X##_;
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
- FOR_ALL_OUTPUT_TENSORS(DefineTensor);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
+ FOR_ALL_OUTPUT_TENSORS(DefineTensor);
#undef DefineTensor
};
@@ -155,19 +154,17 @@ class EmbeddingLookupOpModel {
// TODO: write more tests that exercise the details of the op, such as
// lookup errors and variable input shapes.
TEST(EmbeddingLookupOpTest, SimpleTest) {
- EmbeddingLookupOpModel m({3}, {3, 2, 4});
- m.SetLookup({1, 0, 2});
- m.Set3DWeightMatrix(
- [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
-
- m.Invoke();
-
- EXPECT_THAT(m.GetOutput(),
- ElementsAreArray(ArrayFloatNear({
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
- })));
+ EmbeddingLookupOpModel m({3}, {3, 2, 4});
+ m.SetLookup({1, 0, 2});
+ m.Set3DWeightMatrix([](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ })));
}
} // namespace wrapper
diff --git a/nn/common/operations/GenerateProposals.cpp b/nn/common/operations/GenerateProposals.cpp
index 67d614f6b..d70271504 100644
--- a/nn/common/operations/GenerateProposals.cpp
+++ b/nn/common/operations/GenerateProposals.cpp
@@ -480,10 +480,10 @@ bool boxWithNmsLimitFloat32Compute(float* scoresData, const Shape& scoresShape,
NN_RET_CHECK_LE(roi[1], roi[3]);
}
std::vector<uint32_t> result;
- softNmsMultiClass(scoresBase, numClasses, batchSplitIn->at(b), scoreThreshold,
- nmsScoreThreshold, maxNumDetections, maxNumDetections,
- [&roiBase](uint32_t ind) { return roiBase + ind * kRoiDim; }, kernel,
- &result);
+ softNmsMultiClass(
+ scoresBase, numClasses, batchSplitIn->at(b), scoreThreshold, nmsScoreThreshold,
+ maxNumDetections, maxNumDetections,
+ [&roiBase](uint32_t ind) { return roiBase + ind * kRoiDim; }, kernel, &result);
// Sort again by class.
std::sort(result.begin(), result.end(),
[&scoresBase, numClasses](const uint32_t& lhs, const uint32_t& rhs) {
diff --git a/nn/common/operations/HashtableLookup.cpp b/nn/common/operations/HashtableLookup.cpp
index 67cdffd2d..ed4de8e64 100644
--- a/nn/common/operations/HashtableLookup.cpp
+++ b/nn/common/operations/HashtableLookup.cpp
@@ -32,47 +32,46 @@ namespace {
using namespace hal;
int greater(const void* a, const void* b) {
- return *static_cast<const int*>(a) - *static_cast<const int*>(b);
+ return *static_cast<const int*>(a) - *static_cast<const int*>(b);
}
} // anonymous namespace
HashtableLookup::HashtableLookup(const Operation& operation,
std::vector<RunTimeOperandInfo>& operands) {
- lookup_ = GetInput(operation, operands, kLookupTensor);
- key_ = GetInput(operation, operands, kKeyTensor);
- value_ = GetInput(operation, operands, kValueTensor);
+ lookup_ = GetInput(operation, operands, kLookupTensor);
+ key_ = GetInput(operation, operands, kKeyTensor);
+ value_ = GetInput(operation, operands, kValueTensor);
- output_ = GetOutput(operation, operands, kOutputTensor);
- hits_ = GetOutput(operation, operands, kHitsTensor);
+ output_ = GetOutput(operation, operands, kOutputTensor);
+ hits_ = GetOutput(operation, operands, kHitsTensor);
}
bool HashtableLookup::Eval() {
- NNTRACE_COMP("HashtableLookup::Eval");
- const int num_rows = value_->shape().dimensions[0];
- const int row_bytes = nonExtensionOperandSizeOfData(value_->type, value_->dimensions) / num_rows;
- void* pointer = nullptr;
-
- for (int i = 0; i < static_cast<int>(lookup_->shape().dimensions[0]); i++) {
- int idx = -1;
- pointer = bsearch(lookup_->buffer + sizeof(int) * i, key_->buffer,
- num_rows, sizeof(int), greater);
- if (pointer != nullptr) {
- idx =
- (reinterpret_cast<uint8_t*>(pointer) - key_->buffer) / sizeof(float);
+ NNTRACE_COMP("HashtableLookup::Eval");
+ const int num_rows = value_->shape().dimensions[0];
+ const int row_bytes =
+ nonExtensionOperandSizeOfData(value_->type, value_->dimensions) / num_rows;
+ void* pointer = nullptr;
+
+ for (int i = 0; i < static_cast<int>(lookup_->shape().dimensions[0]); i++) {
+ int idx = -1;
+ pointer = bsearch(lookup_->buffer + sizeof(int) * i, key_->buffer, num_rows, sizeof(int),
+ greater);
+ if (pointer != nullptr) {
+ idx = (reinterpret_cast<uint8_t*>(pointer) - key_->buffer) / sizeof(float);
+ }
+
+ if (idx >= num_rows || idx < 0) {
+ memset(output_->buffer + i * row_bytes, 0, row_bytes);
+ hits_->buffer[i] = 0;
+ } else {
+ memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes, row_bytes);
+ hits_->buffer[i] = 1;
+ }
}
- if (idx >= num_rows || idx < 0) {
- memset(output_->buffer + i * row_bytes, 0, row_bytes);
- hits_->buffer[i] = 0;
- } else {
- memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes,
- row_bytes);
- hits_->buffer[i] = 1;
- }
- }
-
- return true;
+ return true;
}
} // namespace nn
diff --git a/nn/common/operations/HashtableLookup.h b/nn/common/operations/HashtableLookup.h
index 52d9cc5bc..854e7dff7 100644
--- a/nn/common/operations/HashtableLookup.h
+++ b/nn/common/operations/HashtableLookup.h
@@ -27,26 +27,26 @@ namespace nn {
struct RunTimeOperandInfo;
class HashtableLookup {
- public:
- HashtableLookup(const hardware::neuralnetworks::V1_2::Operation& operation,
- std::vector<RunTimeOperandInfo>& operands);
+ public:
+ HashtableLookup(const hardware::neuralnetworks::V1_2::Operation& operation,
+ std::vector<RunTimeOperandInfo>& operands);
- bool Eval();
+ bool Eval();
- static constexpr int kLookupTensor = 0;
- static constexpr int kKeyTensor = 1;
- static constexpr int kValueTensor = 2;
+ static constexpr int kLookupTensor = 0;
+ static constexpr int kKeyTensor = 1;
+ static constexpr int kValueTensor = 2;
- static constexpr int kOutputTensor = 0;
- static constexpr int kHitsTensor = 1;
+ static constexpr int kOutputTensor = 0;
+ static constexpr int kHitsTensor = 1;
- private:
- const RunTimeOperandInfo *lookup_;
- const RunTimeOperandInfo *key_;
- const RunTimeOperandInfo *value_;
+ private:
+ const RunTimeOperandInfo* lookup_;
+ const RunTimeOperandInfo* key_;
+ const RunTimeOperandInfo* value_;
- RunTimeOperandInfo *output_;
- RunTimeOperandInfo *hits_;
+ RunTimeOperandInfo* output_;
+ RunTimeOperandInfo* hits_;
};
} // namespace nn
diff --git a/nn/common/operations/HashtableLookupTest.cpp b/nn/common/operations/HashtableLookupTest.cpp
index 7fbab58ce..ff62006b8 100644
--- a/nn/common/operations/HashtableLookupTest.cpp
+++ b/nn/common/operations/HashtableLookupTest.cpp
@@ -31,158 +31,160 @@ namespace wrapper {
namespace {
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
- float max_abs_error=1.e-6) {
- std::vector<Matcher<float>> matchers;
- matchers.reserve(values.size());
- for (const float& v : values) {
- matchers.emplace_back(FloatNear(v, max_abs_error));
- }
- return matchers;
+ float max_abs_error = 1.e-6) {
+ std::vector<Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(FloatNear(v, max_abs_error));
+ }
+ return matchers;
}
} // namespace
using ::testing::ElementsAreArray;
-#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
- ACTION(Lookup, int) \
- ACTION(Key, int) \
- ACTION(Value, float)
+#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
+ ACTION(Lookup, int) \
+ ACTION(Key, int) \
+ ACTION(Value, float)
// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
- ACTION(Output, float) \
- ACTION(Hits, uint8_t)
+ ACTION(Output, float) \
+ ACTION(Hits, uint8_t)
class HashtableLookupOpModel {
- public:
+ public:
HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,
std::initializer_list<uint32_t> key_shape,
std::initializer_list<uint32_t> value_shape) {
- auto it_vs = value_shape.begin();
- rows_ = *it_vs++;
- features_ = *it_vs;
+ auto it_vs = value_shape.begin();
+ rows_ = *it_vs++;
+ features_ = *it_vs;
- std::vector<uint32_t> inputs;
+ std::vector<uint32_t> inputs;
- // Input and weights
- OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
- inputs.push_back(model_.addOperand(&LookupTy));
+ // Input and weights
+ OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
+ inputs.push_back(model_.addOperand(&LookupTy));
- OperandType KeyTy(Type::TENSOR_INT32, key_shape);
- inputs.push_back(model_.addOperand(&KeyTy));
+ OperandType KeyTy(Type::TENSOR_INT32, key_shape);
+ inputs.push_back(model_.addOperand(&KeyTy));
- OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
- inputs.push_back(model_.addOperand(&ValueTy));
+ OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
+ inputs.push_back(model_.addOperand(&ValueTy));
- // Output and other intermediate state
- std::vector<uint32_t> outputs;
+ // Output and other intermediate state
+ std::vector<uint32_t> outputs;
- std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
- out_dim.push_back(features_);
+ std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
+ out_dim.push_back(features_);
- OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
- outputs.push_back(model_.addOperand(&OutputOpndTy));
+ OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
+ outputs.push_back(model_.addOperand(&OutputOpndTy));
- OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
- outputs.push_back(model_.addOperand(&HitsOpndTy));
+ OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
+ outputs.push_back(model_.addOperand(&HitsOpndTy));
- auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
- uint32_t sz = 1;
- for (uint32_t d : dims) { sz *= d; }
- return sz;
- };
+ auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
+ uint32_t sz = 1;
+ for (uint32_t d : dims) {
+ sz *= d;
+ }
+ return sz;
+ };
- Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
- Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
- Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
+ Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
+ Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
+ Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
- model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
- model_.identifyInputsAndOutputs(inputs, outputs);
+ model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
+ model_.identifyInputsAndOutputs(inputs, outputs);
- model_.finish();
- }
+ model_.finish();
+ }
- void Invoke() {
- ASSERT_TRUE(model_.isValid());
+ void Invoke() {
+ ASSERT_TRUE(model_.isValid());
- Compilation compilation(&model_);
- compilation.finish();
- Execution execution(&compilation);
+ Compilation compilation(&model_);
+ compilation.finish();
+ Execution execution(&compilation);
-#define SetInputOrWeight(X, T) \
- ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
- sizeof(T) * X##_.size()), \
- Result::NO_ERROR);
+#define SetInputOrWeight(X, T) \
+ ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
+ sizeof(T) * X##_.size()), \
+ Result::NO_ERROR);
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
#undef SetInputOrWeight
-#define SetOutput(X, T) \
- ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
- sizeof(T) * X##_.size()), \
- Result::NO_ERROR);
+#define SetOutput(X, T) \
+ ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
+ sizeof(T) * X##_.size()), \
+ Result::NO_ERROR);
- FOR_ALL_OUTPUT_TENSORS(SetOutput);
+ FOR_ALL_OUTPUT_TENSORS(SetOutput);
#undef SetOutput
- ASSERT_EQ(execution.compute(), Result::NO_ERROR);
- }
+ ASSERT_EQ(execution.compute(), Result::NO_ERROR);
+ }
-#define DefineSetter(X, T) \
- void Set##X(const std::vector<T>& f) { \
- X##_.insert(X##_.end(), f.begin(), f.end()); \
- }
+#define DefineSetter(X, T) \
+ void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
#undef DefineSetter
- void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
- for (uint32_t i = 0; i < rows_; i++) {
- for (uint32_t j = 0; j < features_; j++) {
- Value_[i * features_ + j] = function(i, j);
- }
+ void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
+ for (uint32_t i = 0; i < rows_; i++) {
+ for (uint32_t j = 0; j < features_; j++) {
+ Value_[i * features_ + j] = function(i, j);
+ }
+ }
}
- }
- const std::vector<float>& GetOutput() const { return Output_; }
- const std::vector<uint8_t>& GetHits() const { return Hits_; }
+ const std::vector<float>& GetOutput() const { return Output_; }
+ const std::vector<uint8_t>& GetHits() const { return Hits_; }
- private:
- Model model_;
- uint32_t rows_;
- uint32_t features_;
+ private:
+ Model model_;
+ uint32_t rows_;
+ uint32_t features_;
#define DefineTensor(X, T) std::vector<T> X##_;
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
- FOR_ALL_OUTPUT_TENSORS(DefineTensor);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
+ FOR_ALL_OUTPUT_TENSORS(DefineTensor);
#undef DefineTensor
};
TEST(HashtableLookupOpTest, BlackBoxTest) {
- HashtableLookupOpModel m({4}, {3}, {3, 2});
-
- m.SetLookup({1234, -292, -11, 0});
- m.SetKey({-11, 0, 1234});
- m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
-
- m.Invoke();
-
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 2.0, 2.1, // 2-rd item
- 0, 0, // Not found
- 0.0, 0.1, // 0-th item
- 1.0, 1.1, // 1-st item
- })));
- EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
- 1, 0, 1, 1,
+ HashtableLookupOpModel m({4}, {3}, {3, 2});
+
+ m.SetLookup({1234, -292, -11, 0});
+ m.SetKey({-11, 0, 1234});
+ m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 2.0, 2.1, // 2-rd item
+ 0, 0, // Not found
+ 0.0, 0.1, // 0-th item
+ 1.0, 1.1, // 1-st item
+ })));
+ EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
+ 1,
+ 0,
+ 1,
+ 1,
}));
-
}
} // namespace wrapper
diff --git a/nn/common/operations/LSTMTest.cpp b/nn/common/operations/LSTMTest.cpp
index edbf12841..35e5eded9 100644
--- a/nn/common/operations/LSTMTest.cpp
+++ b/nn/common/operations/LSTMTest.cpp
@@ -32,60 +32,61 @@ using ::testing::Matcher;
namespace {
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
- float max_abs_error=1.e-6) {
- std::vector<Matcher<float>> matchers;
- matchers.reserve(values.size());
- for (const float& v : values) {
- matchers.emplace_back(FloatNear(v, max_abs_error));
- }
- return matchers;
+ float max_abs_error = 1.e-6) {
+ std::vector<Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(FloatNear(v, max_abs_error));
+ }
+ return matchers;
}
} // anonymous namespace
-#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
- ACTION(Input) \
- ACTION(InputToInputWeights) \
- ACTION(InputToCellWeights) \
- ACTION(InputToForgetWeights) \
- ACTION(InputToOutputWeights) \
- ACTION(RecurrentToInputWeights) \
- ACTION(RecurrentToCellWeights) \
- ACTION(RecurrentToForgetWeights) \
- ACTION(RecurrentToOutputWeights) \
- ACTION(CellToInputWeights) \
- ACTION(CellToForgetWeights) \
- ACTION(CellToOutputWeights) \
- ACTION(InputGateBias) \
- ACTION(CellGateBias) \
- ACTION(ForgetGateBias) \
- ACTION(OutputGateBias) \
- ACTION(ProjectionWeights) \
- ACTION(ProjectionBias) \
- ACTION(OutputStateIn) \
+#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
+ ACTION(Input) \
+ ACTION(InputToInputWeights) \
+ ACTION(InputToCellWeights) \
+ ACTION(InputToForgetWeights) \
+ ACTION(InputToOutputWeights) \
+ ACTION(RecurrentToInputWeights) \
+ ACTION(RecurrentToCellWeights) \
+ ACTION(RecurrentToForgetWeights) \
+ ACTION(RecurrentToOutputWeights) \
+ ACTION(CellToInputWeights) \
+ ACTION(CellToForgetWeights) \
+ ACTION(CellToOutputWeights) \
+ ACTION(InputGateBias) \
+ ACTION(CellGateBias) \
+ ACTION(ForgetGateBias) \
+ ACTION(OutputGateBias) \
+ ACTION(ProjectionWeights) \
+ ACTION(ProjectionBias) \
+ ACTION(OutputStateIn) \
ACTION(CellStateIn)
// For all output and intermediate states
-#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
- ACTION(ScratchBuffer) \
- ACTION(OutputStateOut) \
- ACTION(CellStateOut) \
- ACTION(Output) \
+#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
+ ACTION(ScratchBuffer) \
+ ACTION(OutputStateOut) \
+ ACTION(CellStateOut) \
+ ACTION(Output)
class LSTMOpModel {
-public:
- LSTMOpModel(uint32_t n_batch, uint32_t n_input,
- uint32_t n_cell, uint32_t n_output, bool use_cifg,
- bool use_peephole, bool use_projection_weights,
+ public:
+ LSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
+ bool use_cifg, bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip,
const std::vector<std::vector<uint32_t>>& input_shapes0)
: n_input_(n_input),
n_output_(n_output),
- use_cifg_(use_cifg), use_peephole_(use_peephole),
+ use_cifg_(use_cifg),
+ use_peephole_(use_peephole),
use_projection_weights_(use_projection_weights),
use_projection_bias_(use_projection_bias),
activation_(ActivationFn::kActivationTanh),
- cell_clip_(cell_clip), proj_clip_(proj_clip) {
+ cell_clip_(cell_clip),
+ proj_clip_(proj_clip) {
std::vector<uint32_t> inputs;
std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
@@ -94,9 +95,9 @@ public:
auto it = input_shapes.begin();
// Input and weights
-#define AddInput(X) \
- OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
- inputs.push_back(model_.addOperand(&X##OpndTy));
+#define AddInput(X) \
+ OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
+ inputs.push_back(model_.addOperand(&X##OpndTy));
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
@@ -112,18 +113,18 @@ public:
// Output and other intermediate state
std::vector<std::vector<uint32_t>> output_shapes{
- {n_batch, n_cell * (use_cifg ? 3 : 4)},
- {n_batch, n_output},
- {n_batch, n_cell},
- {n_batch, n_output},
+ {n_batch, n_cell * (use_cifg ? 3 : 4)},
+ {n_batch, n_output},
+ {n_batch, n_cell},
+ {n_batch, n_output},
};
std::vector<uint32_t> outputs;
auto it2 = output_shapes.begin();
-#define AddOutput(X)\
- OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
- outputs.push_back(model_.addOperand(&X##OpndTy));
+#define AddOutput(X) \
+ OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
+ outputs.push_back(model_.addOperand(&X##OpndTy));
FOR_ALL_OUTPUT_TENSORS(AddOutput);
@@ -136,9 +137,11 @@ public:
OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
- auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
+ auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
uint32_t sz = 1;
- for(uint32_t d:dims) { sz *= d; }
+ for (uint32_t d : dims) {
+ sz *= d;
+ }
return sz;
};
@@ -153,10 +156,8 @@ public:
model_.finish();
}
-#define DefineSetter(X) \
- void Set##X(const std::vector<float> &f) { \
- X##_.insert(X##_.end(), f.begin(), f.end()); \
- }
+#define DefineSetter(X) \
+ void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
@@ -172,8 +173,8 @@ public:
std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
}
- void SetInput(int offset, float *begin, float *end) {
- for (;begin != end; begin++, offset++) {
+ void SetInput(int offset, float* begin, float* end) {
+ for (; begin != end; begin++, offset++) {
Input_[offset] = *begin;
}
}
@@ -181,7 +182,7 @@ public:
uint32_t num_inputs() const { return n_input_; }
uint32_t num_outputs() const { return n_output_; }
- const std::vector<float> &GetOutput() const { return Output_; }
+ const std::vector<float>& GetOutput() const { return Output_; }
void Invoke() {
ASSERT_TRUE(model_.isValid());
@@ -192,19 +193,19 @@ public:
Compilation compilation(&model_);
compilation.finish();
Execution execution(&compilation);
-#define SetInputOrWeight(X) \
- ASSERT_EQ(execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), \
- sizeof(float)*X##_.size()), \
- Result::NO_ERROR);
+#define SetInputOrWeight(X) \
+ ASSERT_EQ( \
+ execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
+ Result::NO_ERROR);
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
#undef SetInputOrWeight
-#define SetOutput(X) \
- ASSERT_EQ(execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), \
- sizeof(float)*X##_.size()), \
- Result::NO_ERROR);
+#define SetOutput(X) \
+ ASSERT_EQ( \
+ execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
+ Result::NO_ERROR);
FOR_ALL_OUTPUT_TENSORS(SetOutput);
@@ -234,20 +235,17 @@ public:
execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
}
- ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam,
- &activation_, sizeof(activation_)),
+ ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
Result::NO_ERROR);
- ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam,
- &cell_clip_, sizeof(cell_clip_)),
+ ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
Result::NO_ERROR);
- ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam,
- &proj_clip_, sizeof(proj_clip_)),
+ ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
Result::NO_ERROR);
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
}
-private:
+ private:
Model model_;
// Execution execution_;
const uint32_t n_input_;
@@ -262,8 +260,7 @@ private:
const float cell_clip_;
const float proj_clip_;
-#define DefineTensor(X) \
- std::vector<float> X##_;
+#define DefineTensor(X) std::vector<float> X##_;
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
FOR_ALL_OUTPUT_TENSORS(DefineTensor);
@@ -272,843 +269,741 @@ private:
};
TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
- const int n_batch = 1;
- const int n_input = 2;
- // n_cell and n_output have the same size when there is no projection.
- const int n_cell = 4;
- const int n_output = 4;
-
- LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
- /*use_cifg=*/false, /*use_peephole=*/false,
- /*use_projection_weights=*/false,
- /*use_projection_bias=*/false,
- /*cell_clip=*/0.0, /*proj_clip=*/0.0,
- {
- {n_batch, n_input}, // input tensor
-
- {n_cell, n_input}, // input_to_input_weight tensor
- {n_cell, n_input}, // input_to_forget_weight tensor
- {n_cell, n_input}, // input_to_cell_weight tensor
- {n_cell, n_input}, // input_to_output_weight tensor
-
- {n_cell, n_output}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
-
- {0}, // cell_to_input_weight tensor
- {0}, // cell_to_forget_weight tensor
- {0}, // cell_to_output_weight tensor
-
- {n_cell}, // input_gate_bias tensor
- {n_cell}, // forget_gate_bias tensor
- {n_cell}, // cell_bias tensor
- {n_cell}, // output_gate_bias tensor
-
- {0, 0}, // projection_weight tensor
- {0}, // projection_bias tensor
- });
-
- lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
- -0.34550029, 0.04266912, -0.15680569,
- -0.34856534, 0.43890524});
-
- lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
- -0.20583314, 0.44344562, 0.22077113,
- -0.29909778});
-
- lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
- -0.31343272, -0.40032279, 0.44781327,
- 0.01387155, -0.35593212});
-
- lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
- 0.40525138, 0.44272184, 0.03897077, -0.1556896,
- 0.19487578});
-
- lstm.SetInputGateBias({0., 0., 0., 0.});
-
- lstm.SetCellGateBias({0., 0., 0., 0.});
-
- lstm.SetForgetGateBias({1., 1., 1., 1.});
-
- lstm.SetOutputGateBias({0., 0., 0., 0.});
-
- lstm.SetRecurrentToInputWeights(
- {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
- -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
- -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
-
- lstm.SetRecurrentToCellWeights(
- {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
- -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
- -0.46367589, 0.26016325, -0.03894562, -0.16368064});
-
- lstm.SetRecurrentToForgetWeights(
- {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
- -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
- 0.28053468, 0.01560611, -0.20127171, -0.01140004});
-
- lstm.SetRecurrentToOutputWeights(
- {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
- 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
- -0.51818722, -0.15390486, 0.0468148, 0.39922136});
-
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
- -0.15358765, -0.03716109, 0.12507336,
- 0.41193449, -0.20860538, -0.15053082,
- 0.09120187, 0.24278517, -0.12222792};
-
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
- const int input_sequence_size =
- sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
-
- lstm.SetInput(0, batch0_start, batch0_end);
-
- lstm.Invoke();
-
- float* golden_start = lstm_golden_output + i * lstm.num_outputs();
- float* golden_end = golden_start + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912,
+ -0.15680569, -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, -0.20583314,
+ 0.44344562, 0.22077113, -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, -0.31343272, -0.40032279,
+ 0.44781327, 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, 0.40525138, 0.44272184,
+ 0.03897077, -0.1556896, 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellGateBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights({-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304,
+ 0.08183324, -0.16555229, 0.02286911, -0.13566875, 0.03034258,
+ 0.48091322, -0.12528998, 0.24077177, -0.51332325, -0.33502164,
+ 0.10629296});
+
+ lstm.SetRecurrentToCellWeights({-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659,
+ -0.00123841, -0.4744786, -0.35869038, -0.06418842, -0.13502428,
+ -0.501764, 0.22830659, -0.46367589, 0.26016325, -0.03894562,
+ -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights({-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213,
+ 0.20864892, -0.07646349, 0.45877004, 0.00141793, -0.14609534,
+ 0.36447752, 0.09196436, 0.28053468, 0.01560611, -0.20127171,
+ -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights({0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647,
+ -0.39835793, 0.18212086, 0.01301402, 0.48572797, -0.50656658,
+ 0.20047462, -0.20607421, -0.51818722, -0.15390486, 0.0468148,
+ 0.39922136});
+
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size = sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ float* golden_start = lstm_golden_output + i * lstm.num_outputs();
+ float* golden_end = golden_start + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
}
-
TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
- const int n_batch = 1;
- const int n_input = 2;
- // n_cell and n_output have the same size when there is no projection.
- const int n_cell = 4;
- const int n_output = 4;
-
- LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
- /*use_cifg=*/true, /*use_peephole=*/true,
- /*use_projection_weights=*/false,
- /*use_projection_bias=*/false,
- /*cell_clip=*/0.0, /*proj_clip=*/0.0,
- {
- {n_batch, n_input}, // input tensor
-
- {0, 0}, // input_to_input_weight tensor
- {n_cell, n_input}, // input_to_forget_weight tensor
- {n_cell, n_input}, // input_to_cell_weight tensor
- {n_cell, n_input}, // input_to_output_weight tensor
-
- {0, 0}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
-
- {0}, // cell_to_input_weight tensor
- {n_cell}, // cell_to_forget_weight tensor
- {n_cell}, // cell_to_output_weight tensor
-
- {n_cell}, // input_gate_bias tensor
- {n_cell}, // forget_gate_bias tensor
- {n_cell}, // cell_bias tensor
- {n_cell}, // output_gate_bias tensor
-
- {0, 0}, // projection_weight tensor
- {0}, // projection_bias tensor
- });
-
- lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
- 0.04717243, 0.48944736, -0.38535351,
- -0.17212132});
-
- lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
- -0.3633365, -0.22755712, 0.28253698, 0.24407166,
- 0.33826375});
-
- lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
- -0.09426838, -0.44257352, 0.54939759,
- 0.01533556, 0.42751634});
-
- lstm.SetCellGateBias({0., 0., 0., 0.});
-
- lstm.SetForgetGateBias({1., 1., 1., 1.});
-
- lstm.SetOutputGateBias({0., 0., 0., 0.});
-
- lstm.SetRecurrentToCellWeights(
- {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
- 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
- 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
- 0.21193194});
-
- lstm.SetRecurrentToForgetWeights(
- {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
- 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
- -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
-
- lstm.SetRecurrentToOutputWeights(
- {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
- -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
- 0.50248802, 0.26114327, -0.43736315, 0.33149987});
-
- lstm.SetCellToForgetWeights(
- {0.47485286, -0.51955009, -0.24458408, 0.31544167});
- lstm.SetCellToOutputWeights(
- {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
-
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
- -0.05163646, -0.42312205, -0.01218222,
- 0.24201041, -0.08124574, -0.358325,
- -0.04621704, 0.21641694, -0.06471302};
-
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
- const int input_sequence_size =
- sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
-
- lstm.SetInput(0, batch0_start, batch0_end);
-
- lstm.Invoke();
-
- float* golden_start = lstm_golden_output + i * lstm.num_outputs();
- float* golden_end = golden_start + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, 0.04717243,
+ 0.48944736, -0.38535351, -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, -0.3633365, -0.22755712,
+ 0.28253698, 0.24407166, 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, -0.09426838, -0.44257352,
+ 0.54939759, 0.01533556, 0.42751634});
+
+ lstm.SetCellGateBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights({0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights({-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671,
+ 0.17751795, -0.34484994, -0.35874045, -0.11352962, 0.27268326,
+ 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights({0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174,
+ -0.05115908, -0.33941799, 0.23364776, 0.11178309, 0.09481031,
+ -0.26424935, 0.46261835, 0.50248802, 0.26114327, -0.43736315,
+ 0.33149987});
+
+ lstm.SetCellToForgetWeights({0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights({-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size = sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ float* golden_start = lstm_golden_output + i * lstm.num_outputs();
+ float* golden_end = golden_start + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
}
TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
- const int n_batch = 2;
- const int n_input = 5;
- const int n_cell = 20;
- const int n_output = 16;
-
- LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
- /*use_cifg=*/false, /*use_peephole=*/true,
- /*use_projection_weights=*/true,
- /*use_projection_bias=*/false,
- /*cell_clip=*/0.0, /*proj_clip=*/0.0,
- {
- {n_batch, n_input}, // input tensor
-
- {n_cell, n_input}, // input_to_input_weight tensor
- {n_cell, n_input}, // input_to_forget_weight tensor
- {n_cell, n_input}, // input_to_cell_weight tensor
- {n_cell, n_input}, // input_to_output_weight tensor
-
- {n_cell, n_output}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
-
- {n_cell}, // cell_to_input_weight tensor
- {n_cell}, // cell_to_forget_weight tensor
- {n_cell}, // cell_to_output_weight tensor
-
- {n_cell}, // input_gate_bias tensor
- {n_cell}, // forget_gate_bias tensor
- {n_cell}, // cell_bias tensor
- {n_cell}, // output_gate_bias tensor
-
- {n_output, n_cell}, // projection_weight tensor
- {0}, // projection_bias tensor
- });
-
- lstm.SetInputToInputWeights(
- {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
- 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
- -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
- -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
- -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
- -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
- -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
- 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
- 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
- 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
- -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
- 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
- -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
- -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
- -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
- 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
- -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
- -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
- -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
- -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
-
- lstm.SetInputToForgetWeights(
- {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
- -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
- -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
- 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
- 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
- -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
- -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
- 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
- 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
- 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
- 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
- -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
- 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
- -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
- -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
- 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
- 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
- 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
- -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
- 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
-
- lstm.SetInputToCellWeights(
- {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
- -0.043528453, 0.043018587, -0.049152344, -0.12418144,
- -0.078985475, -0.07596889, 0.019484362, -0.11434962,
- -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
- -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
- 0.10665918, -0.032036792, -0.08505916, -0.10843358,
- -0.13002433, -0.036816437, -0.02130134, -0.016518239,
- 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
- -0.10652836, -0.1037554, -0.13056071, -0.03266643,
- -0.033702414, -0.006473424, -0.04611692, 0.014419339,
- -0.025174323, 0.0396852, 0.081777506, 0.06157468,
- 0.10210095, -0.009658194, 0.046511717, 0.03603906,
- 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
- 0.053568836, 0.06408714, 0.12835667, -0.008714329,
- -0.20211966, -0.12093674, 0.029450472, 0.2849013,
- -0.029227901, 0.1164364, -0.08560263, 0.09941786,
- -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
- -0.09720865, -0.11193351, -0.029155117, -0.017936034,
- -0.009768936, -0.04223324, -0.036159635, 0.06505112,
- -0.021742892, -0.023377212, -0.07221364, -0.06430552,
- 0.05453865, 0.091149814, 0.06387331, 0.007518393,
- 0.055960953, 0.069779344, 0.046411168, 0.10509911,
- 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
- 0.056955688, 0.06555285, 0.050801456, -0.009862683,
- 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
-
- lstm.SetInputToOutputWeights(
- {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
- -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
- 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
- -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
- -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
- 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
- -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
- -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
- -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
- -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
- 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
- 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
- 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
- -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
- 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
- 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
- -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
- 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
- -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
- -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
-
- lstm.SetInputGateBias(
- {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
- -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
- -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
- 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
-
- lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
- 0.11098921, 0.15378423, 0.09263801, 0.09790885,
- 0.09508917, 0.061199076, 0.07665568, -0.015443159,
- -0.03499149, 0.046190713, 0.08895977, 0.10899629,
- 0.40694186, 0.06030037, 0.012413437, -0.06108739});
-
- lstm.SetCellGateBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
- -0.1483596, -0.10639995, -0.091433935, 0.058573797,
- -0.06809782, -0.07889636, -0.043246906, -0.09829136,
- -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
- 0.016178843, 0.1749513, 0.13975595, 0.92058027});
-
- lstm.SetOutputGateBias(
- {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
- 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
- 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
- -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
-
- lstm.SetRecurrentToInputWeights(
- {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
- -0.11585556, 0.02557986, -0.13446963, -0.035785314,
- -0.01244275, 0.025961924, -0.02337298, -0.044228926,
- -0.055839065, -0.046598054, -0.010546039, -0.06900766,
- 0.027239809, 0.022582639, -0.013296484, -0.05459212,
- 0.08981, -0.045407712, 0.08682226, -0.06867011,
- -0.14390695, -0.02916037, 0.000996957, 0.091420636,
- 0.14283475, -0.07390571, -0.06402044, 0.062524505,
- -0.093129106, 0.04860203, -0.08364217, -0.08119002,
- 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
- -0.13732095, 0.012405723, -0.07551853, 0.06343048,
- 0.12162708, -0.031923793, -0.014335606, 0.01790974,
- -0.10650317, -0.0724401, 0.08554849, -0.05727212,
- 0.06556731, -0.042729504, -0.043227166, 0.011683251,
- -0.013082158, -0.029302018, -0.010899579, -0.062036745,
- -0.022509435, -0.00964907, -0.01567329, 0.04260106,
- -0.07787477, -0.11576462, 0.017356863, 0.048673786,
- -0.017577527, -0.05527947, -0.082487635, -0.040137455,
- -0.10820036, -0.04666372, 0.022746278, -0.07851417,
- 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
- 0.08944216, -0.0685835, 0.010513544, 0.07228705,
- 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
- 0.040414046, -0.1380399, 0.094208956, -0.05722982,
- 0.012092817, -0.04989123, -0.086576, -0.003399834,
- -0.04696032, -0.045747425, 0.10091314, 0.048676282,
- -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
- 0.09504992, 0.041799378, -0.049185462, -0.031518843,
- -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
- -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
- -0.10167381, 0.042500053, -0.01447153, 0.06464186,
- -0.017142897, 0.03312627, 0.009205989, 0.024138335,
- -0.011337001, 0.035530265, -0.010912711, 0.0706555,
- -0.005894094, 0.051841937, -0.1401738, -0.02351249,
- 0.0365468, 0.07590991, 0.08838724, 0.021681072,
- -0.10086113, 0.019608743, -0.06195883, 0.077335775,
- 0.023646897, -0.095322326, 0.02233014, 0.09756986,
- -0.048691444, -0.009579111, 0.07595467, 0.11480546,
- -0.09801813, 0.019894179, 0.08502348, 0.004032281,
- 0.037211012, 0.068537936, -0.048005626, -0.091520436,
- -0.028379958, -0.01556313, 0.06554592, -0.045599163,
- -0.01672207, -0.020169014, -0.011877351, -0.20212261,
- 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
- -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
- 0.015963363, 0.00871737, 0.060130805, 0.028611384,
- 0.10109069, -0.015060172, -0.07894427, 0.06401885,
- 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
- 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
- 0.019899689, 0.006106124, -0.027092824, 0.0786356,
- 0.05052217, -0.058925, -0.011402121, -0.024987547,
- -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
- -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
- -0.033664223, -0.07978348, -0.025200296, -0.017207067,
- -0.058403496, -0.055697463, 0.005798788, 0.12965427,
- -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
- 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
- 0.013806405, -0.017858358, -0.01008298, -0.07700066,
- -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
- 0.062634714, -0.02338735, -0.039547626, -0.02050681,
- 0.03385117, -0.083611414, 0.002862572, -0.09421313,
- 0.058618143, -0.08598433, 0.00972939, 0.023867095,
- -0.053934585, -0.023203006, 0.07452513, -0.048767887,
- -0.07314807, -0.056307215, -0.10433547, -0.06440842,
- 0.04328182, 0.04389765, -0.020006588, -0.09076438,
- -0.11652589, -0.021705797, 0.03345259, -0.010329105,
- -0.025767034, 0.013057034, -0.07316461, -0.10145612,
- 0.06358255, 0.18531723, 0.07759293, 0.12006465,
- 0.1305557, 0.058638252, -0.03393652, 0.09622831,
- -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
- -0.005644518, 0.06857898, -0.12598175, -0.035084512,
- 0.03156317, -0.12794146, -0.031963028, 0.04692781,
- 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
- 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
- 0.08184801, -0.019164372, 0.06791302, 0.034257166,
- -0.10307039, 0.021943003, 0.046745934, 0.0790918,
- -0.0265588, -0.007824208, 0.042546265, -0.00977924,
- -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
- -0.014512694, -0.08251313, 0.08861942, 0.13589665,
- 0.026351685, 0.012641483, 0.07466548, 0.044301085,
- -0.045414884, -0.051112458, 0.03444247, -0.08502782,
- -0.04106223, -0.028126027, 0.028473156, 0.10467447});
-
- lstm.SetRecurrentToForgetWeights(
- {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
- 0.14811787, 0.10826372, 0.09471067, 0.03987225,
- -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
- 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
- 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
- -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
- -0.06193199, 0.055729095, 0.03736828, 0.020123724,
- 0.061878487, -0.04729229, 0.034919553, -0.07585433,
- -0.04421272, -0.044019096, 0.085488975, 0.04058006,
- -0.06890133, -0.030951202, -0.024628663, -0.07672815,
- 0.034293607, 0.08556707, -0.05293577, -0.033561368,
- -0.04899627, 0.0241671, 0.015736353, -0.095442444,
- -0.029564252, 0.016493602, -0.035026584, 0.022337519,
- -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
- 0.016435321, -0.03263031, -0.09543275, -0.047392778,
- 0.013454138, 0.028934088, 0.01685226, -0.086110644,
- -0.046250615, -0.01847454, 0.047608484, 0.07339695,
- 0.034546845, -0.04881143, 0.009128804, -0.08802852,
- 0.03761666, 0.008096139, -0.014454086, 0.014361001,
- -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
- -0.06509276, -0.006021153, -0.08570962, -0.1451793,
- 0.060212336, 0.055259194, 0.06974018, 0.049454916,
- -0.027794661, -0.08077226, -0.016179763, 0.1169753,
- 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
- -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
- 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
- -0.05695512, 0.047233116, 0.038937137, -0.06542224,
- 0.014429736, -0.09719407, 0.13908425, -0.05379757,
- 0.012321099, 0.082840554, -0.029899208, 0.044217527,
- 0.059855383, 0.07711018, -0.045319796, 0.0948846,
- -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
- -0.13873616, 0.040668588, 0.034832682, -0.015319203,
- -0.018715994, 0.046002675, 0.0599172, -0.043107376,
- 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
- 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
- 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
- 0.052958444, 0.07558703, 0.04817258, 0.044462286,
- -0.015213451, -0.08783778, -0.0561384, -0.003008196,
- 0.047060397, -0.002058388, 0.03429439, -0.018839769,
- 0.024734668, 0.024614193, -0.042046934, 0.09597743,
- -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
- -0.02558259, -0.022822596, -0.023273505, -0.02464396,
- -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
- 0.04383914, -0.046476185, 0.028658995, 0.060410924,
- 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
- 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
- 0.015898481, 0.021362653, -0.030262267, 0.016587038,
- -0.011442813, 0.041154444, -0.007631438, -0.03423484,
- -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
- 0.02318443, -0.041350313, 0.021485701, -0.10906167,
- -0.028218046, -0.00954771, 0.020531068, -0.11995105,
- -0.03672871, 0.024019798, 0.014255957, -0.05221243,
- -0.00661567, -0.04630967, 0.033188973, 0.10107534,
- -0.014027541, 0.030796422, -0.10270911, -0.035999842,
- 0.15443139, 0.07684145, 0.036571592, -0.035900835,
- -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
- -0.03858649, 0.01849943, 0.13872518, 0.01503974,
- 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
- -0.047401894, 0.03100163, -0.041533746, -0.10430945,
- 0.044574402, -0.01425562, -0.024290353, 0.034563623,
- 0.05866852, 0.023947537, -0.09445152, 0.035450947,
- 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
- 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
- 0.03532124, -0.016341697, 0.09685456, -0.016764693,
- 0.051808182, 0.05875331, -0.04536488, 0.001626336,
- -0.028892258, -0.01048663, -0.009793449, -0.017093895,
- 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
- -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
- -0.01769146, 0.040995963, 0.02235177, -0.060430344,
- 0.11475477, -0.023854522, 0.10071741, 0.0686208,
- -0.014250481, 0.034261297, 0.047418304, 0.08562733,
- -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
- 0.04096551, 0.032249358, -0.08355519, -0.026823482,
- 0.056386515, -0.010401743, -0.028396193, 0.08507674,
- 0.014410365, 0.020995233, 0.17040324, 0.11511526,
- 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
- -0.081302024, 0.017264642, -0.009585969, 0.09491168,
- -0.051313367, 0.054532815, -0.014298593, 0.10657464,
- 0.007076659, 0.10964551, 0.0409152, 0.008275321,
- -0.07283536, 0.07937492, 0.04192024, -0.1075027});
-
- lstm.SetRecurrentToCellWeights(
- {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
- 0.055647098, -0.05713207, -0.05626563, 0.005559383,
- 0.03375411, -0.025757805, -0.088049285, 0.06017052,
- -0.06570978, 0.007384076, 0.035123326, -0.07920549,
- 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
- 0.08089997, 0.05143358, 0.038261272, 0.03339287,
- -0.027673481, 0.044746667, 0.028349208, 0.020090483,
- -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
- -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
- -0.10893326, 0.076739706, -0.08509834, -0.027997585,
- 0.037871376, 0.01449768, -0.09002357, -0.06111149,
- -0.046195522, 0.0422062, -0.005683705, -0.1253618,
- -0.012925729, -0.04890792, 0.06985068, 0.037654128,
- 0.03398274, -0.004781977, 0.007032333, -0.031787455,
- 0.010868644, -0.031489216, 0.09525667, 0.013939797,
- 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
- -0.048885044, -0.12722108, 0.035304096, 0.06554885,
- 0.00972396, -0.039238118, -0.05159735, -0.11329045,
- 0.1613692, -0.03750952, 0.06529313, -0.071974665,
- -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
- 0.02786344, -0.014179351, 0.005264273, 0.14376344,
- 0.015983658, 0.03406988, -0.06939408, 0.040699873,
- 0.02111075, 0.09669095, 0.041345075, -0.08316494,
- -0.07684199, -0.045768797, 0.032298047, -0.041805092,
- 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
- -0.024950314, 0.11574242, 0.04508852, -0.04335324,
- 0.06760663, -0.027437469, 0.07216407, 0.06977076,
- -0.05438599, 0.034033038, -0.028602652, 0.05346137,
- 0.043184172, -0.037189785, 0.10420091, 0.00882477,
- -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
- 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
- 0.04361412, -0.007001822, 0.09631092, -0.06702025,
- -0.042049985, -0.035070654, -0.04103342, -0.10273396,
- 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
- -0.008264958, 0.042035464, 0.05891794, 0.029673764,
- 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
- -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
- -0.04043371, -0.017094059, 0.07229206, -0.023670016,
- -0.052195564, -0.025616996, -0.01520939, 0.045104615,
- -0.007376126, 0.003533447, 0.006570588, 0.056037236,
- 0.12436656, 0.051817212, 0.028532185, -0.08686856,
- 0.11868599, 0.07663395, -0.07323171, 0.03463402,
- -0.050708205, -0.04458982, -0.11590894, 0.021273347,
- 0.1251325, -0.15313013, -0.12224372, 0.17228661,
- 0.023029093, 0.086124025, 0.006445803, -0.03496501,
- 0.028332196, 0.04449512, -0.042436164, -0.026587414,
- -0.006041347, -0.09292539, -0.05678812, 0.03897832,
- 0.09465633, 0.008115513, -0.02171956, 0.08304309,
- 0.071401566, 0.019622514, 0.032163795, -0.004167056,
- 0.02295182, 0.030739572, 0.056506045, 0.004612461,
- 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
- -0.1335546, -0.030136576, 0.11584653, -0.014678886,
- 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
- -0.0329582, 0.07922767, 0.029322514, 0.026405897,
- 0.04207835, -0.07073373, 0.063781224, 0.0859677,
- -0.10925287, -0.07011058, 0.048005477, 0.03438226,
- -0.09606514, -0.006669445, -0.043381985, 0.04240257,
- -0.06955775, -0.06769346, 0.043903265, -0.026784198,
- -0.017840602, 0.024307009, -0.040079936, -0.019946516,
- 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
- 0.15978073, 0.10185836, 0.10298046, -0.015476589,
- -0.039390966, -0.072174534, 0.0739445, -0.1211869,
- -0.0347889, -0.07943156, 0.014809798, -0.12412325,
- -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
- -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
- -0.01514876, -0.056505352, -0.012800942, -0.06994386,
- 0.012962922, -0.031234352, 0.07029052, 0.016418684,
- 0.03618972, 0.055686004, -0.08663945, -0.017404709,
- -0.054761406, 0.029065743, 0.052404847, 0.020238016,
- 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
- 0.06262858, 0.009184685, 0.020785125, -0.043904778,
- -0.0270329, -0.03299152, -0.060088247, -0.015162964,
- -0.001828936, 0.12642565, -0.056757294, 0.013586685,
- 0.09232601, -0.035886683, 0.06000002, 0.05229691,
- -0.052580316, -0.082029596, -0.010794592, 0.012947712,
- -0.036429964, -0.085508935, -0.13127148, -0.017744139,
- 0.031502828, 0.036232427, -0.031581745, 0.023051167,
- -0.05325106, -0.03421577, 0.028793324, -0.034633752,
- -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
- -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
-
- lstm.SetRecurrentToOutputWeights({
- 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
- -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
- -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
- -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
- -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
- -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
- -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
- 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
- -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
- 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
- -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
- -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
- 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
- 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
- -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
- 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
- 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
- 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
- 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
- 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
- -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
- 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
- -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
- 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
- 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
- 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
- -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
- -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
- -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
- -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
- -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
- -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
- 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
- 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
- -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
- 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
- -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
- -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
- -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
- 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
- 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
- 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
- -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
- 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
- -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
- -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
- -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
- -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
- 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
- -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
- 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
- -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
- -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
- -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
- -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
- 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
- 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
- -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
- 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
- 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
- -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
- 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
- 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
- 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
- });
-
- lstm.SetCellToInputWeights(
- {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
- -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
- -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
- 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
-
- lstm.SetCellToForgetWeights(
- {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
- -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
- -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
- 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
-
- lstm.SetCellToOutputWeights(
- {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
- -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
- -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
- 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
-
- lstm.SetProjectionWeights(
- {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
- 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
- -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
- -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
- 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
- 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
- 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
- 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
- -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
- -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
- -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
- 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
- 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
- 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
- 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
- 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
- -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
- 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
- -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
- 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
- -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
- -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
- 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
- -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
- 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
- -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
- -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
- 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
- -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
- -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
- -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
- 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
- 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
- -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
- 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
- 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
- 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
- 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
- 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
- -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
- -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
- 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
- -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
- -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
- 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
- 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
- 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
- -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
- -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
- -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
- 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
- -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
- 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
- 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
- -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
- -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
- -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
- 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
- -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
- -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
- -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
- 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
- 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
- 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
-
- static float lstm_input[][20] = {
- {// Batch0: 4 (input_sequence_size) * 5 (n_input)
- 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
- 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
- 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
-
- {// Batch1: 4 (input_sequence_size) * 5 (n_input)
- 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
- 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
- 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
-
- static float lstm_golden_output[][64] = {
- {// Batch0: 4 (input_sequence_size) * 16 (n_output)
- -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
- -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
- -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
- 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
- -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
- -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
- 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
- 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
- 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
- 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
- -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
- -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
- 0.0286833, 0.00824207, 0.0264887, 0.0305169},
- {// Batch1: 4 (input_sequence_size) * 16 (n_output)
- -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
- -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
- 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
- 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
- -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
- -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
- 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
- 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
- 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
- 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
- -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
- -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
- 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
-
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
- const int input_sequence_size =
- sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
-
- lstm.SetInput(0, batch0_start, batch0_end);
-
- float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
- float* batch1_end = batch1_start + lstm.num_inputs();
- lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end);
-
- lstm.Invoke();
-
- float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
- float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
- float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
- float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
- expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(
+ {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, 0.09171803,
+ 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, -0.2726754, 0.10154029,
+ -0.018539885, 0.080349885, -0.10262385, -0.022599787, -0.09121155, -0.008675967,
+ -0.045206103, -0.0821282, -0.008045952, 0.015478081, 0.055217247, 0.038719587,
+ 0.044153627, -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, 0.25005487,
+ -0.22790983, 0.009855087, -0.028140958, -0.11200698, 0.11295408, -0.0035217577,
+ 0.054485075, 0.05184695, 0.064711206, 0.10989193, 0.11674786, 0.03490607,
+ 0.07727357, 0.11390585, -0.1863375, -0.1034451, -0.13945189, -0.049401227,
+ -0.18767063, 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, -0.042484224,
+ -0.11827596, -0.09171104, -0.10808628, -0.16327988, -0.2273378, -0.0993647,
+ -0.017155107, 0.0023917493, 0.049272764, 0.0038534778, 0.054764505, 0.089753784,
+ 0.06947234, 0.08014476, -0.04544234, -0.0497073, -0.07135631, -0.048929106,
+ -0.004042012, -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, -0.39292613,
+ -0.18519334, -0.11651281, -0.06809892, 0.011373677});
+
+ lstm.SetInputToForgetWeights(
+ {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, -0.016726194,
+ -0.05249759, -0.10204261, 0.00861066, -0.040979505, -0.009899187, 0.01923892,
+ -0.028177269, -0.08535103, -0.14585495, 0.10662567, -0.01909731, -0.017883534,
+ -0.0047269356, -0.045103323, 0.0030784295, 0.076784775, 0.07463696, 0.094531395,
+ 0.0814421, -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
+ -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, 0.045100946,
+ 0.0012300825, 0.013964662, 0.099372394, 0.02543059, 0.06958324, 0.034257296,
+ 0.0482646, 0.06267997, 0.052625068, 0.12784666, 0.07077897, 0.025725935,
+ 0.04165009, 0.07241905, 0.018668644, -0.037377294, -0.06277783, -0.08833636,
+ -0.040120605, -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, -0.08402166,
+ -0.01901462, -0.044678304, -0.07720565, 0.014350063, -0.11757958, -0.0652038,
+ -0.08185733, -0.076754324, -0.092614375, 0.10405491, 0.052960336, 0.035755895,
+ 0.035839386, -0.012540553, 0.036881298, 0.02913376, 0.03420159, 0.05448447,
+ -0.054523353, 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
+ -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
+ -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
+
+ lstm.SetInputToCellWeights(
+ {-0.04580283, -0.09549462, -0.032418985, -0.06454633, -0.043528453, 0.043018587,
+ -0.049152344, -0.12418144, -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, -0.025034338, -0.0028890965,
+ 0.048929527, 0.06235075, 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239, 0.0047691227, -0.0025825808,
+ 0.066017866, 0.029991534, -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339, -0.025174323, 0.0396852,
+ 0.081777506, 0.06157468, 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598, 0.053568836, 0.06408714,
+ 0.12835667, -0.008714329, -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786, -0.036999565, -0.028842626,
+ -0.0033637602, -0.017012902, -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112, -0.021742892, -0.023377212,
+ -0.07221364, -0.06430552, 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911, 0.07463894, 0.0075130584,
+ 0.012850982, 0.04555431, 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
+
+ lstm.SetInputToOutputWeights(
+ {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, -0.07650751,
+ 0.02359855, -0.075155355, -0.08037709, -0.15093534, 0.029517552, -0.04751393,
+ 0.010350531, -0.02664851, -0.016839722, -0.023121163, 0.0077019283, 0.012851257,
+ -0.05040649, -0.0129761, -0.021737747, -0.038305793, -0.06870586, -0.01481247,
+ -0.001285394, 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, -0.086665764,
+ -0.037162706, -0.038880914, -0.035832845, -0.014481564, -0.09825003, -0.12048569,
+ -0.097665586, -0.05287633, -0.0964047, -0.11366429, 0.035777505, 0.13568819,
+ 0.052451383, 0.050649304, 0.05798951, -0.021852335, -0.099848844, 0.014740475,
+ -0.078897946, 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, -0.078907564,
+ -0.06707616, -0.11844508, -0.09986688, -0.07509403, 0.06263226, 0.14925587,
+ 0.20188436, 0.12098451, 0.14639415, 0.0015017595, -0.014267382, -0.03417257,
+ 0.012711468, 0.0028300495, -0.024758482, -0.05098548, -0.0821182, 0.014225672,
+ 0.021544158, 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, -0.049097303,
+ -0.017121866, -0.083368234, -0.02332002, -0.0840956});
+
+ lstm.SetInputGateBias({0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
+ -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
+ -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
+ 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
+
+ lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, 0.11098921,
+ 0.15378423, 0.09263801, 0.09790885, 0.09508917, 0.061199076,
+ 0.07665568, -0.015443159, -0.03499149, 0.046190713, 0.08895977,
+ 0.10899629, 0.40694186, 0.06030037, 0.012413437, -0.06108739});
+
+ lstm.SetCellGateBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, -0.1483596,
+ -0.10639995, -0.091433935, 0.058573797, -0.06809782, -0.07889636,
+ -0.043246906, -0.09829136, -0.4279842, 0.034901652, 0.18797937,
+ 0.0075234566, 0.016178843, 0.1749513, 0.13975595, 0.92058027});
+
+ lstm.SetOutputGateBias({0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
+ 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
+ 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
+ -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.001374326, -0.078856036, 0.10672688, 0.029162422, -0.11585556,
+ 0.02557986, -0.13446963, -0.035785314, -0.01244275, 0.025961924,
+ -0.02337298, -0.044228926, -0.055839065, -0.046598054, -0.010546039,
+ -0.06900766, 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011, -0.14390695,
+ -0.02916037, 0.000996957, 0.091420636, 0.14283475, -0.07390571,
+ -0.06402044, 0.062524505, -0.093129106, 0.04860203, -0.08364217,
+ -0.08119002, 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048, 0.12162708,
+ -0.031923793, -0.014335606, 0.01790974, -0.10650317, -0.0724401,
+ 0.08554849, -0.05727212, 0.06556731, -0.042729504, -0.043227166,
+ 0.011683251, -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106, -0.07787477,
+ -0.11576462, 0.017356863, 0.048673786, -0.017577527, -0.05527947,
+ -0.082487635, -0.040137455, -0.10820036, -0.04666372, 0.022746278,
+ -0.07851417, 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705, 0.02032331,
+ -0.059686817, -0.0005566496, -0.086984694, 0.040414046, -0.1380399,
+ 0.094208956, -0.05722982, 0.012092817, -0.04989123, -0.086576,
+ -0.003399834, -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843, 0.09504992,
+ 0.041799378, -0.049185462, -0.031518843, -0.10516937, 0.026374253,
+ 0.10058866, -0.0033195973, -0.041975245, 0.0073591834, 0.0033782164,
+ -0.004325073, -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335, -0.011337001,
+ 0.035530265, -0.010912711, 0.0706555, -0.005894094, 0.051841937,
+ -0.1401738, -0.02351249, 0.0365468, 0.07590991, 0.08838724,
+ 0.021681072, -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986, -0.048691444,
+ -0.009579111, 0.07595467, 0.11480546, -0.09801813, 0.019894179,
+ 0.08502348, 0.004032281, 0.037211012, 0.068537936, -0.048005626,
+ -0.091520436, -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261, 0.010889619,
+ 0.0047078193, 0.038385306, 0.08540671, -0.017140968, -0.0035865551,
+ 0.016678626, 0.005633034, 0.015963363, 0.00871737, 0.060130805,
+ 0.028611384, 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358, 0.030737216,
+ -0.0046374933, 0.14215417, -0.11823516, 0.019899689, 0.006106124,
+ -0.027092824, 0.0786356, 0.05052217, -0.058925, -0.011402121,
+ -0.024987547, -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559, -0.033664223,
+ -0.07978348, -0.025200296, -0.017207067, -0.058403496, -0.055697463,
+ 0.005798788, 0.12965427, -0.062582195, 0.0013350133, -0.10482091,
+ 0.0379771, 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066, -0.017081132,
+ 0.019358726, 0.0027079724, 0.004635139, 0.062634714, -0.02338735,
+ -0.039547626, -0.02050681, 0.03385117, -0.083611414, 0.002862572,
+ -0.09421313, 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887, -0.07314807,
+ -0.056307215, -0.10433547, -0.06440842, 0.04328182, 0.04389765,
+ -0.020006588, -0.09076438, -0.11652589, -0.021705797, 0.03345259,
+ -0.010329105, -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465, 0.1305557,
+ 0.058638252, -0.03393652, 0.09622831, -0.16253184, -2.4580743e-06,
+ 0.079869635, -0.070196845, -0.005644518, 0.06857898, -0.12598175,
+ -0.035084512, 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372, 0.040170413,
+ -0.062104587, -0.0037324072, 0.0554317, 0.08184801, -0.019164372,
+ 0.06791302, 0.034257166, -0.10307039, 0.021943003, 0.046745934,
+ 0.0790918, -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321, -0.014512694,
+ -0.08251313, 0.08861942, 0.13589665, 0.026351685, 0.012641483,
+ 0.07466548, 0.044301085, -0.045414884, -0.051112458, 0.03444247,
+ -0.08502782, -0.04106223, -0.028126027, 0.028473156, 0.10467447});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.057784554, -0.026057621, -0.068447545, -0.022581743, 0.14811787,
+ 0.10826372, 0.09471067, 0.03987225, -0.0039523416, 0.00030638507,
+ 0.053185795, 0.10572994, 0.08414449, -0.022036452, -0.00066928595,
+ -0.09203576, 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116, -0.06193199,
+ 0.055729095, 0.03736828, 0.020123724, 0.061878487, -0.04729229,
+ 0.034919553, -0.07585433, -0.04421272, -0.044019096, 0.085488975,
+ 0.04058006, -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368, -0.04899627,
+ 0.0241671, 0.015736353, -0.095442444, -0.029564252, 0.016493602,
+ -0.035026584, 0.022337519, -0.026871363, 0.004780428, 0.0077918363,
+ -0.03601621, 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644, -0.046250615,
+ -0.01847454, 0.047608484, 0.07339695, 0.034546845, -0.04881143,
+ 0.009128804, -0.08802852, 0.03761666, 0.008096139, -0.014454086,
+ 0.014361001, -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793, 0.060212336,
+ 0.055259194, 0.06974018, 0.049454916, -0.027794661, -0.08077226,
+ -0.016179763, 0.1169753, 0.17213494, -0.0056326236, -0.053934924,
+ -0.0124349, -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196, -0.05695512,
+ 0.047233116, 0.038937137, -0.06542224, 0.014429736, -0.09719407,
+ 0.13908425, -0.05379757, 0.012321099, 0.082840554, -0.029899208,
+ 0.044217527, 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985, -0.13873616,
+ 0.040668588, 0.034832682, -0.015319203, -0.018715994, 0.046002675,
+ 0.0599172, -0.043107376, 0.0294216, -0.002314414, -0.022424703,
+ 0.0030315618, 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465, 0.052958444,
+ 0.07558703, 0.04817258, 0.044462286, -0.015213451, -0.08783778,
+ -0.0561384, -0.003008196, 0.047060397, -0.002058388, 0.03429439,
+ -0.018839769, 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, -0.02558259,
+ -0.022822596, -0.023273505, -0.02464396, -0.10991725, -0.006240552,
+ 0.0074488563, 0.024044557, 0.04383914, -0.046476185, 0.028658995,
+ 0.060410924, 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, 0.015898481,
+ 0.021362653, -0.030262267, 0.016587038, -0.011442813, 0.041154444,
+ -0.007631438, -0.03423484, -0.010977775, 0.036152758, 0.0066366293,
+ 0.11915515, 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105, -0.03672871,
+ 0.024019798, 0.014255957, -0.05221243, -0.00661567, -0.04630967,
+ 0.033188973, 0.10107534, -0.014027541, 0.030796422, -0.10270911,
+ -0.035999842, 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351, -0.03858649,
+ 0.01849943, 0.13872518, 0.01503974, 0.069941424, -0.06948533,
+ -0.0088794185, 0.061282158, -0.047401894, 0.03100163, -0.041533746,
+ -0.10430945, 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947, 0.02247216,
+ -0.0042998926, 0.061146557, -0.10250651, 0.020881841, -0.06747029,
+ 0.10062043, -0.0023941975, 0.03532124, -0.016341697, 0.09685456,
+ -0.016764693, 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895, 0.010987891,
+ 0.02357273, -0.00010856845, 0.0099760275, -0.001845119, -0.03551521,
+ 0.0018358806, 0.05763657, -0.01769146, 0.040995963, 0.02235177,
+ -0.060430344, 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733, -0.030519066,
+ 0.0060542435, 0.014653856, -0.038836084, 0.04096551, 0.032249358,
+ -0.08355519, -0.026823482, 0.056386515, -0.010401743, -0.028396193,
+ 0.08507674, 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837, -0.081302024,
+ 0.017264642, -0.009585969, 0.09491168, -0.051313367, 0.054532815,
+ -0.014298593, 0.10657464, 0.007076659, 0.10964551, 0.0409152,
+ 0.008275321, -0.07283536, 0.07937492, 0.04192024, -0.1075027});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, 0.055647098,
+ -0.05713207, -0.05626563, 0.005559383, 0.03375411, -0.025757805,
+ -0.088049285, 0.06017052, -0.06570978, 0.007384076, 0.035123326,
+ -0.07920549, 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287, -0.027673481,
+ 0.044746667, 0.028349208, 0.020090483, -0.019443132, -0.030755889,
+ -0.0040000007, 0.04465846, -0.021585021, 0.0031670958, 0.0053199246,
+ -0.056117613, -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149, -0.046195522,
+ 0.0422062, -0.005683705, -0.1253618, -0.012925729, -0.04890792,
+ 0.06985068, 0.037654128, 0.03398274, -0.004781977, 0.007032333,
+ -0.031787455, 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466, -0.048885044,
+ -0.12722108, 0.035304096, 0.06554885, 0.00972396, -0.039238118,
+ -0.05159735, -0.11329045, 0.1613692, -0.03750952, 0.06529313,
+ -0.071974665, -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344, 0.015983658,
+ 0.03406988, -0.06939408, 0.040699873, 0.02111075, 0.09669095,
+ 0.041345075, -0.08316494, -0.07684199, -0.045768797, 0.032298047,
+ -0.041805092, 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324, 0.06760663,
+ -0.027437469, 0.07216407, 0.06977076, -0.05438599, 0.034033038,
+ -0.028602652, 0.05346137, 0.043184172, -0.037189785, 0.10420091,
+ 0.00882477, -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, 0.04361412,
+ -0.007001822, 0.09631092, -0.06702025, -0.042049985, -0.035070654,
+ -0.04103342, -0.10273396, 0.0544271, 0.037184782, -0.13150354,
+ -0.0058036847, -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513, -0.00093483756,
+ 0.048938446, -0.004952862, -0.007730018, -0.04043371, -0.017094059,
+ 0.07229206, -0.023670016, -0.052195564, -0.025616996, -0.01520939,
+ 0.045104615, -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856, 0.11868599,
+ 0.07663395, -0.07323171, 0.03463402, -0.050708205, -0.04458982,
+ -0.11590894, 0.021273347, 0.1251325, -0.15313013, -0.12224372,
+ 0.17228661, 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414, -0.006041347,
+ -0.09292539, -0.05678812, 0.03897832, 0.09465633, 0.008115513,
+ -0.02171956, 0.08304309, 0.071401566, 0.019622514, 0.032163795,
+ -0.004167056, 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207, -0.1335546,
+ -0.030136576, 0.11584653, -0.014678886, 0.0020118146, -0.09688814,
+ -0.0790206, 0.039770417, -0.0329582, 0.07922767, 0.029322514,
+ 0.026405897, 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226, -0.09606514,
+ -0.006669445, -0.043381985, 0.04240257, -0.06955775, -0.06769346,
+ 0.043903265, -0.026784198, -0.017840602, 0.024307009, -0.040079936,
+ -0.019946516, 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589, -0.039390966,
+ -0.072174534, 0.0739445, -0.1211869, -0.0347889, -0.07943156,
+ 0.014809798, -0.12412325, -0.0030663363, 0.039695457, 0.0647603,
+ -0.08291318, -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386, 0.012962922,
+ -0.031234352, 0.07029052, 0.016418684, 0.03618972, 0.055686004,
+ -0.08663945, -0.017404709, -0.054761406, 0.029065743, 0.052404847,
+ 0.020238016, 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778, -0.0270329,
+ -0.03299152, -0.060088247, -0.015162964, -0.001828936, 0.12642565,
+ -0.056757294, 0.013586685, 0.09232601, -0.035886683, 0.06000002,
+ 0.05229691, -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139, 0.031502828,
+ 0.036232427, -0.031581745, 0.023051167, -0.05325106, -0.03421577,
+ 0.028793324, -0.034633752, -0.009881397, -0.043551125, -0.018609839,
+ 0.0019097115, -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
+
+ lstm.SetRecurrentToOutputWeights({
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
+ -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
+ -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
+ -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
+ -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
+ -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
+ 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
+ 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
+ -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
+ -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
+ 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
+ -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
+ 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
+ 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
+ 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
+ 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
+ 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
+ -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
+ 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
+ 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
+ -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
+ -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
+ -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
+ -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
+ -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
+ 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
+ -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
+ 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
+ -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
+ -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
+ 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
+ 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
+ -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
+ 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
+ -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
+ -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
+ -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
+ -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
+ 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
+ -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
+ -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
+ -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
+ 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
+ -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
+ 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
+ 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
+ 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
+ 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
+ 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ });
+
+ lstm.SetCellToInputWeights({0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
+
+ lstm.SetCellToForgetWeights({-0.01998659, -0.15568835, -0.24248174, -0.012770197,
+ 0.041331276, -0.072311886, -0.052123554, -0.0066330447,
+ -0.043891653, 0.036225766, -0.047248036, 0.021479502,
+ 0.033189066, 0.11952997, -0.020432774, 0.64658105,
+ -0.06650122, -0.03467612, 0.095340036, 0.23647355});
+
+ lstm.SetCellToOutputWeights({0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
+
+ lstm.SetProjectionWeights(
+ {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, 0.060420845,
+ 0.08539281, 0.054285463, 0.061395317, 0.034448683, -0.042991187, 0.019801661,
+ -0.16840284, -0.015726732, -0.23041931, -0.024478018, -0.10959692, -0.013875541,
+ 0.18600968, -0.061274476, 0.0138165, -0.08160894, -0.07661644, 0.032372914,
+ 0.16169067, 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
+ 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, 0.16702051,
+ 0.0077946745, 0.15140012, 0.29405436, 0.120285, -0.188994, -0.027265169,
+ 0.043389652, -0.022061434, 0.014777949, -0.20203483, 0.094781205, 0.19100232,
+ 0.13987629, -0.036132768, -0.06426278, -0.05108664, 0.13221376, 0.009441198,
+ -0.16715929, 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, 0.050794356,
+ 0.10770313, -0.20790008, -0.07149004, -0.11425117, 0.008225835, -0.035802525,
+ 0.14374903, 0.15262283, 0.048710253, 0.1847461, -0.007487823, 0.11000021,
+ -0.09542012, 0.22619456, -0.029149994, 0.08527916, 0.009043713, 0.0042746216,
+ 0.016261552, 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
+ -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, 0.09944291,
+ -0.18897448, -0.1593054, -0.06526116, -0.040107165, -0.004618631, -0.067624845,
+ -0.007576253, 0.10727444, 0.041546922, -0.20424393, 0.06907816, 0.050412357,
+ 0.00724631, 0.039827548, 0.12449835, 0.10747581, 0.13708383, 0.09134148,
+ -0.12617786, -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, -0.042050496,
+ 0.16842307, -0.060597885, 0.10531834, -0.06411776, -0.07451711, -0.03410368,
+ -0.13393489, 0.06534304, 0.003620307, 0.04490757, 0.05970546, 0.05197996,
+ 0.02839995, 0.10434969, -0.013699693, -0.028353551, -0.07260381, 0.047201227,
+ -0.024575593, -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
+ -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, 0.18419984,
+ -0.13012612, -0.014588381, -0.035059117, -0.04824723, 0.07830115, -0.056184657,
+ 0.03277091, 0.025466874, 0.14494097, -0.12522776, -0.098633975, -0.10766018,
+ -0.08317623, 0.08594209, 0.07749552, 0.039474737, 0.1776665, -0.07409566,
+ -0.0477268, 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, 0.045594677,
+ 0.0635285, -0.0715442, -0.089667566, -0.10811871, 0.00026344223, 0.08298446,
+ -0.009525053, 0.006585689, -0.24567553, -0.09450807, 0.09648481, 0.026996298,
+ -0.06419476, -0.04752702, -0.11063944, -0.23441927, -0.17608605, -0.052156363,
+ 0.067035615, 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
+ -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, -0.10190601,
+ 0.18335468, 0.10494553, -0.052095775, -0.0026118709, 0.10539724, -0.04383912,
+ -0.042349473, 0.08438151, -0.1947263, 0.02251204, 0.11216432, -0.10307853,
+ 0.17351969, -0.039091777, 0.08066188, -0.00561982, 0.12633002, 0.11335965,
+ -0.0088127935, -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, -0.0855457,
+ 0.099339016, -0.07580735, -0.13775392, 0.08434318, 0.08330512, -0.12131499,
+ 0.031935584, 0.09180414, -0.08876437, -0.08049874, 0.008753825, 0.03498998,
+ 0.030215185, 0.03907079, 0.089751154, 0.029194152, -0.03337423, -0.019092513,
+ 0.04331237, 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
+ -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, -0.19379048,
+ -0.218606, 0.21448623, 0.017840758, 0.1416943, -0.07051762, 0.19488361,
+ 0.02664691, -0.18104725, -0.09334311, 0.15026465, -0.15493552, -0.057762887,
+ -0.11604192, -0.262013, -0.01391798, 0.012185008, 0.11156489, -0.07483202,
+ 0.06693364, -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, 0.010969227,
+ 0.11109743, 0.010919218, 0.027526086, 0.13519906, 0.01891392, -0.046839405,
+ -0.040167913, 0.017953383, -0.09700955, 0.0061885654, -0.07000971, 0.026893595,
+ -0.038844477, 0.14543656});
+
+ static float lstm_input[][20] = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
+ 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
+ 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
+ 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
+ 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
+
+ static float lstm_golden_output[][64] = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, -0.0211779,
+ 0.0283512, -0.0114597, 0.00907307, -0.0244004, -0.0152191, -0.0259063,
+ 0.00914318, 0.00415118, 0.017147, 0.0134203, -0.0166936, 0.0381209,
+ 0.000889694, 0.0143363, -0.0328911, -0.0234288, 0.0333051, -0.012229,
+ 0.0110322, -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, 0.0276012,
+ -0.0263374, -0.0371449, 0.0446149, -0.0205474, 0.0103729, -0.0576349,
+ -0.0150052, -0.0292043, 0.0376827, 0.0136115, 0.0243435, 0.0354492,
+ -0.0189322, 0.0464512, -0.00251373, 0.0225745, -0.0308346, -0.0317124,
+ 0.0460407, -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, -0.0186926,
+ 0.0193662, -0.0115437, 0.00422612, -0.0345232, 0.00223253, -0.00957321,
+ 0.0210624, 0.013331, 0.0150954, 0.02168, -0.0141913, 0.0322082,
+ 0.00227024, 0.0260507, -0.0188721, -0.0296489, 0.0399134, -0.0160509,
+ 0.0116039, -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, 0.0167673,
+ -0.0375007, -0.0238314, 0.038784, -0.0174034, 0.0131743, -0.0506589,
+ -0.0048447, -0.0240239, 0.0325789, 0.00790065, 0.0220157, 0.0333314,
+ -0.0264787, 0.0387855, -0.000764675, 0.0217599, -0.037537, -0.0335206,
+ 0.0431679, -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size = sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
+ float* batch1_end = batch1_start + lstm.num_inputs();
+ lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end);
+
+ lstm.Invoke();
+
+ float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
+ float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
+ float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
+ float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
+ expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
}
-
} // namespace wrapper
} // namespace nn
} // namespace android
diff --git a/nn/common/operations/RNN.cpp b/nn/common/operations/RNN.cpp
index dcb5928fa..36473042e 100644
--- a/nn/common/operations/RNN.cpp
+++ b/nn/common/operations/RNN.cpp
@@ -29,61 +29,55 @@ namespace nn {
using namespace hal;
-RNN::RNN(const Operation& operation,
- std::vector<RunTimeOperandInfo>& operands) {
- NNTRACE_TRANS("RNN::RNN");
- input_ = GetInput(operation, operands, kInputTensor);
- weights_ = GetInput(operation, operands, kWeightsTensor);
- recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
- hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
- bias_ = GetInput(operation, operands, kBiasTensor);
-
- activation_ = static_cast<ActivationFn>(
- getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
-
- hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
- output_ = GetOutput(operation, operands, kOutputTensor);
+RNN::RNN(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
+ NNTRACE_TRANS("RNN::RNN");
+ input_ = GetInput(operation, operands, kInputTensor);
+ weights_ = GetInput(operation, operands, kWeightsTensor);
+ recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
+ hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
+ bias_ = GetInput(operation, operands, kBiasTensor);
+
+ activation_ = static_cast<ActivationFn>(
+ getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
+
+ hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
+ output_ = GetOutput(operation, operands, kOutputTensor);
}
-bool RNN::Prepare(const Operation &operation,
- std::vector<RunTimeOperandInfo> &operands,
- Shape *hiddenStateShape,
- Shape *outputShape) {
- NNTRACE_TRANS("RNN::Prepare");
- // Check we have all the inputs and outputs we need.
- const int num_inputs = NumInputsWithValues(operation, operands);
- NN_CHECK(num_inputs == 5 || num_inputs == 6);
- NN_CHECK_EQ(NumOutputs(operation), 2);
-
- const RunTimeOperandInfo *input =
- GetInput(operation, operands, kInputTensor);
- const RunTimeOperandInfo *input_weights =
- GetInput(operation, operands, kWeightsTensor);
- const RunTimeOperandInfo *recurrent_weights =
- GetInput(operation, operands, kRecurrentWeightsTensor);
- const RunTimeOperandInfo *bias =
- GetInput(operation, operands, kBiasTensor);
-
- // Check all the parameters of tensor match within themselves and match the
- // input configuration.
- const uint32_t batch_size = SizeOfDimension(input, 0);
- const uint32_t num_units = SizeOfDimension(input_weights, 0);
- NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
- NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
- NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
- NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
-
- const Shape &inputShape = input->shape();
-
- // Resize state.
- hiddenStateShape->type = inputShape.type;
- hiddenStateShape->dimensions = { batch_size, num_units };
-
- // Resize output.
- outputShape->type = inputShape.type;
- outputShape->dimensions = { batch_size, num_units };
-
- return true;
+bool RNN::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
+ Shape* hiddenStateShape, Shape* outputShape) {
+ NNTRACE_TRANS("RNN::Prepare");
+ // Check we have all the inputs and outputs we need.
+ const int num_inputs = NumInputsWithValues(operation, operands);
+ NN_CHECK(num_inputs == 5 || num_inputs == 6);
+ NN_CHECK_EQ(NumOutputs(operation), 2);
+
+ const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor);
+ const RunTimeOperandInfo* input_weights = GetInput(operation, operands, kWeightsTensor);
+ const RunTimeOperandInfo* recurrent_weights =
+ GetInput(operation, operands, kRecurrentWeightsTensor);
+ const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor);
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ const uint32_t batch_size = SizeOfDimension(input, 0);
+ const uint32_t num_units = SizeOfDimension(input_weights, 0);
+ NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
+ NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
+ NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
+ NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
+
+ const Shape& inputShape = input->shape();
+
+ // Resize state.
+ hiddenStateShape->type = inputShape.type;
+ hiddenStateShape->dimensions = {batch_size, num_units};
+
+ // Resize output.
+ outputShape->type = inputShape.type;
+ outputShape->dimensions = {batch_size, num_units};
+
+ return true;
}
bool RNN::Eval() {
diff --git a/nn/common/operations/RNNTest.cpp b/nn/common/operations/RNNTest.cpp
index 332885f95..66acac7cb 100644
--- a/nn/common/operations/RNNTest.cpp
+++ b/nn/common/operations/RNNTest.cpp
@@ -33,304 +33,271 @@ namespace {
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
float max_abs_error = 1.e-5) {
- std::vector<Matcher<float>> matchers;
- matchers.reserve(values.size());
- for (const float& v : values) {
- matchers.emplace_back(FloatNear(v, max_abs_error));
- }
- return matchers;
+ std::vector<Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(FloatNear(v, max_abs_error));
+ }
+ return matchers;
}
static float rnn_input[] = {
- 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
- 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
- -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
- 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
- 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
- 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
- -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
- -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
- 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
- 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
- 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
- -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
- 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
- -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
- -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
- -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
- 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
- -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
- -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
- 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
- -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
- 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
- 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
- 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
- -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
- 0.93455386, -0.6324693, -0.083922029};
+ 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, 0.43773448,
+ 0.60379338, 0.35562468, -0.69424844, -0.93421471, -0.87287879, 0.37144363,
+ -0.62476718, 0.23791671, 0.40060222, 0.1356622, -0.99774903, -0.98858172,
+ -0.38952237, -0.47685933, 0.31073618, 0.71511042, -0.63767755, -0.31729108,
+ 0.33468103, 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
+ -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, -0.61777675,
+ -0.21095741, 0.41213346, 0.73784804, 0.094794154, 0.47791874, 0.86496925,
+ -0.53376222, 0.85315156, 0.10288584, 0.86684, -0.011186242, 0.10513687,
+ 0.87825835, 0.59929144, 0.62827742, 0.18899453, 0.31440187, 0.99059987,
+ 0.87170351, -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
+ 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, -0.66609079,
+ 0.59098077, 0.73017097, 0.74604273, 0.32882881, -0.17503482, 0.22396147,
+ 0.19379807, 0.29120302, 0.077113032, -0.70331609, 0.15804303, -0.93407321,
+ 0.40182066, 0.036301374, 0.66521823, 0.0300982, -0.7747041, -0.02038002,
+ 0.020698071, -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
+ -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, 0.43519354,
+ 0.14744234, 0.62589407, 0.1653645, -0.10651493, -0.045277178, 0.99032974,
+ -0.88255352, -0.85147917, 0.28153265, 0.19455957, -0.55479527, -0.56042433,
+ 0.26048636, 0.84702539, 0.47587705, -0.074295521, -0.12287641, 0.70117295,
+ 0.90532446, 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
+ -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, 0.93455386,
+ -0.6324693, -0.083922029};
static float rnn_golden_output[] = {
- 0.496726, 0, 0.965996, 0, 0.0584254, 0,
- 0, 0.12315, 0, 0, 0.612266, 0.456601,
- 0, 0.52286, 1.16099, 0.0291232,
+ 0.496726, 0, 0.965996, 0, 0.0584254, 0, 0, 0.12315,
+ 0, 0, 0.612266, 0.456601, 0, 0.52286, 1.16099, 0.0291232,
- 0, 0, 0.524901, 0, 0, 0,
- 0, 1.02116, 0, 1.35762, 0, 0.356909,
- 0.436415, 0.0355727, 0, 0,
+ 0, 0, 0.524901, 0, 0, 0, 0, 1.02116,
+ 0, 1.35762, 0, 0.356909, 0.436415, 0.0355727, 0, 0,
- 0, 0, 0, 0.262335, 0, 0,
- 0, 1.33992, 0, 2.9739, 0, 0,
- 1.31914, 2.66147, 0, 0,
+ 0, 0, 0, 0.262335, 0, 0, 0, 1.33992,
+ 0, 2.9739, 0, 0, 1.31914, 2.66147, 0, 0,
- 0.942568, 0, 0, 0, 0.025507, 0,
- 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
- 0.8158, 1.21805, 0.586239, 0.25427,
+ 0.942568, 0, 0, 0, 0.025507, 0, 0, 0,
+ 0.321429, 0.569141, 1.25274, 1.57719, 0.8158, 1.21805, 0.586239, 0.25427,
- 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
- 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
- 0, 1.22031, 1.30117, 0.495867,
+ 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, 0.363026, 0,
+ 0.533426, 0, 1.25926, 0.722707, 0, 1.22031, 1.30117, 0.495867,
- 0.222187, 0, 0.72725, 0, 0.767003, 0,
- 0, 0.147835, 0, 0, 0, 0.608758,
- 0.469394, 0.00720298, 0.927537, 0,
+ 0.222187, 0, 0.72725, 0, 0.767003, 0, 0, 0.147835,
+ 0, 0, 0, 0.608758, 0.469394, 0.00720298, 0.927537, 0,
- 0.856974, 0.424257, 0, 0, 0.937329, 0,
- 0, 0, 0.476425, 0, 0.566017, 0.418462,
- 0.141911, 0.996214, 1.13063, 0,
+ 0.856974, 0.424257, 0, 0, 0.937329, 0, 0, 0,
+ 0.476425, 0, 0.566017, 0.418462, 0.141911, 0.996214, 1.13063, 0,
- 0.967899, 0, 0, 0, 0.0831304, 0,
- 0, 1.00378, 0, 0, 0, 1.44818,
- 1.01768, 0.943891, 0.502745, 0,
+ 0.967899, 0, 0, 0, 0.0831304, 0, 0, 1.00378,
+ 0, 0, 0, 1.44818, 1.01768, 0.943891, 0.502745, 0,
- 0.940135, 0, 0, 0, 0, 0,
- 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
- 1.30225, 1.59644, 0.70222, 0,
+ 0.940135, 0, 0, 0, 0, 0, 0, 2.13243,
+ 0, 0.71208, 0.123918, 1.53907, 1.30225, 1.59644, 0.70222, 0,
- 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
- 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
- 0.0454298, 0.300267, 0.562784, 0.395095,
+ 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, 0.343448, 0,
+ 0.107756, 0.614544, 1.44549, 1.52311, 0.0454298, 0.300267, 0.562784, 0.395095,
- 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
- 0, 0, 0, 0.735363, 0.0759267, 1.91017,
- 0.941888, 0, 0, 0,
+ 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, 0, 0,
+ 0, 0.735363, 0.0759267, 1.91017, 0.941888, 0, 0, 0,
- 0, 0, 1.5909, 0, 0, 0,
- 0, 0.5755, 0, 0.184687, 0, 1.56296,
- 0.625285, 0, 0, 0,
+ 0, 0, 1.5909, 0, 0, 0, 0, 0.5755,
+ 0, 0.184687, 0, 1.56296, 0.625285, 0, 0, 0,
- 0, 0, 0.0857888, 0, 0, 0,
- 0, 0.488383, 0.252786, 0, 0, 0,
- 1.02817, 1.85665, 0, 0,
+ 0, 0, 0.0857888, 0, 0, 0, 0, 0.488383,
+ 0.252786, 0, 0, 0, 1.02817, 1.85665, 0, 0,
- 0.00981836, 0, 1.06371, 0, 0, 0,
- 0, 0, 0, 0.290445, 0.316406, 0,
- 0.304161, 1.25079, 0.0707152, 0,
+ 0.00981836, 0, 1.06371, 0, 0, 0, 0, 0,
+ 0, 0.290445, 0.316406, 0, 0.304161, 1.25079, 0.0707152, 0,
- 0.986264, 0.309201, 0, 0, 0, 0,
- 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
- 0.524981, 1.92076, 2.07013, 0.333244,
+ 0.986264, 0.309201, 0, 0, 0, 0, 0, 1.64896,
+ 0.346248, 0, 0.918175, 0.78884, 0.524981, 1.92076, 2.07013, 0.333244,
- 0.415153, 0.210318, 0, 0, 0, 0,
- 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
- 0.628881, 3.58099, 1.49974, 0};
+ 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616,
+ 0, 0.728256, 0.84183, 0.0907453, 0.628881, 3.58099, 1.49974, 0};
} // anonymous namespace
#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
- ACTION(Input) \
- ACTION(Weights) \
- ACTION(RecurrentWeights) \
- ACTION(Bias) \
- ACTION(HiddenStateIn)
+ ACTION(Input) \
+ ACTION(Weights) \
+ ACTION(RecurrentWeights) \
+ ACTION(Bias) \
+ ACTION(HiddenStateIn)
// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
- ACTION(HiddenStateOut) \
- ACTION(Output)
+ ACTION(HiddenStateOut) \
+ ACTION(Output)
class BasicRNNOpModel {
- public:
- BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
- : batches_(batches),
- units_(units),
- input_size_(size),
- activation_(kActivationRelu) {
- std::vector<uint32_t> inputs;
-
- OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
- inputs.push_back(model_.addOperand(&InputTy));
- OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
- inputs.push_back(model_.addOperand(&WeightTy));
- OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
- inputs.push_back(model_.addOperand(&RecurrentWeightTy));
- OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
- inputs.push_back(model_.addOperand(&BiasTy));
- OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
- inputs.push_back(model_.addOperand(&HiddenStateTy));
- OperandType ActionParamTy(Type::INT32, {});
- inputs.push_back(model_.addOperand(&ActionParamTy));
-
- std::vector<uint32_t> outputs;
-
- outputs.push_back(model_.addOperand(&HiddenStateTy));
- OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
- outputs.push_back(model_.addOperand(&OutputTy));
-
- Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
- HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
- HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
- Output_.insert(Output_.end(), batches_ * units_, 0.f);
-
- model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
- model_.identifyInputsAndOutputs(inputs, outputs);
-
- model_.finish();
- }
-
-#define DefineSetter(X) \
- void Set##X(const std::vector<float>& f) { \
- X##_.insert(X##_.end(), f.begin(), f.end()); \
- }
-
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
+ public:
+ BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
+ : batches_(batches), units_(units), input_size_(size), activation_(kActivationRelu) {
+ std::vector<uint32_t> inputs;
+
+ OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
+ inputs.push_back(model_.addOperand(&InputTy));
+ OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
+ inputs.push_back(model_.addOperand(&WeightTy));
+ OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
+ inputs.push_back(model_.addOperand(&RecurrentWeightTy));
+ OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
+ inputs.push_back(model_.addOperand(&BiasTy));
+ OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
+ inputs.push_back(model_.addOperand(&HiddenStateTy));
+ OperandType ActionParamTy(Type::INT32, {});
+ inputs.push_back(model_.addOperand(&ActionParamTy));
+
+ std::vector<uint32_t> outputs;
+
+ outputs.push_back(model_.addOperand(&HiddenStateTy));
+ OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
+ outputs.push_back(model_.addOperand(&OutputTy));
+
+ Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
+ HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
+ HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
+ Output_.insert(Output_.end(), batches_ * units_, 0.f);
+
+ model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
+ model_.identifyInputsAndOutputs(inputs, outputs);
+
+ model_.finish();
+ }
+
+#define DefineSetter(X) \
+ void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
+
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
#undef DefineSetter
- void SetInput(int offset, float* begin, float* end) {
- for (; begin != end; begin++, offset++) {
- Input_[offset] = *begin;
+ void SetInput(int offset, float* begin, float* end) {
+ for (; begin != end; begin++, offset++) {
+ Input_[offset] = *begin;
+ }
}
- }
- void ResetHiddenState() {
- std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
- std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
- }
+ void ResetHiddenState() {
+ std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
+ std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
+ }
- const std::vector<float>& GetOutput() const { return Output_; }
+ const std::vector<float>& GetOutput() const { return Output_; }
- uint32_t input_size() const { return input_size_; }
- uint32_t num_units() const { return units_; }
- uint32_t num_batches() const { return batches_; }
+ uint32_t input_size() const { return input_size_; }
+ uint32_t num_units() const { return units_; }
+ uint32_t num_batches() const { return batches_; }
- void Invoke() {
- ASSERT_TRUE(model_.isValid());
+ void Invoke() {
+ ASSERT_TRUE(model_.isValid());
- HiddenStateIn_.swap(HiddenStateOut_);
+ HiddenStateIn_.swap(HiddenStateOut_);
- Compilation compilation(&model_);
- compilation.finish();
- Execution execution(&compilation);
-#define SetInputOrWeight(X) \
- ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), \
- sizeof(float) * X##_.size()), \
- Result::NO_ERROR);
+ Compilation compilation(&model_);
+ compilation.finish();
+ Execution execution(&compilation);
+#define SetInputOrWeight(X) \
+ ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
+ Result::NO_ERROR);
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
#undef SetInputOrWeight
-#define SetOutput(X) \
- ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), \
- sizeof(float) * X##_.size()), \
- Result::NO_ERROR);
+#define SetOutput(X) \
+ ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
+ Result::NO_ERROR);
- FOR_ALL_OUTPUT_TENSORS(SetOutput);
+ FOR_ALL_OUTPUT_TENSORS(SetOutput);
#undef SetOutput
- ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_,
- sizeof(activation_)),
- Result::NO_ERROR);
+ ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, sizeof(activation_)),
+ Result::NO_ERROR);
- ASSERT_EQ(execution.compute(), Result::NO_ERROR);
- }
+ ASSERT_EQ(execution.compute(), Result::NO_ERROR);
+ }
- private:
- Model model_;
+ private:
+ Model model_;
- const uint32_t batches_;
- const uint32_t units_;
- const uint32_t input_size_;
+ const uint32_t batches_;
+ const uint32_t units_;
+ const uint32_t input_size_;
- const int activation_;
+ const int activation_;
#define DefineTensor(X) std::vector<float> X##_;
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
- FOR_ALL_OUTPUT_TENSORS(DefineTensor);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
+ FOR_ALL_OUTPUT_TENSORS(DefineTensor);
#undef DefineTensor
};
TEST(RNNOpTest, BlackBoxTest) {
- BasicRNNOpModel rnn(2, 16, 8);
- rnn.SetWeights(
- {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
- 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
- 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
- -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
- -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
- -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
- -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
- 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
- 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
- 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
- -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
- 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
- -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
- -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
- 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
- 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
- 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
- -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
- 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
- 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
- -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
- 0.277308, 0.415818});
-
- rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
- -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
- 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
- -0.37609905});
-
- rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1});
-
- rnn.ResetHiddenState();
- const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
- (rnn.input_size() * rnn.num_batches());
-
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = rnn_input + i * rnn.input_size();
- float* batch_end = batch_start + rnn.input_size();
- rnn.SetInput(0, batch_start, batch_end);
- rnn.SetInput(rnn.input_size(), batch_start, batch_end);
-
- rnn.Invoke();
-
- float* golden_start = rnn_golden_output + i * rnn.num_units();
- float* golden_end = golden_start + rnn.num_units();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ BasicRNNOpModel rnn(2, 16, 8);
+ rnn.SetWeights(
+ {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 0.317493,
+ 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 0.448504, 0.317662,
+ 0.523556, -0.323514, 0.480877, 0.333113, -0.757714, -0.674487, -0.643585,
+ 0.217766, -0.0251462, 0.79512, -0.595574, -0.422444, 0.371572, -0.452178,
+ -0.556069, -0.482188, -0.685456, -0.727851, 0.841829, 0.551535, -0.232336,
+ 0.729158, -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 0.306261,
+ -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 0.0354295, 0.566564,
+ -0.485469, -0.620498, 0.832546, 0.697884, -0.279115, 0.294415, -0.584313,
+ 0.548772, 0.0648819, 0.968726, 0.723834, -0.0080452, -0.350386, -0.272803,
+ 0.115121, -0.412644, -0.824713, -0.992843, -0.592904, -0.417893, 0.863791,
+ -0.423461, -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 0.0960841,
+ 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 0.37225, -0.623598,
+ -0.405423, 0.455101, 0.673656, -0.145345, -0.511346, -0.901675, -0.81252,
+ -0.127006, 0.809865, -0.721884, 0.636255, 0.868989, -0.347973, -0.10179,
+ -0.777449, 0.917274, 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872,
+ 0.972934, -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818});
+
+ rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
+ -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, 0.37197268,
+ 0.61957061, 0.3956964, -0.37609905});
+
+ rnn.SetRecurrentWeights(
+ {0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0.1});
+
+ rnn.ResetHiddenState();
+ const int input_sequence_size =
+ sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches());
+
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(rnn.input_size(), batch_start, batch_end);
+
+ rnn.Invoke();
+
+ float* golden_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_end = golden_start + rnn.num_units();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
}
} // namespace wrapper
diff --git a/nn/common/operations/SVDF.cpp b/nn/common/operations/SVDF.cpp
index 844361da2..4009b7735 100644
--- a/nn/common/operations/SVDF.cpp
+++ b/nn/common/operations/SVDF.cpp
@@ -29,8 +29,7 @@ namespace nn {
using namespace hal;
-SVDF::SVDF(const Operation& operation,
- std::vector<RunTimeOperandInfo>& operands) {
+SVDF::SVDF(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
NNTRACE_TRANS("SVDF::SVDF");
input_ = GetInput(operation, operands, kInputTensor);
weights_feature_ = GetInput(operation, operands, kWeightsFeatureTensor);
@@ -39,62 +38,58 @@ SVDF::SVDF(const Operation& operation,
state_in_ = GetInput(operation, operands, kStateInTensor);
params_.rank_ = getScalarData<int>(*GetInput(operation, operands, kRankParam));
- params_.activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int>(
- *GetInput(operation, operands, kActivationParam)));
+ params_.activation_ = static_cast<TfLiteFusedActivation>(
+ getScalarData<int>(*GetInput(operation, operands, kActivationParam)));
state_out_ = GetOutput(operation, operands, kStateOutTensor);
output_ = GetOutput(operation, operands, kOutputTensor);
}
-bool SVDF::Prepare(const Operation &operation,
- std::vector<RunTimeOperandInfo> &operands,
- Shape *stateShape,
- Shape *outputShape) {
- NNTRACE_TRANS("SVDF::Prepare");
- // Check we have all the inputs and outputs we need.
- const int num_inputs = NumInputsWithValues(operation, operands);
-
- NN_CHECK(num_inputs == 6 || num_inputs == 7);
- NN_CHECK_EQ(NumOutputs(operation), 2);
-
- const RunTimeOperandInfo *input =
- GetInput(operation, operands, SVDF::kInputTensor);
- const RunTimeOperandInfo *weights_feature =
- GetInput(operation, operands, SVDF::kWeightsFeatureTensor);
- const RunTimeOperandInfo *weights_time =
- GetInput(operation, operands, SVDF::kWeightsTimeTensor);
-
- // Check all the parameters of tensor match within themselves and match the
- // input configuration.
- const int rank = getScalarData<int>(*GetInput(operation, operands, kRankParam));
- const uint32_t batch_size = SizeOfDimension(input, 0);
- const uint32_t num_filters = SizeOfDimension(weights_feature, 0);
- NN_CHECK_EQ(num_filters % rank, 0);
- const uint32_t num_units = num_filters / rank;
- const uint32_t memory_size = SizeOfDimension(weights_time, 1);
- NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(weights_feature, 1));
- NN_CHECK_EQ(SizeOfDimension(weights_time, 0), num_filters);
-
- const RunTimeOperandInfo *bias =
- GetInput(operation, operands, kBiasTensor);
- if (!IsNullInput(bias)) {
- NN_CHECK_EQ(SizeOfDimension(bias, 0), num_units);
- }
-
- // Resize state.
- const Shape &inputShape = input->shape();
- stateShape->type = inputShape.type;
- stateShape->dimensions = { batch_size, memory_size * num_filters };
- stateShape->offset = inputShape.offset;
- stateShape->scale = inputShape.scale;
-
- // Resize output.
- outputShape->type = inputShape.type;
- outputShape->dimensions = { batch_size, num_units };
- outputShape->offset = inputShape.offset;
- outputShape->scale = inputShape.scale;
-
- return true;
+bool SVDF::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
+ Shape* stateShape, Shape* outputShape) {
+ NNTRACE_TRANS("SVDF::Prepare");
+ // Check we have all the inputs and outputs we need.
+ const int num_inputs = NumInputsWithValues(operation, operands);
+
+ NN_CHECK(num_inputs == 6 || num_inputs == 7);
+ NN_CHECK_EQ(NumOutputs(operation), 2);
+
+ const RunTimeOperandInfo* input = GetInput(operation, operands, SVDF::kInputTensor);
+ const RunTimeOperandInfo* weights_feature =
+ GetInput(operation, operands, SVDF::kWeightsFeatureTensor);
+ const RunTimeOperandInfo* weights_time =
+ GetInput(operation, operands, SVDF::kWeightsTimeTensor);
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ const int rank = getScalarData<int>(*GetInput(operation, operands, kRankParam));
+ const uint32_t batch_size = SizeOfDimension(input, 0);
+ const uint32_t num_filters = SizeOfDimension(weights_feature, 0);
+ NN_CHECK_EQ(num_filters % rank, 0);
+ const uint32_t num_units = num_filters / rank;
+ const uint32_t memory_size = SizeOfDimension(weights_time, 1);
+ NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(weights_feature, 1));
+ NN_CHECK_EQ(SizeOfDimension(weights_time, 0), num_filters);
+
+ const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor);
+ if (!IsNullInput(bias)) {
+ NN_CHECK_EQ(SizeOfDimension(bias, 0), num_units);
+ }
+
+ // Resize state.
+ const Shape& inputShape = input->shape();
+ stateShape->type = inputShape.type;
+ stateShape->dimensions = {batch_size, memory_size * num_filters};
+ stateShape->offset = inputShape.offset;
+ stateShape->scale = inputShape.scale;
+
+ // Resize output.
+ outputShape->type = inputShape.type;
+ outputShape->dimensions = {batch_size, num_units};
+ outputShape->offset = inputShape.offset;
+ outputShape->scale = inputShape.scale;
+
+ return true;
}
bool SVDF::Eval() {
diff --git a/nn/common/operations/SVDFTest.cpp b/nn/common/operations/SVDFTest.cpp
index 864c2eb8a..21f769fc2 100644
--- a/nn/common/operations/SVDFTest.cpp
+++ b/nn/common/operations/SVDFTest.cpp
@@ -30,418 +30,389 @@ namespace wrapper {
namespace {
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
- float max_abs_error=1.e-6) {
- std::vector<Matcher<float>> matchers;
- matchers.reserve(values.size());
- for (const float& v : values) {
- matchers.emplace_back(FloatNear(v, max_abs_error));
- }
- return matchers;
+ float max_abs_error = 1.e-6) {
+ std::vector<Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(FloatNear(v, max_abs_error));
+ }
+ return matchers;
}
} // namespace
using ::testing::ElementsAreArray;
-static float svdf_input[] = {0.12609188, -0.46347019, -0.89598465,
- 0.12609188, -0.46347019, -0.89598465,
+static float svdf_input[] = {
+ 0.12609188, -0.46347019, -0.89598465, 0.12609188, -0.46347019, -0.89598465,
- 0.14278367, -1.64410412, -0.75222826,
- 0.14278367, -1.64410412, -0.75222826,
+ 0.14278367, -1.64410412, -0.75222826, 0.14278367, -1.64410412, -0.75222826,
- 0.49837467, 0.19278903, 0.26584083,
- 0.49837467, 0.19278903, 0.26584083,
+ 0.49837467, 0.19278903, 0.26584083, 0.49837467, 0.19278903, 0.26584083,
- -0.11186574, 0.13164264, -0.05349274,
- -0.11186574, 0.13164264, -0.05349274,
+ -0.11186574, 0.13164264, -0.05349274, -0.11186574, 0.13164264, -0.05349274,
- -0.68892461, 0.37783599, 0.18263303,
- -0.68892461, 0.37783599, 0.18263303,
+ -0.68892461, 0.37783599, 0.18263303, -0.68892461, 0.37783599, 0.18263303,
- -0.81299269, -0.86831826, 1.43940818,
- -0.81299269, -0.86831826, 1.43940818,
+ -0.81299269, -0.86831826, 1.43940818, -0.81299269, -0.86831826, 1.43940818,
- -1.45006323, -0.82251364, -1.69082689,
- -1.45006323, -0.82251364, -1.69082689,
+ -1.45006323, -0.82251364, -1.69082689, -1.45006323, -0.82251364, -1.69082689,
- 0.03966608, -0.24936394, -0.77526885,
- 0.03966608, -0.24936394, -0.77526885,
+ 0.03966608, -0.24936394, -0.77526885, 0.03966608, -0.24936394, -0.77526885,
- 0.11771342, -0.23761693, -0.65898693,
- 0.11771342, -0.23761693, -0.65898693,
+ 0.11771342, -0.23761693, -0.65898693, 0.11771342, -0.23761693, -0.65898693,
- -0.89477462, 1.67204106, -0.53235275,
- -0.89477462, 1.67204106, -0.53235275};
+ -0.89477462, 1.67204106, -0.53235275, -0.89477462, 1.67204106, -0.53235275};
static float svdf_input_rank2[] = {
- 0.12609188, -0.46347019, -0.89598465,
- 0.35867718, 0.36897406, 0.73463392,
+ 0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406, 0.73463392,
- 0.14278367, -1.64410412, -0.75222826,
- -0.57290924, 0.12729003, 0.7567004,
+ 0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003, 0.7567004,
- 0.49837467, 0.19278903, 0.26584083,
- 0.17660543, 0.52949083, -0.77931279,
+ 0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279,
- -0.11186574, 0.13164264, -0.05349274,
- -0.72674477, -0.5683046, 0.55900657,
+ -0.11186574, 0.13164264, -0.05349274, -0.72674477, -0.5683046, 0.55900657,
- -0.68892461, 0.37783599, 0.18263303,
- -0.63690937, 0.44483393, -0.71817774,
+ -0.68892461, 0.37783599, 0.18263303, -0.63690937, 0.44483393, -0.71817774,
- -0.81299269, -0.86831826, 1.43940818,
- -0.95760226, 1.82078898, 0.71135032,
+ -0.81299269, -0.86831826, 1.43940818, -0.95760226, 1.82078898, 0.71135032,
- -1.45006323, -0.82251364, -1.69082689,
- -1.65087092, -1.89238167, 1.54172635,
+ -1.45006323, -0.82251364, -1.69082689, -1.65087092, -1.89238167, 1.54172635,
- 0.03966608, -0.24936394, -0.77526885,
- 2.06740379, -1.51439476, 1.43768692,
+ 0.03966608, -0.24936394, -0.77526885, 2.06740379, -1.51439476, 1.43768692,
- 0.11771342, -0.23761693, -0.65898693,
- 0.31088525, -1.55601168, -0.87661445,
+ 0.11771342, -0.23761693, -0.65898693, 0.31088525, -1.55601168, -0.87661445,
- -0.89477462, 1.67204106, -0.53235275,
- -0.6230064, 0.29819036, 1.06939757,
+ -0.89477462, 1.67204106, -0.53235275, -0.6230064, 0.29819036, 1.06939757,
};
-static float svdf_golden_output[] = {
- 0.014899, -0.0517661, -0.143725, -0.00271883,
- 0.014899, -0.0517661, -0.143725, -0.00271883,
+static float svdf_golden_output[] = {0.014899, -0.0517661, -0.143725, -0.00271883,
+ 0.014899, -0.0517661, -0.143725, -0.00271883,
- 0.068281, -0.162217, -0.152268, 0.00323521,
- 0.068281, -0.162217, -0.152268, 0.00323521,
+ 0.068281, -0.162217, -0.152268, 0.00323521,
+ 0.068281, -0.162217, -0.152268, 0.00323521,
- -0.0317821, -0.0333089, 0.0609602, 0.0333759,
- -0.0317821, -0.0333089, 0.0609602, 0.0333759,
+ -0.0317821, -0.0333089, 0.0609602, 0.0333759,
+ -0.0317821, -0.0333089, 0.0609602, 0.0333759,
- -0.00623099, -0.077701, -0.391193, -0.0136691,
- -0.00623099, -0.077701, -0.391193, -0.0136691,
+ -0.00623099, -0.077701, -0.391193, -0.0136691,
+ -0.00623099, -0.077701, -0.391193, -0.0136691,
- 0.201551, -0.164607, -0.179462, -0.0592739,
- 0.201551, -0.164607, -0.179462, -0.0592739,
+ 0.201551, -0.164607, -0.179462, -0.0592739,
+ 0.201551, -0.164607, -0.179462, -0.0592739,
- 0.0886511, -0.0875401, -0.269283, 0.0281379,
- 0.0886511, -0.0875401, -0.269283, 0.0281379,
+ 0.0886511, -0.0875401, -0.269283, 0.0281379,
+ 0.0886511, -0.0875401, -0.269283, 0.0281379,
- -0.201174, -0.586145, -0.628624, -0.0330412,
- -0.201174, -0.586145, -0.628624, -0.0330412,
+ -0.201174, -0.586145, -0.628624, -0.0330412,
+ -0.201174, -0.586145, -0.628624, -0.0330412,
- -0.0839096, -0.299329, 0.108746, 0.109808,
- -0.0839096, -0.299329, 0.108746, 0.109808,
+ -0.0839096, -0.299329, 0.108746, 0.109808,
+ -0.0839096, -0.299329, 0.108746, 0.109808,
- 0.419114, -0.237824, -0.422627, 0.175115,
- 0.419114, -0.237824, -0.422627, 0.175115,
+ 0.419114, -0.237824, -0.422627, 0.175115,
+ 0.419114, -0.237824, -0.422627, 0.175115,
- 0.36726, -0.522303, -0.456502, -0.175475,
- 0.36726, -0.522303, -0.456502, -0.175475};
+ 0.36726, -0.522303, -0.456502, -0.175475,
+ 0.36726, -0.522303, -0.456502, -0.175475};
static float svdf_golden_output_rank_2[] = {
- -0.09623547, -0.10193135, 0.11083051, -0.0347917,
- 0.1141196, 0.12965347, -0.12652366, 0.01007236,
+ -0.09623547, -0.10193135, 0.11083051, -0.0347917,
+ 0.1141196, 0.12965347, -0.12652366, 0.01007236,
- -0.16396809, -0.21247184, 0.11259045, -0.04156673,
- 0.10132131, -0.06143532, -0.00924693, 0.10084561,
+ -0.16396809, -0.21247184, 0.11259045, -0.04156673,
+ 0.10132131, -0.06143532, -0.00924693, 0.10084561,
- 0.01257364, 0.0506071, -0.19287863, -0.07162561,
- -0.02033747, 0.22673416, 0.15487903, 0.02525555,
+ 0.01257364, 0.0506071, -0.19287863, -0.07162561,
+ -0.02033747, 0.22673416, 0.15487903, 0.02525555,
- -0.1411963, -0.37054959, 0.01774767, 0.05867489,
- 0.09607603, -0.0141301, -0.08995658, 0.12867066,
+ -0.1411963, -0.37054959, 0.01774767, 0.05867489,
+ 0.09607603, -0.0141301, -0.08995658, 0.12867066,
- -0.27142537, -0.16955489, 0.18521598, -0.12528358,
- 0.00331409, 0.11167502, 0.02218599, -0.07309391,
+ -0.27142537, -0.16955489, 0.18521598, -0.12528358,
+ 0.00331409, 0.11167502, 0.02218599, -0.07309391,
- 0.09593632, -0.28361851, -0.0773851, 0.17199151,
- -0.00075242, 0.33691186, -0.1536046, 0.16572715,
+ 0.09593632, -0.28361851, -0.0773851, 0.17199151,
+ -0.00075242, 0.33691186, -0.1536046, 0.16572715,
- -0.27916506, -0.27626723, 0.42615682, 0.3225764,
- -0.37472126, -0.55655634, -0.05013514, 0.289112,
+ -0.27916506, -0.27626723, 0.42615682, 0.3225764,
+ -0.37472126, -0.55655634, -0.05013514, 0.289112,
- -0.24418658, 0.07540751, -0.1940318, -0.08911639,
- 0.00732617, 0.46737891, 0.26449674, 0.24888524,
+ -0.24418658, 0.07540751, -0.1940318, -0.08911639,
+ 0.00732617, 0.46737891, 0.26449674, 0.24888524,
- -0.17225097, -0.54660404, -0.38795233, 0.08389944,
- 0.07736043, -0.28260678, 0.15666828, 1.14949894,
+ -0.17225097, -0.54660404, -0.38795233, 0.08389944,
+ 0.07736043, -0.28260678, 0.15666828, 1.14949894,
- -0.57454878, -0.64704704, 0.73235172, -0.34616736,
- 0.21120001, -0.22927976, 0.02455296, -0.35906726,
+ -0.57454878, -0.64704704, 0.73235172, -0.34616736,
+ 0.21120001, -0.22927976, 0.02455296, -0.35906726,
};
#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
- ACTION(Input) \
- ACTION(WeightsFeature) \
- ACTION(WeightsTime) \
- ACTION(Bias) \
- ACTION(StateIn)
+ ACTION(Input) \
+ ACTION(WeightsFeature) \
+ ACTION(WeightsTime) \
+ ACTION(Bias) \
+ ACTION(StateIn)
// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
- ACTION(StateOut) \
- ACTION(Output)
+ ACTION(StateOut) \
+ ACTION(Output)
// Derived class of SingleOpModel, which is used to test SVDF TFLite op.
class SVDFOpModel {
- public:
- SVDFOpModel(uint32_t batches, uint32_t units, uint32_t input_size,
- uint32_t memory_size, uint32_t rank)
- : batches_(batches),
- units_(units),
- input_size_(input_size),
- memory_size_(memory_size),
- rank_(rank) {
- std::vector<std::vector<uint32_t>> input_shapes{
- {batches_, input_size_}, // Input tensor
- {units_ * rank_, input_size_}, // weights_feature tensor
- {units_ * rank_, memory_size_}, // weights_time tensor
- {units_}, // bias tensor
- {batches_, memory_size * units_ * rank_}, // state in tensor
- };
- std::vector<uint32_t> inputs;
- auto it = input_shapes.begin();
-
- // Input and weights
-#define AddInput(X) \
- OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
- inputs.push_back(model_.addOperand(&X##OpndTy));
-
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
+ public:
+ SVDFOpModel(uint32_t batches, uint32_t units, uint32_t input_size, uint32_t memory_size,
+ uint32_t rank)
+ : batches_(batches),
+ units_(units),
+ input_size_(input_size),
+ memory_size_(memory_size),
+ rank_(rank) {
+ std::vector<std::vector<uint32_t>> input_shapes{
+ {batches_, input_size_}, // Input tensor
+ {units_ * rank_, input_size_}, // weights_feature tensor
+ {units_ * rank_, memory_size_}, // weights_time tensor
+ {units_}, // bias tensor
+ {batches_, memory_size * units_ * rank_}, // state in tensor
+ };
+ std::vector<uint32_t> inputs;
+ auto it = input_shapes.begin();
+
+ // Input and weights
+#define AddInput(X) \
+ OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
+ inputs.push_back(model_.addOperand(&X##OpndTy));
+
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
#undef AddInput
- // Parameters
- OperandType RankParamTy(Type::INT32, {});
- inputs.push_back(model_.addOperand(&RankParamTy));
- OperandType ActivationParamTy(Type::INT32, {});
- inputs.push_back(model_.addOperand(&ActivationParamTy));
+ // Parameters
+ OperandType RankParamTy(Type::INT32, {});
+ inputs.push_back(model_.addOperand(&RankParamTy));
+ OperandType ActivationParamTy(Type::INT32, {});
+ inputs.push_back(model_.addOperand(&ActivationParamTy));
- // Output and other intermediate state
- std::vector<std::vector<uint32_t>> output_shapes{{batches_, memory_size_ * units_ * rank_},
- {batches_, units_}};
- std::vector<uint32_t> outputs;
+ // Output and other intermediate state
+ std::vector<std::vector<uint32_t>> output_shapes{{batches_, memory_size_ * units_ * rank_},
+ {batches_, units_}};
+ std::vector<uint32_t> outputs;
- auto it2 = output_shapes.begin();
+ auto it2 = output_shapes.begin();
-#define AddOutput(X) \
- OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
- outputs.push_back(model_.addOperand(&X##OpndTy));
+#define AddOutput(X) \
+ OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
+ outputs.push_back(model_.addOperand(&X##OpndTy));
- FOR_ALL_OUTPUT_TENSORS(AddOutput);
+ FOR_ALL_OUTPUT_TENSORS(AddOutput);
#undef AddOutput
- Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
- StateIn_.insert(StateIn_.end(), batches_ * units_ * rank_ * memory_size_, 0.f);
+ Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
+ StateIn_.insert(StateIn_.end(), batches_ * units_ * rank_ * memory_size_, 0.f);
- auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
- uint32_t sz = 1;
- for(uint32_t d:dims) { sz *= d; }
- return sz;
- };
+ auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
+ uint32_t sz = 1;
+ for (uint32_t d : dims) {
+ sz *= d;
+ }
+ return sz;
+ };
- it2 = output_shapes.begin();
+ it2 = output_shapes.begin();
#define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
- FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
+ FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
- model_.addOperation(ANEURALNETWORKS_SVDF, inputs, outputs);
- model_.identifyInputsAndOutputs(inputs, outputs);
+ model_.addOperation(ANEURALNETWORKS_SVDF, inputs, outputs);
+ model_.identifyInputsAndOutputs(inputs, outputs);
- model_.finish();
- }
+ model_.finish();
+ }
- void Invoke() {
- ASSERT_TRUE(model_.isValid());
+ void Invoke() {
+ ASSERT_TRUE(model_.isValid());
- Compilation compilation(&model_);
- compilation.finish();
- Execution execution(&compilation);
+ Compilation compilation(&model_);
+ compilation.finish();
+ Execution execution(&compilation);
- StateIn_.swap(StateOut_);
+ StateIn_.swap(StateOut_);
-#define SetInputOrWeight(X) \
- ASSERT_EQ(execution.setInput(SVDF::k##X##Tensor, X##_.data(), \
- sizeof(float) * X##_.size()), \
- Result::NO_ERROR);
+#define SetInputOrWeight(X) \
+ ASSERT_EQ(execution.setInput(SVDF::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
+ Result::NO_ERROR);
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
#undef SetInputOrWeight
-#define SetOutput(X) \
- EXPECT_TRUE(X##_.data() != nullptr); \
- ASSERT_EQ(execution.setOutput(SVDF::k##X##Tensor, X##_.data(), \
- sizeof(float) * X##_.size()), \
- Result::NO_ERROR);
+#define SetOutput(X) \
+ EXPECT_TRUE(X##_.data() != nullptr); \
+ ASSERT_EQ(execution.setOutput(SVDF::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
+ Result::NO_ERROR);
- FOR_ALL_OUTPUT_TENSORS(SetOutput);
+ FOR_ALL_OUTPUT_TENSORS(SetOutput);
#undef SetOutput
- ASSERT_EQ(execution.setInput(SVDF::kRankParam, &rank_, sizeof(rank_)),
- Result::NO_ERROR);
+ ASSERT_EQ(execution.setInput(SVDF::kRankParam, &rank_, sizeof(rank_)), Result::NO_ERROR);
- int activation = TfLiteFusedActivation::kTfLiteActNone;
- ASSERT_EQ(execution.setInput(SVDF::kActivationParam, &activation,
- sizeof(activation)),
- Result::NO_ERROR);
+ int activation = TfLiteFusedActivation::kTfLiteActNone;
+ ASSERT_EQ(execution.setInput(SVDF::kActivationParam, &activation, sizeof(activation)),
+ Result::NO_ERROR);
- ASSERT_EQ(execution.compute(), Result::NO_ERROR);
- }
+ ASSERT_EQ(execution.compute(), Result::NO_ERROR);
+ }
-#define DefineSetter(X) \
- void Set##X(const std::vector<float>& f) { \
- X##_.insert(X##_.end(), f.begin(), f.end()); \
- }
+#define DefineSetter(X) \
+ void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
#undef DefineSetter
- void SetInput(int offset, float* begin, float* end) {
- for (; begin != end; begin++, offset++) {
- Input_[offset] = *begin;
+ void SetInput(int offset, float* begin, float* end) {
+ for (; begin != end; begin++, offset++) {
+ Input_[offset] = *begin;
+ }
}
- }
- // Resets the state of SVDF op by filling it with 0's.
- void ResetState() {
- std::fill(StateIn_.begin(), StateIn_.end(), 0.f);
- std::fill(StateOut_.begin(), StateOut_.end(), 0.f);
- }
+ // Resets the state of SVDF op by filling it with 0's.
+ void ResetState() {
+ std::fill(StateIn_.begin(), StateIn_.end(), 0.f);
+ std::fill(StateOut_.begin(), StateOut_.end(), 0.f);
+ }
- // Extracts the output tensor from the SVDF op.
- const std::vector<float>& GetOutput() const { return Output_; }
+ // Extracts the output tensor from the SVDF op.
+ const std::vector<float>& GetOutput() const { return Output_; }
- int input_size() const { return input_size_; }
- int num_units() const { return units_; }
- int num_batches() const { return batches_; }
+ int input_size() const { return input_size_; }
+ int num_units() const { return units_; }
+ int num_batches() const { return batches_; }
- private:
- Model model_;
+ private:
+ Model model_;
- const uint32_t batches_;
- const uint32_t units_;
- const uint32_t input_size_;
- const uint32_t memory_size_;
- const uint32_t rank_;
+ const uint32_t batches_;
+ const uint32_t units_;
+ const uint32_t input_size_;
+ const uint32_t memory_size_;
+ const uint32_t rank_;
#define DefineTensor(X) std::vector<float> X##_;
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
- FOR_ALL_OUTPUT_TENSORS(DefineTensor);
+ FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
+ FOR_ALL_OUTPUT_TENSORS(DefineTensor);
#undef DefineTensor
};
TEST(SVDFOpTest, BlackBoxTest) {
- SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
- /*memory_size=*/10, /*rank=*/1);
- svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
- 0.22197971, 0.12416199, 0.27901134, 0.27557442,
- 0.3905206, -0.36137494, -0.06634006, -0.10640851});
-
- svdf.SetWeightsTime(
- {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
- 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
-
- 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
- -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
-
- -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
- 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
-
- -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
- -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
-
- svdf.SetBias({});
-
- svdf.ResetState();
- const int svdf_num_batches = svdf.num_batches();
- const int svdf_input_size = svdf.input_size();
- const int svdf_num_units = svdf.num_units();
- const int input_sequence_size =
- sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
- // Going over each input batch, setting the input tensor, invoking the SVDF op
- // and checking the output with the expected golden values.
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
- float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
- svdf.SetInput(0, batch_start, batch_end);
-
- svdf.Invoke();
-
- float* golden_start =
- svdf_golden_output + i * svdf_num_units * svdf_num_batches;
- float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/1);
+ svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 0.22197971, 0.12416199,
+ 0.27901134, 0.27557442, 0.3905206, -0.36137494, -0.06634006,
+ -0.10640851});
+
+ svdf.SetWeightsTime({-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
+
+ svdf.SetBias({});
+
+ svdf.ResetState();
+ const int svdf_num_batches = svdf.num_batches();
+ const int svdf_input_size = svdf.input_size();
+ const int svdf_num_units = svdf.num_units();
+ const int input_sequence_size =
+ sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF op
+ // and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf.SetInput(0, batch_start, batch_end);
+
+ svdf.Invoke();
+
+ float* golden_start = svdf_golden_output + i * svdf_num_units * svdf_num_batches;
+ float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
}
TEST(SVDFOpTest, BlackBoxTestRank2) {
- SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
- /*memory_size=*/10, /*rank=*/2);
- svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
- 0.12416199, 0.15785322, 0.27901134, 0.3905206,
- 0.21931258, -0.36137494, -0.10640851, 0.31053296,
- -0.36118156, -0.0976817, -0.36916667, 0.22197971,
- 0.15294972, 0.38031587, 0.27557442, 0.39635518,
- -0.21580373, -0.06634006, -0.02702999, 0.27072677});
-
- svdf.SetWeightsTime(
- {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
- 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
-
- 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
- -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
-
- -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
- 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
-
- -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
- -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
-
- -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
- 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
-
- -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
- 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
-
- -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
- -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
-
- 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
- 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
-
- svdf.SetBias({});
-
- svdf.ResetState();
- const int svdf_num_batches = svdf.num_batches();
- const int svdf_input_size = svdf.input_size();
- const int svdf_num_units = svdf.num_units();
- const int input_sequence_size =
- sizeof(svdf_input_rank2) / sizeof(float) / (svdf_input_size * svdf_num_batches);
- // Going over each input batch, setting the input tensor, invoking the SVDF op
- // and checking the output with the expected golden values.
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = svdf_input_rank2 + i * svdf_input_size * svdf_num_batches;
- float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
- svdf.SetInput(0, batch_start, batch_end);
-
- svdf.Invoke();
-
- float* golden_start =
- svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches;
- float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/2);
+ svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199,
+ 0.15785322, 0.27901134, 0.3905206, 0.21931258, -0.36137494,
+ -0.10640851, 0.31053296, -0.36118156, -0.0976817, -0.36916667,
+ 0.22197971, 0.15294972, 0.38031587, 0.27557442, 0.39635518,
+ -0.21580373, -0.06634006, -0.02702999, 0.27072677});
+
+ svdf.SetWeightsTime({-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
+
+ -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
+ 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
+
+ -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
+ 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
+
+ -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
+ -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
+
+ 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
+ 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
+
+ svdf.SetBias({});
+
+ svdf.ResetState();
+ const int svdf_num_batches = svdf.num_batches();
+ const int svdf_input_size = svdf.input_size();
+ const int svdf_num_units = svdf.num_units();
+ const int input_sequence_size =
+ sizeof(svdf_input_rank2) / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF op
+ // and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = svdf_input_rank2 + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf.SetInput(0, batch_start, batch_end);
+
+ svdf.Invoke();
+
+ float* golden_start = svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches;
+ float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
}
} // namespace wrapper
diff --git a/nn/driver/cache/BlobCache/BlobCache.cpp b/nn/driver/cache/BlobCache/BlobCache.cpp
index 61130b9ad..e3274da7a 100644
--- a/nn/driver/cache/BlobCache/BlobCache.cpp
+++ b/nn/driver/cache/BlobCache/BlobCache.cpp
@@ -28,7 +28,7 @@
#include <algorithm>
static const char property_value[] = "[HOST]";
#define PROPERTY_VALUE_MAX (sizeof(property_value) - 1)
-static int property_get(const char *key, char *value, const char *default_value) {
+static int property_get(const char* key, char* value, const char* default_value) {
if (!strcmp(key, "ro.build.id")) {
memcpy(value, property_value, PROPERTY_VALUE_MAX);
return PROPERTY_VALUE_MAX;
@@ -57,14 +57,14 @@ static const uint32_t blobCacheVersion = 3;
// BlobCache::Header::mDeviceVersion value
static const uint32_t blobCacheDeviceVersion = 1;
-BlobCache::BlobCache(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize, Policy policy):
- mMaxKeySize(maxKeySize),
- mMaxValueSize(maxValueSize),
- mMaxTotalSize(maxTotalSize),
- mPolicySelect(policy.first),
- mPolicyCapacity(policy.second),
- mTotalSize(0),
- mAccessCount(0) {
+BlobCache::BlobCache(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize, Policy policy)
+ : mMaxKeySize(maxKeySize),
+ mMaxValueSize(maxValueSize),
+ mMaxTotalSize(maxTotalSize),
+ mPolicySelect(policy.first),
+ mPolicyCapacity(policy.second),
+ mTotalSize(0),
+ mAccessCount(0) {
int64_t now = std::chrono::steady_clock::now().time_since_epoch().count();
#ifdef _WIN32
srand(now);
@@ -76,21 +76,21 @@ BlobCache::BlobCache(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize
ALOGV("initializing random seed using %lld", (unsigned long long)now);
}
-void BlobCache::set(const void* key, size_t keySize, const void* value,
- size_t valueSize) {
+void BlobCache::set(const void* key, size_t keySize, const void* value, size_t valueSize) {
if (mMaxKeySize < keySize) {
- ALOGV("set: not caching because the key is too large: %zu (limit: %zu)",
- keySize, mMaxKeySize);
+ ALOGV("set: not caching because the key is too large: %zu (limit: %zu)", keySize,
+ mMaxKeySize);
return;
}
if (mMaxValueSize < valueSize) {
- ALOGV("set: not caching because the value is too large: %zu (limit: %zu)",
- valueSize, mMaxValueSize);
+ ALOGV("set: not caching because the value is too large: %zu (limit: %zu)", valueSize,
+ mMaxValueSize);
return;
}
if (mMaxTotalSize < keySize + valueSize) {
ALOGV("set: not caching because the combined key/value size is too "
- "large: %zu (limit: %zu)", keySize + valueSize, mMaxTotalSize);
+ "large: %zu (limit: %zu)",
+ keySize + valueSize, mMaxTotalSize);
return;
}
if (keySize == 0) {
@@ -127,16 +127,16 @@ void BlobCache::set(const void* key, size_t keySize, const void* value,
continue;
} else {
ALOGV("set: not caching new key/value pair because the "
- "total cache size limit would be exceeded: %zu "
- "(limit: %zu)",
- keySize + valueSize, mMaxTotalSize);
+ "total cache size limit would be exceeded: %zu "
+ "(limit: %zu)",
+ keySize + valueSize, mMaxTotalSize);
break;
}
}
mCacheEntries.insert(index, CacheEntry(keyBlob, valueBlob, ++mAccessCount));
mTotalSize = newTotalSize;
- ALOGV("set: created new cache entry with %zu byte key and %zu byte value",
- keySize, valueSize);
+ ALOGV("set: created new cache entry with %zu byte key and %zu byte value", keySize,
+ valueSize);
} else {
// Update the existing cache entry.
std::shared_ptr<Blob> valueBlob(new Blob(value, valueSize, true));
@@ -157,8 +157,8 @@ void BlobCache::set(const void* key, size_t keySize, const void* value,
continue;
} else {
ALOGV("set: not caching new value because the total cache "
- "size limit would be exceeded: %zu (limit: %zu)",
- keySize + valueSize, mMaxTotalSize);
+ "size limit would be exceeded: %zu (limit: %zu)",
+ keySize + valueSize, mMaxTotalSize);
break;
}
}
@@ -166,26 +166,25 @@ void BlobCache::set(const void* key, size_t keySize, const void* value,
index->setRecency(++mAccessCount);
mTotalSize = newTotalSize;
ALOGV("set: updated existing cache entry with %zu byte key and %zu byte "
- "value", keySize, valueSize);
+ "value",
+ keySize, valueSize);
}
break;
}
}
-size_t BlobCache::get(const void* key, size_t keySize, void* value,
- size_t valueSize) {
- void *dummy;
- return get(key, keySize, &dummy,
- [value, valueSize](size_t allocSize) {
- return (allocSize <= valueSize ? value : nullptr);
- });
+size_t BlobCache::get(const void* key, size_t keySize, void* value, size_t valueSize) {
+ void* dummy;
+ return get(key, keySize, &dummy, [value, valueSize](size_t allocSize) {
+ return (allocSize <= valueSize ? value : nullptr);
+ });
}
size_t BlobCache::get(const void* key, size_t keySize, void** value,
- std::function<void*(size_t)> alloc) {
+ std::function<void*(size_t)> alloc) {
if (mMaxKeySize < keySize) {
- ALOGV("get: not searching because the key is too large: %zu (limit %zu)",
- keySize, mMaxKeySize);
+ ALOGV("get: not searching because the key is too large: %zu (limit %zu)", keySize,
+ mMaxKeySize);
*value = nullptr;
return 0;
}
@@ -201,7 +200,7 @@ size_t BlobCache::get(const void* key, size_t keySize, void** value,
// The key was found. Return the value if we can allocate a buffer.
std::shared_ptr<Blob> valueBlob(index->getValue());
size_t valueBlobSize = valueBlob->getSize();
- void *buf = alloc(valueBlobSize);
+ void* buf = alloc(valueBlobSize);
if (buf != nullptr) {
ALOGV("get: copying %zu bytes to caller's buffer", valueBlobSize);
memcpy(buf, valueBlob->getData(), valueBlobSize);
@@ -220,7 +219,7 @@ static inline size_t align4(size_t size) {
size_t BlobCache::getFlattenedSize() const {
size_t size = align4(sizeof(Header) + PROPERTY_VALUE_MAX);
- for (const CacheEntry& e : mCacheEntries) {
+ for (const CacheEntry& e : mCacheEntries) {
std::shared_ptr<Blob> const& keyBlob = e.getKey();
std::shared_ptr<Blob> const& valueBlob = e.getValue();
size += align4(sizeof(EntryHeader) + keyBlob->getSize() + valueBlob->getSize());
@@ -246,7 +245,7 @@ int BlobCache::flatten(void* buffer, size_t size) const {
// Write cache entries
uint8_t* byteBuffer = reinterpret_cast<uint8_t*>(buffer);
off_t byteOffset = align4(sizeof(Header) + header->mBuildIdLength);
- for (const CacheEntry& e : mCacheEntries) {
+ for (const CacheEntry& e : mCacheEntries) {
std::shared_ptr<Blob> const& keyBlob = e.getKey();
std::shared_ptr<Blob> const& valueBlob = e.getValue();
size_t keySize = keyBlob->getSize();
@@ -295,9 +294,8 @@ int BlobCache::unflatten(void const* buffer, size_t size) {
char buildId[PROPERTY_VALUE_MAX];
int len = property_get("ro.build.id", buildId, "");
if (header->mBlobCacheVersion != blobCacheVersion ||
- header->mDeviceVersion != blobCacheDeviceVersion ||
- len != header->mBuildIdLength ||
- strncmp(buildId, header->mBuildId, len)) {
+ header->mDeviceVersion != blobCacheDeviceVersion || len != header->mBuildIdLength ||
+ strncmp(buildId, header->mBuildId, len)) {
// We treat version mismatches as an empty cache.
return 0;
}
@@ -313,8 +311,7 @@ int BlobCache::unflatten(void const* buffer, size_t size) {
return -EINVAL;
}
- const EntryHeader* eheader = reinterpret_cast<const EntryHeader*>(
- &byteBuffer[byteOffset]);
+ const EntryHeader* eheader = reinterpret_cast<const EntryHeader*>(&byteBuffer[byteOffset]);
size_t keySize = eheader->mKeySize;
size_t valueSize = eheader->mValueSize;
size_t entrySize = sizeof(EntryHeader) + keySize + valueSize;
@@ -349,9 +346,10 @@ size_t BlobCache::findVictim() {
return size_t(blob_random() % (mCacheEntries.size()));
case Select::LRU:
return std::min_element(mCacheEntries.begin(), mCacheEntries.end(),
- [](const CacheEntry &a, const CacheEntry &b) {
+ [](const CacheEntry& a, const CacheEntry& b) {
return a.getRecency() < b.getRecency();
- }) - mCacheEntries.begin();
+ }) -
+ mCacheEntries.begin();
default:
ALOGE("findVictim: unknown mPolicySelect: %d", mPolicySelect);
return 0;
@@ -360,9 +358,8 @@ size_t BlobCache::findVictim() {
size_t BlobCache::findDownTo(size_t newEntrySize, size_t onBehalfOf) {
auto oldEntrySize = [this, onBehalfOf]() -> size_t {
- if (onBehalfOf == NoEntry)
- return 0;
- const auto &entry = mCacheEntries[onBehalfOf];
+ if (onBehalfOf == NoEntry) return 0;
+ const auto& entry = mCacheEntries[onBehalfOf];
return entry.getKey()->getSize() + entry.getValue()->getSize();
};
switch (mPolicyCapacity) {
@@ -421,10 +418,8 @@ bool BlobCache::isCleanable() const {
}
}
-BlobCache::Blob::Blob(const void* data, size_t size, bool copyData) :
- mData(copyData ? malloc(size) : data),
- mSize(size),
- mOwnsData(copyData) {
+BlobCache::Blob::Blob(const void* data, size_t size, bool copyData)
+ : mData(copyData ? malloc(size) : data), mSize(size), mOwnsData(copyData) {
if (data != NULL && copyData) {
memcpy(const_cast<void*>(mData), data, size);
}
@@ -452,21 +447,14 @@ size_t BlobCache::Blob::getSize() const {
return mSize;
}
-BlobCache::CacheEntry::CacheEntry(): mRecency(0) {
-}
+BlobCache::CacheEntry::CacheEntry() : mRecency(0) {}
-BlobCache::CacheEntry::CacheEntry(
- const std::shared_ptr<Blob>& key, const std::shared_ptr<Blob>& value, uint32_t recency):
- mKey(key),
- mValue(value),
- mRecency(recency) {
-}
+BlobCache::CacheEntry::CacheEntry(const std::shared_ptr<Blob>& key,
+ const std::shared_ptr<Blob>& value, uint32_t recency)
+ : mKey(key), mValue(value), mRecency(recency) {}
-BlobCache::CacheEntry::CacheEntry(const CacheEntry& ce):
- mKey(ce.mKey),
- mValue(ce.mValue),
- mRecency(ce.mRecency) {
-}
+BlobCache::CacheEntry::CacheEntry(const CacheEntry& ce)
+ : mKey(ce.mKey), mValue(ce.mValue), mRecency(ce.mRecency) {}
bool BlobCache::CacheEntry::operator<(const CacheEntry& rhs) const {
return *mKey < *rhs.mKey;
@@ -499,4 +487,4 @@ void BlobCache::CacheEntry::setRecency(uint32_t recency) {
mRecency = recency;
}
-} // namespace android
+} // namespace android
diff --git a/nn/driver/cache/BlobCache/BlobCache.h b/nn/driver/cache/BlobCache/BlobCache.h
index d79a23de4..9caa6beec 100644
--- a/nn/driver/cache/BlobCache/BlobCache.h
+++ b/nn/driver/cache/BlobCache/BlobCache.h
@@ -34,7 +34,7 @@ namespace android {
// serialization is non-portable and the data should only be used by the device
// that generated it.
class BlobCache {
-public:
+ public:
enum class Select {
RANDOM, // evict random entries
LRU, // evict least-recently-used entries
@@ -86,8 +86,7 @@ public:
// 0 < keySize
// value != NULL
// 0 < valueSize
- void set(const void* key, size_t keySize, const void* value,
- size_t valueSize);
+ void set(const void* key, size_t keySize, const void* value, size_t valueSize);
// get retrieves from the cache the binary value associated with a given
// binary key. If the key is present in the cache then the length of the
@@ -131,7 +130,7 @@ public:
size_t get(const void* key, size_t keySize, void** value, std::function<void*(size_t)> alloc);
template <typename T>
size_t get(const void* key, size_t keySize, T** value, std::function<void*(size_t)> alloc) {
- void *valueVoid;
+ void* valueVoid;
const size_t size = get(key, keySize, &valueVoid, alloc);
*value = static_cast<T*>(valueVoid);
return size;
@@ -158,7 +157,7 @@ public:
//
int unflatten(void const* buffer, size_t size);
-private:
+ private:
// Copying is disallowed.
BlobCache(const BlobCache&);
void operator=(const BlobCache&);
@@ -204,7 +203,7 @@ private:
// A Blob is an immutable sized unstructured data blob.
class Blob {
- public:
+ public:
Blob(const void* data, size_t size, bool copyData);
~Blob();
@@ -213,7 +212,7 @@ private:
const void* getData() const;
size_t getSize() const;
- private:
+ private:
// Copying is not allowed.
Blob(const Blob&);
void operator=(const Blob&);
@@ -231,9 +230,10 @@ private:
// A CacheEntry is a single key/value pair in the cache.
class CacheEntry {
- public:
+ public:
CacheEntry();
- CacheEntry(const std::shared_ptr<Blob>& key, const std::shared_ptr<Blob>& value, uint32_t recency);
+ CacheEntry(const std::shared_ptr<Blob>& key, const std::shared_ptr<Blob>& value,
+ uint32_t recency);
CacheEntry(const CacheEntry& ce);
bool operator<(const CacheEntry& rhs) const;
@@ -247,8 +247,7 @@ private:
uint32_t getRecency() const;
void setRecency(uint32_t recency);
- private:
-
+ private:
// mKey is the key that identifies the cache entry.
std::shared_ptr<Blob> mKey;
@@ -349,6 +348,6 @@ private:
std::vector<CacheEntry> mCacheEntries;
};
-}
+} // namespace android
#endif // ANDROID_FRAMEWORKS_ML_NN_DRIVER_CACHE_BLOB_CACHE_BLOB_CACHE_H
diff --git a/nn/driver/cache/BlobCache/BlobCache_test.cpp b/nn/driver/cache/BlobCache/BlobCache_test.cpp
index 26b915cbb..2635fcc06 100644
--- a/nn/driver/cache/BlobCache/BlobCache_test.cpp
+++ b/nn/driver/cache/BlobCache/BlobCache_test.cpp
@@ -29,15 +29,12 @@
namespace android {
-template<typename T> using sp = std::shared_ptr<T>;
+template <typename T>
+using sp = std::shared_ptr<T>;
class BlobCacheTest : public ::testing::TestWithParam<BlobCache::Policy> {
-protected:
-
- enum {
- OK = 0,
- BAD_VALUE = -EINVAL
- };
+ protected:
+ enum { OK = 0, BAD_VALUE = -EINVAL };
enum {
MAX_KEY_SIZE = 6,
@@ -49,25 +46,25 @@ protected:
mBC.reset(new BlobCache(MAX_KEY_SIZE, MAX_VALUE_SIZE, MAX_TOTAL_SIZE, GetParam()));
}
- virtual void TearDown() {
- mBC.reset();
- }
+ virtual void TearDown() { mBC.reset(); }
std::unique_ptr<BlobCache> mBC;
};
-INSTANTIATE_TEST_CASE_P(Policy, BlobCacheTest,
- ::testing::Values(BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::HALVE),
- BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::HALVE),
+INSTANTIATE_TEST_CASE_P(
+ Policy, BlobCacheTest,
+ ::testing::Values(
+ BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::HALVE),
+ BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::HALVE),
- BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT),
- BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT),
+ BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT),
+ BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT),
- BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT_HALVE),
- BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT_HALVE)));
+ BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT_HALVE),
+ BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT_HALVE)));
TEST_P(BlobCacheTest, CacheSingleValueSucceeds) {
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf, 4));
ASSERT_EQ('e', buf[0]);
@@ -77,7 +74,7 @@ TEST_P(BlobCacheTest, CacheSingleValueSucceeds) {
}
TEST_P(BlobCacheTest, CacheTwoValuesSucceeds) {
- unsigned char buf[2] = { 0xee, 0xee };
+ unsigned char buf[2] = {0xee, 0xee};
mBC->set("ab", 2, "cd", 2);
mBC->set("ef", 2, "gh", 2);
ASSERT_EQ(size_t(2), mBC->get("ab", 2, buf, 2));
@@ -89,7 +86,7 @@ TEST_P(BlobCacheTest, CacheTwoValuesSucceeds) {
}
TEST_P(BlobCacheTest, CacheTwoValuesMallocSucceeds) {
- unsigned char *bufPtr;
+ unsigned char* bufPtr;
mBC->set("ab", 2, "cd", 2);
mBC->set("ef", 2, "gh", 2);
@@ -109,9 +106,9 @@ TEST_P(BlobCacheTest, CacheTwoValuesMallocSucceeds) {
}
TEST_P(BlobCacheTest, GetOnlyWritesInsideBounds) {
- unsigned char buf[6] = { 0xee, 0xee, 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[6] = {0xee, 0xee, 0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
- ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf+1, 4));
+ ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf + 1, 4));
ASSERT_EQ(0xee, buf[0]);
ASSERT_EQ('e', buf[1]);
ASSERT_EQ('f', buf[2]);
@@ -121,7 +118,7 @@ TEST_P(BlobCacheTest, GetOnlyWritesInsideBounds) {
}
TEST_P(BlobCacheTest, GetOnlyWritesIfBufferIsLargeEnough) {
- unsigned char buf[3] = { 0xee, 0xee, 0xee };
+ unsigned char buf[3] = {0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf, 3));
ASSERT_EQ(0xee, buf[0]);
@@ -130,13 +127,13 @@ TEST_P(BlobCacheTest, GetOnlyWritesIfBufferIsLargeEnough) {
}
TEST_P(BlobCacheTest, GetWithFailedAllocator) {
- unsigned char buf[3] = { 0xee, 0xee, 0xee };
+ unsigned char buf[3] = {0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
// If allocator fails, verify that we set the value pointer to
// nullptr, and that we do not modify the buffer that the value
// pointer originally pointed to.
- unsigned char *bufPtr = &buf[0];
+ unsigned char* bufPtr = &buf[0];
ASSERT_EQ(size_t(4), mBC->get("abcd", 4, &bufPtr, [](size_t) -> void* { return nullptr; }));
ASSERT_EQ(nullptr, bufPtr);
ASSERT_EQ(0xee, buf[0]);
@@ -150,7 +147,7 @@ TEST_P(BlobCacheTest, GetDoesntAccessNullBuffer) {
}
TEST_P(BlobCacheTest, MultipleSetsCacheLatestValue) {
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
mBC->set("abcd", 4, "ijkl", 4);
ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf, 4));
@@ -161,9 +158,9 @@ TEST_P(BlobCacheTest, MultipleSetsCacheLatestValue) {
}
TEST_P(BlobCacheTest, SecondSetKeepsFirstValueIfTooLarge) {
- unsigned char buf[MAX_VALUE_SIZE+1] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[MAX_VALUE_SIZE + 1] = {0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
- mBC->set("abcd", 4, buf, MAX_VALUE_SIZE+1);
+ mBC->set("abcd", 4, buf, MAX_VALUE_SIZE + 1);
ASSERT_EQ(size_t(4), mBC->get("abcd", 4, buf, 4));
ASSERT_EQ('e', buf[0]);
ASSERT_EQ('f', buf[1]);
@@ -172,14 +169,14 @@ TEST_P(BlobCacheTest, SecondSetKeepsFirstValueIfTooLarge) {
}
TEST_P(BlobCacheTest, DoesntCacheIfKeyIsTooBig) {
- char key[MAX_KEY_SIZE+1];
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
- for (int i = 0; i < MAX_KEY_SIZE+1; i++) {
+ char key[MAX_KEY_SIZE + 1];
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
+ for (int i = 0; i < MAX_KEY_SIZE + 1; i++) {
key[i] = 'a';
}
- mBC->set(key, MAX_KEY_SIZE+1, "bbbb", 4);
+ mBC->set(key, MAX_KEY_SIZE + 1, "bbbb", 4);
- ASSERT_EQ(size_t(0), mBC->get(key, MAX_KEY_SIZE+1, buf, 4));
+ ASSERT_EQ(size_t(0), mBC->get(key, MAX_KEY_SIZE + 1, buf, 4));
ASSERT_EQ(0xee, buf[0]);
ASSERT_EQ(0xee, buf[1]);
ASSERT_EQ(0xee, buf[2]);
@@ -188,12 +185,12 @@ TEST_P(BlobCacheTest, DoesntCacheIfKeyIsTooBig) {
// If key is too large, verify that we do not call the allocator,
// that we set the value pointer to nullptr, and that we do not
// modify the buffer that the value pointer originally pointed to.
- unsigned char *bufPtr = &buf[0];
+ unsigned char* bufPtr = &buf[0];
bool calledAlloc = false;
- ASSERT_EQ(size_t(0), mBC->get(key, MAX_KEY_SIZE+1, &bufPtr,
- [&calledAlloc](size_t) -> void* {
- calledAlloc = true;
- return nullptr; }));
+ ASSERT_EQ(size_t(0), mBC->get(key, MAX_KEY_SIZE + 1, &bufPtr, [&calledAlloc](size_t) -> void* {
+ calledAlloc = true;
+ return nullptr;
+ }));
ASSERT_EQ(false, calledAlloc);
ASSERT_EQ(nullptr, bufPtr);
ASSERT_EQ(0xee, buf[0]);
@@ -203,16 +200,16 @@ TEST_P(BlobCacheTest, DoesntCacheIfKeyIsTooBig) {
}
TEST_P(BlobCacheTest, DoesntCacheIfValueIsTooBig) {
- unsigned char buf[MAX_VALUE_SIZE+1];
- for (int i = 0; i < MAX_VALUE_SIZE+1; i++) {
+ unsigned char buf[MAX_VALUE_SIZE + 1];
+ for (int i = 0; i < MAX_VALUE_SIZE + 1; i++) {
buf[i] = 'b';
}
- mBC->set("abcd", 4, buf, MAX_VALUE_SIZE+1);
- for (int i = 0; i < MAX_VALUE_SIZE+1; i++) {
+ mBC->set("abcd", 4, buf, MAX_VALUE_SIZE + 1);
+ for (int i = 0; i < MAX_VALUE_SIZE + 1; i++) {
buf[i] = 0xee;
}
- ASSERT_EQ(size_t(0), mBC->get("abcd", 4, buf, MAX_VALUE_SIZE+1));
- for (int i = 0; i < MAX_VALUE_SIZE+1; i++) {
+ ASSERT_EQ(size_t(0), mBC->get("abcd", 4, buf, MAX_VALUE_SIZE + 1));
+ for (int i = 0; i < MAX_VALUE_SIZE + 1; i++) {
SCOPED_TRACE(i);
ASSERT_EQ(0xee, buf[i]);
}
@@ -240,7 +237,7 @@ TEST_P(BlobCacheTest, DoesntCacheIfKeyValuePairIsTooBig) {
TEST_P(BlobCacheTest, CacheMaxKeySizeSucceeds) {
char key[MAX_KEY_SIZE];
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
for (int i = 0; i < MAX_KEY_SIZE; i++) {
key[i] = 'a';
}
@@ -261,8 +258,7 @@ TEST_P(BlobCacheTest, CacheMaxValueSizeSucceeds) {
for (int i = 0; i < MAX_VALUE_SIZE; i++) {
buf[i] = 0xee;
}
- ASSERT_EQ(size_t(MAX_VALUE_SIZE), mBC->get("abcd", 4, buf,
- MAX_VALUE_SIZE));
+ ASSERT_EQ(size_t(MAX_VALUE_SIZE), mBC->get("abcd", 4, buf, MAX_VALUE_SIZE));
for (int i = 0; i < MAX_VALUE_SIZE; i++) {
SCOPED_TRACE(i);
ASSERT_EQ('b', buf[i]);
@@ -289,7 +285,7 @@ TEST_P(BlobCacheTest, CacheMaxKeyValuePairSizeSucceeds) {
}
TEST_P(BlobCacheTest, CacheMinKeyAndValueSizeSucceeds) {
- unsigned char buf[1] = { 0xee };
+ unsigned char buf[1] = {0xee};
mBC->set("x", 1, "y", 1);
ASSERT_EQ(size_t(1), mBC->get("x", 1, buf, 1));
ASSERT_EQ('y', buf[0]);
@@ -328,17 +324,16 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitHalvesCacheSize) {
// Count the number of entries in the cache; and check which
// entries they are.
int numCached = 0;
- for (int i = 0; i < maxEntries+1; i++) {
+ for (int i = 0; i < maxEntries + 1; i++) {
uint8_t k = i;
bool found = (mBC->get(&k, 1, NULL, 0) == 1);
- if (found)
- numCached++;
+ if (found) numCached++;
if (GetParam().first == BlobCache::Select::LRU) {
SCOPED_TRACE(i);
- ASSERT_EQ(found, i >= maxEntries/2);
+ ASSERT_EQ(found, i >= maxEntries / 2);
}
}
- ASSERT_EQ(maxEntries/2 + 1, numCached);
+ ASSERT_EQ(maxEntries / 2 + 1, numCached);
}
TEST_P(BlobCacheTest, ExceedingTotalLimitJustFitsSmallEntry) {
@@ -358,10 +353,9 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitJustFitsSmallEntry) {
}
// Count the number of entries in the cache.
int numCached = 0;
- for (int i = 0; i < maxEntries+1; i++) {
+ for (int i = 0; i < maxEntries + 1; i++) {
uint8_t k = i;
- if (mBC->get(&k, 1, NULL, 0) == 1)
- numCached++;
+ if (mBC->get(&k, 1, NULL, 0) == 1) numCached++;
}
ASSERT_EQ(maxEntries, numCached);
}
@@ -376,18 +370,17 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitFitsBigEntry) {
}
// Insert one more entry, causing a cache overflow.
const int bigValueSize = std::min((MAX_TOTAL_SIZE * 3) / 4 - 1, int(MAX_VALUE_SIZE));
- ASSERT_GT(bigValueSize+1, MAX_TOTAL_SIZE / 2); // Check testing assumption
+ ASSERT_GT(bigValueSize + 1, MAX_TOTAL_SIZE / 2); // Check testing assumption
{
unsigned char buf[MAX_VALUE_SIZE];
- for (int i = 0; i < bigValueSize; i++)
- buf[i] = 0xee;
+ for (int i = 0; i < bigValueSize; i++) buf[i] = 0xee;
uint8_t k = maxEntries;
mBC->set(&k, 1, buf, bigValueSize);
}
// Count the number and size of entries in the cache.
int numCached = 0;
size_t sizeCached = 0;
- for (int i = 0; i < maxEntries+1; i++) {
+ for (int i = 0; i < maxEntries + 1; i++) {
uint8_t k = i;
size_t size = mBC->get(&k, 1, NULL, 0);
if (size) {
@@ -399,8 +392,8 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitFitsBigEntry) {
case BlobCache::Capacity::HALVE:
// New value is too big for this cleaning algorithm. So
// we cleaned the cache, but did not insert the new value.
- ASSERT_EQ(maxEntries/2, numCached);
- ASSERT_EQ(size_t((maxEntries/2)*2), sizeCached);
+ ASSERT_EQ(maxEntries / 2, numCached);
+ ASSERT_EQ(size_t((maxEntries / 2) * 2), sizeCached);
break;
case BlobCache::Capacity::FIT:
case BlobCache::Capacity::FIT_HALVE: {
@@ -433,21 +426,20 @@ TEST_P(BlobCacheTest, FailedGetWithAllocator) {
// allocator, that we set the value pointer to nullptr, and that
// we do not modify the buffer that the value pointer originally
// pointed to.
- unsigned char buf[1] = { 0xee };
- unsigned char *bufPtr = &buf[0];
+ unsigned char buf[1] = {0xee};
+ unsigned char* bufPtr = &buf[0];
bool calledAlloc = false;
- ASSERT_EQ(size_t(0), mBC->get("a", 1, &bufPtr,
- [&calledAlloc](size_t) -> void* {
- calledAlloc = true;
- return nullptr; }));
+ ASSERT_EQ(size_t(0), mBC->get("a", 1, &bufPtr, [&calledAlloc](size_t) -> void* {
+ calledAlloc = true;
+ return nullptr;
+ }));
ASSERT_EQ(false, calledAlloc);
ASSERT_EQ(nullptr, bufPtr);
ASSERT_EQ(0xee, buf[0]);
}
TEST_P(BlobCacheTest, ExceedingTotalLimitRemovesLRUEntries) {
- if (GetParam().first != BlobCache::Select::LRU)
- return; // test doesn't apply for this policy
+ if (GetParam().first != BlobCache::Select::LRU) return; // test doesn't apply for this policy
// Fill up the entire cache with 1 char key/value pairs.
static const int maxEntries = MAX_TOTAL_SIZE / 2;
@@ -487,20 +479,19 @@ TEST_P(BlobCacheTest, ExceedingTotalLimitRemovesLRUEntries) {
for (int i = 0; i < maxEntries; i++) {
uint8_t k = accessSequence[i];
bool found = (mBC->get(&k, 1, NULL, 0) == 1);
- if (foundAny == found)
- continue;
+ if (foundAny == found) continue;
if (!foundAny) {
// found == true, so we just discovered j == i
foundAny = true;
} else {
// foundAny == true, found == false -- oops
- FAIL() << "found [" << i-1 << "]th entry but not [" << i << "]th entry";
+ FAIL() << "found [" << i - 1 << "]th entry but not [" << i << "]th entry";
}
}
}
class BlobCacheFlattenTest : public BlobCacheTest {
-protected:
+ protected:
virtual void SetUp() {
BlobCacheTest::SetUp();
mBC2.reset(new BlobCache(MAX_KEY_SIZE, MAX_VALUE_SIZE, MAX_TOTAL_SIZE, GetParam()));
@@ -522,18 +513,20 @@ protected:
sp<BlobCache> mBC2;
};
-INSTANTIATE_TEST_CASE_P(Policy, BlobCacheFlattenTest,
- ::testing::Values(BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::HALVE),
- BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::HALVE),
+INSTANTIATE_TEST_CASE_P(
+ Policy, BlobCacheFlattenTest,
+ ::testing::Values(
+ BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::HALVE),
+ BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::HALVE),
- BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT),
- BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT),
+ BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT),
+ BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT),
- BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT_HALVE),
- BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT_HALVE)));
+ BlobCache::Policy(BlobCache::Select::RANDOM, BlobCache::Capacity::FIT_HALVE),
+ BlobCache::Policy(BlobCache::Select::LRU, BlobCache::Capacity::FIT_HALVE)));
TEST_P(BlobCacheFlattenTest, FlattenOneValue) {
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
roundTrip();
ASSERT_EQ(size_t(4), mBC2->get("abcd", 4, buf, 4));
@@ -601,7 +594,7 @@ TEST_P(BlobCacheFlattenTest, FlattenCatchesBufferTooSmall) {
}
TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadMagic) {
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
size_t size = mBC->getFlattenedSize();
@@ -618,7 +611,7 @@ TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadMagic) {
}
TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadBlobCacheVersion) {
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
size_t size = mBC->getFlattenedSize();
@@ -637,7 +630,7 @@ TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadBlobCacheVersion) {
}
TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadBlobCacheDeviceVersion) {
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
size_t size = mBC->getFlattenedSize();
@@ -656,7 +649,7 @@ TEST_P(BlobCacheFlattenTest, UnflattenCatchesBadBlobCacheDeviceVersion) {
}
TEST_P(BlobCacheFlattenTest, UnflattenCatchesBufferTooSmall) {
- unsigned char buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ unsigned char buf[4] = {0xee, 0xee, 0xee, 0xee};
mBC->set("abcd", 4, "efgh", 4);
size_t size = mBC->getFlattenedSize();
@@ -673,4 +666,4 @@ TEST_P(BlobCacheFlattenTest, UnflattenCatchesBufferTooSmall) {
ASSERT_EQ(size_t(0), mBC2->get("abcd", 4, buf, 4));
}
-} // namespace android
+} // namespace android
diff --git a/nn/driver/cache/nnCache/nnCache.cpp b/nn/driver/cache/nnCache/nnCache.cpp
index 702688bf4..9f9e9be8b 100644
--- a/nn/driver/cache/nnCache/nnCache.cpp
+++ b/nn/driver/cache/nnCache/nnCache.cpp
@@ -39,15 +39,15 @@ namespace android {
//
// NNCache definition
//
-NNCache::NNCache() :
- mInitialized(false),
- mMaxKeySize(0), mMaxValueSize(0), mMaxTotalSize(0),
- mPolicy(defaultPolicy()),
- mSavePending(false) {
-}
+NNCache::NNCache()
+ : mInitialized(false),
+ mMaxKeySize(0),
+ mMaxValueSize(0),
+ mMaxTotalSize(0),
+ mPolicy(defaultPolicy()),
+ mSavePending(false) {}
-NNCache::~NNCache() {
-}
+NNCache::~NNCache() {}
NNCache NNCache::sCache;
@@ -72,8 +72,7 @@ void NNCache::terminate() {
mInitialized = false;
}
-void NNCache::setBlob(const void* key, ssize_t keySize,
- const void* value, ssize_t valueSize) {
+void NNCache::setBlob(const void* key, ssize_t keySize, const void* value, ssize_t valueSize) {
std::lock_guard<std::mutex> lock(mMutex);
if (keySize < 0 || valueSize < 0) {
@@ -100,8 +99,7 @@ void NNCache::setBlob(const void* key, ssize_t keySize,
}
}
-ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
- void* value, ssize_t valueSize) {
+ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void* value, ssize_t valueSize) {
std::lock_guard<std::mutex> lock(mMutex);
if (keySize < 0 || valueSize < 0) {
@@ -116,8 +114,8 @@ ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
return 0;
}
-ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
- void** value, std::function<void*(size_t)> alloc) {
+ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void** value,
+ std::function<void*(size_t)> alloc) {
std::lock_guard<std::mutex> lock(mMutex);
if (keySize < 0) {
@@ -175,26 +173,23 @@ void NNCache::saveBlobCacheLocked() {
// The file exists, delete it and try again.
if (unlink(fname) == -1) {
// No point in retrying if the unlink failed.
- ALOGE("error unlinking cache file %s: %s (%d)", fname,
- strerror(errno), errno);
+ ALOGE("error unlinking cache file %s: %s (%d)", fname, strerror(errno), errno);
return;
}
// Retry now that we've unlinked the file.
fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
}
if (fd == -1) {
- ALOGE("error creating cache file %s: %s (%d)", fname,
- strerror(errno), errno);
+ ALOGE("error creating cache file %s: %s (%d)", fname, strerror(errno), errno);
return;
}
}
size_t fileSize = headerSize + cacheSize;
- uint8_t* buf = new uint8_t [fileSize];
+ uint8_t* buf = new uint8_t[fileSize];
if (!buf) {
- ALOGE("error allocating buffer for cache contents: %s (%d)",
- strerror(errno), errno);
+ ALOGE("error allocating buffer for cache contents: %s (%d)", strerror(errno), errno);
close(fd);
unlink(fname);
return;
@@ -202,9 +197,8 @@ void NNCache::saveBlobCacheLocked() {
int err = mBlobCache->flatten(buf + headerSize, cacheSize);
if (err < 0) {
- ALOGE("error writing cache contents: %s (%d)", strerror(-err),
- -err);
- delete [] buf;
+ ALOGE("error writing cache contents: %s (%d)", strerror(-err), -err);
+ delete[] buf;
close(fd);
unlink(fname);
return;
@@ -216,15 +210,14 @@ void NNCache::saveBlobCacheLocked() {
*crc = crc32c(buf + headerSize, cacheSize);
if (write(fd, buf, fileSize) == -1) {
- ALOGE("error writing cache file: %s (%d)", strerror(errno),
- errno);
- delete [] buf;
+ ALOGE("error writing cache file: %s (%d)", strerror(errno), errno);
+ delete[] buf;
close(fd);
unlink(fname);
return;
}
- delete [] buf;
+ delete[] buf;
fchmod(fd, S_IRUSR);
close(fd);
}
@@ -237,8 +230,8 @@ void NNCache::loadBlobCacheLocked() {
int fd = open(mFilename.c_str(), O_RDONLY, 0);
if (fd == -1) {
if (errno != ENOENT) {
- ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(),
- strerror(errno), errno);
+ ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(), strerror(errno),
+ errno);
}
return;
}
@@ -253,17 +246,15 @@ void NNCache::loadBlobCacheLocked() {
// Sanity check the size before trying to mmap it.
size_t fileSize = statBuf.st_size;
if (fileSize > mMaxTotalSize * 2) {
- ALOGE("cache file is too large: %#" PRIx64,
- static_cast<off64_t>(statBuf.st_size));
+ ALOGE("cache file is too large: %#" PRIx64, static_cast<off64_t>(statBuf.st_size));
close(fd);
return;
}
- uint8_t* buf = reinterpret_cast<uint8_t*>(mmap(NULL, fileSize,
- PROT_READ, MAP_PRIVATE, fd, 0));
+ uint8_t* buf =
+ reinterpret_cast<uint8_t*>(mmap(NULL, fileSize, PROT_READ, MAP_PRIVATE, fd, 0));
if (buf == MAP_FAILED) {
- ALOGE("error mmaping cache file: %s (%d)", strerror(errno),
- errno);
+ ALOGE("error mmaping cache file: %s (%d)", strerror(errno), errno);
close(fd);
return;
}
@@ -284,8 +275,7 @@ void NNCache::loadBlobCacheLocked() {
int err = mBlobCache->unflatten(buf + headerSize, cacheSize);
if (err < 0) {
- ALOGE("error reading cache contents: %s (%d)", strerror(-err),
- -err);
+ ALOGE("error reading cache contents: %s (%d)", strerror(-err), -err);
munmap(buf, fileSize);
close(fd);
return;
@@ -297,5 +287,5 @@ void NNCache::loadBlobCacheLocked() {
}
// ----------------------------------------------------------------------------
-}; // namespace android
+}; // namespace android
// ----------------------------------------------------------------------------
diff --git a/nn/driver/cache/nnCache/nnCache.h b/nn/driver/cache/nnCache/nnCache.h
index 43f4ec350..a0ec6ee20 100644
--- a/nn/driver/cache/nnCache/nnCache.h
+++ b/nn/driver/cache/nnCache/nnCache.h
@@ -29,8 +29,7 @@ namespace android {
// ----------------------------------------------------------------------------
class NNCache {
-public:
-
+ public:
typedef BlobCache::Select Select;
typedef BlobCache::Capacity Capacity;
typedef BlobCache::Policy Policy;
@@ -59,19 +58,17 @@ public:
void terminate();
// setBlob attempts to insert a new key/value blob pair into the cache.
- void setBlob(const void* key, ssize_t keySize, const void* value,
- ssize_t valueSize);
+ void setBlob(const void* key, ssize_t keySize, const void* value, ssize_t valueSize);
// getBlob attempts to retrieve the value blob associated with a given key
// blob from cache.
- ssize_t getBlob(const void* key, ssize_t keySize,
- void* value, ssize_t valueSize);
- ssize_t getBlob(const void* key, ssize_t keySize,
- void** value, std::function<void*(size_t)> alloc);
+ ssize_t getBlob(const void* key, ssize_t keySize, void* value, ssize_t valueSize);
+ ssize_t getBlob(const void* key, ssize_t keySize, void** value,
+ std::function<void*(size_t)> alloc);
template <typename T>
- ssize_t getBlob(const void* key, size_t keySize,
- T** value, std::function<void*(size_t)> alloc) {
- void *valueVoid;
+ ssize_t getBlob(const void* key, size_t keySize, T** value,
+ std::function<void*(size_t)> alloc) {
+ void* valueVoid;
const ssize_t size = getBlob(key, keySize, &valueVoid, alloc);
*value = static_cast<T*>(valueVoid);
return size;
@@ -81,7 +78,7 @@ public:
// cache contents from one program invocation to another.
void setCacheFilename(const char* filename);
-private:
+ private:
// Creation and (the lack of) destruction is handled internally.
NNCache();
~NNCache();
@@ -153,7 +150,7 @@ private:
};
// ----------------------------------------------------------------------------
-}; // namespace android
+}; // namespace android
// ----------------------------------------------------------------------------
#endif // ANDROID_FRAMEWORKS_ML_NN_DRIVER_CACHE_NN_CACHE_NN_CACHE_H
diff --git a/nn/driver/cache/nnCache/nnCache_test.cpp b/nn/driver/cache/nnCache/nnCache_test.cpp
index 059160361..7ef2ccc8a 100644
--- a/nn/driver/cache/nnCache/nnCache_test.cpp
+++ b/nn/driver/cache/nnCache/nnCache_test.cpp
@@ -36,10 +36,8 @@ static const size_t maxTotalSize = 2 * 1024 * 1024;
namespace android {
class NNCacheTest : public ::testing::TestWithParam<NNCache::Policy> {
-protected:
- virtual void SetUp() {
- mCache = NNCache::get();
- }
+ protected:
+ virtual void SetUp() { mCache = NNCache::get(); }
virtual void TearDown() {
mCache->setCacheFilename("");
@@ -49,18 +47,19 @@ protected:
NNCache* mCache;
};
-INSTANTIATE_TEST_CASE_P(Policy, NNCacheTest,
- ::testing::Values(NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::HALVE),
- NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::HALVE),
+INSTANTIATE_TEST_CASE_P(
+ Policy, NNCacheTest,
+ ::testing::Values(NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::HALVE),
+ NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::HALVE),
- NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::FIT),
- NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::FIT),
+ NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::FIT),
+ NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::FIT),
- NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::FIT_HALVE),
- NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::FIT_HALVE)));
+ NNCache::Policy(NNCache::Select::RANDOM, NNCache::Capacity::FIT_HALVE),
+ NNCache::Policy(NNCache::Select::LRU, NNCache::Capacity::FIT_HALVE)));
TEST_P(NNCacheTest, UninitializedCacheAlwaysMisses) {
- uint8_t buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ uint8_t buf[4] = {0xee, 0xee, 0xee, 0xee};
mCache->setBlob("abcd", 4, "efgh", 4);
ASSERT_EQ(0, mCache->getBlob("abcd", 4, buf, 4));
ASSERT_EQ(0xee, buf[0]);
@@ -70,7 +69,7 @@ TEST_P(NNCacheTest, UninitializedCacheAlwaysMisses) {
}
TEST_P(NNCacheTest, InitializedCacheAlwaysHits) {
- uint8_t buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ uint8_t buf[4] = {0xee, 0xee, 0xee, 0xee};
mCache->initialize(maxKeySize, maxValueSize, maxTotalSize, GetParam());
mCache->setBlob("abcd", 4, "efgh", 4);
ASSERT_EQ(4, mCache->getBlob("abcd", 4, buf, 4));
@@ -81,7 +80,7 @@ TEST_P(NNCacheTest, InitializedCacheAlwaysHits) {
}
TEST_P(NNCacheTest, TerminatedCacheAlwaysMisses) {
- uint8_t buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ uint8_t buf[4] = {0xee, 0xee, 0xee, 0xee};
mCache->initialize(maxKeySize, maxValueSize, maxTotalSize, GetParam());
mCache->setBlob("abcd", 4, "efgh", 4);
@@ -122,18 +121,17 @@ TEST_P(NNCacheTest, ExceedingTotalLimitFitsBigEntry) {
}
// Insert one more entry, causing a cache overflow.
const int bigValueSize = std::min((MAX_TOTAL_SIZE * 3) / 4 - 1, int(MAX_VALUE_SIZE));
- ASSERT_GT(bigValueSize+1, MAX_TOTAL_SIZE / 2); // Check testing assumption
+ ASSERT_GT(bigValueSize + 1, MAX_TOTAL_SIZE / 2); // Check testing assumption
{
unsigned char buf[MAX_VALUE_SIZE];
- for (int i = 0; i < bigValueSize; i++)
- buf[i] = 0xee;
+ for (int i = 0; i < bigValueSize; i++) buf[i] = 0xee;
uint8_t k = maxEntries;
mCache->setBlob(&k, 1, buf, bigValueSize);
}
// Count the number and size of entries in the cache.
int numCached = 0;
size_t sizeCached = 0;
- for (int i = 0; i < maxEntries+1; i++) {
+ for (int i = 0; i < maxEntries + 1; i++) {
uint8_t k = i;
size_t size = mCache->getBlob(&k, 1, NULL, 0);
if (size) {
@@ -145,8 +143,8 @@ TEST_P(NNCacheTest, ExceedingTotalLimitFitsBigEntry) {
case NNCache::Capacity::HALVE:
// New value is too big for this cleaning algorithm. So
// we cleaned the cache, but did not insert the new value.
- ASSERT_EQ(maxEntries/2, numCached);
- ASSERT_EQ(size_t((maxEntries/2)*2), sizeCached);
+ ASSERT_EQ(maxEntries / 2, numCached);
+ ASSERT_EQ(size_t((maxEntries / 2) * 2), sizeCached);
break;
case NNCache::Capacity::FIT:
case NNCache::Capacity::FIT_HALVE: {
@@ -175,9 +173,7 @@ TEST_P(NNCacheTest, ExceedingTotalLimitFitsBigEntry) {
}
class NNCacheSerializationTest : public NNCacheTest {
-
-protected:
-
+ protected:
virtual void SetUp() {
NNCacheTest::SetUp();
mTempFile.reset(new TemporaryFile());
@@ -190,7 +186,7 @@ protected:
std::unique_ptr<TemporaryFile> mTempFile;
- void yesStringBlob(const char *key, const char *value) {
+ void yesStringBlob(const char* key, const char* value) {
SCOPED_TRACE(key);
uint8_t buf[10];
@@ -206,7 +202,7 @@ protected:
}
}
- void noStringBlob(const char *key) {
+ void noStringBlob(const char* key) {
SCOPED_TRACE(key);
uint8_t buf[10];
@@ -219,11 +215,10 @@ protected:
ASSERT_EQ(0xee, buf[i]);
}
}
-
};
TEST_P(NNCacheSerializationTest, ReinitializedCacheContainsValues) {
- uint8_t buf[4] = { 0xee, 0xee, 0xee, 0xee };
+ uint8_t buf[4] = {0xee, 0xee, 0xee, 0xee};
mCache->setCacheFilename(&mTempFile->path[0]);
mCache->initialize(maxKeySize, maxValueSize, maxTotalSize, GetParam());
mCache->setBlob("abcd", 4, "efgh", 4);
@@ -235,7 +230,7 @@ TEST_P(NNCacheSerializationTest, ReinitializedCacheContainsValues) {
// - we do not modify the buffer that value pointer originally points to
// - the value pointer gets set to something other than nullptr
// - the newly-allocated buffer is set properly
- uint8_t *bufPtr = &buf[0];
+ uint8_t* bufPtr = &buf[0];
ASSERT_EQ(4, mCache->getBlob("abcd", 4, &bufPtr, malloc));
ASSERT_EQ(0xee, buf[0]);
ASSERT_EQ(0xee, buf[1]);
@@ -273,8 +268,8 @@ TEST_P(NNCacheSerializationTest, ReinitializedCacheContainsValuesSizeConstrained
SCOPED_TRACE("after second initialize()");
yesStringBlob("abcd", "efgh");
noStringBlob("abcdef"); // key too large
- noStringBlob("ab"); // value too large
+ noStringBlob("ab"); // value too large
}
}
-}
+} // namespace android
diff --git a/nn/driver/sample/SampleDriver.cpp b/nn/driver/sample/SampleDriver.cpp
index ab2c9db32..0cc2d256e 100644
--- a/nn/driver/sample/SampleDriver.cpp
+++ b/nn/driver/sample/SampleDriver.cpp
@@ -203,8 +203,7 @@ Return<ErrorStatus> SampleDriver::prepareModelFromCache(
}
Return<DeviceStatus> SampleDriver::getStatus() {
- NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED,
- "SampleDriver::getStatus");
+ NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
VLOG(DRIVER) << "getStatus()";
return DeviceStatus::AVAILABLE;
}
@@ -298,8 +297,7 @@ Return<ErrorStatus> executeBase(const Request& request, MeasureTiming measure, c
// is expected to live forever.
std::thread([&model, &driver, &poolInfos, request, measure, driverStart, callback] {
asyncExecute(request, measure, driverStart, model, driver, poolInfos, callback);
- })
- .detach();
+ }).detach();
return ErrorStatus::NONE;
}
@@ -469,6 +467,6 @@ Return<void> SamplePreparedModel::configureExecutionBurst(
return Void();
}
-} // namespace sample_driver
-} // namespace nn
-} // namespace android
+} // namespace sample_driver
+} // namespace nn
+} // namespace android
diff --git a/nn/driver/sample/SampleDriverFloatFast.cpp b/nn/driver/sample/SampleDriverFloatFast.cpp
index 329408fa6..3611bbaeb 100644
--- a/nn/driver/sample/SampleDriverFloatFast.cpp
+++ b/nn/driver/sample/SampleDriverFloatFast.cpp
@@ -33,7 +33,7 @@ namespace sample_driver {
using namespace hal;
class SampleDriverFloatFast : public SampleDriver {
-public:
+ public:
SampleDriverFloatFast() : SampleDriver("sample-float-fast") {}
Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
@@ -78,12 +78,12 @@ Return<void> SampleDriverFloatFast::getSupportedOperations_1_2(const V1_2::Model
return Void();
}
-} // namespace sample_driver
-} // namespace nn
-} // namespace android
+} // namespace sample_driver
+} // namespace nn
+} // namespace android
-using android::nn::sample_driver::SampleDriverFloatFast;
using android::sp;
+using android::nn::sample_driver::SampleDriverFloatFast;
int main() {
sp<SampleDriverFloatFast> driver(new SampleDriverFloatFast());
diff --git a/nn/driver/sample/SampleDriverFloatSlow.cpp b/nn/driver/sample/SampleDriverFloatSlow.cpp
index c5571eb55..af498379f 100644
--- a/nn/driver/sample/SampleDriverFloatSlow.cpp
+++ b/nn/driver/sample/SampleDriverFloatSlow.cpp
@@ -33,7 +33,7 @@ namespace sample_driver {
using namespace hal;
class SampleDriverFloatSlow : public SampleDriver {
-public:
+ public:
SampleDriverFloatSlow() : SampleDriver("sample-float-slow") {}
Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
@@ -78,12 +78,12 @@ Return<void> SampleDriverFloatSlow::getSupportedOperations_1_2(const V1_2::Model
return Void();
}
-} // namespace sample_driver
-} // namespace nn
-} // namespace android
+} // namespace sample_driver
+} // namespace nn
+} // namespace android
-using android::nn::sample_driver::SampleDriverFloatSlow;
using android::sp;
+using android::nn::sample_driver::SampleDriverFloatSlow;
int main() {
sp<SampleDriverFloatSlow> driver(new SampleDriverFloatSlow());
diff --git a/nn/driver/sample/SampleDriverMinimal.cpp b/nn/driver/sample/SampleDriverMinimal.cpp
index 6ff596458..e0420348e 100644
--- a/nn/driver/sample/SampleDriverMinimal.cpp
+++ b/nn/driver/sample/SampleDriverMinimal.cpp
@@ -34,7 +34,7 @@ namespace sample_driver {
using namespace hal;
class SampleDriverMinimal : public SampleDriver {
-public:
+ public:
SampleDriverMinimal() : SampleDriver("sample-minimal") {}
Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
@@ -90,12 +90,12 @@ Return<void> SampleDriverMinimal::getSupportedOperations_1_2(const V1_2::Model&
return Void();
}
-} // namespace sample_driver
-} // namespace nn
-} // namespace android
+} // namespace sample_driver
+} // namespace nn
+} // namespace android
-using android::nn::sample_driver::SampleDriverMinimal;
using android::sp;
+using android::nn::sample_driver::SampleDriverMinimal;
int main() {
sp<SampleDriverMinimal> driver(new SampleDriverMinimal());
diff --git a/nn/driver/sample/SampleDriverQuant.cpp b/nn/driver/sample/SampleDriverQuant.cpp
index 8e4c1854e..83fc550fe 100644
--- a/nn/driver/sample/SampleDriverQuant.cpp
+++ b/nn/driver/sample/SampleDriverQuant.cpp
@@ -33,7 +33,7 @@ namespace sample_driver {
using namespace hal;
class SampleDriverQuant : public SampleDriver {
-public:
+ public:
SampleDriverQuant() : SampleDriver("sample-quant") {}
Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
@@ -74,12 +74,12 @@ Return<void> SampleDriverQuant::getSupportedOperations_1_2(const V1_2::Model& mo
return Void();
}
-} // namespace sample_driver
-} // namespace nn
-} // namespace android
+} // namespace sample_driver
+} // namespace nn
+} // namespace android
-using android::nn::sample_driver::SampleDriverQuant;
using android::sp;
+using android::nn::sample_driver::SampleDriverQuant;
int main() {
sp<SampleDriverQuant> driver(new SampleDriverQuant());
diff --git a/nn/runtime/CompilationBuilder.cpp b/nn/runtime/CompilationBuilder.cpp
index 6de44d137..912f0087b 100644
--- a/nn/runtime/CompilationBuilder.cpp
+++ b/nn/runtime/CompilationBuilder.cpp
@@ -90,8 +90,8 @@ int CompilationBuilder::finish() {
int CompilationBuilder::setPreference(int32_t preference) {
if (mFinished) {
- LOG(ERROR) <<
- "ANeuralNetworksCompilation_setPreference can't modify after compilation finished";
+ LOG(ERROR) << "ANeuralNetworksCompilation_setPreference can't modify after compilation "
+ "finished";
return ANEURALNETWORKS_BAD_STATE;
}
if (preference >= kNumberOfPreferences) {
@@ -121,8 +121,8 @@ int CompilationBuilder::setCaching(const std::string& cacheDir, const uint8_t* t
int CompilationBuilder::setPartitioning(uint32_t partitioning) {
if (mFinished) {
- LOG(ERROR) <<
- "ANeuralNetworksCompilation_setPartitioning can't modify after compilation finished";
+ LOG(ERROR) << "ANeuralNetworksCompilation_setPartitioning can't modify after compilation "
+ "finished";
return ANEURALNETWORKS_BAD_STATE;
}
@@ -130,7 +130,7 @@ int CompilationBuilder::setPartitioning(uint32_t partitioning) {
return ANEURALNETWORKS_NO_ERROR;
}
-int CompilationBuilder::createExecution(ExecutionBuilder **execution) {
+int CompilationBuilder::createExecution(ExecutionBuilder** execution) {
if (!mFinished) {
LOG(ERROR) << "ANeuralNetworksExecution_create passed an unfinished compilation";
*execution = nullptr;
diff --git a/nn/runtime/CompilationBuilder.h b/nn/runtime/CompilationBuilder.h
index 7c0b786f8..a47f99903 100644
--- a/nn/runtime/CompilationBuilder.h
+++ b/nn/runtime/CompilationBuilder.h
@@ -32,7 +32,7 @@ class ExecutionBuilder;
class ModelBuilder;
class CompilationBuilder {
-public:
+ public:
friend class ExecutionBuilder; // TODO remove this
// explicitDeviceList is true if the list of devices was provided explicitly
@@ -56,7 +56,7 @@ public:
const ExecutionPlan& forTest_getExecutionPlan() const { return mPlan; }
-private:
+ private:
const ModelBuilder* mModel;
ExecutionPlan mPlan;
@@ -88,7 +88,7 @@ private:
bool mIsCacheInfoProvided = false;
};
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
#endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_COMPILATION_BUILDER_H
diff --git a/nn/runtime/ExecutionBuilder.h b/nn/runtime/ExecutionBuilder.h
index 1c8b1d68c..a837238a1 100644
--- a/nn/runtime/ExecutionBuilder.h
+++ b/nn/runtime/ExecutionBuilder.h
@@ -44,7 +44,8 @@ class StepExecutor;
class ExecutionBuilder {
friend class StepExecutor;
-public:
+
+ public:
ExecutionBuilder(const CompilationBuilder* compilation);
int setInput(uint32_t index, const ANeuralNetworksOperandType* type, const void* buffer,
@@ -171,21 +172,18 @@ class StepExecutor {
mapInputOrOutput(mExecutionBuilder->mOutputs[builderIndex], &mOutputs[executorIndex]);
}
void mapOutputToInput(uint32_t builderIndex, uint32_t executorIndex) {
- mapInputOrOutput(mExecutionBuilder->mOutputs[builderIndex],
- &mInputs[executorIndex]);
+ mapInputOrOutput(mExecutionBuilder->mOutputs[builderIndex], &mInputs[executorIndex]);
}
// The input or output is assumed to have the size of the
// corresponding operand.
int setInputFromTemporaryMemory(uint32_t inputIndex, const Memory* memory, uint32_t offset) {
- return setInputOrOutputFromTemporaryMemory(mModel->getInputOperand(inputIndex),
- memory, offset,
- &mInputs.at(inputIndex));
+ return setInputOrOutputFromTemporaryMemory(mModel->getInputOperand(inputIndex), memory,
+ offset, &mInputs.at(inputIndex));
}
int setOutputFromTemporaryMemory(uint32_t outputIndex, const Memory* memory, uint32_t offset) {
- return setInputOrOutputFromTemporaryMemory(mModel->getOutputOperand(outputIndex),
- memory, offset,
- &mOutputs.at(outputIndex));
+ return setInputOrOutputFromTemporaryMemory(mModel->getOutputOperand(outputIndex), memory,
+ offset, &mOutputs.at(outputIndex));
}
// Executes using the (driver, preparedModel) specified at construction time.
@@ -238,7 +236,7 @@ class StepExecutor {
MemoryTracker mMemories;
};
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
#endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_EXECUTION_BUILDER_H
diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp
index 4e5409da5..866bfd738 100644
--- a/nn/runtime/ExecutionPlan.cpp
+++ b/nn/runtime/ExecutionPlan.cpp
@@ -211,7 +211,7 @@ int copyOperandExtraParams(ModelBuilder& model, uint32_t toOperandIndex,
// This class tracks whether we know the value of an operand as operations
// are processed.
class OperandTracker {
-public:
+ public:
// Creates the tracker for this model. Figure out which operations can be
// executed right away and cb for each one of them.
OperandTracker(const ModelBuilder* model, OperationReadyCallback cb);
@@ -220,14 +220,14 @@ public:
// able to run. Call cb for each one of them.
void markProcessed(uint32_t operationIndex, OperationReadyCallback cb);
-private:
+ private:
const ModelBuilder* mModel;
std::multimap<uint32_t, uint32_t> mOperandToOperations;
std::vector<uint32_t> mUnknownInputCount; // For each operation
};
-OperandTracker::OperandTracker(const ModelBuilder* model, OperationReadyCallback cb) :
- mModel(model) {
+OperandTracker::OperandTracker(const ModelBuilder* model, OperationReadyCallback cb)
+ : mModel(model) {
const auto& operations = mModel->getOperations();
mUnknownInputCount.resize(operations.size());
for (uint32_t operationIndex = 0; operationIndex < operations.size(); operationIndex++) {
@@ -319,9 +319,8 @@ int ExecutionStep::addOperand(uint32_t fromOperandIndex, uint32_t* toOperandInde
} break;
case OperandLifeTime::CONSTANT_REFERENCE: {
const Memory* memory = fromModel.getMemories()[operand.location.poolIndex];
- n = mSubModel.setOperandValueFromMemory(*toOperandIndex, memory,
- operand.location.offset,
- operand.location.length);
+ n = mSubModel.setOperandValueFromMemory(
+ *toOperandIndex, memory, operand.location.offset, operand.location.length);
if (n != ANEURALNETWORKS_NO_ERROR) {
LOG(ERROR) << "Previous error occurred when partitioning the graph";
return n;
@@ -355,7 +354,8 @@ int ExecutionStep::addOperand(uint32_t fromOperandIndex, uint32_t* toOperandInde
// The first time we've seen this operand is as an
// input. That means it must be defined by a
// different partition, and is an input to this one.
- mOutputsAsSubModelInputs.push_back(std::make_pair(fromOperandIndex, *toOperandIndex));
+ mOutputsAsSubModelInputs.push_back(
+ std::make_pair(fromOperandIndex, *toOperandIndex));
} else {
// The first time we've seen this operand is as an
// output.
@@ -396,8 +396,7 @@ int ExecutionStep::addOperation(int operationIndex, const ModelBuilder& fromMode
for (uint32_t i = 0; i < operandCount; i++) {
uint32_t localOperand = ~0U;
int n = addOperand(globalOperands[i], &localOperand, fromModel, kind);
- if (n != ANEURALNETWORKS_NO_ERROR)
- return n;
+ if (n != ANEURALNETWORKS_NO_ERROR) return n;
localOperands[i] = localOperand;
}
return ANEURALNETWORKS_NO_ERROR;
@@ -410,7 +409,7 @@ int ExecutionStep::addOperation(int operationIndex, const ModelBuilder& fromMode
}
return mSubModel.addOperation(static_cast<uint32_t>(operation.type), inputCount, inputs.data(),
- outputCount, outputs.data());
+ outputCount, outputs.data());
}
void ExecutionStep::mapInputsAndOutputs(std::shared_ptr<StepExecutor> stepExecutor) const {
@@ -438,7 +437,7 @@ void ExecutionPlan::CompoundBody::findTempsAsSubModelOutputs() {
void ExecutionStep::logSubModel() const {
VLOG(COMPILATION) << "ExecutionStep::finishSubModel, step " << mIndex;
- auto logRemapEntry = [](std::string &toLog, const std::pair<uint32_t, uint32_t>& e) {
+ auto logRemapEntry = [](std::string& toLog, const std::pair<uint32_t, uint32_t>& e) {
if (!toLog.empty()) {
toLog += ", ";
}
@@ -475,13 +474,13 @@ static void convertModelInputsOrOutputs(
// IN: mModel{Inputs|Outputs}
const ExecutionStep::RemapVectorType& myModelInputsOrOutputs,
// IN: fromModel->{input|output}Count()
- uint32_t fromModelInputOrOutputCount,
+ uint32_t fromModelInputOrOutputCount,
// IN: fromModel->get{Input|Output}OperandIndex
- std::function<uint32_t(uint32_t)> fromModelGetInputOrOutputOperandIndex,
+ std::function<uint32_t(uint32_t)> fromModelGetInputOrOutputOperandIndex,
// OUT: for v : mModel{Inputs|Outputs} : v.second
- std::vector<uint32_t>* inputsOrOutputs,
+ std::vector<uint32_t>* inputsOrOutputs,
// OUT: submodel input-or-output index to original model input-or-output index
- std::vector<uint32_t>* inputOrOutputIndexSubModelToFromModel) {
+ std::vector<uint32_t>* inputOrOutputIndexSubModelToFromModel) {
std::map<uint32_t, uint32_t> fromModelIndexMap; // operand index to input-or-output index
for (uint32_t i = 0; i < fromModelInputOrOutputCount; i++) {
fromModelIndexMap[fromModelGetInputOrOutputOperandIndex(i)] = i;
@@ -508,11 +507,10 @@ int ExecutionStep::finishSubModel(const ModelBuilder* fromModel, bool* hasOutput
// ExecutionPlan::next() depends on these orderings.
std::vector<uint32_t> inputs;
- convertModelInputsOrOutputs(mModelInputs,
- fromModel->inputCount(),
- [=](uint32_t i) { return fromModel->getInputOperandIndex(i); },
- &inputs,
- &mInputIndexSubModelToFromModel);
+ convertModelInputsOrOutputs(
+ mModelInputs, fromModel->inputCount(),
+ [=](uint32_t i) { return fromModel->getInputOperandIndex(i); }, &inputs,
+ &mInputIndexSubModelToFromModel);
for (const auto& subModelInput : mTempsAsSubModelInputs) {
inputs.push_back(subModelInput.second);
}
@@ -521,11 +519,10 @@ int ExecutionStep::finishSubModel(const ModelBuilder* fromModel, bool* hasOutput
}
std::vector<uint32_t> outputs;
- convertModelInputsOrOutputs(mModelOutputs,
- fromModel->outputCount(),
- [=](uint32_t i) { return fromModel->getOutputOperandIndex(i); },
- &outputs,
- &mOutputIndexSubModelToFromModel);
+ convertModelInputsOrOutputs(
+ mModelOutputs, fromModel->outputCount(),
+ [=](uint32_t i) { return fromModel->getOutputOperandIndex(i); }, &outputs,
+ &mOutputIndexSubModelToFromModel);
for (const auto& subModelOutput : mTempsAsSubModelOutputs) {
outputs.push_back(subModelOutput.second);
const Operand& operand = mSubModel.getOperand(subModelOutput.second);
@@ -546,7 +543,8 @@ int ExecutionStep::finishSubModel(const ModelBuilder* fromModel, bool* hasOutput
}
{
- int n = mSubModel.identifyInputsAndOutputs(inputs.size(), &inputs[0], outputs.size(), &outputs[0]);
+ int n = mSubModel.identifyInputsAndOutputs(inputs.size(), &inputs[0], outputs.size(),
+ &outputs[0]);
if (n != ANEURALNETWORKS_NO_ERROR) {
return n;
}
@@ -601,7 +599,8 @@ int ExecutionPlan::CompoundBody::finish(const ModelBuilder* fromModel,
}
}
if (mHasSubModelOutputOfUnknownSize) {
- VLOG(COMPILATION) << "ExecutionPlan::CompoundBody::finish -- mHasSubModelOutputOfUnknownSize";
+ VLOG(COMPILATION)
+ << "ExecutionPlan::CompoundBody::finish -- mHasSubModelOutputOfUnknownSize";
return ANEURALNETWORKS_OP_FAILED;
}
@@ -709,7 +708,7 @@ std::shared_ptr<ExecutionPlan::Controller> ExecutionPlan::makeController(
if (mState == COMPOUND) {
const ModelBuilder* fromModel = executionBuilder->getModel();
for (const auto& step : compound()->mSteps) {
- for (const auto& output: step->getTempsAsSubModelOutputs()) {
+ for (const auto& output : step->getTempsAsSubModelOutputs()) {
const uint32_t fromModelOperandIndex = output.first;
const Operand& fromModelOperand = fromModel->getOperand(fromModelOperandIndex);
if (subModelInputsAndOutputs == nullptr) {
@@ -718,14 +717,14 @@ std::shared_ptr<ExecutionPlan::Controller> ExecutionPlan::makeController(
}
const uint32_t size = TypeManager::get()->getSizeOfData(fromModelOperand);
totalSizeOfTemporaries += alignBytesNeeded(totalSizeOfTemporaries, size);
- subModelInputsAndOutputs->insert(std::make_pair(fromModelOperandIndex, totalSizeOfTemporaries));
+ subModelInputsAndOutputs->insert(
+ std::make_pair(fromModelOperandIndex, totalSizeOfTemporaries));
totalSizeOfTemporaries += size;
}
}
if (VLOG_IS_ON(EXECUTION) && (subModelInputsAndOutputs != nullptr)) {
for (const auto& io : *subModelInputsAndOutputs) {
- VLOG(EXECUTION) << "temp: origOpndIdx = " << io.first
- << ", offset = " << io.second;
+ VLOG(EXECUTION) << "temp: origOpndIdx = " << io.first << ", offset = " << io.second;
}
}
}
@@ -735,7 +734,6 @@ std::shared_ptr<ExecutionPlan::Controller> ExecutionPlan::makeController(
totalSizeOfTemporaries));
}
-
// TODO: Find a better way to provide this functionality.
int ExecutionPlan::fallback(std::shared_ptr<Controller> controller,
std::shared_ptr<StepExecutor>* executor) const {
@@ -766,8 +764,7 @@ int ExecutionPlan::next(std::shared_ptr<Controller> controller,
*burstController = nullptr;
}
- VLOG(EXECUTION) << "ExecutionPlan::next("
- << SHOW_IF_DEBUG(controller << ", " << executor)
+ VLOG(EXECUTION) << "ExecutionPlan::next(" << SHOW_IF_DEBUG(controller << ", " << executor)
<< "): mNextStepIndex = " << controller->mNextStepIndex;
if (controller->mNextStepIndex == Controller::kBadStepIndex) {
@@ -832,11 +829,10 @@ int ExecutionPlan::next(std::shared_ptr<Controller> controller,
for (auto I = subModelOutputs.begin(), E = subModelOutputs.end(); I != E; I++, idx++) {
const uint32_t fromModelOperandIndex = I->first;
const uint32_t offsetOfTemporary =
- controller->mSubModelInputsAndOutputs->at(fromModelOperandIndex);
- int n = (*executor)->setOutputFromTemporaryMemory(
- firstSubModelOutputIndex + idx,
- &controller->mTemporaries,
- offsetOfTemporary);
+ controller->mSubModelInputsAndOutputs->at(fromModelOperandIndex);
+ int n = (*executor)->setOutputFromTemporaryMemory(firstSubModelOutputIndex + idx,
+ &controller->mTemporaries,
+ offsetOfTemporary);
if (n != ANEURALNETWORKS_NO_ERROR) {
controller->mNextStepIndex = Controller::kBadStepIndex;
return n;
@@ -853,11 +849,10 @@ int ExecutionPlan::next(std::shared_ptr<Controller> controller,
for (auto I = subModelInputs.begin(), E = subModelInputs.end(); I != E; I++, idx++) {
const uint32_t fromModelOperandIndex = I->first;
const uint32_t offsetOfTemporary =
- controller->mSubModelInputsAndOutputs->at(fromModelOperandIndex);
- int n = (*executor)->setInputFromTemporaryMemory(
- firstSubModelInputIndex + idx,
- &controller->mTemporaries,
- offsetOfTemporary);
+ controller->mSubModelInputsAndOutputs->at(fromModelOperandIndex);
+ int n = (*executor)->setInputFromTemporaryMemory(firstSubModelInputIndex + idx,
+ &controller->mTemporaries,
+ offsetOfTemporary);
if (n != ANEURALNETWORKS_NO_ERROR) {
controller->mNextStepIndex = Controller::kBadStepIndex;
return n;
@@ -1057,7 +1052,7 @@ PerformanceInfo ModelBuilder::getPerformanceInfo(const std::shared_ptr<Device> d
// currently the case but is not a safe assumption to make in the long term.
const uint32_t operandIndex = operation.inputs[0];
const OperandType operandType = mOperands[operandIndex].type;
- switch(operandType) {
+ switch (operandType) {
case OperandType::FLOAT32:
if (mRelaxComputationFloat32toFloat16) {
return device->getRelaxedFloat32toFloat16PerformanceScalar();
@@ -1079,7 +1074,7 @@ namespace {
// This class determines whether a given device can execute a given operation
class CanDo {
-public:
+ public:
CanDo() {}
void initialize(const MetaModel& metaModel, std::shared_ptr<Device> device) {
@@ -1088,7 +1083,7 @@ public:
bool check(size_t operationIndex) const { return mSupportsOperationByIndex[operationIndex]; }
-private:
+ private:
hidl_vec<bool> mSupportsOperationByIndex;
};
@@ -1116,8 +1111,8 @@ int ModelBuilder::findBestDeviceForEachOperation(
if (canDo[deviceIndex].check(operationIndex)) {
const PerformanceInfo perf = getPerformanceInfo(device, operationIndex);
const float perfVal =
- (preference == ANEURALNETWORKS_PREFER_LOW_POWER ? perf.powerUsage
- : perf.execTime);
+ (preference == ANEURALNETWORKS_PREFER_LOW_POWER ? perf.powerUsage
+ : perf.execTime);
if (bestChoice < 0 || perfVal < bestPerfVal ||
(perfVal == bestPerfVal && device == DeviceManager::getCpuDevice())) {
bestChoice = deviceIndex;
@@ -1129,8 +1124,7 @@ int ModelBuilder::findBestDeviceForEachOperation(
// specific device.
// Logs O(operationCount * deviceCount) times, but
// typically deviceCount is very small.
- VLOG(COMPILATION) << "Device " << device->getName()
- << " can't do operation "
+ VLOG(COMPILATION) << "Device " << device->getName() << " can't do operation "
<< toString(getOperation(operationIndex).type);
}
}
@@ -1147,5 +1141,5 @@ int ModelBuilder::findBestDeviceForEachOperation(
return ANEURALNETWORKS_NO_ERROR;
}
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
diff --git a/nn/runtime/ExecutionPlan.h b/nn/runtime/ExecutionPlan.h
index be9b9aae0..6697c8293 100644
--- a/nn/runtime/ExecutionPlan.h
+++ b/nn/runtime/ExecutionPlan.h
@@ -51,7 +51,7 @@ class PreparedModel;
class StepExecutor;
class ExecutionStep {
-public:
+ public:
typedef std::vector<std::pair<uint32_t, uint32_t>> RemapVectorType;
typedef std::set<std::pair<uint32_t, uint32_t>> SubModelOutputSetType;
@@ -63,21 +63,13 @@ public:
const ModelBuilder& fromModel, OperandKind kind);
// Each container entry is of the form (fromModel index, subModel index)
- const RemapVectorType& getModelInputs() const {
- return mModelInputs;
- }
- const RemapVectorType& getModelOutputs() const {
- return mModelOutputs;
- }
- const RemapVectorType& getTempsAsSubModelInputs() const {
- return mTempsAsSubModelInputs;
- }
+ const RemapVectorType& getModelInputs() const { return mModelInputs; }
+ const RemapVectorType& getModelOutputs() const { return mModelOutputs; }
+ const RemapVectorType& getTempsAsSubModelInputs() const { return mTempsAsSubModelInputs; }
const SubModelOutputSetType& getTempsAsSubModelOutputs() const {
return mTempsAsSubModelOutputs;
}
- const RemapVectorType& getOutputsAsSubModelInputs() const {
- return mOutputsAsSubModelInputs;
- }
+ const RemapVectorType& getOutputsAsSubModelInputs() const { return mOutputsAsSubModelInputs; }
const std::vector<uint32_t>& getOutputIndexSubModelToFromModel() const {
return mOutputIndexSubModelToFromModel;
}
@@ -170,11 +162,11 @@ public:
};
class ExecutionPlan {
-public:
+ public:
ExecutionPlan(const ExecutionPlan&) = delete;
ExecutionPlan& operator=(const ExecutionPlan&) = delete;
- ExecutionPlan() { }
+ ExecutionPlan() {}
~ExecutionPlan() { delete mBody; }
// Controller is part of the interface to a mechanism for
@@ -190,7 +182,8 @@ public:
// a problem has occurred.
class Controller {
friend class ExecutionPlan;
- private:
+
+ private:
Controller(const Controller&) = delete;
Controller& operator=(const Controller&) = delete;
@@ -209,7 +202,8 @@ public:
const ExecutionPlan* mPlan;
ExecutionBuilder* mExecutionBuilder;
const BurstBuilder* mBurstBuilder;
- std::shared_ptr<const SubModelInputsAndOutputsType> mSubModelInputsAndOutputs; // may be nullptr
+ std::shared_ptr<const SubModelInputsAndOutputsType>
+ mSubModelInputsAndOutputs; // may be nullptr
Memory mTemporaries;
size_t mNextStepIndex;
};
@@ -223,7 +217,8 @@ public:
std::shared_ptr<ExecutionBurstController>* burstController = nullptr) const;
// Create the same executor as the last one created by next().
- int fallback(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor) const;
+ int fallback(std::shared_ptr<Controller> controller,
+ std::shared_ptr<StepExecutor>* executor) const;
std::shared_ptr<ExecutionStep> createNewStep(const std::shared_ptr<Device> device);
@@ -310,7 +305,8 @@ public:
std::unordered_map<uint32_t, uint32_t> mTemporaryToDefiningStep;
bool mHasSubModelOutputOfUnknownSize = false;
- private:
+
+ private:
void findTempsAsSubModelOutputs();
};
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 2511cd343..7f91b2059 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -228,7 +228,7 @@ void DriverDevice::getSupportedOperations(const MetaModel& metaModel,
}
uint32_t accumulator = baseAccumulator;
- const Operation &operation = hidlModel.operations[operationIndex];
+ const Operation& operation = hidlModel.operations[operationIndex];
accumulator ^= static_cast<uint32_t>(operation.type);
auto accumulateOperands = [&hidlModel, &accumulator](const hidl_vec<uint32_t>& operands) {
for (uint32_t operandIndex : operands) {
diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h
index 01acdf391..2c2d6cf99 100644
--- a/nn/runtime/Manager.h
+++ b/nn/runtime/Manager.h
@@ -127,11 +127,7 @@ class DeviceManager {
// 1 - Do graph partitioning; but fall back to non-partitioned
// execution if there is a partitioning failure.
// 2 - Do graph partitioning, and rely on it; there is no fallback.
- enum {
- kPartitioningNo = 0,
- kPartitioningWithFallback = 1,
- kPartitioningWithoutFallback = 2
- };
+ enum { kPartitioningNo = 0, kPartitioningWithFallback = 1, kPartitioningWithoutFallback = 2 };
uint32_t getPartitioning() const { return mPartitioning; }
static bool partitioningAllowsFallback(uint32_t partitioning) {
return partitioning == kPartitioningWithFallback;
@@ -209,7 +205,7 @@ class DeviceManager {
bool mStrictSlicing = false;
};
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
#endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MANAGER_H
diff --git a/nn/runtime/Memory.cpp b/nn/runtime/Memory.cpp
index a2e3de56c..20fabfa08 100644
--- a/nn/runtime/Memory.cpp
+++ b/nn/runtime/Memory.cpp
@@ -161,5 +161,5 @@ uint32_t MemoryTracker::add(const Memory* memory) {
return idx;
}
-} // namespace nn
-} // namespace android
+} // namespace nn
+} // namespace android
diff --git a/nn/runtime/ModelBuilder.h b/nn/runtime/ModelBuilder.h
index 6cf93b914..4bfcd04f9 100644
--- a/nn/runtime/ModelBuilder.h
+++ b/nn/runtime/ModelBuilder.h
@@ -170,7 +170,6 @@ class ModelBuilder {
// No further modifications are allowed to the model.
bool mInvalidModel = false;
-
// 'true' indicates TENSOR_FLOAT32 may be calculated with range and/or
// precision as low as that of the IEEE 754 16-bit floating-point format.
// 'false' indicates TENSOR_FLOAT32 must be calculated using at least the
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index 038068ee8..beaaa261e 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -76,7 +76,8 @@ class DeathHandler : public hidl_death_recipient {
[this, callback] { unregisterCallback(callback); });
}
- private : void registerCallback(const sp<ICallback>& callback) {
+ private:
+ void registerCallback(const sp<ICallback>& callback) {
std::lock_guard<std::mutex> hold(mMutex);
mCallbacks.push_back(callback);
}
diff --git a/nn/runtime/test/TestExecution.cpp b/nn/runtime/test/TestExecution.cpp
index 1a753c73f..af62ab46c 100644
--- a/nn/runtime/test/TestExecution.cpp
+++ b/nn/runtime/test/TestExecution.cpp
@@ -308,9 +308,9 @@ class TestDriver10 : public V1_0::IDevice {
// This class adds some simple utilities on top of WrapperCompilation in order
// to provide access to certain features from CompilationBuilder that are not
// exposed by the base class.
-template<typename DriverClass>
+template <typename DriverClass>
class TestCompilation : public WrapperCompilation {
-public:
+ public:
// Allow dummying up the error status for all executions from this
// compilation. If errorStatus is NONE, then execute behaves
// normally (and sends back the actual execution status).
@@ -433,20 +433,21 @@ class ExecutionTestTemplate
private:
static WrapperModel makeModel() {
- static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, { 1 });
+ static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, {1});
WrapperModel model;
uint32_t input = model.addOperand(&tensorType);
uint32_t output = model.addOperand(&tensorType);
- model.addOperation(ANEURALNETWORKS_FLOOR, { input }, { output });
- model.identifyInputsAndOutputs({ input }, { output } );
+ model.addOperation(ANEURALNETWORKS_FLOOR, {input}, {output});
+ model.identifyInputsAndOutputs({input}, {output});
assert(model.finish() == Result::NO_ERROR);
return model;
}
};
-template<class DriverClass> void ExecutionTestTemplate<DriverClass>::TestWait() {
+template <class DriverClass>
+void ExecutionTestTemplate<DriverClass>::TestWait() {
SCOPED_TRACE(kName);
// Skip Introspection API tests when CPU only flag is forced on.
if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
diff --git a/nn/runtime/test/TestMemory.cpp b/nn/runtime/test/TestMemory.cpp
index fd342dd7a..122bde2b6 100644
--- a/nn/runtime/test/TestMemory.cpp
+++ b/nn/runtime/test/TestMemory.cpp
@@ -35,9 +35,8 @@ namespace {
// Tests the various ways to pass weights and input/output data.
class MemoryTest : public ::testing::Test {
-protected:
+ protected:
void SetUp() override {}
-
};
TEST_F(MemoryTest, TestFd) {
diff --git a/nn/runtime/test/TestMemory.h b/nn/runtime/test/TestMemory.h
index 57df0fe68..3ce179026 100644
--- a/nn/runtime/test/TestMemory.h
+++ b/nn/runtime/test/TestMemory.h
@@ -43,9 +43,8 @@ const Matrix3x4 matrix1 = {{1.f, 2.f, 3.f, 4.f}, {5.f, 6.f, 7.f, 8.f}, {9.f, 10.
const Matrix3x4 matrix2 = {{100.f, 200.f, 300.f, 400.f},
{500.f, 600.f, 700.f, 800.f},
{900.f, 1000.f, 1100.f, 1200.f}};
-const Matrix3x4 matrix3 = {{20.f, 30.f, 40.f, 50.f},
- {21.f, 22.f, 23.f, 24.f},
- {31.f, 32.f, 33.f, 34.f}};
+const Matrix3x4 matrix3 = {
+ {20.f, 30.f, 40.f, 50.f}, {21.f, 22.f, 23.f, 24.f}, {31.f, 32.f, 33.f, 34.f}};
const Matrix3x4 expected3 = {{121.f, 232.f, 343.f, 454.f},
{526.f, 628.f, 730.f, 832.f},
{940.f, 1042.f, 1144.f, 1246.f}};
diff --git a/nn/runtime/test/TestMemoryInternal.cpp b/nn/runtime/test/TestMemoryInternal.cpp
index 755af09fd..8bba8253d 100644
--- a/nn/runtime/test/TestMemoryInternal.cpp
+++ b/nn/runtime/test/TestMemoryInternal.cpp
@@ -53,11 +53,11 @@ namespace {
// (We can also get very unlucky and mask a memory leak by unrelated unmapping
// somewhere else. This seems unlikely enough to not deal with.)
class MemoryLeakTest : public ::testing::Test {
-protected:
+ protected:
void SetUp() override;
void TearDown() override;
-private:
+ private:
size_t GetAshmemMappingsCount();
size_t mStartingMapCount = 0;
@@ -77,7 +77,7 @@ void MemoryLeakTest::TearDown() {
size_t MemoryLeakTest::GetAshmemMappingsCount() {
std::ifstream mappingsStream("/proc/self/maps");
- if (! mappingsStream.good()) {
+ if (!mappingsStream.good()) {
// errno is set by std::ifstream on Linux
ADD_FAILURE() << "Failed to open /proc/self/maps: " << std::strerror(errno);
return 0;
@@ -85,9 +85,9 @@ size_t MemoryLeakTest::GetAshmemMappingsCount() {
std::string line;
int mapCount = 0;
while (std::getline(mappingsStream, line)) {
- if (line.find("/dev/ashmem") != std::string::npos) {
- ++mapCount;
- }
+ if (line.find("/dev/ashmem") != std::string::npos) {
+ ++mapCount;
+ }
}
return mapCount;
}
@@ -106,8 +106,8 @@ TEST_F(MemoryLeakTest, TestASharedMemory) {
int weightsFd = ASharedMemory_create("weights", weightsSize);
ASSERT_GT(weightsFd, -1);
- uint8_t* weightsData = (uint8_t*)mmap(nullptr, weightsSize, PROT_READ | PROT_WRITE,
- MAP_SHARED, weightsFd, 0);
+ uint8_t* weightsData =
+ (uint8_t*)mmap(nullptr, weightsSize, PROT_READ | PROT_WRITE, MAP_SHARED, weightsFd, 0);
ASSERT_NE(weightsData, nullptr);
memcpy(weightsData + offsetForMatrix2, matrix2, sizeof(matrix2));
memcpy(weightsData + offsetForMatrix3, matrix3, sizeof(matrix3));
@@ -139,8 +139,8 @@ TEST_F(MemoryLeakTest, TestASharedMemory) {
constexpr size_t inputSize = offsetForMatrix1 + sizeof(Matrix3x4);
int inputFd = ASharedMemory_create("input", inputSize);
ASSERT_GT(inputFd, -1);
- uint8_t* inputData = (uint8_t*)mmap(nullptr, inputSize,
- PROT_READ | PROT_WRITE, MAP_SHARED, inputFd, 0);
+ uint8_t* inputData =
+ (uint8_t*)mmap(nullptr, inputSize, PROT_READ | PROT_WRITE, MAP_SHARED, inputFd, 0);
ASSERT_NE(inputData, nullptr);
memcpy(inputData + offsetForMatrix1, matrix1, sizeof(Matrix3x4));
WrapperMemory input(inputSize, PROT_READ, inputFd, 0);
@@ -150,8 +150,8 @@ TEST_F(MemoryLeakTest, TestASharedMemory) {
constexpr size_t outputSize = offsetForActual + sizeof(Matrix3x4);
int outputFd = ASharedMemory_create("output", outputSize);
ASSERT_GT(outputFd, -1);
- uint8_t* outputData = (uint8_t*)mmap(nullptr, outputSize,
- PROT_READ | PROT_WRITE, MAP_SHARED, outputFd, 0);
+ uint8_t* outputData =
+ (uint8_t*)mmap(nullptr, outputSize, PROT_READ | PROT_WRITE, MAP_SHARED, outputFd, 0);
ASSERT_NE(outputData, nullptr);
memset(outputData, 0, outputSize);
WrapperMemory actual(outputSize, PROT_READ | PROT_WRITE, outputFd, 0);
@@ -166,8 +166,9 @@ TEST_F(MemoryLeakTest, TestASharedMemory) {
ASSERT_EQ(execution2.setOutputFromMemory(0, &actual, offsetForActual, sizeof(Matrix3x4)),
WrapperResult::NO_ERROR);
ASSERT_EQ(execution2.compute(), WrapperResult::NO_ERROR);
- ASSERT_EQ(CompareMatrices(expected3,
- *reinterpret_cast<Matrix3x4*>(outputData + offsetForActual)), 0);
+ ASSERT_EQ(
+ CompareMatrices(expected3, *reinterpret_cast<Matrix3x4*>(outputData + offsetForActual)),
+ 0);
munmap(weightsData, weightsSize);
munmap(inputData, inputSize);
@@ -199,7 +200,7 @@ TEST_F(MemoryLeakTest, GetPointer) {
ASSERT_TRUE(mem.isValid());
auto internalMem = reinterpret_cast<::android::nn::Memory*>(mem.get());
- uint8_t *dummy;
+ uint8_t* dummy;
ASSERT_EQ(internalMem->getPointer(&dummy), ANEURALNETWORKS_NO_ERROR);
(*dummy)++;
}
@@ -219,7 +220,7 @@ TEST_F(MemoryLeakTest, Instantiate) {
ASSERT_TRUE(mem.isValid());
auto internalMem = reinterpret_cast<::android::nn::Memory*>(mem.get());
- uint8_t *dummy;
+ uint8_t* dummy;
ASSERT_EQ(internalMem->getPointer(&dummy), ANEURALNETWORKS_NO_ERROR);
close(fd);
@@ -260,7 +261,8 @@ TEST_F(MemoryLeakTest, convTooLarge) {
model.setOperandValue(act, act_init, sizeof(act_init));
int32_t stride_init[] = {1};
model.setOperandValue(stride, stride_init, sizeof(stride_init));
- model.addOperation(ANEURALNETWORKS_CONV_2D, {op1, op2, op3, pad0, pad0, pad0, pad0, stride, stride, act}, {op4});
+ model.addOperation(ANEURALNETWORKS_CONV_2D,
+ {op1, op2, op3, pad0, pad0, pad0, pad0, stride, stride, act}, {op4});
// Inputs and outputs
model.identifyInputsAndOutputs({op1}, {op4});
@@ -269,7 +271,7 @@ TEST_F(MemoryLeakTest, convTooLarge) {
// Compilation
WrapperCompilation compilation(&model);
- ASSERT_EQ(WrapperResult::NO_ERROR,compilation.finish());
+ ASSERT_EQ(WrapperResult::NO_ERROR, compilation.finish());
WrapperExecution execution(&compilation);
// Set input and outputs
@@ -283,6 +285,6 @@ TEST_F(MemoryLeakTest, convTooLarge) {
ASSERT_EQ(WrapperResult::OP_FAILED, r);
}
-#endif // NNTEST_ONLY_PUBLIC_API
+#endif // NNTEST_ONLY_PUBLIC_API
} // end namespace
diff --git a/nn/runtime/test/TestOpenmpSettings.cpp b/nn/runtime/test/TestOpenmpSettings.cpp
index a021f8db1..d1dd2f526 100644
--- a/nn/runtime/test/TestOpenmpSettings.cpp
+++ b/nn/runtime/test/TestOpenmpSettings.cpp
@@ -16,19 +16,19 @@
#include "CpuExecutor.h"
-#include <algorithm>
#include <gtest/gtest.h>
-#include <memory>
#include <omp.h>
+#include <unistd.h>
+#include <algorithm>
+#include <memory>
#include <random>
#include <thread>
-#include <unistd.h>
#include <vector>
namespace {
class OpenmpSettingsTest : public ::testing::Test {
-protected:
+ protected:
virtual void SetUp() override {
const int blocktimeInitial = kmp_get_blocktime();
ASSERT_EQ(blocktimeInitial, kOpenmpDefaultBlockTime);
@@ -84,9 +84,7 @@ TEST_F(OpenmpSettingsTest, TestThreaded) {
ASSERT_EQ(blocktimeSet2, 1);
}));
}
- std::for_each(threads.begin(), threads.end(), [](std::thread& t) {
- t.join();
- });
+ std::for_each(threads.begin(), threads.end(), [](std::thread& t) { t.join(); });
}
} // end namespace
diff --git a/nn/runtime/test/TestPartitioning.cpp b/nn/runtime/test/TestPartitioning.cpp
index d2bea0e1f..c4f792eb6 100644
--- a/nn/runtime/test/TestPartitioning.cpp
+++ b/nn/runtime/test/TestPartitioning.cpp
@@ -237,9 +237,7 @@ uint32_t lookupOperation(std::function<const Operation&(uint32_t)> getOperation,
(input2.lifetime == OperandLifeTime::CONSTANT_COPY)) {
int32_t value;
CHECK_EQ(sizeof(value), input2.location.length);
- memcpy(&value,
- getValue(input2.location.offset),
- input2.location.length);
+ memcpy(&value, getValue(input2.location.offset), input2.location.length);
return value + operationToFirstEncoding.at(operation.type);
}
break;
@@ -257,14 +255,9 @@ uint32_t lookupOperation(std::function<const Operation&(uint32_t)> getOperation,
uint32_t lookupOperation(const HidlModel& model, uint32_t operationIndex) {
return lookupOperation(
- [&model](uint32_t index) -> const Operation& {
- return model.operations[index];
- },
- [&model](uint32_t index) -> const Operand& {
- return model.operands[index];
- },
- [&model](uint32_t offset) {return &model.operandValues[offset];},
- operationIndex);
+ [&model](uint32_t index) -> const Operation& { return model.operations[index]; },
+ [&model](uint32_t index) -> const Operand& { return model.operands[index]; },
+ [&model](uint32_t offset) { return &model.operandValues[offset]; }, operationIndex);
}
#ifdef VERBOSE
@@ -287,32 +280,33 @@ void dump(const char* name, const ModelBuilder* model) {
// operation. The subset is represented with a bitmask, in which
// operation kind K corresponds to the bit (1 << K).
class PartitioningDriver : public SampleDriver {
-private:
+ private:
// Dummy class -- a prepared model must not be nullptr.
class PartitioningPreparedModel : public IPreparedModel {
- public:
- Return<ErrorStatus> execute(const Request&, const sp<V1_0::IExecutionCallback>&) override {
- return ErrorStatus::DEVICE_UNAVAILABLE;
- }
- Return<ErrorStatus> execute_1_2(const Request&, MeasureTiming,
- const sp<V1_2::IExecutionCallback>&) override {
- return ErrorStatus::DEVICE_UNAVAILABLE;
- }
- Return<void> executeSynchronously(const Request&, MeasureTiming,
- executeSynchronously_cb cb) override {
- cb(ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
- return Void();
- }
- Return<void> configureExecutionBurst(
- const sp<V1_2::IBurstCallback>& /*callback*/,
- const MQDescriptorSync<V1_2::FmqRequestDatum>& /*requestChannel*/,
- const MQDescriptorSync<V1_2::FmqResultDatum>& /*resultChannel*/,
- configureExecutionBurst_cb cb) override {
- cb(ErrorStatus::DEVICE_UNAVAILABLE, nullptr);
- return Void();
- }
+ public:
+ Return<ErrorStatus> execute(const Request&, const sp<V1_0::IExecutionCallback>&) override {
+ return ErrorStatus::DEVICE_UNAVAILABLE;
+ }
+ Return<ErrorStatus> execute_1_2(const Request&, MeasureTiming,
+ const sp<V1_2::IExecutionCallback>&) override {
+ return ErrorStatus::DEVICE_UNAVAILABLE;
+ }
+ Return<void> executeSynchronously(const Request&, MeasureTiming,
+ executeSynchronously_cb cb) override {
+ cb(ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
+ return Void();
+ }
+ Return<void> configureExecutionBurst(
+ const sp<V1_2::IBurstCallback>& /*callback*/,
+ const MQDescriptorSync<V1_2::FmqRequestDatum>& /*requestChannel*/,
+ const MQDescriptorSync<V1_2::FmqResultDatum>& /*resultChannel*/,
+ configureExecutionBurst_cb cb) override {
+ cb(ErrorStatus::DEVICE_UNAVAILABLE, nullptr);
+ return Void();
+ }
};
-public:
+
+ public:
enum OEM {
OEMNo, // rejected by getSupportedOperations and prepareModel
OEMIndecisive, // accepted by getSupportedOperations but not prepareModel
@@ -350,9 +344,7 @@ public:
return status;
}
- Return<DeviceStatus> getStatus() override {
- return DeviceStatus::AVAILABLE;
- }
+ Return<DeviceStatus> getStatus() override { return DeviceStatus::AVAILABLE; }
Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override {
cb(ErrorStatus::NONE, mCapabilities);
@@ -567,15 +559,15 @@ class PartitioningModel : private WrapperModel {
uint32_t addOperationOEM1To1(const uint32_t input,
Dimensioned dimensionedOutput = Dimensioned::YES) {
uint32_t output = addOperandOfSameType(input, dimensionedOutput);
- addOperation(ANEURALNETWORKS_OEM_OPERATION, { input }, { output });
+ addOperation(ANEURALNETWORKS_OEM_OPERATION, {input}, {output});
return output;
}
// Run the partitioning algorithm to create an ExecutionPlan.
int partitionTheWork(const std::vector<std::shared_ptr<Device>>& devices,
ExecutePreference preference, ExecutionPlan* plan) {
- return reinterpret_cast<ModelBuilder*>(getHandle())->partitionTheWork(
- devices, static_cast<uint32_t>(preference), plan);
+ return reinterpret_cast<ModelBuilder*>(getHandle())
+ ->partitionTheWork(devices, static_cast<uint32_t>(preference), plan);
}
#ifdef VERBOSE
@@ -586,34 +578,34 @@ class PartitioningModel : private WrapperModel {
}
#endif
-private:
- // Create an operation with two inputs and one output, specifying
- // the operation kind and the input operand indexes.
- // Returns the output operand index.
- uint32_t addOperation2To1(uint32_t operation, const uint32_t input0, const uint32_t input1,
- Dimensioned dimensionedOutput = Dimensioned::YES) {
- auto it = firstEncodingToOperation.lower_bound(operation);
- CHECK(it != firstEncodingToOperation.end());
- ANeuralNetworksOperationType type = it->second.first;
- if (it->second.second) {
- int32_t fuseCode = operation - it->first;
- uint32_t input2 = addIntOperand(fuseCode);
- uint32_t output = addOperandOfSameType(input0, dimensionedOutput);
- addOperation(type, {input0, input1, input2}, {output});
- return output;
- } else {
- uint32_t output = addOperandOfSameType(input0, dimensionedOutput);
- addOperation(type, {input0, input1}, {output});
- return output;
- }
- }
-
- // Create a scalar integer operand of the specified value, and
- // return the corresponding operand index.
- uint32_t addIntOperand(int32_t value) {
- uint32_t operand = addOperand(WrapperType::INT32);
- setOperandValue(operand, &value, sizeof(value));
- return operand;
+ private:
+ // Create an operation with two inputs and one output, specifying
+ // the operation kind and the input operand indexes.
+ // Returns the output operand index.
+ uint32_t addOperation2To1(uint32_t operation, const uint32_t input0, const uint32_t input1,
+ Dimensioned dimensionedOutput = Dimensioned::YES) {
+ auto it = firstEncodingToOperation.lower_bound(operation);
+ CHECK(it != firstEncodingToOperation.end());
+ ANeuralNetworksOperationType type = it->second.first;
+ if (it->second.second) {
+ int32_t fuseCode = operation - it->first;
+ uint32_t input2 = addIntOperand(fuseCode);
+ uint32_t output = addOperandOfSameType(input0, dimensionedOutput);
+ addOperation(type, {input0, input1, input2}, {output});
+ return output;
+ } else {
+ uint32_t output = addOperandOfSameType(input0, dimensionedOutput);
+ addOperation(type, {input0, input1}, {output});
+ return output;
+ }
+ }
+
+ // Create a scalar integer operand of the specified value, and
+ // return the corresponding operand index.
+ uint32_t addIntOperand(int32_t value) {
+ uint32_t operand = addOperand(WrapperType::INT32);
+ setOperandValue(operand, &value, sizeof(value));
+ return operand;
}
// Create an operand of the same type as the specified operand,
@@ -633,30 +625,26 @@ private:
// This class adds some utilities on top of WrapperCompilation.
class PartitioningCompilation : public WrapperCompilation {
-public:
- PartitioningCompilation(const PartitioningModel* model,
- const std::vector<std::shared_ptr<Device>>& devices) {
- ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model->getHandle());
- CompilationBuilder* c = nullptr;
- int result = m->createCompilation(&c, devices);
- EXPECT_EQ(result, 0);
- mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
- }
-
- Result setPartitioning(uint32_t partitioning) {
- return static_cast<Result>(builder()->setPartitioning(partitioning));
+ public:
+ PartitioningCompilation(const PartitioningModel* model,
+ const std::vector<std::shared_ptr<Device>>& devices) {
+ ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model->getHandle());
+ CompilationBuilder* c = nullptr;
+ int result = m->createCompilation(&c, devices);
+ EXPECT_EQ(result, 0);
+ mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
+ }
+
+ Result setPartitioning(uint32_t partitioning) {
+ return static_cast<Result>(builder()->setPartitioning(partitioning));
}
using WrapperCompilation::finish;
- const ExecutionPlan& getExecutionPlan() const {
- return builder()->forTest_getExecutionPlan();
- }
+ const ExecutionPlan& getExecutionPlan() const { return builder()->forTest_getExecutionPlan(); }
-private:
- CompilationBuilder* builder() {
- return reinterpret_cast<CompilationBuilder*>(getHandle());
- }
+ private:
+ CompilationBuilder* builder() { return reinterpret_cast<CompilationBuilder*>(getHandle()); }
const CompilationBuilder* builder() const {
return reinterpret_cast<const CompilationBuilder*>(getHandle());
@@ -664,16 +652,14 @@ private:
};
#ifdef VERBOSE
-#define RETURN_TRUE() \
- { \
- std::cerr << "returning true from " << __LINE__ << std::endl; \
- return true; \
+#define RETURN_TRUE() \
+ { \
+ std::cerr << "returning true from " << __LINE__ << std::endl; \
+ return true; \
}
#else
-#define RETURN_TRUE() \
- { \
- return true; \
- }
+#define RETURN_TRUE() \
+ { return true; }
#endif
#ifdef VERBOSE
#define RETURN_FALSE(MESSAGE) \
@@ -682,19 +668,16 @@ private:
return false; \
}
#else
-#define RETURN_FALSE(MESSAGE) \
- { \
- return false; \
- }
+#define RETURN_FALSE(MESSAGE) \
+ { return false; }
#endif
class PartitioningTest : public ::testing::Test {
-protected:
+ protected:
using RemapVectorType = ExecutionStep::RemapVectorType;
using SubModelOutputSetType = ExecutionStep::SubModelOutputSetType;
- virtual void SetUp() {
- }
+ virtual void SetUp() {}
// From a vector of DeviceSpecification, create a vector of
// Devices.
@@ -841,21 +824,20 @@ protected:
// within its scope (actual operations, inputs, constants).
enum PseudoDefiningOperationEncodings : uint32_t {
- kPseudoDefiningOperationModelInput0 = 0x80000000U,
+ kPseudoDefiningOperationModelInput0 = 0x80000000U,
kPseudoDefiningOperationConstantCopy0 = 0x90000000U,
- kPseudoDefiningOperationNoValue = 0xeeeeeeeeU,
+ kPseudoDefiningOperationNoValue = 0xeeeeeeeeU,
// lowest value for special encoding
- kPseudoDefiningOperationBase = 0x80000000U,
+ kPseudoDefiningOperationBase = 0x80000000U,
// range of encoded input or constant
- kPseudoDefiningOperationRange = 0x10000000U,
+ kPseudoDefiningOperationRange = 0x10000000U,
};
// Build a map from operand to defining operation.
// TODO: Replace map with vector?
- void buildDefinitionMap(const ModelBuilder* model,
- std::map<uint32_t, uint32_t>* defMap) {
+ void buildDefinitionMap(const ModelBuilder* model, std::map<uint32_t, uint32_t>* defMap) {
// actual definitions
ASSERT_LT(model->operationCount(), kPseudoDefiningOperationBase);
for (uint32_t i = 0, e = model->operationCount(); i < e; i++) {
@@ -879,7 +861,8 @@ protected:
case OperandLifeTime::CONSTANT_COPY: {
ASSERT_EQ(operand.location.length, sizeof(uint32_t));
uint32_t value;
- memcpy(&value, model->getPointerToOperandValue(operand.location.offset), sizeof(uint32_t));
+ memcpy(&value, model->getPointerToOperandValue(operand.location.offset),
+ sizeof(uint32_t));
ASSERT_LT(value, kPseudoDefiningOperationNoValue);
(*defMap)[i] = kPseudoDefiningOperationConstantCopy0 + value;
break;
@@ -927,11 +910,9 @@ protected:
#endif
bool compare(const Operand& operandA, const Operand& operandB) {
- if (operandA.type != operandB.type ||
- operandA.dimensions != operandB.dimensions ||
+ if (operandA.type != operandB.type || operandA.dimensions != operandB.dimensions ||
operandA.numberOfConsumers != operandB.numberOfConsumers ||
- operandA.scale != operandB.scale ||
- operandA.zeroPoint != operandB.zeroPoint) {
+ operandA.scale != operandB.scale || operandA.zeroPoint != operandB.zeroPoint) {
return false;
}
return true;
@@ -980,10 +961,10 @@ protected:
::dump("compare(B)", modelB);
#endif
- if (modelA->operandCount() != modelB->operandCount() ||
+ if (modelA->operandCount() != modelB->operandCount() ||
modelA->operationCount() != modelB->operationCount() ||
- modelA->inputCount() != modelB->inputCount() ||
- modelA->outputCount() != modelB->outputCount()) {
+ modelA->inputCount() != modelB->inputCount() ||
+ modelA->outputCount() != modelB->outputCount()) {
RETURN_FALSE();
}
@@ -1093,8 +1074,7 @@ protected:
}
// Sanity check
- if (modelA->operandCount() != defsA.size() ||
- modelA->operandCount() != defsB.size() ||
+ if (modelA->operandCount() != defsA.size() || modelA->operandCount() != defsB.size() ||
modelA->operandCount() != equivalentOperandsAToB.size() ||
modelA->operationCount() + pseudoDefinitionCount != equivalentOperationsAToB.size()) {
RETURN_FALSE();
@@ -1180,7 +1160,7 @@ TEST_F(PartitioningTest, SimpleModel) {
uint32_t opnd2 = model.addOperation2To1V1_0(0, opnd0, opnd1);
uint32_t opnd3 = model.addFloatOperand();
uint32_t opnd4 = model.addOperation2To1V1_0(1, opnd2, opnd3);
- model.identifyInputsAndOutputs({ opnd0, opnd1, opnd3 }, { opnd4 });
+ model.identifyInputsAndOutputs({opnd0, opnd1, opnd3}, {opnd4});
model.finish();
ASSERT_TRUE(model.isValid());
@@ -1222,7 +1202,7 @@ TEST_F(PartitioningTest, SimpleModel) {
uint32_t b0Opnd0 = modelB0.addFloatOperand();
uint32_t b0Opnd1 = modelB0.addFloatOperand();
uint32_t b0Opnd2 = modelB0.addOperation2To1V1_0(0, b0Opnd0, b0Opnd1);
- modelB0.identifyInputsAndOutputs({ b0Opnd0, b0Opnd1 }, { b0Opnd2 });
+ modelB0.identifyInputsAndOutputs({b0Opnd0, b0Opnd1}, {b0Opnd2});
modelB0.finish();
ASSERT_TRUE(modelB0.isValid());
@@ -1245,7 +1225,7 @@ TEST_F(PartitioningTest, SimpleModel) {
// an input; so in the submodel "modelB1", the corresponding
// input b1Opnd2 is a submodel input, and must follow the
// model input b1Opnd3.
- modelB1.identifyInputsAndOutputs({ b1Opnd3, b1Opnd2 }, { b1Opnd4 });
+ modelB1.identifyInputsAndOutputs({b1Opnd3, b1Opnd2}, {b1Opnd4});
modelB1.finish();
ASSERT_TRUE(modelB1.isValid());
@@ -1410,7 +1390,7 @@ TEST_F(PartitioningTest, Cpu) {
uint32_t opnd7 = model.addOperation2To1V1_0(kDevOp, opnd3, opnd5);
uint32_t opnd8 = model.addOperation2To1V1_0(kDevOp, opnd6, opnd7);
- model.identifyInputsAndOutputs({ opnd0, opnd1, opnd6 }, { opnd4, opnd8 });
+ model.identifyInputsAndOutputs({opnd0, opnd1, opnd6}, {opnd4, opnd8});
model.finish();
ASSERT_TRUE(model.isValid());
@@ -1429,7 +1409,7 @@ TEST_F(PartitioningTest, Cpu) {
uint32_t m0Opnd1 = model0.addFloatOperand();
uint32_t m0Opnd2 = model0.addOperation2To1V1_0(kDevOp, m0Opnd0, m0Opnd1);
uint32_t m0Opnd3 = model0.addOperation2To1V1_0(kDevOp, m0Opnd0, m0Opnd2);
- model0.identifyInputsAndOutputs({ m0Opnd0, m0Opnd1 }, { m0Opnd2, m0Opnd3 });
+ model0.identifyInputsAndOutputs({m0Opnd0, m0Opnd1}, {m0Opnd2, m0Opnd3});
model0.finish();
ASSERT_TRUE(model0.isValid());
@@ -1452,7 +1432,7 @@ TEST_F(PartitioningTest, Cpu) {
uint32_t m1Opnd4 = model1.addOperation2To1V1_0(kCpuOp, m1Opnd0, m1Opnd3);
uint32_t m1Opnd2 = model1.addFloatOperand();
uint32_t m1Opnd5 = model1.addOperation2To1V1_0(kCpuOp, m1Opnd2, m1Opnd4);
- model1.identifyInputsAndOutputs({ m1Opnd0, m1Opnd3, m1Opnd2 }, { m1Opnd4, m1Opnd5 });
+ model1.identifyInputsAndOutputs({m1Opnd0, m1Opnd3, m1Opnd2}, {m1Opnd4, m1Opnd5});
model1.finish();
ASSERT_TRUE(model1.isValid());
@@ -1474,7 +1454,7 @@ TEST_F(PartitioningTest, Cpu) {
uint32_t m2Opnd7 = model2.addOperation2To1V1_0(kDevOp, m2Opnd3, m2Opnd5);
uint32_t m2Opnd6 = model2.addFloatOperand();
uint32_t m2Opnd8 = model2.addOperation2To1V1_0(kDevOp, m2Opnd6, m2Opnd7);
- model2.identifyInputsAndOutputs({ m2Opnd6, m2Opnd3, m2Opnd5 }, { m2Opnd8 });
+ model2.identifyInputsAndOutputs({m2Opnd6, m2Opnd3, m2Opnd5}, {m2Opnd8});
model2.finish();
ASSERT_TRUE(model2.isValid());
@@ -1495,7 +1475,7 @@ TEST_F(PartitioningTest, SetPartitioning) {
model.addOperation2To1V1_0(0, opnd0, opnd1, PartitioningModel::Dimensioned::NO);
uint32_t opnd3 = model.addFloatOperand();
uint32_t opnd4 = model.addOperation2To1V1_0(1, opnd2, opnd3);
- model.identifyInputsAndOutputs({ opnd0, opnd1, opnd3 }, { opnd4 });
+ model.identifyInputsAndOutputs({opnd0, opnd1, opnd3}, {opnd4});
model.finish();
ASSERT_TRUE(model.isValid());
@@ -1524,7 +1504,8 @@ TEST_F(PartitioningTest, SetPartitioning) {
// No need to compare the original model to the model from the plan -- we
// didn't actually do any partitioning.
PartitioningCompilation cPWithFallback(&model, devices);
- ASSERT_EQ(cPWithFallback.setPartitioning(DeviceManager::kPartitioningWithFallback), Result::NO_ERROR);
+ ASSERT_EQ(cPWithFallback.setPartitioning(DeviceManager::kPartitioningWithFallback),
+ Result::NO_ERROR);
ASSERT_EQ(cPWithFallback.finish(), Result::NO_ERROR);
ASSERT_EQ(cPWithFallback.getExecutionPlan().forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
ASSERT_EQ(cPWithFallback.getExecutionPlan().forTest_simpleGetDevice(),
@@ -1533,21 +1514,23 @@ TEST_F(PartitioningTest, SetPartitioning) {
// Test kPartitioningWithoutFallback. We should attempt
// partitioning, and fail.
PartitioningCompilation cPWithoutFallback(&model, devices);
- ASSERT_EQ(cPWithoutFallback.setPartitioning(DeviceManager::kPartitioningWithoutFallback), Result::NO_ERROR);
+ ASSERT_EQ(cPWithoutFallback.setPartitioning(DeviceManager::kPartitioningWithoutFallback),
+ Result::NO_ERROR);
ASSERT_EQ(cPWithoutFallback.finish(), Result::OP_FAILED);
ASSERT_TRUE(cPWithoutFallback.getExecutionPlan().forTest_hasSubModelOutputsOfUnknownSize());
ASSERT_EQ(cPWithoutFallback.getExecutionPlan().forTest_getKind(), ExecutionPlan::Kind::ERROR);
}
// Regression test for http://b/69166603:
-// "partitioned compilation and execution yields wrong results when model output is submodel input"
+// "partitioned compilation and execution yields wrong results when model output is submodel
+// input"
TEST_F(PartitioningTest, ModelOutputAsSubmodelInput) {
PartitioningModel model;
uint32_t opnd0 = model.addFloatOperand();
uint32_t opnd1 = model.addFloatOperand();
uint32_t opnd2 = model.addOperation2To1V1_0(0, opnd0, opnd1);
uint32_t opnd3 = model.addOperation2To1V1_0(1, opnd2, opnd2);
- model.identifyInputsAndOutputs({ opnd0, opnd1 }, { opnd2, opnd3 });
+ model.identifyInputsAndOutputs({opnd0, opnd1}, {opnd2, opnd3});
model.finish();
ASSERT_TRUE(model.isValid());
@@ -1568,7 +1551,7 @@ TEST_F(PartitioningTest, ModelOutputAsSubmodelInput) {
uint32_t m0Opnd0 = model0.addFloatOperand();
uint32_t m0Opnd1 = model0.addFloatOperand();
uint32_t m0Opnd2 = model0.addOperation2To1V1_0(0, m0Opnd0, m0Opnd1);
- model0.identifyInputsAndOutputs({ m0Opnd0, m0Opnd1 }, { m0Opnd2 });
+ model0.identifyInputsAndOutputs({m0Opnd0, m0Opnd1}, {m0Opnd2});
model0.finish();
ASSERT_TRUE(model0.isValid());
ASSERT_NO_FATAL_FAILURE(
@@ -1584,7 +1567,7 @@ TEST_F(PartitioningTest, ModelOutputAsSubmodelInput) {
PartitioningModel model1;
uint32_t m1Opnd2 = model1.addFloatOperand();
uint32_t m1Opnd3 = model1.addOperation2To1V1_0(1, m1Opnd2, m1Opnd2);
- model1.identifyInputsAndOutputs({ m1Opnd2 }, { m1Opnd3 });
+ model1.identifyInputsAndOutputs({m1Opnd2}, {m1Opnd3});
model1.finish();
ASSERT_TRUE(model1.isValid());
@@ -1602,7 +1585,7 @@ TEST_F(PartitioningTest, OemOperations) {
PartitioningModel model;
uint32_t opndIn = model.addFloatOperand();
uint32_t opndOut = model.addOperationOEM1To1(opndIn);
- model.identifyInputsAndOutputs({ opndIn }, { opndOut });
+ model.identifyInputsAndOutputs({opndIn}, {opndOut});
model.finish();
ASSERT_TRUE(model.isValid());
@@ -1649,7 +1632,7 @@ TEST_F(PartitioningTest, RelaxedFP) {
uint32_t opnd0 = model.addFloatOperand();
uint32_t opnd1 = model.addFloatOperand();
uint32_t opnd2 = model.addOperation2To1V1_0(0, opnd0, opnd1);
- model.identifyInputsAndOutputs({ opnd0, opnd1 }, { opnd2 });
+ model.identifyInputsAndOutputs({opnd0, opnd1}, {opnd2});
model.relaxComputationFloat32toFloat16(doRelax);
model.finish();
ASSERT_TRUE(model.isValid());
diff --git a/nn/runtime/test/TestPartitioningRandom.cpp b/nn/runtime/test/TestPartitioningRandom.cpp
index d62a6ad97..b7326e562 100644
--- a/nn/runtime/test/TestPartitioningRandom.cpp
+++ b/nn/runtime/test/TestPartitioningRandom.cpp
@@ -147,8 +147,7 @@ typedef std::pair<ANeuralNetworksOperationType, int> Signature;
// it provides access to certain features from ModelBuilder that are not exposed
// by the base class (such as inputCount() and operation index).
class TestModel : public WrapperModel {
-public:
-
+ public:
uint32_t addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
const std::vector<uint32_t>& outputs) {
const uint32_t operationIndex = operationCount();
@@ -157,16 +156,10 @@ public:
return operationIndex;
}
- uint32_t operationCount() const {
- return mOperations.size();
- }
+ uint32_t operationCount() const { return mOperations.size(); }
- uint32_t inputCount() const {
- return builder()->inputCount();
- }
- uint32_t outputCount() const {
- return builder()->outputCount();
- }
+ uint32_t inputCount() const { return builder()->inputCount(); }
+ uint32_t outputCount() const { return builder()->outputCount(); }
const std::vector<uint32_t>& getOperationOutputs(uint32_t index) const {
CHECK(index < mOperations.size());
@@ -198,8 +191,7 @@ public:
WrapperModel::setOperandValue(index, &value, sizeof(value));
}
-private:
-
+ private:
const ModelBuilder* builder() const {
return reinterpret_cast<const ModelBuilder*>(getHandle());
}
@@ -217,7 +209,7 @@ private:
// to provide access to certain features from CompilationBuilder that are not
// exposed by the base class.
class TestCompilation : public WrapperCompilation {
-public:
+ public:
TestCompilation(const WrapperModel* model) : WrapperCompilation(model) {}
TestCompilation(const WrapperModel* model, std::vector<std::shared_ptr<Device>> devices) {
@@ -234,17 +226,13 @@ public:
return static_cast<Result>(builder()->setPartitioning(partitioning));
}
- const ExecutionPlan& getExecutionPlan() const {
- return builder()->forTest_getExecutionPlan();
- }
+ const ExecutionPlan& getExecutionPlan() const { return builder()->forTest_getExecutionPlan(); }
-private:
+ private:
const CompilationBuilder* builder() const {
return reinterpret_cast<const CompilationBuilder*>(getHandle());
}
- CompilationBuilder* builder() {
- return reinterpret_cast<CompilationBuilder*>(getHandle());
- }
+ CompilationBuilder* builder() { return reinterpret_cast<CompilationBuilder*>(getHandle()); }
};
// This class is used to manage a collection of memory regions,
@@ -262,7 +250,7 @@ private:
// TestMemories instance, and are destroyed when the TestMemories
// instance is destroyed.
class TestMemories {
-public:
+ public:
TestMemories() = default;
~TestMemories();
@@ -274,9 +262,7 @@ public:
mMemorySizes.push_back(0);
return memoryCount() - 1;
}
- unsigned memoryCount() const {
- return mMemorySizes.size();
- }
+ unsigned memoryCount() const { return mMemorySizes.size(); }
unsigned addRegion(unsigned memoryIndex, uint32_t length) {
CHECK(!mLayoutDone);
@@ -287,14 +273,12 @@ public:
memorySize += length;
return regionCount() - 1;
}
- unsigned regionCount() const {
- return mRegions.size();
- }
+ unsigned regionCount() const { return mRegions.size(); }
void layout();
- void* getRegion(unsigned regionIndex,
- const WrapperMemory** pMemory, uint32_t* pOffset, uint32_t* pLength) {
+ void* getRegion(unsigned regionIndex, const WrapperMemory** pMemory, uint32_t* pOffset,
+ uint32_t* pLength) {
CHECK(mLayoutDone);
CHECK(regionIndex < regionCount());
const auto& regionDescriptor = mRegions[regionIndex];
@@ -319,7 +303,7 @@ public:
return getRegion(regionIndex, nullptr, nullptr, nullptr);
}
-private:
+ private:
// Index is the memory index; value is the size of the memory
// (aggregate size of all regions in the memory).
std::vector<uint32_t> mMemorySizes;
@@ -354,23 +338,23 @@ TestMemories::~TestMemories() {
}
class RandomPartitioningTest : public ::testing::TestWithParam<unsigned> {
-public:
+ public:
RandomPartitioningTest() : mRandNumEng(GetParam() /* seed */), mRandNumUnitDist(0.0, 1.0) {}
static Signature getSignature(const HidlModel& model, const Operation& operation);
-protected:
- static V1_0::IDevice* makeTestDriver(HalVersion version, const char* name,
- std::set<Signature> signatures);
+ protected:
+ static V1_0::IDevice* makeTestDriver(HalVersion version, const char* name,
+ std::set<Signature> signatures);
- static HalVersion getMinHalVersion(ANeuralNetworksOperationType type);
+ static HalVersion getMinHalVersion(ANeuralNetworksOperationType type);
- static std::string to_string(HalVersion version);
+ static std::string to_string(HalVersion version);
- bool randBool() { return randUInt(2) == 1; }
+ bool randBool() { return randUInt(2) == 1; }
- double randFrac() { // [0.0, 1.0)
- return mRandNumUnitDist(mRandNumEng);
+ double randFrac() { // [0.0, 1.0)
+ return mRandNumUnitDist(mRandNumEng);
}
unsigned randUInt(unsigned limit) { // [0, limit)
@@ -412,11 +396,10 @@ protected:
}
// input operand 3 is bias, a 1-D tensor
- const WrapperOperandType biasType(WrapperType::TENSOR_FLOAT32, { problemSize });
+ const WrapperOperandType biasType(WrapperType::TENSOR_FLOAT32, {problemSize});
const uint32_t operandIndex = model->addOperand(&biasType);
std::vector<float> biasValue(problemSize);
- std::generate(biasValue.begin(), biasValue.end(),
- [this]{ return randFrac(); });
+ std::generate(biasValue.begin(), biasValue.end(), [this] { return randFrac(); });
model->setOperandValue(operandIndex, biasValue);
return int(operandIndex);
}
@@ -440,24 +423,23 @@ protected:
#ifdef VERBOSE
class ModelStats {
- public:
- ModelStats(const ModelBuilder* model) :
- mBuilder(model) { }
- ModelStats(const WrapperModel* model) :
- mBuilder(reinterpret_cast<const ModelBuilder*>(model->getHandle())) { }
+ public:
+ ModelStats(const ModelBuilder* model) : mBuilder(model) {}
+ ModelStats(const WrapperModel* model)
+ : mBuilder(reinterpret_cast<const ModelBuilder*>(model->getHandle())) {}
friend std::ostream& operator<<(std::ostream& out, const ModelStats& stats) {
const uint32_t operandCount = stats.mBuilder->operandCount();
const uint32_t inputCount = stats.mBuilder->inputCount();
const uint32_t outputCount = stats.mBuilder->outputCount();
out << "operationCount = " << stats.mBuilder->operationCount()
- << ", operandCount = " << operandCount
- << ", inputCount = " << inputCount
- << " (" << (double(inputCount) / operandCount) << ")"
- << ", outputCount = " << outputCount
- << " (" << (double(outputCount) / operandCount) << ")";
+ << ", operandCount = " << operandCount << ", inputCount = " << inputCount << " ("
+ << (double(inputCount) / operandCount) << ")"
+ << ", outputCount = " << outputCount << " (" << (double(outputCount) / operandCount)
+ << ")";
return out;
}
- private:
+
+ private:
const ModelBuilder* mBuilder;
};
@@ -526,9 +508,7 @@ Signature RandomPartitioningTest::getSignature(const HidlModel& model, const Ope
CHECK(operand.lifetime == OperandLifeTime::CONSTANT_COPY);
CHECK(operand.type == OperandType::INT32);
int32_t value;
- memcpy(&value,
- &model.operandValues[operand.location.offset],
- operand.location.length);
+ memcpy(&value, &model.operandValues[operand.location.offset], operand.location.length);
return Signature(operationType, value);
}
@@ -546,11 +526,11 @@ std::string RandomPartitioningTest::to_string(HalVersion version) {
};
class TestDriver : public SampleDriver {
-public:
+ public:
// Behaves like SampleDriver, except that it only supports
// operations with the specified signatures.
- TestDriver(const char* name, std::set<Signature> signatures) :
- SampleDriver(name), mSignatures(std::move(signatures)) { }
+ TestDriver(const char* name, std::set<Signature> signatures)
+ : SampleDriver(name), mSignatures(std::move(signatures)) {}
Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override {
android::nn::initVLogMask();
@@ -569,11 +549,8 @@ public:
const size_t count = model.operations.size();
std::vector<bool> supported(count);
for (size_t i = 0; i < count; i++) {
- supported[i] =
- (mSignatures.count(
- RandomPartitioningTest::getSignature(
- model,
- model.operations[i])) != 0);
+ supported[i] = (mSignatures.count(RandomPartitioningTest::getSignature(
+ model, model.operations[i])) != 0);
}
cb(ErrorStatus::NONE, supported);
} else {
@@ -591,15 +568,15 @@ public:
// NOTE: We verify that all operations in the model are supported.
ErrorStatus outStatus = ErrorStatus::INVALID_ARGUMENT;
auto ret = getSupportedOperations_1_2(
- model,
- [&outStatus](ErrorStatus inStatus, const hidl_vec<bool>& supportedOperations) {
- if (inStatus == ErrorStatus::NONE) {
- if (std::all_of(supportedOperations.begin(), supportedOperations.end(),
- [](bool v){ return v; })) {
- outStatus = ErrorStatus::NONE;
+ model,
+ [&outStatus](ErrorStatus inStatus, const hidl_vec<bool>& supportedOperations) {
+ if (inStatus == ErrorStatus::NONE) {
+ if (std::all_of(supportedOperations.begin(), supportedOperations.end(),
+ [](bool v) { return v; })) {
+ outStatus = ErrorStatus::NONE;
+ }
}
- }
- });
+ });
if (ret.isOk() && (outStatus == ErrorStatus::NONE)) {
return SampleDriver::prepareModel_1_2(model, preference, modelCache, dataCache, token,
callback);
@@ -609,7 +586,7 @@ public:
}
}
-private:
+ private:
const std::set<Signature> mSignatures;
};
@@ -696,13 +673,13 @@ TEST_P(RandomPartitioningTest, Test) {
std::cout << std::setprecision(2) << std::fixed << std::setw(4);
#endif
- const unsigned problemSize = 1+randUInt(kMaxProblemSize);
- const WrapperOperandType problemType(WrapperType::TENSOR_FLOAT32, { problemSize, problemSize });
- const WrapperOperandType unknownDimensionsType(WrapperType::TENSOR_FLOAT32, { 0, 0 });
+ const unsigned problemSize = 1 + randUInt(kMaxProblemSize);
+ const WrapperOperandType problemType(WrapperType::TENSOR_FLOAT32, {problemSize, problemSize});
+ const WrapperOperandType unknownDimensionsType(WrapperType::TENSOR_FLOAT32, {0, 0});
- static const WrapperOperandType activationFunctionType(WrapperType::INT32, { });
+ static const WrapperOperandType activationFunctionType(WrapperType::INT32, {});
- const unsigned numOperations = 2+randUInt(kMaxNumOperations-1);
+ const unsigned numOperations = 2 + randUInt(kMaxNumOperations - 1);
const bool allowDeadOperations = (randFrac() < 0.2);
const bool allowUnknownDimensions = (randFrac() < 0.25);
@@ -783,7 +760,7 @@ TEST_P(RandomPartitioningTest, Test) {
}
if (operationPattern.mMakeSpecialInput != nullptr) {
const int operandIndex = (this->*(operationPattern.mMakeSpecialInput))(
- problemSize, &model, operationInputIndex);
+ problemSize, &model, operationInputIndex);
if (operandIndex >= 0) {
operationInputs[operationInputIndex] = operandIndex;
continue;
@@ -811,48 +788,46 @@ TEST_P(RandomPartitioningTest, Test) {
// decision later.
enum InputKind { IK_MODEL_INPUT, IK_OPERATION_OUTPUT, IK_VALUE };
std::vector<InputKind> normalOperationInputKinds(normalOperationInputCount);
- std::generate(normalOperationInputKinds.begin(), normalOperationInputKinds.end(),
- [this, &model,
- numOperations,
- normalOperationInputCount,
- &normalOperationInputConstantCount,
- &normalOperationInputModelInputCount]() -> InputKind {
- // Constant? Becomes less likely the more
- // constants we already have as inputs to
- // this operation.
- if (randFrac() < 0.3 * (1 - double(normalOperationInputConstantCount) /
- normalOperationInputCount)) {
- normalOperationInputConstantCount++;
- return IK_VALUE;
- }
+ std::generate(
+ normalOperationInputKinds.begin(), normalOperationInputKinds.end(),
+ [this, &model, numOperations, normalOperationInputCount,
+ &normalOperationInputConstantCount,
+ &normalOperationInputModelInputCount]() -> InputKind {
+ // Constant? Becomes less likely the more
+ // constants we already have as inputs to
+ // this operation.
+ if (randFrac() < 0.3 * (1 - double(normalOperationInputConstantCount) /
+ normalOperationInputCount)) {
+ normalOperationInputConstantCount++;
+ return IK_VALUE;
+ }
- // Model input? Becomes less likely the
- // more model inputs we already have as
- // inputs to this operation, and the further
- // along we are in generating this model
- // (i.e., the more operations we have
- // generated).
- if ((model.operationCount() == 0) ||
- (randFrac() < 0.5 *
- (1 - double(normalOperationInputModelInputCount) /
- normalOperationInputCount) *
- std::min(0.3, (1 - double(model.operationCount()) /
- numOperations)))) {
- normalOperationInputModelInputCount++;
- return IK_MODEL_INPUT;
- }
+ // Model input? Becomes less likely the
+ // more model inputs we already have as
+ // inputs to this operation, and the further
+ // along we are in generating this model
+ // (i.e., the more operations we have
+ // generated).
+ if ((model.operationCount() == 0) ||
+ (randFrac() < 0.5 *
+ (1 - double(normalOperationInputModelInputCount) /
+ normalOperationInputCount) *
+ std::min(0.3, (1 - double(model.operationCount()) /
+ numOperations)))) {
+ normalOperationInputModelInputCount++;
+ return IK_MODEL_INPUT;
+ }
- // Else output of an existing operation.
- return IK_OPERATION_OUTPUT;
- });
+ // Else output of an existing operation.
+ return IK_OPERATION_OUTPUT;
+ });
// Now force common root or model input, if necessary. (A
// model must have at least one input.)
- auto force =
- [this, &normalOperationInputKinds, normalOperationInputCount](InputKind forceKind){
- if (std::none_of(normalOperationInputKinds.begin(),
- normalOperationInputKinds.end(),
- [forceKind](InputKind kind){ return kind == forceKind; })) {
+ auto force = [this, &normalOperationInputKinds,
+ normalOperationInputCount](InputKind forceKind) {
+ if (std::none_of(normalOperationInputKinds.begin(), normalOperationInputKinds.end(),
+ [forceKind](InputKind kind) { return kind == forceKind; })) {
normalOperationInputKinds[randUInt(normalOperationInputCount)] = forceKind;
}
};
@@ -889,7 +864,7 @@ TEST_P(RandomPartitioningTest, Test) {
const auto& existingOperationOutputs =
model.getOperationOutputs(existingOperationIndex);
operandIndex =
- existingOperationOutputs[randUInt(existingOperationOutputs.size())];
+ existingOperationOutputs[randUInt(existingOperationOutputs.size())];
deadOperandI = deadOperands.find(operandIndex);
CHECK(deadOperandI == deadOperands.end() ||
deadOperandI->second == existingOperationIndex);
@@ -913,7 +888,8 @@ TEST_P(RandomPartitioningTest, Test) {
operandIndex = model.addOperand(&problemType);
if (randFrac() < 0.5) {
std::vector<float> value(problemSize * problemSize);
- std::generate(value.begin(), value.end(), [this]{ return randFrac(); });
+ std::generate(value.begin(), value.end(),
+ [this] { return randFrac(); });
model.setOperandValue(operandIndex, value);
valueOperands.push_back(std::make_pair(operandIndex, ~0U));
} else {
@@ -945,7 +921,7 @@ TEST_P(RandomPartitioningTest, Test) {
std::vector<uint32_t> operationOutputs(operationPattern.mNumOutputs);
std::generate(operationOutputs.begin(), operationOutputs.end(),
[&model, &problemType, &unknownDimensionsType, &hasUnknownDimensions,
- allowUnknownDimensions, this]{
+ allowUnknownDimensions, this] {
// 3% unknowns causes ~35% of partitionings to fail
// (determined by commenting out the fallback code,
// running tests and noting number of failures).
@@ -959,9 +935,8 @@ TEST_P(RandomPartitioningTest, Test) {
// OPERATION ///////////////////////////////////////////////////////////////////////////////
- const uint32_t operationIndex =
- model.addOperation(operationPattern.mOperationType,
- operationInputs, operationOutputs);
+ const uint32_t operationIndex = model.addOperation(operationPattern.mOperationType,
+ operationInputs, operationOutputs);
deadOperations.insert(operationIndex);
std::for_each(operationOutputs.begin(), operationOutputs.end(),
[&deadOperands, operationIndex](uint32_t operandIndex) {
@@ -984,7 +959,7 @@ TEST_P(RandomPartitioningTest, Test) {
float* region =
static_cast<float*>(weights.getRegion(regionIndex, &memory, &offset, &length));
CHECK(length == problemSize * problemSize * sizeof(float));
- std::generate(region, region + problemSize * problemSize, [this]{ return randFrac(); });
+ std::generate(region, region + problemSize * problemSize, [this] { return randFrac(); });
model.setOperandValueFromMemory(operandIndex, memory, offset, length);
}
@@ -1005,7 +980,7 @@ TEST_P(RandomPartitioningTest, Test) {
// more likely we are to classify this operation
// output as a model output.
const double probabilityOfModelOutput =
- 0.50 * [](double x){ return x*x; }((operationIdx + 1) / operationCount);
+ 0.50 * [](double x) { return x * x; }((operationIdx + 1) / operationCount);
modelOutput = (randFrac() < probabilityOfModelOutput);
} else {
// This is consumed within the model, so we'll rarely
@@ -1044,8 +1019,7 @@ TEST_P(RandomPartitioningTest, Test) {
#ifdef VERBOSE
{
std::cout << "Original model: " << ModelStats(&model) << std::endl;
- std::cout << "rootOperationCount = " << rootOperationCount
- << ", deadOperations = ";
+ std::cout << "rootOperationCount = " << rootOperationCount << ", deadOperations = ";
if (allowDeadOperations) {
std::cout << deadOperations.size();
} else {
@@ -1072,8 +1046,8 @@ TEST_P(RandomPartitioningTest, Test) {
}
// Now remove each entry that has no signatures.
auto firstExtra =
- std::remove_if(signaturesForDriver.begin(), signaturesForDriver.end(),
- [](const std::set<Signature>& sigSet) { return sigSet.empty(); });
+ std::remove_if(signaturesForDriver.begin(), signaturesForDriver.end(),
+ [](const std::set<Signature>& sigSet) { return sigSet.empty(); });
if (firstExtra != signaturesForDriver.end()) {
signaturesForDriver.erase(firstExtra, signaturesForDriver.end());
}
@@ -1114,7 +1088,7 @@ TEST_P(RandomPartitioningTest, Test) {
// the fallback to succeed.
TestCompilation cNoFallback(&model, devices);
TestCompilation cWithFallback(&model, devices);
- TestCompilation *c2 = nullptr;
+ TestCompilation* c2 = nullptr;
ASSERT_EQ(cNoFallback.setPartitioning(DeviceManager::kPartitioningWithoutFallback),
Result::NO_ERROR);
auto compilationResult = cNoFallback.finish();
@@ -1134,8 +1108,8 @@ TEST_P(RandomPartitioningTest, Test) {
#ifdef VERBOSE
{
- std::cout << "signatures = " << signatures.size()
- << ", devices = " << devices.size() << std::endl;
+ std::cout << "signatures = " << signatures.size() << ", devices = " << devices.size()
+ << std::endl;
const ExecutionPlan& plan = c2->getExecutionPlan();
switch (plan.forTest_getKind()) {
case ExecutionPlan::Kind::SIMPLE:
@@ -1157,7 +1131,7 @@ TEST_P(RandomPartitioningTest, Test) {
}
default:
std::cout << "Unexpected plan kind: "
- << static_cast<unsigned>(plan.forTest_getKind());
+ << static_cast<unsigned>(plan.forTest_getKind());
break;
}
}
@@ -1187,7 +1161,7 @@ TEST_P(RandomPartitioningTest, Test) {
// should not be dependent on the outputs; but we'll initialize the
// outputs anyway.
std::vector<float> masterInputs(problemSize * problemSize * model.inputCount());
- std::generate(masterInputs.begin(), masterInputs.end(), [this]{ return randFrac(); });
+ std::generate(masterInputs.begin(), masterInputs.end(), [this] { return randFrac(); });
#ifdef VERBOSE
{
std::cout << "flat inputs = ";
@@ -1213,9 +1187,8 @@ TEST_P(RandomPartitioningTest, Test) {
};
std::vector<InputOutputDescriptor> ioDescriptors(model.inputCount() + model.outputCount());
for (unsigned i = 0; i < ioDescriptors.size(); i++) {
- ioDescriptors[i].mKind = (i < model.inputCount()
- ? InputOutputDescriptor::INPUT
- : InputOutputDescriptor::OUTPUT);
+ ioDescriptors[i].mKind = (i < model.inputCount() ? InputOutputDescriptor::INPUT
+ : InputOutputDescriptor::OUTPUT);
}
// We randomly interleave inputs and outputs in creation
// order, because when we we create memory regions in a
@@ -1226,7 +1199,7 @@ TEST_P(RandomPartitioningTest, Test) {
// they'll be interleaved.
std::shuffle(ioDescriptors.begin(), ioDescriptors.end(), mRandNumEng);
TestMemories ioMemories;
- for (auto &desc : ioDescriptors) {
+ for (auto& desc : ioDescriptors) {
if (randFrac() < 0.5) {
desc.mVector.resize(problemSize * problemSize);
} else {
@@ -1245,11 +1218,10 @@ TEST_P(RandomPartitioningTest, Test) {
// Function to set up actual inputs and outputs (initializing them
// and telling the WrapperExecution about them).
- auto prepareForExecution =
- [&model, &ioDescriptors, &ioMemories,
- &masterInputs, &masterOutput, problemSize, &problemType](WrapperExecution *e) {
+ auto prepareForExecution = [&model, &ioDescriptors, &ioMemories, &masterInputs, &masterOutput,
+ problemSize, &problemType](WrapperExecution* e) {
uint32_t inputIndex = 0, outputIndex = 0;
- for (auto &desc : ioDescriptors) {
+ for (auto& desc : ioDescriptors) {
if (desc.getLocation() == InputOutputDescriptor::VECTOR) {
if (desc.mKind == InputOutputDescriptor::INPUT) {
const size_t inputOffset = inputIndex * problemSize * problemSize;
@@ -1260,18 +1232,15 @@ TEST_P(RandomPartitioningTest, Test) {
desc.mVector.size() * sizeof(float));
} else {
std::fill(desc.mVector.begin(),
- desc.mVector.begin() + problemSize * problemSize,
- masterOutput);
+ desc.mVector.begin() + problemSize * problemSize, masterOutput);
e->setOutput(outputIndex++, desc.mVector.data(),
- desc.mVector.size() * sizeof(float),
- &problemType.operandType);
+ desc.mVector.size() * sizeof(float), &problemType.operandType);
}
} else {
const WrapperMemory* memory;
uint32_t offset, length;
- float* region =
- static_cast<float*>(ioMemories.getRegion(desc.mMemoryRegion,
- &memory, &offset, &length));
+ float* region = static_cast<float*>(
+ ioMemories.getRegion(desc.mMemoryRegion, &memory, &offset, &length));
CHECK(length == problemSize * problemSize * sizeof(float));
if (desc.mKind == InputOutputDescriptor::INPUT) {
const size_t inputOffset = inputIndex * problemSize * problemSize;
@@ -1280,9 +1249,7 @@ TEST_P(RandomPartitioningTest, Test) {
region);
e->setInputFromMemory(inputIndex++, memory, offset, length);
} else {
- std::fill(region,
- region + problemSize * problemSize,
- masterOutput);
+ std::fill(region, region + problemSize * problemSize, masterOutput);
e->setOutputFromMemory(outputIndex++, memory, offset, length,
&problemType.operandType);
}
@@ -1307,13 +1274,11 @@ TEST_P(RandomPartitioningTest, Test) {
}
const size_t outputOffset = outputIndex * problemSize * problemSize;
if (desc.getLocation() == InputOutputDescriptor::VECTOR) {
- std::copy(desc.mVector.begin(),
- desc.mVector.end(),
+ std::copy(desc.mVector.begin(), desc.mVector.end(),
nonPartitionedOutputs.begin() + outputOffset);
} else {
float* region = static_cast<float*>(ioMemories.getRegion(desc.mMemoryRegion));
- std::copy(region,
- region + problemSize * problemSize,
+ std::copy(region, region + problemSize * problemSize,
nonPartitionedOutputs.begin() + outputOffset);
}
#ifdef VERBOSE
@@ -1347,8 +1312,7 @@ TEST_P(RandomPartitioningTest, Test) {
std::cout << " partitioned output[" << outputIndex << "] = ";
dump(desc.mVector.begin(), desc.mVector.end());
#endif
- ASSERT_TRUE(std::equal(desc.mVector.begin(),
- desc.mVector.end(),
+ ASSERT_TRUE(std::equal(desc.mVector.begin(), desc.mVector.end(),
nonPartitionedOutputs.begin() + outputOffset));
} else {
float* region = static_cast<float*>(ioMemories.getRegion(desc.mMemoryRegion));
@@ -1356,8 +1320,7 @@ TEST_P(RandomPartitioningTest, Test) {
std::cout << "part output[" << outputIndex << "] = ";
dump(region, region + problemSize * problemSize);
#endif
- ASSERT_TRUE(std::equal(region,
- region + problemSize * problemSize,
+ ASSERT_TRUE(std::equal(region, region + problemSize * problemSize,
nonPartitionedOutputs.begin() + outputOffset));
}
outputIndex++;
diff --git a/nn/runtime/test/TestTrivialModel.cpp b/nn/runtime/test/TestTrivialModel.cpp
index c5c065439..7280e6ae0 100644
--- a/nn/runtime/test/TestTrivialModel.cpp
+++ b/nn/runtime/test/TestTrivialModel.cpp
@@ -27,7 +27,7 @@ typedef float Matrix3x4[3][4];
typedef float Matrix4[4];
class TrivialTest : public ::testing::Test {
-protected:
+ protected:
virtual void SetUp() {}
const Matrix3x4 matrix1 = {{1.f, 2.f, 3.f, 4.f}, {5.f, 6.f, 7.f, 8.f}, {9.f, 10.f, 11.f, 12.f}};
@@ -35,9 +35,8 @@ protected:
{500.f, 600.f, 700.f, 800.f},
{900.f, 1000.f, 1100.f, 1200.f}};
const Matrix4 matrix2b = {100.f, 200.f, 300.f, 400.f};
- const Matrix3x4 matrix3 = {{20.f, 30.f, 40.f, 50.f},
- {21.f, 22.f, 23.f, 24.f},
- {31.f, 32.f, 33.f, 34.f}};
+ const Matrix3x4 matrix3 = {
+ {20.f, 30.f, 40.f, 50.f}, {21.f, 22.f, 23.f, 24.f}, {31.f, 32.f, 33.f, 34.f}};
const Matrix3x4 expected2 = {{101.f, 202.f, 303.f, 404.f},
{505.f, 606.f, 707.f, 808.f},
{909.f, 1010.f, 1111.f, 1212.f}};
@@ -51,9 +50,8 @@ protected:
const Matrix3x4 expected3 = {{121.f, 232.f, 343.f, 454.f},
{526.f, 628.f, 730.f, 832.f},
{940.f, 1042.f, 1144.f, 1246.f}};
- const Matrix3x4 expected3b = {{22.f, 34.f, 46.f, 58.f},
- {31.f, 34.f, 37.f, 40.f},
- {49.f, 52.f, 55.f, 58.f}};
+ const Matrix3x4 expected3b = {
+ {22.f, 34.f, 46.f, 58.f}, {31.f, 34.f, 37.f, 40.f}, {49.f, 52.f, 55.f, 58.f}};
};
// Create a model that can add two tensors using a one node graph.
diff --git a/nn/runtime/test/TestUnknownDimensions.cpp b/nn/runtime/test/TestUnknownDimensions.cpp
index ada52c2b6..75465659a 100644
--- a/nn/runtime/test/TestUnknownDimensions.cpp
+++ b/nn/runtime/test/TestUnknownDimensions.cpp
@@ -28,8 +28,8 @@ using namespace test_helper;
namespace {
const uint32_t INTENDED_SIZE = 3;
-const uint32_t OTHER_SIZE = 2;
-const uint32_t UNKNOWN_SIZE = 0;
+const uint32_t OTHER_SIZE = 2;
+const uint32_t UNKNOWN_SIZE = 0;
// We test three basic scenarios for each tensor dimension:
// INTENDED_AT_COMPILE_AND_EXECUTE: set the dimension at compile
@@ -58,19 +58,21 @@ const uint32_t UNKNOWN_SIZE = 0;
// infrastructure to handle correctly. However, running all 16k in one test
// makes the ASAN version take so long that the automatic test runner things the
// command has become unresponsinve, so we split on the first level.
-enum class DimensionKind { INTENDED_AT_COMPILE_AND_EXECUTE,
- INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE,
- UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE,
- UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE };
+enum class DimensionKind {
+ INTENDED_AT_COMPILE_AND_EXECUTE,
+ INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE,
+ UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE,
+ UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE
+};
typedef std::tuple<DimensionKind, DimensionKind> OperandParams;
std::vector<DimensionKind> ioDimensionValues = {
- DimensionKind::INTENDED_AT_COMPILE_AND_EXECUTE,
- DimensionKind::INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE,
- DimensionKind::UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE,
- DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE };
+ DimensionKind::INTENDED_AT_COMPILE_AND_EXECUTE,
+ DimensionKind::INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE,
+ DimensionKind::UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE,
+ DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE};
std::vector<DimensionKind> constantDimensionValues = {
DimensionKind::INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE,
- DimensionKind::UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE };
+ DimensionKind::UNKNOWN_AT_COMPILE_INTENDED_AT_EXECUTE};
std::vector<OperandParams> Combine(const std::vector<DimensionKind>& firsts,
const std::vector<DimensionKind>& seconds);
auto ioValues = Combine(ioDimensionValues, ioDimensionValues);
@@ -143,18 +145,20 @@ void UnknownDimensionsTest::CompareResults<_Float16>(std::map<int, std::vector<_
EXPECT_EQ(size_t{0}, totalNumberOfErrors);
}
-template<class T, Type TensorType> void UnknownDimensionsTest::TestOne(
- const OperandParams& paramsForInput0, const OperandParams& paramsForInput1,
- const OperandParams& paramsForConst, const OperandParams& paramsForOutput) {
+template <class T, Type TensorType>
+void UnknownDimensionsTest::TestOne(const OperandParams& paramsForInput0,
+ const OperandParams& paramsForInput1,
+ const OperandParams& paramsForConst,
+ const OperandParams& paramsForOutput) {
typedef T IntendedMatrix[INTENDED_SIZE][INTENDED_SIZE];
- static const IntendedMatrix ones = { { 1, 1, 1 }, { 1, 1, 1 }, { 1, 1, 1 } };
- static const IntendedMatrix twos = { { 2, 2, 2 }, { 2, 2, 2 }, { 2, 2, 2 } };
- static const IntendedMatrix fives = { { 5, 5, 5 }, { 5, 5, 5 }, { 5, 5, 5 } };
+ static const IntendedMatrix ones = {{1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
+ static const IntendedMatrix twos = {{2, 2, 2}, {2, 2, 2}, {2, 2, 2}};
+ static const IntendedMatrix fives = {{5, 5, 5}, {5, 5, 5}, {5, 5, 5}};
const float scale = TensorType == Type::TENSOR_QUANT8_ASYMM ? 1.f : 0.f;
Model model;
- std::string input0Scope("Input 0:"), input1Scope("Input 1:"),
- constantScope("Constant:"), outputScope("Output:");
+ std::string input0Scope("Input 0:"), input1Scope("Input 1:"), constantScope("Constant:"),
+ outputScope("Output:");
auto getDimForCompile = [](DimensionKind kind, std::string* scope) {
switch (kind) {
@@ -176,8 +180,8 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne(
std::string* scope = nullptr) {
OperandType matrixTypeWithPotentiallyUnknownDims(
TensorType,
- { getDimForCompile(std::get<0>(params), scope),
- getDimForCompile(std::get<1>(params), scope) },
+ {getDimForCompile(std::get<0>(params), scope),
+ getDimForCompile(std::get<1>(params), scope)},
scale);
return model.addOperand(&matrixTypeWithPotentiallyUnknownDims);
};
@@ -202,11 +206,9 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne(
model.setOperandValue(activationOpd0, &activation, sizeof(activation));
model.setOperandValue(constantOpd0, twos, sizeof(twos));
- model.addOperation(ANEURALNETWORKS_ADD,
- {inputOpd0, inputOpd1, activationOpd0},
+ model.addOperation(ANEURALNETWORKS_ADD, {inputOpd0, inputOpd1, activationOpd0},
{intermediateOpd0});
- model.addOperation(ANEURALNETWORKS_ADD,
- {intermediateOpd0, constantOpd0, activationOpd0},
+ model.addOperation(ANEURALNETWORKS_ADD, {intermediateOpd0, constantOpd0, activationOpd0},
{outputOpd0});
model.identifyInputsAndOutputs({inputOpd0, inputOpd1}, {outputOpd0});
if (std::get<0>(paramsForConst) == DimensionKind::INTENDED_AT_COMPILE_NOT_SET_AT_EXECUTE &&
@@ -224,7 +226,7 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne(
Compilation compilation(&model);
ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
- IntendedMatrix actual = { { 10, 10, 10 }, { 10, 10, 10 }, { 10, 10, 10 } };
+ IntendedMatrix actual = {{10, 10, 10}, {10, 10, 10}, {10, 10, 10}};
Execution execution(&compilation);
OperandType matrixTypeIntended(TensorType, {INTENDED_SIZE, INTENDED_SIZE}, scale);
@@ -261,19 +263,21 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne(
// on OperandParams
auto sizeAtSet = [](OperandParams params) {
auto first = std::get<0>(params), second = std::get<1>(params);
- size_t firstDim = (first == DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE) ?
- OTHER_SIZE : INTENDED_SIZE;
- size_t secondDim = (second == DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE) ?
- OTHER_SIZE : INTENDED_SIZE;
+ size_t firstDim = (first == DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE)
+ ? OTHER_SIZE
+ : INTENDED_SIZE;
+ size_t secondDim = (second == DimensionKind::UNKNOWN_AT_COMPILE_OTHER_AT_EXECUTE)
+ ? OTHER_SIZE
+ : INTENDED_SIZE;
return firstDim * secondDim * sizeof(fives[0][0]);
};
ASSERT_EQ(execution.setInput(0, ones, sizeAtSet(paramsForInput0), typeAtSet(paramsForInput0)),
Result::NO_ERROR);
ASSERT_EQ(execution.setInput(1, twos, sizeAtSet(paramsForInput1), typeAtSet(paramsForInput1)),
Result::NO_ERROR);
- ASSERT_EQ(execution.setOutput(0, actual, sizeAtSet(paramsForOutput),
- typeAtSet(paramsForOutput)),
- Result::NO_ERROR);
+ ASSERT_EQ(
+ execution.setOutput(0, actual, sizeAtSet(paramsForOutput), typeAtSet(paramsForOutput)),
+ Result::NO_ERROR);
if (allAreIntendedSizeAtExecution) {
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
@@ -295,21 +299,22 @@ template<class T, Type TensorType> void UnknownDimensionsTest::TestOne(
std::vector<OperandParams> Combine(const std::vector<DimensionKind>& firsts,
const std::vector<DimensionKind>& seconds) {
std::vector<OperandParams> ret;
- for (auto first: firsts) {
- for (auto second: seconds) {
+ for (auto first : firsts) {
+ for (auto second : seconds) {
ret.push_back({first, second});
}
}
return ret;
}
-template<class T, Type TensorType> void UnknownDimensionsTest::TestAll() {
+template <class T, Type TensorType>
+void UnknownDimensionsTest::TestAll() {
const OperandParams paramsForInput0 = GetParam();
- for (auto paramsForInput1: ioValues) {
- for (auto paramsForConst: constantValues) {
- for (auto paramsForOutput: ioValues) {
- TestOne<T, TensorType>(paramsForInput0, paramsForInput1,
- paramsForConst, paramsForOutput);
+ for (auto paramsForInput1 : ioValues) {
+ for (auto paramsForConst : constantValues) {
+ for (auto paramsForOutput : ioValues) {
+ TestOne<T, TensorType>(paramsForInput0, paramsForInput1, paramsForConst,
+ paramsForOutput);
}
}
}
diff --git a/nn/runtime/test/TestWrapper.cpp b/nn/runtime/test/TestWrapper.cpp
index d80d46950..1ab8f9540 100644
--- a/nn/runtime/test/TestWrapper.cpp
+++ b/nn/runtime/test/TestWrapper.cpp
@@ -23,7 +23,7 @@ using namespace ::android::nn::wrapper;
// This file tests certain aspects of the interfaces from NeuralNetworksWrapper.h.
class WrapperTestModelFinish : public ::testing::Test {
-protected:
+ protected:
void SetUp() override {
OperandType type(Type::TENSOR_FLOAT32, {1});
mIn = mModel.addOperand(&type);
diff --git a/nn/tools/ion_watcher/ion_watcher.cpp b/nn/tools/ion_watcher/ion_watcher.cpp
index 0061086dc..1a79b38b3 100644
--- a/nn/tools/ion_watcher/ion_watcher.cpp
+++ b/nn/tools/ion_watcher/ion_watcher.cpp
@@ -16,12 +16,12 @@
#define LOG_TAG "IonWatcher"
+#include <stdio.h>
+#include <unistd.h>
#include <fstream>
#include <iostream>
#include <sstream>
-#include <stdio.h>
#include <string>
-#include <unistd.h>
#include <android/log.h>
#define ATRACE_TAG ATRACE_TAG_NNAPI
@@ -54,16 +54,16 @@ int main(void) {
}
int size = 0;
while (true) {
- const int newSize = parseMemInfo("ION_heap");
- if (newSize < 0) {
- return newSize;
- }
- if (newSize != size) {
- size = newSize;
- std::cout << size << "\n";
- ATRACE_INT("ION_heap", size);
- __android_log_print(ANDROID_LOG_INFO, "ion", "ION_heap %d", size);
- }
- usleep(10);
+ const int newSize = parseMemInfo("ION_heap");
+ if (newSize < 0) {
+ return newSize;
+ }
+ if (newSize != size) {
+ size = newSize;
+ std::cout << size << "\n";
+ ATRACE_INT("ION_heap", size);
+ __android_log_print(ANDROID_LOG_INFO, "ion", "ION_heap %d", size);
+ }
+ usleep(10);
}
}
diff --git a/nn/tools/test_generator/include/TestHarness.h b/nn/tools/test_generator/include/TestHarness.h
index e43658d03..3b4b26b16 100644
--- a/nn/tools/test_generator/include/TestHarness.h
+++ b/nn/tools/test_generator/include/TestHarness.h
@@ -224,8 +224,7 @@ void filter_internal(const std::map<int, std::vector<T>>& golden,
});
}
-inline MixedTyped filter(const MixedTyped& golden,
- std::function<bool(int)> is_ignored) {
+inline MixedTyped filter(const MixedTyped& golden, std::function<bool(int)> is_ignored) {
MixedTyped filtered;
filter_internal(golden.operandDimensions, &filtered.operandDimensions, is_ignored);
filter_internal(golden.float32Operands, &filtered.float32Operands, is_ignored);
@@ -275,8 +274,8 @@ inline int getQuant8AllowedError() {
}
}
-inline void compare(const MixedTyped& golden, const MixedTyped& test,
- float fpAtol = 1e-5f, float fpRtol = 1e-5f) {
+inline void compare(const MixedTyped& golden, const MixedTyped& test, float fpAtol = 1e-5f,
+ float fpRtol = 1e-5f) {
int quant8AllowedError = getQuant8AllowedError();
for_each<uint32_t>(
golden.operandDimensions, test.operandDimensions,