From 4f9c1332e69cf6f934e887b3f9957b0172df871b Mon Sep 17 00:00:00 2001 From: Lev Proleev Date: Sat, 11 Apr 2020 08:39:31 +0100 Subject: Add tests for multiplier quantization functions The tests are based on TF Lite's quantization_util_test.cc. Fix: 129569821 Test: atest NeuralNetworkTest_utils Change-Id: I9715e0a5f36d431f9d0bb2fbbf34e221996b6208 --- nn/common/QuantUtils.h | 17 +++++-- nn/common/UtilsTest.cpp | 122 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) (limited to 'nn') diff --git a/nn/common/QuantUtils.h b/nn/common/QuantUtils.h index 09a87405c..3da27e93d 100644 --- a/nn/common/QuantUtils.h +++ b/nn/common/QuantUtils.h @@ -5,11 +5,11 @@ #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H #define ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H +#include + #include #include -#include - #include "OperationsUtils.h" #include "Utils.h" @@ -77,12 +77,23 @@ int CountLeadingZeros(T integer_input) { inline bool GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift, int32_t* output_inv_sqrt, int* output_shift) { + NN_RET_CHECK_GE(input, 0); + if (input <= 1) { + // Handle the input value 1 separately to avoid overflow in that case + // in the general computation below. Also handle 0 as if it + // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid + // but rare/unrealistic input value. We can expect both to occur in some + // incompletely trained models, but probably not in fully trained models. + *output_inv_sqrt = std::numeric_limits::max(); + *output_shift = 0; + return true; + } + *output_shift = 11; while (input >= (1 << 29)) { input /= 4; ++*output_shift; } - NN_RET_CHECK_GT(input, 0); const unsigned max_left_shift_bits = CountLeadingZeros(static_cast(input)) - 1; const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; diff --git a/nn/common/UtilsTest.cpp b/nn/common/UtilsTest.cpp index d3eeee0dc..4d5a32d2a 100644 --- a/nn/common/UtilsTest.cpp +++ b/nn/common/UtilsTest.cpp @@ -16,9 +16,11 @@ #include #include +#include #include #include "OperationsUtils.cpp" +#include "QuantUtils.h" namespace android { namespace nn { @@ -155,6 +157,126 @@ TEST_F(CombineDimensionsTest, Dimensions) { testIncompatible({1, 2, 3, 4}, {1, 2, 3, 3}); } +TEST(QuantizationUtilsTest, QuantizeMultiplierSmallerThanOneExp) { + auto checkInvalidQuantization = [](double value) { + int32_t q; + int s; + EXPECT_FALSE(QuantizeMultiplierSmallerThanOneExp(value, &q, &s)); + }; + + checkInvalidQuantization(-0.1); + checkInvalidQuantization(0.0); + // If we get close enough to 1.0 it crashes and dies in one of two ways: + // Either the shift becomes negative or we trigger the 'less-than-one' CHECK. + checkInvalidQuantization(1 - 1e-15); + checkInvalidQuantization(1 - 1e-17); + checkInvalidQuantization(1.0); + + auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) { + int32_t q; + int s; + EXPECT_TRUE(QuantizeMultiplierSmallerThanOneExp(value, &q, &s)); + EXPECT_EQ(q, goldenQuantized); + EXPECT_EQ(s, goldenShift); + }; + + checkQuantization(0.25, 1073741824, -1); + checkQuantization(0.50 - 5e-9, 2147483627, -1); + checkQuantization(0.50 - 1e-10, 1073741824, 0); + checkQuantization(0.50, 1073741824, 0); + checkQuantization(0.75, 1610612736, 0); + checkQuantization(1 - 1e-9, 2147483646, 0); +} + +TEST(QuantizationUtilsTest, QuantizeMultiplierGreaterThanOne) { + auto checkInvalidQuantization = [](double value) { + int32_t q; + int s; + EXPECT_FALSE(QuantizeMultiplierGreaterThanOne(value, &q, &s)); + }; + + checkInvalidQuantization(1 + 1e-16); + + auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) { + int32_t q; + int s; + EXPECT_TRUE(QuantizeMultiplierGreaterThanOne(value, &q, &s)); + EXPECT_EQ(q, goldenQuantized); + EXPECT_EQ(s, goldenShift); + }; + + checkQuantization(1 + 1e-11, 1073741824, 1); + checkQuantization(1.25, 1342177280, 1); + checkQuantization(1.50, 1610612736, 1); + checkQuantization(1.50, 1610612736, 1); + checkQuantization(1.75, 1879048192, 1); + checkQuantization(2 - 1e-9, 2147483647, 1); + checkQuantization(2 - 1e-11, 1073741824, 2); + checkQuantization(2, 1073741824, 2); +} + +TEST(QuantizationUtilTest, QuantizeMultiplier) { + auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) { + int32_t q; + int s; + EXPECT_TRUE(QuantizeMultiplier(value, &q, &s)); + EXPECT_EQ(q, goldenQuantized); + EXPECT_EQ(s, goldenShift); + }; + + checkQuantization(-4, -1073741824, 3); + checkQuantization(-2, -1073741824, 2); + checkQuantization(-1, -1073741824, 1); + checkQuantization(-0.5, -1073741824, 0); + checkQuantization(-0.25, -1073741824, -1); + checkQuantization(-0.125, -1073741824, -2); + checkQuantization(0, 0, 0); + checkQuantization(0.125, 1073741824, -2); + checkQuantization(0.25, 1073741824, -1); + checkQuantization(0.5, 1073741824, 0); + checkQuantization(1, 1073741824, 1); + checkQuantization(2, 1073741824, 2); + checkQuantization(4, 1073741824, 3); +} + +TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) { + auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) { + int32_t q; + int s; + EXPECT_TRUE(QuantizeMultiplier(value, &q, &s)); + EXPECT_EQ(q, goldenQuantized); + EXPECT_EQ(s, goldenShift); + }; + + checkQuantization(std::ldexp(1.0f, -31), 1073741824, -30); + checkQuantization(std::ldexp(1.0f, -32), 1073741824, -31); + checkQuantization(std::ldexp(0.99f, -32), 0, 0); + checkQuantization(std::ldexp(1.0f, -33), 0, 0); +} + +TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) { + auto checkInvSqrtQuantization = [](int32_t input, int32_t goldenInvSqrt, int goldenShift) { + int32_t q; + int s; + EXPECT_TRUE(GetInvSqrtQuantizedMultiplierExp(input, 1, &q, &s)); + EXPECT_EQ(q, goldenInvSqrt); + EXPECT_EQ(s, goldenShift); + }; + + const auto kInt32Max = std::numeric_limits::max(); + checkInvSqrtQuantization(0, kInt32Max, 0); + checkInvSqrtQuantization(1, kInt32Max, 0); + checkInvSqrtQuantization(2, 1518498372, 0); + checkInvSqrtQuantization(3, 1239850284, 0); + checkInvSqrtQuantization(4, 1073741828, 0); + checkInvSqrtQuantization(100, 214748363, 0); + checkInvSqrtQuantization(10000, 343597361, 4); + checkInvSqrtQuantization(1000000, 274877901, 7); + checkInvSqrtQuantization(100000000, 219902323, 10); + checkInvSqrtQuantization((1 << 30), 268435457, 12); + checkInvSqrtQuantization(kInt32Max, 189812531, 12); +} + } // namespace wrapper } // namespace nn } // namespace android -- cgit v1.2.3