aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLev Proleev <levp@google.com>2021-03-12 18:40:35 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2021-03-12 18:40:35 +0000
commit8a4a46b623427068a13c3cbdc93a751953f7ce09 (patch)
treea29716289a0b730ca66a3e632c6ce054eb3b90d6
parent5ddee49b6b65a7be1f67186e0aa6ad74cf956165 (diff)
parent4696348bdca97daf302db4a53048b4d6127f800c (diff)
downloadgemmlowp-8a4a46b623427068a13c3cbdc93a751953f7ce09.tar.gz
Original change: https://android-review.googlesource.com/c/platform/external/gemmlowp/+/1610833 MUST ONLY BE SUBMITTED BY AUTOMERGER Change-Id: I0e39ec9870ed171cd4e6a85f280801262d840f98
-rw-r--r--AUTHORS (renamed from AUTHORS.txt)5
-rw-r--r--Android.bp2
-rw-r--r--CONTRIBUTING (renamed from CONTRIBUTING.txt)0
-rw-r--r--CONTRIBUTORS (renamed from CONTRIBUTORS.txt)18
-rw-r--r--LICENSE (renamed from LICENSE.txt)0
-rw-r--r--README.md276
-rw-r--r--README.txt260
-rw-r--r--fixedpoint/fixedpoint.h30
-rw-r--r--fixedpoint/fixedpoint_avx.h168
-rw-r--r--fixedpoint/fixedpoint_sse.h52
-rw-r--r--fixedpoint/fixedpoint_wasmsimd.h381
-rw-r--r--flags.bzl7
-rw-r--r--internal/allocator.h4
-rw-r--r--internal/common.h2
-rw-r--r--internal/detect_platform.h5
-rw-r--r--internal/dispatch_gemm_shape.h6
-rw-r--r--internal/kernel.h20
-rw-r--r--internal/output_sse.h21
-rw-r--r--internal/pack.h24
-rw-r--r--internal/pack_sse.h10
-rw-r--r--internal/platform.h3
-rw-r--r--meta/generators/cc_emitter.py8
-rw-r--r--meta/generators/common.py3
-rw-r--r--meta/generators/neon_emitter.py8
-rw-r--r--meta/generators/neon_emitter_64.py8
-rw-r--r--public/bit_depth.h4
-rw-r--r--public/map.h6
-rw-r--r--standalone/cache_counters.cc404
-rw-r--r--standalone/encode.py134
-rw-r--r--standalone/neon-gemm-kernel-benchmark.cc2350
-rw-r--r--test/benchmark.cc11
-rw-r--r--test/benchmark_all_sizes.cc19
-rw-r--r--test/test.cc84
-rw-r--r--test/test.h6
-rw-r--r--test/test_blocking_counter.cc55
-rw-r--r--test/test_fixedpoint.cc952
36 files changed, 3765 insertions, 1581 deletions
diff --git a/AUTHORS.txt b/AUTHORS
index 13a49e0..996e104 100644
--- a/AUTHORS.txt
+++ b/AUTHORS
@@ -7,3 +7,8 @@
# The email address is not required for organizations.
Google Inc.
+Intel Corporation
+ARM Ltd.
+Silk Labs Inc.
+MIPS Tech LLC
+Wave Computing Inc.
diff --git a/Android.bp b/Android.bp
index 649324a..5efb5ff 100644
--- a/Android.bp
+++ b/Android.bp
@@ -30,7 +30,7 @@ license {
"SPDX-license-identifier-Apache-2.0",
],
license_text: [
- "LICENSE.txt",
+ "LICENSE",
"NOTICE",
],
}
diff --git a/CONTRIBUTING.txt b/CONTRIBUTING
index d6d63bc..d6d63bc 100644
--- a/CONTRIBUTING.txt
+++ b/CONTRIBUTING
diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS
index 7c2415b..3740e0e 100644
--- a/CONTRIBUTORS.txt
+++ b/CONTRIBUTORS
@@ -15,8 +15,26 @@ Pete Warden <petewarden@google.com>
Miao Wang <miaowang@google.com>
David Andersen <dga@google.com>
Maciek Chociej <maciekc@google.com>
+Justine Tunney <jart@google.com>
+Mark J. Matthews <mjmatthews@google.com>
+Marie White <mariewhite@google.com>
+Suharsh Sivakumar <suharshs@google.com>
Intel:
Sagi Marcovich <sagi.marcovich@intel.com>
Murat Efe Guney <murat.e.guney@intel.com>
Sarah Knepper <sarah.knepper@intel.com>
+Mourad Gouicem <mourad.gouicem@intel.com>
+Richard Winterton <richard.winterton@intel.com>
+
+ARM:
+David Mansell <David.Mansell@arm.com>
+
+Silk Labs:
+Andreas Gal <andreas@silklabs.com>
+
+MIPS Tech LLC:
+Alexey Frunze <Alexey.Frunze@mips.com>
+
+Wave Computing Inc.:
+Alexey Frunze <afrunze@wavecomp.com>
diff --git a/LICENSE.txt b/LICENSE
index d645695..d645695 100644
--- a/LICENSE.txt
+++ b/LICENSE
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..22fabac
--- /dev/null
+++ b/README.md
@@ -0,0 +1,276 @@
+# gemmlowp: a small self-contained low-precision GEMM library
+
+[![Build Status](https://secure.travis-ci.org/google/gemmlowp.png)](http://travis-ci.org/google/gemmlowp)
+
+This is not a full linear algebra library, only a GEMM library: it only does
+general matrix multiplication ("GEMM").
+
+The meaning of "low precision" is detailed in this document:
+[doc/low-precision.md](doc/low-precision.md)
+
+Some of the general design is explained in [doc/design.md](doc/design.md).
+
+**Warning:** This library goes very slow if compiled incorrectly; see below.
+
+## Disclaimer
+
+This is not an official Google product (experimental or otherwise), it is just
+code that happens to be owned by Google.
+
+## Mailing list
+
+gemmlowp-related discussion, about either development or usage, is welcome on
+this Google Group (mailing list / forum):
+
+https://groups.google.com/forum/#!forum/gemmlowp
+
+## Portability, target platforms/architectures
+
+Should be portable to any platform with some C++11 and POSIX support, while we
+have optional optimized code paths for specific architectures.
+
+Required:
+
+* C++11 (a small conservative subset of it)
+
+Required for some features:
+
+* Some POSIX interfaces:
+ * pthreads (for multi-threaded operation and for profiling).
+ * sysconf (for multi-threaded operation to detect number of cores; may be
+ bypassed).
+
+Optional:
+
+* Architecture-specific code paths use intrinsics or inline assembly. See
+ "Architecture-specific optimized code paths" below.
+
+## Architecture-specific optimized code paths
+
+We have some optimized code paths for specific instruction sets. Some are
+written in inline assembly, some are written in C++ using intrinsics. Both GCC
+and Clang are supported.
+
+Current optimized code paths:
+
+* ARM with NEON (both 32bit and 64bit).
+* Intel x86 with SSE 4.1 (both 32bit and 64bit).
+
+When building for x86, it's very important to pass `-msse4.1` to the compiler,
+otherwise gemmlowp will use slow reference code. Bazel users can compile by
+running `bazel build --copt=-msse4.1 //gemmlowp:all`. The compiled binary should
+work on all Intel CPUs since 2008 (including low power microarchitectures) as
+well as AMD CPUs since 2011.
+
+Please note when compiling binaries that don't need to be distributed, it's
+generally a better idea to pass `-march=native` to the compiler. That flag
+implies `-msse4.1` flag, along with others that might be helpful. This of course
+assumes the host machine supports those instructions. Bazel users should prefer
+to run `bazel build --config=opt //gemmlowp:all` instead.
+
+Details of what it takes to make an efficient port of gemmlowp, namely writing a
+suitable GEMM kernel and accompanying packing code, are explained in this file:
+[doc/kernel.md](doc/kernel.md).
+
+## Public interfaces
+
+### The gemmlowp public interface
+
+gemmlowp's main public interface is in the `public/` subdirectory.
+
+This is a headers-only library, so there is nothing to link to.
+
+Usage documentation, and comments on the deprecation status of each public entry
+point, may be found in [doc/public.md](doc/public.md) .
+
+A full, self-contained usage example, showing how to quantize float matrices and
+perform a quantized matrix multiplication approximating a float matrix
+multiplication, is given in
+[doc/quantization_example.cc](doc/quantization_example.cc).
+
+### Old EightBitIntGemm legacy deprecated interface
+
+The `eight_bit_int_gemm/` subdirectory contains an alternate interface that
+should be considered purely legacy, deprecated, and going to be removed at some
+point in the future.
+
+## Building
+
+### Building by manually invoking your compiler
+
+Because gemmlowp is so simple, working with it involves only single-command-line
+compiler invocations. Therefore we expect that most people working with gemmlowp
+will either manually invoke their compiler, or write their own rules for their
+own preferred build system.
+
+Keep in mind (previous section) that gemmlowp itself is a pure-headers-only
+library so there is nothing to build.
+
+For a Android gemmlowp development workflow, the `scripts/` directory contains a
+script to build and run a program on an Android device:
+
+```
+scripts/test-android.sh
+```
+
+### Building using Bazel
+
+That being said, we also maintain a Bazel BUILD system as part of gemmlowp. Its
+usage is not mandatory at all and is only one possible way that gemmlowp
+libraries and tests may be built. If you are interested, Bazel's home page is
+http://bazel.build/ And you can get started with using Bazel to build gemmlowp
+targets by first creating an empty WORKSPACE file in a parent directory, for
+instance:
+
+```
+$ cd gemmlowp/.. # change to parent directory containing gemmlowp/
+$ touch WORKSPACE # declare that to be our workspace root
+$ bazel build gemmlowp:all
+```
+
+## Testing
+
+### Testing by manually building and running tests
+
+The test/ directory contains unit tests. The primary unit test is
+
+```
+test/test.cc
+```
+
+Since it covers also the EightBitIntGemm interface, it needs to be linked
+against
+
+```
+eight_bit_int_gemm/eight_bit_int_gemm.cc
+```
+
+It also uses realistic data captured from a neural network run in
+
+```
+test/test_data.cc
+```
+
+Thus you'll want to pass the following list of source files to your
+compiler/linker:
+
+```
+test/test.cc
+eight_bit_int_gemm/eight_bit_int_gemm.cc
+test/test_data.cc
+```
+
+The `scripts/` directory contains a script to build and run a program on an
+Android device:
+
+```
+scripts/test-android.sh
+```
+
+It expects the `CXX` environment variable to point to an Android toolchain's C++
+compiler, and expects source files (and optionally, cflags) as command-line
+parameters. To build and run the above-mentioned main unit test, first set `CXX`
+e.g.:
+
+```
+$ export CXX=/some/toolchains/arm-linux-androideabi-4.8/bin/arm-linux-androideabi-g++
+```
+
+Then run:
+
+```
+$ ./scripts/test-android.sh \
+test/test.cc \
+eight_bit_int_gemm/eight_bit_int_gemm.cc \
+test/test_data.cc
+```
+
+### Testing using Bazel
+
+Alternatively, you can use Bazel to build and run tests. See the Bazel
+instruction in the above section on building. Once your Bazel workspace is set
+up, you can for instance do:
+
+```
+$ bazel test gemmlowp:all
+```
+
+## Troubleshooting Compilation
+
+If you're having trouble finding the compiler, follow these instructions to
+build a standalone toolchain:
+https://developer.android.com/ndk/guides/standalone_toolchain.html
+
+Here's an example of setting up Clang 3.5:
+
+```
+$ export INSTALL_DIR=~/toolchains/clang-21-stl-gnu
+$ $NDK/build/tools/make-standalone-toolchain.sh \
+--toolchain=arm-linux-androideabi-clang3.5 --platform=android-21 \
+--install-dir=$INSTALL_DIR
+$ export CXX="$INSTALL_DIR/bin/arm-linux-androideabi-g++ \
+--sysroot=$INSTALL_DIR/sysroot"
+```
+
+Some compilers (e.g. the default clang++ in the same bin directory) don't
+support NEON assembly. The benchmark build process will issue a warning if
+support isn't detected, and you should make sure you're using a compiler like
+arm-linux-androideabi-g++ that does include NEON.
+
+## Benchmarking
+
+The main benchmark is
+
+```
+test/benchmark.cc
+```
+
+It doesn't need to be linked to any other source file. We recommend building
+with assertions disabled (`-DNDEBUG`).
+
+For example, the benchmark can be built and run on an Android device by doing:
+
+```
+$ ./scripts/test-android.sh test/benchmark.cc -DNDEBUG
+```
+
+If `GEMMLOWP_TEST_PROFILE` is defined then the benchmark will be built with
+profiling instrumentation (which makes it slower) and will dump profiles. See
+next section on profiling.
+
+## Profiling
+
+The `profiling/` subdirectory offers a very simple, naive, inaccurate,
+non-interrupting sampling profiler that only requires pthreads (no signals).
+
+It relies on source code being instrumented with pseudo-stack labels. See
+`profiling/instrumentation.h`. A full example of using this profiler is given in
+the top comment of `profiling/profiler.h`.
+
+## Contributing
+
+Contribution-related discussion is always welcome on the gemmlowp mailing list
+(see above).
+
+We try to keep a current list of TODO items in the `todo/` directory.
+Prospective contributors are welcome to pick one to work on, and communicate
+about it on the gemmlowp mailing list.
+
+Details of the contributing process, including legalese, are in CONTRIBUTING.
+
+## Performance goals
+
+Our performance goals differ from typical GEMM performance goals in the
+following ways:
+
+1. We care not only about speed, but also about minimizing power usage. We
+ specifically care about charge usage in mobile/embedded devices. This
+ implies that we care doubly about minimizing memory bandwidth usage: we care
+ about it, like any GEMM, because of the impact on speed, and we also care
+ about it because it is a key factor of power usage.
+
+2. Most GEMMs are optimized primarily for large dense matrix sizes (>= 1000).
+ We do care about large sizes, but we also care specifically about the
+ typically smaller matrix sizes encountered in various mobile applications.
+ This means that we have to optimize for all sizes, not just for large enough
+ sizes.
diff --git a/README.txt b/README.txt
deleted file mode 100644
index e29f0e4..0000000
--- a/README.txt
+++ /dev/null
@@ -1,260 +0,0 @@
-gemmlowp: a small self-contained low-precision GEMM library
-===========================================================
-
-This is not a full linear algebra library, only a GEMM library: it only does
-general matrix multiplication ("GEMM").
-
-The meaning of "low precision" is detailed in this document:
- doc/low-precision.txt
-
-Some of the general design is explained in
- doc/design.txt
-
-
-Disclaimer
-==========
-
-This is not an official Google product (experimental or otherwise), it is just
-code that happens to be owned by Google.
-
-
-Mailing list
-============
-
-gemmlowp-related discussion, about either development or usage, is welcome
-on this Google Group (mailing list / forum):
-
- https://groups.google.com/forum/#!forum/gemmlowp
-
-
-Portability, target platforms/architectures
-===========================================
-
-Should be portable to any platform with some C++11 and POSIX support,
-while we have optional optimized code paths for specific architectures.
-
-Required:
- C++11 (a small conservative subset of it)
-
-Required for some features:
- * Some POSIX interfaces:
- * pthreads (for multi-threaded operation and for profiling).
- * sysconf (for multi-threaded operation to detect number of cores;
- may be bypassed).
-
-Optional:
- Architecture-specific code paths use intrinsics or inline assembly.
- See "Architecture-specific optimized code paths" below.
-
-Architecture-specific optimized code paths
-==========================================
-
-We have some optimized code paths for specific instruction sets.
-Some are written in inline assembly, some are written in C++ using
-intrinsics. Both GCC and Clang are supported.
-
-At the moment, we have a full set of optimized code paths (kernels,
-packing and unpacking paths) only for ARM NEON, supporting both
-ARMv7 (32bit) and ARMv8 (64bit).
-
-We also have a partial set of optimized code paths (only kernels
-at the moment) for Intel SSE. It supports both x86 and x86-64 but
-only targets SSE4. The lack of packing/unpacking code paths means
-that performance isn't optimal yet.
-
-Details of what it takes to make an efficient port of gemmlowp, namely
-writing a suitable GEMM kernel and accompanying packing code, are
-explained in this file:
- doc/kernels.txt
-
-
-Public interfaces
-=================
-
-1. gemmlowp public interface
-----------------------------
-
- gemmlowp's main public interface is in the public/ subdirectory. The
- header to include is
- public/gemmlowp.h.
- This is a headers-only library, so there is nothing to link to.
-
-2. EightBitIntGemm standard interface
--------------------------------------
-
- Additionally, the eight_bit_int_gemm/ subdirectory provides an
- implementation of the standard EightBitIntGemm interface. The header
- to include is
- eight_bit_int_gemm/eight_bit_int_gemm.h
- This is *NOT* a headers-only library, users need to link to
- eight_bit_int_gemm/eight_bit_int_gemm.cc.
- The API is similar to the standard BLAS GEMM interface, and implements
- C = A * B. If the transpose flags for a matrix argument are false, its memory
- order is treated as column major, and row major if its true.
-
-
-Building
-========
-
-Building by manually invoking your compiler
--------------------------------------------
-
-Because gemmlowp is so simple, working with it involves only
-single-command-line compiler invokations. Therefore we expect that
-most people working with gemmlowp will either manually invoke their
-compiler, or write their own rules for their own preferred build
-system.
-
-Keep in mind (previous section) that gemmlowp itself is a pure-headers-only
-library so there is nothing to build, and the eight_bit_int_gemm library
-consists of a single eight_bit_int_gemm.cc file to build.
-
-For a Android gemmlowp development workflow, the scripts/ directory
-contains a script to build and run a program on an Android device:
- scripts/test-android.sh
-
-Building using Bazel
---------------------
-
-That being said, we also maintain a Bazel BUILD system as part of
-gemmlowp. Its usage is not mandatory at all and is only one
-possible way that gemmlowp libraries and tests may be built. If
-you are interested, Bazel's home page is
- http://bazel.io/
-And you can get started with using Bazel to build gemmlowp targets
-by first creating an empty WORKSPACE file in a parent directory,
-for instance:
-
-$ cd gemmlowp/.. # change to parent directory containing gemmlowp/
-$ touch WORKSPACE # declare that to be our workspace root
-$ bazel build gemmlowp:all
-
-
-Testing
-=======
-
-Testing by manually building and running tests
-----------------------------------------------
-
-The test/ directory contains unit tests. The primary unit test is
- test/test.cc
-Since it covers also the EightBitIntGemm interface, it needs to be
-linked against
- eight_bit_int_gemm/eight_bit_int_gemm.cc
-It also uses realistic data captured from a neural network run in
- test/test_data.cc
-
-Thus you'll want to pass the following list of source files to your
-compiler/linker:
- test/test.cc
- eight_bit_int_gemm/eight_bit_int_gemm.cc
- test/test_data.cc
-
-The scripts/ directory contains a script to build and run a program
-on an Android device:
- scripts/test-android.sh
-
-It expects the CXX environment variable to point to an Android toolchain's
-C++ compiler, and expects source files (and optionally, cflags) as
-command-line parameters. To build and run the above-mentioned main unit test,
-first set CXX e.g.:
-
-$ export CXX=/some/toolchains/arm-linux-androideabi-4.8/bin/arm-linux-androideabi-g++
-
-Then run:
-
-$ ./scripts/test-android.sh \
-test/test.cc \
-eight_bit_int_gemm/eight_bit_int_gemm.cc \
-test/test_data.cc
-
-
-Testing using Bazel
--------------------
-
-Alternatively, you can use Bazel to build and run tests. See the Bazel
-instruction in the above section on building. Once your Bazel workspace
-is set up, you can for instance do:
-
-$ bazel test gemmlowp:all
-
-
-Troubleshooting Compilation
-===========================
-
-If you're having trouble finding the compiler, follow these instructions to
-build a standalone toolchain:
-https://developer.android.com/ndk/guides/standalone_toolchain.html
-
-Here's an example of setting up Clang 3.5:
-
-$ export INSTALL_DIR=~/toolchains/clang-21-stl-gnu
-$ $NDK/build/tools/make-standalone-toolchain.sh \
---toolchain=arm-linux-androideabi-clang3.5 --platform=android-21 \
---install-dir=$INSTALL_DIR
-$ export CXX="$INSTALL_DIR/bin/arm-linux-androideabi-g++ \
---sysroot=$INSTALL_DIR/sysroot"
-
-Some compilers (e.g. the default clang++ in the same bin directory) don't
-support NEON assembly. The benchmark build process will issue a warning if
-support isn't detected, and you should make sure you're using a compiler like
-arm-linux-androideabi-g++ that does include NEON.
-
-
-Benchmarking
-============
-
-The main benchmark is
- benchmark.cc
-It doesn't need to be linked to any
-other source file. We recommend building with assertions disabled (-DNDEBUG).
-
-For example, the benchmark can be built and run on an Android device by doing:
-
-$ ./scripts/test-android.sh test/benchmark.cc -DNDEBUG
-
-If GEMMLOWP_TEST_PROFILE is defined then the benchmark will be built with
-profiling instrumentation (which makes it slower) and will dump profiles.
-See next section on profiling.
-
-
-Profiling
-=========
-
-The profiling/ subdirectory offers a very simple non-interrupting sampling
-profiler that only requires pthreads (no signals).
-
-It relies on source code being instrumented with pseudo-stack labels.
-See profiling/instrumentation.h.
-A full example of using this profiler is given in profiling/profiler.h.
-
-
-Contributing
-============
-
-Contribution-related discussion is always welcome on the gemmlowp
-mailing list (see above).
-
-We try to keep a current list of TODO items in the todo/ directory.
-Prospective contributors are welcome to pick one to work on, and
-communicate about it on the gemmlowp mailing list.
-
-Details of the contributing process, including legalese, are in CONTRIBUTING.
-
-Performance goals
-=================
-
-Our performance goals differ from typical GEMM performance goals in the
-following ways:
-
-1. We care not only about speed, but also about minimizing power usage.
- We specifically care about charge usage in mobile/embedded devices.
- This implies that we care doubly about minimizing memory bandwidth usage:
- we care about it, like any GEMM, because of the impact on speed, and we
- also care about it because it is a key factor of power usage.
-
-2. Most GEMMs are optimized primarily for large dense matrix sizes (>= 1000).
- We do care about large sizes, but we also care specifically about the
- typically smaller matrix sizes encountered in various mobile applications.
- This means that we have to optimize for all sizes, not just for large enough
- sizes.
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h
index 58e8050..56e95c0 100644
--- a/fixedpoint/fixedpoint.h
+++ b/fixedpoint/fixedpoint.h
@@ -95,12 +95,13 @@ tIntegerType Add(tIntegerType a, tIntegerType b) {
return a + b;
}
-// Integer subtraction. Not saturating. Overflow is undefined behavior.
+// Integer multiplication. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Mul(tIntegerType a, tIntegerType b) {
return a * b;
}
+// Integer subtraction. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Sub(tIntegerType a, tIntegerType b) {
return a - b;
@@ -268,6 +269,16 @@ inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
std::max(static_cast<std::int32_t>(-32768), sum)));
}
+template <>
+inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) {
+ std::int16_t a16 = a;
+ std::int16_t b16 = b;
+ std::int16_t sum = a16 + b16;
+ return static_cast<std::int8_t>(std::min(
+ static_cast<int16_t>(std::numeric_limits<int8_t>::max()),
+ std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum)));
+}
+
// Returns a+b, saturating if the integers are 16bit or narrower,
// otherwise just a plain addition.
template <typename IntegerType, bool Is16Bit>
@@ -767,13 +778,14 @@ FixedPoint<tRawType, 0> exp_on_negative_values(
result * kMultiplier, result); \
}
- GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
- GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
- GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
- GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
- GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
- GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
- GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
+ // Constants below are Q0 representations of negative exp fractionals:
+ GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); // exp(-1/4)
+ GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); // exp(-1/2)
+ GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); // exp(-1)
+ GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); // exp(-2)
+ GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); // exp(-4)
+ GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); // exp(-8)
+ GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); // exp(-16)
#undef GEMMLOWP_EXP_BARREL_SHIFTER
@@ -895,6 +907,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
#include "./fixedpoint_sse.h"
#elif defined(GEMMLOWP_MSA)
#include "./fixedpoint_msa.h"
+#elif defined(GEMMLOWP_WASMSIMD)
+#include "./fixedpoint_wasmsimd.h"
#endif
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h
index 1816386..f3fe732 100644
--- a/fixedpoint/fixedpoint_avx.h
+++ b/fixedpoint/fixedpoint_avx.h
@@ -17,69 +17,139 @@
#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
#define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
-#include <smmintrin.h>
+#include <immintrin.h>
#include "fixedpoint.h"
#include "fixedpoint_sse.h"
namespace gemmlowp {
+struct int16x16_m256i {
+ __m256i v;
+};
+
+// Keep int16x16_m256i trivially constructible/destructible and provide
+// easily optimized helper function.
+inline int16x16_m256i to_int16x16_m256i(__m256i w) {
+ int16x16_m256i r;
+ r.v = w;
+ return r;
+}
+
template <>
struct FixedPointRawTypeTraits<__m256i> {
typedef std::int32_t ScalarRawType;
+ // TODO: This can actually support up to 8 lanes, so we should either
+ // change to 8 or create int32x8_m256i struct to handle that case.
static const int kLanes = 4;
};
template <>
+struct FixedPointRawTypeTraits<int16x16_m256i> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 16;
+};
+
+template <>
inline __m256i BitAnd(__m256i a, __m256i b) {
return _mm256_and_si256(a, b);
}
template <>
+inline int16x16_m256i BitAnd(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_and_si256(a.v, b.v));
+}
+
+template <>
inline __m256i BitOr(__m256i a, __m256i b) {
return _mm256_or_si256(a, b);
}
template <>
+inline int16x16_m256i BitOr(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_or_si256(a.v, b.v));
+}
+
+template <>
inline __m256i BitXor(__m256i a, __m256i b) {
return _mm256_xor_si256(a, b);
}
template <>
+inline int16x16_m256i BitXor(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_xor_si256(a.v, b.v));
+}
+
+template <>
inline __m256i BitNot(__m256i a) {
return _mm256_andnot_si256(a, _mm256_set1_epi32(-1));
}
template <>
+inline int16x16_m256i BitNot(int16x16_m256i a) {
+ return to_int16x16_m256i(_mm256_andnot_si256(a.v, _mm256_set1_epi16(-1)));
+}
+
+template <>
inline __m256i Add(__m256i a, __m256i b) {
return _mm256_add_epi32(a, b);
}
template <>
+inline int16x16_m256i Add(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_add_epi16(a.v, b.v));
+}
+
+template <>
inline __m256i Mul(__m256i a, __m256i b) {
return _mm256_mullo_epi32(a, b);
}
template <>
+inline int16x16_m256i Mul(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_mullo_epi16(a.v, b.v));
+}
+
+template <>
inline __m256i Sub(__m256i a, __m256i b) {
return _mm256_sub_epi32(a, b);
}
template <>
+inline int16x16_m256i Sub(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_sub_epi16(a.v, b.v));
+}
+
+template <>
inline __m256i Neg(__m256i a) {
return _mm256_sign_epi32(a, _mm256_set1_epi32(-1));
}
template <>
+inline int16x16_m256i Neg(int16x16_m256i a) {
+ return to_int16x16_m256i(_mm256_sign_epi16(a.v, _mm256_set1_epi16(-1)));
+}
+
+template <>
inline __m256i ShiftLeft(__m256i a, int offset) {
return _mm256_slli_epi32(a, offset);
}
template <>
+inline int16x16_m256i ShiftLeft(int16x16_m256i a, int offset) {
+ return to_int16x16_m256i(_mm256_slli_epi16(a.v, offset));
+}
+
+template <>
inline __m256i ShiftRight(__m256i a, int offset) {
return _mm256_srai_epi32(a, offset);
}
template <>
+inline int16x16_m256i ShiftRight(int16x16_m256i a, int offset) {
+ return to_int16x16_m256i(_mm256_srai_epi16(a.v, offset));
+}
+
+template <>
inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val,
__m256i else_val) {
return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val),
@@ -88,45 +158,97 @@ inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val,
}
template <>
+inline int16x16_m256i SelectUsingMask(int16x16_m256i if_mask,
+ int16x16_m256i then_val,
+ int16x16_m256i else_val) {
+ // Borrowed from Intel's arm_neon_sse.h header.
+ return to_int16x16_m256i(
+ _mm256_or_si256(_mm256_and_si256(if_mask.v, then_val.v),
+ _mm256_andnot_si256(if_mask.v, else_val.v)));
+}
+
+template <>
inline __m256i MaskIfEqual(__m256i a, __m256i b) {
return _mm256_cmpeq_epi32(a, b);
}
template <>
+inline int16x16_m256i MaskIfEqual(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_cmpeq_epi16(a.v, b.v));
+}
+
+template <>
inline __m256i MaskIfNotEqual(__m256i a, __m256i b) {
return BitNot(MaskIfEqual(a, b));
}
template <>
+inline int16x16_m256i MaskIfNotEqual(int16x16_m256i a, int16x16_m256i b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
inline __m256i MaskIfZero(__m256i a) {
return MaskIfEqual(a, _mm256_set1_epi32(0));
}
template <>
+inline int16x16_m256i MaskIfZero(int16x16_m256i a) {
+ return MaskIfEqual(a, to_int16x16_m256i(_mm256_set1_epi16(0)));
+}
+
+template <>
inline __m256i MaskIfNonZero(__m256i a) {
return MaskIfNotEqual(a, _mm256_set1_epi32(0));
}
template <>
+inline int16x16_m256i MaskIfNonZero(int16x16_m256i a) {
+ return MaskIfNotEqual(a, to_int16x16_m256i(_mm256_set1_epi16(0)));
+}
+
+template <>
inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) {
return _mm256_cmpgt_epi32(a, b);
}
template <>
+inline int16x16_m256i MaskIfGreaterThan(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_cmpgt_epi16(a.v, b.v));
+}
+
+template <>
inline __m256i MaskIfLessThan(__m256i a, __m256i b) {
return _mm256_cmpgt_epi32(b, a);
}
template <>
+inline int16x16_m256i MaskIfLessThan(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_cmpgt_epi16(b.v, a.v));
+}
+
+template <>
inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) {
return BitNot(MaskIfLessThan(a, b));
}
template <>
+inline int16x16_m256i MaskIfGreaterThanOrEqual(int16x16_m256i a,
+ int16x16_m256i b) {
+ return BitNot(MaskIfLessThan(a, b));
+}
+
+template <>
inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) {
return BitNot(MaskIfGreaterThan(a, b));
}
+template <>
+inline int16x16_m256i MaskIfLessThanOrEqual(int16x16_m256i a,
+ int16x16_m256i b) {
+ return BitNot(MaskIfGreaterThan(a, b));
+}
+
/* Assumptions:
- All and Any are used on masks.
- masks are all_ones for true lanes, all_zeroes otherwise.
@@ -139,11 +261,21 @@ inline bool All(__m256i a) {
}
template <>
+inline bool All(int16x16_m256i a) {
+ return _mm256_testc_si256(a.v, a.v);
+}
+
+template <>
inline bool Any(__m256i a) {
return BitNot(_mm256_testz_si256(a, a));
}
template <>
+inline bool Any(int16x16_m256i a) {
+ return BitNot(_mm256_testz_si256(a.v, a.v));
+}
+
+template <>
inline __m256i RoundingHalfSum(__m256i a, __m256i b) {
/* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */
/* We divide the inputs before the add to avoid the overflow and costly test
@@ -171,6 +303,17 @@ inline __m256i RoundingHalfSum(__m256i a, __m256i b) {
}
template <>
+inline int16x16_m256i RoundingHalfSum(int16x16_m256i a, int16x16_m256i b) {
+ // Borrowed from Intel's arm_neon_sse.h header.
+ __m256i constant_neg_32768 = _mm256_set1_epi16(-32768);
+ __m256i a_unsigned = _mm256_sub_epi16(a.v, constant_neg_32768);
+ __m256i b_unsigned = _mm256_sub_epi16(b.v, constant_neg_32768);
+ __m256i avg_unsigned = _mm256_avg_epu16(a_unsigned, b_unsigned);
+ __m256i avg = _mm256_add_epi16(avg_unsigned, constant_neg_32768);
+ return to_int16x16_m256i(avg);
+}
+
+template <>
inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) {
__m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
__m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
@@ -209,10 +352,33 @@ inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) {
}
template <>
+inline int16x16_m256i SaturatingRoundingDoublingHighMul(int16x16_m256i a,
+ int16x16_m256i b) {
+ // Use _mm256_mulhrs_epi16 then saturate with a bit-operation,
+ // borrowed from Intel's arm_neon_sse.h header.
+ __m256i result_unsaturated = _mm256_mulhrs_epi16(a.v, b.v);
+ __m256i saturation_mask =
+ _mm256_cmpeq_epi16(result_unsaturated, _mm256_set1_epi16(0x8000));
+ __m256i result = _mm256_xor_si256(result_unsaturated, saturation_mask);
+ return to_int16x16_m256i(result);
+}
+
+template <>
inline __m256i Dup<__m256i>(std::int32_t x) {
return _mm256_set1_epi32(x);
}
+template <>
+inline int16x16_m256i Dup<int16x16_m256i>(std::int16_t x) {
+ return to_int16x16_m256i(_mm256_set1_epi16(x));
+}
+
+// So far this is only needed for int16.
+template <>
+inline int16x16_m256i SaturatingAdd(int16x16_m256i a, int16x16_m256i b) {
+ return to_int16x16_m256i(_mm256_adds_epi16(a.v, b.v));
+}
+
} // end namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h
index a1fae32..fbaa26a 100644
--- a/fixedpoint/fixedpoint_sse.h
+++ b/fixedpoint/fixedpoint_sse.h
@@ -32,13 +32,17 @@ namespace gemmlowp {
// data type, int16x8_m128i, that wraps __m128i while being a separate
// type.
struct int16x8_m128i {
- int16x8_m128i() {}
- explicit int16x8_m128i(__m128i w) : v(w) {}
- ~int16x8_m128i() {}
-
__m128i v;
};
+// Keep int16x8_m128i trivially constructible/destructible and provide
+// easily optimized helper function.
+inline int16x8_m128i to_int16x8_m128i(__m128i w) {
+ int16x8_m128i r;
+ r.v = w;
+ return r;
+}
+
template <>
struct FixedPointRawTypeTraits<__m128i> {
typedef std::int32_t ScalarRawType;
@@ -58,7 +62,7 @@ inline __m128i BitAnd(__m128i a, __m128i b) {
template <>
inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_and_si128(a.v, b.v));
+ return to_int16x8_m128i(_mm_and_si128(a.v, b.v));
}
template <>
@@ -68,7 +72,7 @@ inline __m128i BitOr(__m128i a, __m128i b) {
template <>
inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_or_si128(a.v, b.v));
+ return to_int16x8_m128i(_mm_or_si128(a.v, b.v));
}
template <>
@@ -78,7 +82,7 @@ inline __m128i BitXor(__m128i a, __m128i b) {
template <>
inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_xor_si128(a.v, b.v));
+ return to_int16x8_m128i(_mm_xor_si128(a.v, b.v));
}
template <>
@@ -88,7 +92,7 @@ inline __m128i BitNot(__m128i a) {
template <>
inline int16x8_m128i BitNot(int16x8_m128i a) {
- return int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1)));
+ return to_int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1)));
}
template <>
@@ -98,7 +102,7 @@ inline __m128i Add(__m128i a, __m128i b) {
template <>
inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_add_epi16(a.v, b.v));
+ return to_int16x8_m128i(_mm_add_epi16(a.v, b.v));
}
template <>
@@ -108,7 +112,7 @@ inline __m128i Mul(__m128i a, __m128i b) {
template <>
inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_mullo_epi16(a.v, b.v));
+ return to_int16x8_m128i(_mm_mullo_epi16(a.v, b.v));
}
template <>
@@ -118,7 +122,7 @@ inline __m128i Sub(__m128i a, __m128i b) {
template <>
inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_sub_epi16(a.v, b.v));
+ return to_int16x8_m128i(_mm_sub_epi16(a.v, b.v));
}
template <>
@@ -128,7 +132,7 @@ inline __m128i Neg(__m128i a) {
template <>
inline int16x8_m128i Neg(int16x8_m128i a) {
- return int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1)));
+ return to_int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1)));
}
template <>
@@ -138,7 +142,7 @@ inline __m128i ShiftLeft(__m128i a, int offset) {
template <>
inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) {
- return int16x8_m128i(_mm_slli_epi16(a.v, offset));
+ return to_int16x8_m128i(_mm_slli_epi16(a.v, offset));
}
template <>
@@ -148,7 +152,7 @@ inline __m128i ShiftRight(__m128i a, int offset) {
template <>
inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) {
- return int16x8_m128i(_mm_srai_epi16(a.v, offset));
+ return to_int16x8_m128i(_mm_srai_epi16(a.v, offset));
}
template <>
@@ -164,7 +168,7 @@ inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask,
int16x8_m128i then_val,
int16x8_m128i else_val) {
// borrowed from Intel's arm_neon_sse.h header.
- return int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v));
+ return to_int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v));
}
template <>
@@ -174,7 +178,7 @@ inline __m128i MaskIfEqual(__m128i a, __m128i b) {
template <>
inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v));
+ return to_int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v));
}
template <>
@@ -194,7 +198,7 @@ inline __m128i MaskIfZero(__m128i a) {
template <>
inline int16x8_m128i MaskIfZero(int16x8_m128i a) {
- return MaskIfEqual(a, int16x8_m128i(_mm_set1_epi16(0)));
+ return MaskIfEqual(a, to_int16x8_m128i(_mm_set1_epi16(0)));
}
template <>
@@ -204,7 +208,7 @@ inline __m128i MaskIfNonZero(__m128i a) {
template <>
inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) {
- return MaskIfNotEqual(a, int16x8_m128i(_mm_set1_epi16(0)));
+ return MaskIfNotEqual(a, to_int16x8_m128i(_mm_set1_epi16(0)));
}
template <>
@@ -214,7 +218,7 @@ inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) {
template <>
inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v));
+ return to_int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v));
}
template <>
@@ -224,7 +228,7 @@ inline __m128i MaskIfLessThan(__m128i a, __m128i b) {
template <>
inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_cmplt_epi16(a.v, b.v));
+ return to_int16x8_m128i(_mm_cmplt_epi16(a.v, b.v));
}
template <>
@@ -310,7 +314,7 @@ inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) {
__m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768);
__m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned);
__m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768);
- return int16x8_m128i(avg);
+ return to_int16x8_m128i(avg);
}
template <>
@@ -360,7 +364,7 @@ inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a,
__m128i saturation_mask =
_mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000));
__m128i result = _mm_xor_si128(result_unsaturated, saturation_mask);
- return int16x8_m128i(result);
+ return to_int16x8_m128i(result);
}
template <>
@@ -370,13 +374,13 @@ inline __m128i Dup<__m128i>(std::int32_t x) {
template <>
inline int16x8_m128i Dup<int16x8_m128i>(std::int16_t x) {
- return int16x8_m128i(_mm_set1_epi16(x));
+ return to_int16x8_m128i(_mm_set1_epi16(x));
}
// So far this is only needed for int16.
template <>
inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) {
- return int16x8_m128i(_mm_adds_epi16(a.v, b.v));
+ return to_int16x8_m128i(_mm_adds_epi16(a.v, b.v));
}
} // end namespace gemmlowp
diff --git a/fixedpoint/fixedpoint_wasmsimd.h b/fixedpoint/fixedpoint_wasmsimd.h
new file mode 100644
index 0000000..868fbfe
--- /dev/null
+++ b/fixedpoint/fixedpoint_wasmsimd.h
@@ -0,0 +1,381 @@
+// Copyright 2020 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// fixedpoint_wasmsimd.h: optimized WAsm SIMD specializations of the templates
+// in fixedpoint.h.
+
+#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_
+#define GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_
+
+#include <wasm_simd128.h>
+
+namespace gemmlowp {
+
+// WAsm SIMD intrinsics are not typed: there is a single v128_t vector
+// type that does not distinguish between "int32x4" and "int16x8" use
+// cases, unlike the NEON equivalents. Because we had initially focused
+// on int32x4, we did not pay attention and specialized these fixedpoint
+// templates directly for v128_t hardcoding the int32x4 semantics,
+// not leaving room for int16x8 semantics. Amending that by adding a separate
+// data type, int16x8_v128_t, that wraps v128_t while being a separate
+// type.
+struct int16x8_v128_t {
+ v128_t v;
+};
+
+// Keep int16x8_v128_t trivially constructible/destructible and provide
+// easily optimized helper function.
+inline int16x8_v128_t to_int16x8_v128_t(v128_t w) {
+ int16x8_v128_t r;
+ r.v = w;
+ return r;
+}
+
+template <>
+struct FixedPointRawTypeTraits<v128_t> {
+ typedef std::int32_t ScalarRawType;
+ static constexpr int kLanes = 4;
+};
+
+template <>
+struct FixedPointRawTypeTraits<int16x8_v128_t> {
+ typedef std::int16_t ScalarRawType;
+ static constexpr int kLanes = 8;
+};
+
+template <>
+inline v128_t BitAnd(v128_t a, v128_t b) {
+ return wasm_v128_and(a, b);
+}
+
+template <>
+inline int16x8_v128_t BitAnd(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_v128_and(a.v, b.v));
+}
+
+template <>
+inline v128_t BitOr(v128_t a, v128_t b) {
+ return wasm_v128_or(a, b);
+}
+
+template <>
+inline int16x8_v128_t BitOr(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_v128_or(a.v, b.v));
+}
+
+template <>
+inline v128_t BitXor(v128_t a, v128_t b) {
+ return wasm_v128_xor(a, b);
+}
+
+template <>
+inline int16x8_v128_t BitXor(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_v128_xor(a.v, b.v));
+}
+
+template <>
+inline v128_t BitNot(v128_t a) {
+ return wasm_v128_not(a);
+}
+
+template <>
+inline int16x8_v128_t BitNot(int16x8_v128_t a) {
+ return to_int16x8_v128_t(wasm_v128_not(a.v));
+}
+
+template <>
+inline v128_t Add(v128_t a, v128_t b) {
+ return wasm_i32x4_add(a, b);
+}
+
+template <>
+inline int16x8_v128_t Add(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_add(a.v, b.v));
+}
+
+template <>
+inline v128_t Mul(v128_t a, v128_t b) {
+ return wasm_i32x4_mul(a, b);
+}
+
+template <>
+inline int16x8_v128_t Mul(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_mul(a.v, b.v));
+}
+
+template <>
+inline v128_t Sub(v128_t a, v128_t b) {
+ return wasm_i32x4_sub(a, b);
+}
+
+template <>
+inline int16x8_v128_t Sub(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_sub(a.v, b.v));
+}
+
+template <>
+inline v128_t Neg(v128_t a) {
+ return wasm_i32x4_neg(a);
+}
+
+template <>
+inline int16x8_v128_t Neg(int16x8_v128_t a) {
+ return to_int16x8_v128_t(wasm_i16x8_neg(a.v));
+}
+
+template <>
+inline v128_t ShiftLeft(v128_t a, int offset) {
+ return wasm_i32x4_shl(a, offset);
+}
+
+template <>
+inline int16x8_v128_t ShiftLeft(int16x8_v128_t a, int offset) {
+ return to_int16x8_v128_t(wasm_i16x8_shl(a.v, offset));
+}
+
+template <>
+inline v128_t ShiftRight(v128_t a, int offset) {
+ return wasm_i32x4_shr(a, offset);
+}
+
+template <>
+inline int16x8_v128_t ShiftRight(int16x8_v128_t a, int offset) {
+ return to_int16x8_v128_t(wasm_i16x8_shr(a.v, offset));
+}
+
+template <>
+inline v128_t SelectUsingMask(v128_t if_mask, v128_t then_val,
+ v128_t else_val) {
+ return wasm_v128_bitselect(then_val, else_val, if_mask);
+}
+
+template <>
+inline int16x8_v128_t SelectUsingMask(int16x8_v128_t if_mask,
+ int16x8_v128_t then_val,
+ int16x8_v128_t else_val) {
+ return to_int16x8_v128_t(
+ wasm_v128_bitselect(then_val.v, else_val.v, if_mask.v));
+}
+
+template <>
+inline v128_t MaskIfEqual(v128_t a, v128_t b) {
+ return wasm_i32x4_eq(a, b);
+}
+
+template <>
+inline int16x8_v128_t MaskIfEqual(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_eq(a.v, b.v));
+}
+
+template <>
+inline v128_t MaskIfNotEqual(v128_t a, v128_t b) {
+ return wasm_i32x4_ne(a, b);
+}
+
+template <>
+inline int16x8_v128_t MaskIfNotEqual(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_ne(a.v, b.v));
+}
+
+template <>
+inline v128_t MaskIfZero(v128_t a) {
+ return MaskIfEqual(a, wasm_i32x4_const(0, 0, 0, 0));
+}
+
+template <>
+inline int16x8_v128_t MaskIfZero(int16x8_v128_t a) {
+ return MaskIfEqual(
+ a, to_int16x8_v128_t(wasm_i16x8_const(0, 0, 0, 0, 0, 0, 0, 0)));
+}
+
+template <>
+inline v128_t MaskIfNonZero(v128_t a) {
+ return MaskIfNotEqual(a, wasm_i32x4_const(0, 0, 0, 0));
+}
+
+template <>
+inline int16x8_v128_t MaskIfNonZero(int16x8_v128_t a) {
+ return MaskIfNotEqual(
+ a, to_int16x8_v128_t(wasm_i16x8_const(0, 0, 0, 0, 0, 0, 0, 0)));
+}
+
+template <>
+inline v128_t MaskIfGreaterThan(v128_t a, v128_t b) {
+ return wasm_i32x4_gt(a, b);
+}
+
+template <>
+inline int16x8_v128_t MaskIfGreaterThan(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_gt(a.v, b.v));
+}
+
+template <>
+inline v128_t MaskIfLessThan(v128_t a, v128_t b) {
+ return wasm_i32x4_lt(a, b);
+}
+
+template <>
+inline int16x8_v128_t MaskIfLessThan(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_lt(a.v, b.v));
+}
+
+template <>
+inline v128_t MaskIfGreaterThanOrEqual(v128_t a, v128_t b) {
+ return wasm_i32x4_ge(a, b);
+}
+
+template <>
+inline int16x8_v128_t MaskIfGreaterThanOrEqual(int16x8_v128_t a,
+ int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_ge(a.v, b.v));
+}
+
+template <>
+inline v128_t MaskIfLessThanOrEqual(v128_t a, v128_t b) {
+ return wasm_i32x4_le(a, b);
+}
+
+template <>
+inline int16x8_v128_t MaskIfLessThanOrEqual(int16x8_v128_t a,
+ int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_le(a.v, b.v));
+}
+
+/* Assumptions:
+ - All and Any are used on masks.
+ - masks are all_ones for true lanes, all_zeroes otherwise.
+Hence, All means all 128bits set, and Any means any bit set.
+*/
+
+template <>
+inline bool All(v128_t a) {
+ return wasm_i32x4_all_true(a);
+}
+
+template <>
+inline bool All(int16x8_v128_t a) {
+ return wasm_i16x8_all_true(a.v);
+}
+
+template <>
+inline bool Any(v128_t a) {
+ return wasm_i32x4_any_true(a);
+}
+
+template <>
+inline bool Any(int16x8_v128_t a) {
+ return wasm_i16x8_any_true(a.v);
+}
+
+template <>
+inline v128_t RoundingHalfSum(v128_t a, v128_t b) {
+ // We divide the inputs before the add to avoid the overflow and costly test.
+ const v128_t one = wasm_i32x4_const(1, 1, 1, 1);
+ const v128_t sign_bit_mask =
+ wasm_i32x4_const(0x80000000, 0x80000000, 0x80000000, 0x80000000);
+ const v128_t sum = Add(a, b);
+ const v128_t rounded_half_sum = ShiftRight(Add(sum, one), 1);
+ const v128_t overflow =
+ BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)),
+ sign_bit_mask);
+ const v128_t result = BitXor(rounded_half_sum, overflow);
+ return result;
+}
+
+template <>
+inline int16x8_v128_t RoundingHalfSum(int16x8_v128_t a, int16x8_v128_t b) {
+ // Idea: go to unsigned to use wasm_u16x8_avgr,
+ // borrowed from Intel's arm_neon_sse.h header.
+ const v128_t constant_neg_32768 = wasm_i16x8_const(
+ -32768, -32768, -32768, -32768, -32768, -32768, -32768, -32768);
+ const v128_t a_unsigned = wasm_v128_xor(a.v, constant_neg_32768);
+ const v128_t b_unsigned = wasm_v128_xor(b.v, constant_neg_32768);
+ const v128_t avg_unsigned = wasm_u16x8_avgr(a_unsigned, b_unsigned);
+ const v128_t avg = wasm_v128_xor(avg_unsigned, constant_neg_32768);
+ return to_int16x8_v128_t(avg);
+}
+
+template <>
+inline v128_t SaturatingRoundingDoublingHighMul(v128_t a, v128_t b) {
+ // TODO: switch to extended multiplication once implemented in the toolchain
+ const v128_t a_sign = wasm_i32x4_shr(a, 31);
+ const v128_t b_sign = wasm_i32x4_shr(b, 31);
+
+ const v128_t a_ext_lo = wasm_v32x4_shuffle(a, a_sign, 0, 4, 1, 5);
+ const v128_t a_ext_hi = wasm_v32x4_shuffle(a, a_sign, 2, 6, 3, 7);
+ const v128_t b_ext_lo = wasm_v32x4_shuffle(b, b_sign, 0, 4, 1, 5);
+ const v128_t b_ext_hi = wasm_v32x4_shuffle(b, b_sign, 2, 6, 3, 7);
+
+ const v128_t ab_lo = wasm_i64x2_mul(a_ext_lo, b_ext_lo);
+ const v128_t ab_hi = wasm_i64x2_mul(a_ext_hi, b_ext_hi);
+
+ const v128_t nudge_2x =
+ wasm_i64x2_const(INT64_C(0x80000000), INT64_C(0x80000000));
+ const v128_t ab_lo_2x = wasm_i64x2_add(ab_lo, ab_lo);
+ const v128_t ab_hi_2x = wasm_i64x2_add(ab_hi, ab_hi);
+
+ const v128_t ab_lo_rounded_2x = wasm_i64x2_add(ab_lo_2x, nudge_2x);
+ const v128_t ab_hi_rounded_2x = wasm_i64x2_add(ab_hi_2x, nudge_2x);
+
+ const v128_t prod =
+ wasm_v32x4_shuffle(ab_lo_rounded_2x, ab_hi_rounded_2x, 1, 3, 5, 7);
+
+ // Saturation only happen if a == b == INT_MIN, and this is the only case
+ // where prod == INT_MIN (0x80000000) instead of INT_MAX (0x7FFFFFFF).
+ const v128_t min = wasm_i32x4_const(INT32_C(0x80000000), INT32_C(0x80000000),
+ INT32_C(0x80000000), INT32_C(0x80000000));
+
+ return wasm_v128_xor(prod, wasm_i32x4_eq(prod, min));
+}
+
+template <>
+inline int16x8_v128_t SaturatingRoundingDoublingHighMul(int16x8_v128_t a,
+ int16x8_v128_t b) {
+#if 0
+ // TODO: enable if https://github.com/WebAssembly/simd/pull/365 is accepted
+ return to_int16x8_v128_t(__builtin_wasm_q15mulr_saturate_s_i16x8(a.v, b.v));
+#else
+ // TODO: switch to extended multiplication once implemented in the toolchain
+ v128_t lo = wasm_i32x4_mul(wasm_i32x4_widen_low_i16x8(a.v),
+ wasm_i32x4_widen_low_i16x8(b.v));
+ v128_t hi = wasm_i32x4_mul(wasm_i32x4_widen_high_i16x8(a.v),
+ wasm_i32x4_widen_high_i16x8(b.v));
+ const v128_t inc = wasm_i32x4_const(0x4000, 0x4000, 0x4000, 0x4000);
+ lo = wasm_i32x4_add(lo, inc);
+ hi = wasm_i32x4_add(hi, inc);
+ lo = wasm_i32x4_shr(lo, 15);
+ hi = wasm_i32x4_shr(hi, 15);
+ return to_int16x8_v128_t(wasm_i16x8_narrow_i32x4(lo, hi));
+#endif
+}
+
+template <>
+inline v128_t Dup<v128_t>(std::int32_t x) {
+ return wasm_i32x4_splat(x);
+}
+
+template <>
+inline int16x8_v128_t Dup<int16x8_v128_t>(std::int16_t x) {
+ return to_int16x8_v128_t(wasm_i16x8_splat(x));
+}
+
+// So far this is only needed for int16.
+template <>
+inline int16x8_v128_t SaturatingAdd(int16x8_v128_t a, int16x8_v128_t b) {
+ return to_int16x8_v128_t(wasm_i16x8_add_saturate(a.v, b.v));
+}
+
+} // end namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_
diff --git a/flags.bzl b/flags.bzl
index 16dba2d..e35fe9e 100644
--- a/flags.bzl
+++ b/flags.bzl
@@ -3,10 +3,9 @@ LIB_COPTS = []
LIB_LINKOPTS = select({
":android": [],
+ ":windows": [],
"//conditions:default": ["-lpthread"],
})
-BIN_LINKOPTS = select({
- ":android": [],
- "//conditions:default": ["-lpthread"],
-})
+BIN_LINKOPTS = LIB_LINKOPTS
+
diff --git a/internal/allocator.h b/internal/allocator.h
index 3a6f077..e71df15 100644
--- a/internal/allocator.h
+++ b/internal/allocator.h
@@ -86,11 +86,11 @@ class Allocator {
}
// Alignment of allocated blocks.
- static const std::size_t kAlignment = kDefaultCacheLineSize;
+ static constexpr std::size_t kAlignment = kDefaultCacheLineSize;
// This is all we need so far, and since the usage pattern is fixed,
// there is no point in allowing more until we need to.
- static const std::size_t kMaxBlocks = 5;
+ static constexpr std::size_t kMaxBlocks = 5;
void Commit() {
assert(!committed_);
diff --git a/internal/common.h b/internal/common.h
index 332ad07..708cc40 100644
--- a/internal/common.h
+++ b/internal/common.h
@@ -165,7 +165,7 @@ Integer RoundUpToPowerOfTwo(Integer n) {
template <int N>
struct IsPowerOfTwo {
- static const bool value = !(N & (N - 1));
+ static constexpr bool value = !(N & (N - 1));
};
template <typename T>
diff --git a/internal/detect_platform.h b/internal/detect_platform.h
index 6f06d19..7f0d78c 100644
--- a/internal/detect_platform.h
+++ b/internal/detect_platform.h
@@ -71,6 +71,11 @@
#define GEMMLOWP_X86
#endif
+// Detect WebAssembly SIMD.
+#if defined(__wasm_simd128__)
+#define GEMMLOWP_WASMSIMD
+#endif
+
// Some of our optimized paths use inline assembly and for
// now we don't bother enabling some other optimized paths using intrinddics
// where we can't use inline assembly paths.
diff --git a/internal/dispatch_gemm_shape.h b/internal/dispatch_gemm_shape.h
index ba4f341..b844f78 100644
--- a/internal/dispatch_gemm_shape.h
+++ b/internal/dispatch_gemm_shape.h
@@ -74,7 +74,8 @@ struct TransposeImpl<MatrixMap<Scalar, Order>> {
template <VectorShape Shape>
struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
- static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
+ static constexpr VectorShape TransposedShape =
+ TransposeVectorShape<Shape>::Value;
typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
static DstType Run(const SrcType& src) {
DstType dst;
@@ -88,7 +89,8 @@ struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
template <VectorShape Shape>
struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
- static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
+ static constexpr VectorShape TransposedShape =
+ TransposeVectorShape<Shape>::Value;
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
DstType;
static DstType Run(const SrcType& src) {
diff --git a/internal/kernel.h b/internal/kernel.h
index 3120216..f1a3fd8 100644
--- a/internal/kernel.h
+++ b/internal/kernel.h
@@ -126,11 +126,11 @@ enum class CellOrder { DepthMajor, WidthMajor, Diagonal };
// out in a cell. That is, a CellOrder together with actual dimensions.
template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor>
struct CellFormat {
- static const int kWidth = tWidth;
- static const int kDepth = tDepth;
- static const CellOrder kOrder = tOrder;
+ static constexpr int kWidth = tWidth;
+ static constexpr int kDepth = tDepth;
+ static constexpr CellOrder kOrder = tOrder;
- static const int kSize = kWidth * kDepth;
+ static constexpr int kSize = kWidth * kDepth;
};
// KernelSideFormat describes how data is laid out in a kernel side
@@ -142,9 +142,9 @@ struct CellFormat {
template <typename tCellFormat, int tCells>
struct KernelSideFormat {
typedef tCellFormat Cell;
- static const int kCells = tCells;
- static const int kWidth = kCells * Cell::kWidth;
- static const int kDepth = Cell::kDepth;
+ static constexpr int kCells = tCells;
+ static constexpr int kWidth = kCells * Cell::kWidth;
+ static constexpr int kDepth = Cell::kDepth;
typedef std::uint8_t Scalar; // The scalar type of the Format.
typedef std::uint8_t InputScalar; // The scalar type of the original input.
};
@@ -173,9 +173,9 @@ struct KernelFormat {
typedef tRhs Rhs;
static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, "");
- static const int kDepth = Lhs::Cell::kDepth;
- static const int kRows = Lhs::Cell::kWidth * Lhs::kCells;
- static const int kCols = Rhs::Cell::kWidth * Rhs::kCells;
+ static constexpr int kDepth = Lhs::Cell::kDepth;
+ static constexpr int kRows = Lhs::Cell::kWidth * Lhs::kCells;
+ static constexpr int kCols = Rhs::Cell::kWidth * Rhs::kCells;
};
inline const char* CellOrderName(CellOrder o) {
diff --git a/internal/output_sse.h b/internal/output_sse.h
index 75aebfd..6ea3290 100644
--- a/internal/output_sse.h
+++ b/internal/output_sse.h
@@ -535,6 +535,27 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
}
};
+// Specialization for MatrixMap, for performance.
+template <typename tScalar, MapOrder tOrder>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, MatrixMap<tScalar, tOrder>> {
+ static void Run(const RegBlockUint8<8, 8>& src,
+ MatrixMap<tScalar, tOrder>* dst, int row, int col) {
+ std::uint8_t buf[64];
+ StoreUint8x16(buf, src.buf.reg[0]);
+ StoreUint8x16(buf + 16, src.buf.reg[1]);
+ StoreUint8x16(buf + 32, src.buf.reg[2]);
+ StoreUint8x16(buf + 48, src.buf.reg[3]);
+ // Make a local copy so that the compiler can prove that data_ does not
+ // alias &data_ or &stride_.
+ MatrixMap<tScalar, tOrder> local = *dst;
+ for (int c = 0; c < 8; c++) {
+ for (int r = 0; r < 8; r++) {
+ *local.data(row + r, col + c) = buf[r + 8 * c];
+ }
+ }
+ }
+};
+
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
diff --git a/internal/pack.h b/internal/pack.h
index 7c43d6e..82f0dd1 100644
--- a/internal/pack.h
+++ b/internal/pack.h
@@ -143,7 +143,7 @@ template <typename tScalar, SideMapOrder tOrder>
class SideMap {
public:
typedef tScalar Scalar;
- static const SideMapOrder kOrder = tOrder;
+ static constexpr SideMapOrder kOrder = tOrder;
SideMap(Scalar* data, int width, int depth, int stride)
: data_(data), width_(width), depth_(depth), stride_(stride) {}
@@ -214,13 +214,13 @@ class PackingRegisterBlockBase {
typedef typename KernelSideFormat::Cell CellFormat;
typedef typename KernelSideFormat::InputScalar KernelInputScalar;
typedef typename KernelSideFormat::Scalar KernelScalar;
- static const int kCells = KernelSideFormat::kCells;
- static const int kCellWidth = CellFormat::kWidth;
- static const int kKernelWidth = CellFormat::kWidth * kCells;
- static const int kCellDepth = CellFormat::kDepth;
- static const int kCellSize = CellFormat::kSize;
- static const SideMapOrder kSrcOrder = SrcMapType::kOrder;
- static const int kZeroPointInputValue =
+ static constexpr int kCells = KernelSideFormat::kCells;
+ static constexpr int kCellWidth = CellFormat::kWidth;
+ static constexpr int kKernelWidth = CellFormat::kWidth * kCells;
+ static constexpr int kCellDepth = CellFormat::kDepth;
+ static constexpr int kCellSize = CellFormat::kSize;
+ static constexpr SideMapOrder kSrcOrder = SrcMapType::kOrder;
+ static constexpr int kZeroPointInputValue =
ZeroPointInputValue<KernelInputScalar, KernelScalar>::kValue;
PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {}
@@ -302,10 +302,10 @@ class PackSideBlockImpl {
public:
typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat;
typedef typename KernelSideFormat::Cell CellFormat;
- static const int kCells = KernelSideFormat::kCells;
- static const int kCellWidth = CellFormat::kWidth;
- static const int kKernelWidth = CellFormat::kWidth * kCells;
- static const int kCellDepth = CellFormat::kDepth;
+ static constexpr int kCells = KernelSideFormat::kCells;
+ static constexpr int kCellWidth = CellFormat::kWidth;
+ static constexpr int kKernelWidth = CellFormat::kWidth * kCells;
+ static constexpr int kCellDepth = CellFormat::kDepth;
typedef PackingRegisterBlock<SrcMapType, PackedSideBlock>
PackingRegisterBlockType;
diff --git a/internal/pack_sse.h b/internal/pack_sse.h
index 52163c4..b729014 100644
--- a/internal/pack_sse.h
+++ b/internal/pack_sse.h
@@ -41,11 +41,11 @@ class PackingRegisterBlock<
public:
typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
typedef typename KernelSideFormat::Cell CellFormat;
- static const int kCells = KernelSideFormat::kCells;
- static const int kCellWidth = CellFormat::kWidth;
- static const int kKernelWidth = CellFormat::kWidth * kCells;
- static const int kCellDepth = CellFormat::kDepth;
- static const int kCellSize = CellFormat::kSize;
+ static constexpr int kCells = KernelSideFormat::kCells;
+ static constexpr int kCellWidth = CellFormat::kWidth;
+ static constexpr int kKernelWidth = CellFormat::kWidth * kCells;
+ static constexpr int kCellDepth = CellFormat::kDepth;
+ static constexpr int kCellSize = CellFormat::kSize;
void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
std::uint8_t* dst_ptr = dst->current_data();
diff --git a/internal/platform.h b/internal/platform.h
index 54517c3..0f3a2b8 100644
--- a/internal/platform.h
+++ b/internal/platform.h
@@ -30,8 +30,7 @@
#include <sys/time.h>
#endif
-#if defined __ANDROID__
-#include <android/api-level.h>
+#if defined ANDROID || defined __ANDROID__
#include <malloc.h>
// The 18 here should be 16, but has to be 18 for now due
// to a Google-internal issue.
diff --git a/meta/generators/cc_emitter.py b/meta/generators/cc_emitter.py
index 8615671..c1dc75d 100644
--- a/meta/generators/cc_emitter.py
+++ b/meta/generators/cc_emitter.py
@@ -52,16 +52,16 @@ class CCEmitter(object):
self.indent = self.indent[:-2]
def EmitIndented(self, what):
- print self.indent + what
+ print(self.indent + what)
def EmitNewline(self):
- print ''
+ print('')
def EmitPreprocessor1(self, op, param):
- print '#%s %s' % (op, param)
+ print('#%s %s' % (op, param))
def EmitPreprocessor(self, op):
- print '#%s' % op
+ print('#%s' % op)
def EmitInclude(self, include):
self.EmitPreprocessor1('include', include)
diff --git a/meta/generators/common.py b/meta/generators/common.py
index d680372..7269b50 100644
--- a/meta/generators/common.py
+++ b/meta/generators/common.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""."""
+import collections
_HEADER_COPYRIGHT = (
'''// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
@@ -71,7 +72,7 @@ class StreamGenerator(object):
self.emitter = emitter
def SpecializeStream(self, in_type, lanes_count, pack_size, leftovers):
- if callable(getattr(self, 'EmitPack', None)):
+ if isinstance(getattr(self, 'EmitPack', None), collections.Callable):
template_params = [in_type, lanes_count, pack_size, leftovers, self.name]
self.emitter.EmitMemberFunctionBegin(
'Stream', [], template_params, 'Pack',
diff --git a/meta/generators/neon_emitter.py b/meta/generators/neon_emitter.py
index 726766e..0304317 100644
--- a/meta/generators/neon_emitter.py
+++ b/meta/generators/neon_emitter.py
@@ -187,7 +187,7 @@ class NeonEmitter(object):
self.indent = self.indent[:-delta]
def EmitIndented(self, what):
- print self.indent + what
+ print(self.indent + what)
def PushOp(self, op):
if op in self.ops.keys():
@@ -199,13 +199,13 @@ class NeonEmitter(object):
self.ops.clear()
def EmitNewline(self):
- print ''
+ print('')
def EmitPreprocessor1(self, op, param):
- print '#%s %s' % (op, param)
+ print('#%s %s' % (op, param))
def EmitPreprocessor(self, op):
- print '#%s' % op
+ print('#%s' % op)
def EmitInclude(self, include):
self.EmitPreprocessor1('include', include)
diff --git a/meta/generators/neon_emitter_64.py b/meta/generators/neon_emitter_64.py
index 13a0715..956b16b 100644
--- a/meta/generators/neon_emitter_64.py
+++ b/meta/generators/neon_emitter_64.py
@@ -423,7 +423,7 @@ class NeonEmitter64(object):
self.indent = self.indent[:-delta]
def EmitIndented(self, what):
- print self.indent + what
+ print(self.indent + what)
def PushOp(self, op):
if op in self.ops.keys():
@@ -435,13 +435,13 @@ class NeonEmitter64(object):
self.ops.clear()
def EmitNewline(self):
- print ''
+ print('')
def EmitPreprocessor1(self, op, param):
- print '#%s %s' % (op, param)
+ print('#%s %s' % (op, param))
def EmitPreprocessor(self, op):
- print '#%s' % op
+ print('#%s' % op)
def EmitInclude(self, include):
self.EmitPreprocessor1('include', include)
diff --git a/public/bit_depth.h b/public/bit_depth.h
index 412944e..5b19430 100644
--- a/public/bit_depth.h
+++ b/public/bit_depth.h
@@ -22,8 +22,8 @@ namespace gemmlowp {
// The range of allowed values for an operand.
template <int tMinValue, int tMaxValue>
struct OperandRange {
- static const int kMinValue = tMinValue;
- static const int kMaxValue = tMaxValue;
+ static constexpr int kMinValue = tMinValue;
+ static constexpr int kMaxValue = tMaxValue;
static_assert(kMinValue < kMaxValue, "");
};
diff --git a/public/map.h b/public/map.h
index fe6bc5c..1b71f9e 100644
--- a/public/map.h
+++ b/public/map.h
@@ -32,7 +32,7 @@ template <typename tScalar, MapOrder tOrder>
class MatrixMap {
public:
typedef tScalar Scalar;
- static const MapOrder kOrder = tOrder;
+ static constexpr MapOrder kOrder = tOrder;
protected:
Scalar* data_; // not owned.
@@ -84,7 +84,7 @@ template <typename tScalar, VectorShape tShape>
class VectorMap {
public:
typedef tScalar Scalar;
- static const VectorShape kShape = tShape;
+ static constexpr VectorShape kShape = tShape;
protected:
Scalar* data_; // not owned.
@@ -113,7 +113,7 @@ template <typename tScalar, VectorShape tShape>
class VectorDup {
public:
typedef tScalar Scalar;
- static const VectorShape kShape = tShape;
+ static constexpr VectorShape kShape = tShape;
protected:
Scalar data_;
diff --git a/standalone/cache_counters.cc b/standalone/cache_counters.cc
new file mode 100644
index 0000000..24e971c
--- /dev/null
+++ b/standalone/cache_counters.cc
@@ -0,0 +1,404 @@
+#include <asm/unistd.h>
+#include <linux/perf_event.h>
+#include <sys/ioctl.h>
+#include <unistd.h>
+#include <algorithm>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <random>
+
+#ifndef __aarch64__
+#error This program is for 64-bit ARM only.
+#endif
+
+struct PerfEvent {
+ perf_event_attr pe;
+ int fd = -1;
+
+ PerfEvent(std::uint32_t type, std::uint64_t config) {
+ memset(&pe, 0, sizeof(pe));
+ pe.size = sizeof(pe);
+ pe.type = type;
+ pe.config = config;
+ pe.disabled = 1;
+ pe.exclude_kernel = 1;
+ pe.exclude_hv = 1;
+ fd = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0);
+ if (fd == -1) {
+ fprintf(stderr, "perf_event_open failed for config 0x%lx\n", config);
+ abort();
+ }
+ }
+
+ void Start() {
+ ioctl(fd, PERF_EVENT_IOC_RESET, 0);
+ ioctl(fd, PERF_EVENT_IOC_ENABLE, 0);
+ }
+
+ std::int64_t Stop() {
+ ioctl(fd, PERF_EVENT_IOC_DISABLE, 0);
+ std::int64_t count = 0;
+ read(fd, &count, sizeof(count));
+ return count;
+ }
+
+ ~PerfEvent() { close(fd); }
+};
+
+struct ArmPmuEvent : PerfEvent {
+ static constexpr std::uint16_t L1I_CACHE_REFILL = 0x01;
+ static constexpr std::uint16_t L1I_TLB_REFILL = 0x02;
+ static constexpr std::uint16_t L1D_CACHE_REFILL = 0x03;
+ static constexpr std::uint16_t L1D_CACHE = 0x04;
+ static constexpr std::uint16_t L1D_TLB_REFILL = 0x05;
+ static constexpr std::uint16_t LD_RETIRED = 0x06;
+ static constexpr std::uint16_t ST_RETIRED = 0x07;
+ static constexpr std::uint16_t INST_RETIRED = 0x08;
+ static constexpr std::uint16_t EXC_TAKEN = 0x09;
+ static constexpr std::uint16_t EXC_RETURN = 0x0A;
+ static constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B;
+ static constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C;
+ static constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D;
+ static constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E;
+ static constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F;
+ static constexpr std::uint16_t BR_MIS_PRED = 0x10;
+ static constexpr std::uint16_t CPU_CYCLES = 0x11;
+ static constexpr std::uint16_t BR_PRED = 0x12;
+ static constexpr std::uint16_t MEM_ACCESS = 0x13;
+ static constexpr std::uint16_t L1I_CACHE = 0x14;
+ static constexpr std::uint16_t L1D_CACHE_WB = 0x15;
+ static constexpr std::uint16_t L2D_CACHE = 0x16;
+ static constexpr std::uint16_t L2D_CACHE_REFILL = 0x17;
+ static constexpr std::uint16_t L2D_CACHE_WB = 0x18;
+ static constexpr std::uint16_t BUS_ACCESS = 0x19;
+ static constexpr std::uint16_t MEMORY_ERROR = 0x1A;
+ static constexpr std::uint16_t INST_SPEC = 0x1B;
+ static constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C;
+ static constexpr std::uint16_t BUS_CYCLES = 0x1D;
+ static constexpr std::uint16_t CHAIN = 0x1E;
+ static constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F;
+ static constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20;
+ static constexpr std::uint16_t BR_RETIRED = 0x21;
+ static constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22;
+ static constexpr std::uint16_t STALL_FRONTEND = 0x23;
+ static constexpr std::uint16_t STALL_BACKEND = 0x24;
+ static constexpr std::uint16_t L1D_TLB = 0x25;
+ static constexpr std::uint16_t L1I_TLB = 0x26;
+ static constexpr std::uint16_t L2I_CACHE = 0x27;
+ static constexpr std::uint16_t L2I_CACHE_REFILL = 0x28;
+ static constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29;
+ static constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A;
+ static constexpr std::uint16_t L3D_CACHE = 0x2B;
+ static constexpr std::uint16_t L3D_CACHE_WB = 0x2C;
+ static constexpr std::uint16_t L2D_TLB_REFILL = 0x2D;
+ static constexpr std::uint16_t L2I_TLB_REFILL = 0x2E;
+ static constexpr std::uint16_t L2D_TLB = 0x2F;
+ static constexpr std::uint16_t L2I_TLB = 0x30;
+ static constexpr std::uint16_t LL_CACHE = 0x32;
+ static constexpr std::uint16_t LL_CACHE_MISS = 0x33;
+ static constexpr std::uint16_t DTLB_WALK = 0x34;
+ static constexpr std::uint16_t LL_CACHE_RD = 0x36;
+ static constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37;
+ static constexpr std::uint16_t L1D_CACHE_RD = 0x40;
+ static constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42;
+ static constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C;
+ static constexpr std::uint16_t L1D_TLB_RD = 0x4E;
+ static constexpr std::uint16_t L2D_CACHE_RD = 0x50;
+ static constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52;
+ static constexpr std::uint16_t BUS_ACCESS_RD = 0x60;
+ static constexpr std::uint16_t MEM_ACCESS_RD = 0x66;
+ static constexpr std::uint16_t L3D_CACHE_RD = 0xA0;
+ static constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2;
+ ArmPmuEvent(std::uint16_t number) : PerfEvent(PERF_TYPE_RAW, number) {}
+};
+
+struct CacheCounts {
+ int ld_retired = 0;
+ int mem_access = 0;
+ int ll_cache = 0;
+ int ll_cache_miss = 0;
+ int l1d_cache = 0;
+ int l1d_cache_refill = 0;
+ int l2d_cache = 0;
+ int l2d_cache_refill = 0;
+ int l3d_cache = 0;
+ int l3d_cache_refill = 0;
+};
+
+void PrintCacheCounts(const CacheCounts& cache_counts) {
+ printf("ld_retired = %d\n", cache_counts.ld_retired);
+ printf("mem_access = %d\n", cache_counts.mem_access);
+ printf("ll_cache = %d\n", cache_counts.ll_cache);
+ printf("ll_cache_miss = %d\n", cache_counts.ll_cache_miss);
+ printf("l1d_cache = %d\n", cache_counts.l1d_cache);
+ printf("l1d_cache_refill = %d\n", cache_counts.l1d_cache_refill);
+ printf("l2d_cache = %d\n", cache_counts.l2d_cache);
+ printf("l2d_cache_refill = %d\n", cache_counts.l2d_cache_refill);
+ printf("l3d_cache = %d\n", cache_counts.l3d_cache);
+ printf("l3d_cache_refill = %d\n", cache_counts.l3d_cache_refill);
+}
+
+void Workload(int accesses, int size, std::uint8_t* buf) {
+ // The main reason to do this in assembly is an attempt to make sense
+ // of instruction count counters, such as LD_RETIRED.
+ // Also, if we did this in C++, we would need to be watchful of the compiler
+ // optimizing away operations whose result isn't consumed.
+ //
+ // Note that TWO separate tricks are needed here to prevent Cortex-A76
+ // speculative execution om prefetching data from future loop iterations:
+ // 1. A data-dependency whereby the pointers being dereferenced at the
+ // next loop iteration depend on values loaded at the current iteration.
+ // That is the role of 'dummy'.
+ // 2. A pseudo-random sequence. This is the role of register w0,
+ // where we implement a simple xorshift pseudorandom generator.
+ // BOTH of these tricks are needed: if we disable just one of them,
+ // Cortex-A76 successfully speculates some addresses, resulting in different
+ // L3 / DRAM hit percentages on large sizes.
+ std::uint64_t dummy = 123456789;
+ asm volatile(
+ // w0 := xorshift RNG state. Must be nonzero.
+ "mov w0, #1\n"
+ "1:\n"
+ // xorshift RNG iteration: update w0 with the next pseudorandom value
+ // in [1 .. 2^32-1].
+ // This pseudorandomness is crucial to preventing speculative execution
+ // on Cortex-A76 from prefetching data from future loop iterations.
+ "eor w0, w0, w0, lsl #13\n"
+ "eor w0, w0, w0, lsr #17\n"
+ "eor w0, w0, w0, lsl #5\n"
+ // w1 := size - 1 = size mask (size is required to be power-of-two).
+ "sub w1, %w[size], #1\n"
+ // w2 := (pseudorandom value w0) xor (data-dependent sum).
+ "eor w2, w0, %w[dummy]\n"
+ // w1 := w2 modulo size
+ "and w1, w2, w1\n"
+ // align w1
+ "and w1, w1, #-64\n"
+ // load at offset w1, again using x1 as destination.
+ "ldr x1, [%[buf], w1, uxtw]\n"
+ // Update our dummy so it depends on the value we have just loaded.
+ // This data-dependency is key to preventing speculative execution on
+ // Cortex-A76 from prefetching data from future loop iterations.
+ "add %[dummy], %[dummy], w1, uxtw\n"
+ // loop back.
+ "subs %w[accesses], %w[accesses], #1\n"
+ "bne 1b\n"
+ : [ accesses ] "+r"(accesses), [ dummy ] "+r"(dummy)
+ : [ size ] "r"(size), [ buf ] "r"(buf)
+ : "memory", "cc", "x0", "x1", "x2");
+}
+
+void MeasureCacheCounts(int accesses, int size, std::uint8_t* buf,
+ CacheCounts* cache_counts) {
+ const bool only_reads = getenv("ONLY_READS");
+ ArmPmuEvent ld_retired(ArmPmuEvent::LD_RETIRED);
+ ArmPmuEvent mem_access(only_reads ? ArmPmuEvent::MEM_ACCESS_RD
+ : ArmPmuEvent::MEM_ACCESS);
+ ArmPmuEvent ll_cache(only_reads ? ArmPmuEvent::LL_CACHE_RD
+ : ArmPmuEvent::LL_CACHE);
+ ArmPmuEvent ll_cache_miss(only_reads ? ArmPmuEvent::LL_CACHE_MISS_RD
+ : ArmPmuEvent::LL_CACHE_MISS);
+ ArmPmuEvent l1d_cache(only_reads ? ArmPmuEvent::L1D_CACHE_RD
+ : ArmPmuEvent::L1D_CACHE);
+ ArmPmuEvent l1d_cache_refill(only_reads ? ArmPmuEvent::L1D_CACHE_REFILL_RD
+ : ArmPmuEvent::L1D_CACHE_REFILL);
+ ArmPmuEvent l2d_cache(only_reads ? ArmPmuEvent::L2D_CACHE_RD
+ : ArmPmuEvent::L2D_CACHE);
+ ArmPmuEvent l2d_cache_refill(only_reads ? ArmPmuEvent::L2D_CACHE_REFILL_RD
+ : ArmPmuEvent::L2D_CACHE_REFILL);
+ ArmPmuEvent l3d_cache(only_reads ? ArmPmuEvent::L3D_CACHE_RD
+ : ArmPmuEvent::L3D_CACHE);
+ ArmPmuEvent l3d_cache_refill(only_reads ? ArmPmuEvent::L3D_CACHE_REFILL_RD
+ : ArmPmuEvent::L3D_CACHE_REFILL);
+
+ ld_retired.Start();
+ mem_access.Start();
+ ll_cache.Start();
+ ll_cache_miss.Start();
+ l1d_cache.Start();
+ l1d_cache_refill.Start();
+ l2d_cache.Start();
+ l2d_cache_refill.Start();
+ l3d_cache.Start();
+ l3d_cache_refill.Start();
+
+ Workload(accesses, size, buf);
+
+ cache_counts->ld_retired = ld_retired.Stop();
+ cache_counts->mem_access = mem_access.Stop();
+ cache_counts->ll_cache = ll_cache.Stop();
+ cache_counts->ll_cache_miss = ll_cache_miss.Stop();
+ cache_counts->l1d_cache = l1d_cache.Stop();
+ cache_counts->l1d_cache_refill = l1d_cache_refill.Stop();
+ cache_counts->l2d_cache = l2d_cache.Stop();
+ cache_counts->l2d_cache_refill = l2d_cache_refill.Stop();
+ cache_counts->l3d_cache = l3d_cache.Stop();
+ cache_counts->l3d_cache_refill = l3d_cache_refill.Stop();
+}
+
+struct PieChart {
+ // How many accesses were recorded, total? The other fields must sum to that.
+ int total;
+ // How many accesses were serviced with the typical cost of a L1 cache hit?
+ int l1_hits;
+ // How many accesses were serviced with the typical cost of a L2 cache hit?
+ int l2_hits;
+ // How many accesses were serviced with the typical cost of a L3 cache hit?
+ int l3_hits;
+ // How many accesses were serviced with the typical cost of a DRAM access?
+ int dram_hits;
+
+ ~PieChart() {
+ // Consistency check
+ if (total != l1_hits + l2_hits + l3_hits + dram_hits) {
+ fprintf(stderr, "inconsistent pie-chart\n");
+ abort();
+ }
+ }
+};
+
+struct Hypothesis {
+ virtual ~Hypothesis() {}
+ virtual const char* Name() const = 0;
+ virtual void Analyze(const CacheCounts& cache_counts,
+ PieChart* pie) const = 0;
+};
+
+struct Hypothesis1 : Hypothesis {
+ const char* Name() const override { return "Hypothesis1"; }
+ void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override {
+ pie->total = cache_counts.l1d_cache + cache_counts.l1d_cache_refill;
+ pie->l1_hits = cache_counts.l1d_cache - cache_counts.l2d_cache_refill -
+ cache_counts.l3d_cache_refill;
+ pie->l2_hits = cache_counts.l1d_cache_refill;
+ pie->l3_hits = cache_counts.l2d_cache_refill;
+ pie->dram_hits = cache_counts.l3d_cache_refill;
+ }
+};
+
+struct Hypothesis2 : Hypothesis {
+ const char* Name() const override { return "Hypothesis2"; }
+ void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override {
+ pie->total = cache_counts.l1d_cache;
+ pie->l1_hits = cache_counts.l1d_cache - cache_counts.l2d_cache;
+ pie->l2_hits = cache_counts.l2d_cache - cache_counts.l3d_cache;
+ pie->l3_hits = cache_counts.l3d_cache - cache_counts.l3d_cache_refill;
+ pie->dram_hits = cache_counts.l3d_cache_refill;
+ }
+};
+
+struct Hypothesis3 : Hypothesis {
+ const char* Name() const override { return "Hypothesis3"; }
+ void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override {
+ pie->total = cache_counts.l1d_cache;
+ int corrected_l2 = std::min(cache_counts.l2d_cache, cache_counts.l1d_cache);
+ int corrected_l3 = std::min(cache_counts.l3d_cache, corrected_l2);
+ pie->l1_hits = cache_counts.l1d_cache - corrected_l2;
+ pie->l2_hits = corrected_l2 - corrected_l3;
+ pie->l3_hits = corrected_l3 - cache_counts.l3d_cache_refill;
+ pie->dram_hits = cache_counts.l3d_cache_refill;
+ }
+};
+
+struct Hypothesis4 : Hypothesis {
+ const char* Name() const override { return "Hypothesis4"; }
+ void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override {
+ pie->total = cache_counts.l1d_cache;
+ pie->l1_hits = cache_counts.l1d_cache - cache_counts.l1d_cache_refill;
+ pie->l2_hits =
+ cache_counts.l1d_cache_refill - cache_counts.l2d_cache_refill;
+ pie->l3_hits =
+ cache_counts.l2d_cache_refill - cache_counts.l3d_cache_refill;
+ pie->dram_hits = cache_counts.l3d_cache_refill;
+ }
+};
+
+struct Hypothesis5 : Hypothesis {
+ const char* Name() const override { return "Hypothesis5"; }
+ void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override {
+ pie->l1_hits =
+ std::max(0, cache_counts.l1d_cache - cache_counts.l1d_cache_refill);
+ pie->l2_hits = std::max(
+ 0, cache_counts.l1d_cache_refill - cache_counts.l2d_cache_refill);
+ const int l3_misses =
+ std::max(cache_counts.ll_cache_miss, cache_counts.l3d_cache_refill);
+ pie->l3_hits = std::max(0, cache_counts.l2d_cache_refill - l3_misses);
+ pie->dram_hits = l3_misses;
+ pie->total = pie->l1_hits + pie->l2_hits + pie->l3_hits + pie->dram_hits;
+ }
+};
+
+void PrintPieChart(const PieChart& pie) {
+ printf("total accesses: %d\n", pie.total);
+ double l1_hits_pct = 100. * pie.l1_hits / pie.total;
+ double l2_hits_pct = 100. * pie.l2_hits / pie.total;
+ double l3_hits_pct = 100. * pie.l3_hits / pie.total;
+ double dram_hits_pct = 100. * pie.dram_hits / pie.total;
+ printf("L1 hits: %.2f%%\n", l1_hits_pct);
+ printf("L2 hits: %.2f%%\n", l2_hits_pct);
+ printf("L1/2 hits: %.2f%%\n", l1_hits_pct + l2_hits_pct);
+ printf("L3 hits: %.2f%%\n", l3_hits_pct);
+ printf("L1/2/3 hits: %.2f%%\n", l1_hits_pct + l2_hits_pct + l3_hits_pct);
+ printf("DRAM hits: %.2f%%\n", dram_hits_pct);
+}
+
+void PrintPieChartCsvNoNewline(const PieChart& pie) {
+ double l1_hits_pct = 100. * pie.l1_hits / pie.total;
+ double l2_hits_pct = 100. * pie.l2_hits / pie.total;
+ double l3_hits_pct = 100. * pie.l3_hits / pie.total;
+ double dram_hits_pct = 100. * pie.dram_hits / pie.total;
+ printf("%.2f,%.2f,%.2f,%.2f", l1_hits_pct, l2_hits_pct, l3_hits_pct,
+ dram_hits_pct);
+}
+
+void Study(int accesses, int size, std::uint8_t* buf) {
+ CacheCounts cache_counts;
+ MeasureCacheCounts(accesses, size, buf, &cache_counts);
+ const Hypothesis* hypotheses[] = {
+ new Hypothesis5, new Hypothesis4, new Hypothesis3,
+ new Hypothesis2, new Hypothesis1,
+ };
+ if (getenv("DUMP_CSV")) {
+ printf("%d", size);
+ for (const Hypothesis* hypothesis : hypotheses) {
+ printf(",");
+ PieChart pie;
+ hypothesis->Analyze(cache_counts, &pie);
+ PrintPieChartCsvNoNewline(pie);
+ }
+ printf("\n");
+ } else {
+ printf("\n\n\naccesses=%d, size=%d:\n", accesses, size);
+ printf("\nCache counts:\n");
+ PrintCacheCounts(cache_counts);
+ for (const Hypothesis* hypothesis : hypotheses) {
+ printf("\n%s:\n", hypothesis->Name());
+ PieChart pie;
+ hypothesis->Analyze(cache_counts, &pie);
+ PrintPieChart(pie);
+ }
+ }
+ fflush(stdout);
+ for (const Hypothesis* hypothesis : hypotheses) {
+ delete hypothesis;
+ }
+}
+
+int main() {
+ const int kMinSize = 1 << 12;
+ const int kMaxSize = 1 << 24;
+ const int kAccesses = 1e8;
+ void* buf_void = nullptr;
+ posix_memalign(&buf_void, 64, kMaxSize);
+ std::uint8_t* buf = static_cast<std::uint8_t*>(buf_void);
+ std::default_random_engine random_engine;
+ for (int i = 0; i < kMaxSize; i++) {
+ buf[i] = random_engine();
+ }
+ for (int size = kMinSize; size <= kMaxSize; size *= 2) {
+ Study(kAccesses, size, buf);
+ }
+ delete[] buf;
+}
diff --git a/standalone/encode.py b/standalone/encode.py
new file mode 100644
index 0000000..c192ab9
--- /dev/null
+++ b/standalone/encode.py
@@ -0,0 +1,134 @@
+# Copyright 2018 The gemmlowp Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Encodes ARM asm code for certain instructions into the corresponding machine code encoding, as a .word directive in the asm code, preserving the original code in a comment.
+
+Reads from stdin, writes to stdout.
+
+Example diff:
+- "udot v16.4s, v4.16b, v0.16b\n"
++ ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n"
+
+The intended use case is to make asm code easier to compile on toolchains that
+do not support certain new instructions.
+"""
+
+import sys
+import re
+import argparse
+
+
+def encode_udot_sdot_vector(line):
+ m = re.search(
+ r'\b([us])dot[ ]+v([0-9]+)[ ]*\.[ ]*4s[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b',
+ line)
+ if not m:
+ return 0, line
+
+ match = m.group(0)
+ unsigned = 1 if m.group(1) == 'u' else 0
+ accum = int(m.group(2))
+ lhs = int(m.group(3))
+ rhs = int(m.group(4))
+ assert accum >= 0 and accum <= 31
+ assert lhs >= 0 and lhs <= 31
+ assert rhs >= 0 and rhs <= 31
+ mcode = 0x4e809400 | (accum << 0) | (lhs << 5) | (rhs << 16) | (
+ unsigned << 29)
+ return mcode, match
+
+
+def encode_udot_sdot_element(line):
+ m = re.search(
+ r'\b([us])dot[ ]+v([0-9]+)[ ]*\.[ ]*4s[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*4b[ ]*\[([0-9])\]',
+ line)
+ if not m:
+ return 0, line
+
+ match = m.group(0)
+ unsigned = 1 if m.group(1) == 'u' else 0
+ accum = int(m.group(2))
+ lhs = int(m.group(3))
+ rhs = int(m.group(4))
+ lanegroup = int(m.group(5))
+ assert accum >= 0 and accum <= 31
+ assert lhs >= 0 and lhs <= 31
+ assert rhs >= 0 and rhs <= 31
+ assert lanegroup >= 0 and lanegroup <= 3
+ l = 1 if lanegroup & 1 else 0
+ h = 1 if lanegroup & 2 else 0
+ mcode = 0x4f80e000 | (accum << 0) | (lhs << 5) | (rhs << 16) | (l << 21) | (
+ h << 11) | (
+ unsigned << 29)
+ return mcode, match
+
+
+def encode(line):
+ for encode_func in [encode_udot_sdot_vector, encode_udot_sdot_element]:
+ mcode, match = encode_func(line)
+ if mcode:
+ return mcode, match
+ return 0, line
+
+
+def read_existing_encoding(line):
+ m = re.search(r'\.word\ (0x[0-9a-f]+)', line)
+ if m:
+ return int(m.group(1), 16)
+ return 0
+
+
+parser = argparse.ArgumentParser(description='Encode some A64 instructions.')
+parser.add_argument(
+ '-f',
+ '--fix',
+ help='fix existing wrong encodings in-place and continue',
+ action='store_true')
+args = parser.parse_args()
+
+lineno = 0
+found_existing_encodings = False
+found_error = False
+found_fixes = False
+for line in sys.stdin:
+ lineno = lineno + 1
+ mcode, match = encode(line)
+ if mcode:
+ existing_encoding = read_existing_encoding(line)
+ if existing_encoding:
+ found_existing_encodings = True
+ if mcode != existing_encoding:
+ if args.fix:
+ line = line.replace('.word 0x%x // %s' % (existing_encoding, match),
+ '.word 0x%x // %s' % (mcode, match))
+ found_fixes = True
+ else:
+ sys.stderr.write(
+ "Error at line %d: existing encoding 0x%x differs from encoding 0x%x for instruction '%s':\n\n%s\n\n"
+ % (lineno, existing_encoding, mcode, match, line))
+ found_error = True
+ else:
+ line = line.replace(match, '.word 0x%x // %s' % (mcode, match))
+ sys.stdout.write(line)
+if found_error:
+ sys.exit(1)
+if found_existing_encodings:
+ if found_fixes:
+ sys.stderr.write(
+ 'Note: some instructions that this program is able to encode, were already encoded and their existing encodings didn\'t match the specified asm instructions. Since --fix was passed, these were fixed in-place.\n'
+ )
+ else:
+ sys.stderr.write(
+ 'Note: some instructions that this program is able to encode, were already encoded. These encodings have been checked.\n'
+ )
diff --git a/standalone/neon-gemm-kernel-benchmark.cc b/standalone/neon-gemm-kernel-benchmark.cc
index bff33fb..9146179 100644
--- a/standalone/neon-gemm-kernel-benchmark.cc
+++ b/standalone/neon-gemm-kernel-benchmark.cc
@@ -1,4 +1,4 @@
-// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
+// Copyright 2016 The gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -240,6 +240,51 @@ struct KernelFormat {
static const int kCols = Rhs::Cell::kWidth * Rhs::kCells;
};
+// KernelOperandRanges specifies the minimum and maximum values an operand can
+// take. It consists of two ranges: one for the LHS and one for the RHS. The
+// default values are the minimum and maximum values of the operand data type.
+template <typename Kernel, typename OperandType = typename Kernel::OperandType>
+struct KernelOperandRanges {
+ static OperandType LhsMin() {
+ return std::numeric_limits<OperandType>::lowest();
+ }
+ static OperandType LhsMax() {
+ return std::numeric_limits<OperandType>::max();
+ }
+ static OperandType RhsMin() {
+ return std::numeric_limits<OperandType>::lowest();
+ }
+ static OperandType RhsMax() {
+ return std::numeric_limits<OperandType>::max();
+ }
+};
+
+template <typename Kernel>
+struct KernelOperandRanges<Kernel, float> {
+ static float LhsMin() { return -100.f; }
+ static float LhsMax() { return 100.f; }
+ static float RhsMin() { return -100.f; }
+ static float RhsMax() { return 100.f; }
+};
+
+#define SET_7BIT_RANGES(kernel) \
+template <> \
+struct KernelOperandRanges<kernel, std::int8_t> { \
+ static std::int8_t LhsMin() { return -63; } \
+ static std::int8_t LhsMax() { return 63; } \
+ static std::int8_t RhsMin() { return -64; } \
+ static std::int8_t RhsMax() { return 63; } \
+};
+
+#define SET_425BIT_RANGES(kernel) \
+template <> \
+struct KernelOperandRanges<kernel, std::int8_t> { \
+ static std::int8_t LhsMin() { return -7; } \
+ static std::int8_t LhsMax() { return 7; } \
+ static std::int8_t RhsMin() { return -9; } \
+ static std::int8_t RhsMax() { return 9; } \
+};
+
inline const char* CellOrderName(CellOrder o) {
switch (o) {
case CellOrder::DepthMajor:
@@ -596,7 +641,6 @@ struct NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits {
AccumulatorType* accum_ptr, int depth) {
std::size_t start_depth = 123;
std::size_t run_depth = depth;
- std::size_t dst_col_stride = 4;
AccumulatorType* dst_ptr = accum_ptr;
asm volatile(
@@ -2516,6 +2560,817 @@ struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits {
}
};
+struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+ KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ std::size_t start_depth = 123;
+ std::size_t run_depth = depth;
+ std::size_t dst_col_stride = 4;
+ AccumulatorType* dst_ptr = accum_ptr;
+ asm volatile(
+ // Overview of register layout:
+ //
+ // A 4x16 block of Rhs is stored in 8 bit in v0--v3.
+ // A 4x16 block of Lhs is stored in 8 bit in v4--v7.
+ //
+ // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit
+ // components which need to be horizontally-added at the end)
+ //
+ // Register layout:
+ //
+ // +--------+--------+--------+--------+
+ // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] |
+ // Rhs +--------+--------+--------+--------+
+ // | ... | ... | ... | ... |
+ // +--------+--------+--------+--------|
+ // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]|
+ // +--------+--------+--------+--------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +-------+-----+--------+ - - +--------+--------+--------+--------+
+ // |v4.b[0]| ... |v4.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s |
+ // |v5.b[0]| ... |v5.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s |
+ // |v6.b[0]| ... |v6.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s |
+ // |v7.b[0]| ... |v7.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s |
+ // +-------+--------------+ - - +--------+--------+--------+--------+
+ //
+ // Accumulator
+ //
+
+ // Clear accumulators
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+ "dup v16.4s, wzr\n"
+ "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
+ "dup v17.4s, wzr\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ "dup v18.4s, wzr\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ "dup v19.4s, wzr\n"
+ "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+ "dup v20.4s, wzr\n"
+ "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+ "dup v21.4s, wzr\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ "dup v22.4s, wzr\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+ "dup v23.4s, wzr\n"
+ "subs %w[run_depth], %w[run_depth], #16\n"
+ "dup v24.4s, wzr\n"
+ "mov x0, %[dst_ptr]\n"
+ "dup v25.4s, wzr\n"
+ "dup v26.4s, wzr\n"
+ "dup v27.4s, wzr\n"
+ "dup v28.4s, wzr\n"
+ "dup v29.4s, wzr\n"
+ "dup v30.4s, wzr\n"
+ "dup v31.4s, wzr\n"
+
+ "beq 1f\n"
+
+ "cmp %w[run_depth], #32\n"
+ "blt 2f\n"
+
+ "3:\n"
+ "ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n"
+ ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n"
+ "ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n"
+ ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n"
+ "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n"
+ ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n"
+ "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n"
+ ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n"
+ "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n"
+ ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n"
+ "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n"
+ "prfm pldl1keep, [%[rhs_ptr], #128]\n"
+ ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n"
+ "ld1 {v14.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n"
+ ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n"
+ "ld1 {v15.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n"
+ "prfm pldl1keep, [%[lhs_ptr], #128]\n"
+ ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e889590 // udot v16.4s, v12.16b, v8.16b\n"
+ ".word 0x6e899591 // udot v17.4s, v12.16b, v9.16b\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e8a9592 // udot v18.4s, v12.16b, v10.16b\n"
+ ".word 0x6e8b9593 // udot v19.4s, v12.16b, v11.16b\n"
+ "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e8895b4 // udot v20.4s, v13.16b, v8.16b\n"
+ ".word 0x6e8995b5 // udot v21.4s, v13.16b, v9.16b\n"
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+ "sub %[run_depth], %[run_depth], #32\n"
+ ".word 0x6e8a95b6 // udot v22.4s, v13.16b, v10.16b\n"
+ ".word 0x6e8b95b7 // udot v23.4s, v13.16b, v11.16b\n"
+ "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8895d8 // udot v24.4s, v14.16b, v8.16b\n"
+ ".word 0x6e8995d9 // udot v25.4s, v14.16b, v9.16b\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8a95da // udot v26.4s, v14.16b, v10.16b\n"
+ ".word 0x6e8b95db // udot v27.4s, v14.16b, v11.16b\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8895fc // udot v28.4s, v15.16b, v8.16b\n"
+ "prfm pldl1keep, [%[rhs_ptr], #128]\n"
+ ".word 0x6e8995fd // udot v29.4s, v15.16b, v9.16b\n"
+ "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+ "cmp %w[run_depth], #32\n"
+ ".word 0x6e8a95fe // udot v30.4s, v15.16b, v10.16b\n"
+ "prfm pldl1keep, [%[lhs_ptr], #128]\n"
+ ".word 0x6e8b95ff // udot v31.4s, v15.16b, v11.16b\n"
+
+ "bge 3b\n"
+
+ "cmp %w[run_depth], #0\n"
+ "beq 1f\n"
+
+ "2:\n"
+
+ "subs %w[run_depth], %w[run_depth], #16\n"
+
+ ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n"
+ ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n"
+ ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n"
+ ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n"
+ ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n"
+ ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n"
+ ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n"
+ ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n"
+ ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n"
+ ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n"
+ "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n"
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n"
+ "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+
+ "bne 2b\n"
+
+ "1:\n"
+
+ ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n"
+ ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n"
+ ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n"
+ ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n"
+ ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n"
+ ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n"
+ ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n"
+ ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n"
+ ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n"
+ ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n"
+ ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n"
+ ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n"
+ ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n"
+ ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n"
+ ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n"
+ ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n"
+
+ // Load accumulators from memory
+ "ld1 {v8.16b}, [x0], #16\n"
+ "ld1 {v9.16b}, [x0], #16\n"
+ "ld1 {v10.16b}, [x0], #16\n"
+ "ld1 {v11.16b}, [x0], #16\n"
+ "mov x0, %[dst_ptr]\n"
+
+ // Reduce aggregators horizontally
+ "addp v0.4s, v16.4s, v20.4s\n"
+ "addp v1.4s, v17.4s, v21.4s\n"
+ "addp v2.4s, v18.4s, v22.4s\n"
+ "addp v3.4s, v19.4s, v23.4s\n"
+ "addp v4.4s, v24.4s, v28.4s\n"
+ "addp v5.4s, v25.4s, v29.4s\n"
+ "addp v6.4s, v26.4s, v30.4s\n"
+ "addp v7.4s, v27.4s, v31.4s\n"
+
+ "addp v12.4s, v0.4s, v4.4s\n"
+ "addp v13.4s, v1.4s, v5.4s\n"
+ "addp v14.4s, v2.4s, v6.4s\n"
+ "addp v15.4s, v3.4s, v7.4s\n"
+
+ // Add to the accumulators loaded from memory
+ "add v8.4s, v8.4s, v12.4s\n"
+ "add v9.4s, v9.4s, v13.4s\n"
+ "add v10.4s, v10.4s, v14.4s\n"
+ "add v11.4s, v11.4s, v15.4s\n"
+
+ // Store accumulators back to memory
+ "st1 {v8.16b}, [x0], #16\n"
+ "st1 {v9.16b}, [x0], #16\n"
+ "st1 {v10.16b}, [x0], #16\n"
+ "st1 {v11.16b}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth),
+ [dst_col_stride] "+r"(dst_col_stride)
+ : // inputs
+ [start_depth] "r"(start_depth)
+ : // clobbers
+ "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
+ "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
+ "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
+ "v28", "v29", "v30", "v31");
+ }
+};
+
+// Fast kernel operating on int8 operands with 7-bit range.
+// It is assumed that one of the two operands only takes values in [-63, 63],
+// while the other take values in [-64, 63].
+// With this restriction, it is possible to multiply-accumulate operands into
+// a 16-bit integer eight times without overflow.
+struct NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits {
+ typedef std::int8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+ KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+#define GEMMLOWP_LABEL_64_DEPTH_LOOP "1"
+#define GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP "2"
+#define GEMMLOWP_LABEL_16_DEPTH_LOOP "3"
+#define GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP "4"
+
+ AccumulatorType* dst_ptr = accum_ptr;
+ asm volatile(
+ // Overview of register layout:
+ //
+ // A 4x16 block of Lhs is stored in 8 bit in v0--v7.
+ // A 2x16 block of Rhs is stored in 8 bit in v8--v15.
+ //
+ // A 4x2 block of global accumulators is stored in v24-v31 (as 4x32 bit
+ // components which need to be horizontally-added at the end).
+ //
+ // A 4x2 block of local accumulators is stored in v16-v23 (as 8x16 bit
+ // components which are added to global accumulators every 64 depth
+ // iteration.
+ //
+ // The Lhs vectors are multiplied by the Rhs vectors with a widening
+ // multiply over the 8 first levels of depth, producing int16x8
+ // vectors of products for each position in the accumulator matrix.
+ //
+ // Like the trick used in the fast 8-bit kernel, the operands are
+ // restricted to 7-bit range [-2^6, 2^6) so their products are in range
+ // [-2^12, 2^12 -1). This enables adding eight such products without any
+ // risk of overflowing int16, equating to 64 levels of depth before
+ // horizontally adding these int16x8 accumulators into the final int32x4
+ // accumulators.
+ //
+ // Register layout including both local and global accumulators.
+ // Since we do not have enough registers to store all Lhs values, we
+ // reuse the same registers v0--v7 to load the rest of the Lhs values.
+ //
+ // +-----+-----+
+ // | v8 | v9 |
+ // Rhs +-----+-----+
+ // | v10 | v11 |
+ // +-----+-----+
+ // | v12 | v13 |
+ // +-----+-----+
+ // | v14 | v15 |
+ // Lhs +-----+-----+
+ // +----+----+----+----+ - - +-----+-----+ +--------+--------+
+ // | v0 | v4 | v0 | v4 | | v16 | v20 | | v24.4s | v28.4s |
+ // | v1 | v5 | v1 | v5 | | v17 | v21 | -> | v25.4s | v29.4s |
+ // | v2 | v6 | v2 | v6 | | v18 | v22 | | v26.4s | v30.4s |
+ // | v3 | v7 | v3 | v7 | | v19 | v23 | | v27.4s | v31.4s |
+ // +----+----+----+----+ - - +-----+-----+ +--------+--------+
+ //
+ // Local Accumulator Global Accumulator
+ //
+
+ // Clear accumulators.
+ "dup v16.4s, wzr\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "dup v24.4s, wzr\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "dup v17.4s, wzr\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "dup v25.4s, wzr\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "dup v18.4s, wzr\n"
+ "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
+ "dup v26.4s, wzr\n"
+ "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
+ "dup v19.4s, wzr\n"
+ "dup v27.4s, wzr\n"
+ "dup v20.4s, wzr\n"
+ "dup v28.4s, wzr\n"
+ "dup v21.4s, wzr\n"
+ "dup v29.4s, wzr\n"
+ "dup v22.4s, wzr\n"
+ "dup v30.4s, wzr\n"
+ "dup v23.4s, wzr\n"
+ "dup v31.4s, wzr\n"
+
+ "cmp %w[depth], #64\n"
+ "blt " GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ GEMMLOWP_LABEL_64_DEPTH_LOOP
+ ":\n"
+ "subs %w[depth], %w[depth], #64\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ "sadalp v24.4s, v16.8h\n"
+ "smull v16.8h, v0.8b, v8.8b\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ "sadalp v25.4s, v17.8h\n"
+ "smull v17.8h, v1.8b, v8.8b\n"
+ "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+ "sadalp v26.4s, v18.8h\n"
+ "smull v18.8h, v2.8b, v8.8b\n"
+ "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+ "sadalp v27.4s, v19.8h\n"
+ "smull v19.8h, v3.8b, v8.8b\n"
+ "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
+ "sadalp v28.4s, v20.8h\n"
+ "smull v20.8h, v0.8b, v9.8b\n"
+ "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
+ "sadalp v29.4s, v21.8h\n"
+ "smull v21.8h, v1.8b, v9.8b\n"
+ "ld1 {v12.16b}, [%[rhs_ptr]], #16\n"
+ "sadalp v30.4s, v22.8h\n"
+ "smull v22.8h, v2.8b, v9.8b\n"
+ "ld1 {v13.16b}, [%[rhs_ptr]], #16\n"
+ "sadalp v31.4s, v23.8h\n"
+ "smull v23.8h, v3.8b, v9.8b\n"
+
+ "cmp %w[depth], #64\n"
+ "smlal2 v16.8h, v0.16b, v8.16b\n"
+ "ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
+ "smlal2 v17.8h, v1.16b, v8.16b\n"
+ "ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
+ "smlal2 v18.8h, v2.16b, v8.16b\n"
+ "smlal2 v19.8h, v3.16b, v8.16b\n"
+
+ "smlal2 v20.8h, v0.16b, v9.16b\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v21.8h, v1.16b, v9.16b\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v22.8h, v2.16b, v9.16b\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v23.8h, v3.16b, v9.16b\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "smlal v16.8h, v4.8b, v10.8b\n"
+ "smlal v17.8h, v5.8b, v10.8b\n"
+ "smlal v18.8h, v6.8b, v10.8b\n"
+ "smlal v19.8h, v7.8b, v10.8b\n"
+ "smlal v20.8h, v4.8b, v11.8b\n"
+
+ "smlal v21.8h, v5.8b, v11.8b\n"
+ "smlal v22.8h, v6.8b, v11.8b\n"
+ "smlal v23.8h, v7.8b, v11.8b\n"
+
+ "smlal2 v16.8h, v4.16b, v10.16b\n"
+ "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
+ "smlal2 v17.8h, v5.16b, v10.16b\n"
+ "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
+ "smlal2 v18.8h, v6.16b, v10.16b\n"
+ "smlal2 v19.8h, v7.16b, v10.16b\n"
+
+ "smlal2 v20.8h, v4.16b, v11.16b\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v21.8h, v5.16b, v11.16b\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v22.8h, v6.16b, v11.16b\n"
+ "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v23.8h, v7.16b, v11.16b\n"
+ "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+
+ "smlal v16.8h, v0.8b, v12.8b\n"
+ "smlal v17.8h, v1.8b, v12.8b\n"
+ "smlal v18.8h, v2.8b, v12.8b\n"
+ "smlal v19.8h, v3.8b, v12.8b\n"
+ "smlal v20.8h, v0.8b, v13.8b\n"
+ "smlal v21.8h, v1.8b, v13.8b\n"
+ "smlal v22.8h, v2.8b, v13.8b\n"
+ "smlal v23.8h, v3.8b, v13.8b\n"
+
+ "smlal2 v16.8h, v0.16b, v12.16b\n"
+ "smlal2 v17.8h, v1.16b, v12.16b\n"
+ "smlal2 v18.8h, v2.16b, v12.16b\n"
+ "smlal2 v19.8h, v3.16b, v12.16b\n"
+
+ "smlal2 v20.8h, v0.16b, v13.16b\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v21.8h, v1.16b, v13.16b\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v22.8h, v2.16b, v13.16b\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v23.8h, v3.16b, v13.16b\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "smlal v16.8h, v4.8b, v14.8b\n"
+ "smlal v17.8h, v5.8b, v14.8b\n"
+ "smlal v18.8h, v6.8b, v14.8b\n"
+ "smlal v19.8h, v7.8b, v14.8b\n"
+
+ "smlal v20.8h, v4.8b, v15.8b\n"
+ "smlal v21.8h, v5.8b, v15.8b\n"
+ "smlal v22.8h, v6.8b, v15.8b\n"
+ "smlal v23.8h, v7.8b, v15.8b\n"
+
+ "smlal2 v16.8h, v4.16b, v14.16b\n"
+ "smlal2 v17.8h, v5.16b, v14.16b\n"
+ "smlal2 v18.8h, v6.16b, v14.16b\n"
+ "smlal2 v19.8h, v7.16b, v14.16b\n"
+
+ "smlal2 v20.8h, v4.16b, v15.16b\n"
+ "smlal2 v21.8h, v5.16b, v15.16b\n"
+ "smlal2 v22.8h, v6.16b, v15.16b\n"
+ "smlal2 v23.8h, v7.16b, v15.16b\n"
+
+ "bge " GEMMLOWP_LABEL_64_DEPTH_LOOP "b\n"
+
+ GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP
+ ":\n"
+
+ "cmp %w[depth], #16\n"
+ "blt " GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ GEMMLOWP_LABEL_16_DEPTH_LOOP
+ ":\n"
+ "sadalp v24.4s, v16.8h\n"
+ "smull v16.8h, v0.8b, v8.8b\n"
+ "subs %w[depth], %w[depth], #16\n"
+ "sadalp v25.4s, v17.8h\n"
+ "smull v17.8h, v1.8b, v8.8b\n"
+ "sadalp v26.4s, v18.8h\n"
+ "smull v18.8h, v2.8b, v8.8b\n"
+ "sadalp v27.4s, v19.8h\n"
+ "smull v19.8h, v3.8b, v8.8b\n"
+ "sadalp v28.4s, v20.8h\n"
+ "smull v20.8h, v0.8b, v9.8b\n"
+ "sadalp v29.4s, v21.8h\n"
+ "smull v21.8h, v1.8b, v9.8b\n"
+ "sadalp v30.4s, v22.8h\n"
+ "smull v22.8h, v2.8b, v9.8b\n"
+ "sadalp v31.4s, v23.8h\n"
+ "smull v23.8h, v3.8b, v9.8b\n"
+
+ "cmp %w[depth], #16\n"
+ "smlal2 v16.8h, v0.16b, v8.16b\n"
+ "smlal2 v17.8h, v1.16b, v8.16b\n"
+ "smlal2 v18.8h, v2.16b, v8.16b\n"
+ "smlal2 v19.8h, v3.16b, v8.16b\n"
+ "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
+
+ "smlal2 v20.8h, v0.16b, v9.16b\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v21.8h, v1.16b, v9.16b\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v22.8h, v2.16b, v9.16b\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v23.8h, v3.16b, v9.16b\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
+
+ "bge " GEMMLOWP_LABEL_16_DEPTH_LOOP "b\n"
+
+ GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP
+ ":\n"
+
+ "sadalp v24.4s, v16.8h\n"
+ "sadalp v25.4s, v17.8h\n"
+ "sadalp v26.4s, v18.8h\n"
+ "sadalp v27.4s, v19.8h\n"
+ "sadalp v28.4s, v20.8h\n"
+ "sadalp v29.4s, v21.8h\n"
+ "sadalp v30.4s, v22.8h\n"
+ "sadalp v31.4s, v23.8h\n"
+
+ // Reduce aggregators horizontally.
+ "addp v0.4s, v24.4s, v25.4s\n"
+ "addp v1.4s, v26.4s, v27.4s\n"
+ "addp v2.4s, v28.4s, v29.4s\n"
+ "addp v3.4s, v30.4s, v31.4s\n"
+
+ "addp v4.4s, v0.4s, v1.4s\n"
+ "addp v5.4s, v2.4s, v3.4s\n"
+
+ // Load accumulators from memory.
+ "mov x0, %[dst_ptr]\n"
+ "ld1 {v6.16b}, [x0], #16\n"
+ "ld1 {v7.16b}, [x0], #16\n"
+
+ // Add to the accumulators loaded from memory.
+ "add v6.4s, v6.4s, v4.4s\n"
+ "add v7.4s, v7.4s, v5.4s\n"
+
+ // Store accumulators back to memory.
+ "mov x0, %[dst_ptr]\n"
+ "st1 {v6.16b}, [x0], #16\n"
+ "st1 {v7.16b}, [x0], #16\n"
+
+ :
+ // Outputs.
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_ptr] "+r"(dst_ptr), [depth] "+r"(depth)
+ :
+ // Inputs.
+
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
+ "v31", "x0");
+ }
+};
+
+SET_7BIT_RANGES(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits);
+
+// Kernel operating on int8 operands with 4.25-bit range.
+// It is assumed that one of the two operands only takes values in [-7, 7],
+// while the other take values in [-9, 9].
+// With this restriction, it is possible to multiply-accumulate operands into
+// a 16-bit integer thirty-two times without overflow.
+struct NEON_64bit_GEMM_Int425Operands {
+ typedef std::int8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 32, CellOrder::WidthMajor>, 1>,
+ KernelSideFormat<CellFormat<2, 32, CellOrder::WidthMajor>, 1> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+#define GEMMLOWP_LABEL_512_DEPTH_LOOP "1"
+#define GEMMLOWP_LABEL_32_DEPTH_LOOP "2"
+#define GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP "3"
+
+ AccumulatorType* dst_ptr = accum_ptr;
+ int outer_depth = depth / 512 + 1;
+
+ asm volatile(
+ // Overview of register layout:
+ //
+ // A 4x32 block of Lhs is stored in 8 bit in v0--v7.
+ // A 2x32 block of Rhs is stored in 8 bit in v8--v11.
+ //
+ // A 4x2 block of global accumulators is stored in v24-v31 (as 4x32 bit
+ // components which need to be horizontally-added at the end).
+ //
+ // A 4x2 block of local accumulators is stored in v16-v23 (as 8x16 bit
+ // components which are horizontally-added to global accumulators every
+ // 512 depth iteration.
+ //
+ // The Lhs vectors are multiplied by the Rhs vectors with a multiply
+ // over the 16 first levels of depth, producing int8x16 vectors of
+ // products for each position in the accumulator matrix.
+ //
+ // Like the trick used in the fast 8-bit and 7-bit kernels, the operands
+ // are restricted to 4.25-bit range, [-7, 7] for one operand and [-9, 9]
+ // for the other operand. This enables adding two such products without
+ // any risk of overflowing int8, and thiry-two such products without
+ // overflowing int16. This equates to 512 levels of depth before
+ // horizontally adding these int16x8 accumulators into the final int32x4
+ // accumulators.
+ //
+ // Register layout (ignoring the v12--v15 temporary 8-bit accumulators).
+ // Since we do not have enough registers to store all Lhs values and Rhs
+ // values, we reuse the same registers v0--v7 to load subsequent Lhs
+ // values and v8-v11 to subsequent Rhs values.
+ //
+ // +-----+-----+
+ // | v8 | v9 |
+ // Rhs +-----+-----+
+ // | v10 | v11 |
+ // +-----+-----+
+ // | v8 | v9 |
+ // +-----+-----+
+ // | v10 | v11 |
+ // Lhs +-----+-----+
+ // +----+----+----+----+ - - +-----+-----+ +--------+--------+
+ // | v0 | v4 | v0 | v4 | | v16 | v17 | | v24.4s | v25.4s |
+ // | v1 | v5 | v1 | v5 | | v18 | v19 | -> | v26.4s | v27.4s |
+ // | v2 | v6 | v2 | v6 | | v20 | v21 | | v28.4s | v29.4s |
+ // | v3 | v7 | v3 | v7 | | v22 | v23 | | v30.4s | v31.4s |
+ // +----+----+----+----+ - - +-----+-----+ +--------+--------+
+ //
+ // Local Accumulator Global Accumulator
+ //
+
+ // Clear global accumulators.
+ "dup v24.4s, wzr\n"
+ "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
+ "dup v25.4s, wzr\n"
+ "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
+ "dup v26.4s, wzr\n"
+ "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
+ "dup v27.4s, wzr\n"
+ "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
+ "dup v28.4s, wzr\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "dup v29.4s, wzr\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "dup v30.4s, wzr\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "dup v31.4s, wzr\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+
+ //"loop_%=:\n"
+ GEMMLOWP_LABEL_512_DEPTH_LOOP
+ ":\n"
+ // Clear local accumulators.
+ "dup v16.8h, wzr\n"
+ "dup v17.8h, wzr\n"
+ "dup v18.8h, wzr\n"
+ "mov x1, #512\n"
+ "dup v19.8h, wzr\n"
+ "dup v20.8h, wzr\n"
+ "dup v21.8h, wzr\n"
+ "dup v22.8h, wzr\n"
+ "dup v23.8h, wzr\n"
+
+ //"loop_%=:\n"
+ GEMMLOWP_LABEL_32_DEPTH_LOOP
+ ":\n"
+ "mul v12.16b, v0.16b, v8.16b\n"
+ "mul v13.16b, v0.16b, v10.16b\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "mul v14.16b, v2.16b, v8.16b\n"
+ "mul v15.16b, v2.16b, v10.16b\n"
+
+ "mla v12.16b, v1.16b, v9.16b\n"
+ "mla v13.16b, v1.16b, v11.16b\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "mla v14.16b, v3.16b, v9.16b\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "mla v15.16b, v3.16b, v11.16b\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "sadalp v16.8h, v12.16b\n"
+ "sadalp v17.8h, v13.16b\n"
+ "subs %w[depth], %w[depth], #32\n"
+ "sadalp v18.8h, v14.16b\n"
+ "sadalp v19.8h, v15.16b\n"
+ "subs x1, x1, #32\n"
+
+ "mul v12.16b, v4.16b, v8.16b\n"
+ "mul v13.16b, v4.16b, v10.16b\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ "mul v14.16b, v6.16b, v8.16b\n"
+ "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
+ "mul v15.16b, v6.16b, v10.16b\n"
+
+ "mla v12.16b, v5.16b, v9.16b\n"
+ "mla v13.16b, v5.16b, v11.16b\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ "mla v14.16b, v7.16b, v9.16b\n"
+ "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
+ "mla v15.16b, v7.16b, v11.16b\n"
+ "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
+
+ "sadalp v20.8h, v12.16b\n"
+ "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
+ "sadalp v21.8h, v13.16b\n"
+ "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+ "sadalp v22.8h, v14.16b\n"
+ "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+ "sadalp v23.8h, v15.16b\n"
+
+ "mul v12.16b, v0.16b, v8.16b\n"
+ "mul v13.16b, v0.16b, v10.16b\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "mul v14.16b, v2.16b, v8.16b\n"
+ "mul v15.16b, v2.16b, v10.16b\n"
+
+ "mla v12.16b, v1.16b, v9.16b\n"
+ "mla v13.16b, v1.16b, v11.16b\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "mla v14.16b, v3.16b, v9.16b\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "mla v15.16b, v3.16b, v11.16b\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "sadalp v16.8h, v12.16b\n"
+ "sadalp v17.8h, v13.16b\n"
+ "sadalp v18.8h, v14.16b\n"
+ "sadalp v19.8h, v15.16b\n"
+
+ "mul v12.16b, v4.16b, v8.16b\n"
+ "mul v13.16b, v4.16b, v10.16b\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ "mul v14.16b, v6.16b, v8.16b\n"
+ "ld1 {v8.16b}, [%[rhs_ptr]], #16\n"
+ "mul v15.16b, v6.16b, v10.16b\n"
+
+ "mla v12.16b, v5.16b, v9.16b\n"
+ "mla v13.16b, v5.16b, v11.16b\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ "mla v14.16b, v7.16b, v9.16b\n"
+ "ld1 {v9.16b}, [%[rhs_ptr]], #16\n"
+ "mla v15.16b, v7.16b, v11.16b\n"
+ "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
+
+ "sadalp v20.8h, v12.16b\n"
+ "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
+ "sadalp v21.8h, v13.16b\n"
+ "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+ "sadalp v22.8h, v14.16b\n"
+ "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+ "sadalp v23.8h, v15.16b\n"
+
+ "beq " GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP
+ "f\n"
+
+ "cmp %w[depth], #0\n"
+ "bne " GEMMLOWP_LABEL_32_DEPTH_LOOP "b\n"
+
+ GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP
+ ":\n"
+
+ // Pairwise add 16-bit local accums to 32-bit global accums.
+ "sadalp v24.4s, v16.8h\n"
+ "sadalp v25.4s, v17.8h\n"
+ "sadalp v26.4s, v18.8h\n"
+ "sadalp v27.4s, v19.8h\n"
+ "sadalp v28.4s, v20.8h\n"
+ "sadalp v29.4s, v21.8h\n"
+ "sadalp v30.4s, v22.8h\n"
+ "sadalp v31.4s, v23.8h\n"
+
+ "bne " GEMMLOWP_LABEL_512_DEPTH_LOOP
+ "b\n"
+
+ // Reduce aggregators horizontally.
+ "addp v0.4s, v24.4s, v26.4s\n"
+ "addp v1.4s, v28.4s, v30.4s\n"
+ "addp v2.4s, v25.4s, v27.4s\n"
+ "addp v3.4s, v29.4s, v31.4s\n"
+
+ "addp v4.4s, v0.4s, v1.4s\n"
+ "addp v5.4s, v2.4s, v3.4s\n"
+
+ // Load accumulators from memory.
+ "mov x0, %[dst_ptr]\n"
+ "ld1 {v6.16b}, [x0], #16\n"
+ "ld1 {v7.16b}, [x0], #16\n"
+
+ // Add to the accumulators loaded from memory.
+ "add v6.4s, v6.4s, v4.4s\n"
+ "add v7.4s, v7.4s, v5.4s\n"
+
+ // Store accumulators back to memory.
+ "mov x0, %[dst_ptr]\n"
+ "st1 {v6.16b}, [x0], #16\n"
+ "st1 {v7.16b}, [x0], #16\n"
+
+ :
+ // Outputs.
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_ptr] "+r"(dst_ptr), [depth] "+r"(depth),
+ [outer_depth] "+r"(outer_depth)
+ :
+ // Inputs.
+
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
+ "v31", "x0", "x1");
+ }
+};
+
+SET_425BIT_RANGES(NEON_64bit_GEMM_Int425Operands);
+
#ifdef __ARM_FEATURE_DOTPROD
// Kernels utilizing the Armv8.2 Dot Product extension.
//
@@ -2582,41 +3437,41 @@ struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct {
// Start the MACs at the head of the loop - 1st cell from each side
// already loaded.
- "udot v8.4s, v2.16b, v0.b[0]\n"
- "udot v9.4s, v2.16b, v0.b[1]\n"
+ ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n"
+ ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n"
"ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
- "udot v10.4s, v2.16b, v0.b[2]\n"
- "udot v11.4s, v2.16b, v0.b[3]\n"
+ ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n"
+ ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
- "udot v12.4s, v2.16b, v1.b[0]\n"
- "udot v13.4s, v2.16b, v1.b[1]\n"
+ ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n"
+ ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n"
"ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
- "udot v14.4s, v2.16b, v1.b[2]\n"
- "udot v15.4s, v2.16b, v1.b[3]\n"
+ ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n"
+ ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
// for the next iteration early.
- "udot v16.4s, v3.16b, v0.b[0]\n"
- "udot v17.4s, v3.16b, v0.b[1]\n"
- "udot v18.4s, v3.16b, v0.b[2]\n"
- "udot v19.4s, v3.16b, v0.b[3]\n"
- "udot v20.4s, v3.16b, v1.b[0]\n"
- "udot v21.4s, v3.16b, v1.b[1]\n"
- "udot v22.4s, v3.16b, v1.b[2]\n"
- "udot v23.4s, v3.16b, v1.b[3]\n"
- "udot v24.4s, v4.16b, v0.b[0]\n"
- "udot v25.4s, v4.16b, v0.b[1]\n"
- "udot v26.4s, v4.16b, v0.b[2]\n"
- "udot v27.4s, v4.16b, v0.b[3]\n"
+ ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n"
+ ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n"
+ ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n"
+ ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n"
+ ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n"
+ ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n"
+ ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n"
+ ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n"
+ ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n"
+ ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n"
+ ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n"
+ ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n"
"ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
// load for the next iteration early.
- "udot v28.4s, v4.16b, v1.b[0]\n"
- "udot v29.4s, v4.16b, v1.b[1]\n"
+ ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n"
+ ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n"
// Loop. Decrement loop index (depth) by 4 as udot processes 4
// depth values.
"subs %w[depth], %w[depth], #4\n"
- "udot v30.4s, v4.16b, v1.b[2]\n"
- "udot v31.4s, v4.16b, v1.b[3]\n"
+ ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n"
+ ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n"
"bne " GEMMLOWP_LABEL_LOOP
"b\n"
@@ -2712,54 +3567,67 @@ struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 {
GEMMLOWP_LABEL_LOOP
":\n"
- "udot v8.4s, v2.16b, v0.b[0]\n"
- "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
- "udot v9.4s, v2.16b, v0.b[1]\n"
- "ins v0.d[1], x18\n" // Finish loading v0
- "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure.
- "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
- "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure.
- "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
- "udot v10.4s, v2.16b, v0.b[2]\n"
- "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
- "udot v11.4s, v2.16b, v0.b[3]\n"
- "ins v1.d[1], x18\n" // Finish loading v1
- "udot v12.4s, v2.16b, v1.b[0]\n"
- "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
- "udot v13.4s, v2.16b, v1.b[1]\n"
- "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
- "udot v14.4s, v2.16b, v1.b[2]\n"
-
- "udot v15.4s, v2.16b, v1.b[3]\n"
- "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
- "udot v18.4s, v3.16b, v0.b[2]\n"
- "ins v4.d[1], x18\n" // Finish loading v4
- "udot v19.4s, v3.16b, v0.b[3]\n"
- "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
- "udot v20.4s, v3.16b, v1.b[0]\n"
+ ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n"
+ "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
+ ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n"
+ "ins v0.d[1], x18\n" // Finish loading v0
+ ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n" // out of
+ // sequence -
+ // used to
+ // reduce
+ // load/use
+ // pressure.
+ "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
+ ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n" // out of
+ // sequence -
+ // used to
+ // reduce
+ // load/use
+ // pressure.
+ "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment
+ // pointer.
+ ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n"
+ "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
+ ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n"
+ "ins v1.d[1], x18\n" // Finish loading v1
+ ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n"
+ "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
+ ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment
+ // pointer.
+ ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n"
+
+ ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n"
+ "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
+ ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n"
+ "ins v4.d[1], x18\n" // Finish loading v4
+ ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n"
+ "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
+ ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n"
"subs %w[depth], %w[depth], #4\n"
- "udot v21.4s, v3.16b, v1.b[1]\n"
-
- "udot v22.4s, v3.16b, v1.b[2]\n"
-
- "udot v23.4s, v3.16b, v1.b[3]\n"
- "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
- "udot v24.4s, v4.16b, v0.b[0]\n"
- "ins v2.d[1], x18\n" // Finish loading next v2
- "udot v25.4s, v4.16b, v0.b[1]\n"
- "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
- "udot v26.4s, v4.16b, v0.b[2]\n"
-
- "udot v27.4s, v4.16b, v0.b[3]\n"
- "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
- "udot v28.4s, v4.16b, v1.b[0]\n"
- "ins v3.d[1], x18\n" // Finish loading next v3
- "udot v29.4s, v4.16b, v1.b[1]\n"
- "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
- "udot v30.4s, v4.16b, v1.b[2]\n"
-
- "udot v31.4s, v4.16b, v1.b[3]\n"
- "bne " GEMMLOWP_LABEL_LOOP "b\n"
+ ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n"
+
+ ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n"
+
+ ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n"
+ "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
+ ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n"
+ "ins v2.d[1], x18\n" // Finish loading next v2
+ ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n"
+ "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
+ ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n"
+
+ ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n"
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
+ ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n"
+ "ins v3.d[1], x18\n" // Finish loading next v3
+ ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
+ ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n"
+
+ ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n"
+ "bne " GEMMLOWP_LABEL_LOOP
+ "b\n"
// Store accumulators
"mov x0, %[accum_ptr]\n"
@@ -3852,533 +4720,195 @@ using NEON_32bit_GEMM_Float32_WithScalar_intrinsics =
using NEON_64bit_GEMM_Float32_WithScalar_intrinsics =
NEON_GEMM_Float32_WithScalar_intrinsics<2>;
-#endif // __arm__ || __aarch64__
-#ifdef __mips
-static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
- // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
-#if 0
- return __builtin_msa_maddv_w(a, b, c);
-#else
- asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
- // Outputs
- : [a] "+f"(a)
- // Inputs
- : [b] "f"(b), [c] "f"(c));
- return a;
-#endif
-}
-
-// Using 32x32=32 multiplications.
-// 20 MSA regs used:
-// - 12 accumulators
-// - 6 lhs
-// - 1 rhs
-// - 1 temps/zeroes
-// ~55 instructions in the loop.
-struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics {
- typedef std::uint8_t OperandType;
+// C++ intrinsics-based variant of the deep, 7-bit, fast kernel
+struct NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics {
+ typedef std::int8_t OperandType;
typedef std::int32_t AccumulatorType;
typedef KernelFormat<
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+ KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
Format;
static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
AccumulatorType* accum_ptr, int depth) {
- const v16i8 zeroes = __builtin_msa_ldi_b(0);
- v4i32 acc[3][4];
- // Load accumulators.
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 4; j++) {
- acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
+ int32x4_t acc[4][2];
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 2; j++) {
+ acc[i][j] = vdupq_n_s32(0);
}
}
- while (depth > 0) {
- // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
- v8i16 lhs[6];
- lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
- lhs[1] =
- reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
-
- // Zero-extend 8-bit elements of lhs[] to 16 bits.
- lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
- reinterpret_cast<v16i8>(lhs[0])));
- lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
- reinterpret_cast<v16i8>(lhs[1])));
- lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
- reinterpret_cast<v16i8>(lhs[1])));
-
- // Zero-extend 16-bit elements of lhs[] to 32 bits.
- lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
- lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
- lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
- lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
- lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
- lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
-
- // Depth 0.
- for (int j = 0; j < 4; j++) {
- // Load 1 byte of rhs, making 4 32-bit replicas of it.
- v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
- // Multiply-add into accumulators.
- for (int i = 0; i < 3; i++) {
- acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ int d = 0;
+ for (; d <= depth - 64; d += 64) {
+ int16x8_t local_acc[4][2];
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 2; j++) {
+ local_acc[i][j] = vdupq_n_s16(0);
}
}
- // Depth 1.
- for (int j = 0; j < 4; j++) {
- // Load 1 byte of rhs, making 4 32-bit replicas of it.
- v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
- // Multiply-add into accumulators.
- for (int i = 0; i < 3; i++) {
- acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ // There are not enough registers to fit all lhs and rhs values for 64
+ // depth. Instead, load values for 32 depth at a time.
+ for (int k = 0; k < 2; k++) {
+ int8x16_t lhs[4][2];
+ for (int i = 0; i < 4; i++) {
+ lhs[i][0] = vld1q_s8(lhs_ptr + 16 * i + 128 * k);
+ lhs[i][1] = vld1q_s8(lhs_ptr + 64 + 16 * i + 128 * k);
+ }
+
+ int8x16_t rhs[4];
+ for (int i = 0; i < 4; i++) {
+ rhs[i] = vld1q_s8(rhs_ptr + 16 * i + 64 * k);
+ }
+
+ for (int i = 0; i < 4; i++) {
+ if (k == 0) {
+ local_acc[i][0] = vmull_s8(vget_low_s8(lhs[i][0]),
+ vget_low_s8(rhs[0]));
+ local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][1]),
+ vget_low_s8(rhs[2]));
+ local_acc[i][1] = vmull_s8(vget_low_s8(lhs[i][0]),
+ vget_low_s8(rhs[1]));
+ local_acc[i][1] = vmlal_s8(local_acc[i][1],
+ vget_low_s8(lhs[i][1]),
+ vget_low_s8(rhs[3]));
+ } else {
+ local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][0]),
+ vget_low_s8(rhs[0]));
+ local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][1]),
+ vget_low_s8(rhs[2]));
+ local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_low_s8(lhs[i][0]),
+ vget_low_s8(rhs[1]));
+ local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_low_s8(lhs[i][1]),
+ vget_low_s8(rhs[3]));
+ }
+
+ local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_high_s8(lhs[i][0]),
+ vget_high_s8(rhs[0]));
+ local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_high_s8(lhs[i][1]),
+ vget_high_s8(rhs[2]));
+ local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_high_s8(lhs[i][0]),
+ vget_high_s8(rhs[1]));
+ local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_high_s8(lhs[i][1]),
+ vget_high_s8(rhs[3]));
}
}
- lhs_ptr += 24;
- rhs_ptr += 8;
- depth -= 2;
+ for (int i = 0; i < 4; i++) {
+ acc[i][0] = vpadalq_s16(acc[i][0], local_acc[i][0]);
+ acc[i][1] = vpadalq_s16(acc[i][1], local_acc[i][1]);
+ }
+
+ lhs_ptr += 64 * 4;
+ rhs_ptr += 64 * 2;
}
+ for (; d <= depth - 16; d += 16) {
+ int8x16_t lhs[4];
+ for (int i = 0; i < 4; i++) {
+ lhs[i] = vld1q_s8(lhs_ptr + 16 * i);
+ }
+ int8x16_t rhs[2];
+ for (int i = 0; i < 2; i++) {
+ rhs[i] = vld1q_s8(rhs_ptr + 16 * i);
+ }
- // Store accumulators.
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 4; j++) {
- __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 2; j++) {
+ int16x8_t local_acc =
+ vmull_s8(vget_low_s8(lhs[i]), vget_low_s8(rhs[j]));
+ local_acc =
+ vmlal_s8(local_acc, vget_high_s8(lhs[i]), vget_high_s8(rhs[j]));
+ acc[i][j] = vpadalq_s16(acc[i][j], local_acc);
+ }
}
+ lhs_ptr += 16 * 4;
+ rhs_ptr += 16 * 2;
+ }
+ for (int i = 0; i < 2; i++) {
+ int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]);
+ int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]);
+ int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1);
+ int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i);
+ dst_val = vaddq_s32(dst_val, acc_4x);
+ vst1q_s32(accum_ptr + 4 * i, dst_val);
}
}
};
-// Assembly implementation of the above
-// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics.
-// Using 32x32=32 multiplications.
-// 20 MSA regs used:
-// - 12 accumulators
-// - 6 lhs
-// - 1 rhs
-// - 1 temps/zeroes
-// ~55 instructions in the loop.
-struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly {
- typedef std::uint8_t OperandType;
- typedef std::int32_t AccumulatorType;
- typedef KernelFormat<
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
- Format;
- static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
- AccumulatorType* accum_ptr, int depth) {
- asm volatile(
- // Load accumulators
- "ld.w $w0, (0*16)(%[accum_ptr])\n"
- "ld.w $w4, (1*16)(%[accum_ptr])\n"
- "ld.w $w8, (2*16)(%[accum_ptr])\n"
- "ld.w $w1, (3*16)(%[accum_ptr])\n"
- "ld.w $w5, (4*16)(%[accum_ptr])\n"
- "ld.w $w9, (5*16)(%[accum_ptr])\n"
- "ld.w $w2, (6*16)(%[accum_ptr])\n"
- "ld.w $w6, (7*16)(%[accum_ptr])\n"
- "ld.w $w10, (8*16)(%[accum_ptr])\n"
- "ld.w $w3, (9*16)(%[accum_ptr])\n"
- "ld.w $w7, (10*16)(%[accum_ptr])\n"
- "ld.w $w11, (11*16)(%[accum_ptr])\n"
- // Set a temp to all zeroes.
- "ldi.b $w19, 0\n"
-
- GEMMLOWP_LABEL_LOOP ":\n"
- // Overview of register layout:
- //
- // A half of the 2x4 cell of Rhs is stored in 32bit in w18.
- // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17.
- // A 12x4 block of accumulators is stored in 32bit in w0-w11.
- //
- // +------+------+------+------+
- // Rhs |w18[0]|w18[1]|w18[2]|w18[3]|
- // +------+------+------+------+
- //
- // | | | | |
- //
- // Lhs | | | | |
- //
- // +---+---+ - - - - +------+------+------+------+
- // |w12|w15| | w0 | w1 | w2 | w3 |
- // |w12|w15| | w0 | w1 | w2 | w3 |
- // |w12|w15| | w0 | w1 | w2 | w3 |
- // |w12|w15| | w0 | w1 | w2 | w3 |
- // +---+---+ - - - - +------+------+------+------+
- // |w13|w16| | w4 | w5 | w6 | w7 |
- // |w13|w16| | w4 | w5 | w6 | w7 |
- // |w13|w16| | w4 | w5 | w6 | w7 |
- // |w13|w16| | w4 | w5 | w6 | w7 |
- // +---+---+ - - - - +------+------+------+------+
- // |w14|w17| | w8 | w9 | w10 | w11 |
- // |w14|w17| | w8 | w9 | w10 | w11 |
- // |w14|w17| | w8 | w9 | w10 | w11 |
- // |w14|w17| | w8 | w9 | w10 | w11 |
- // +---+---+ - - - - +------+------+------+------+
- //
- // Accumulator
+SET_7BIT_RANGES(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics);
- // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
- "ld.b $w12, 0(%[lhs_ptr])\n"
- "ld.b $w13, 8(%[lhs_ptr])\n"
-
- // Load 4 bytes of rhs[] for depth 0.
- "lbu $a0, 0(%[rhs_ptr])\n"
- "lbu $a1, 1(%[rhs_ptr])\n"
- "lbu $a2, 2(%[rhs_ptr])\n"
- "lbu $a3, 3(%[rhs_ptr])\n"
-
- // Zero-extend 8-bit elements of lhs[] to 16 bits.
- "ilvr.b $w12, $w19, $w12\n"
- "ilvl.b $w14, $w19, $w13\n"
- "ilvr.b $w13, $w19, $w13\n"
- // Zero-extend 16-bit elements of lhs[] to 32 bits.
- "ilvl.h $w15, $w19, $w12\n"
- "ilvl.h $w16, $w19, $w13\n"
- "ilvl.h $w17, $w19, $w14\n"
- "ilvr.h $w12, $w19, $w12\n"
- "ilvr.h $w13, $w19, $w13\n"
- "ilvr.h $w14, $w19, $w14\n"
-
- // Depth 0.
- "fill.w $w18, $a0\n"
- "lbu $a0, 4(%[rhs_ptr])\n"
- "maddv.w $w0, $w12, $w18\n"
- "maddv.w $w4, $w13, $w18\n"
- "maddv.w $w8, $w14, $w18\n"
- "fill.w $w18, $a1\n"
- "lbu $a1, 5(%[rhs_ptr])\n"
- "maddv.w $w1, $w12, $w18\n"
- "maddv.w $w5, $w13, $w18\n"
- "maddv.w $w9, $w14, $w18\n"
- "fill.w $w18, $a2\n"
- "lbu $a2, 6(%[rhs_ptr])\n"
- "maddv.w $w2, $w12, $w18\n"
- "maddv.w $w6, $w13, $w18\n"
- "maddv.w $w10, $w14, $w18\n"
- "fill.w $w18, $a3\n"
- "lbu $a3, 7(%[rhs_ptr])\n"
- "maddv.w $w3, $w12, $w18\n"
- "maddv.w $w7, $w13, $w18\n"
- "maddv.w $w11, $w14, $w18\n"
-
- // Depth 1.
- "fill.w $w18, $a0\n"
- "maddv.w $w0, $w15, $w18\n"
- "maddv.w $w4, $w16, $w18\n"
- "maddv.w $w8, $w17, $w18\n"
- "fill.w $w18, $a1\n"
- "maddv.w $w1, $w15, $w18\n"
- "maddv.w $w5, $w16, $w18\n"
- "maddv.w $w9, $w17, $w18\n"
- "fill.w $w18, $a2\n"
- "maddv.w $w2, $w15, $w18\n"
- "maddv.w $w6, $w16, $w18\n"
- "maddv.w $w10, $w17, $w18\n"
- "fill.w $w18, $a3\n"
- "maddv.w $w3, $w15, $w18\n"
- "maddv.w $w7, $w16, $w18\n"
- "maddv.w $w11, $w17, $w18\n"
-
- "addiu %[depth], -2\n"
- GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
- GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
- "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
-
- // Store accumulators.
- "st.w $w0, (0*16)(%[accum_ptr])\n"
- "st.w $w4, (1*16)(%[accum_ptr])\n"
- "st.w $w8, (2*16)(%[accum_ptr])\n"
- "st.w $w1, (3*16)(%[accum_ptr])\n"
- "st.w $w5, (4*16)(%[accum_ptr])\n"
- "st.w $w9, (5*16)(%[accum_ptr])\n"
- "st.w $w2, (6*16)(%[accum_ptr])\n"
- "st.w $w6, (7*16)(%[accum_ptr])\n"
- "st.w $w10, (8*16)(%[accum_ptr])\n"
- "st.w $w3, (9*16)(%[accum_ptr])\n"
- "st.w $w7, (10*16)(%[accum_ptr])\n"
- "st.w $w11, (11*16)(%[accum_ptr])\n"
- : // outputs
- [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
- [depth] "+r"(depth)
- : // inputs
- [accum_ptr] "r"(accum_ptr)
- : // clobbers
- "memory",
- "a0", "a1", "a2", "a3",
- "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
- "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
- "$f16", "$f17", "$f18", "$f19");
- }
-};
-
-// Assembly implementation of the above
-// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
-// Using 16x16=32 multiplications.
-// 20 MSA regs used:
-// - 12 accumulators
-// - 3 lhs
-// - 4 rhs
-// - 1 temps/zeroes
-// ~45 instructions in the loop.
-struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 {
- typedef std::uint8_t OperandType;
+// C++ intrinsics-based variant of the deep, 4.25-bit, fast kernel.
+struct NEON_64bit_GEMM_Int425Operands_intrinsics {
+ typedef std::int8_t OperandType;
typedef std::int32_t AccumulatorType;
typedef KernelFormat<
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
- Format;
- static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
- AccumulatorType* accum_ptr, int depth) {
- asm volatile(
- // Load accumulators
- "ld.w $w0, (0*16)(%[accum_ptr])\n"
- "ld.w $w4, (1*16)(%[accum_ptr])\n"
- "ld.w $w8, (2*16)(%[accum_ptr])\n"
- "ld.w $w1, (3*16)(%[accum_ptr])\n"
- "ld.w $w5, (4*16)(%[accum_ptr])\n"
- "ld.w $w9, (5*16)(%[accum_ptr])\n"
- "ld.w $w2, (6*16)(%[accum_ptr])\n"
- "ld.w $w6, (7*16)(%[accum_ptr])\n"
- "ld.w $w10, (8*16)(%[accum_ptr])\n"
- "ld.w $w3, (9*16)(%[accum_ptr])\n"
- "ld.w $w7, (10*16)(%[accum_ptr])\n"
- "ld.w $w11, (11*16)(%[accum_ptr])\n"
- // Set a temp to all zeroes.
- "ldi.b $w19, 0\n"
-
- GEMMLOWP_LABEL_LOOP ":\n"
- // Overview of register layout:
- //
- // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register
- // contains 4 replicas of a pair of elements).
- // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14.
- // A 12x4 block of accumulators is stored in 32bit in w0-w11.
- //
- // +-----+-----+-----+-----+
- // Rhs | w15 | w16 | w17 | w18 |
- // +-----+-----+-----+-----+
- //
- // | | | | |
- //
- // Lhs | | | | |
- //
- // +---+ - - - - +-----+-----+-----+-----+
- // |w12| | w0 | w1 | w2 | w3 |
- // |w12| | w0 | w1 | w2 | w3 |
- // |w12| | w0 | w1 | w2 | w3 |
- // |w12| | w0 | w1 | w2 | w3 |
- // +---+ - - - - +-----+-----+-----+-----+
- // |w13| | w4 | w5 | w6 | w7 |
- // |w13| | w4 | w5 | w6 | w7 |
- // |w13| | w4 | w5 | w6 | w7 |
- // |w13| | w4 | w5 | w6 | w7 |
- // +---+ - - - - +-----+-----+-----+-----+
- // |w14| | w8 | w9 | w10 | w11 |
- // |w14| | w8 | w9 | w10 | w11 |
- // |w14| | w8 | w9 | w10 | w11 |
- // |w14| | w8 | w9 | w10 | w11 |
- // +---+ - - - - +-----+-----+-----+-----+
- //
- // Accumulators
-
- // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
- "ld.b $w12, 0(%[lhs_ptr])\n"
- "ld.b $w13, 8(%[lhs_ptr])\n"
-
- // Load 4 bytes of rhs[] for depth 0.
- "lbu $a0, 0(%[rhs_ptr])\n"
- "lbu $a1, 1(%[rhs_ptr])\n"
- "lbu $a2, 2(%[rhs_ptr])\n"
- "lbu $a3, 3(%[rhs_ptr])\n"
- // Load 4 bytes of rhs[] for depth 1.
- "lbu $v0, 4(%[rhs_ptr])\n"
- "lbu $v1, 5(%[rhs_ptr])\n"
- "lbu $t8, 6(%[rhs_ptr])\n"
- "lbu $t9, 7(%[rhs_ptr])\n"
-
- // Zero-extend 8-bit elements of lhs[] to 16 bits.
- "ilvr.b $w12, $w19, $w12\n"
- "ilvl.b $w14, $w19, $w13\n"
- "ilvr.b $w13, $w19, $w13\n"
- // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
- "ilvl.d $w15, $w19, $w12\n"
- "ilvl.d $w16, $w19, $w13\n"
- "ilvl.d $w17, $w19, $w14\n"
- "ilvr.h $w12, $w15, $w12\n"
- "ilvr.h $w13, $w16, $w13\n"
- "ilvr.h $w14, $w17, $w14\n"
-
- // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w.
- "ins $a0, $v0, 16, 8\n"
- "ins $a1, $v1, 16, 8\n"
- "ins $a2, $t8, 16, 8\n"
- "ins $a3, $t9, 16, 8\n"
- // Make 4 replicas of every pair of rhs[] elements.
- "fill.w $w15, $a0\n"
- "fill.w $w16, $a1\n"
- "fill.w $w17, $a2\n"
- "fill.w $w18, $a3\n"
-
- // Depths 0 and 1.
- // Dot-product-(and)-add doubles multiplicand width.
- "dpadd_u.w $w0, $w12, $w15\n"
- "dpadd_u.w $w4, $w13, $w15\n"
- "dpadd_u.w $w8, $w14, $w15\n"
- "dpadd_u.w $w1, $w12, $w16\n"
- "dpadd_u.w $w5, $w13, $w16\n"
- "dpadd_u.w $w9, $w14, $w16\n"
- "dpadd_u.w $w2, $w12, $w17\n"
- "dpadd_u.w $w6, $w13, $w17\n"
- "dpadd_u.w $w10, $w14, $w17\n"
- "dpadd_u.w $w3, $w12, $w18\n"
- "dpadd_u.w $w7, $w13, $w18\n"
- "dpadd_u.w $w11, $w14, $w18\n"
-
- "addiu %[depth], -2\n"
- GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
- GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
- "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
-
- // Store accumulators.
- "st.w $w0, (0*16)(%[accum_ptr])\n"
- "st.w $w4, (1*16)(%[accum_ptr])\n"
- "st.w $w8, (2*16)(%[accum_ptr])\n"
- "st.w $w1, (3*16)(%[accum_ptr])\n"
- "st.w $w5, (4*16)(%[accum_ptr])\n"
- "st.w $w9, (5*16)(%[accum_ptr])\n"
- "st.w $w2, (6*16)(%[accum_ptr])\n"
- "st.w $w6, (7*16)(%[accum_ptr])\n"
- "st.w $w10, (8*16)(%[accum_ptr])\n"
- "st.w $w3, (9*16)(%[accum_ptr])\n"
- "st.w $w7, (10*16)(%[accum_ptr])\n"
- "st.w $w11, (11*16)(%[accum_ptr])\n"
- : // outputs
- [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
- [depth] "+r"(depth)
- : // inputs
- [accum_ptr] "r"(accum_ptr)
- : // clobbers
- "memory",
- "v0", "v1",
- "a0", "a1", "a2", "a3",
- "t8", "t9",
- "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
- "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
- "$f16", "$f17", "$f18", "$f19");
- }
-};
-
-// Using 32x32=32 multiplications.
-// 32 MSA regs used:
-// - 24 accumulators
-// - 6 lhs
-// - 1 rhs
-// - 1 temps/zeroes
-// ~95 instructions in the loop.
-struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics {
- typedef std::uint8_t OperandType;
- typedef std::uint32_t AccumulatorType;
- typedef KernelFormat<
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ KernelSideFormat<CellFormat<4, 32, CellOrder::WidthMajor>, 1>,
+ KernelSideFormat<CellFormat<2, 32, CellOrder::WidthMajor>, 1> >
Format;
static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
AccumulatorType* accum_ptr, int depth) {
- const v16i8 zeroes = __builtin_msa_ldi_b(0);
- v4i32 acc[3][8];
- // Load accumulators.
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 8; j++) {
- acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
+ int32x4_t acc[4][2];
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 2; j++) {
+ acc[i][j] = vdupq_n_s32(0);
}
}
- while (depth > 0) {
- // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
- v8i16 lhs[6];
- lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
- lhs[1] =
- reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
-
- // Zero-extend 8-bit elements of lhs[] to 16 bits.
- lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
- reinterpret_cast<v16i8>(lhs[0])));
- lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
- reinterpret_cast<v16i8>(lhs[1])));
- lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
- reinterpret_cast<v16i8>(lhs[1])));
-
- // Zero-extend 16-bit elements of lhs[] to 32 bits.
- lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
- lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
- lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
- lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
- lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
- lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
-
- // Depth 0.
- for (int j = 0; j < 4; j++) {
- // Load 1 byte of rhs, making 4 32-bit replicas of it.
- v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
- // Multiply-add into accumulators.
- for (int i = 0; i < 3; i++) {
- acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ const int num_outer_depth_loop = depth / 512 + 1;
+ int d = 0;
+ for (int od = 0; od < num_outer_depth_loop; od++) {
+ int16x8_t local_acc[4][2];
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 2; j++) {
+ local_acc[i][j] = vdupq_n_s16(0);
}
}
- for (int j = 4; j < 8; j++) {
- // Load 1 byte of rhs, making 4 32-bit replicas of it.
- v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
- // Multiply-add into accumulators.
- for (int i = 0; i < 3; i++) {
- acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ for (int k = 0; k < 16 && d <= depth - 32; k++, d += 32) {
+ int8x16_t lhs[8];
+ for (int i = 0; i < 8; i++) {
+ lhs[i] = vld1q_s8(lhs_ptr + 16 * i);
}
- }
-
- // Depth 1.
- for (int j = 0; j < 4; j++) {
- // Load 1 byte of rhs, making 4 32-bit replicas of it.
- v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
- // Multiply-add into accumulators.
- for (int i = 0; i < 3; i++) {
- acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ int8x16_t rhs[4];
+ for (int i = 0; i < 4; i++) {
+ rhs[i] = vld1q_s8(rhs_ptr + 16 * i);
}
- }
- for (int j = 4; j < 8; j++) {
- // Load 1 byte of rhs, making 4 32-bit replicas of it.
- v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 8]));
- // Multiply-add into accumulators.
- for (int i = 0; i < 3; i++) {
- acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 2; j++) {
+ int8x16_t temp_acc = vmulq_s8(lhs[i * 2], rhs[j * 2]);
+ temp_acc = vmlaq_s8(temp_acc, lhs[i * 2 + 1], rhs[j * 2 + 1]);
+ local_acc[i][j] = vpadalq_s8(local_acc[i][j], temp_acc);
+ }
}
+ lhs_ptr += 128;
+ rhs_ptr += 64;
}
- lhs_ptr += 24;
- rhs_ptr += 16;
- depth -= 2;
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 2; j++) {
+ acc[i][j] = vpadalq_s16(acc[i][j], local_acc[i][j]);
+ }
+ }
}
- // Store accumulators.
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 8; j++) {
- __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
- }
+ for (int i = 0; i < 2; i++) {
+ int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]);
+ int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]);
+ int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1);
+
+ int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i);
+ dst_val = vaddq_s32(dst_val, acc_4x);
+ vst1q_s32(accum_ptr + 4 * i, dst_val);
}
}
};
-// Assembly implementation of the above
-// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics.
-// Using 32x32=32 multiplications.
-// 32 MSA regs used:
-// - 24 accumulators
-// - 6 lhs
-// - 1 rhs
-// - 1 temps/zeroes
-// ~95 instructions in the loop.
-struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
+SET_425BIT_RANGES(NEON_64bit_GEMM_Int425Operands_intrinsics);
+
+#endif // __arm__ || __aarch64__
+
+#ifdef __mips
+// 12x8 depth 2 depth-major kernel.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators1 {
typedef std::uint8_t OperandType;
typedef std::uint32_t AccumulatorType;
typedef KernelFormat<
@@ -4419,36 +4949,37 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
GEMMLOWP_LABEL_LOOP ":\n"
// Overview of register layout:
//
- // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30.
- // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29.
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+ // (each register contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
// A 12x8 block of accumulators is stored in 32bit in w0-w23.
//
// +------+------+------+------+
- // Rhs |w30[0]|w30[1]|w30[2]|w30[3]|
+ // Rhs |w27 |w28 |w29 |w30 |
// +------+------+------+------+
//
// | | | | |
//
- // Lhs | | | | |
+ // Lhs | | | | |
//
- // +---+---+ - - - - +------+------+------+------+
- // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
- // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
- // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
- // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
- // +---+---+ - - - - +------+------+------+------+
- // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
- // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
- // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
- // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
- // +---+---+ - - - - +------+------+------+------+
- // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
- // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
- // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
- // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
- // +---+---+ - - - - +------+------+------+------+
+ // +---+ - - - - +------+------+------+------+
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+ - - - - +------+------+------+------+
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+ - - - - +------+------+------+------+
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+ - - - - +------+------+------+------+
//
- // Accumulator
+ // Accumulators
// Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
"ld.b $w24, 0(%[lhs_ptr])\n"
@@ -4459,100 +4990,88 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
"lbu $a1, 1(%[rhs_ptr])\n"
"lbu $a2, 2(%[rhs_ptr])\n"
"lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the first half of depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
// Zero-extend 8-bit elements of lhs[] to 16 bits.
"ilvr.b $w24, $w31, $w24\n"
"ilvl.b $w26, $w31, $w25\n"
"ilvr.b $w25, $w31, $w25\n"
- // Zero-extend 16-bit elements of lhs[] to 32 bits.
- "ilvl.h $w27, $w31, $w24\n"
- "ilvl.h $w28, $w31, $w25\n"
- "ilvl.h $w29, $w31, $w26\n"
- "ilvr.h $w24, $w31, $w24\n"
- "ilvr.h $w25, $w31, $w25\n"
- "ilvr.h $w26, $w31, $w26\n"
-
- // Depth 0.
- "fill.w $w30, $a0\n"
- "lbu $a0, 8(%[rhs_ptr])\n"
- "maddv.w $w0, $w24, $w30\n"
- "maddv.w $w4, $w25, $w30\n"
- "maddv.w $w8, $w26, $w30\n"
- "fill.w $w30, $a1\n"
- "lbu $a1, 9(%[rhs_ptr])\n"
- "maddv.w $w1, $w24, $w30\n"
- "maddv.w $w5, $w25, $w30\n"
- "maddv.w $w9, $w26, $w30\n"
- "fill.w $w30, $a2\n"
- "lbu $a2, 10(%[rhs_ptr])\n"
- "maddv.w $w2, $w24, $w30\n"
- "maddv.w $w6, $w25, $w30\n"
- "maddv.w $w10, $w26, $w30\n"
- "fill.w $w30, $a3\n"
- "lbu $a3, 11(%[rhs_ptr])\n"
- "maddv.w $w3, $w24, $w30\n"
- "maddv.w $w7, $w25, $w30\n"
- "maddv.w $w11, $w26, $w30\n"
-
- "fill.w $w30, $a0\n"
- "lbu $a0, 4(%[rhs_ptr])\n"
- "maddv.w $w12, $w24, $w30\n"
- "maddv.w $w16, $w25, $w30\n"
- "maddv.w $w20, $w26, $w30\n"
- "fill.w $w30, $a1\n"
- "lbu $a1, 5(%[rhs_ptr])\n"
- "maddv.w $w13, $w24, $w30\n"
- "maddv.w $w17, $w25, $w30\n"
- "maddv.w $w21, $w26, $w30\n"
- "fill.w $w30, $a2\n"
- "lbu $a2, 6(%[rhs_ptr])\n"
- "maddv.w $w14, $w24, $w30\n"
- "maddv.w $w18, $w25, $w30\n"
- "maddv.w $w22, $w26, $w30\n"
- "fill.w $w30, $a3\n"
- "lbu $a3, 7(%[rhs_ptr])\n"
- "maddv.w $w15, $w24, $w30\n"
- "maddv.w $w19, $w25, $w30\n"
- "maddv.w $w23, $w26, $w30\n"
-
- // Depth 1.
- "fill.w $w30, $a0\n"
- "lbu $a0, 12(%[rhs_ptr])\n"
- "maddv.w $w0, $w27, $w30\n"
- "maddv.w $w4, $w28, $w30\n"
- "maddv.w $w8, $w29, $w30\n"
- "fill.w $w30, $a1\n"
- "lbu $a1, 13(%[rhs_ptr])\n"
- "maddv.w $w1, $w27, $w30\n"
- "maddv.w $w5, $w28, $w30\n"
- "maddv.w $w9, $w29, $w30\n"
- "fill.w $w30, $a2\n"
- "lbu $a2, 14(%[rhs_ptr])\n"
- "maddv.w $w2, $w27, $w30\n"
- "maddv.w $w6, $w28, $w30\n"
- "maddv.w $w10, $w29, $w30\n"
- "fill.w $w30, $a3\n"
- "lbu $a3, 15(%[rhs_ptr])\n"
- "maddv.w $w3, $w27, $w30\n"
- "maddv.w $w7, $w28, $w30\n"
- "maddv.w $w11, $w29, $w30\n"
-
- "fill.w $w30, $a0\n"
- "maddv.w $w12, $w27, $w30\n"
- "maddv.w $w16, $w28, $w30\n"
- "maddv.w $w20, $w29, $w30\n"
- "fill.w $w30, $a1\n"
- "maddv.w $w13, $w27, $w30\n"
- "maddv.w $w17, $w28, $w30\n"
- "maddv.w $w21, $w29, $w30\n"
- "fill.w $w30, $a2\n"
- "maddv.w $w14, $w27, $w30\n"
- "maddv.w $w18, $w28, $w30\n"
- "maddv.w $w22, $w29, $w30\n"
- "fill.w $w30, $a3\n"
- "maddv.w $w15, $w27, $w30\n"
- "maddv.w $w19, $w28, $w30\n"
- "maddv.w $w23, $w29, $w30\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w27, $w31, $w24\n"
+ "ilvl.d $w28, $w31, $w25\n"
+ "ilvl.d $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w27, $w24\n"
+ "ilvr.h $w25, $w28, $w25\n"
+ "ilvr.h $w26, $w29, $w26\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
+ // (for the first half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Load 4 bytes of rhs[] for the second half of depth 0.
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the second half of depth 1.
+ "lbu $v0, 12(%[rhs_ptr])\n"
+ "lbu $v1, 13(%[rhs_ptr])\n"
+ "lbu $t8, 14(%[rhs_ptr])\n"
+ "lbu $t9, 15(%[rhs_ptr])\n"
+
+ // First half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w24, $w27\n"
+ "dpadd_u.w $w4, $w25, $w27\n"
+ "dpadd_u.w $w8, $w26, $w27\n"
+ "dpadd_u.w $w1, $w24, $w28\n"
+ "dpadd_u.w $w5, $w25, $w28\n"
+ "dpadd_u.w $w9, $w26, $w28\n"
+ "dpadd_u.w $w2, $w24, $w29\n"
+ "dpadd_u.w $w6, $w25, $w29\n"
+ "dpadd_u.w $w10, $w26, $w29\n"
+ "dpadd_u.w $w3, $w24, $w30\n"
+ "dpadd_u.w $w7, $w25, $w30\n"
+ "dpadd_u.w $w11, $w26, $w30\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
+ // (for the second half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Second half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w12, $w24, $w27\n"
+ "dpadd_u.w $w16, $w25, $w27\n"
+ "dpadd_u.w $w20, $w26, $w27\n"
+ "dpadd_u.w $w13, $w24, $w28\n"
+ "dpadd_u.w $w17, $w25, $w28\n"
+ "dpadd_u.w $w21, $w26, $w28\n"
+ "dpadd_u.w $w14, $w24, $w29\n"
+ "dpadd_u.w $w18, $w25, $w29\n"
+ "dpadd_u.w $w22, $w26, $w29\n"
+ "dpadd_u.w $w15, $w24, $w30\n"
+ "dpadd_u.w $w19, $w25, $w30\n"
+ "dpadd_u.w $w23, $w26, $w30\n"
"addiu %[depth], -2\n"
GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
@@ -4591,7 +5110,9 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
[accum_ptr] "r"(accum_ptr)
: // clobbers
"memory",
+ "v0", "v1",
"a0", "a1", "a2", "a3",
+ "t8", "t9",
"$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
"$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
"$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
@@ -4599,21 +5120,14 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
}
};
-// Assembly implementation of the above
-// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
-// Using 16x16=32 multiplications.
-// 32 MSA regs used:
-// - 24 accumulators
-// - 3 lhs
-// - 4 rhs
-// - 1 temps/zeroes
-// ~70 instructions in the loop.
-struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 {
+// 12x8 depth 2 width-major kernel.
+// Does less shuffling and replication than the kernel above.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators2 {
typedef std::uint8_t OperandType;
typedef std::uint32_t AccumulatorType;
typedef KernelFormat<
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
- KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2> >
Format;
static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
AccumulatorType* accum_ptr, int depth) {
@@ -4643,19 +5157,18 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 {
"ld.w $w15, (21*16)(%[accum_ptr])\n"
"ld.w $w19, (22*16)(%[accum_ptr])\n"
"ld.w $w23, (23*16)(%[accum_ptr])\n"
- // Set a temp to all zeroes.
- "ldi.b $w31, 0\n"
- GEMMLOWP_LABEL_LOOP ":\n"
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
// Overview of register layout:
//
- // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31
// (each register contains 4 replicas of a pair of elements).
// A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
// A 12x8 block of accumulators is stored in 32bit in w0-w23.
//
// +------+------+------+------+
- // Rhs |w27 |w28 |w29 |w30 |
+ // Rhs |w28 |w29 |w30 |w31 |
// +------+------+------+------+
//
// | | | | |
@@ -4685,98 +5198,65 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 {
"ld.b $w24, 0(%[lhs_ptr])\n"
"ld.b $w25, 8(%[lhs_ptr])\n"
- // Load 4 bytes of rhs[] for the first half of depth 0.
- "lbu $a0, 0(%[rhs_ptr])\n"
- "lbu $a1, 1(%[rhs_ptr])\n"
- "lbu $a2, 2(%[rhs_ptr])\n"
- "lbu $a3, 3(%[rhs_ptr])\n"
- // Load 4 bytes of rhs[] for the first half of depth 1.
- "lbu $v0, 4(%[rhs_ptr])\n"
- "lbu $v1, 5(%[rhs_ptr])\n"
- "lbu $t8, 6(%[rhs_ptr])\n"
- "lbu $t9, 7(%[rhs_ptr])\n"
+ // Load 2 x 8 bytes of rhs[].
+ "ld.b $w27, 0(%[rhs_ptr])\n"
// Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ldi.b $w31, 0\n"
"ilvr.b $w24, $w31, $w24\n"
"ilvl.b $w26, $w31, $w25\n"
"ilvr.b $w25, $w31, $w25\n"
- // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
- "ilvl.d $w27, $w31, $w24\n"
- "ilvl.d $w28, $w31, $w25\n"
- "ilvl.d $w29, $w31, $w26\n"
- "ilvr.h $w24, $w27, $w24\n"
- "ilvr.h $w25, $w28, $w25\n"
- "ilvr.h $w26, $w29, $w26\n"
-
- // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
- // (for the first half).
- "ins $a0, $v0, 16, 8\n"
- "ins $a1, $v1, 16, 8\n"
- "ins $a2, $t8, 16, 8\n"
- "ins $a3, $t9, 16, 8\n"
- // Make 4 replicas of every pair of rhs[] elements.
- "fill.w $w27, $a0\n"
- "fill.w $w28, $a1\n"
- "fill.w $w29, $a2\n"
- "fill.w $w30, $a3\n"
-
- // Load 4 bytes of rhs[] for the second half of depth 0.
- "lbu $a0, 8(%[rhs_ptr])\n"
- "lbu $a1, 9(%[rhs_ptr])\n"
- "lbu $a2, 10(%[rhs_ptr])\n"
- "lbu $a3, 11(%[rhs_ptr])\n"
- // Load 4 bytes of rhs[] for the second half of depth 1.
- "lbu $v0, 12(%[rhs_ptr])\n"
- "lbu $v1, 13(%[rhs_ptr])\n"
- "lbu $t8, 14(%[rhs_ptr])\n"
- "lbu $t9, 15(%[rhs_ptr])\n"
// First half of depths 0 and 1.
- // Dot-product-(and)-add doubles multiplicand width.
- "dpadd_u.w $w0, $w24, $w27\n"
- "dpadd_u.w $w4, $w25, $w27\n"
- "dpadd_u.w $w8, $w26, $w27\n"
- "dpadd_u.w $w1, $w24, $w28\n"
- "dpadd_u.w $w5, $w25, $w28\n"
- "dpadd_u.w $w9, $w26, $w28\n"
- "dpadd_u.w $w2, $w24, $w29\n"
- "dpadd_u.w $w6, $w25, $w29\n"
- "dpadd_u.w $w10, $w26, $w29\n"
- "dpadd_u.w $w3, $w24, $w30\n"
- "dpadd_u.w $w7, $w25, $w30\n"
- "dpadd_u.w $w11, $w26, $w30\n"
-
- // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
- // (for the second half).
- "ins $a0, $v0, 16, 8\n"
- "ins $a1, $v1, 16, 8\n"
- "ins $a2, $t8, 16, 8\n"
- "ins $a3, $t9, 16, 8\n"
+ // Zero-extend 8-bit elements of rhs[] to 16 bits.
+ "ilvr.b $w31, $w31, $w27\n"
// Make 4 replicas of every pair of rhs[] elements.
- "fill.w $w27, $a0\n"
- "fill.w $w28, $a1\n"
- "fill.w $w29, $a2\n"
- "fill.w $w30, $a3\n"
+ "splati.w $w28, $w31[0]\n"
+ "splati.w $w29, $w31[1]\n"
+ "splati.w $w30, $w31[2]\n"
+ "splati.w $w31, $w31[3]\n"
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w24, $w28\n"
+ "dpadd_u.w $w4, $w25, $w28\n"
+ "dpadd_u.w $w8, $w26, $w28\n"
+ "dpadd_u.w $w1, $w24, $w29\n"
+ "dpadd_u.w $w5, $w25, $w29\n"
+ "dpadd_u.w $w9, $w26, $w29\n"
+ "dpadd_u.w $w2, $w24, $w30\n"
+ "dpadd_u.w $w6, $w25, $w30\n"
+ "dpadd_u.w $w10, $w26, $w30\n"
+ "dpadd_u.w $w3, $w24, $w31\n"
+ "dpadd_u.w $w7, $w25, $w31\n"
+ "dpadd_u.w $w11, $w26, $w31\n"
// Second half of depths 0 and 1.
+ // Zero-extend 8-bit elements of rhs[] to 16 bits.
+ "ldi.b $w31, 0\n"
+ "ilvl.b $w31, $w31, $w27\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "splati.w $w28, $w31[0]\n"
+ "splati.w $w29, $w31[1]\n"
+ "splati.w $w30, $w31[2]\n"
+ "splati.w $w31, $w31[3]\n"
// Dot-product-(and)-add doubles multiplicand width.
- "dpadd_u.w $w12, $w24, $w27\n"
- "dpadd_u.w $w16, $w25, $w27\n"
- "dpadd_u.w $w20, $w26, $w27\n"
- "dpadd_u.w $w13, $w24, $w28\n"
- "dpadd_u.w $w17, $w25, $w28\n"
- "dpadd_u.w $w21, $w26, $w28\n"
- "dpadd_u.w $w14, $w24, $w29\n"
- "dpadd_u.w $w18, $w25, $w29\n"
- "dpadd_u.w $w22, $w26, $w29\n"
- "dpadd_u.w $w15, $w24, $w30\n"
- "dpadd_u.w $w19, $w25, $w30\n"
- "dpadd_u.w $w23, $w26, $w30\n"
-
- "addiu %[depth], -2\n"
- GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
- GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
- "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+ "dpadd_u.w $w12, $w24, $w28\n"
+ "dpadd_u.w $w16, $w25, $w28\n"
+ "dpadd_u.w $w20, $w26, $w28\n"
+ "dpadd_u.w $w13, $w24, $w29\n"
+ "dpadd_u.w $w17, $w25, $w29\n"
+ "dpadd_u.w $w21, $w26, $w29\n"
+ "dpadd_u.w $w14, $w24, $w30\n"
+ "dpadd_u.w $w18, $w25, $w30\n"
+ "dpadd_u.w $w22, $w26, $w30\n"
+ "dpadd_u.w $w15, $w24, $w31\n"
+ "dpadd_u.w $w19, $w25, $w31\n"
+ "dpadd_u.w $w23, $w26, $w31\n"
+
+ "addiu %[depth], -2\n" GEMMLOWP_MIPS_XADDIU
+ " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU
+ " %[rhs_ptr], 16\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP
+ "b\n"
// Store accumulators.
"st.w $w0, (0*16)(%[accum_ptr])\n"
@@ -4809,14 +5289,304 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 {
: // inputs
[accum_ptr] "r"(accum_ptr)
: // clobbers
- "memory",
- "v0", "v1",
- "a0", "a1", "a2", "a3",
- "t8", "t9",
- "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
- "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
- "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
- "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
+ "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8",
+ "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17",
+ "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26",
+ "$f27", "$f28", "$f29", "$f30", "$f31");
+ }
+};
+
+// 4x4 depth 16 width-major kernel operating on int8 operands.
+// It is assumed that one of the two int8 operands only takes values
+// in [-127, 127], while the other may freely range in [-128, 127].
+// The issue with both operands taking the value -128 is that:
+// -128*-128 + -128*-128 == -32768 overflows int16.
+// Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
+// range. That is the basic idea of this kernel.
+struct MSA_GEMM_Int8Operands_AccumTwoWithin16Bits {
+ typedef std::int8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+ KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ std::size_t start_depth = 123;
+ std::size_t run_depth = depth;
+ std::size_t dst_col_stride = 4;
+ AccumulatorType* dst_ptr = accum_ptr;
+#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
+#define GEMMLOWP_LABEL_LOOP "2"
+#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
+#define GEMMLOWP_LABEL_STORE "4"
+ asm volatile(
+ GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
+ // Load lhs[] and rhs[], zero out internal accumulators.
+ "ld.b $w16, 0(%[lhs_ptr])\n"
+ "ldi.b $w0, 0\n"
+ "ld.b $w20, 0(%[rhs_ptr])\n"
+ "ldi.b $w1, 0\n"
+ "ld.b $w17, 16(%[lhs_ptr])\n"
+ "ldi.b $w2, 0\n"
+ "ld.b $w21, 16(%[rhs_ptr])\n"
+ "ldi.b $w3, 0\n"
+ "ld.b $w18, 32(%[lhs_ptr])\n"
+ "ldi.b $w4, 0\n"
+ "ld.b $w19, 48(%[lhs_ptr])\n"
+ "ldi.b $w5, 0\n"
+ "ld.b $w22, 32(%[rhs_ptr])\n"
+ "ldi.b $w6, 0\n"
+ "ld.b $w23, 48(%[rhs_ptr])\n"
+ "ldi.b $w7, 0\n"
+ "ldi.b $w8, 0\n"
+ "ldi.b $w9, 0\n"
+ "ldi.b $w10, 0\n"
+ "ldi.b $w11, 0\n"
+ "ldi.b $w12, 0\n"
+ "ldi.b $w13, 0\n"
+ "ldi.b $w14, 0\n"
+ "ldi.b $w15, 0\n"
+ "ldi.h $w31, 1\n"
+ // If the loop depth is only 16, then we can skip the general loop
+ // and go straight to the final part of the code.
+ "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A 4x16 block of Rhs is stored in 8 bit in w16-w19.
+ // A 4x16 block of Lhs is stored in 8 bit in w20-w23.
+ //
+ // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit
+ // components which need to be horizontally added at the end).
+ //
+ // Dot products of Lhs and Rhs are 16-bit values, which can't
+ // immediately be accumulated in 32-bit accumulators by that
+ // same instruction that calculates them.
+ // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit
+ // sums in w25 (note, the 16 sums have already been reduced to 8
+ // by the horizontal addition of the dotp instruction).
+ // They are then sign-extended to 32 bits, horizontally added
+ // (again) to form 4 32-bit sums and then they are finally added
+ // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31".
+ //
+ // +-----+-----+-----+-----+
+ // Rhs | w20 | w21 | w22 | w23 |
+ // +-----+-----+-----+-----+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w16| | w0 | w4 | w8 | w12 |
+ // |w17| | w1 | w5 | w9 | w13 |
+ // |w18| | w2 | w6 | w10 | w14 |
+ // |w19| | w3 | w7 | w11 | w15 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ //
+ // Accumulators
+
+ // Calculate the results for 16 depths and load
+ // lhs[] and rhs[] for the next iteration.
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n"
+ GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w25, $w16, $w20\n"
+ "dotp_s.h $w26, $w17, $w20\n"
+ "dotp_s.h $w27, $w16, $w21\n"
+ "dotp_s.h $w28, $w17, $w21\n"
+ "dotp_s.h $w29, $w18, $w20\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w0, $w25, $w31\n"
+ "dpadd_s.w $w1, $w26, $w31\n"
+ "dpadd_s.w $w4, $w27, $w31\n"
+ "dpadd_s.w $w5, $w28, $w31\n"
+ "dpadd_s.w $w2, $w29, $w31\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w24, $w16, $w22\n"
+ "dotp_s.h $w25, $w19, $w20\n"
+ "dotp_s.h $w26, $w16, $w23\n"
+ "dotp_s.h $w27, $w17, $w22\n"
+ "ld.b $w20, 0(%[rhs_ptr])\n"
+ "dotp_s.h $w28, $w17, $w23\n"
+ "ld.b $w16, 0(%[lhs_ptr])\n"
+ "dotp_s.h $w29, $w18, $w21\n"
+ "ld.b $w17, 16(%[lhs_ptr])\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w8, $w24, $w31\n"
+ "dpadd_s.w $w3, $w25, $w31\n"
+ "dpadd_s.w $w12, $w26, $w31\n"
+ "dpadd_s.w $w9, $w27, $w31\n"
+ "dpadd_s.w $w13, $w28, $w31\n"
+ "dpadd_s.w $w6, $w29, $w31\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w25, $w19, $w21\n"
+ "dotp_s.h $w26, $w18, $w22\n"
+ "dotp_s.h $w27, $w18, $w23\n"
+ "ld.b $w21, 16(%[rhs_ptr])\n"
+ "dotp_s.h $w28, $w19, $w22\n"
+ "ld.b $w18, 32(%[lhs_ptr])\n"
+ "dotp_s.h $w29, $w19, $w23\n"
+ "ld.b $w22, 32(%[rhs_ptr])\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w7, $w25, $w31\n"
+ "ld.b $w19, 48(%[lhs_ptr])\n"
+ "dpadd_s.w $w10, $w26, $w31\n"
+ "ld.b $w23, 48(%[rhs_ptr])\n"
+ "dpadd_s.w $w14, $w27, $w31\n"
+ "dpadd_s.w $w11, $w28, $w31\n"
+ "dpadd_s.w $w15, $w29, $w31\n"
+
+ "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n"
+
+ GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n"
+ // Calculate the results for the last 16 depths.
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w25, $w16, $w20\n"
+ "dotp_s.h $w26, $w17, $w20\n"
+ "dotp_s.h $w27, $w16, $w21\n"
+ "dotp_s.h $w28, $w17, $w21\n"
+ "dotp_s.h $w29, $w18, $w20\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w0, $w25, $w31\n"
+ "dpadd_s.w $w1, $w26, $w31\n"
+ "dpadd_s.w $w4, $w27, $w31\n"
+ "dpadd_s.w $w5, $w28, $w31\n"
+ "dpadd_s.w $w2, $w29, $w31\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w24, $w16, $w22\n"
+ "dotp_s.h $w25, $w19, $w20\n"
+ "dotp_s.h $w26, $w16, $w23\n"
+ "dotp_s.h $w27, $w17, $w22\n"
+ "dotp_s.h $w28, $w17, $w23\n"
+ "dotp_s.h $w29, $w18, $w21\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w8, $w24, $w31\n"
+ "dpadd_s.w $w3, $w25, $w31\n"
+ "dpadd_s.w $w12, $w26, $w31\n"
+ "dpadd_s.w $w9, $w27, $w31\n"
+ "dpadd_s.w $w13, $w28, $w31\n"
+ "dpadd_s.w $w6, $w29, $w31\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w25, $w19, $w21\n"
+ "dotp_s.h $w26, $w18, $w22\n"
+ "dotp_s.h $w27, $w18, $w23\n"
+ "dotp_s.h $w28, $w19, $w22\n"
+ "dotp_s.h $w29, $w19, $w23\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w7, $w25, $w31\n"
+ "dpadd_s.w $w10, $w26, $w31\n"
+ "dpadd_s.w $w14, $w27, $w31\n"
+ "dpadd_s.w $w11, $w28, $w31\n"
+ "dpadd_s.w $w15, $w29, $w31\n"
+
+ // Horizontal-add internal accumulators.
+ "hadd_s.d $w0, $w0, $w0\n"
+ "hadd_s.d $w1, $w1, $w1\n"
+ "hadd_s.d $w2, $w2, $w2\n"
+ "hadd_s.d $w3, $w3, $w3\n"
+ "hadd_s.d $w4, $w4, $w4\n"
+ "hadd_s.d $w5, $w5, $w5\n"
+ "hadd_s.d $w6, $w6, $w6\n"
+ "hadd_s.d $w7, $w7, $w7\n"
+ "hadd_s.d $w8, $w8, $w8\n"
+ "hadd_s.d $w9, $w9, $w9\n"
+ "hadd_s.d $w10, $w10, $w10\n"
+ "hadd_s.d $w11, $w11, $w11\n"
+ "hadd_s.d $w12, $w12, $w12\n"
+ "hadd_s.d $w13, $w13, $w13\n"
+ "hadd_s.d $w14, $w14, $w14\n"
+ "hadd_s.d $w15, $w15, $w15\n"
+ "pckev.w $w0, $w1, $w0\n"
+ "pckev.w $w2, $w3, $w2\n"
+ "pckev.w $w4, $w5, $w4\n"
+ "pckev.w $w6, $w7, $w6\n"
+ "pckev.w $w8, $w9, $w8\n"
+ "pckev.w $w10, $w11, $w10\n"
+ "pckev.w $w12, $w13, $w12\n"
+ "pckev.w $w14, $w15, $w14\n"
+ "hadd_s.d $w0, $w0, $w0\n"
+ "hadd_s.d $w2, $w2, $w2\n"
+ "hadd_s.d $w4, $w4, $w4\n"
+ "hadd_s.d $w6, $w6, $w6\n"
+ "hadd_s.d $w8, $w8, $w8\n"
+ "hadd_s.d $w10, $w10, $w10\n"
+ "hadd_s.d $w12, $w12, $w12\n"
+ "hadd_s.d $w14, $w14, $w14\n"
+ // 4 more pckev instructions follow in both paths below.
+
+ // Check if start_depth==0 to decide whether we will load
+ // existing accumulators from memory.
+ "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n"
+
+ "pckev.w $w0, $w2, $w0\n"
+ "pckev.w $w1, $w6, $w4\n"
+ "pckev.w $w2, $w10, $w8\n"
+ "pckev.w $w3, $w14, $w12\n"
+
+ "b " GEMMLOWP_LABEL_STORE "f\n"
+
+ GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n"
+ // Load accumulators from memory.
+ "ld.w $w16, 0(%[dst_ptr0])\n"
+ "pckev.w $w0, $w2, $w0\n"
+ "ld.w $w17, 0(%[dst_ptr1])\n"
+ "pckev.w $w1, $w6, $w4\n"
+ "ld.w $w18, 0(%[dst_ptr2])\n"
+ "pckev.w $w2, $w10, $w8\n"
+ "ld.w $w19, 0(%[dst_ptr3])\n"
+ "pckev.w $w3, $w14, $w12\n"
+
+ // Add them to internal accumulators.
+ "addv.w $w0, $w0, $w16\n"
+ "addv.w $w1, $w1, $w17\n"
+ "addv.w $w2, $w2, $w18\n"
+ "addv.w $w3, $w3, $w19\n"
+
+ GEMMLOWP_LABEL_STORE ":\n"
+ // Store accumulators.
+ "st.w $w0, 0(%[dst_ptr0])\n"
+ "st.w $w1, 0(%[dst_ptr1])\n"
+ "st.w $w2, 0(%[dst_ptr2])\n"
+ "st.w $w3, 0(%[dst_ptr3])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [run_depth] "+r"(run_depth)
+ : // inputs
+ [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride),
+ [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2),
+ [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3),
+ [start_depth] "r"(start_depth)
+ : // clobbers
+ "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8",
+ "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17",
+ "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26",
+ "$f27", "$f28", "$f29", "$f30", "$f31");
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16
+#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+#undef GEMMLOWP_LABEL_STORE
}
};
#endif // __mips
@@ -4901,13 +5671,10 @@ class CacheLineAlignedBuffer {
};
template <typename DataType>
-void FillRandom(CacheLineAlignedBuffer<DataType>* buffer) {
+void FillRandom(CacheLineAlignedBuffer<DataType>* buffer, DataType min,
+ DataType max) {
static std::mt19937 generator(0);
- // 100 is smaller than any nonzero bound of the range of any data type.
- const DataType kMaxVal = DataType(100);
- const DataType kMinVal =
- std::is_signed<DataType>::value ? -kMaxVal : DataType(0);
- std::uniform_real_distribution<float> dist(kMinVal, kMaxVal);
+ std::uniform_real_distribution<float> dist(min, max);
for (std::size_t i = 0; i < buffer->size(); i++) {
buffer->data()[i] = DataType(dist(generator));
}
@@ -4971,9 +5738,16 @@ void test_kernel(int depth, const char* kernel_name) {
CacheLineAlignedBuffer<AccumulatorType> accum_reference(kLhsWidth *
kRhsWidth);
- FillRandom(&lhs);
- FillRandom(&rhs);
- FillRandom(&accum_initial);
+ FillRandom(&lhs, KernelOperandRanges<Kernel>::LhsMin(),
+ KernelOperandRanges<Kernel>::LhsMax());
+ FillRandom(&rhs, KernelOperandRanges<Kernel>::RhsMin(),
+ KernelOperandRanges<Kernel>::RhsMax());
+ FillRandom(&accum_initial,
+ std::is_signed<AccumulatorType>::value
+ ? AccumulatorType(-100)
+ : AccumulatorType(0),
+ AccumulatorType(100));
+
Copy(&accum, accum_initial);
Copy(&accum_reference, accum_initial);
@@ -5159,6 +5933,10 @@ int main() {
#endif
#ifdef __aarch64__
+ BENCHMARK(NEON_64bit_GEMM_Int425Operands);
+ BENCHMARK(NEON_64bit_GEMM_Int425Operands_intrinsics);
+ BENCHMARK(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits);
+ BENCHMARK(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics);
BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits);
BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators);
@@ -5167,6 +5945,7 @@ int main() {
#ifdef __ARM_FEATURE_DOTPROD
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1);
+ BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow);
#endif
BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar);
BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar);
@@ -5180,12 +5959,9 @@ int main() {
#endif
#ifdef __mips
- BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics);
- BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly);
- BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2);
- BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics);
- BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly);
- BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators1);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators2);
+ BENCHMARK(MSA_GEMM_Int8Operands_AccumTwoWithin16Bits);
#endif
return 0;
diff --git a/test/benchmark.cc b/test/benchmark.cc
index 9a87a41..d8236de 100644
--- a/test/benchmark.cc
+++ b/test/benchmark.cc
@@ -36,7 +36,16 @@
#warning "Building without NEON support on ARM, check your compiler setup!"
#endif
-#if defined(__SSE4_2__) && !defined(GEMMLOWP_SSE4)
+#if defined(__mips) && !defined(GEMMLOWP_MSA)
+#warning "Building without MSA support on MIPS, check your compiler setup!"
+#endif
+
+#if defined(__AVX2__) && !defined(GEMMLOWP_AVX2)
+#warning \
+ "Building without AVX2 support on AVX2 enabled machine, check your compiler setup!"
+#endif
+
+#if defined(__SSE4_2__) && !defined(GEMMLOWP_AVX2) && !defined(GEMMLOWP_SSE4)
#warning \
"Building without SSE4.2 support on SSE4.2 enabled machine, check your compiler setup!"
#endif
diff --git a/test/benchmark_all_sizes.cc b/test/benchmark_all_sizes.cc
index 16cc57c..527aad6 100644
--- a/test/benchmark_all_sizes.cc
+++ b/test/benchmark_all_sizes.cc
@@ -16,6 +16,10 @@ test/benchmark_all_sizes.cc -o /tmp/b -O3 --std=c++11 -fPIE -static \
#include "../public/gemmlowp.h"
+#ifdef GEMMLOWP_PROFILING
+#include "../profiling/profiler.h"
+#endif
+
#if defined GEMMLOWP_ANDROID && defined GEMMLOWP_ARM_32
// Compilation workaround
namespace std {
@@ -122,10 +126,10 @@ float benchmark_8bit(int rows, int depth, int cols) {
MakeZero(&rhs);
MakeZero(&result);
- typedef std::tuple<OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
+ typedef std::tuple<OutputStageQuantizeDownInt32ByFixedPoint,
OutputStageSaturatingCastToUint8>
Pipeline;
- gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
+ gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint
quantize_down_stage;
quantize_down_stage.result_offset_after_shift = 128;
quantize_down_stage.result_fixedpoint_multiplier = 1234567890;
@@ -345,7 +349,18 @@ void run_benchmarks(std::map<Shape, float>* results) {
int main() {
std::map<Shape, float> results;
+
+#ifdef GEMMLOWP_PROFILING
+ gemmlowp::RegisterCurrentThreadForProfiling();
+ gemmlowp::StartProfiling();
+#endif
+
run_benchmarks(&results);
+
+#ifdef GEMMLOWP_PROFILING
+ gemmlowp::FinishProfiling();
+#endif
+
printf("Using %d thread(s)\n", kNumThreads);
printf("depth,rows,cols,latency(s),Gop/s\n");
for (const auto& result : results) {
diff --git a/test/test.cc b/test/test.cc
index eee16b4..735ad1e 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -1277,6 +1277,47 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset,
}
}
+ // Test a variant of the familiar default pipeline consisting of quantize-down
+ // and clamp-and-cast-to-int16.
+ OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
+ auto quantize_down_and_saturating_cast_int16_pipeline =
+ std::make_tuple(quantize_down_stage, saturating_cast_int16_stage);
+ Matrix<std::int16_t, ResultOrder> result_quantized_down_saturated_int16(rows,
+ cols);
+ GemmWithOutputPipeline<std::uint8_t, std::int16_t, DefaultL8R8BitDepthParams>(
+ &context, lhs.const_map(), rhs.const_map(),
+ &result_quantized_down_saturated_int16, lhs_offset, rhs_offset,
+ quantize_down_and_saturating_cast_int16_pipeline);
+
+ for (int r = 0; r < rows; r++) {
+ for (int c = 0; c < cols; c++) {
+ std::int32_t quantized = result_quantized_down_int32(r, c);
+ std::int16_t expected = std::min(std::max(quantized, -32768), 32767);
+ Check(expected == result_quantized_down_saturated_int16(r, c));
+ }
+ }
+
+#ifdef GEMMLOWP_MSA
+ // Test a pipeline consisting of quantize-down and truncating-cast-to-uint8.
+ OutputStageTruncatingCastToUint8 truncating_cast_stage;
+ auto quantize_down_and_truncating_cast_pipeline =
+ std::make_tuple(quantize_down_stage, truncating_cast_stage);
+ Matrix<std::uint8_t, ResultOrder> result_quantized_down_truncated_uint8(
+ rows, cols);
+ GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
+ &context, lhs.const_map(), rhs.const_map(),
+ &result_quantized_down_truncated_uint8, lhs_offset, rhs_offset,
+ quantize_down_and_truncating_cast_pipeline);
+
+ for (int r = 0; r < rows; r++) {
+ for (int c = 0; c < cols; c++) {
+ std::int32_t quantized = result_quantized_down_int32(r, c);
+ std::uint8_t expected = quantized & 255;
+ Check(expected == result_quantized_down_truncated_uint8(r, c));
+ }
+ }
+#endif
+
// Test a bias-addition with row-vector
std::vector<std::int32_t> row_vector_data(cols);
std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
@@ -1428,8 +1469,8 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset,
result_fixedpoint_shift++;
}
Check(result_fixedpoint_shift >= 0);
- // Now test OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
- OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
+ // Now test OutputStageQuantizeDownInt32ByFixedPoint
+ OutputStageQuantizeDownInt32ByFixedPoint
quantize_down_by_fixedpoint_stage;
quantize_down_by_fixedpoint_stage.result_offset_after_shift =
static_cast<std::int32_t>(
@@ -1447,7 +1488,6 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset,
&result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
quantize_down_by_fixedpoint_pipeline);
- std::vector<std::int32_t> diffs_caused_by_fixedpoint;
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
const std::int32_t actual =
@@ -1462,6 +1502,44 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset,
}
}
+ // Test OutputStageScaleInt32ByFixedPointAndExponent
+ for (int exponent = -2; exponent <= 2; exponent++) {
+ OutputStageScaleInt32ByFixedPointAndExponent
+ scale_by_fixedpoint_and_exponent_stage;
+ scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift =
+ static_cast<std::int32_t>(round(static_cast<double>(
+ result_offset * result_mult_int * std::pow(2.0, exponent))));
+ scale_by_fixedpoint_and_exponent_stage.result_fixedpoint_multiplier =
+ result_fixedpoint_multiplier;
+ scale_by_fixedpoint_and_exponent_stage.result_exponent = exponent;
+ auto scale_by_fixedpoint_and_exponent_pipeline =
+ std::make_tuple(scale_by_fixedpoint_and_exponent_stage);
+ Matrix<std::int32_t, ResultOrder>
+ result_scaled_by_fixedpoint_and_exponent_int32(rows, cols);
+ GemmWithOutputPipeline<std::uint8_t, std::int32_t,
+ DefaultL8R8BitDepthParams>(
+ &context, lhs.const_map(), rhs.const_map(),
+ &result_scaled_by_fixedpoint_and_exponent_int32, lhs_offset, rhs_offset,
+ scale_by_fixedpoint_and_exponent_pipeline);
+
+ for (int r = 0; r < rows; r++) {
+ for (int c = 0; c < cols; c++) {
+ const std::int32_t actual =
+ result_scaled_by_fixedpoint_and_exponent_int32(r, c);
+ const std::int32_t raw = result_raw_int32(r, c);
+ int left_shift = std::max(0, exponent);
+ int right_shift = std::max(0, -exponent);
+ const std::int32_t expected =
+ scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift +
+ RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul((1 << left_shift) * raw,
+ result_fixedpoint_multiplier),
+ right_shift);
+ Check(actual == expected);
+ }
+ }
+ }
+
// Test the variant of the familiar default pipeline consisting of
// quantize-down and
// clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
diff --git a/test/test.h b/test/test.h
index aecd0c1..b381bad 100644
--- a/test/test.h
+++ b/test/test.h
@@ -49,7 +49,7 @@ class Matrix : public MatrixMap<tScalar, tOrder> {
typedef MatrixMap<tScalar, tOrder> Map;
typedef MatrixMap<const tScalar, tOrder> ConstMap;
typedef typename Map::Scalar Scalar;
- static const MapOrder Order = tOrder;
+ static constexpr MapOrder Order = tOrder;
using Map::kOrder;
using Map::rows_;
using Map::cols_;
@@ -92,12 +92,12 @@ class Matrix : public MatrixMap<tScalar, tOrder> {
std::vector<Scalar> storage;
};
-std::mt19937& RandomEngine() {
+inline std::mt19937& RandomEngine() {
static std::mt19937 engine;
return engine;
}
-int Random() {
+inline int Random() {
std::uniform_int_distribution<int> dist(0, std::numeric_limits<int>::max());
return dist(RandomEngine());
}
diff --git a/test/test_blocking_counter.cc b/test/test_blocking_counter.cc
index d1e0932..34d963d 100644
--- a/test/test_blocking_counter.cc
+++ b/test/test_blocking_counter.cc
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "test.h"
-#include "../profiling/pthread_everywhere.h"
-
+#include <atomic> // NOLINT
#include <vector>
+#include <iostream>
+#include <cstdlib>
#include "../internal/multi_thread_gemm.h"
+#include "../profiling/pthread_everywhere.h"
+#include "test.h"
namespace gemmlowp {
@@ -26,16 +28,36 @@ class Thread {
Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
: blocking_counter_(blocking_counter),
number_of_times_to_decrement_(number_of_times_to_decrement),
- finished_(false),
- made_the_last_decrement_(false) {
+ made_the_last_decrement_(false),
+ finished_(false) {
+#if defined GEMMLOWP_USE_PTHREAD
+ // Limit the stack size so as not to deplete memory when creating
+ // many threads.
+ pthread_attr_t attr;
+ int err = pthread_attr_init(&attr);
+ if (!err) {
+ size_t stack_size;
+ err = pthread_attr_getstacksize(&attr, &stack_size);
+ if (!err && stack_size > max_stack_size_) {
+ err = pthread_attr_setstacksize(&attr, max_stack_size_);
+ }
+ if (!err) {
+ err = pthread_create(&thread_, &attr, ThreadFunc, this);
+ }
+ }
+ if (err) {
+ std::cerr << "Failed to create a thread.\n";
+ std::abort();
+ }
+#else
pthread_create(&thread_, nullptr, ThreadFunc, this);
+#endif
}
~Thread() { Join(); }
- bool Join() const {
- if (!finished_) {
- pthread_join(thread_, nullptr);
+ bool Join() {
+ while (!finished_.load()) {
}
return made_the_last_decrement_;
}
@@ -48,7 +70,7 @@ class Thread {
Check(!made_the_last_decrement_);
made_the_last_decrement_ = blocking_counter_->DecrementCount();
}
- finished_ = true;
+ finished_.store(true);
}
static void* ThreadFunc(void* ptr) {
@@ -56,11 +78,18 @@ class Thread {
return nullptr;
}
+ static constexpr size_t max_stack_size_ = 256 * 1024;
BlockingCounter* const blocking_counter_;
const int number_of_times_to_decrement_;
pthread_t thread_;
- bool finished_;
bool made_the_last_decrement_;
+ // finished_ is used to manually implement Join() by busy-waiting.
+ // I wanted to use pthread_join / std::thread::join, but the behavior
+ // observed on Android was that pthread_join aborts when the thread has
+ // already joined before calling pthread_join, making that hard to use.
+ // It appeared simplest to just implement this simple spinlock, and that
+ // is good enough as this is just a test.
+ std::atomic<bool> finished_;
};
void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
@@ -89,10 +118,10 @@ void test_blocking_counter() {
// repeating the entire test sequence ensures that we test
// non-monotonic changes.
for (int repeat = 1; repeat <= 2; repeat++) {
- for (int num_threads = 1; num_threads <= 16; num_threads++) {
+ for (int num_threads = 1; num_threads <= 5; num_threads++) {
for (int num_decrements_per_thread = 1;
- num_decrements_per_thread <= 64 * 1024;
- num_decrements_per_thread *= 4) {
+ num_decrements_per_thread <= 4 * 1024;
+ num_decrements_per_thread *= 16) {
test_blocking_counter(blocking_counter, num_threads,
num_decrements_per_thread,
num_threads * num_decrements_per_thread);
diff --git a/test/test_fixedpoint.cc b/test/test_fixedpoint.cc
index da222f0..44e6fae 100644
--- a/test/test_fixedpoint.cc
+++ b/test/test_fixedpoint.cc
@@ -17,479 +17,587 @@
#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
#include <algorithm>
+#include <cinttypes>
#include <cmath>
+#include <cstdio>
#include <random>
#include <vector>
-#include "test.h"
#include "../fixedpoint/fixedpoint.h"
+#include "test.h"
namespace gemmlowp {
namespace {
-// Explanation of SimdVector type and associated functions
-// (LoadSimdVector, StoreSimdVector):
-// The fixedpoint stuff being tested here is generic in an underlying
-// integer type which may be either scalar (int32_t) or SIMD (e.g.
-// NEON int32x4_t). We want to write uniform tests that can test
-// both the scalar and SIMD paths. We achieve this by having this
-// generic SimdVector abstraction, local to this test.
-
+template <typename T>
+T Load(const typename FixedPointRawTypeTraits<T>::ScalarRawType* src) {
+ return *src;
+}
+template <typename T>
+void Store(typename FixedPointRawTypeTraits<T>::ScalarRawType* dst, T v) {
+ *dst = v;
+}
#ifdef GEMMLOWP_NEON
-using SimdVector = int32x4_t;
-constexpr std::size_t SimdVectorSize = 4;
-SimdVector LoadSimdVector(const std::int32_t* src) { return vld1q_s32(src); }
-void StoreSimdVector(std::int32_t* dst, SimdVector v) { vst1q_s32(dst, v); }
-#elif defined(GEMMLOWP_SSE4)
-using SimdVector = __m128i;
-constexpr std::size_t SimdVectorSize = 4;
-SimdVector LoadSimdVector(const std::int32_t* src) {
+template <>
+int32x4_t Load<int32x4_t>(const std::int32_t* src) {
+ return vld1q_s32(src);
+}
+template <>
+int16x8_t Load<int16x8_t>(const std::int16_t* src) {
+ return vld1q_s16(src);
+}
+template <>
+void Store<int32x4_t>(std::int32_t* dst, int32x4_t v) {
+ vst1q_s32(dst, v);
+}
+template <>
+void Store<int16x8_t>(std::int16_t* dst, int16x8_t v) {
+ vst1q_s16(dst, v);
+}
+#endif
+#ifdef GEMMLOWP_SSE4
+template <>
+__m128i Load<__m128i>(const std::int32_t* src) {
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(src));
}
-void StoreSimdVector(std::int32_t* dst, SimdVector v) {
+template <>
+void Store<__m128i>(std::int32_t* dst, __m128i v) {
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v);
}
-#else
-using SimdVector = std::int32_t;
-constexpr std::size_t SimdVectorSize = 1;
-SimdVector LoadSimdVector(const std::int32_t* src) { return *src; }
-void StoreSimdVector(std::int32_t* dst, SimdVector v) { *dst = v; }
+template <>
+int16x8_m128i Load<int16x8_m128i>(const std::int16_t* src) {
+ return to_int16x8_m128i(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(src)));
+}
+template <>
+void Store<int16x8_m128i>(std::int16_t* dst, int16x8_m128i v) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v.v);
+}
+#endif
+#ifdef GEMMLOWP_MSA
+template <>
+v4i32 Load<v4i32>(const std::int32_t* src) {
+ return __builtin_msa_ld_w(const_cast<std::int32_t*>(src), 0);
+}
+template <>
+v8i16 Load<v8i16>(const std::int16_t* src) {
+ return __builtin_msa_ld_h(const_cast<std::int16_t*>(src), 0);
+}
+template <>
+void Store<v4i32>(std::int32_t* dst, v4i32 v) {
+ __builtin_msa_st_w(v, dst, 0);
+}
+template <>
+void Store<v8i16>(std::int16_t* dst, v8i16 v) {
+ __builtin_msa_st_h(v, dst, 0);
+}
#endif
-// Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp:
-// Most (though not all) of the fixedpoint functionality being tested
-// consists of functions taking one fixedpoint value and returning one
-// fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators".
-// We factor a lot of testing boilerplate into a common TestUnaryOp function
-// taking a "unary op" object that fully describes the function to be tested.
-// These objects inherit UnaryOpBase mostly as a means to share some default
-// values for some properties.
-//
-// An important design element here is that the fixed-point values are passed
-// around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not
-// as higher-level FixedPoint objects. The motivation for this design is 1) to
-// avoid having to templatize everything in the tIntegerBits parameter of
-// class FixedPoint, and 2) to allow directly testing low-level functions
-// operating on raw types (e.g. RoundingDivideByPOT) without needlessly
-// requiring
-// wrapping raw values in FixedPoint objects.
-class UnaryOpBase {
- public:
- // Min bound of the input range of this op. For example, an op only handling
- // nonnegative values would return 0.
- std::int32_t MinInput() const {
- return std::numeric_limits<std::int32_t>::min();
- }
- // Max bound of the input range of this op. For example, an op only handling
- // nonpositive values would return 0.
- std::int32_t MaxInput() const {
- return std::numeric_limits<std::int32_t>::max();
- }
- // Tolerated difference between actual and reference int32 values.
- // Note that the corresponding real-numbers tolerance depends on the number
- // of integer bits of the fixed-point representation of the results of this
- // op.
- // For example, for an op returning fixed-point values with 0 integer bits,
- // the correspondence between real-number values and raw values is
- // real_number = (2^31) * raw_value.
- std::int32_t Tolerance() const { return 0; }
-};
+#ifdef GEMMLOWP_AVX2
+template <>
+__m256i Load<__m256i>(const std::int32_t* src) {
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src));
+}
-// Op wrapping RoundingDivideByPOT
-class RoundingDivideByPOTOp final : public UnaryOpBase {
- public:
- RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {}
- std::int32_t ReferenceOp(std::int32_t x) const {
- const double d = static_cast<double>(x) / (1ll << exponent_);
- return static_cast<std::int32_t>(std::round(d));
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- return RoundingDivideByPOT(x, exponent_);
- }
+template <>
+int16x16_m256i Load<int16x16_m256i>(const std::int16_t* src) {
+ return to_int16x16_m256i(
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src)));
+}
- private:
- const int exponent_;
-};
+template <>
+void Store<__m256i>(std::int32_t* dst, __m256i v) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
+}
-// Op wrapping SaturatingRoundingMultiplyByPOT
-template <int tExponent>
-class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase {
- public:
- std::int32_t ReferenceOp(std::int32_t x) const {
- const double d = static_cast<double>(x) * std::pow(2., tExponent);
- const double clamp_min = std::numeric_limits<std::int32_t>::min();
- const double clamp_max = std::numeric_limits<std::int32_t>::max();
- const double clamped = std::min(clamp_max, std::max(clamp_min, d));
- return static_cast<std::int32_t>(std::round(clamped));
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- return SaturatingRoundingMultiplyByPOT<tExponent>(x);
- }
-};
+template <>
+void Store<int16x16_m256i>(std::int16_t* dst, int16x16_m256i v) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v.v);
+}
+#endif
-// Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl
-class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final
- : public UnaryOpBase {
+template <typename tSimdType>
+class TestFixedPoint {
public:
- std::int32_t MinInput() const { return -(1 << 29); }
- std::int32_t MaxInput() const { return 0; }
- std::int32_t Tolerance() const { return 500; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = std::exp(d);
- return F::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, 0>;
- const F f = F::FromRaw(x);
- const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f);
- return e.raw();
- }
-};
+ using SimdType = tSimdType;
+ using SimdTypeTraits = FixedPointRawTypeTraits<SimdType>;
+ using ScalarType = typename SimdTypeTraits::ScalarRawType;
+ static constexpr int kSimdLanes = SimdTypeTraits::kLanes;
+ static constexpr int kScalarTypeBits = 8 * sizeof(ScalarType);
+
+ // Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp:
+ // Most (though not all) of the fixedpoint functionality being tested
+ // consists of functions taking one fixedpoint value and returning one
+ // fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators".
+ // We factor a lot of testing boilerplate into a common TestUnaryOp function
+ // taking a "unary op" object that fully describes the function to be tested.
+ // These objects inherit UnaryOpBase mostly as a means to share some default
+ // values for some properties.
+ //
+ // An important design element here is that the fixed-point values are passed
+ // around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not
+ // as higher-level FixedPoint objects. The motivation for this design is 1) to
+ // avoid having to templatize everything in the tIntegerBits parameter of
+ // class FixedPoint, and 2) to allow directly testing low-level functions
+ // operating on raw types (e.g. RoundingDivideByPOT) without needlessly
+ // requiring
+ // wrapping raw values in FixedPoint objects.
+ class UnaryOpBase {
+ public:
+ // Min bound of the input range of this op. For example, an op only handling
+ // nonnegative values would return 0.
+ ScalarType MinInput() const {
+ return std::numeric_limits<ScalarType>::min();
+ }
+ // Max bound of the input range of this op. For example, an op only handling
+ // nonpositive values would return 0.
+ ScalarType MaxInput() const {
+ return std::numeric_limits<ScalarType>::max();
+ }
+ // Tolerated difference between actual and reference ScalarType values.
+ // Note that the corresponding real-numbers tolerance depends on the number
+ // of integer bits of the fixed-point representation of the results of this
+ // op.
+ // For example, for an op returning fixed-point values with 0 integer bits,
+ // the correspondence between real-number values and raw values is
+ // real_number = (2^31) * raw_value.
+ ScalarType Tolerance() const { return 0; }
+ };
+
+ // Op wrapping RoundingDivideByPOT
+ class RoundingDivideByPOTOp final : public UnaryOpBase {
+ public:
+ RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {}
+ ScalarType ReferenceOp(ScalarType x) const {
+ const double d = static_cast<double>(x) / (1ll << exponent_);
+ return static_cast<ScalarType>(std::round(d));
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ return RoundingDivideByPOT(x, exponent_);
+ }
-// Op wrapping exp_on_negative_values
-template <int tIntegerBits>
-class ExpOnNegativeValuesOp final : public UnaryOpBase {
- public:
- std::int32_t MaxInput() const { return 0; }
- std::int32_t Tolerance() const { return 500; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, tIntegerBits>;
- using F0 = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = std::exp(d);
- return F0::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, tIntegerBits>;
- const F f = F::FromRaw(x);
- return exp_on_negative_values(f).raw();
+ private:
+ const int exponent_;
+ };
+
+ // Op wrapping SaturatingRoundingMultiplyByPOT
+ template <int tExponent>
+ class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase {
+ public:
+ ScalarType ReferenceOp(ScalarType x) const {
+ const double d = static_cast<double>(x) * std::pow(2., tExponent);
+ const double clamp_min = std::numeric_limits<ScalarType>::min();
+ const double clamp_max = std::numeric_limits<ScalarType>::max();
+ const double clamped = std::min(clamp_max, std::max(clamp_min, d));
+ return static_cast<ScalarType>(std::round(clamped));
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ return SaturatingRoundingMultiplyByPOT<tExponent>(x);
+ }
+ };
+
+ // Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl
+ class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final
+ : public UnaryOpBase {
+ public:
+ ScalarType MinInput() const { return -(1 << (kScalarTypeBits - 3)); }
+ ScalarType MaxInput() const { return 0; }
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 1; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = std::exp(d);
+ return F::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, 0>;
+ const F f = F::FromRaw(x);
+ const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f);
+ return e.raw();
+ }
+ };
+
+ // Op wrapping exp_on_negative_values
+ template <int tIntegerBits>
+ class ExpOnNegativeValuesOp final : public UnaryOpBase {
+ public:
+ ScalarType MaxInput() const { return 0; }
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 2; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, tIntegerBits>;
+ using F0 = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = std::exp(d);
+ return F0::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, tIntegerBits>;
+ const F f = F::FromRaw(x);
+ return exp_on_negative_values(f).raw();
+ }
+ };
+
+ // Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1
+ class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase {
+ public:
+ ScalarType MinInput() const { return 0; }
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 12 : 11; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = (1 - d) / (1 + d);
+ return F::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, 0>;
+ const F f = F::FromRaw(x);
+ return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw();
+ }
+ };
+
+ // Op wrapping tanh
+ template <int tIntegerBits>
+ class TanhOp final : public UnaryOpBase {
+ public:
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 310 : 12; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, tIntegerBits>;
+ using F0 = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = std::tanh(d);
+ return F0::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, tIntegerBits>;
+ const F f = F::FromRaw(x);
+ return tanh(f).raw();
+ }
+ };
+
+ // Op wrapping one_over_one_plus_x_for_x_in_0_1
+ class OneOverOnePlusXForXIn01Op final : public UnaryOpBase {
+ public:
+ ScalarType MinInput() const { return 0; }
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 6 : 5; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = 1 / (1 + d);
+ return F::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, 0>;
+ const F f = F::FromRaw(x);
+ return one_over_one_plus_x_for_x_in_0_1(f).raw();
+ }
+ };
+
+ // Op wrapping logistic
+ template <int tIntegerBits>
+ class LogisticOp final : public UnaryOpBase {
+ public:
+ ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 155 : 6; }
+ ScalarType ReferenceOp(ScalarType x) const {
+ using F = FixedPoint<ScalarType, tIntegerBits>;
+ using F0 = FixedPoint<ScalarType, 0>;
+ const double d = ToDouble(F::FromRaw(x));
+ const double e = 1 / (1 + std::exp(-d));
+ return F0::FromDouble(e).raw();
+ }
+ template <typename RawType>
+ RawType Op(RawType x) const {
+ using F = FixedPoint<RawType, tIntegerBits>;
+ const F f = F::FromRaw(x);
+ return logistic(f).raw();
+ }
+ };
+
+ // Tests a given op, on a given list of int32 input values.
+ template <typename tUnaryOpType>
+ void TestUnaryOp(const tUnaryOpType& unary_op,
+ const std::vector<ScalarType>& testvals) {
+ Check(0 == (testvals.size() % kSimdLanes));
+ for (std::size_t i = 0; i < testvals.size(); i += kSimdLanes) {
+ // First, clamp input values accoding to the MinInput() and MaxInput()
+ // bounds returned by the op.
+ ScalarType input[kSimdLanes] = {0};
+ for (std::size_t j = 0; j < kSimdLanes; j++) {
+ const ScalarType raw_input = testvals[i + j];
+ input[j] = std::min(unary_op.MaxInput(),
+ std::max(unary_op.MinInput(), raw_input));
+ }
+ // Compute reference results and check that the actual results on
+ // scalar inputs agree with them, to the Tolerance() returned by the op.
+ ScalarType reference[kSimdLanes] = {0};
+ ScalarType actual_scalar[kSimdLanes] = {0};
+ for (std::size_t j = 0; j < kSimdLanes; j++) {
+ reference[j] = unary_op.ReferenceOp(input[j]);
+ actual_scalar[j] = unary_op.Op(input[j]);
+ const std::int64_t diff = static_cast<std::int64_t>(actual_scalar[j]) -
+ static_cast<std::int64_t>(reference[j]);
+ if (std::abs(diff) > unary_op.Tolerance()) {
+ fprintf(stderr, "abs(diff) (%" PRId64 ") > tolerance (%d)\n", diff,
+ unary_op.Tolerance());
+ }
+ Check(std::abs(diff) <= unary_op.Tolerance());
+ }
+ // Check that the actual results on SIMD inputs agree *exactly* with the
+ // actual results on scalar inputs. I.e. SIMD must make absolutely no
+ // difference
+ // to the results, regardless of the fact that both scalar and SIMD
+ // results may differ from the reference results.
+ ScalarType actual_simd[kSimdLanes] = {0};
+ Store<SimdType>(actual_simd, unary_op.Op(Load<SimdType>(input)));
+ for (std::size_t j = 0; j < kSimdLanes; j++) {
+ if (actual_simd[j] != actual_scalar[j]) {
+ fprintf(stderr, "SIMD (%d) != scalar (%d)\n", actual_simd[j],
+ actual_scalar[j]);
+ }
+ Check(actual_simd[j] == actual_scalar[j]);
+ }
+ }
}
-};
-// Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1
-class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase {
- public:
- std::int32_t MinInput() const { return 0; }
- std::int32_t Tolerance() const { return 12; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = (1 - d) / (1 + d);
- return F::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, 0>;
- const F f = F::FromRaw(x);
- return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw();
+ template <int tIntegerBits>
+ void test_convert(FixedPoint<ScalarType, tIntegerBits> x) {
+ typedef FixedPoint<ScalarType, tIntegerBits> F;
+ F y = F::FromDouble(ToDouble(x));
+ Check(y == x);
}
-};
-// Op wrapping tanh
-template <int tIntegerBits>
-class TanhOp final : public UnaryOpBase {
- public:
- std::int32_t Tolerance() const { return 310; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, tIntegerBits>;
- using F0 = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = std::tanh(d);
- return F0::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, tIntegerBits>;
- const F f = F::FromRaw(x);
- return tanh(f).raw();
+ template <int tIntegerBits_a, int tIntegerBits_b>
+ void test_Rescale(FixedPoint<ScalarType, tIntegerBits_a> a) {
+ FixedPoint<ScalarType, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
+ FixedPoint<ScalarType, tIntegerBits_b> expected =
+ FixedPoint<ScalarType, tIntegerBits_b>::FromDouble(ToDouble(a));
+ Check(actual == expected);
}
-};
-// Op wrapping one_over_one_plus_x_for_x_in_0_1
-class OneOverOnePlusXForXIn01Op final : public UnaryOpBase {
- public:
- std::int32_t MinInput() const { return 0; }
- std::int32_t Tolerance() const { return 6; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = 1 / (1 + d);
- return F::FromDouble(e).raw();
- }
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, 0>;
- const F f = F::FromRaw(x);
- return one_over_one_plus_x_for_x_in_0_1(f).raw();
+ template <int tIntegerBits_a, int tIntegerBits_b>
+ void test_Rescale(const std::vector<ScalarType>& testvals) {
+ for (auto a : testvals) {
+ FixedPoint<ScalarType, tIntegerBits_a> aq;
+ aq.raw() = a;
+ test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
+ }
}
-};
-// Op wrapping logistic
-template <int tIntegerBits>
-class LogisticOp final : public UnaryOpBase {
- public:
- std::int32_t Tolerance() const { return 155; }
- std::int32_t ReferenceOp(std::int32_t x) const {
- using F = FixedPoint<std::int32_t, tIntegerBits>;
- using F0 = FixedPoint<std::int32_t, 0>;
- const double d = ToDouble(F::FromRaw(x));
- const double e = 1 / (1 + std::exp(-d));
- return F0::FromDouble(e).raw();
+ template <int tIntegerBits_a, int tIntegerBits_b>
+ void test_mul(FixedPoint<ScalarType, tIntegerBits_a> a,
+ FixedPoint<ScalarType, tIntegerBits_b> b) {
+ static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b;
+ using ProductFixedPoint = FixedPoint<ScalarType, ProductIntegerBits>;
+ ProductFixedPoint ab;
+ ab = a * b;
+ double a_double = ToDouble(a);
+ double b_double = ToDouble(b);
+ double ab_double = a_double * b_double;
+ ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double);
+ std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw());
+ Check(std::abs(diff) <= 1);
}
- template <typename tRawType>
- tRawType Op(tRawType x) const {
- using F = FixedPoint<tRawType, tIntegerBits>;
- const F f = F::FromRaw(x);
- return logistic(f).raw();
- }
-};
-// Tests a given op, on a given list of int32 input values.
-template <typename tUnaryOpType>
-void TestUnaryOp(const tUnaryOpType& unary_op,
- const std::vector<std::int32_t>& testvals_int32) {
- Check(0 == (testvals_int32.size() % SimdVectorSize));
- for (std::size_t i = 0; i < testvals_int32.size(); i += SimdVectorSize) {
- // First, clamp input int32 values accoding to the MinInput() and MaxInput()
- // bounds returned by the op.
- std::int32_t input[SimdVectorSize] = {0};
- for (std::size_t j = 0; j < SimdVectorSize; j++) {
- const std::int32_t raw_input = testvals_int32[i + j];
- input[j] = std::min(unary_op.MaxInput(),
- std::max(unary_op.MinInput(), raw_input));
- }
- // Compute reference results and check that the actual results on
- // scalar inputs agree with them, to the Tolerance() returned by the op.
- std::int32_t reference[SimdVectorSize] = {0};
- std::int32_t actual_scalar[SimdVectorSize] = {0};
- for (std::size_t j = 0; j < SimdVectorSize; j++) {
- reference[j] = unary_op.ReferenceOp(input[j]);
- actual_scalar[j] = unary_op.Op(input[j]);
- const std::int64_t diff = static_cast<std::int64_t>(actual_scalar[j]) -
- static_cast<std::int64_t>(reference[j]);
- Check(std::abs(diff) <= unary_op.Tolerance());
- }
- // Check that the actual results on SIMD inputs agree *exactly* with the
- // actual results on scalar inputs. I.e. SIMD must make absolutely no
- // difference
- // to the results, regardless of the fact that both scalar and SIMD results
- // may differ from the reference results.
- std::int32_t actual_simd[SimdVectorSize] = {0};
- StoreSimdVector(actual_simd, unary_op.Op(LoadSimdVector(input)));
- for (std::size_t j = 0; j < SimdVectorSize; j++) {
- Check(actual_simd[j] == actual_scalar[j]);
+ template <int tIntegerBits_a, int tIntegerBits_b>
+ void test_mul(const std::vector<ScalarType>& testvals) {
+ for (auto a : testvals) {
+ for (auto b : testvals) {
+ FixedPoint<ScalarType, tIntegerBits_a> aq;
+ FixedPoint<ScalarType, tIntegerBits_b> bq;
+ aq.raw() = a;
+ bq.raw() = b;
+ test_mul(aq, bq);
+ }
}
}
-}
-template <int tIntegerBits>
-void test_convert(FixedPoint<std::int32_t, tIntegerBits> x) {
- typedef FixedPoint<std::int32_t, tIntegerBits> F;
- F y = F::FromDouble(ToDouble(x));
- Check(y == x);
-}
-
-template <int tIntegerBits_a, int tIntegerBits_b>
-void test_Rescale(FixedPoint<std::int32_t, tIntegerBits_a> a) {
- FixedPoint<std::int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
- FixedPoint<std::int32_t, tIntegerBits_b> expected =
- FixedPoint<std::int32_t, tIntegerBits_b>::FromDouble(ToDouble(a));
- Check(actual == expected);
-}
-
-template <int tIntegerBits_a, int tIntegerBits_b>
-void test_Rescale(const std::vector<std::int32_t>& testvals_int32) {
- for (auto a : testvals_int32) {
- FixedPoint<std::int32_t, tIntegerBits_a> aq;
- aq.raw() = a;
- test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
+ template <int tExponent, int tIntegerBits_a>
+ void test_ExactMulByPot(FixedPoint<ScalarType, tIntegerBits_a> a) {
+ double x = ToDouble(a) * std::pow(2.0, tExponent);
+ double y = ToDouble(ExactMulByPot<tExponent>(a));
+ Check(x == y);
}
-}
-
-template <int tIntegerBits_a, int tIntegerBits_b>
-void test_mul(FixedPoint<std::int32_t, tIntegerBits_a> a,
- FixedPoint<std::int32_t, tIntegerBits_b> b) {
- static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b;
- using ProductFixedPoint = FixedPoint<std::int32_t, ProductIntegerBits>;
- ProductFixedPoint ab;
- ab = a * b;
- double a_double = ToDouble(a);
- double b_double = ToDouble(b);
- double ab_double = a_double * b_double;
- ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double);
- std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw());
- Check(std::abs(diff) <= 1);
-}
-template <int tIntegerBits_a, int tIntegerBits_b>
-void test_mul(const std::vector<std::int32_t>& testvals_int32) {
- for (auto a : testvals_int32) {
- for (auto b : testvals_int32) {
- FixedPoint<std::int32_t, tIntegerBits_a> aq;
- FixedPoint<std::int32_t, tIntegerBits_b> bq;
+ template <int tExponent, int tIntegerBits_a>
+ void test_ExactMulByPot(const std::vector<ScalarType>& testvals) {
+ for (auto a : testvals) {
+ FixedPoint<ScalarType, tIntegerBits_a> aq;
aq.raw() = a;
- bq.raw() = b;
- test_mul(aq, bq);
+ test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
}
}
-}
-template <int tExponent, int tIntegerBits_a>
-void test_ExactMulByPot(FixedPoint<std::int32_t, tIntegerBits_a> a) {
- double x = ToDouble(a) * std::pow(2.0, tExponent);
- double y = ToDouble(ExactMulByPot<tExponent>(a));
- Check(x == y);
-}
+ // Make the list of test values to test each op against.
+ std::vector<ScalarType> MakeTestVals() {
+ std::vector<ScalarType> testvals;
+
+ for (int i = 0; i < kScalarTypeBits - 1; i++) {
+ testvals.push_back((1 << i) - 2);
+ testvals.push_back((1 << i) - 1);
+ testvals.push_back((1 << i));
+ testvals.push_back((1 << i) + 1);
+ testvals.push_back((1 << i) + 2);
+ testvals.push_back(-(1 << i) - 2);
+ testvals.push_back(-(1 << i) - 1);
+ testvals.push_back(-(1 << i));
+ testvals.push_back(-(1 << i) + 1);
+ testvals.push_back(-(1 << i) + 2);
+ }
+ testvals.push_back(std::numeric_limits<ScalarType>::min());
+ testvals.push_back(std::numeric_limits<ScalarType>::min() + 1);
+ testvals.push_back(std::numeric_limits<ScalarType>::min() + 2);
+ testvals.push_back(std::numeric_limits<ScalarType>::max() - 2);
+ testvals.push_back(std::numeric_limits<ScalarType>::max() - 1);
+ testvals.push_back(std::numeric_limits<ScalarType>::max());
+
+ std::mt19937 random_engine;
+ std::uniform_int_distribution<ScalarType> uniform_distribution(
+ std::numeric_limits<ScalarType>::min(),
+ std::numeric_limits<ScalarType>::max());
+ for (int i = 0; i < 1000; i++) {
+ testvals.push_back(uniform_distribution(random_engine));
+ }
-template <int tExponent, int tIntegerBits_a>
-void test_ExactMulByPot(const std::vector<std::int32_t>& testvals_int32) {
- for (auto a : testvals_int32) {
- FixedPoint<std::int32_t, tIntegerBits_a> aq;
- aq.raw() = a;
- test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
- }
-}
+ // SIMD tests will require the length of testvals to be a multiple
+ // of SIMD vector size.
+ while (testvals.size() % kSimdLanes) {
+ testvals.push_back(0);
+ }
-// Make the list of test values to test each op against.
-std::vector<std::int32_t> MakeTestValsInt32() {
- std::vector<std::int32_t> testvals_int32;
-
- for (int i = 0; i < 31; i++) {
- testvals_int32.push_back((1 << i) - 2);
- testvals_int32.push_back((1 << i) - 1);
- testvals_int32.push_back((1 << i));
- testvals_int32.push_back((1 << i) + 1);
- testvals_int32.push_back((1 << i) + 2);
- testvals_int32.push_back(-(1 << i) - 2);
- testvals_int32.push_back(-(1 << i) - 1);
- testvals_int32.push_back(-(1 << i));
- testvals_int32.push_back(-(1 << i) + 1);
- testvals_int32.push_back(-(1 << i) + 2);
- }
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::min());
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::min() + 1);
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::min() + 2);
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::max() - 2);
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::max() - 1);
- testvals_int32.push_back(std::numeric_limits<std::int32_t>::max());
-
- std::mt19937 random_engine;
- std::uniform_int_distribution<std::int32_t> uniform_distribution(
- std::numeric_limits<std::int32_t>::min(),
- std::numeric_limits<std::int32_t>::max());
- for (int i = 0; i < 1000; i++) {
- testvals_int32.push_back(uniform_distribution(random_engine));
+ std::sort(testvals.begin(), testvals.end());
+ return testvals;
}
- // SIMD tests will require the length of testvals_int32 to be a multiple
- // of SIMD vector size.
- while (testvals_int32.size() % SimdVectorSize) {
- testvals_int32.push_back(0);
- }
+ void RunTests(const char* msg) {
+ const std::vector<ScalarType> testvals = MakeTestVals();
- std::sort(testvals_int32.begin(), testvals_int32.end());
- return testvals_int32;
-}
+ for (int s = 0; s < kScalarTypeBits; s++) {
+ TestUnaryOp(RoundingDivideByPOTOp(s), testvals);
+ }
+
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<14 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15 - kScalarTypeBits>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 15>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 14>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 3>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 2>(),
+ testvals);
+ TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 1>(),
+ testvals);
+
+ TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals);
+ TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals);
+
+ TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals);
+ TestUnaryOp(TanhOp<0>(), testvals);
+ TestUnaryOp(TanhOp<1>(), testvals);
+ TestUnaryOp(TanhOp<2>(), testvals);
+ TestUnaryOp(TanhOp<3>(), testvals);
+ TestUnaryOp(TanhOp<4>(), testvals);
+ TestUnaryOp(TanhOp<5>(), testvals);
+ TestUnaryOp(TanhOp<6>(), testvals);
+
+ TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals);
+ TestUnaryOp(LogisticOp<0>(), testvals);
+ TestUnaryOp(LogisticOp<1>(), testvals);
+ TestUnaryOp(LogisticOp<2>(), testvals);
+ TestUnaryOp(LogisticOp<3>(), testvals);
+ TestUnaryOp(LogisticOp<4>(), testvals);
+ TestUnaryOp(LogisticOp<5>(), testvals);
+ TestUnaryOp(LogisticOp<6>(), testvals);
+
+ for (auto a : testvals) {
+ FixedPoint<ScalarType, 4> x;
+ x.raw() = a;
+ test_convert(x);
+ }
+
+ test_mul<0, 0>(testvals);
+ test_mul<0, 1>(testvals);
+ test_mul<2, 0>(testvals);
+ test_mul<1, 1>(testvals);
+ test_mul<4, 4>(testvals);
+ test_mul<3, 5>(testvals);
+ test_mul<7, 2>(testvals);
+ test_mul<kScalarTypeBits / 2 - 1, kScalarTypeBits / 2 - 2>(testvals);
+
+ test_Rescale<0, 0>(testvals);
+ test_Rescale<0, 1>(testvals);
+ test_Rescale<2, 0>(testvals);
+ test_Rescale<4, 4>(testvals);
+ test_Rescale<4, 5>(testvals);
+ test_Rescale<6, 3>(testvals);
+ test_Rescale<13, 9>(testvals);
+
+ test_ExactMulByPot<0, 0>(testvals);
+ test_ExactMulByPot<0, 4>(testvals);
+ test_ExactMulByPot<1, 4>(testvals);
+ test_ExactMulByPot<3, 2>(testvals);
+ test_ExactMulByPot<-4, 5>(testvals);
+ test_ExactMulByPot<-2, 6>(testvals);
+
+ fprintf(stderr, "PASS (%s)\n", msg);
+ }
+};
} // end anonymous namespace
} // end namespace gemmlowp
int main() {
- using namespace gemmlowp;
-
- const std::vector<std::int32_t> testvals_int32 = MakeTestValsInt32();
-
- for (int s = 0; s < 32; s++) {
- TestUnaryOp(RoundingDivideByPOTOp(s), testvals_int32);
- }
-
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-31>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-30>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-29>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-17>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-16>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<16>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<17>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<29>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<30>(), testvals_int32);
- TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<31>(), testvals_int32);
-
- TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(),
- testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals_int32);
- TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals_int32);
-
- TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals_int32);
- TestUnaryOp(TanhOp<0>(), testvals_int32);
- TestUnaryOp(TanhOp<1>(), testvals_int32);
- TestUnaryOp(TanhOp<2>(), testvals_int32);
- TestUnaryOp(TanhOp<3>(), testvals_int32);
- TestUnaryOp(TanhOp<4>(), testvals_int32);
- TestUnaryOp(TanhOp<5>(), testvals_int32);
- TestUnaryOp(TanhOp<6>(), testvals_int32);
-
- TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals_int32);
- TestUnaryOp(LogisticOp<0>(), testvals_int32);
- TestUnaryOp(LogisticOp<1>(), testvals_int32);
- TestUnaryOp(LogisticOp<2>(), testvals_int32);
- TestUnaryOp(LogisticOp<3>(), testvals_int32);
- TestUnaryOp(LogisticOp<4>(), testvals_int32);
- TestUnaryOp(LogisticOp<5>(), testvals_int32);
- TestUnaryOp(LogisticOp<6>(), testvals_int32);
-
- for (auto a : testvals_int32) {
- FixedPoint<std::int32_t, 4> x;
- x.raw() = a;
- test_convert(x);
- }
-
- test_mul<0, 0>(testvals_int32);
- test_mul<0, 1>(testvals_int32);
- test_mul<2, 0>(testvals_int32);
- test_mul<1, 1>(testvals_int32);
- test_mul<4, 4>(testvals_int32);
- test_mul<3, 5>(testvals_int32);
- test_mul<7, 2>(testvals_int32);
- test_mul<14, 15>(testvals_int32);
-
- test_Rescale<0, 0>(testvals_int32);
- test_Rescale<0, 1>(testvals_int32);
- test_Rescale<2, 0>(testvals_int32);
- test_Rescale<4, 4>(testvals_int32);
- test_Rescale<4, 5>(testvals_int32);
- test_Rescale<6, 3>(testvals_int32);
- test_Rescale<13, 9>(testvals_int32);
-
- test_ExactMulByPot<0, 0>(testvals_int32);
- test_ExactMulByPot<0, 4>(testvals_int32);
- test_ExactMulByPot<1, 4>(testvals_int32);
- test_ExactMulByPot<3, 2>(testvals_int32);
- test_ExactMulByPot<-4, 5>(testvals_int32);
- test_ExactMulByPot<-2, 6>(testvals_int32);
-
- std::cerr << "All tests passed." << std::endl;
+ gemmlowp::TestFixedPoint<std::int32_t>().RunTests("Scalar int32");
+ gemmlowp::TestFixedPoint<std::int16_t>().RunTests("Scalar int16");
+#ifdef GEMMLOWP_SSE4
+ gemmlowp::TestFixedPoint<__m128i>().RunTests("SSE4 __m128i = int32x4");
+ gemmlowp::TestFixedPoint<gemmlowp::int16x8_m128i>().RunTests(
+ "SSE4 __m128i = int16x8");
+#endif
+#ifdef GEMMLOWP_NEON
+ gemmlowp::TestFixedPoint<int32x4_t>().RunTests("NEON int32x4_t");
+ gemmlowp::TestFixedPoint<int16x8_t>().RunTests("NEON int16x8_t");
+#endif
+#ifdef GEMMLOWP_MSA
+ gemmlowp::TestFixedPoint<v4i32>().RunTests("MSA v4i32");
+ gemmlowp::TestFixedPoint<v8i16>().RunTests("MSA v8i16");
+#endif
+#ifdef GEMMLOWP_AVX2
+ gemmlowp::TestFixedPoint<__m256i>().RunTests("AVX __m256i");
+ gemmlowp::TestFixedPoint<gemmlowp::int16x16_m256i>().RunTests(
+ "AVX2 __m256i = int16x16");
+#endif
}