summaryrefslogtreecommitdiff
path: root/nn
diff options
context:
space:
mode:
authorLev Proleev <levp@google.com>2020-04-11 08:39:31 +0100
committerLev Proleev <levp@google.com>2020-04-20 15:05:00 +0100
commit4f9c1332e69cf6f934e887b3f9957b0172df871b (patch)
tree985fc822b0d5c0f86874a25142048b190c305c14 /nn
parentc9839533afd58835ad901a7b0fcf16a0ad5668ba (diff)
downloadml-4f9c1332e69cf6f934e887b3f9957b0172df871b.tar.gz
Add tests for multiplier quantization functions
The tests are based on TF Lite's quantization_util_test.cc. Fix: 129569821 Test: atest NeuralNetworkTest_utils Change-Id: I9715e0a5f36d431f9d0bb2fbbf34e221996b6208
Diffstat (limited to 'nn')
-rw-r--r--nn/common/QuantUtils.h17
-rw-r--r--nn/common/UtilsTest.cpp122
2 files changed, 136 insertions, 3 deletions
diff --git a/nn/common/QuantUtils.h b/nn/common/QuantUtils.h
index 09a87405c..3da27e93d 100644
--- a/nn/common/QuantUtils.h
+++ b/nn/common/QuantUtils.h
@@ -5,11 +5,11 @@
#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H
+#include <public/gemmlowp.h>
+
#include <limits>
#include <memory>
-#include <public/gemmlowp.h>
-
#include "OperationsUtils.h"
#include "Utils.h"
@@ -77,12 +77,23 @@ int CountLeadingZeros(T integer_input) {
inline bool GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
int32_t* output_inv_sqrt, int* output_shift) {
+ NN_RET_CHECK_GE(input, 0);
+ if (input <= 1) {
+ // Handle the input value 1 separately to avoid overflow in that case
+ // in the general computation below. Also handle 0 as if it
+ // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
+ // but rare/unrealistic input value. We can expect both to occur in some
+ // incompletely trained models, but probably not in fully trained models.
+ *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
+ *output_shift = 0;
+ return true;
+ }
+
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
++*output_shift;
}
- NN_RET_CHECK_GT(input, 0);
const unsigned max_left_shift_bits = CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
diff --git a/nn/common/UtilsTest.cpp b/nn/common/UtilsTest.cpp
index d3eeee0dc..4d5a32d2a 100644
--- a/nn/common/UtilsTest.cpp
+++ b/nn/common/UtilsTest.cpp
@@ -16,9 +16,11 @@
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
+#include <limits>
#include <vector>
#include "OperationsUtils.cpp"
+#include "QuantUtils.h"
namespace android {
namespace nn {
@@ -155,6 +157,126 @@ TEST_F(CombineDimensionsTest, Dimensions) {
testIncompatible({1, 2, 3, 4}, {1, 2, 3, 3});
}
+TEST(QuantizationUtilsTest, QuantizeMultiplierSmallerThanOneExp) {
+ auto checkInvalidQuantization = [](double value) {
+ int32_t q;
+ int s;
+ EXPECT_FALSE(QuantizeMultiplierSmallerThanOneExp(value, &q, &s));
+ };
+
+ checkInvalidQuantization(-0.1);
+ checkInvalidQuantization(0.0);
+ // If we get close enough to 1.0 it crashes and dies in one of two ways:
+ // Either the shift becomes negative or we trigger the 'less-than-one' CHECK.
+ checkInvalidQuantization(1 - 1e-15);
+ checkInvalidQuantization(1 - 1e-17);
+ checkInvalidQuantization(1.0);
+
+ auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
+ int32_t q;
+ int s;
+ EXPECT_TRUE(QuantizeMultiplierSmallerThanOneExp(value, &q, &s));
+ EXPECT_EQ(q, goldenQuantized);
+ EXPECT_EQ(s, goldenShift);
+ };
+
+ checkQuantization(0.25, 1073741824, -1);
+ checkQuantization(0.50 - 5e-9, 2147483627, -1);
+ checkQuantization(0.50 - 1e-10, 1073741824, 0);
+ checkQuantization(0.50, 1073741824, 0);
+ checkQuantization(0.75, 1610612736, 0);
+ checkQuantization(1 - 1e-9, 2147483646, 0);
+}
+
+TEST(QuantizationUtilsTest, QuantizeMultiplierGreaterThanOne) {
+ auto checkInvalidQuantization = [](double value) {
+ int32_t q;
+ int s;
+ EXPECT_FALSE(QuantizeMultiplierGreaterThanOne(value, &q, &s));
+ };
+
+ checkInvalidQuantization(1 + 1e-16);
+
+ auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
+ int32_t q;
+ int s;
+ EXPECT_TRUE(QuantizeMultiplierGreaterThanOne(value, &q, &s));
+ EXPECT_EQ(q, goldenQuantized);
+ EXPECT_EQ(s, goldenShift);
+ };
+
+ checkQuantization(1 + 1e-11, 1073741824, 1);
+ checkQuantization(1.25, 1342177280, 1);
+ checkQuantization(1.50, 1610612736, 1);
+ checkQuantization(1.50, 1610612736, 1);
+ checkQuantization(1.75, 1879048192, 1);
+ checkQuantization(2 - 1e-9, 2147483647, 1);
+ checkQuantization(2 - 1e-11, 1073741824, 2);
+ checkQuantization(2, 1073741824, 2);
+}
+
+TEST(QuantizationUtilTest, QuantizeMultiplier) {
+ auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
+ int32_t q;
+ int s;
+ EXPECT_TRUE(QuantizeMultiplier(value, &q, &s));
+ EXPECT_EQ(q, goldenQuantized);
+ EXPECT_EQ(s, goldenShift);
+ };
+
+ checkQuantization(-4, -1073741824, 3);
+ checkQuantization(-2, -1073741824, 2);
+ checkQuantization(-1, -1073741824, 1);
+ checkQuantization(-0.5, -1073741824, 0);
+ checkQuantization(-0.25, -1073741824, -1);
+ checkQuantization(-0.125, -1073741824, -2);
+ checkQuantization(0, 0, 0);
+ checkQuantization(0.125, 1073741824, -2);
+ checkQuantization(0.25, 1073741824, -1);
+ checkQuantization(0.5, 1073741824, 0);
+ checkQuantization(1, 1073741824, 1);
+ checkQuantization(2, 1073741824, 2);
+ checkQuantization(4, 1073741824, 3);
+}
+
+TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) {
+ auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
+ int32_t q;
+ int s;
+ EXPECT_TRUE(QuantizeMultiplier(value, &q, &s));
+ EXPECT_EQ(q, goldenQuantized);
+ EXPECT_EQ(s, goldenShift);
+ };
+
+ checkQuantization(std::ldexp(1.0f, -31), 1073741824, -30);
+ checkQuantization(std::ldexp(1.0f, -32), 1073741824, -31);
+ checkQuantization(std::ldexp(0.99f, -32), 0, 0);
+ checkQuantization(std::ldexp(1.0f, -33), 0, 0);
+}
+
+TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) {
+ auto checkInvSqrtQuantization = [](int32_t input, int32_t goldenInvSqrt, int goldenShift) {
+ int32_t q;
+ int s;
+ EXPECT_TRUE(GetInvSqrtQuantizedMultiplierExp(input, 1, &q, &s));
+ EXPECT_EQ(q, goldenInvSqrt);
+ EXPECT_EQ(s, goldenShift);
+ };
+
+ const auto kInt32Max = std::numeric_limits<std::int32_t>::max();
+ checkInvSqrtQuantization(0, kInt32Max, 0);
+ checkInvSqrtQuantization(1, kInt32Max, 0);
+ checkInvSqrtQuantization(2, 1518498372, 0);
+ checkInvSqrtQuantization(3, 1239850284, 0);
+ checkInvSqrtQuantization(4, 1073741828, 0);
+ checkInvSqrtQuantization(100, 214748363, 0);
+ checkInvSqrtQuantization(10000, 343597361, 4);
+ checkInvSqrtQuantization(1000000, 274877901, 7);
+ checkInvSqrtQuantization(100000000, 219902323, 10);
+ checkInvSqrtQuantization((1 << 30), 268435457, 12);
+ checkInvSqrtQuantization(kInt32Max, 189812531, 12);
+}
+
} // namespace wrapper
} // namespace nn
} // namespace android