aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorLev Proleev <levp@google.com>2021-02-26 21:44:39 +0000
committerLev Proleev <levp@google.com>2021-02-26 22:17:12 +0000
commit123f384187504585be3fe01002381dd459c17d96 (patch)
treea29716289a0b730ca66a3e632c6ce054eb3b90d6 /test
parent8dd5f1b93261d6ea0fe0c8e51d13f89657ceb0b8 (diff)
downloadgemmlowp-123f384187504585be3fe01002381dd459c17d96.tar.gz
Update gemmlowp to 13d57703abca3005d97b19df1f2db731607a7dc2
An updated is needed after TF Lite rebase. Bug: 178609672 Test: mma, NeuralNetworksStatic_test Change-Id: Ia7f04fc5b6bd760549395854618d8b20f5c8d228
Diffstat (limited to 'test')
-rw-r--r--test/benchmark.cc11
-rw-r--r--test/benchmark_all_sizes.cc19
-rw-r--r--test/test.cc84
-rw-r--r--test/test.h6
-rw-r--r--test/test_blocking_counter.cc55
-rw-r--r--test/test_fixedpoint.cc952
6 files changed, 683 insertions, 444 deletions
diff --git a/test/benchmark.cc b/test/benchmark.cc
index 9a87a41..d8236de 100644
--- a/test/benchmark.cc
+++ b/test/benchmark.cc
@@ -36,7 +36,16 @@
#warning "Building without NEON support on ARM, check your compiler setup!"
#endif
-#if defined(__SSE4_2__) && !defined(GEMMLOWP_SSE4)
+#if defined(__mips) && !defined(GEMMLOWP_MSA)
+#warning "Building without MSA support on MIPS, check your compiler setup!"
+#endif
+
+#if defined(__AVX2__) && !defined(GEMMLOWP_AVX2)
+#warning \
+ "Building without AVX2 support on AVX2 enabled machine, check your compiler setup!"
+#endif
+
+#if defined(__SSE4_2__) && !defined(GEMMLOWP_AVX2) && !defined(GEMMLOWP_SSE4)
#warning \
"Building without SSE4.2 support on SSE4.2 enabled machine, check your compiler setup!"
#endif
diff --git a/test/benchmark_all_sizes.cc b/test/benchmark_all_sizes.cc
index 16cc57c..527aad6 100644
--- a/test/benchmark_all_sizes.cc
+++ b/test/benchmark_all_sizes.cc
@@ -16,6 +16,10 @@ test/benchmark_all_sizes.cc -o /tmp/b -O3 --std=c++11 -fPIE -static \
#include "../public/gemmlowp.h"
+#ifdef GEMMLOWP_PROFILING
+#include "../profiling/profiler.h"
+#endif
+
#if defined GEMMLOWP_ANDROID && defined GEMMLOWP_ARM_32
// Compilation workaround
namespace std {
@@ -122,10 +126,10 @@ float benchmark_8bit(int rows, int depth, int cols) {
MakeZero(&rhs);
MakeZero(&result);
- typedef std::tuple<OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
+ typedef std::tuple<OutputStageQuantizeDownInt32ByFixedPoint,
OutputStageSaturatingCastToUint8>
Pipeline;
- gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
+ gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint
quantize_down_stage;
quantize_down_stage.result_offset_after_shift = 128;
quantize_down_stage.result_fixedpoint_multiplier = 1234567890;
@@ -345,7 +349,18 @@ void run_benchmarks(std::map<Shape, float>* results) {
int main() {
std::map<Shape, float> results;
+
+#ifdef GEMMLOWP_PROFILING
+ gemmlowp::RegisterCurrentThreadForProfiling();
+ gemmlowp::StartProfiling();
+#endif
+
run_benchmarks(&results);
+
+#ifdef GEMMLOWP_PROFILING
+ gemmlowp::FinishProfiling();
+#endif
+
printf("Using %d thread(s)\n", kNumThreads);
printf("depth,rows,cols,latency(s),Gop/s\n");
for (const auto& result : results) {
diff --git a/test/test.cc b/test/test.cc
index eee16b4..735ad1e 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -1277,6 +1277,47 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset,
}
}
+ // Test a variant of the familiar default pipeline consisting of quantize-down
+ // and clamp-and-cast-to-int16.
+ OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
+ auto quantize_down_and_saturating_cast_int16_pipeline =
+ std::make_tuple(quantize_down_stage, saturating_cast_int16_stage);
+ Matrix<std::int16_t, ResultOrder> result_quantized_down_saturated_int16(rows,
+ cols);
+ GemmWithOutputPipeline<std::uint8_t, std::int16_t, DefaultL8R8BitDepthParams>(
+ &context, lhs.const_map(), rhs.const_map(),
+ &result_quantized_down_saturated_int16, lhs_offset, rhs_offset,
+ quantize_down_and_saturating_cast_int16_pipeline);
+
+ for (int r = 0; r < rows; r++) {
+ for (int c = 0; c < cols; c++) {
+ std::int32_t quantized = result_quantized_down_int32(r, c);
+ std::int16_t expected = std::min(std::max(quantized, -32768), 32767);
+ Check(expected == result_quantized_down_saturated_int16(r, c));
+ }
+ }
+
+#ifdef GEMMLOWP_MSA
+ // Test a pipeline consisting of quantize-down and truncating-cast-to-uint8.
+ OutputStageTruncatingCastToUint8 truncating_cast_stage;
+ auto quantize_down_and_truncating_cast_pipeline =
+ std::make_tuple(quantize_down_stage, truncating_cast_stage);
+ Matrix<std::uint8_t, ResultOrder> result_quantized_down_truncated_uint8(
+ rows, cols);
+ GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
+ &context, lhs.const_map(), rhs.const_map(),
+ &result_quantized_down_truncated_uint8, lhs_offset, rhs_offset,
+ quantize_down_and_truncating_cast_pipeline);
+
+ for (int r = 0; r < rows; r++) {
+ for (int c = 0; c < cols; c++) {
+ std::int32_t quantized = result_quantized_down_int32(r, c);
+ std::uint8_t expected = quantized & 255;
+ Check(expected == result_quantized_down_truncated_uint8(r, c));
+ }
+ }
+#endif
+
// Test a bias-addition with row-vector
std::vector<std::int32_t> row_vector_data(cols);
std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
@@ -1428,8 +1469,8 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset,
result_fixedpoint_shift++;
}
Check(result_fixedpoint_shift >= 0);
- // Now test OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
- OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
+ // Now test OutputStageQuantizeDownInt32ByFixedPoint
+ OutputStageQuantizeDownInt32ByFixedPoint
quantize_down_by_fixedpoint_stage;
quantize_down_by_fixedpoint_stage.result_offset_after_shift =
static_cast<std::int32_t>(
@@ -1447,7 +1488,6 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset,
&result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
quantize_down_by_fixedpoint_pipeline);
- std::vector<std::int32_t> diffs_caused_by_fixedpoint;
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
const std::int32_t actual =
@@ -1462,6 +1502,44 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset,
}
}
+ // Test OutputStageScaleInt32ByFixedPointAndExponent
+ for (int exponent = -2; exponent <= 2; exponent++) {
+ OutputStageScaleInt32ByFixedPointAndExponent
+ scale_by_fixedpoint_and_exponent_stage;
+ scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift =
+ static_cast<std::int32_t>(round(static_cast<double>(
+ result_offset * result_mult_int * std::pow(2.0, exponent))));
+ scale_by_fixedpoint_and_exponent_stage.result_fixedpoint_multiplier =
+ result_fixedpoint_multiplier;
+ scale_by_fixedpoint_and_exponent_stage.result_exponent = exponent;
+ auto scale_by_fixedpoint_and_exponent_pipeline =
+ std::make_tuple(scale_by_fixedpoint_and_exponent_stage);
+ Matrix<std::int32_t, ResultOrder>
+ result_scaled_by_fixedpoint_and_exponent_int32(rows, cols);
+ GemmWithOutputPipeline<std::uint8_t, std::int32_t,
+ DefaultL8R8BitDepthParams>(
+ &context, lhs.const_map(), rhs.const_map(),
+ &result_scaled_by_fixedpoint_and_exponent_int32, lhs_offset, rhs_offset,
+ scale_by_fixedpoint_and_exponent_pipeline);
+
+ for (int r = 0; r < rows; r++) {
+ for (int c = 0; c < cols; c++) {
+ const std::int32_t actual =
+ result_scaled_by_fixedpoint_and_exponent_int32(r, c);
+ const std::int32_t raw = result_raw_int32(r, c);
+ int left_shift = std::max(0, exponent);
+ int right_shift = std::max(0, -exponent);
+ const std::int32_t expected =
+ scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift +
+ RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul((1 << left_shift) * raw,
+ result_fixedpoint_multiplier),
+ right_shift);
+ Check(actual == expected);
+ }
+ }
+ }
+
// Test the variant of the familiar default pipeline consisting of
// quantize-down and
// clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
diff --git a/test/test.h b/test/test.h
index aecd0c1..b381bad 100644
--- a/test/test.h
+++ b/test/test.h
@@ -49,7 +49,7 @@ class Matrix : public MatrixMap<tScalar, tOrder> {
typedef MatrixMap<tScalar, tOrder> Map;
typedef MatrixMap<const tScalar, tOrder> ConstMap;
typedef typename Map::Scalar Scalar;
- static const MapOrder Order = tOrder;
+ static constexpr MapOrder Order = tOrder;
using Map::kOrder;
using Map::rows_;
using Map::cols_;
@@ -92,12 +92,12 @@ class Matrix : public MatrixMap<tScalar, tOrder> {
std::vector<Scalar> storage;
};
-std::mt19937& RandomEngine() {
+inline std::mt19937& RandomEngine() {
static std::mt19937 engine;
return engine;
}
-int Random() {
+inline int Random() {
std::uniform_int_distribution<int> dist(0, std::numeric_limits<int>::max());
return dist(RandomEngine());
}
diff --git a/test/test_blocking_counter.cc b/test/test_blocking_counter.cc
index d1e0932..34d963d 100644
--- a/test/test_blocking_counter.cc
+++ b/test/test_blocking_counter.cc
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "test.h"
-#include "../profiling/pthread_everywhere.h"
-
+#include <atomic> // NOLINT
#include <vector>
+#include <iostream>
+#include <cstdlib>
#include "../internal/multi_thread_gemm.h"
+#include "../profiling/pthread_everywhere.h"
+#include "test.h"
namespace gemmlowp {
@@ -26,16 +28,36 @@ class Thread {
Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
: blocking_counter_(blocking_counter),
number_of_times_to_decrement_(number_of_times_to_decrement),
- finished_(false),
- made_the_last_decrement_(false) {
+ made_the_last_decrement_(false),
+ finished_(false) {
+#if defined GEMMLOWP_USE_PTHREAD
+ // Limit the stack size so as not to deplete memory when creating
+ // many threads.
+ pthread_attr_t attr;
+ int err = pthread_attr_init(&attr);
+ if (!err) {
+ size_t stack_size;
+ err = pthread_attr_getstacksize(&attr, &stack_size);
+ if (!err && stack_size > max_stack_size_) {
+ err = pthread_attr_setstacksize(&attr, max_stack_size_);
+ }
+ if (!err) {
+ err = pthread_create(&thread_, &attr, ThreadFunc, this);
+ }
+ }
+ if (err) {
+ std::cerr << "Failed to create a thread.\n";
+ std::abort();
+ }
+#else
pthread_create(&thread_, nullptr, ThreadFunc, this);
+#endif
}
~Thread() { Join(); }
- bool Join() const {
- if (!finished_) {
- pthread_join(thread_, nullptr);
+ bool Join() {
+ while (!finished_.load()) {
}
return made_the_last_decrement_;
}
@@ -48,7 +70,7 @@ class Thread {
Check(!made_the_last_decrement_);
made_the_last_decrement_ = blocking_counter_->DecrementCount();
}
- finished_ = true;
+ finished_.store(true);
}
static void* ThreadFunc(void* ptr) {
@@ -56,11 +78,18 @@ class Thread {
return nullptr;
}
+ static constexpr size_t max_stack_size_ = 256 * 1024;
BlockingCounter* const blocking_counter_;
const int number_of_times_to_decrement_;
pthread_t thread_;
- bool finished_;
bool made_the_last_decrement_;
+ // finished_ is used to manually implement Join() by busy-waiting.
+ // I wanted to use pthread_join / std::thread::join, but the behavior
+ // observed on Android was that pthread_join aborts when the thread has
+ // already joined before calling pthread_join, making that hard to use.
+ // It appeared simplest to just implement this simple spinlock, and that
+ // is good enough as this is just a test.
+ std::atomic<bool> finished_;
};
void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
@@ -89,10 +118,10 @@ void test_blocking_counter() {
// repeating the entire test sequence ensures that we test
// non-monotonic changes.
for (int repeat = 1; repeat <= 2; repeat++) {
- for (int num_threads = 1; num_threads <= 16; num_threads++) {
+ for (int num_threads = 1; num_threads <= 5; num_threads++) {
for (int num_decrements_per_thread = 1;
- num_decrements_per_thread <= 64 * 1024;
- num_decrements_per_thread *= 4) {
+ num_decrements_per_thread <= 4 * 1024;
+ num_decrements_per_thread *= 16) {
test_blocking_counter(blocking_counter, num_threads,
num_decrements_per_thread,
num_threads * num_decrements_per_thread);
diff --git a/test/test_fixedpoint.cc b/test/test_fixedpoint.cc
index da222f0..44e6fae 100644
--- a/test/test_fixedpoint.cc
+++ b/test/test_fixedpoint.cc
@@ -17,479 +17,587 @@
#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
#include <algorithm>
+#include <cinttypes>
#include <cmath>
+#include <cstdio>
#include <random>
#include <vector>
-#include "test.h"
#include "../fixedpoint/fixedpoint.h"
+#include "test.h"
namespace gemmlowp {
namespace {
-// Explanation of SimdVector type and associated functions
-// (LoadSimdVector, StoreSimdVector):
-// The fixedpoint stuff being tested here is generic in an underlying
-// integer type which may be either scalar (int32_t) or SIMD (e.g.
-// NEON int32x4_t). We want to write uniform tests that can test
-// both the scalar and SIMD paths. We achieve this by having this
-// generic SimdVector abstraction, local to this test.
-
+template <typename T>
+T Load(const typename FixedPointRawTypeTraits<T>::ScalarRawType* src) {
+ return *src;
+}
+template <typename T>
+void Store(typename FixedPointRawTypeTraits<T>::ScalarRawType* dst, T v) {
+ *dst = v;
+}
#ifdef GEMMLOWP_NEON
-using SimdVector = int32x4_t;
-constexpr std::size_t SimdVectorSize = 4;
-SimdVector LoadSimdVector(const std::int32_t* src) { return vld1q_s32(src); }
-void StoreSimdVector(std::int32_t* dst, SimdVector v) { vst1q_s32(dst, v); }
-#elif defined(GEMMLOWP_SSE4)
-using SimdVector = __m128i;
-constexpr std::size_t SimdVectorSize = 4;
-SimdVector LoadSimdVector(const std::int32_t* src) {
+template <>
+int32x4_t Load<int32x4_t>(const std::int32_t* src) {
+ return vld1q_s32(src);
+}
+template <>
+int16x8_t Load<int16x8_t>(const std::int16_t* src) {
+ return vld1q_s16(src);
+}
+template <>
+void Store<int32x4_t>(std::int32_t* dst, int32x4_t v) {
+ vst1q_s32(dst, v);
+}
+template <>
+void Store<int16x8_t>(std::int16_t* dst, int16x8_t v) {
+ vst1q_s16(dst, v);
+}
+#endif
+#ifdef GEMMLOWP_SSE4
+template <>
+__m128i Load<__m128i>(const std::int32_t* src) {
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(src));
}
-void StoreSimdVector(std::int32_t* dst, SimdVector v) {
+template <>
+void Store<__m128i>(std::int32_t* dst, __m128i v) {
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v);
}
-#else
-using SimdVector = std::int32_t;
-constexpr std::size_t SimdVectorSize = 1;
-SimdVector LoadSimdVector(const std::int32_t* src) { return *src; }
-void StoreSimdVector(std::int32_t* dst, SimdVector v) { *dst = v; }
+template <>
+int16x8_m128i Load<int16x8_m128i>(const std::int16_t* src) {
+ return to_int16x8_m128i(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(src)));
+}
+template <>
+void Store<int16x8_m128i>(std::int16_t* dst, int16x8_m128i v) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v.v);
+}
+#endif
+#ifdef GEMMLOWP_MSA
+template <>
+v4i32 Load<v4i32>(const std::int32_t* src) {
+ return __builtin_msa_ld_w(const_cast<std::int32_t*>(src), 0);
+}
+template <>
+v8i16 Load<v8i16>(const std::int16_t* src) {
+ return __builtin_msa_ld_h(const_cast<std::int16_t*>(src), 0);
+}
+template <>
+void Store<v4i32>(std::int32_t* dst, v4i32 v) {
+ __builtin_msa_st_w(v, dst, 0);
+}
+template <>
+void Store<v8i16>(std::int16_t* dst, v8i16 v) {
+ __builtin_msa_st_h(v, dst, 0);
+}
#endif
-// Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp:
-// Most (though not all) of the fixedpoint functionality being tested
-// consists of functions taking one fixedpoint value and returning one
-// fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators".
-// We factor a lot of testing boilerplate into a common TestUnaryOp function
-// taking a "unary op" object that fully describes the function to be tested.
-// These objects inherit UnaryOpBase mostly as a means to share some default
-// values for some properties.
-//
-// An important design element here is that the fixed-point values are passed
-// around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not
-// as higher-level FixedPoint objects. The motivation for this design is 1) to
-// avoid having to templatize everything in the tIntegerBits parameter of
-// class FixedPoint, and 2) to allow directly testing low-level functions
-// operating on raw types (e.g. RoundingDivideByPOT) without needlessly
-// requiring
-// wrapping raw values in FixedPoint objects.
-class UnaryOpBase {
- public:
- // Min bound of the input range of this op. For example, an op only handling
- // nonnegative values would return 0.
- std::int32_t MinInput() const {
- return std::numeric_limits<std::int32_t>::min();
- }
- // Max bound of the input range of this op. For example, an op only handling
- // nonpositive values would return 0.
- std::int32_t MaxInput() const {
- return std::numeric_limits<std::int32_t>::max();
- }
- // Tolerated difference between actual and reference int32 values.
- // Note that the corresponding real-numbers tolerance depends on the number
- // of integer bits of the fixed-point representation of the results of this
- // op.
- // For example, for an op returning fixed-point values with 0 integer bits,
- // the correspondence between real-number values and raw values is
- // real_number = (2^31) * raw_value.
- std::int32_t Tolerance() const { return 0; }
-};
+#ifdef GEMMLOWP_AVX2
+template <>
+__m256i Load<__m256i>(const std::int32_t* src) {
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src));
+}
-// Op wrapping RoundingDivideByPOT
-class RoundingDivideByPOTOp final : public UnaryOpBase {
- public:
- RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {}
- std::int32_t ReferenceOp(std::int32_t x) const {
- const double d = static_cast<double>(x) / (1ll << exponent_);
- return static_cast<std::int32_t>(std::round(d));
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- return RoundingDivideByPOT(x, exponent_);
- }
+template <>
+int16x16_m256i Load<int16x16_m256i>(const std::int16_t* src) {
+ return to_int16x16_m256i(
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src)));
+}
- private:
- const int exponent_;
-};
+template <>
+void Store<__m256i>(std::int32_t* dst, __m256i v) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
+}
-// Op wrapping SaturatingRoundingMultiplyByPOT
-template <int tExponent>
-class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase {
- public:
- std::int32_t ReferenceOp(std::int32_t x) const {
- const double d = static_cast<double>(x) * std::pow(2., tExponent);
- const double clamp_min = std::numeric_limits<std::int32_t>::min();
- const double clamp_max = std::numeric_limits<std::int32_t>::max();
- const double clamped = std::min(clamp_max, std::max(clamp_min, d));
- return static_cast<std::int32_t>(std::round(clamped));
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- return SaturatingRoundingMultiplyByPOT<tExponent>(x);
- }
-};
+template <>
+void Store<int16x16_m256i>(std::int16_t* dst, int16x16_m256i v) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v.v);
+}
+#endif
-// Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl
-class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final
- : public UnaryOpBase {
+template <typename tSimdType>
+class TestFixedPoint {
public:
- std::int32_t MinInput() const { return -(1 << 29); }
- std::int32_t MaxInput() const { return 0; }
- std::int32_t Tolerance() const { return 500; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = std::exp(d);
- return F::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, 0>;
- const F f = F::FromRaw(x);
- const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f);
- return e.raw();
- }
-};
+ using SimdType = tSimdType;
+ using SimdTypeTraits = FixedPointRawTypeTraits<SimdType>;
+ using ScalarType = typename SimdTypeTraits::ScalarRawType;
+ static constexpr int kSimdLanes = SimdTypeTraits::kLanes;
+ static constexpr int kScalarTypeBits = 8 * sizeof(ScalarType);
+
+ // Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp:
+ // Most (though not all) of the fixedpoint functionality being tested
+ // consists of functions taking one fixedpoint value and returning one
+ // fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators".
+ // We factor a lot of testing boilerplate into a common TestUnaryOp function
+ // taking a "unary op" object that fully describes the function to be tested.
+ // These objects inherit UnaryOpBase mostly as a means to share some default
+ // values for some properties.
+ //
+ // An important design element here is that the fixed-point values are passed
+ // around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not
+ // as higher-level FixedPoint objects. The motivation for this design is 1) to
+ // avoid having to templatize everything in the tIntegerBits parameter of
+ // class FixedPoint, and 2) to allow directly testing low-level functions
+ // operating on raw types (e.g. RoundingDivideByPOT) without needlessly
+ // requiring
+ // wrapping raw values in FixedPoint objects.
+ class UnaryOpBase {
+ public:
+ // Min bound of the input range of this op. For example, an op only handling
+ // nonnegative values would return 0.
+ ScalarType MinInput() const {
+ return std::numeric_limits<ScalarType>::min();
+ }
+ // Max bound of the input range of this op. For example, an op only handling
+ // nonpositive values would return 0.
+ ScalarType MaxInput() const {
+ return std::numeric_limits<ScalarType>::max();
+ }
+ // Tolerated difference between actual and reference ScalarType values.
+ // Note that the corresponding real-numbers tolerance depends on the number
+ // of integer bits of the fixed-point representation of the results of this
+ // op.
+ // For example, for an op returning fixed-point values with 0 integer bits,
+ // the correspondence between real-number values and raw values is
+ // real_number = (2^31) * raw_value.
+ ScalarType Tolerance() const { return 0; }
+ };
+
+ // Op wrapping RoundingDivideByPOT
+ class RoundingDivideByPOTOp final : public UnaryOpBase {
+ public:
+ RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {}
+ ScalarType ReferenceOp(ScalarType x) const {
+ const double d = static_cast<double>(x) / (1ll << exponent_);
+ return static_cast<ScalarType>(std::round(d));
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ return RoundingDivideByPOT(x, exponent_);
+ }
-// Op wrapping exp_on_negative_values
-template <int tIntegerBits>
-class ExpOnNegativeValuesOp final : public UnaryOpBase {
- public:
- std::int32_t MaxInput() const { return 0; }
- std::int32_t Tolerance() const { return 500; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, tIntegerBits>;
- using F0 = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = std::exp(d);
- return F0::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, tIntegerBits>;
- const F f = F::FromRaw(x);
- return exp_on_negative_values(f).raw();
+ private:
+ const int exponent_;
+ };
+
+ // Op wrapping SaturatingRoundingMultiplyByPOT
+ template <int tExponent>
+ class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase {
+ public:
+ ScalarType ReferenceOp(ScalarType x) const {
+ const double d = static_cast<double>(x) * std::pow(2., tExponent);
+ const double clamp_min = std::numeric_limits<ScalarType>::min();
+ const double clamp_max = std::numeric_limits<ScalarType>::max();
+ const double clamped = std::min(clamp_max, std::max(clamp_min, d));
+ return static_cast<ScalarType>(std::round(clamped));
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ return SaturatingRoundingMultiplyByPOT<tExponent>(x);
+ }
+ };
+
+ // Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl
+ class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final
+ : public UnaryOpBase {
+ public:
+ ScalarType MinInput() const { return -(1 << (kScalarTypeBits - 3)); }
+ ScalarType MaxInput() const { return 0; }
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 1; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = std::exp(d);
+ return F::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, 0>;
+ const F f = F::FromRaw(x);
+ const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f);
+ return e.raw();
+ }
+ };
+
+ // Op wrapping exp_on_negative_values
+ template <int tIntegerBits>
+ class ExpOnNegativeValuesOp final : public UnaryOpBase {
+ public:
+ ScalarType MaxInput() const { return 0; }
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 2; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, tIntegerBits>;
+ using F0 = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = std::exp(d);
+ return F0::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, tIntegerBits>;
+ const F f = F::FromRaw(x);
+ return exp_on_negative_values(f).raw();
+ }
+ };
+
+ // Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1
+ class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase {
+ public:
+ ScalarType MinInput() const { return 0; }
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 12 : 11; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = (1 - d) / (1 + d);
+ return F::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, 0>;
+ const F f = F::FromRaw(x);
+ return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw();
+ }
+ };
+
+ // Op wrapping tanh
+ template <int tIntegerBits>
+ class TanhOp final : public UnaryOpBase {
+ public:
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 310 : 12; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, tIntegerBits>;
+ using F0 = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = std::tanh(d);
+ return F0::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, tIntegerBits>;
+ const F f = F::FromRaw(x);
+ return tanh(f).raw();
+ }
+ };
+
+ // Op wrapping one_over_one_plus_x_for_x_in_0_1
+ class OneOverOnePlusXForXIn01Op final : public UnaryOpBase {
+ public:
+ ScalarType MinInput() const { return 0; }
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 6 : 5; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = 1 / (1 + d);
+ return F::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, 0>;
+ const F f = F::FromRaw(x);
+ return one_over_one_plus_x_for_x_in_0_1(f).raw();
+ }
+ };
+
+ // Op wrapping logistic
+ template <int tIntegerBits>
+ class LogisticOp final : public UnaryOpBase {
+ public:
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 155 : 6; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, tIntegerBits>;
+ using F0 = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = 1 / (1 + std::exp(-d));
+ return F0::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, tIntegerBits>;
+ const F f = F::FromRaw(x);
+ return logistic(f).raw();
+ }
+ };
+
+ // Tests a given op, on a given list of int32 input values.
+ template <typename tUnaryOpType>
+ void TestUnaryOp(const tUnaryOpType& unary_op,
+ const std::vector<ScalarType>& testvals) {
+ Check(0 == (testvals.size() % kSimdLanes));
+ for (std::size_t i = 0; i < testvals.size(); i += kSimdLanes) {
+ // First, clamp input values accoding to the MinInput() and MaxInput()
+ // bounds returned by the op.
+ ScalarType input[kSimdLanes] = {0};
+ for (std::size_t j = 0; j < kSimdLanes; j++) {
+ const ScalarType raw_input = testvals[i + j];
+ input[j] = std::min(unary_op.MaxInput(),
+ std::max(unary_op.MinInput(), raw_input));
+ }
+ // Compute reference results and check that the actual results on
+ // scalar inputs agree with them, to the Tolerance() returned by the op.
+ ScalarType reference[kSimdLanes] = {0};
+ ScalarType actual_scalar[kSimdLanes] = {0};
+ for (std::size_t j = 0; j < kSimdLanes; j++) {
+ reference[j] = unary_op.ReferenceOp(input[j]);
+ actual_scalar[j] = unary_op.Op(input[j]);
+ const std::int64_t diff = static_cast<std::int64_t>(actual_scalar[j]) -
+ static_cast<std::int64_t>(reference[j]);
+ if (std::abs(diff) > unary_op.Tolerance()) {
+ fprintf(stderr, "abs(diff) (%" PRId64 ") > tolerance (%d)\n", diff,
+ unary_op.Tolerance());
+ }
+ Check(std::abs(diff) <= unary_op.Tolerance());
+ }
+ // Check that the actual results on SIMD inputs agree *exactly* with the
+ // actual results on scalar inputs. I.e. SIMD must make absolutely no
+ // difference
+ // to the results, regardless of the fact that both scalar and SIMD
+ // results may differ from the reference results.
+ ScalarType actual_simd[kSimdLanes] = {0};
+ Store<SimdType>(actual_simd, unary_op.Op(Load<SimdType>(input)));
+ for (std::size_t j = 0; j < kSimdLanes; j++) {
+ if (actual_simd[j] != actual_scalar[j]) {
+ fprintf(stderr, "SIMD (%d) != scalar (%d)\n", actual_simd[j],
+ actual_scalar[j]);
+ }
+ Check(actual_simd[j] == actual_scalar[j]);
+ }
+ }
}
-};
-// Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1
-class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase {
- public:
- std::int32_t MinInput() const { return 0; }
- std::int32_t Tolerance() const { return 12; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = (1 - d) / (1 + d);
- return F::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, 0>;
- const F f = F::FromRaw(x);
- return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw();
+ template <int tIntegerBits>
+ void test_convert(FixedPoint<ScalarType, tIntegerBits> x) {
+ typedef FixedPoint<ScalarType, tIntegerBits> F;
+ F y = F::FromDouble(ToDouble(x));
+ Check(y == x);
}
-};
-// Op wrapping tanh
-template <int tIntegerBits>
-class TanhOp final : public UnaryOpBase {
- public:
- std::int32_t Tolerance() const { return 310; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, tIntegerBits>;
- using F0 = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = std::tanh(d);
- return F0::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, tIntegerBits>;
- const F f = F::FromRaw(x);
- return tanh(f).raw();
+ template <int tIntegerBits_a, int tIntegerBits_b>
+ void test_Rescale(FixedPoint<ScalarType, tIntegerBits_a> a) {
+ FixedPoint<ScalarType, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
+ FixedPoint<ScalarType, tIntegerBits_b> expected =
+ FixedPoint<ScalarType, tIntegerBits_b>::FromDouble(ToDouble(a));
+ Check(actual == expected);
}
-};
-// Op wrapping one_over_one_plus_x_for_x_in_0_1
-class OneOverOnePlusXForXIn01Op final : public UnaryOpBase {
- public:
- std::int32_t MinInput() const { return 0; }
- std::int32_t Tolerance() const { return 6; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = 1 / (1 + d);
- return F::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, 0>;
- const F f = F::FromRaw(x);
- return one_over_one_plus_x_for_x_in_0_1(f).raw();
+ template <int tIntegerBits_a, int tIntegerBits_b>
+ void test_Rescale(const std::vector<ScalarType>& testvals) {
+ for (auto a : testvals) {
+ FixedPoint<ScalarType, tIntegerBits_a> aq;
+ aq.raw() = a;
+ test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
+ }
}
-};
-// Op wrapping logistic
-template <int tIntegerBits>
-class LogisticOp final : public UnaryOpBase {
- public:
- std::int32_t Tolerance() const { return 155; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, tIntegerBits>;
- using F0 = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = 1 / (1 + std::exp(-d));
- return F0::FromDouble(e).raw();
+ template <int tIntegerBits_a, int tIntegerBits_b>
+ void test_mul(FixedPoint<ScalarType, tIntegerBits_a> a,
+ FixedPoint<ScalarType, tIntegerBits_b> b) {
+ static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b;
+ using ProductFixedPoint = FixedPoint<ScalarType, ProductIntegerBits>;
+ ProductFixedPoint ab;
+ ab = a * b;
+ double a_double = ToDouble(a);
+ double b_double = ToDouble(b);
+ double ab_double = a_double * b_double;
+ ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double);
+ std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw());
+ Check(std::abs(diff) <= 1);
}
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, tIntegerBits>;
- const F f = F::FromRaw(x);
- return logistic(f).raw();
- }
-};
-// Tests a given op, on a given list of int32 input values.
-template <typename tUnaryOpType>
-void TestUnaryOp(const tUnaryOpType& unary_op,
- const std::vector<std::int32_t>& testvals_int32) {
- Check(0 == (testvals_int32.size() % SimdVectorSize));
- for (std::size_t i = 0; i < testvals_int32.size(); i += SimdVectorSize) {
- // First, clamp input int32 values accoding to the MinInput() and MaxInput()
- // bounds returned by the op.
- std::int32_t input[SimdVectorSize] = {0};
- for (std::size_t j = 0; j < SimdVectorSize; j++) {
- const std::int32_t raw_input = testvals_int32[i + j];
- input[j] = std::min(unary_op.MaxInput(),
- std::max(unary_op.MinInput(), raw_input));
- }
- // Compute reference results and check that the actual results on
- // scalar inputs agree with them, to the Tolerance() returned by the op.
- std::int32_t reference[SimdVectorSize] = {0};
- std::int32_t actual_scalar[SimdVectorSize] = {0};
- for (std::size_t j = 0; j < SimdVectorSize; j++) {
- reference[j] = unary_op.ReferenceOp(input[j]);
- actual_scalar[j] = unary_op.Op(input[j]);
- const std::int64_t diff = static_cast<std::int64_t>(actual_scalar[j]) -
- static_cast<std::int64_t>(reference[j]);
- Check(std::abs(diff) <= unary_op.Tolerance());
- }
- // Check that the actual results on SIMD inputs agree *exactly* with the
- // actual results on scalar inputs. I.e. SIMD must make absolutely no
- // difference
- // to the results, regardless of the fact that both scalar and SIMD results
- // may differ from the reference results.
- std::int32_t actual_simd[SimdVectorSize] = {0};
- StoreSimdVector(actual_simd, unary_op.Op(LoadSimdVector(input)));
- for (std::size_t j = 0; j < SimdVectorSize; j++) {
- Check(actual_simd[j] == actual_scalar[j]);
+ template <int tIntegerBits_a, int tIntegerBits_b>
+ void test_mul(const std::vector<ScalarType>& testvals) {
+ for (auto a : testvals) {
+ for (auto b : testvals) {
+ FixedPoint<ScalarType, tIntegerBits_a> aq;
+ FixedPoint<ScalarType, tIntegerBits_b> bq;
+ aq.raw() = a;
+ bq.raw() = b;
+ test_mul(aq, bq);
+ }
}
}
-}
-template <int tIntegerBits>
-void test_convert(FixedPoint<std::int32_t, tIntegerBits> x) {
- typedef FixedPoint<std::int32_t, tIntegerBits> F;
- F y = F::FromDouble(ToDouble(x));
- Check(y == x);
-}
-
-template <int tIntegerBits_a, int tIntegerBits_b>
-void test_Rescale(FixedPoint<std::int32_t, tIntegerBits_a> a) {
- FixedPoint<std::int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
- FixedPoint<std::int32_t, tIntegerBits_b> expected =
- FixedPoint<std::int32_t, tIntegerBits_b>::FromDouble(ToDouble(a));
- Check(actual == expected);
-}
-
-template <int tIntegerBits_a, int tIntegerBits_b>
-void test_Rescale(const std::vector<std::int32_t>& testvals_int32) {
- for (auto a : testvals_int32) {
- FixedPoint<std::int32_t, tIntegerBits_a> aq;
- aq.raw() = a;
- test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
+ template <int tExponent, int tIntegerBits_a>
+ void test_ExactMulByPot(FixedPoint<ScalarType, tIntegerBits_a> a) {
+ double x = ToDouble(a) * std::pow(2.0, tExponent);
+ double y = ToDouble(ExactMulByPot<tExponent>(a));
+ Check(x == y);
}
-}
-
-template <int tIntegerBits_a, int tIntegerBits_b>
-void test_mul(FixedPoint<std::int32_t, tIntegerBits_a> a,
- FixedPoint<std::int32_t, tIntegerBits_b> b) {
- static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b;
- using ProductFixedPoint = FixedPoint<std::int32_t, ProductIntegerBits>;
- ProductFixedPoint ab;
- ab = a * b;
- double a_double = ToDouble(a);
- double b_double = ToDouble(b);
- double ab_double = a_double * b_double;
- ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double);
- std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw());
- Check(std::abs(diff) <= 1);
-}
-template <int tIntegerBits_a, int tIntegerBits_b>
-void test_mul(const std::vector<std::int32_t>& testvals_int32) {
- for (auto a : testvals_int32) {
- for (auto b : testvals_int32) {
- FixedPoint<std::int32_t, tIntegerBits_a> aq;
- FixedPoint<std::int32_t, tIntegerBits_b> bq;
+ template <int tExponent, int tIntegerBits_a>
+ void test_ExactMulByPot(const std::vector<ScalarType>& testvals) {
+ for (auto a : testvals) {
+ FixedPoint<ScalarType, tIntegerBits_a> aq;
aq.raw() = a;
- bq.raw() = b;
- test_mul(aq, bq);
+ test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
}
}
-}
-template <int tExponent, int tIntegerBits_a>
-void test_ExactMulByPot(FixedPoint<std::int32_t, tIntegerBits_a> a) {
- double x = ToDouble(a) * std::pow(2.0, tExponent);
- double y = ToDouble(ExactMulByPot<tExponent>(a));
- Check(x == y);
-}
+ // Make the list of test values to test each op against.
+ std::vector<ScalarType> MakeTestVals() {
+ std::vector<ScalarType> testvals;
+
+ for (int i = 0; i < kScalarTypeBits - 1; i++) {
+ testvals.push_back((1 << i) - 2);
+ testvals.push_back((1 << i) - 1);
+ testvals.push_back((1 << i));
+ testvals.push_back((1 << i) + 1);
+ testvals.push_back((1 << i) + 2);
+ testvals.push_back(-(1 << i) - 2);
+ testvals.push_back(-(1 << i) - 1);
+ testvals.push_back(-(1 << i));
+ testvals.push_back(-(1 << i) + 1);
+ testvals.push_back(-(1 << i) + 2);
+ }
+ testvals.push_back(std::numeric_limits<ScalarType>::min());
+ testvals.push_back(std::numeric_limits<ScalarType>::min() + 1);
+ testvals.push_back(std::numeric_limits<ScalarType>::min() + 2);
+ testvals.push_back(std::numeric_limits<ScalarType>::max() - 2);
+ testvals.push_back(std::numeric_limits<ScalarType>::max() - 1);
+ testvals.push_back(std::numeric_limits<ScalarType>::max());
+
+ std::mt19937 random_engine;
+ std::uniform_int_distribution<ScalarType> uniform_distribution(
+ std::numeric_limits<ScalarType>::min(),
+ std::numeric_limits<ScalarType>::max());
+ for (int i = 0; i < 1000; i++) {
+ testvals.push_back(uniform_distribution(random_engine));
+ }
-template <int tExponent, int tIntegerBits_a>
-void test_ExactMulByPot(const std::vector<std::int32_t>& testvals_int32) {
- for (auto a : testvals_int32) {
- FixedPoint<std::int32_t, tIntegerBits_a> aq;
- aq.raw() = a;
- test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
- }
-}
+ // SIMD tests will require the length of testvals to be a multiple
+ // of SIMD vector size.
+ while (testvals.size() % kSimdLanes) {
+ testvals.push_back(0);
+ }
-// Make the list of test values to test each op against.
-std::vector<std::int32_t> MakeTestValsInt32() {
- std::vector<std::int32_t> testvals_int32;
-
- for (int i = 0; i < 31; i++) {
- testvals_int32.push_back((1 << i) - 2);
- testvals_int32.push_back((1 << i) - 1);
- testvals_int32.push_back((1 << i));
- testvals_int32.push_back((1 << i) + 1);
- testvals_int32.push_back((1 << i) + 2);
- testvals_int32.push_back(-(1 << i) - 2);
- testvals_int32.push_back(-(1 << i) - 1);
- testvals_int32.push_back(-(1 << i));
- testvals_int32.push_back(-(1 << i) + 1);
- testvals_int32.push_back(-(1 << i) + 2);
- }
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::min());
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::min() + 1);
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::min() + 2);
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::max() - 2);
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::max() - 1);
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::max());
-
- std::mt19937 random_engine;
- std::uniform_int_distribution<std::int32_t> uniform_distribution(
- std::numeric_limits<std::int32_t>::min(),
- std::numeric_limits<std::int32_t>::max());
- for (int i = 0; i < 1000; i++) {
- testvals_int32.push_back(uniform_distribution(random_engine));
+ std::sort(testvals.begin(), testvals.end());
+ return testvals;
}
- // SIMD tests will require the length of testvals_int32 to be a multiple
- // of SIMD vector size.
- while (testvals_int32.size() % SimdVectorSize) {
- testvals_int32.push_back(0);
- }
+ void RunTests(const char* msg) {
+ const std::vector<ScalarType> testvals = MakeTestVals();
- std::sort(testvals_int32.begin(), testvals_int32.end());
- return testvals_int32;
-}
+ for (int s = 0; s < kScalarTypeBits; s++) {
+ TestUnaryOp(RoundingDivideByPOTOp(s), testvals);
+ }
+
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<14 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 15>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 14>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 3>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 2>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 1>(),
+ testvals);
+
+ TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals);
+
+ TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals);
+ TestUnaryOp(TanhOp<0>(), testvals);
+ TestUnaryOp(TanhOp<1>(), testvals);
+ TestUnaryOp(TanhOp<2>(), testvals);
+ TestUnaryOp(TanhOp<3>(), testvals);
+ TestUnaryOp(TanhOp<4>(), testvals);
+ TestUnaryOp(TanhOp<5>(), testvals);
+ TestUnaryOp(TanhOp<6>(), testvals);
+
+ TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals);
+ TestUnaryOp(LogisticOp<0>(), testvals);
+ TestUnaryOp(LogisticOp<1>(), testvals);
+ TestUnaryOp(LogisticOp<2>(), testvals);
+ TestUnaryOp(LogisticOp<3>(), testvals);
+ TestUnaryOp(LogisticOp<4>(), testvals);
+ TestUnaryOp(LogisticOp<5>(), testvals);
+ TestUnaryOp(LogisticOp<6>(), testvals);
+
+ for (auto a : testvals) {
+ FixedPoint<ScalarType, 4> x;
+ x.raw() = a;
+ test_convert(x);
+ }
+
+ test_mul<0, 0>(testvals);
+ test_mul<0, 1>(testvals);
+ test_mul<2, 0>(testvals);
+ test_mul<1, 1>(testvals);
+ test_mul<4, 4>(testvals);
+ test_mul<3, 5>(testvals);
+ test_mul<7, 2>(testvals);
+ test_mul<kScalarTypeBits / 2 - 1, kScalarTypeBits / 2 - 2>(testvals);
+
+ test_Rescale<0, 0>(testvals);
+ test_Rescale<0, 1>(testvals);
+ test_Rescale<2, 0>(testvals);
+ test_Rescale<4, 4>(testvals);
+ test_Rescale<4, 5>(testvals);
+ test_Rescale<6, 3>(testvals);
+ test_Rescale<13, 9>(testvals);
+
+ test_ExactMulByPot<0, 0>(testvals);
+ test_ExactMulByPot<0, 4>(testvals);
+ test_ExactMulByPot<1, 4>(testvals);
+ test_ExactMulByPot<3, 2>(testvals);
+ test_ExactMulByPot<-4, 5>(testvals);
+ test_ExactMulByPot<-2, 6>(testvals);
+
+ fprintf(stderr, "PASS (%s)\n", msg);
+ }
+};
} // end anonymous namespace
} // end namespace gemmlowp
int main() {
- using namespace gemmlowp;
-
- const std::vector<std::int32_t> testvals_int32 = MakeTestValsInt32();
-
- for (int s = 0; s < 32; s++) {
- TestUnaryOp(RoundingDivideByPOTOp(s), testvals_int32);
- }
-
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-31>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-30>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-29>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-17>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-16>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<16>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<17>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<29>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<30>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<31>(), testvals_int32);
-
- TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(),
- testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals_int32);
-
- TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals_int32);
- TestUnaryOp(TanhOp<0>(), testvals_int32);
- TestUnaryOp(TanhOp<1>(), testvals_int32);
- TestUnaryOp(TanhOp<2>(), testvals_int32);
- TestUnaryOp(TanhOp<3>(), testvals_int32);
- TestUnaryOp(TanhOp<4>(), testvals_int32);
- TestUnaryOp(TanhOp<5>(), testvals_int32);
- TestUnaryOp(TanhOp<6>(), testvals_int32);
-
- TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals_int32);
- TestUnaryOp(LogisticOp<0>(), testvals_int32);
- TestUnaryOp(LogisticOp<1>(), testvals_int32);
- TestUnaryOp(LogisticOp<2>(), testvals_int32);
- TestUnaryOp(LogisticOp<3>(), testvals_int32);
- TestUnaryOp(LogisticOp<4>(), testvals_int32);
- TestUnaryOp(LogisticOp<5>(), testvals_int32);
- TestUnaryOp(LogisticOp<6>(), testvals_int32);
-
- for (auto a : testvals_int32) {
- FixedPoint<std::int32_t, 4> x;
- x.raw() = a;
- test_convert(x);
- }
-
- test_mul<0, 0>(testvals_int32);
- test_mul<0, 1>(testvals_int32);
- test_mul<2, 0>(testvals_int32);
- test_mul<1, 1>(testvals_int32);
- test_mul<4, 4>(testvals_int32);
- test_mul<3, 5>(testvals_int32);
- test_mul<7, 2>(testvals_int32);
- test_mul<14, 15>(testvals_int32);
-
- test_Rescale<0, 0>(testvals_int32);
- test_Rescale<0, 1>(testvals_int32);
- test_Rescale<2, 0>(testvals_int32);
- test_Rescale<4, 4>(testvals_int32);
- test_Rescale<4, 5>(testvals_int32);
- test_Rescale<6, 3>(testvals_int32);
- test_Rescale<13, 9>(testvals_int32);
-
- test_ExactMulByPot<0, 0>(testvals_int32);
- test_ExactMulByPot<0, 4>(testvals_int32);
- test_ExactMulByPot<1, 4>(testvals_int32);
- test_ExactMulByPot<3, 2>(testvals_int32);
- test_ExactMulByPot<-4, 5>(testvals_int32);
- test_ExactMulByPot<-2, 6>(testvals_int32);
-
- std::cerr << "All tests passed." << std::endl;
+ gemmlowp::TestFixedPoint<std::int32_t>().RunTests("Scalar int32");
+ gemmlowp::TestFixedPoint<std::int16_t>().RunTests("Scalar int16");
+#ifdef GEMMLOWP_SSE4
+ gemmlowp::TestFixedPoint<__m128i>().RunTests("SSE4 __m128i = int32x4");
+ gemmlowp::TestFixedPoint<gemmlowp::int16x8_m128i>().RunTests(
+ "SSE4 __m128i = int16x8");
+#endif
+#ifdef GEMMLOWP_NEON
+ gemmlowp::TestFixedPoint<int32x4_t>().RunTests("NEON int32x4_t");
+ gemmlowp::TestFixedPoint<int16x8_t>().RunTests("NEON int16x8_t");
+#endif
+#ifdef GEMMLOWP_MSA
+ gemmlowp::TestFixedPoint<v4i32>().RunTests("MSA v4i32");
+ gemmlowp::TestFixedPoint<v8i16>().RunTests("MSA v8i16");
+#endif
+#ifdef GEMMLOWP_AVX2
+ gemmlowp::TestFixedPoint<__m256i>().RunTests("AVX __m256i");
+ gemmlowp::TestFixedPoint<gemmlowp::int16x16_m256i>().RunTests(
+ "AVX2 __m256i = int16x16");
+#endif
}