summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViet Dang <vddang@google.com>2019-12-04 16:18:08 +0000
committerViet Dang <vddang@google.com>2020-01-09 15:04:38 +0000
commitf388cca00f9f6056f911a954843e9cd8adabef56 (patch)
tree806b1eab0320727f94ed53907bd48a40c033e5b6
parent2ff6fe9a732a56873eec9f2ac3fd45a43c145a4d (diff)
downloadml-f388cca00f9f6056f911a954843e9cd8adabef56.tar.gz
Implements Quantized LSTM op for R.
Also adds support for TENSOR_QUANT8_ASYMM_SIGNED in Test Generator. Bug: 144841609 Bug: 145916330 Test: NeuralNetworksTest_static Change-Id: I14b0d284b1945833d532cbaa33c66e4d77afd8b7
-rw-r--r--nn/common/Android.bp2
-rw-r--r--nn/common/OperationResolver.cpp2
-rw-r--r--nn/common/OperationsUtils.cpp14
-rw-r--r--nn/common/QuantUtils.cpp193
-rw-r--r--nn/common/QuantUtils.h201
-rw-r--r--nn/common/include/Utils.h2
-rw-r--r--nn/common/operations/QLSTM.cpp785
-rw-r--r--nn/runtime/NeuralNetworks.cpp1
-rw-r--r--nn/runtime/include/NeuralNetworks.h131
-rw-r--r--nn/runtime/test/TestValidateOperations.cpp120
-rw-r--r--nn/runtime/test/generated/spec_V1_3/qlstm.example.cpp380
-rw-r--r--nn/runtime/test/specs/V1_3/qlstm.mod.py178
-rw-r--r--nn/tools/api/NeuralNetworks.t4
-rw-r--r--nn/tools/api/types.spec143
-rw-r--r--nn/tools/test_generator/test_harness/include/TestHarness.h1
15 files changed, 2155 insertions, 2 deletions
diff --git a/nn/common/Android.bp b/nn/common/Android.bp
index 43c8fd83e..033292dbc 100644
--- a/nn/common/Android.bp
+++ b/nn/common/Android.bp
@@ -46,6 +46,7 @@ cc_defaults {
"operations/Neg.cpp",
"operations/PRelu.cpp",
"operations/Pooling.cpp",
+ "operations/QLSTM.cpp",
"operations/Quantize.cpp",
"operations/Reduce.cpp",
"operations/ResizeImageOps.cpp",
@@ -136,6 +137,7 @@ cc_library_static {
"MemoryUtils.cpp",
"MetaModel.cpp",
"OperationsUtils.cpp",
+ "QuantUtils.cpp",
"TokenHasher.cpp",
"Utils.cpp",
"ValidateHal.cpp",
diff --git a/nn/common/OperationResolver.cpp b/nn/common/OperationResolver.cpp
index ef9517221..c9a74fa58 100644
--- a/nn/common/OperationResolver.cpp
+++ b/nn/common/OperationResolver.cpp
@@ -64,6 +64,7 @@ const OperationRegistration* register_NEG();
const OperationRegistration* register_NOT_EQUAL();
const OperationRegistration* register_PRELU();
const OperationRegistration* register_QUANTIZE();
+const OperationRegistration* register_QUANTIZED_LSTM();
const OperationRegistration* register_REDUCE_ALL();
const OperationRegistration* register_REDUCE_ANY();
const OperationRegistration* register_REDUCE_MAX();
@@ -131,6 +132,7 @@ BuiltinOperationResolver::BuiltinOperationResolver() {
registerOperation(register_NOT_EQUAL());
registerOperation(register_PRELU());
registerOperation(register_QUANTIZE());
+ registerOperation(register_QUANTIZED_LSTM());
registerOperation(register_REDUCE_ALL());
registerOperation(register_REDUCE_ANY());
registerOperation(register_REDUCE_MAX());
diff --git a/nn/common/OperationsUtils.cpp b/nn/common/OperationsUtils.cpp
index 2cac8b344..418045e54 100644
--- a/nn/common/OperationsUtils.cpp
+++ b/nn/common/OperationsUtils.cpp
@@ -211,6 +211,20 @@ bool QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
++*shift;
}
NN_RET_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
+ // A shift amount smaller than -31 would cause all bits to be shifted out
+ // and thus all results would be zero. We implement that instead with
+ // q_fixed==0, so as to avoid hitting issues with right-shift
+ // operations with shift amounts greater than 31. Note that this happens
+ // roughly when abs(double_multiplier) < 2^-31 and the present handling means
+ // that we're effectively flushing tiny double_multiplier's to zero.
+ // We could conceivably handle values in the range (roughly) [32, 63]
+ // as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
+ // the present handling is just doing 'flush denormals to zero'. We could
+ // reconsider and actually generate nonzero denormals if a need arises.
+ if (*shift < -31) {
+ *shift = 0;
+ q_fixed = 0;
+ }
*quantized_multiplier = static_cast<int32_t>(q_fixed);
return true;
}
diff --git a/nn/common/QuantUtils.cpp b/nn/common/QuantUtils.cpp
new file mode 100644
index 000000000..97b76b74a
--- /dev/null
+++ b/nn/common/QuantUtils.cpp
@@ -0,0 +1,193 @@
+#include "QuantUtils.h"
+
+#include <algorithm>
+#include <limits>
+#include <memory>
+
+namespace android {
+namespace nn {
+
+void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights, const int32_t* bias,
+ int32_t layer_norm_scale_a, int32_t layer_norm_scale_b, int32_t variance_limit,
+ int n_batch, int n_input, int16_t* output) {
+ static const int kOverflowGuard = 1 << 20;
+ for (int i = 0; i < n_batch; ++i) {
+ int64_t sum = 0;
+ int64_t sum_sq = 0;
+ for (int j = 0; j < n_input; ++j) {
+ const int32_t index = i * n_input + j;
+ int32_t val = static_cast<int32_t>(input[index]);
+ sum += val;
+ sum_sq += val * val;
+ }
+ int32_t mean = static_cast<int32_t>(static_cast<int64_t>(sum) * 1024 / n_input);
+ // TODO(jianlijianli): Avoids overflow but only works for POT n_input.
+ int32_t temp = kOverflowGuard / n_input;
+ int64_t variance = sum_sq * temp - static_cast<int64_t>(mean) * static_cast<int64_t>(mean);
+ int32_t variance2 = static_cast<int32_t>(variance / kOverflowGuard);
+ if (variance2 < 1) {
+ variance2 = variance_limit;
+ }
+ int32_t stddev_inverse_a;
+ int stddev_inverse_b;
+ GetInvSqrtQuantizedMultiplierExp(variance2, /*reverse_shift*/ -1, &stddev_inverse_a,
+ &stddev_inverse_b);
+
+ for (int j = 0; j < n_input; ++j) {
+ const int32_t index = i * n_input + j;
+ int32_t val = static_cast<int32_t>(input[index]);
+ int32_t shifted = 1024 * val - mean;
+ int32_t rescaled =
+ MultiplyByQuantizedMultiplier(shifted, stddev_inverse_a, stddev_inverse_b);
+ // TODO(jianlijianli): Saturate this.
+ int64_t val3 = rescaled * layer_norm_weights[j] + bias[j];
+ int32_t val4 = static_cast<int32_t>((val3 > 0 ? val3 + 512 : val3 - 512) / 1024);
+ int32_t val5 = MultiplyByQuantizedMultiplier(val4, layer_norm_scale_a,
+ layer_norm_scale_b + 12);
+ val5 = std::min(std::max(INT16_MIN, val5), INT16_MAX);
+ output[index] = static_cast<int16_t>(val5);
+ }
+ }
+}
+
+void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar, int32_t n_row,
+ int32_t n_col, int32_t* output) {
+ for (int i = 0; i < n_row; ++i) {
+ int32_t row_sum = 0;
+ for (int j = 0; j < n_col; ++j) {
+ row_sum += *matrix++;
+ }
+ output[i] += row_sum * scalar;
+ }
+}
+
+bool PrecomputeZeroPointTimesWeightWithBias(int32_t zero_point, const int8_t* weight_tensor,
+ const Shape& weight_shape, const int32_t* bias_tensor,
+ std::unique_ptr<int32_t[]>* output) {
+ if (weight_tensor == nullptr) {
+ return true;
+ }
+
+ NN_RET_CHECK_EQ(weight_shape.dimensions.size(), 2u);
+ const int row = weight_shape.dimensions[0];
+ const int col = weight_shape.dimensions[1];
+ *output = std::make_unique<int32_t[]>(row);
+ if (bias_tensor == nullptr) {
+ memset(output->get(), 0, row * sizeof(int32_t));
+ } else {
+ memcpy(output->get(), bias_tensor, row * sizeof(int32_t));
+ }
+ if (zero_point != 0) {
+ MatrixScalarMultiplyAccumulate(weight_tensor, zero_point, row, col, output->get());
+ }
+ return true;
+}
+
+void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input, int16_t* output) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int c = 0; c < n_input; c++) {
+ using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
+ using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
+ const int index = batch * n_input + c;
+ F3 sigmoid_input = F3::FromRaw(input[index]);
+ F0 sigmoid_output = gemmlowp::logistic(sigmoid_input);
+ output[index] = sigmoid_output.raw();
+ }
+ }
+}
+
+void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch, int n_input, int shift,
+ int16_t* output) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int i = 0; i < n_input; ++i) {
+ const int index = batch * n_input + i;
+ const int16_t a = input_1[index];
+ const int16_t b = input_2[index];
+ const int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
+ output[index] = static_cast<int16_t>(gemmlowp::RoundingDivideByPOT(value, shift));
+ }
+ }
+}
+
+void CwiseMul(const int16_t* input_1, const int16_t* input_2, int32_t multiplier, int32_t shift,
+ int32_t n_batch, int32_t n_input, int32_t output_zp, int8_t* output) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int i = 0; i < n_input; ++i) {
+ const int index = batch * n_input + i;
+ const int16_t a = input_1[index];
+ const int16_t b = input_2[index];
+ int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
+ value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
+ value -= output_zp;
+ value = std::min(std::max(-128, value), 127);
+
+ output[index] = static_cast<int8_t>(value);
+ }
+ }
+}
+
+bool CheckedLog2(const float x, int* log2_result) {
+ const float x_log2 = std::log(x) * (1.0f / std::log(2.0f));
+ const float x_log2_rounded = std::round(x_log2);
+ const float x_log2_fracpart = x_log2 - x_log2_rounded;
+
+ *log2_result = static_cast<int>(x_log2_rounded);
+ return std::abs(x_log2_fracpart) < 1e-3;
+}
+
+void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch, int n_input,
+ int16_t* output) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int i = 0; i < n_input; ++i) {
+ const int index = batch * n_input + i;
+ int32_t sum = input_1[index] + input_2[index];
+ const int32_t sum_clamped = std::min(INT16_MAX, std::max(INT16_MIN, sum));
+ output[index] = static_cast<int16_t>(sum_clamped);
+ }
+ }
+}
+
+void CwiseClipping(int16_t* input, const int16_t clipping_value, int32_t n_batch, int32_t n_input) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int i = 0; i < n_input; ++i) {
+ const int index = batch * n_input + i;
+ if (input[index] > clipping_value) {
+ input[index] = clipping_value;
+ }
+ if (input[index] < -clipping_value) {
+ input[index] = -clipping_value;
+ }
+ }
+ }
+}
+
+void CwiseClipping(int8_t* input, const int8_t clipping_value, int32_t n_batch, int32_t n_input) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int i = 0; i < n_input; ++i) {
+ const int index = batch * n_input + i;
+ if (input[index] > clipping_value) {
+ input[index] = clipping_value;
+ }
+ if (input[index] < -clipping_value) {
+ input[index] = -clipping_value;
+ }
+ }
+ }
+}
+
+void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
+ const int16_t* batch_vector, int n_batch,
+ int32_t multiplier, int shift, int16_t* result) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < v_size; v++) {
+ int32_t prod = vector[v] * *batch_vector++;
+ prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift);
+ int32_t output = prod + *result;
+ output = std::max(std::min(32767, output), -32768);
+ *result++ = output;
+ }
+ }
+}
+
+} // namespace nn
+} // namespace android
diff --git a/nn/common/QuantUtils.h b/nn/common/QuantUtils.h
new file mode 100644
index 000000000..09a87405c
--- /dev/null
+++ b/nn/common/QuantUtils.h
@@ -0,0 +1,201 @@
+// Quantized calculation utilities.
+// TODO(vddang): Replace this with tensorflow/lite/kernels/internal/tensor_utils(common).h
+// after TFLite module has been synced.
+
+#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H
+#define ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H
+
+#include <limits>
+#include <memory>
+
+#include <public/gemmlowp.h>
+
+#include "OperationsUtils.h"
+#include "Utils.h"
+
+namespace android {
+namespace nn {
+
+inline int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift) {
+ using gemmlowp::RoundingDivideByPOT;
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ int left_shift = shift > 0 ? shift : 0;
+ int right_shift = shift > 0 ? 0 : -shift;
+ return RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier),
+ right_shift);
+}
+
+template <typename T>
+void MatrixBatchVectorMultiplyAccumulate(const int8_t* input, const int32_t* bias,
+ const int8_t* input_to_gate_weights, int32_t multiplier,
+ int32_t shift, int32_t n_batch, int32_t n_input,
+ int32_t n_output, int32_t output_zp, T* output) {
+ const int16_t output_max = std::numeric_limits<T>::max();
+ const int16_t output_min = std::numeric_limits<T>::min();
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int row = 0; row < n_output; ++row) {
+ int32_t acc = bias[row];
+ for (int col = 0; col < n_input; ++col) {
+ int8_t input_val = input[batch * n_input + col];
+ int8_t weights_val = input_to_gate_weights[row * n_input + col];
+ acc += input_val * weights_val;
+ }
+ acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
+ acc += output_zp;
+ acc += output[batch * n_output + row];
+ if (acc > output_max) {
+ acc = output_max;
+ }
+ if (acc < output_min) {
+ acc = output_min;
+ }
+ output[batch * n_output + row] = static_cast<T>(acc);
+ }
+ }
+}
+
+template <typename T>
+int CountLeadingZeros(T integer_input) {
+ static_assert(std::is_unsigned<T>::value, "Only unsigned integer types handled.");
+#if defined(__GNUC__)
+ return integer_input ? __builtin_clz(integer_input) : std::numeric_limits<T>::digits;
+#else
+ if (integer_input == 0) {
+ return std::numeric_limits<T>::digits;
+ }
+
+ const T one_in_leading_positive = static_cast<T>(1) << (std::numeric_limits<T>::digits - 1);
+ int leading_zeros = 0;
+ while (integer_input < one_in_leading_positive) {
+ integer_input <<= 1;
+ ++leading_zeros;
+ }
+ return leading_zeros;
+#endif
+}
+
+inline bool GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
+ int32_t* output_inv_sqrt, int* output_shift) {
+ *output_shift = 11;
+ while (input >= (1 << 29)) {
+ input /= 4;
+ ++*output_shift;
+ }
+ NN_RET_CHECK_GT(input, 0);
+ const unsigned max_left_shift_bits = CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
+ const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
+ const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
+ *output_shift -= left_shift_bit_pairs;
+ input <<= 2 * left_shift_bit_pairs;
+ NN_RET_CHECK_GE(input, (1 << 27));
+ NN_RET_CHECK_LT(input, (1 << 29));
+ using gemmlowp::FixedPoint;
+ using gemmlowp::Rescale;
+ using gemmlowp::SaturatingRoundingMultiplyByPOT;
+ // Using 3 integer bits gives us enough room for the internal arithmetic in
+ // this Newton-Raphson iteration.
+ using F3 = FixedPoint<int32_t, 3>;
+ using F0 = FixedPoint<int32_t, 0>;
+ const F3 fixedpoint_input = F3::FromRaw(input >> 1);
+ const F3 fixedpoint_half_input = SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
+ const F3 fixedpoint_half_three =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
+ // Newton-Raphson iteration
+ // Naive unoptimized starting guess: x = 1
+ F3 x = F3::One();
+ // Naive unoptimized number of iterations: 5
+ for (int i = 0; i < 5; i++) {
+ const F3 x3 = Rescale<3>(x * x * x);
+ x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
+ }
+ const F0 fixedpoint_half_sqrt_2 =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
+ x = x * fixedpoint_half_sqrt_2;
+ *output_inv_sqrt = x.raw();
+ if (*output_shift < 0) {
+ *output_inv_sqrt <<= -*output_shift;
+ *output_shift = 0;
+ }
+ // Convert right shift (right is positive) to left shift.
+ *output_shift *= reverse_shift;
+ return true;
+}
+
+void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights, const int32_t* bias,
+ int32_t layer_norm_scale_a, int32_t layer_norm_scale_b, int32_t variance_limit,
+ int n_batch, int n_input, int16_t* output);
+
+void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar, int32_t n_row,
+ int32_t n_col, int32_t* output);
+
+bool PrecomputeZeroPointTimesWeightWithBias(int32_t zero_point, const int8_t* weight_tensor,
+ const Shape& weight_shape, const int32_t* bias_tensor,
+ std::unique_ptr<int32_t[]>* output);
+
+void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input, int16_t* output);
+
+template <int IntegerBits>
+void ApplyTanh(const int16_t* input, int32_t n_batch, int32_t n_input, int16_t* output) {
+ using FX = gemmlowp::FixedPoint<std::int16_t, IntegerBits>;
+ using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int i = 0; i < n_input; ++i) {
+ const int index = batch * n_input + i;
+ FX tanh_input = FX::FromRaw(input[index]);
+ F0 tanh_output = gemmlowp::tanh(tanh_input);
+ output[index] = tanh_output.raw();
+ }
+ }
+}
+
+inline void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch, int32_t n_input,
+ int16_t* output) {
+ assert(integer_bits <= 6);
+#define DISPATCH_TANH(i) \
+ case i: \
+ ApplyTanh<i>(input, n_batch, n_input, output); \
+ break;
+ switch (integer_bits) {
+ DISPATCH_TANH(0);
+ DISPATCH_TANH(1);
+ DISPATCH_TANH(2);
+ DISPATCH_TANH(3);
+ DISPATCH_TANH(4);
+ DISPATCH_TANH(5);
+ DISPATCH_TANH(6);
+ default:
+ return;
+ }
+#undef DISPATCH_TANH
+}
+
+void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch, int n_input, int shift,
+ int16_t* output);
+void CwiseMul(const int16_t* input_1, const int16_t* input_2, int32_t multiplier, int32_t shift,
+ int32_t n_batch, int32_t n_input, int32_t output_zp, int8_t* output);
+
+bool CheckedLog2(const float x, int* log2_result);
+
+void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch, int n_input,
+ int16_t* output);
+
+inline void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) {
+ static const int16_t kOne = 32767;
+ for (int v = 0; v < v_size; v++) {
+ *result++ = kOne - *vector++;
+ }
+}
+
+void CwiseClipping(int16_t* input, const int16_t clipping_value, int32_t n_batch, int32_t n_input);
+
+void CwiseClipping(int8_t* input, const int8_t clipping_value, int32_t n_batch, int32_t n_input);
+
+void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
+ const int16_t* batch_vector, int n_batch,
+ int32_t multiplier, int shift, int16_t* result);
+
+} // namespace nn
+} // namespace android
+
+#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H
diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h
index 46bdec4a6..146ae21e9 100644
--- a/nn/common/include/Utils.h
+++ b/nn/common/include/Utils.h
@@ -35,7 +35,7 @@ namespace nn {
const int kNumberOfDataTypes = 15;
// The number of operation types (OperationCode) defined in NeuralNetworks.h.
-const int kNumberOfOperationTypes = 95;
+const int kNumberOfOperationTypes = 96;
// The number of execution preferences defined in NeuralNetworks.h.
const int kNumberOfPreferences = 3;
diff --git a/nn/common/operations/QLSTM.cpp b/nn/common/operations/QLSTM.cpp
new file mode 100644
index 000000000..8fc63c916
--- /dev/null
+++ b/nn/common/operations/QLSTM.cpp
@@ -0,0 +1,785 @@
+#include <algorithm>
+#include <memory>
+#include <vector>
+
+#include "CpuExecutor.h"
+#include "OperationsUtils.h"
+#include "QuantUtils.h"
+
+namespace android {
+namespace nn {
+namespace qlstm {
+
+namespace {
+
+// Inputs
+constexpr uint32_t kNumInputs = 32;
+
+constexpr uint32_t kInputTensor = 0;
+// Input weight tensors of size: [numUnits, inputSize].
+constexpr uint32_t kInputToInputWeightsTensor = 1;
+constexpr uint32_t kInputToForgetWeightsTensor = 2;
+constexpr uint32_t kInputToCellWeightsTensor = 3;
+constexpr uint32_t kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size [numUnits, outputSize].
+constexpr uint32_t kRecurrentToInputWeightsTensor = 5;
+constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
+constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
+constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
+
+// For peephole (optional).
+// Cell to input/forget/output weights of size [numUnits].
+constexpr uint32_t kCellToInputWeightsTensor = 9;
+constexpr uint32_t kCellToForgetWeightsTensor = 10;
+constexpr uint32_t kCellToOutputWeightsTensor = 11;
+
+// Gates bias tensors of size [numUnits].
+constexpr uint32_t kInputGateBiasTensor = 12;
+constexpr uint32_t kForgetGateBiasTensor = 13;
+constexpr uint32_t kCellGateBiasTensor = 14;
+constexpr uint32_t kOutputGateBiasTensor = 15;
+
+// Projection weight tensor of size [outputSize, numUnits].
+constexpr uint32_t kProjectionWeightsTensor = 16;
+// Projection bias tensor of size [outputSize].
+constexpr uint32_t kProjectionBiasTensor = 17;
+
+// Output from the previous time step, as tensor
+// of size [numBatches, outputSize].
+constexpr uint32_t kPrevOutputTensor = 18;
+
+// Cell state from the previous time step, as tensor
+// of size [numBatches, numUnits].
+constexpr uint32_t kPrevCellStateTensor = 19;
+
+// Layer normalization tensors of size [numUnits].
+constexpr uint32_t kInputLayerNormTensor = 20;
+constexpr uint32_t kForgetLayerNormTensor = 21;
+constexpr uint32_t kCellLayerNormTensor = 22;
+constexpr uint32_t kOutputLayerNormTensor = 23;
+
+// Clipping.
+constexpr uint32_t kCellClip = 24;
+constexpr uint32_t kProjectionClip = 25;
+
+// Scales of the result of matmul, i.e. input to layer normalization.
+constexpr uint32_t kInputIntermediateScale = 26;
+constexpr uint32_t kForgetIntermediateScale = 27;
+constexpr uint32_t kCellIntermediateScale = 28;
+constexpr uint32_t kOutputIntermediateScale = 29;
+
+// Zero point and scale of hidden state.
+constexpr uint32_t kHiddenStateZeroPoint = 30;
+constexpr uint32_t kHiddenStateScale = 31;
+
+// Outputs:
+constexpr uint32_t kNumOutputs = 3;
+constexpr uint32_t kOutputStateOutTensor = 0;
+constexpr uint32_t kCellStateOutTensor = 1;
+constexpr uint32_t kOutputTensor = 2;
+
+inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
+ return context->getInputBuffer(tensor) != nullptr;
+}
+
+} // namespace
+
+using hal::OperandType;
+
+bool validate(const IOperationValidationContext* context) {
+ NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
+ NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
+
+ std::vector<OperandType> inExpectedTypes;
+ // Input.
+ inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
+ // Input-to-* and recurrent-to-* weights.
+ for (int i = 0; i < 8; ++i) {
+ inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_SYMM);
+ }
+ // Cell-to-* weights.
+ for (int i = 0; i < 3; ++i) {
+ inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
+ }
+ // Gate biases.
+ for (int i = 0; i < 4; ++i) {
+ inExpectedTypes.push_back(OperandType::TENSOR_INT32);
+ }
+ // Projection.
+ inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_SYMM);
+ inExpectedTypes.push_back(OperandType::TENSOR_INT32);
+ // Previous output.
+ inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
+ // Previous cell state.
+ inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
+ // Layer norm weights
+ for (int i = 0; i < 4; ++i) {
+ inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
+ }
+ // Cell/projection clipping and scales of intermediate results at the 4 gates.
+ for (int i = 0; i < 6; ++i) {
+ inExpectedTypes.push_back(OperandType::FLOAT32);
+ }
+ // Zero point and scale of the hidden state.
+ inExpectedTypes.push_back(OperandType::INT32);
+ inExpectedTypes.push_back(OperandType::FLOAT32);
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+
+ std::vector<OperandType> outExpectedTypes;
+ // Output state (out).
+ outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
+ // Cell state (out).
+ outExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
+ // Output.
+ outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
+ NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
+
+ return validateHalVersion(context, HalVersion::V1_3);
+}
+
+bool prepare(IOperationExecutionContext* context) {
+ // Check that none of the required inputs are omitted
+ const std::vector<int> requiredTensorInputs = {
+ kInputTensor,
+ kInputToForgetWeightsTensor,
+ kInputToCellWeightsTensor,
+ kInputToOutputWeightsTensor,
+ kRecurrentToForgetWeightsTensor,
+ kRecurrentToCellWeightsTensor,
+ kRecurrentToOutputWeightsTensor,
+ kForgetGateBiasTensor,
+ kCellGateBiasTensor,
+ kOutputGateBiasTensor,
+ kPrevOutputTensor,
+ kPrevCellStateTensor,
+ };
+ for (const int tensor : requiredTensorInputs) {
+ NN_RET_CHECK(!context->isOmittedInput(tensor))
+ << "required input " << tensor << " is omitted";
+ }
+
+ const Shape inputShape = context->getInputShape(kInputTensor);
+ const uint32_t inputRank = getNumberOfDimensions(inputShape);
+ NN_RET_CHECK_EQ(inputRank, 2) << "Invalid input tensor rank: " << inputRank;
+
+ const uint32_t batchSize = getSizeOfDimension(inputShape, 0);
+ const uint32_t inputSize = getSizeOfDimension(inputShape, 1);
+
+ const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
+ const uint32_t numUnits = getSizeOfDimension(inputToOutputShape, 0);
+
+ const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numUnits);
+ const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
+
+ if (hasTensor(context, kInputToInputWeightsTensor)) {
+ const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numUnits);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
+ }
+
+ const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numUnits);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
+ const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numUnits);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
+
+ if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
+ const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numUnits);
+ NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
+ }
+
+ const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numUnits);
+ NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
+ const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numUnits);
+ NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
+
+ // Make sure the input-gate's parameters are either all present (non-CIFG) or
+ // not at all (CIFG).
+ const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
+ hasTensor(context, kRecurrentToInputWeightsTensor)) ||
+ (!hasTensor(context, kInputToInputWeightsTensor) &&
+ !hasTensor(context, kRecurrentToInputWeightsTensor));
+ NN_RET_CHECK(cifgWeightsAllOrNone);
+
+ if (hasTensor(context, kCellToInputWeightsTensor)) {
+ const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numUnits);
+ }
+
+ if (hasTensor(context, kCellToForgetWeightsTensor)) {
+ const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numUnits);
+ }
+
+ if (hasTensor(context, kCellToOutputWeightsTensor)) {
+ const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numUnits);
+ }
+
+ // Making sure the peephole weights are there all or none.
+ const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
+ const bool peepholeWeightsAllOrNone =
+ ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
+ hasTensor(context, kCellToForgetWeightsTensor) &&
+ hasTensor(context, kCellToOutputWeightsTensor)) ||
+ (!hasTensor(context, kCellToInputWeightsTensor) &&
+ !hasTensor(context, kCellToForgetWeightsTensor) &&
+ !hasTensor(context, kCellToOutputWeightsTensor));
+ NN_RET_CHECK(peepholeWeightsAllOrNone);
+
+ if (!cifgUsed) {
+ NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
+ const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numUnits);
+ } else {
+ NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
+ << "Input gate bias tensor is present when CIFG is used";
+ }
+
+ const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numUnits);
+ const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numUnits);
+ const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numUnits);
+
+ if (hasTensor(context, kProjectionWeightsTensor)) {
+ const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
+ NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numUnits);
+ }
+
+ if (hasTensor(context, kProjectionBiasTensor)) {
+ const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
+ }
+
+ const Shape outputStateShape = context->getInputShape(kPrevOutputTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
+ NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
+ const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
+ NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
+ NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numUnits);
+
+ if (hasTensor(context, kInputLayerNormTensor)) {
+ const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numUnits);
+ }
+
+ if (hasTensor(context, kForgetLayerNormTensor)) {
+ const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numUnits);
+ }
+
+ if (hasTensor(context, kCellLayerNormTensor)) {
+ const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numUnits);
+ }
+
+ if (hasTensor(context, kOutputLayerNormTensor)) {
+ const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
+ NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
+ NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numUnits);
+ }
+
+ if (cifgUsed) {
+ NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor))
+ << "Input layer norm weights tensor is present when CIFG is used";
+ const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) &&
+ hasTensor(context, kCellLayerNormTensor) &&
+ hasTensor(context, kOutputLayerNormTensor)) ||
+ (!hasTensor(context, kForgetLayerNormTensor) &&
+ !hasTensor(context, kCellLayerNormTensor) &&
+ !hasTensor(context, kOutputLayerNormTensor));
+ NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
+ } else {
+ const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) &&
+ hasTensor(context, kForgetLayerNormTensor) &&
+ hasTensor(context, kCellLayerNormTensor) &&
+ hasTensor(context, kOutputLayerNormTensor)) ||
+ (!hasTensor(context, kInputLayerNormTensor) &&
+ !hasTensor(context, kForgetLayerNormTensor) &&
+ !hasTensor(context, kCellLayerNormTensor) &&
+ !hasTensor(context, kOutputLayerNormTensor));
+ NN_RET_CHECK(layerNormWeightsAllOrNone);
+ }
+
+ const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
+ Shape outputShape = context->getOutputShape(kOutputTensor);
+ outputShape.dimensions = prevOutputShape.dimensions;
+
+ const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
+ Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor);
+ cellStateOutShape.dimensions = prevCellStateShape.dimensions;
+
+ return context->setOutputShape(kOutputStateOutTensor, outputShape) &&
+ context->setOutputShape(kCellStateOutTensor, cellStateOutShape) &&
+ context->setOutputShape(kOutputTensor, outputShape);
+}
+
+bool execute(IOperationExecutionContext* context) {
+ // Gets the inputs.
+ const Shape inputShape = context->getInputShape(kInputTensor);
+ const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor);
+ const Shape recurrentToInputWeightsShape =
+ context->getInputShape(kRecurrentToInputWeightsTensor);
+ const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
+ const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
+ const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor);
+ const Shape recurrentToForgetWeightsShape =
+ context->getInputShape(kRecurrentToForgetWeightsTensor);
+ const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
+ const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
+ const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor);
+ const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor);
+ const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
+ const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor);
+ const Shape recurrentToOutputWeightsShape =
+ context->getInputShape(kRecurrentToOutputWeightsTensor);
+ const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
+ const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
+ const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor);
+ const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
+ const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
+
+ const uint32_t batchSize = inputShape.dimensions[0];
+ const uint32_t inputSize = inputShape.dimensions[1];
+ const uint32_t numUnits = inputToOutputWeightsShape.dimensions[0];
+ const uint32_t outputSize = recurrentToOutputWeightsShape.dimensions[1];
+
+ const float cellClip = context->getInputValue<float>(kCellClip);
+ const float projectionClip = context->getInputValue<float>(kProjectionClip);
+ const float inputIntermediateScale = context->getInputValue<float>(kInputIntermediateScale);
+ const float forgetIntermediateScale = context->getInputValue<float>(kForgetIntermediateScale);
+ const float cellIntermediateScale = context->getInputValue<float>(kCellIntermediateScale);
+ const float outputIntermediateScale = context->getInputValue<float>(kOutputIntermediateScale);
+ const int8_t hiddenStateZeroPoint = context->getInputValue<int8_t>(kHiddenStateZeroPoint);
+ const float hiddenStateScale = context->getInputValue<float>(kHiddenStateScale);
+
+ const int8_t* inputBuffer =
+ reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputTensor));
+
+ const int8_t* inputToInputWeightsBuffer =
+ reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToInputWeightsTensor));
+ const bool useCifg = (inputToInputWeightsBuffer == nullptr);
+ const int8_t* recurrentToInputWeightsBuffer = reinterpret_cast<const int8_t*>(
+ context->getInputBuffer(kRecurrentToInputWeightsTensor));
+ const int16_t* cellToInputBuffer =
+ reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToInputWeightsTensor));
+ const int16_t* inputLayerNormBuffer =
+ reinterpret_cast<const int16_t*>(context->getInputBuffer(kInputLayerNormTensor));
+ const int32_t* inputBiasBuffer =
+ reinterpret_cast<const int32_t*>(context->getInputBuffer(kInputGateBiasTensor));
+
+ const int8_t* inputToForgetWeightsBuffer =
+ reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToForgetWeightsTensor));
+ const int8_t* recurrentToForgetWeightsBuffer = reinterpret_cast<const int8_t*>(
+ context->getInputBuffer(kRecurrentToForgetWeightsTensor));
+ const int16_t* cellToForgetBuffer =
+ reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToForgetWeightsTensor));
+ const int16_t* forgetLayerNormBuffer =
+ reinterpret_cast<const int16_t*>(context->getInputBuffer(kForgetLayerNormTensor));
+ const int32_t* forgetBiasBuffer =
+ reinterpret_cast<const int32_t*>(context->getInputBuffer(kForgetGateBiasTensor));
+
+ const int8_t* inputToCellWeightsBuffer =
+ reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToCellWeightsTensor));
+ const int8_t* recurrentToCellWeightsBuffer =
+ reinterpret_cast<const int8_t*>(context->getInputBuffer(kRecurrentToCellWeightsTensor));
+ const int16_t* cellLayerNormBuffer =
+ reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellLayerNormTensor));
+ const int32_t* cellBiasBuffer =
+ reinterpret_cast<const int32_t*>(context->getInputBuffer(kCellGateBiasTensor));
+
+ const int8_t* inputToOutputWeightsBuffer =
+ reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToOutputWeightsTensor));
+ const int8_t* recurrentToOutputWeightsBuffer = reinterpret_cast<const int8_t*>(
+ context->getInputBuffer(kRecurrentToOutputWeightsTensor));
+ const int16_t* cellToOutputBuffer =
+ reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToOutputWeightsTensor));
+ const int16_t* outputLayerNormBuffer =
+ reinterpret_cast<const int16_t*>(context->getInputBuffer(kOutputLayerNormTensor));
+ const int32_t* outputBiasBuffer =
+ reinterpret_cast<const int32_t*>(context->getInputBuffer(kOutputGateBiasTensor));
+
+ const int8_t* projectionWeightsBuffer =
+ reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor));
+ const int32_t* projectionBiasBuffer =
+ reinterpret_cast<const int32_t*>(context->getInputBuffer(kProjectionBiasTensor));
+
+ const int8_t* prevOutputBuffer =
+ reinterpret_cast<const int8_t*>(context->getInputBuffer(kPrevOutputTensor));
+ const int16_t* prevCellStateBuffer =
+ reinterpret_cast<const int16_t*>(context->getInputBuffer(kPrevCellStateTensor));
+
+ uint8_t* outputStateBuffer =
+ reinterpret_cast<uint8_t*>(context->getOutputBuffer(kOutputStateOutTensor));
+ int16_t* cellStateBuffer =
+ reinterpret_cast<int16_t*>(context->getOutputBuffer(kCellStateOutTensor));
+ int8_t* outputBuffer = reinterpret_cast<int8_t*>(context->getOutputBuffer(kOutputTensor));
+
+ // Calculates and decomposes effective scales.
+ // This is for optimizing the matmul calculation.
+ int cellShift;
+ NN_RET_CHECK(CheckedLog2(prevCellStateShape.scale, &cellShift));
+ NN_RET_CHECK(cellShift <= -9);
+
+ int32_t inputToInputEffectiveScaleA;
+ int32_t inputToInputEffectiveScaleB;
+ int32_t recurrentToInputEffectiveScaleA;
+ int32_t recurrentToInputEffectiveScaleB;
+ int32_t cellToInputEffectiveScaleA;
+ int32_t cellToInputEffectiveScaleB;
+ if (!useCifg) {
+ const float inputToInputEffectiveScale =
+ inputToInputWeightsShape.scale * inputShape.scale / inputIntermediateScale;
+ NN_RET_CHECK(QuantizeMultiplier(inputToInputEffectiveScale, &inputToInputEffectiveScaleA,
+ &inputToInputEffectiveScaleB));
+ const float recurrentToInputEffectiveScale =
+ recurrentToInputWeightsShape.scale * prevOutputShape.scale / inputIntermediateScale;
+ NN_RET_CHECK(QuantizeMultiplier(recurrentToInputEffectiveScale,
+ &recurrentToInputEffectiveScaleA,
+ &recurrentToInputEffectiveScaleB));
+ if (cellToInputBuffer != nullptr) {
+ const float cellToInputEffectiveScale =
+ std::pow(2, cellShift) * cellToInputShape.scale / inputIntermediateScale;
+ NN_RET_CHECK(QuantizeMultiplier(cellToInputEffectiveScale, &cellToInputEffectiveScaleA,
+ &cellToInputEffectiveScaleB));
+ }
+ }
+
+ int32_t inputLayerNormScaleA;
+ int32_t inputLayerNormScaleB;
+ if (inputLayerNormBuffer != nullptr) {
+ NN_RET_CHECK(QuantizeMultiplier(inputLayerNormShape.scale, &inputLayerNormScaleA,
+ &inputLayerNormScaleB));
+ }
+
+ const float inputToForgetEffectiveScale =
+ inputToForgetWeightsShape.scale * inputShape.scale / forgetIntermediateScale;
+ int32_t inputToForgetEffectiveScaleA;
+ int32_t inputToForgetEffectiveScaleB;
+ NN_RET_CHECK(QuantizeMultiplier(inputToForgetEffectiveScale, &inputToForgetEffectiveScaleA,
+ &inputToForgetEffectiveScaleB));
+ const float recurrentToForgetEffectiveScale =
+ recurrentToForgetWeightsShape.scale * prevOutputShape.scale / forgetIntermediateScale;
+ int32_t recurrentToForgetEffectiveScaleA;
+ int32_t recurrentToForgetEffectiveScaleB;
+ NN_RET_CHECK(QuantizeMultiplier(recurrentToForgetEffectiveScale,
+ &recurrentToForgetEffectiveScaleA,
+ &recurrentToForgetEffectiveScaleB));
+ int32_t cellToForgetEffectiveScaleA;
+ int32_t cellToForgetEffectiveScaleB;
+ if (cellToForgetBuffer != nullptr) {
+ const float cellToForgetEffectiveScale =
+ std::pow(2, cellShift) * cellToForgetShape.scale / forgetIntermediateScale;
+ NN_RET_CHECK(QuantizeMultiplier(cellToForgetEffectiveScale, &cellToForgetEffectiveScaleA,
+ &cellToForgetEffectiveScaleB));
+ }
+ int32_t forgetLayerNormScaleA;
+ int32_t forgetLayerNormScaleB;
+ if (forgetLayerNormBuffer != nullptr) {
+ NN_RET_CHECK(QuantizeMultiplier(forgetLayerNormShape.scale, &forgetLayerNormScaleA,
+ &forgetLayerNormScaleB));
+ }
+
+ const float inputToCellEffectiveScale =
+ inputToCellWeightsShape.scale * inputShape.scale / cellIntermediateScale;
+ int32_t inputToCellEffectiveScaleA;
+ int32_t inputToCellEffectiveScaleB;
+ NN_RET_CHECK(QuantizeMultiplier(inputToCellEffectiveScale, &inputToCellEffectiveScaleA,
+ &inputToCellEffectiveScaleB));
+ const float recurrentToCellEffectiveScale =
+ recurrentToCellWeightsShape.scale * prevOutputShape.scale / cellIntermediateScale;
+ int32_t recurrentToCellEffectiveScaleA;
+ int32_t recurrentToCellEffectiveScaleB;
+ NN_RET_CHECK(QuantizeMultiplier(recurrentToCellEffectiveScale, &recurrentToCellEffectiveScaleA,
+ &recurrentToCellEffectiveScaleB));
+
+ int32_t cellLayerNormScaleA;
+ int32_t cellLayerNormScaleB;
+ if (cellLayerNormBuffer != nullptr) {
+ NN_RET_CHECK(QuantizeMultiplier(cellLayerNormShape.scale, &cellLayerNormScaleA,
+ &cellLayerNormScaleB));
+ }
+
+ const float inputToOutputEffectiveScale =
+ inputToOutputWeightsShape.scale * inputShape.scale / outputIntermediateScale;
+ int32_t inputToOutputEffectiveScaleA;
+ int32_t inputToOutputEffectiveScaleB;
+ NN_RET_CHECK(QuantizeMultiplier(inputToOutputEffectiveScale, &inputToOutputEffectiveScaleA,
+ &inputToOutputEffectiveScaleB));
+ const float recurrentToOutputEffectiveScale =
+ recurrentToOutputWeightsShape.scale * prevOutputShape.scale / outputIntermediateScale;
+ int32_t recurrentToOutputEffectiveScaleA;
+ int32_t recurrentToOutputEffectiveScaleB;
+ NN_RET_CHECK(QuantizeMultiplier(recurrentToOutputEffectiveScale,
+ &recurrentToOutputEffectiveScaleA,
+ &recurrentToOutputEffectiveScaleB));
+ int32_t cellToOutputEffectiveScaleA;
+ int32_t cellToOutputEffectiveScaleB;
+ if (cellToOutputBuffer != nullptr) {
+ const float cellToOutputEffectiveScale =
+ std::pow(2, cellShift) * cellToOutputShape.scale / outputIntermediateScale;
+ NN_RET_CHECK(QuantizeMultiplier(cellToOutputEffectiveScale, &cellToOutputEffectiveScaleA,
+ &cellToOutputEffectiveScaleB));
+ }
+ int32_t outputLayerNormScaleA;
+ int32_t outputLayerNormScaleB;
+ if (outputLayerNormBuffer != nullptr) {
+ NN_RET_CHECK(QuantizeMultiplier(outputLayerNormShape.scale, &outputLayerNormScaleA,
+ &outputLayerNormScaleB));
+ }
+
+ const float hiddenStateEffectiveScale = std::pow(2, -15) / hiddenStateScale * std::pow(2, -15);
+ int32_t hiddenStateEffectiveScaleA;
+ int32_t hiddenStateEffectiveScaleB;
+ NN_RET_CHECK(QuantizeMultiplier(hiddenStateEffectiveScale, &hiddenStateEffectiveScaleA,
+ &hiddenStateEffectiveScaleB));
+
+ int32_t projectionEffectiveScaleA;
+ int32_t projectionEffectiveScaleB;
+ if (projectionWeightsBuffer != nullptr) {
+ const float projectionEffectiveScale =
+ projectionWeightsShape.scale * hiddenStateScale / prevOutputShape.scale;
+ NN_RET_CHECK(QuantizeMultiplier(projectionEffectiveScale, &projectionEffectiveScaleA,
+ &projectionEffectiveScaleB));
+ }
+
+ // Calculates quantized clipping parameters.
+ int16_t quantizedCellClip = 0;
+ if (cellClip > 0.0) {
+ quantizedCellClip = static_cast<int32_t>(
+ std::min(std::max(cellClip / prevCellStateShape.scale, -32768.0f), 32767.0f));
+ }
+ int8_t quantizedProjectionClip = 0;
+ if (projectionClip > 0.0) {
+ quantizedProjectionClip = static_cast<int32_t>(
+ std::min(std::max(projectionClip / projectionWeightsShape.scale, -128.0f), 127.0f));
+ }
+
+ // Calculates effective bias.
+ // This is for optimizing the matmul calculation.
+ std::unique_ptr<int32_t[]> inputToInputEffectiveBias;
+ std::unique_ptr<int32_t[]> recurrentToInputEffectiveBias;
+ if (!useCifg) {
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ -inputShape.offset, inputToInputWeightsBuffer, inputToInputWeightsShape,
+ /*bias=*/nullptr, &inputToInputEffectiveBias));
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ -prevOutputShape.offset, recurrentToInputWeightsBuffer,
+ recurrentToInputWeightsShape,
+ /*bias=*/nullptr, &recurrentToInputEffectiveBias));
+ }
+
+ std::unique_ptr<int32_t[]> inputToForgetEffectiveBias;
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ -inputShape.offset, inputToForgetWeightsBuffer, inputToForgetWeightsShape,
+ /*bias=*/nullptr, &inputToForgetEffectiveBias));
+ std::unique_ptr<int32_t[]> recurrentToForgetEffectiveBias;
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ -prevOutputShape.offset, recurrentToForgetWeightsBuffer, recurrentToForgetWeightsShape,
+ /*bias=*/nullptr, &recurrentToForgetEffectiveBias));
+
+ std::unique_ptr<int32_t[]> inputToCellEffectiveBias;
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ -inputShape.offset, inputToCellWeightsBuffer, inputToCellWeightsShape,
+ /*bias=*/nullptr, &inputToCellEffectiveBias));
+ std::unique_ptr<int32_t[]> recurrentToCellEffectiveBias;
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ -prevOutputShape.offset, recurrentToCellWeightsBuffer, recurrentToCellWeightsShape,
+ /*bias=*/nullptr, &recurrentToCellEffectiveBias));
+
+ std::unique_ptr<int32_t[]> inputToOutputEffectiveBias;
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ -inputShape.offset, inputToOutputWeightsBuffer, inputToOutputWeightsShape,
+ /*bias=*/nullptr, &inputToOutputEffectiveBias));
+ std::unique_ptr<int32_t[]> recurrentToOutputEffectiveBias;
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ -prevOutputShape.offset, recurrentToOutputWeightsBuffer, recurrentToOutputWeightsShape,
+ /*bias=*/nullptr, &recurrentToOutputEffectiveBias));
+
+ std::unique_ptr<int32_t[]> projectionEffectiveBias;
+ if (projectionBiasBuffer != nullptr) {
+ NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
+ hiddenStateZeroPoint, projectionWeightsBuffer, projectionWeightsShape,
+ projectionBiasBuffer, &projectionEffectiveBias));
+ }
+
+ // Temporary buffers.
+ std::vector<int16_t> inputGateBuffer(batchSize * numUnits);
+ std::vector<int16_t> forgetGateBuffer(batchSize * numUnits);
+ std::vector<int16_t> cellGateBuffer(batchSize * numUnits);
+ std::vector<int16_t> outputGateBuffer(batchSize * numUnits);
+ std::vector<int8_t> buffer8(batchSize * numUnits);
+
+ // To avoid overflow when calculating layer norm.
+ const int32_t inputInvLargeValue =
+ std::min(1, static_cast<int32_t>(10000 * inputLayerNormShape.scale));
+ const int32_t forgetInvLargeValue =
+ std::min(1, static_cast<int32_t>(10000 * forgetLayerNormShape.scale));
+ const int32_t cellInvLargeValue =
+ std::min(1, static_cast<int32_t>(10000 * cellLayerNormShape.scale));
+ const int32_t outputInvLargeValue =
+ std::min(1, static_cast<int32_t>(10000 * outputLayerNormShape.scale));
+
+ // Forget gate.
+ MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToForgetEffectiveBias.get(),
+ inputToForgetWeightsBuffer, inputToForgetEffectiveScaleA,
+ inputToForgetEffectiveScaleB, batchSize, inputSize,
+ numUnits,
+ /*outputZeroPoint=*/0, forgetGateBuffer.data());
+ MatrixBatchVectorMultiplyAccumulate(
+ prevOutputBuffer, recurrentToForgetEffectiveBias.get(), recurrentToForgetWeightsBuffer,
+ recurrentToForgetEffectiveScaleA, recurrentToForgetEffectiveScaleB, batchSize,
+ outputSize, numUnits,
+ /*outputZeroPoint=*/0, forgetGateBuffer.data());
+ if (cellToForgetBuffer != nullptr) {
+ VectorBatchVectorCwiseProductAccumulate(
+ cellToForgetBuffer, outputSize, cellStateBuffer, batchSize,
+ cellToForgetEffectiveScaleA, cellToForgetEffectiveScaleB, forgetGateBuffer.data());
+ }
+ if (forgetLayerNormBuffer != nullptr) {
+ ApplyLayerNorm(forgetGateBuffer.data(), forgetLayerNormBuffer, forgetBiasBuffer,
+ forgetLayerNormScaleA, forgetLayerNormScaleB, forgetInvLargeValue, batchSize,
+ numUnits, forgetGateBuffer.data());
+ }
+ ApplySigmoid(forgetGateBuffer.data(), batchSize, numUnits, forgetGateBuffer.data());
+
+ // Modulation gate.
+ MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToCellEffectiveBias.get(),
+ inputToCellWeightsBuffer, inputToCellEffectiveScaleA,
+ inputToCellEffectiveScaleB, batchSize, inputSize, numUnits,
+ /*outputZeroPoint=*/0, cellGateBuffer.data());
+ MatrixBatchVectorMultiplyAccumulate(
+ prevOutputBuffer, recurrentToCellEffectiveBias.get(), recurrentToCellWeightsBuffer,
+ recurrentToCellEffectiveScaleA, recurrentToCellEffectiveScaleB, batchSize, outputSize,
+ numUnits,
+ /*outputZeroPoint=*/0, cellGateBuffer.data());
+ if (cellLayerNormBuffer != nullptr) {
+ ApplyLayerNorm(cellGateBuffer.data(), cellLayerNormBuffer, cellBiasBuffer,
+ cellLayerNormScaleA, cellLayerNormScaleB, cellInvLargeValue, batchSize,
+ numUnits, cellGateBuffer.data());
+ }
+ ApplyTanh<3>(cellGateBuffer.data(), batchSize, numUnits, cellGateBuffer.data());
+
+ // Input gate.
+ if (useCifg) {
+ Sub1Vector(forgetGateBuffer.data(), batchSize * numUnits, inputGateBuffer.data());
+ } else {
+ MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToInputEffectiveBias.get(),
+ inputToInputWeightsBuffer, inputToInputEffectiveScaleA,
+ inputToInputEffectiveScaleB, batchSize, inputSize,
+ numUnits,
+ /*outputZeroPoint=*/0, inputGateBuffer.data());
+ MatrixBatchVectorMultiplyAccumulate(
+ prevOutputBuffer, recurrentToInputEffectiveBias.get(),
+ recurrentToInputWeightsBuffer, recurrentToInputEffectiveScaleA,
+ recurrentToInputEffectiveScaleB, batchSize, outputSize, numUnits,
+ /*outputZeroPoint=*/0, inputGateBuffer.data());
+ if (cellToInputBuffer != nullptr) {
+ VectorBatchVectorCwiseProductAccumulate(
+ cellToInputBuffer, outputSize, cellStateBuffer, batchSize,
+ cellToInputEffectiveScaleA, cellToInputEffectiveScaleB, inputGateBuffer.data());
+ }
+ if (inputLayerNormBuffer != nullptr) {
+ ApplyLayerNorm(inputGateBuffer.data(), inputLayerNormBuffer, inputBiasBuffer,
+ inputLayerNormScaleA, inputLayerNormScaleB, inputInvLargeValue,
+ batchSize, numUnits, inputGateBuffer.data());
+ }
+ ApplySigmoid(inputGateBuffer.data(), batchSize, numUnits, inputGateBuffer.data());
+ }
+
+ // Cell.
+ CwiseMul(forgetGateBuffer.data(), prevCellStateBuffer, batchSize, numUnits,
+ /*shift=*/15, forgetGateBuffer.data());
+ CwiseMul(inputGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, 30 + cellShift,
+ cellGateBuffer.data());
+ CwiseAdd(forgetGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, cellStateBuffer);
+ if (quantizedCellClip > 0) {
+ CwiseClipping(cellStateBuffer, quantizedCellClip, batchSize, numUnits);
+ }
+
+ // Output gate.
+ MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToOutputEffectiveBias.get(),
+ inputToOutputWeightsBuffer, inputToOutputEffectiveScaleA,
+ inputToOutputEffectiveScaleB, batchSize, inputSize,
+ numUnits,
+ /*outputZeroPoint=*/0, outputGateBuffer.data());
+ MatrixBatchVectorMultiplyAccumulate(
+ prevOutputBuffer, recurrentToOutputEffectiveBias.get(), recurrentToOutputWeightsBuffer,
+ recurrentToOutputEffectiveScaleA, recurrentToOutputEffectiveScaleB, batchSize,
+ outputSize, numUnits,
+ /*outputZeroPoint=*/0, outputGateBuffer.data());
+ if (cellToOutputBuffer != nullptr) {
+ VectorBatchVectorCwiseProductAccumulate(
+ cellToOutputBuffer, outputSize, cellStateBuffer, batchSize,
+ cellToOutputEffectiveScaleA, cellToOutputEffectiveScaleB, outputGateBuffer.data());
+ }
+ if (outputLayerNormBuffer != nullptr) {
+ ApplyLayerNorm(outputGateBuffer.data(), outputLayerNormBuffer, outputBiasBuffer,
+ outputLayerNormScaleA, outputLayerNormScaleB, outputInvLargeValue, batchSize,
+ numUnits, outputGateBuffer.data());
+ }
+ ApplySigmoid(outputGateBuffer.data(), batchSize, numUnits, outputGateBuffer.data());
+
+ // Hidden.
+ ApplyTanh(cellShift + 15, cellStateBuffer, batchSize, numUnits, inputGateBuffer.data());
+ CwiseMul(outputGateBuffer.data(), inputGateBuffer.data(), hiddenStateEffectiveScaleA,
+ hiddenStateEffectiveScaleB, batchSize, numUnits, hiddenStateZeroPoint, buffer8.data());
+
+ // Projection.
+ if (projectionWeightsBuffer != nullptr) {
+ MatrixBatchVectorMultiplyAccumulate(buffer8.data(), projectionEffectiveBias.get(),
+ projectionWeightsBuffer, projectionEffectiveScaleA,
+ projectionEffectiveScaleB, batchSize, numUnits,
+ outputSize, prevOutputShape.offset, outputBuffer);
+ if (quantizedProjectionClip > 0) {
+ CwiseClipping(outputBuffer, quantizedProjectionClip, batchSize, outputSize);
+ }
+ }
+
+ // Copy output to output state out.
+ for (unsigned int i = 0; i < batchSize * outputSize; ++i) {
+ outputStateBuffer[i] = outputBuffer[i];
+ }
+
+ return true;
+}
+
+} // namespace qlstm
+
+NN_REGISTER_OPERATION(QUANTIZED_LSTM, "QUANTIZED_LSTM", qlstm::validate, qlstm::prepare,
+ qlstm::execute, .allowOmittedOperand = true);
+
+} // namespace nn
+} // namespace android
diff --git a/nn/runtime/NeuralNetworks.cpp b/nn/runtime/NeuralNetworks.cpp
index 3542c24b1..5ab55dbf7 100644
--- a/nn/runtime/NeuralNetworks.cpp
+++ b/nn/runtime/NeuralNetworks.cpp
@@ -188,6 +188,7 @@ static_assert(ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM == 92,
"ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM has changed");
static_assert(ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN == 93,
"ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN has changed");
+static_assert(ANEURALNETWORKS_QUANTIZED_LSTM == 95, "ANEURALNETWORKS_QUANTIZED_LSTM has changed");
static_assert(ANEURALNETWORKS_OEM_OPERATION == 10000, "ANEURALNETWORKS_OEM_OPERATION has changed");
diff --git a/nn/runtime/include/NeuralNetworks.h b/nn/runtime/include/NeuralNetworks.h
index 3bc646e2d..b2f72c47d 100644
--- a/nn/runtime/include/NeuralNetworks.h
+++ b/nn/runtime/include/NeuralNetworks.h
@@ -5084,6 +5084,137 @@ typedef enum {
* Available since API level 29.
*/
ANEURALNETWORKS_RESIZE_NEAREST_NEIGHBOR = 94,
+
+ /**
+ * Quantized version of {@link ANEURALNETWORKS_LSTM}.
+ *
+ * The input and the output use asymmetric quantized types, while the rest
+ * use symmetric ones.
+ *
+ * Inputs:
+ * * 0: The input to the LSTM cell.
+ * Type: {@link OperandType::TENSOR_QUANT8_ASYMM_SIGNED}
+ * Shape: [batchSize, inputSize]
+ * * 1: The input-to-input weights. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, inputSize]
+ * * 2: The input-to-forget weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, inputSize]
+ * * 3: The input-to-cell weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, inputSize]
+ * * 4: The input-to-output weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, inputSize]
+ * * 5: The recurrent-to-input weights. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, outputSize]
+ * * 6: The recurrent-to-forget weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, outputSize]
+ * * 7: The recurrent-to-cell weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, outputSize]
+ * * 8: The recurrent-to-output weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, outputSize]
+ * * 9: The cell-to-input weights (for peephole). Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 10: The cell-to-forget weights (for peephole). Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 11: The cell-to-output weights (for peephole). Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 12: The input gate bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Optional.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [numUnits]
+ * * 13: The forget gate bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [numUnits]
+ * * 14: The cell bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [numUnits]
+ * * 15: The output gate bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [numUnits]
+ * * 16: The projection weights. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [outputSize, numUnits]
+ * * 17: The projection bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Optional.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [outputSize]
+ * * 18: The output from the previous time step.
+ * Type: {@link OperandType::TENSOR_QUANT8_ASYMM_SIGNED}
+ * Shape: [batchSize, outputSize]
+ * * 19: The cell state from the previous time step.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [batchSize, numUnits]
+ * * 20: The input layer normalization weights. Used to rescale
+ * normalized inputs to activation at input gate. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 21: The forget layer normalization weights. Used to
+ * rescale normalized inputs to activation at forget gate. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 22: The cell layer normalization weights. Used to rescale
+ * normalized inputs to activation at cell gate. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 23: The output layer normalization weights. Used to
+ * rescale normalized inputs to activation at output gate. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 24: The cell clip. If provided the cell state is clipped
+ * by this value prior to the cell output activation. Optional.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 25: The projection clip. If provided and projection is enabled,
+ * this is used for clipping the projected values. Optional.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 26: The scale of the intermediate result of matmul,
+ * i.e. input to layer normalization, at input gate.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 27: The scale of the intermediate result of matmul,
+ * i.e. input to layer normalization, at forget gate.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 28: The scale of the intermediate result of matmul,
+ * i.e. input to layer normalization, at cell gate.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 29: The scale of the intermediate result of matmul,
+ * i.e. input to layer normalization, at output gate.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 30: The zero point of the hidden state, i.e. input to
+ * projection.
+ * Type: {@link OperandType::INT32}.
+ * * 31: The scale of the hidden state, i.e. input to
+ * projection.
+ * Type: {@link OperandType::FLOAT32}.
+ *
+ * Outputs:
+ * * 0: The output state (out).
+ * Type: {@link OperandType::TENSOR_QUANT8_ASYMM_SIGNED}
+ * Shape: [batchSize, outputSize]
+ * * 1: The cell state (out).
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [batchSize, numUnits]
+ * * 2: The output. This is effectively the same as the current
+ * "output state (out)" value.
+ * Type: {@link OperandType::TENSOR_QUANT8_ASYMM_SIGNED}
+ * Shape: [batchSize, outputSize]
+ *
+ * Available since API level 30.
+ */
+ ANEURALNETWORKS_QUANTIZED_LSTM = 95,
} OperationCode;
/**
diff --git a/nn/runtime/test/TestValidateOperations.cpp b/nn/runtime/test/TestValidateOperations.cpp
index 0a259a2b4..0de593ef7 100644
--- a/nn/runtime/test/TestValidateOperations.cpp
+++ b/nn/runtime/test/TestValidateOperations.cpp
@@ -3520,4 +3520,124 @@ TEST(OperationValidationTest, RESIZE_NEAREST_NEIGHBOR_quant8_signed) {
resizeNearestNeighborTest(ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, ANEURALNETWORKS_FLOAT32);
}
+TEST(OperationValidationTest, QUANTIZED_LSTM) {
+ uint32_t oneDimensional[1] = {5};
+ uint32_t twoDimensional[2] = {5, 5};
+
+ ANeuralNetworksOperandType quant8AsymSignedTensor2D = {
+ .type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED,
+ .dimensionCount = 2,
+ .dimensions = twoDimensional,
+ .scale = 0.0078125,
+ .zeroPoint = 0,
+ };
+ ANeuralNetworksOperandType quant8SymTensor2D = {
+ .type = ANEURALNETWORKS_TENSOR_QUANT8_SYMM,
+ .dimensionCount = 2,
+ .dimensions = twoDimensional,
+ .scale = 0.0078125,
+ .zeroPoint = 0,
+ };
+ ANeuralNetworksOperandType quant16SymTensor1D = {
+ .type = ANEURALNETWORKS_TENSOR_QUANT16_SYMM,
+ .dimensionCount = 1,
+ .dimensions = oneDimensional,
+ .scale = 1.0,
+ .zeroPoint = 0,
+ };
+ ANeuralNetworksOperandType quant16SymTensor2D = {
+ .type = ANEURALNETWORKS_TENSOR_QUANT16_SYMM,
+ .dimensionCount = 2,
+ .dimensions = twoDimensional,
+ .scale = 1.0,
+ .zeroPoint = 0,
+ };
+ ANeuralNetworksOperandType int32Tensor1D = {
+ .type = ANEURALNETWORKS_TENSOR_INT32,
+ .dimensionCount = 1,
+ .dimensions = oneDimensional,
+ .scale = 4.65661e-08,
+ .zeroPoint = 0,
+ };
+ ANeuralNetworksOperandType int32Scalar = {
+ .type = ANEURALNETWORKS_INT32,
+ };
+ ANeuralNetworksOperandType float32Scalar = {
+ .type = ANEURALNETWORKS_FLOAT32,
+ };
+
+ ANeuralNetworksOperandType input = quant8AsymSignedTensor2D;
+ ANeuralNetworksOperandType input_to_input_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType input_to_forget_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType input_to_cell_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType input_to_output_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType recurrent_to_input_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType recurrent_to_forget_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType recurrent_to_cell_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType recurrent_to_output_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType cell_to_input_weights = quant16SymTensor2D;
+ ANeuralNetworksOperandType cell_to_forget_weights = quant16SymTensor2D;
+ ANeuralNetworksOperandType cell_to_output_weights = quant16SymTensor2D;
+ ANeuralNetworksOperandType input_gate_bias = int32Tensor1D;
+ ANeuralNetworksOperandType forget_gate_bias = int32Tensor1D;
+ ANeuralNetworksOperandType cell_gate_bias = int32Tensor1D;
+ ANeuralNetworksOperandType output_gate_bias = int32Tensor1D;
+ ANeuralNetworksOperandType projection_weights = quant8SymTensor2D;
+ ANeuralNetworksOperandType projection_bias = int32Tensor1D;
+ ANeuralNetworksOperandType output_state_in = quant8AsymSignedTensor2D;
+ ANeuralNetworksOperandType cell_state_in = quant16SymTensor2D;
+ ANeuralNetworksOperandType input_layer_norm_weights = quant16SymTensor1D;
+ ANeuralNetworksOperandType forget_layer_norm_weights = quant16SymTensor1D;
+ ANeuralNetworksOperandType cell_layer_norm_weights = quant16SymTensor1D;
+ ANeuralNetworksOperandType output_layer_norm_weights = quant16SymTensor1D;
+ ANeuralNetworksOperandType cell_clip = float32Scalar;
+ ANeuralNetworksOperandType projection_clip = float32Scalar;
+ ANeuralNetworksOperandType input_intermediate_scale = float32Scalar;
+ ANeuralNetworksOperandType forget_intermediate_scale = float32Scalar;
+ ANeuralNetworksOperandType cell_intermediate_scale = float32Scalar;
+ ANeuralNetworksOperandType output_intermediate_scale = float32Scalar;
+ ANeuralNetworksOperandType hidden_state_zero_point = int32Scalar;
+ ANeuralNetworksOperandType hidden_state_scale = float32Scalar;
+
+ ANeuralNetworksOperandType output_state_out = quant8AsymSignedTensor2D;
+ ANeuralNetworksOperandType cell_state_out = quant16SymTensor2D;
+ ANeuralNetworksOperandType output = quant8AsymSignedTensor2D;
+
+ OperationTestBase test(ANEURALNETWORKS_QUANTIZED_LSTM,
+ {input,
+ input_to_input_weights,
+ input_to_forget_weights,
+ input_to_cell_weights,
+ input_to_output_weights,
+ recurrent_to_input_weights,
+ recurrent_to_forget_weights,
+ recurrent_to_cell_weights,
+ recurrent_to_output_weights,
+ cell_to_input_weights,
+ cell_to_forget_weights,
+ cell_to_output_weights,
+ input_gate_bias,
+ forget_gate_bias,
+ cell_gate_bias,
+ output_gate_bias,
+ projection_weights,
+ projection_bias,
+ output_state_in,
+ cell_state_in,
+ input_layer_norm_weights,
+ forget_layer_norm_weights,
+ cell_layer_norm_weights,
+ output_layer_norm_weights,
+ cell_clip,
+ projection_clip,
+ input_intermediate_scale,
+ forget_intermediate_scale,
+ cell_intermediate_scale,
+ output_intermediate_scale,
+ hidden_state_zero_point,
+ hidden_state_scale},
+ {output_state_out, cell_state_out, output});
+ test.testOpsValidations();
+}
+
} // end namespace
diff --git a/nn/runtime/test/generated/spec_V1_3/qlstm.example.cpp b/nn/runtime/test/generated/spec_V1_3/qlstm.example.cpp
new file mode 100644
index 000000000..e678fcd65
--- /dev/null
+++ b/nn/runtime/test/generated/spec_V1_3/qlstm.example.cpp
@@ -0,0 +1,380 @@
+// Generated from qlstm.mod.py
+// DO NOT EDIT
+// clang-format off
+#include "TestHarness.h"
+using namespace test_helper;
+
+namespace generated_tests::qlstm {
+
+const TestModel& get_test_model() {
+ static TestModel model = {
+ .expectFailure = false,
+ .expectedMultinomialDistributionTolerance = 0,
+ .inputIndexes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ .isRelaxed = false,
+ .minSupportedVersion = TestHalVersion::V1_3,
+ .operands = {{
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({90, 102, 13, 26, 38, 102, 13, 26, 51, 64}),
+ .dimensions = {2, 5},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.0078125f,
+ .type = TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13}),
+ .dimensions = {4, 5},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00784314f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64}),
+ .dimensions = {4, 5},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00784314f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77}),
+ .dimensions = {4, 5},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00784314f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51}),
+ .dimensions = {4, 5},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00784314f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({-25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77}),
+ .dimensions = {4, 3},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00784314f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25}),
+ .dimensions = {4, 3},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00784314f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25}),
+ .dimensions = {4, 3},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00784314f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25}),
+ .dimensions = {4, 3},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00784314f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({0, 0, 0, 0}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 1.0f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({0, 0, 0, 0}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 1.0f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({0, 0, 0, 0}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 1.0f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({644245, 3221226, 4724464, 8160438}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 4.65661e-08f,
+ .type = TestOperandType::TENSOR_INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({2147484, -6442451, -4294968, 2147484}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 4.65661e-08f,
+ .type = TestOperandType::TENSOR_INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({-1073742, 15461883, 5368709, 1717987}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 4.65661e-08f,
+ .type = TestOperandType::TENSOR_INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({1073742, -214748, 4294968, 2147484}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 4.65661e-08f,
+ .type = TestOperandType::TENSOR_INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51}),
+ .dimensions = {3,4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.00392157f,
+ .type = TestOperandType::TENSOR_QUANT8_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({0, 0, 0}),
+ .dimensions = {3},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::TENSOR_INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({0, 0, 0, 0, 0, 0}),
+ .dimensions = {2, 3},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 3.05176e-05f,
+ .type = TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({0, 0, 0, 0, 0, 0, 0, 0}),
+ .dimensions = {2, 4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 3.05176e-05f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({3277, 6553, 9830, 16384}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 3.05182e-05f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({6553, 6553, 13107, 9830}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 3.05182e-05f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({22937, 6553, 9830, 26214}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 3.05182e-05f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({19660, 6553, 6553, 16384}),
+ .dimensions = {4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 3.05182e-05f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<float>({0.0f}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::FLOAT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<float>({0.0f}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::FLOAT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<float>({0.007059f}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::FLOAT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<float>({0.007812f}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::FLOAT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<float>({0.007059f}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::FLOAT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<float>({0.007812f}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::FLOAT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({0}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<float>({0.007f}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::FLOAT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({127, 127, -108, -67, 127, 127}),
+ .dimensions = {2, 3},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
+ .numberOfConsumers = 0,
+ .scale = 3.05176e-05f,
+ .type = TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int16_t>({-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939}),
+ .dimensions = {2, 4},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
+ .numberOfConsumers = 0,
+ .scale = 3.05176e-05f,
+ .type = TestOperandType::TENSOR_QUANT16_SYMM,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int8_t>({127, 127, -108, -67, 127, 127}),
+ .dimensions = {2, 3},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
+ .numberOfConsumers = 0,
+ .scale = 3.05176e-05f,
+ .type = TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED,
+ .zeroPoint = 0
+ }},
+ .operations = {{
+ .inputs = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
+ .outputs = {32, 33, 34},
+ .type = TestOperationType::QUANTIZED_LSTM
+ }},
+ .outputIndexes = {32, 33, 34}
+ };
+ return model;
+}
+
+const auto dummy_test_model = TestModelManager::get().add("qlstm", get_test_model());
+
+} // namespace generated_tests::qlstm
+
diff --git a/nn/runtime/test/specs/V1_3/qlstm.mod.py b/nn/runtime/test/specs/V1_3/qlstm.mod.py
new file mode 100644
index 000000000..c00c61400
--- /dev/null
+++ b/nn/runtime/test/specs/V1_3/qlstm.mod.py
@@ -0,0 +1,178 @@
+#
+# Copyright (C) 2019 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Test for QUANTIZED_LSTM op.
+import copy
+
+model = Model()
+
+batch_size = 2
+input_size = 5
+num_units = 4
+output_size = 3
+
+input = Input("input",
+ ("TENSOR_QUANT8_ASYMM_SIGNED", "{%d, %d}" % (batch_size, input_size), 0.0078125, 0))
+
+input_to_input_weights = Input("input_to_input_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d, %d}" % (num_units, input_size), 0.00784314, 0))
+input_to_forget_weights = Input("input_to_forget_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d, %d}" % (num_units, input_size), 0.00784314, 0))
+input_to_cell_weights = Input("input_to_cell_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d, %d}" % (num_units, input_size), 0.00784314, 0))
+input_to_output_weights = Input("input_to_output_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d, %d}" % (num_units, input_size), 0.00784314, 0))
+
+recurrent_to_input_weights = Input("recurrent_to_intput_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d, %d}" % (num_units, output_size),
+ 0.00784314, 0))
+recurrent_to_forget_weights = Input("recurrent_to_forget_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d, %d}" % (num_units, output_size),
+ 0.00784314, 0))
+recurrent_to_cell_weights = Input("recurrent_to_cell_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d, %d}" % (num_units, output_size),
+ 0.00784314, 0))
+recurrent_to_output_weights = Input("recurrent_to_output_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d, %d}" % (num_units, output_size),
+ 0.00784314, 0))
+
+cell_to_input_weights = Input("cell_to_input_weights",
+ ("TENSOR_QUANT16_SYMM", "{%d}" % (num_units), 1.0, 0))
+cell_to_forget_weights = Input("cell_to_forget_weights",
+ ("TENSOR_QUANT16_SYMM", "{%d}" % (num_units), 1.0, 0))
+cell_to_output_weights = Input("cell_to_output_weights",
+ ("TENSOR_QUANT16_SYMM", "{%d}" % (num_units), 1.0, 0))
+
+input_gate_bias = Input("input_gate_bias",
+ ("TENSOR_INT32", "{%d}" % (num_units), 4.65661e-08, 0))
+forget_gate_bias = Input("forget_gate_bias",
+ ("TENSOR_INT32", "{%d}" % (num_units), 4.65661e-08, 0))
+cell_gate_bias = Input("cell_gate_bias",
+ ("TENSOR_INT32", "{%d}" % (num_units), 4.65661e-08, 0))
+output_gate_bias = Input("output_gate_bias",
+ ("TENSOR_INT32", "{%d}" % (num_units), 4.65661e-08, 0))
+
+projection_weights = Input("projection_weights",
+ ("TENSOR_QUANT8_SYMM", "{%d,%d}" % (output_size, num_units), 0.00392157, 0))
+projection_bias = Input("projection_bias", "TENSOR_INT32", "{%d}" % (output_size))
+
+output_state_in = Input("output_state_in",
+ ("TENSOR_QUANT8_ASYMM_SIGNED", "{%d, %d}" % (batch_size, output_size),
+ 3.05176e-05, 0))
+cell_state_in = Input("cell_state_in",
+ ("TENSOR_QUANT16_SYMM", "{%d, %d}" % (batch_size, num_units), 3.05176e-05, 0))
+
+input_layer_norm_weights = Input("input_layer_norm_weights",
+ ("TENSOR_QUANT16_SYMM", "{%d}" % num_units, 3.05182e-05, 0))
+forget_layer_norm_weights = Input("forget_layer_norm_weights",
+ ("TENSOR_QUANT16_SYMM", "{%d}" % num_units, 3.05182e-05, 0))
+cell_layer_norm_weights = Input("cell_layer_norm_weights",
+ ("TENSOR_QUANT16_SYMM", "{%d}" % num_units, 3.05182e-05, 0))
+output_layer_norm_weights = Input("output_layer_norm_weights",
+ ("TENSOR_QUANT16_SYMM", "{%d}" % num_units, 3.05182e-05, 0))
+
+cell_clip = Float32Scalar("cell_clip", 0.)
+projection_clip = Float32Scalar("projection_clip", 0.)
+
+input_intermediate_scale = Float32Scalar("input_intermediate_scale", 0.007059)
+forget_intermediate_scale = Float32Scalar("forget_intermediate_scale", 0.007812)
+cell_intermediate_scale = Float32Scalar("cell_intermediate_scale", 0.007059)
+output_intermediate_scale = Float32Scalar("output_intermediate_scale", 0.007812)
+hidden_state_zero_point = Int32Scalar("hidden_state_zero_point", 0)
+hidden_state_scale = Float32Scalar("hidden_state_scale", 0.007)
+
+output_state_out = Output("output_state_out",
+ ("TENSOR_QUANT8_ASYMM_SIGNED", "{%d, %d}" % (batch_size, output_size),
+ 3.05176e-05, 0))
+cell_state_out = Output("cell_state_out",
+ ("TENSOR_QUANT16_SYMM", "{%d, %d}" % (batch_size, num_units), 3.05176e-05, 0))
+output = Output("output",
+ ("TENSOR_QUANT8_ASYMM_SIGNED", "{%d, %d}" % (batch_size, output_size),
+ 3.05176e-05, 0))
+
+model = model.Operation(
+ "QUANTIZED_LSTM", input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights,
+ recurrent_to_forget_weights, recurrent_to_cell_weights,
+ recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_gate_bias, forget_gate_bias, cell_gate_bias,
+ output_gate_bias, projection_weights, projection_bias, output_state_in,
+ cell_state_in, input_layer_norm_weights, forget_layer_norm_weights,
+ cell_layer_norm_weights, output_layer_norm_weights, cell_clip, projection_clip,
+ input_intermediate_scale, forget_intermediate_scale, cell_intermediate_scale,
+ output_intermediate_scale, hidden_state_zero_point, hidden_state_scale).To([output_state_out,
+ cell_state_out, output])
+
+# Example 1. Input in operand 0,
+input0 = {
+ input_to_input_weights: [
+ 64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13
+ ],
+ input_to_forget_weights: [
+ -77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64
+ ],
+ input_to_cell_weights: [
+ -51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77
+ ],
+ input_to_output_weights: [
+ -102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51
+ ],
+ input_gate_bias: [644245, 3221226, 4724464, 8160438],
+ forget_gate_bias: [2147484, -6442451, -4294968, 2147484],
+ cell_gate_bias: [-1073742, 15461883, 5368709, 1717987],
+ output_gate_bias: [1073742, -214748, 4294968, 2147484],
+ recurrent_to_input_weights: [
+ -25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77
+ ],
+ recurrent_to_forget_weights: [
+ -64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25
+ ],
+ recurrent_to_cell_weights: [
+ -38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25
+ ],
+ recurrent_to_output_weights: [
+ 38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25
+ ],
+ projection_weights: [
+ -25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51
+ ],
+ projection_bias: [ 0 for _ in range(output_size) ],
+ input_layer_norm_weights: [3277, 6553, 9830, 16384],
+ forget_layer_norm_weights: [6553, 6553, 13107, 9830],
+ cell_layer_norm_weights: [22937, 6553, 9830, 26214],
+ output_layer_norm_weights: [19660, 6553, 6553, 16384],
+}
+
+test_input = [90, 102, 13, 26, 38, 102, 13, 26, 51, 64]
+
+golden_output = [
+ 127, 127, -108, -67, 127, 127
+]
+
+output0 = {
+ output_state_out: golden_output,
+ cell_state_out: [-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939],
+ output: golden_output,
+}
+
+input0[input] = test_input
+input0[output_state_in] = [ 0 for _ in range(batch_size * output_size) ]
+input0[cell_state_in] = [ 0 for _ in range(batch_size * num_units) ]
+input0[cell_to_input_weights] = [0 for _ in range(num_units) ]
+input0[cell_to_forget_weights] = [0 for _ in range(num_units) ]
+input0[cell_to_output_weights] = [0 for _ in range(num_units) ]
+
+Example((input0, output0))
diff --git a/nn/tools/api/NeuralNetworks.t b/nn/tools/api/NeuralNetworks.t
index 7129c0a9b..b65bfd4ec 100644
--- a/nn/tools/api/NeuralNetworks.t
+++ b/nn/tools/api/NeuralNetworks.t
@@ -75,6 +75,10 @@ typedef enum {
// Operations below are available since API level 29.
%insert Operation_1.2
+
+ // Operations below are available since API level 30.
+
+%insert Operation_1.3
} OperationCode;
/**
diff --git a/nn/tools/api/types.spec b/nn/tools/api/types.spec
index e68568e97..c6650d959 100644
--- a/nn/tools/api/types.spec
+++ b/nn/tools/api/types.spec
@@ -7,6 +7,7 @@
%define Ann ANeuralNetworks
%define DeclareOperation ANEURALNETWORKS_%{1} = %{2}
%define DeclareOperation_1.2 ANEURALNETWORKS_%{1} = %{2}
+%define DeclareOperation_1.3 ANEURALNETWORKS_%{1} = %{2}
%define FusedActivationFunc FuseCode
%define OperandType OperandCode
%define OperandTypeLinkPfx ANEURALNETWORKS_
@@ -83,6 +84,7 @@
%define-lines ZeroBatchesAPI29
%/define-lines
%define DeclareOperation_1.2 @@@NOT_DEFINED@@@
+%define DeclareOperation_1.3 @@@NOT_DEFINED@@@
%/kind
%kind hal_1.2 hal_1.3
@@ -93,11 +95,13 @@
%kind hal_1.2
%define DeclareOperation %{1} = @1.1::OperationType:%{1}
%define DeclareOperation_1.2 %{1} = %{2}
+%define DeclareOperation_1.3 @@@NOT_DEFINED@@@
%/kind
%kind hal_1.3
%define DeclareOperation %{1} = @1.2::OperationType:%{1}
%define DeclareOperation_1.2 %{1} = @1.2::OperationType:%{1}
+%define DeclareOperation_1.3 %{1} = %{2}
%/kind
%kind ndk hal_1.2 hal_1.3
@@ -5732,6 +5736,143 @@
FUNDAMENTAL_MAX = 14,
%/section
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+%% HAL OperationType for 1.3
+%% NDK OperationCode for API 30
+
+%section Operation_1.3
+ /**
+ * Quantized version of {@link OperationType:LSTM}.
+ *
+ * The input and the output use asymmetric quantized types, while the rest
+ * use symmetric ones.
+ *
+ * Inputs:
+ * * 0: The input to the LSTM cell.
+ * Type: {@link OperandType::TENSOR_QUANT8_ASYMM_SIGNED}
+ * Shape: [batchSize, inputSize]
+ * * 1: The input-to-input weights. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, inputSize]
+ * * 2: The input-to-forget weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, inputSize]
+ * * 3: The input-to-cell weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, inputSize]
+ * * 4: The input-to-output weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, inputSize]
+ * * 5: The recurrent-to-input weights. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, outputSize]
+ * * 6: The recurrent-to-forget weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, outputSize]
+ * * 7: The recurrent-to-cell weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, outputSize]
+ * * 8: The recurrent-to-output weights.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [numUnits, outputSize]
+ * * 9: The cell-to-input weights (for peephole). Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 10: The cell-to-forget weights (for peephole). Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 11: The cell-to-output weights (for peephole). Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 12: The input gate bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Optional.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [numUnits]
+ * * 13: The forget gate bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [numUnits]
+ * * 14: The cell bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [numUnits]
+ * * 15: The output gate bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [numUnits]
+ * * 16: The projection weights. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT8_SYMM}
+ * Shape: [outputSize, numUnits]
+ * * 17: The projection bias. Quantized with scale being the
+ * product of input and weights scales and zeroPoint equal to 0.
+ * Optional.
+ * Type: {@link OperandType::TENSOR_INT32}
+ * Shape: [outputSize]
+ * * 18: The output from the previous time step.
+ * Type: {@link OperandType::TENSOR_QUANT8_ASYMM_SIGNED}
+ * Shape: [batchSize, outputSize]
+ * * 19: The cell state from the previous time step.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [batchSize, numUnits]
+ * * 20: The input layer normalization weights. Used to rescale
+ * normalized inputs to activation at input gate. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 21: The forget layer normalization weights. Used to
+ * rescale normalized inputs to activation at forget gate. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 22: The cell layer normalization weights. Used to rescale
+ * normalized inputs to activation at cell gate. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 23: The output layer normalization weights. Used to
+ * rescale normalized inputs to activation at output gate. Optional.
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [numUnits]
+ * * 24: The cell clip. If provided the cell state is clipped
+ * by this value prior to the cell output activation. Optional.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 25: The projection clip. If provided and projection is enabled,
+ * this is used for clipping the projected values. Optional.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 26: The scale of the intermediate result of matmul,
+ * i.e. input to layer normalization, at input gate.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 27: The scale of the intermediate result of matmul,
+ * i.e. input to layer normalization, at forget gate.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 28: The scale of the intermediate result of matmul,
+ * i.e. input to layer normalization, at cell gate.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 29: The scale of the intermediate result of matmul,
+ * i.e. input to layer normalization, at output gate.
+ * Type: {@link OperandType::FLOAT32}.
+ * * 30: The zero point of the hidden state, i.e. input to
+ * projection.
+ * Type: {@link OperandType::INT32}.
+ * * 31: The scale of the hidden state, i.e. input to
+ * projection.
+ * Type: {@link OperandType::FLOAT32}.
+ *
+ * Outputs:
+ * * 0: The output state (out).
+ * Type: {@link OperandType::TENSOR_QUANT8_ASYMM_SIGNED}
+ * Shape: [batchSize, outputSize]
+ * * 1: The cell state (out).
+ * Type: {@link OperandType::TENSOR_QUANT16_SYMM}
+ * Shape: [batchSize, numUnits]
+ * * 2: The output. This is effectively the same as the current
+ * "output state (out)" value.
+ * Type: {@link OperandType::TENSOR_QUANT8_ASYMM_SIGNED}
+ * Shape: [batchSize, outputSize]
+%insert-lines AVAIL30
+ */
+ %{DeclareOperation_1.3 QUANTIZED_LSTM 95},
+%/section
+
%section Operation_1.3_MAX
- FUNDAMENTAL_MAX = 94,
+ FUNDAMENTAL_MAX = 95,
%/section
diff --git a/nn/tools/test_generator/test_harness/include/TestHarness.h b/nn/tools/test_generator/test_harness/include/TestHarness.h
index d0b248f15..3de9f63dc 100644
--- a/nn/tools/test_generator/test_harness/include/TestHarness.h
+++ b/nn/tools/test_generator/test_harness/include/TestHarness.h
@@ -177,6 +177,7 @@ enum class TestOperationType {
UNIDIRECTIONAL_SEQUENCE_LSTM = 92,
UNIDIRECTIONAL_SEQUENCE_RNN = 93,
RESIZE_NEAREST_NEIGHBOR = 94,
+ QUANTIZED_LSTM = 95,
};
enum class TestHalVersion { UNKNOWN, V1_0, V1_1, V1_2, V1_3 };