diff options
author | Lev Proleev <levp@google.com> | 2021-02-26 21:44:39 +0000 |
---|---|---|
committer | Lev Proleev <levp@google.com> | 2021-02-26 22:17:12 +0000 |
commit | 123f384187504585be3fe01002381dd459c17d96 (patch) | |
tree | a29716289a0b730ca66a3e632c6ce054eb3b90d6 | |
parent | 8dd5f1b93261d6ea0fe0c8e51d13f89657ceb0b8 (diff) | |
download | gemmlowp-123f384187504585be3fe01002381dd459c17d96.tar.gz |
Update gemmlowp to 13d57703abca3005d97b19df1f2db731607a7dc2
An updated is needed after TF Lite rebase.
Bug: 178609672
Test: mma, NeuralNetworksStatic_test
Change-Id: Ia7f04fc5b6bd760549395854618d8b20f5c8d228
36 files changed, 3765 insertions, 1581 deletions
@@ -7,3 +7,8 @@ # The email address is not required for organizations. Google Inc. +Intel Corporation +ARM Ltd. +Silk Labs Inc. +MIPS Tech LLC +Wave Computing Inc. @@ -30,7 +30,7 @@ license { "SPDX-license-identifier-Apache-2.0", ], license_text: [ - "LICENSE.txt", + "LICENSE", "NOTICE", ], } diff --git a/CONTRIBUTING.txt b/CONTRIBUTING index d6d63bc..d6d63bc 100644 --- a/CONTRIBUTING.txt +++ b/CONTRIBUTING diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS index 7c2415b..3740e0e 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS @@ -15,8 +15,26 @@ Pete Warden <petewarden@google.com> Miao Wang <miaowang@google.com> David Andersen <dga@google.com> Maciek Chociej <maciekc@google.com> +Justine Tunney <jart@google.com> +Mark J. Matthews <mjmatthews@google.com> +Marie White <mariewhite@google.com> +Suharsh Sivakumar <suharshs@google.com> Intel: Sagi Marcovich <sagi.marcovich@intel.com> Murat Efe Guney <murat.e.guney@intel.com> Sarah Knepper <sarah.knepper@intel.com> +Mourad Gouicem <mourad.gouicem@intel.com> +Richard Winterton <richard.winterton@intel.com> + +ARM: +David Mansell <David.Mansell@arm.com> + +Silk Labs: +Andreas Gal <andreas@silklabs.com> + +MIPS Tech LLC: +Alexey Frunze <Alexey.Frunze@mips.com> + +Wave Computing Inc.: +Alexey Frunze <afrunze@wavecomp.com> diff --git a/README.md b/README.md new file mode 100644 index 0000000..22fabac --- /dev/null +++ b/README.md @@ -0,0 +1,276 @@ +# gemmlowp: a small self-contained low-precision GEMM library + +[![Build Status](https://secure.travis-ci.org/google/gemmlowp.png)](http://travis-ci.org/google/gemmlowp) + +This is not a full linear algebra library, only a GEMM library: it only does +general matrix multiplication ("GEMM"). + +The meaning of "low precision" is detailed in this document: +[doc/low-precision.md](doc/low-precision.md) + +Some of the general design is explained in [doc/design.md](doc/design.md). + +**Warning:** This library goes very slow if compiled incorrectly; see below. + +## Disclaimer + +This is not an official Google product (experimental or otherwise), it is just +code that happens to be owned by Google. + +## Mailing list + +gemmlowp-related discussion, about either development or usage, is welcome on +this Google Group (mailing list / forum): + +https://groups.google.com/forum/#!forum/gemmlowp + +## Portability, target platforms/architectures + +Should be portable to any platform with some C++11 and POSIX support, while we +have optional optimized code paths for specific architectures. + +Required: + +* C++11 (a small conservative subset of it) + +Required for some features: + +* Some POSIX interfaces: + * pthreads (for multi-threaded operation and for profiling). + * sysconf (for multi-threaded operation to detect number of cores; may be + bypassed). + +Optional: + +* Architecture-specific code paths use intrinsics or inline assembly. See + "Architecture-specific optimized code paths" below. + +## Architecture-specific optimized code paths + +We have some optimized code paths for specific instruction sets. Some are +written in inline assembly, some are written in C++ using intrinsics. Both GCC +and Clang are supported. + +Current optimized code paths: + +* ARM with NEON (both 32bit and 64bit). +* Intel x86 with SSE 4.1 (both 32bit and 64bit). + +When building for x86, it's very important to pass `-msse4.1` to the compiler, +otherwise gemmlowp will use slow reference code. Bazel users can compile by +running `bazel build --copt=-msse4.1 //gemmlowp:all`. The compiled binary should +work on all Intel CPUs since 2008 (including low power microarchitectures) as +well as AMD CPUs since 2011. + +Please note when compiling binaries that don't need to be distributed, it's +generally a better idea to pass `-march=native` to the compiler. That flag +implies `-msse4.1` flag, along with others that might be helpful. This of course +assumes the host machine supports those instructions. Bazel users should prefer +to run `bazel build --config=opt //gemmlowp:all` instead. + +Details of what it takes to make an efficient port of gemmlowp, namely writing a +suitable GEMM kernel and accompanying packing code, are explained in this file: +[doc/kernel.md](doc/kernel.md). + +## Public interfaces + +### The gemmlowp public interface + +gemmlowp's main public interface is in the `public/` subdirectory. + +This is a headers-only library, so there is nothing to link to. + +Usage documentation, and comments on the deprecation status of each public entry +point, may be found in [doc/public.md](doc/public.md) . + +A full, self-contained usage example, showing how to quantize float matrices and +perform a quantized matrix multiplication approximating a float matrix +multiplication, is given in +[doc/quantization_example.cc](doc/quantization_example.cc). + +### Old EightBitIntGemm legacy deprecated interface + +The `eight_bit_int_gemm/` subdirectory contains an alternate interface that +should be considered purely legacy, deprecated, and going to be removed at some +point in the future. + +## Building + +### Building by manually invoking your compiler + +Because gemmlowp is so simple, working with it involves only single-command-line +compiler invocations. Therefore we expect that most people working with gemmlowp +will either manually invoke their compiler, or write their own rules for their +own preferred build system. + +Keep in mind (previous section) that gemmlowp itself is a pure-headers-only +library so there is nothing to build. + +For a Android gemmlowp development workflow, the `scripts/` directory contains a +script to build and run a program on an Android device: + +``` +scripts/test-android.sh +``` + +### Building using Bazel + +That being said, we also maintain a Bazel BUILD system as part of gemmlowp. Its +usage is not mandatory at all and is only one possible way that gemmlowp +libraries and tests may be built. If you are interested, Bazel's home page is +http://bazel.build/ And you can get started with using Bazel to build gemmlowp +targets by first creating an empty WORKSPACE file in a parent directory, for +instance: + +``` +$ cd gemmlowp/.. # change to parent directory containing gemmlowp/ +$ touch WORKSPACE # declare that to be our workspace root +$ bazel build gemmlowp:all +``` + +## Testing + +### Testing by manually building and running tests + +The test/ directory contains unit tests. The primary unit test is + +``` +test/test.cc +``` + +Since it covers also the EightBitIntGemm interface, it needs to be linked +against + +``` +eight_bit_int_gemm/eight_bit_int_gemm.cc +``` + +It also uses realistic data captured from a neural network run in + +``` +test/test_data.cc +``` + +Thus you'll want to pass the following list of source files to your +compiler/linker: + +``` +test/test.cc +eight_bit_int_gemm/eight_bit_int_gemm.cc +test/test_data.cc +``` + +The `scripts/` directory contains a script to build and run a program on an +Android device: + +``` +scripts/test-android.sh +``` + +It expects the `CXX` environment variable to point to an Android toolchain's C++ +compiler, and expects source files (and optionally, cflags) as command-line +parameters. To build and run the above-mentioned main unit test, first set `CXX` +e.g.: + +``` +$ export CXX=/some/toolchains/arm-linux-androideabi-4.8/bin/arm-linux-androideabi-g++ +``` + +Then run: + +``` +$ ./scripts/test-android.sh \ +test/test.cc \ +eight_bit_int_gemm/eight_bit_int_gemm.cc \ +test/test_data.cc +``` + +### Testing using Bazel + +Alternatively, you can use Bazel to build and run tests. See the Bazel +instruction in the above section on building. Once your Bazel workspace is set +up, you can for instance do: + +``` +$ bazel test gemmlowp:all +``` + +## Troubleshooting Compilation + +If you're having trouble finding the compiler, follow these instructions to +build a standalone toolchain: +https://developer.android.com/ndk/guides/standalone_toolchain.html + +Here's an example of setting up Clang 3.5: + +``` +$ export INSTALL_DIR=~/toolchains/clang-21-stl-gnu +$ $NDK/build/tools/make-standalone-toolchain.sh \ +--toolchain=arm-linux-androideabi-clang3.5 --platform=android-21 \ +--install-dir=$INSTALL_DIR +$ export CXX="$INSTALL_DIR/bin/arm-linux-androideabi-g++ \ +--sysroot=$INSTALL_DIR/sysroot" +``` + +Some compilers (e.g. the default clang++ in the same bin directory) don't +support NEON assembly. The benchmark build process will issue a warning if +support isn't detected, and you should make sure you're using a compiler like +arm-linux-androideabi-g++ that does include NEON. + +## Benchmarking + +The main benchmark is + +``` +test/benchmark.cc +``` + +It doesn't need to be linked to any other source file. We recommend building +with assertions disabled (`-DNDEBUG`). + +For example, the benchmark can be built and run on an Android device by doing: + +``` +$ ./scripts/test-android.sh test/benchmark.cc -DNDEBUG +``` + +If `GEMMLOWP_TEST_PROFILE` is defined then the benchmark will be built with +profiling instrumentation (which makes it slower) and will dump profiles. See +next section on profiling. + +## Profiling + +The `profiling/` subdirectory offers a very simple, naive, inaccurate, +non-interrupting sampling profiler that only requires pthreads (no signals). + +It relies on source code being instrumented with pseudo-stack labels. See +`profiling/instrumentation.h`. A full example of using this profiler is given in +the top comment of `profiling/profiler.h`. + +## Contributing + +Contribution-related discussion is always welcome on the gemmlowp mailing list +(see above). + +We try to keep a current list of TODO items in the `todo/` directory. +Prospective contributors are welcome to pick one to work on, and communicate +about it on the gemmlowp mailing list. + +Details of the contributing process, including legalese, are in CONTRIBUTING. + +## Performance goals + +Our performance goals differ from typical GEMM performance goals in the +following ways: + +1. We care not only about speed, but also about minimizing power usage. We + specifically care about charge usage in mobile/embedded devices. This + implies that we care doubly about minimizing memory bandwidth usage: we care + about it, like any GEMM, because of the impact on speed, and we also care + about it because it is a key factor of power usage. + +2. Most GEMMs are optimized primarily for large dense matrix sizes (>= 1000). + We do care about large sizes, but we also care specifically about the + typically smaller matrix sizes encountered in various mobile applications. + This means that we have to optimize for all sizes, not just for large enough + sizes. diff --git a/README.txt b/README.txt deleted file mode 100644 index e29f0e4..0000000 --- a/README.txt +++ /dev/null @@ -1,260 +0,0 @@ -gemmlowp: a small self-contained low-precision GEMM library -=========================================================== - -This is not a full linear algebra library, only a GEMM library: it only does -general matrix multiplication ("GEMM"). - -The meaning of "low precision" is detailed in this document: - doc/low-precision.txt - -Some of the general design is explained in - doc/design.txt - - -Disclaimer -========== - -This is not an official Google product (experimental or otherwise), it is just -code that happens to be owned by Google. - - -Mailing list -============ - -gemmlowp-related discussion, about either development or usage, is welcome -on this Google Group (mailing list / forum): - - https://groups.google.com/forum/#!forum/gemmlowp - - -Portability, target platforms/architectures -=========================================== - -Should be portable to any platform with some C++11 and POSIX support, -while we have optional optimized code paths for specific architectures. - -Required: - C++11 (a small conservative subset of it) - -Required for some features: - * Some POSIX interfaces: - * pthreads (for multi-threaded operation and for profiling). - * sysconf (for multi-threaded operation to detect number of cores; - may be bypassed). - -Optional: - Architecture-specific code paths use intrinsics or inline assembly. - See "Architecture-specific optimized code paths" below. - -Architecture-specific optimized code paths -========================================== - -We have some optimized code paths for specific instruction sets. -Some are written in inline assembly, some are written in C++ using -intrinsics. Both GCC and Clang are supported. - -At the moment, we have a full set of optimized code paths (kernels, -packing and unpacking paths) only for ARM NEON, supporting both -ARMv7 (32bit) and ARMv8 (64bit). - -We also have a partial set of optimized code paths (only kernels -at the moment) for Intel SSE. It supports both x86 and x86-64 but -only targets SSE4. The lack of packing/unpacking code paths means -that performance isn't optimal yet. - -Details of what it takes to make an efficient port of gemmlowp, namely -writing a suitable GEMM kernel and accompanying packing code, are -explained in this file: - doc/kernels.txt - - -Public interfaces -================= - -1. gemmlowp public interface ----------------------------- - - gemmlowp's main public interface is in the public/ subdirectory. The - header to include is - public/gemmlowp.h. - This is a headers-only library, so there is nothing to link to. - -2. EightBitIntGemm standard interface -------------------------------------- - - Additionally, the eight_bit_int_gemm/ subdirectory provides an - implementation of the standard EightBitIntGemm interface. The header - to include is - eight_bit_int_gemm/eight_bit_int_gemm.h - This is *NOT* a headers-only library, users need to link to - eight_bit_int_gemm/eight_bit_int_gemm.cc. - The API is similar to the standard BLAS GEMM interface, and implements - C = A * B. If the transpose flags for a matrix argument are false, its memory - order is treated as column major, and row major if its true. - - -Building -======== - -Building by manually invoking your compiler -------------------------------------------- - -Because gemmlowp is so simple, working with it involves only -single-command-line compiler invokations. Therefore we expect that -most people working with gemmlowp will either manually invoke their -compiler, or write their own rules for their own preferred build -system. - -Keep in mind (previous section) that gemmlowp itself is a pure-headers-only -library so there is nothing to build, and the eight_bit_int_gemm library -consists of a single eight_bit_int_gemm.cc file to build. - -For a Android gemmlowp development workflow, the scripts/ directory -contains a script to build and run a program on an Android device: - scripts/test-android.sh - -Building using Bazel --------------------- - -That being said, we also maintain a Bazel BUILD system as part of -gemmlowp. Its usage is not mandatory at all and is only one -possible way that gemmlowp libraries and tests may be built. If -you are interested, Bazel's home page is - http://bazel.io/ -And you can get started with using Bazel to build gemmlowp targets -by first creating an empty WORKSPACE file in a parent directory, -for instance: - -$ cd gemmlowp/.. # change to parent directory containing gemmlowp/ -$ touch WORKSPACE # declare that to be our workspace root -$ bazel build gemmlowp:all - - -Testing -======= - -Testing by manually building and running tests ----------------------------------------------- - -The test/ directory contains unit tests. The primary unit test is - test/test.cc -Since it covers also the EightBitIntGemm interface, it needs to be -linked against - eight_bit_int_gemm/eight_bit_int_gemm.cc -It also uses realistic data captured from a neural network run in - test/test_data.cc - -Thus you'll want to pass the following list of source files to your -compiler/linker: - test/test.cc - eight_bit_int_gemm/eight_bit_int_gemm.cc - test/test_data.cc - -The scripts/ directory contains a script to build and run a program -on an Android device: - scripts/test-android.sh - -It expects the CXX environment variable to point to an Android toolchain's -C++ compiler, and expects source files (and optionally, cflags) as -command-line parameters. To build and run the above-mentioned main unit test, -first set CXX e.g.: - -$ export CXX=/some/toolchains/arm-linux-androideabi-4.8/bin/arm-linux-androideabi-g++ - -Then run: - -$ ./scripts/test-android.sh \ -test/test.cc \ -eight_bit_int_gemm/eight_bit_int_gemm.cc \ -test/test_data.cc - - -Testing using Bazel -------------------- - -Alternatively, you can use Bazel to build and run tests. See the Bazel -instruction in the above section on building. Once your Bazel workspace -is set up, you can for instance do: - -$ bazel test gemmlowp:all - - -Troubleshooting Compilation -=========================== - -If you're having trouble finding the compiler, follow these instructions to -build a standalone toolchain: -https://developer.android.com/ndk/guides/standalone_toolchain.html - -Here's an example of setting up Clang 3.5: - -$ export INSTALL_DIR=~/toolchains/clang-21-stl-gnu -$ $NDK/build/tools/make-standalone-toolchain.sh \ ---toolchain=arm-linux-androideabi-clang3.5 --platform=android-21 \ ---install-dir=$INSTALL_DIR -$ export CXX="$INSTALL_DIR/bin/arm-linux-androideabi-g++ \ ---sysroot=$INSTALL_DIR/sysroot" - -Some compilers (e.g. the default clang++ in the same bin directory) don't -support NEON assembly. The benchmark build process will issue a warning if -support isn't detected, and you should make sure you're using a compiler like -arm-linux-androideabi-g++ that does include NEON. - - -Benchmarking -============ - -The main benchmark is - benchmark.cc -It doesn't need to be linked to any -other source file. We recommend building with assertions disabled (-DNDEBUG). - -For example, the benchmark can be built and run on an Android device by doing: - -$ ./scripts/test-android.sh test/benchmark.cc -DNDEBUG - -If GEMMLOWP_TEST_PROFILE is defined then the benchmark will be built with -profiling instrumentation (which makes it slower) and will dump profiles. -See next section on profiling. - - -Profiling -========= - -The profiling/ subdirectory offers a very simple non-interrupting sampling -profiler that only requires pthreads (no signals). - -It relies on source code being instrumented with pseudo-stack labels. -See profiling/instrumentation.h. -A full example of using this profiler is given in profiling/profiler.h. - - -Contributing -============ - -Contribution-related discussion is always welcome on the gemmlowp -mailing list (see above). - -We try to keep a current list of TODO items in the todo/ directory. -Prospective contributors are welcome to pick one to work on, and -communicate about it on the gemmlowp mailing list. - -Details of the contributing process, including legalese, are in CONTRIBUTING. - -Performance goals -================= - -Our performance goals differ from typical GEMM performance goals in the -following ways: - -1. We care not only about speed, but also about minimizing power usage. - We specifically care about charge usage in mobile/embedded devices. - This implies that we care doubly about minimizing memory bandwidth usage: - we care about it, like any GEMM, because of the impact on speed, and we - also care about it because it is a key factor of power usage. - -2. Most GEMMs are optimized primarily for large dense matrix sizes (>= 1000). - We do care about large sizes, but we also care specifically about the - typically smaller matrix sizes encountered in various mobile applications. - This means that we have to optimize for all sizes, not just for large enough - sizes. diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h index 58e8050..56e95c0 100644 --- a/fixedpoint/fixedpoint.h +++ b/fixedpoint/fixedpoint.h @@ -95,12 +95,13 @@ tIntegerType Add(tIntegerType a, tIntegerType b) { return a + b; } -// Integer subtraction. Not saturating. Overflow is undefined behavior. +// Integer multiplication. Not saturating. Overflow is undefined behavior. template <typename tIntegerType> tIntegerType Mul(tIntegerType a, tIntegerType b) { return a * b; } +// Integer subtraction. Not saturating. Overflow is undefined behavior. template <typename tIntegerType> tIntegerType Sub(tIntegerType a, tIntegerType b) { return a - b; @@ -268,6 +269,16 @@ inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { std::max(static_cast<std::int32_t>(-32768), sum))); } +template <> +inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) { + std::int16_t a16 = a; + std::int16_t b16 = b; + std::int16_t sum = a16 + b16; + return static_cast<std::int8_t>(std::min( + static_cast<int16_t>(std::numeric_limits<int8_t>::max()), + std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum))); +} + // Returns a+b, saturating if the integers are 16bit or narrower, // otherwise just a plain addition. template <typename IntegerType, bool Is16Bit> @@ -767,13 +778,14 @@ FixedPoint<tRawType, 0> exp_on_negative_values( result * kMultiplier, result); \ } - GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); - GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); - GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); - GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); - GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); - GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); - GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); + // Constants below are Q0 representations of negative exp fractionals: + GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); // exp(-1/4) + GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); // exp(-1/2) + GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); // exp(-1) + GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); // exp(-2) + GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); // exp(-4) + GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); // exp(-8) + GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); // exp(-16) #undef GEMMLOWP_EXP_BARREL_SHIFTER @@ -895,6 +907,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { #include "./fixedpoint_sse.h" #elif defined(GEMMLOWP_MSA) #include "./fixedpoint_msa.h" +#elif defined(GEMMLOWP_WASMSIMD) +#include "./fixedpoint_wasmsimd.h" #endif #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h index 1816386..f3fe732 100644 --- a/fixedpoint/fixedpoint_avx.h +++ b/fixedpoint/fixedpoint_avx.h @@ -17,69 +17,139 @@ #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ #define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ -#include <smmintrin.h> +#include <immintrin.h> #include "fixedpoint.h" #include "fixedpoint_sse.h" namespace gemmlowp { +struct int16x16_m256i { + __m256i v; +}; + +// Keep int16x16_m256i trivially constructible/destructible and provide +// easily optimized helper function. +inline int16x16_m256i to_int16x16_m256i(__m256i w) { + int16x16_m256i r; + r.v = w; + return r; +} + template <> struct FixedPointRawTypeTraits<__m256i> { typedef std::int32_t ScalarRawType; + // TODO: This can actually support up to 8 lanes, so we should either + // change to 8 or create int32x8_m256i struct to handle that case. static const int kLanes = 4; }; template <> +struct FixedPointRawTypeTraits<int16x16_m256i> { + typedef std::int16_t ScalarRawType; + static const int kLanes = 16; +}; + +template <> inline __m256i BitAnd(__m256i a, __m256i b) { return _mm256_and_si256(a, b); } template <> +inline int16x16_m256i BitAnd(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_and_si256(a.v, b.v)); +} + +template <> inline __m256i BitOr(__m256i a, __m256i b) { return _mm256_or_si256(a, b); } template <> +inline int16x16_m256i BitOr(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_or_si256(a.v, b.v)); +} + +template <> inline __m256i BitXor(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); } template <> +inline int16x16_m256i BitXor(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_xor_si256(a.v, b.v)); +} + +template <> inline __m256i BitNot(__m256i a) { return _mm256_andnot_si256(a, _mm256_set1_epi32(-1)); } template <> +inline int16x16_m256i BitNot(int16x16_m256i a) { + return to_int16x16_m256i(_mm256_andnot_si256(a.v, _mm256_set1_epi16(-1))); +} + +template <> inline __m256i Add(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); } template <> +inline int16x16_m256i Add(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_add_epi16(a.v, b.v)); +} + +template <> inline __m256i Mul(__m256i a, __m256i b) { return _mm256_mullo_epi32(a, b); } template <> +inline int16x16_m256i Mul(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_mullo_epi16(a.v, b.v)); +} + +template <> inline __m256i Sub(__m256i a, __m256i b) { return _mm256_sub_epi32(a, b); } template <> +inline int16x16_m256i Sub(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_sub_epi16(a.v, b.v)); +} + +template <> inline __m256i Neg(__m256i a) { return _mm256_sign_epi32(a, _mm256_set1_epi32(-1)); } template <> +inline int16x16_m256i Neg(int16x16_m256i a) { + return to_int16x16_m256i(_mm256_sign_epi16(a.v, _mm256_set1_epi16(-1))); +} + +template <> inline __m256i ShiftLeft(__m256i a, int offset) { return _mm256_slli_epi32(a, offset); } template <> +inline int16x16_m256i ShiftLeft(int16x16_m256i a, int offset) { + return to_int16x16_m256i(_mm256_slli_epi16(a.v, offset)); +} + +template <> inline __m256i ShiftRight(__m256i a, int offset) { return _mm256_srai_epi32(a, offset); } template <> +inline int16x16_m256i ShiftRight(int16x16_m256i a, int offset) { + return to_int16x16_m256i(_mm256_srai_epi16(a.v, offset)); +} + +template <> inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val, __m256i else_val) { return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val), @@ -88,45 +158,97 @@ inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val, } template <> +inline int16x16_m256i SelectUsingMask(int16x16_m256i if_mask, + int16x16_m256i then_val, + int16x16_m256i else_val) { + // Borrowed from Intel's arm_neon_sse.h header. + return to_int16x16_m256i( + _mm256_or_si256(_mm256_and_si256(if_mask.v, then_val.v), + _mm256_andnot_si256(if_mask.v, else_val.v))); +} + +template <> inline __m256i MaskIfEqual(__m256i a, __m256i b) { return _mm256_cmpeq_epi32(a, b); } template <> +inline int16x16_m256i MaskIfEqual(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_cmpeq_epi16(a.v, b.v)); +} + +template <> inline __m256i MaskIfNotEqual(__m256i a, __m256i b) { return BitNot(MaskIfEqual(a, b)); } template <> +inline int16x16_m256i MaskIfNotEqual(int16x16_m256i a, int16x16_m256i b) { + return BitNot(MaskIfEqual(a, b)); +} + +template <> inline __m256i MaskIfZero(__m256i a) { return MaskIfEqual(a, _mm256_set1_epi32(0)); } template <> +inline int16x16_m256i MaskIfZero(int16x16_m256i a) { + return MaskIfEqual(a, to_int16x16_m256i(_mm256_set1_epi16(0))); +} + +template <> inline __m256i MaskIfNonZero(__m256i a) { return MaskIfNotEqual(a, _mm256_set1_epi32(0)); } template <> +inline int16x16_m256i MaskIfNonZero(int16x16_m256i a) { + return MaskIfNotEqual(a, to_int16x16_m256i(_mm256_set1_epi16(0))); +} + +template <> inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) { return _mm256_cmpgt_epi32(a, b); } template <> +inline int16x16_m256i MaskIfGreaterThan(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_cmpgt_epi16(a.v, b.v)); +} + +template <> inline __m256i MaskIfLessThan(__m256i a, __m256i b) { return _mm256_cmpgt_epi32(b, a); } template <> +inline int16x16_m256i MaskIfLessThan(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_cmpgt_epi16(b.v, a.v)); +} + +template <> inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) { return BitNot(MaskIfLessThan(a, b)); } template <> +inline int16x16_m256i MaskIfGreaterThanOrEqual(int16x16_m256i a, + int16x16_m256i b) { + return BitNot(MaskIfLessThan(a, b)); +} + +template <> inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) { return BitNot(MaskIfGreaterThan(a, b)); } +template <> +inline int16x16_m256i MaskIfLessThanOrEqual(int16x16_m256i a, + int16x16_m256i b) { + return BitNot(MaskIfGreaterThan(a, b)); +} + /* Assumptions: - All and Any are used on masks. - masks are all_ones for true lanes, all_zeroes otherwise. @@ -139,11 +261,21 @@ inline bool All(__m256i a) { } template <> +inline bool All(int16x16_m256i a) { + return _mm256_testc_si256(a.v, a.v); +} + +template <> inline bool Any(__m256i a) { return BitNot(_mm256_testz_si256(a, a)); } template <> +inline bool Any(int16x16_m256i a) { + return BitNot(_mm256_testz_si256(a.v, a.v)); +} + +template <> inline __m256i RoundingHalfSum(__m256i a, __m256i b) { /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */ /* We divide the inputs before the add to avoid the overflow and costly test @@ -171,6 +303,17 @@ inline __m256i RoundingHalfSum(__m256i a, __m256i b) { } template <> +inline int16x16_m256i RoundingHalfSum(int16x16_m256i a, int16x16_m256i b) { + // Borrowed from Intel's arm_neon_sse.h header. + __m256i constant_neg_32768 = _mm256_set1_epi16(-32768); + __m256i a_unsigned = _mm256_sub_epi16(a.v, constant_neg_32768); + __m256i b_unsigned = _mm256_sub_epi16(b.v, constant_neg_32768); + __m256i avg_unsigned = _mm256_avg_epu16(a_unsigned, b_unsigned); + __m256i avg = _mm256_add_epi16(avg_unsigned, constant_neg_32768); + return to_int16x16_m256i(avg); +} + +template <> inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) { __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3; __m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded; @@ -209,10 +352,33 @@ inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) { } template <> +inline int16x16_m256i SaturatingRoundingDoublingHighMul(int16x16_m256i a, + int16x16_m256i b) { + // Use _mm256_mulhrs_epi16 then saturate with a bit-operation, + // borrowed from Intel's arm_neon_sse.h header. + __m256i result_unsaturated = _mm256_mulhrs_epi16(a.v, b.v); + __m256i saturation_mask = + _mm256_cmpeq_epi16(result_unsaturated, _mm256_set1_epi16(0x8000)); + __m256i result = _mm256_xor_si256(result_unsaturated, saturation_mask); + return to_int16x16_m256i(result); +} + +template <> inline __m256i Dup<__m256i>(std::int32_t x) { return _mm256_set1_epi32(x); } +template <> +inline int16x16_m256i Dup<int16x16_m256i>(std::int16_t x) { + return to_int16x16_m256i(_mm256_set1_epi16(x)); +} + +// So far this is only needed for int16. +template <> +inline int16x16_m256i SaturatingAdd(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_adds_epi16(a.v, b.v)); +} + } // end namespace gemmlowp #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h index a1fae32..fbaa26a 100644 --- a/fixedpoint/fixedpoint_sse.h +++ b/fixedpoint/fixedpoint_sse.h @@ -32,13 +32,17 @@ namespace gemmlowp { // data type, int16x8_m128i, that wraps __m128i while being a separate // type. struct int16x8_m128i { - int16x8_m128i() {} - explicit int16x8_m128i(__m128i w) : v(w) {} - ~int16x8_m128i() {} - __m128i v; }; +// Keep int16x8_m128i trivially constructible/destructible and provide +// easily optimized helper function. +inline int16x8_m128i to_int16x8_m128i(__m128i w) { + int16x8_m128i r; + r.v = w; + return r; +} + template <> struct FixedPointRawTypeTraits<__m128i> { typedef std::int32_t ScalarRawType; @@ -58,7 +62,7 @@ inline __m128i BitAnd(__m128i a, __m128i b) { template <> inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_and_si128(a.v, b.v)); + return to_int16x8_m128i(_mm_and_si128(a.v, b.v)); } template <> @@ -68,7 +72,7 @@ inline __m128i BitOr(__m128i a, __m128i b) { template <> inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_or_si128(a.v, b.v)); + return to_int16x8_m128i(_mm_or_si128(a.v, b.v)); } template <> @@ -78,7 +82,7 @@ inline __m128i BitXor(__m128i a, __m128i b) { template <> inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_xor_si128(a.v, b.v)); + return to_int16x8_m128i(_mm_xor_si128(a.v, b.v)); } template <> @@ -88,7 +92,7 @@ inline __m128i BitNot(__m128i a) { template <> inline int16x8_m128i BitNot(int16x8_m128i a) { - return int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1))); + return to_int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1))); } template <> @@ -98,7 +102,7 @@ inline __m128i Add(__m128i a, __m128i b) { template <> inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_add_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_add_epi16(a.v, b.v)); } template <> @@ -108,7 +112,7 @@ inline __m128i Mul(__m128i a, __m128i b) { template <> inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_mullo_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_mullo_epi16(a.v, b.v)); } template <> @@ -118,7 +122,7 @@ inline __m128i Sub(__m128i a, __m128i b) { template <> inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_sub_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_sub_epi16(a.v, b.v)); } template <> @@ -128,7 +132,7 @@ inline __m128i Neg(__m128i a) { template <> inline int16x8_m128i Neg(int16x8_m128i a) { - return int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1))); + return to_int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1))); } template <> @@ -138,7 +142,7 @@ inline __m128i ShiftLeft(__m128i a, int offset) { template <> inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) { - return int16x8_m128i(_mm_slli_epi16(a.v, offset)); + return to_int16x8_m128i(_mm_slli_epi16(a.v, offset)); } template <> @@ -148,7 +152,7 @@ inline __m128i ShiftRight(__m128i a, int offset) { template <> inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) { - return int16x8_m128i(_mm_srai_epi16(a.v, offset)); + return to_int16x8_m128i(_mm_srai_epi16(a.v, offset)); } template <> @@ -164,7 +168,7 @@ inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask, int16x8_m128i then_val, int16x8_m128i else_val) { // borrowed from Intel's arm_neon_sse.h header. - return int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v)); + return to_int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v)); } template <> @@ -174,7 +178,7 @@ inline __m128i MaskIfEqual(__m128i a, __m128i b) { template <> inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v)); } template <> @@ -194,7 +198,7 @@ inline __m128i MaskIfZero(__m128i a) { template <> inline int16x8_m128i MaskIfZero(int16x8_m128i a) { - return MaskIfEqual(a, int16x8_m128i(_mm_set1_epi16(0))); + return MaskIfEqual(a, to_int16x8_m128i(_mm_set1_epi16(0))); } template <> @@ -204,7 +208,7 @@ inline __m128i MaskIfNonZero(__m128i a) { template <> inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) { - return MaskIfNotEqual(a, int16x8_m128i(_mm_set1_epi16(0))); + return MaskIfNotEqual(a, to_int16x8_m128i(_mm_set1_epi16(0))); } template <> @@ -214,7 +218,7 @@ inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) { template <> inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v)); } template <> @@ -224,7 +228,7 @@ inline __m128i MaskIfLessThan(__m128i a, __m128i b) { template <> inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_cmplt_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_cmplt_epi16(a.v, b.v)); } template <> @@ -310,7 +314,7 @@ inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) { __m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768); __m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned); __m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768); - return int16x8_m128i(avg); + return to_int16x8_m128i(avg); } template <> @@ -360,7 +364,7 @@ inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a, __m128i saturation_mask = _mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000)); __m128i result = _mm_xor_si128(result_unsaturated, saturation_mask); - return int16x8_m128i(result); + return to_int16x8_m128i(result); } template <> @@ -370,13 +374,13 @@ inline __m128i Dup<__m128i>(std::int32_t x) { template <> inline int16x8_m128i Dup<int16x8_m128i>(std::int16_t x) { - return int16x8_m128i(_mm_set1_epi16(x)); + return to_int16x8_m128i(_mm_set1_epi16(x)); } // So far this is only needed for int16. template <> inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_adds_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_adds_epi16(a.v, b.v)); } } // end namespace gemmlowp diff --git a/fixedpoint/fixedpoint_wasmsimd.h b/fixedpoint/fixedpoint_wasmsimd.h new file mode 100644 index 0000000..868fbfe --- /dev/null +++ b/fixedpoint/fixedpoint_wasmsimd.h @@ -0,0 +1,381 @@ +// Copyright 2020 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// fixedpoint_wasmsimd.h: optimized WAsm SIMD specializations of the templates +// in fixedpoint.h. + +#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_ +#define GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_ + +#include <wasm_simd128.h> + +namespace gemmlowp { + +// WAsm SIMD intrinsics are not typed: there is a single v128_t vector +// type that does not distinguish between "int32x4" and "int16x8" use +// cases, unlike the NEON equivalents. Because we had initially focused +// on int32x4, we did not pay attention and specialized these fixedpoint +// templates directly for v128_t hardcoding the int32x4 semantics, +// not leaving room for int16x8 semantics. Amending that by adding a separate +// data type, int16x8_v128_t, that wraps v128_t while being a separate +// type. +struct int16x8_v128_t { + v128_t v; +}; + +// Keep int16x8_v128_t trivially constructible/destructible and provide +// easily optimized helper function. +inline int16x8_v128_t to_int16x8_v128_t(v128_t w) { + int16x8_v128_t r; + r.v = w; + return r; +} + +template <> +struct FixedPointRawTypeTraits<v128_t> { + typedef std::int32_t ScalarRawType; + static constexpr int kLanes = 4; +}; + +template <> +struct FixedPointRawTypeTraits<int16x8_v128_t> { + typedef std::int16_t ScalarRawType; + static constexpr int kLanes = 8; +}; + +template <> +inline v128_t BitAnd(v128_t a, v128_t b) { + return wasm_v128_and(a, b); +} + +template <> +inline int16x8_v128_t BitAnd(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_v128_and(a.v, b.v)); +} + +template <> +inline v128_t BitOr(v128_t a, v128_t b) { + return wasm_v128_or(a, b); +} + +template <> +inline int16x8_v128_t BitOr(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_v128_or(a.v, b.v)); +} + +template <> +inline v128_t BitXor(v128_t a, v128_t b) { + return wasm_v128_xor(a, b); +} + +template <> +inline int16x8_v128_t BitXor(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_v128_xor(a.v, b.v)); +} + +template <> +inline v128_t BitNot(v128_t a) { + return wasm_v128_not(a); +} + +template <> +inline int16x8_v128_t BitNot(int16x8_v128_t a) { + return to_int16x8_v128_t(wasm_v128_not(a.v)); +} + +template <> +inline v128_t Add(v128_t a, v128_t b) { + return wasm_i32x4_add(a, b); +} + +template <> +inline int16x8_v128_t Add(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_add(a.v, b.v)); +} + +template <> +inline v128_t Mul(v128_t a, v128_t b) { + return wasm_i32x4_mul(a, b); +} + +template <> +inline int16x8_v128_t Mul(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_mul(a.v, b.v)); +} + +template <> +inline v128_t Sub(v128_t a, v128_t b) { + return wasm_i32x4_sub(a, b); +} + +template <> +inline int16x8_v128_t Sub(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_sub(a.v, b.v)); +} + +template <> +inline v128_t Neg(v128_t a) { + return wasm_i32x4_neg(a); +} + +template <> +inline int16x8_v128_t Neg(int16x8_v128_t a) { + return to_int16x8_v128_t(wasm_i16x8_neg(a.v)); +} + +template <> +inline v128_t ShiftLeft(v128_t a, int offset) { + return wasm_i32x4_shl(a, offset); +} + +template <> +inline int16x8_v128_t ShiftLeft(int16x8_v128_t a, int offset) { + return to_int16x8_v128_t(wasm_i16x8_shl(a.v, offset)); +} + +template <> +inline v128_t ShiftRight(v128_t a, int offset) { + return wasm_i32x4_shr(a, offset); +} + +template <> +inline int16x8_v128_t ShiftRight(int16x8_v128_t a, int offset) { + return to_int16x8_v128_t(wasm_i16x8_shr(a.v, offset)); +} + +template <> +inline v128_t SelectUsingMask(v128_t if_mask, v128_t then_val, + v128_t else_val) { + return wasm_v128_bitselect(then_val, else_val, if_mask); +} + +template <> +inline int16x8_v128_t SelectUsingMask(int16x8_v128_t if_mask, + int16x8_v128_t then_val, + int16x8_v128_t else_val) { + return to_int16x8_v128_t( + wasm_v128_bitselect(then_val.v, else_val.v, if_mask.v)); +} + +template <> +inline v128_t MaskIfEqual(v128_t a, v128_t b) { + return wasm_i32x4_eq(a, b); +} + +template <> +inline int16x8_v128_t MaskIfEqual(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_eq(a.v, b.v)); +} + +template <> +inline v128_t MaskIfNotEqual(v128_t a, v128_t b) { + return wasm_i32x4_ne(a, b); +} + +template <> +inline int16x8_v128_t MaskIfNotEqual(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_ne(a.v, b.v)); +} + +template <> +inline v128_t MaskIfZero(v128_t a) { + return MaskIfEqual(a, wasm_i32x4_const(0, 0, 0, 0)); +} + +template <> +inline int16x8_v128_t MaskIfZero(int16x8_v128_t a) { + return MaskIfEqual( + a, to_int16x8_v128_t(wasm_i16x8_const(0, 0, 0, 0, 0, 0, 0, 0))); +} + +template <> +inline v128_t MaskIfNonZero(v128_t a) { + return MaskIfNotEqual(a, wasm_i32x4_const(0, 0, 0, 0)); +} + +template <> +inline int16x8_v128_t MaskIfNonZero(int16x8_v128_t a) { + return MaskIfNotEqual( + a, to_int16x8_v128_t(wasm_i16x8_const(0, 0, 0, 0, 0, 0, 0, 0))); +} + +template <> +inline v128_t MaskIfGreaterThan(v128_t a, v128_t b) { + return wasm_i32x4_gt(a, b); +} + +template <> +inline int16x8_v128_t MaskIfGreaterThan(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_gt(a.v, b.v)); +} + +template <> +inline v128_t MaskIfLessThan(v128_t a, v128_t b) { + return wasm_i32x4_lt(a, b); +} + +template <> +inline int16x8_v128_t MaskIfLessThan(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_lt(a.v, b.v)); +} + +template <> +inline v128_t MaskIfGreaterThanOrEqual(v128_t a, v128_t b) { + return wasm_i32x4_ge(a, b); +} + +template <> +inline int16x8_v128_t MaskIfGreaterThanOrEqual(int16x8_v128_t a, + int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_ge(a.v, b.v)); +} + +template <> +inline v128_t MaskIfLessThanOrEqual(v128_t a, v128_t b) { + return wasm_i32x4_le(a, b); +} + +template <> +inline int16x8_v128_t MaskIfLessThanOrEqual(int16x8_v128_t a, + int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_le(a.v, b.v)); +} + +/* Assumptions: + - All and Any are used on masks. + - masks are all_ones for true lanes, all_zeroes otherwise. +Hence, All means all 128bits set, and Any means any bit set. +*/ + +template <> +inline bool All(v128_t a) { + return wasm_i32x4_all_true(a); +} + +template <> +inline bool All(int16x8_v128_t a) { + return wasm_i16x8_all_true(a.v); +} + +template <> +inline bool Any(v128_t a) { + return wasm_i32x4_any_true(a); +} + +template <> +inline bool Any(int16x8_v128_t a) { + return wasm_i16x8_any_true(a.v); +} + +template <> +inline v128_t RoundingHalfSum(v128_t a, v128_t b) { + // We divide the inputs before the add to avoid the overflow and costly test. + const v128_t one = wasm_i32x4_const(1, 1, 1, 1); + const v128_t sign_bit_mask = + wasm_i32x4_const(0x80000000, 0x80000000, 0x80000000, 0x80000000); + const v128_t sum = Add(a, b); + const v128_t rounded_half_sum = ShiftRight(Add(sum, one), 1); + const v128_t overflow = + BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)), + sign_bit_mask); + const v128_t result = BitXor(rounded_half_sum, overflow); + return result; +} + +template <> +inline int16x8_v128_t RoundingHalfSum(int16x8_v128_t a, int16x8_v128_t b) { + // Idea: go to unsigned to use wasm_u16x8_avgr, + // borrowed from Intel's arm_neon_sse.h header. + const v128_t constant_neg_32768 = wasm_i16x8_const( + -32768, -32768, -32768, -32768, -32768, -32768, -32768, -32768); + const v128_t a_unsigned = wasm_v128_xor(a.v, constant_neg_32768); + const v128_t b_unsigned = wasm_v128_xor(b.v, constant_neg_32768); + const v128_t avg_unsigned = wasm_u16x8_avgr(a_unsigned, b_unsigned); + const v128_t avg = wasm_v128_xor(avg_unsigned, constant_neg_32768); + return to_int16x8_v128_t(avg); +} + +template <> +inline v128_t SaturatingRoundingDoublingHighMul(v128_t a, v128_t b) { + // TODO: switch to extended multiplication once implemented in the toolchain + const v128_t a_sign = wasm_i32x4_shr(a, 31); + const v128_t b_sign = wasm_i32x4_shr(b, 31); + + const v128_t a_ext_lo = wasm_v32x4_shuffle(a, a_sign, 0, 4, 1, 5); + const v128_t a_ext_hi = wasm_v32x4_shuffle(a, a_sign, 2, 6, 3, 7); + const v128_t b_ext_lo = wasm_v32x4_shuffle(b, b_sign, 0, 4, 1, 5); + const v128_t b_ext_hi = wasm_v32x4_shuffle(b, b_sign, 2, 6, 3, 7); + + const v128_t ab_lo = wasm_i64x2_mul(a_ext_lo, b_ext_lo); + const v128_t ab_hi = wasm_i64x2_mul(a_ext_hi, b_ext_hi); + + const v128_t nudge_2x = + wasm_i64x2_const(INT64_C(0x80000000), INT64_C(0x80000000)); + const v128_t ab_lo_2x = wasm_i64x2_add(ab_lo, ab_lo); + const v128_t ab_hi_2x = wasm_i64x2_add(ab_hi, ab_hi); + + const v128_t ab_lo_rounded_2x = wasm_i64x2_add(ab_lo_2x, nudge_2x); + const v128_t ab_hi_rounded_2x = wasm_i64x2_add(ab_hi_2x, nudge_2x); + + const v128_t prod = + wasm_v32x4_shuffle(ab_lo_rounded_2x, ab_hi_rounded_2x, 1, 3, 5, 7); + + // Saturation only happen if a == b == INT_MIN, and this is the only case + // where prod == INT_MIN (0x80000000) instead of INT_MAX (0x7FFFFFFF). + const v128_t min = wasm_i32x4_const(INT32_C(0x80000000), INT32_C(0x80000000), + INT32_C(0x80000000), INT32_C(0x80000000)); + + return wasm_v128_xor(prod, wasm_i32x4_eq(prod, min)); +} + +template <> +inline int16x8_v128_t SaturatingRoundingDoublingHighMul(int16x8_v128_t a, + int16x8_v128_t b) { +#if 0 + // TODO: enable if https://github.com/WebAssembly/simd/pull/365 is accepted + return to_int16x8_v128_t(__builtin_wasm_q15mulr_saturate_s_i16x8(a.v, b.v)); +#else + // TODO: switch to extended multiplication once implemented in the toolchain + v128_t lo = wasm_i32x4_mul(wasm_i32x4_widen_low_i16x8(a.v), + wasm_i32x4_widen_low_i16x8(b.v)); + v128_t hi = wasm_i32x4_mul(wasm_i32x4_widen_high_i16x8(a.v), + wasm_i32x4_widen_high_i16x8(b.v)); + const v128_t inc = wasm_i32x4_const(0x4000, 0x4000, 0x4000, 0x4000); + lo = wasm_i32x4_add(lo, inc); + hi = wasm_i32x4_add(hi, inc); + lo = wasm_i32x4_shr(lo, 15); + hi = wasm_i32x4_shr(hi, 15); + return to_int16x8_v128_t(wasm_i16x8_narrow_i32x4(lo, hi)); +#endif +} + +template <> +inline v128_t Dup<v128_t>(std::int32_t x) { + return wasm_i32x4_splat(x); +} + +template <> +inline int16x8_v128_t Dup<int16x8_v128_t>(std::int16_t x) { + return to_int16x8_v128_t(wasm_i16x8_splat(x)); +} + +// So far this is only needed for int16. +template <> +inline int16x8_v128_t SaturatingAdd(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_add_saturate(a.v, b.v)); +} + +} // end namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_ @@ -3,10 +3,9 @@ LIB_COPTS = [] LIB_LINKOPTS = select({ ":android": [], + ":windows": [], "//conditions:default": ["-lpthread"], }) -BIN_LINKOPTS = select({ - ":android": [], - "//conditions:default": ["-lpthread"], -}) +BIN_LINKOPTS = LIB_LINKOPTS + diff --git a/internal/allocator.h b/internal/allocator.h index 3a6f077..e71df15 100644 --- a/internal/allocator.h +++ b/internal/allocator.h @@ -86,11 +86,11 @@ class Allocator { } // Alignment of allocated blocks. - static const std::size_t kAlignment = kDefaultCacheLineSize; + static constexpr std::size_t kAlignment = kDefaultCacheLineSize; // This is all we need so far, and since the usage pattern is fixed, // there is no point in allowing more until we need to. - static const std::size_t kMaxBlocks = 5; + static constexpr std::size_t kMaxBlocks = 5; void Commit() { assert(!committed_); diff --git a/internal/common.h b/internal/common.h index 332ad07..708cc40 100644 --- a/internal/common.h +++ b/internal/common.h @@ -165,7 +165,7 @@ Integer RoundUpToPowerOfTwo(Integer n) { template <int N> struct IsPowerOfTwo { - static const bool value = !(N & (N - 1)); + static constexpr bool value = !(N & (N - 1)); }; template <typename T> diff --git a/internal/detect_platform.h b/internal/detect_platform.h index 6f06d19..7f0d78c 100644 --- a/internal/detect_platform.h +++ b/internal/detect_platform.h @@ -71,6 +71,11 @@ #define GEMMLOWP_X86 #endif +// Detect WebAssembly SIMD. +#if defined(__wasm_simd128__) +#define GEMMLOWP_WASMSIMD +#endif + // Some of our optimized paths use inline assembly and for // now we don't bother enabling some other optimized paths using intrinddics // where we can't use inline assembly paths. diff --git a/internal/dispatch_gemm_shape.h b/internal/dispatch_gemm_shape.h index ba4f341..b844f78 100644 --- a/internal/dispatch_gemm_shape.h +++ b/internal/dispatch_gemm_shape.h @@ -74,7 +74,8 @@ struct TransposeImpl<MatrixMap<Scalar, Order>> { template <VectorShape Shape> struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> { typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType; - static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value; + static constexpr VectorShape TransposedShape = + TransposeVectorShape<Shape>::Value; typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType; static DstType Run(const SrcType& src) { DstType dst; @@ -88,7 +89,8 @@ struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> { template <VectorShape Shape> struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> { typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType; - static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value; + static constexpr VectorShape TransposedShape = + TransposeVectorShape<Shape>::Value; typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape> DstType; static DstType Run(const SrcType& src) { diff --git a/internal/kernel.h b/internal/kernel.h index 3120216..f1a3fd8 100644 --- a/internal/kernel.h +++ b/internal/kernel.h @@ -126,11 +126,11 @@ enum class CellOrder { DepthMajor, WidthMajor, Diagonal }; // out in a cell. That is, a CellOrder together with actual dimensions. template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor> struct CellFormat { - static const int kWidth = tWidth; - static const int kDepth = tDepth; - static const CellOrder kOrder = tOrder; + static constexpr int kWidth = tWidth; + static constexpr int kDepth = tDepth; + static constexpr CellOrder kOrder = tOrder; - static const int kSize = kWidth * kDepth; + static constexpr int kSize = kWidth * kDepth; }; // KernelSideFormat describes how data is laid out in a kernel side @@ -142,9 +142,9 @@ struct CellFormat { template <typename tCellFormat, int tCells> struct KernelSideFormat { typedef tCellFormat Cell; - static const int kCells = tCells; - static const int kWidth = kCells * Cell::kWidth; - static const int kDepth = Cell::kDepth; + static constexpr int kCells = tCells; + static constexpr int kWidth = kCells * Cell::kWidth; + static constexpr int kDepth = Cell::kDepth; typedef std::uint8_t Scalar; // The scalar type of the Format. typedef std::uint8_t InputScalar; // The scalar type of the original input. }; @@ -173,9 +173,9 @@ struct KernelFormat { typedef tRhs Rhs; static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, ""); - static const int kDepth = Lhs::Cell::kDepth; - static const int kRows = Lhs::Cell::kWidth * Lhs::kCells; - static const int kCols = Rhs::Cell::kWidth * Rhs::kCells; + static constexpr int kDepth = Lhs::Cell::kDepth; + static constexpr int kRows = Lhs::Cell::kWidth * Lhs::kCells; + static constexpr int kCols = Rhs::Cell::kWidth * Rhs::kCells; }; inline const char* CellOrderName(CellOrder o) { diff --git a/internal/output_sse.h b/internal/output_sse.h index 75aebfd..6ea3290 100644 --- a/internal/output_sse.h +++ b/internal/output_sse.h @@ -535,6 +535,27 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { } }; +// Specialization for MatrixMap, for performance. +template <typename tScalar, MapOrder tOrder> +struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, MatrixMap<tScalar, tOrder>> { + static void Run(const RegBlockUint8<8, 8>& src, + MatrixMap<tScalar, tOrder>* dst, int row, int col) { + std::uint8_t buf[64]; + StoreUint8x16(buf, src.buf.reg[0]); + StoreUint8x16(buf + 16, src.buf.reg[1]); + StoreUint8x16(buf + 32, src.buf.reg[2]); + StoreUint8x16(buf + 48, src.buf.reg[3]); + // Make a local copy so that the compiler can prove that data_ does not + // alias &data_ or &stride_. + MatrixMap<tScalar, tOrder> local = *dst; + for (int c = 0; c < 8; c++) { + for (int r = 0; r < 8; r++) { + *local.data(row + r, col + c) = buf[r + 8 * c]; + } + } + } +}; + } // namespace gemmlowp #endif // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ diff --git a/internal/pack.h b/internal/pack.h index 7c43d6e..82f0dd1 100644 --- a/internal/pack.h +++ b/internal/pack.h @@ -143,7 +143,7 @@ template <typename tScalar, SideMapOrder tOrder> class SideMap { public: typedef tScalar Scalar; - static const SideMapOrder kOrder = tOrder; + static constexpr SideMapOrder kOrder = tOrder; SideMap(Scalar* data, int width, int depth, int stride) : data_(data), width_(width), depth_(depth), stride_(stride) {} @@ -214,13 +214,13 @@ class PackingRegisterBlockBase { typedef typename KernelSideFormat::Cell CellFormat; typedef typename KernelSideFormat::InputScalar KernelInputScalar; typedef typename KernelSideFormat::Scalar KernelScalar; - static const int kCells = KernelSideFormat::kCells; - static const int kCellWidth = CellFormat::kWidth; - static const int kKernelWidth = CellFormat::kWidth * kCells; - static const int kCellDepth = CellFormat::kDepth; - static const int kCellSize = CellFormat::kSize; - static const SideMapOrder kSrcOrder = SrcMapType::kOrder; - static const int kZeroPointInputValue = + static constexpr int kCells = KernelSideFormat::kCells; + static constexpr int kCellWidth = CellFormat::kWidth; + static constexpr int kKernelWidth = CellFormat::kWidth * kCells; + static constexpr int kCellDepth = CellFormat::kDepth; + static constexpr int kCellSize = CellFormat::kSize; + static constexpr SideMapOrder kSrcOrder = SrcMapType::kOrder; + static constexpr int kZeroPointInputValue = ZeroPointInputValue<KernelInputScalar, KernelScalar>::kValue; PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {} @@ -302,10 +302,10 @@ class PackSideBlockImpl { public: typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat; typedef typename KernelSideFormat::Cell CellFormat; - static const int kCells = KernelSideFormat::kCells; - static const int kCellWidth = CellFormat::kWidth; - static const int kKernelWidth = CellFormat::kWidth * kCells; - static const int kCellDepth = CellFormat::kDepth; + static constexpr int kCells = KernelSideFormat::kCells; + static constexpr int kCellWidth = CellFormat::kWidth; + static constexpr int kKernelWidth = CellFormat::kWidth * kCells; + static constexpr int kCellDepth = CellFormat::kDepth; typedef PackingRegisterBlock<SrcMapType, PackedSideBlock> PackingRegisterBlockType; diff --git a/internal/pack_sse.h b/internal/pack_sse.h index 52163c4..b729014 100644 --- a/internal/pack_sse.h +++ b/internal/pack_sse.h @@ -41,11 +41,11 @@ class PackingRegisterBlock< public: typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; typedef typename KernelSideFormat::Cell CellFormat; - static const int kCells = KernelSideFormat::kCells; - static const int kCellWidth = CellFormat::kWidth; - static const int kKernelWidth = CellFormat::kWidth * kCells; - static const int kCellDepth = CellFormat::kDepth; - static const int kCellSize = CellFormat::kSize; + static constexpr int kCells = KernelSideFormat::kCells; + static constexpr int kCellWidth = CellFormat::kWidth; + static constexpr int kKernelWidth = CellFormat::kWidth * kCells; + static constexpr int kCellDepth = CellFormat::kDepth; + static constexpr int kCellSize = CellFormat::kSize; void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { std::uint8_t* dst_ptr = dst->current_data(); diff --git a/internal/platform.h b/internal/platform.h index 54517c3..0f3a2b8 100644 --- a/internal/platform.h +++ b/internal/platform.h @@ -30,8 +30,7 @@ #include <sys/time.h> #endif -#if defined __ANDROID__ -#include <android/api-level.h> +#if defined ANDROID || defined __ANDROID__ #include <malloc.h> // The 18 here should be 16, but has to be 18 for now due // to a Google-internal issue. diff --git a/meta/generators/cc_emitter.py b/meta/generators/cc_emitter.py index 8615671..c1dc75d 100644 --- a/meta/generators/cc_emitter.py +++ b/meta/generators/cc_emitter.py @@ -52,16 +52,16 @@ class CCEmitter(object): self.indent = self.indent[:-2] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff --git a/meta/generators/common.py b/meta/generators/common.py index d680372..7269b50 100644 --- a/meta/generators/common.py +++ b/meta/generators/common.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """.""" +import collections _HEADER_COPYRIGHT = ( '''// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. @@ -71,7 +72,7 @@ class StreamGenerator(object): self.emitter = emitter def SpecializeStream(self, in_type, lanes_count, pack_size, leftovers): - if callable(getattr(self, 'EmitPack', None)): + if isinstance(getattr(self, 'EmitPack', None), collections.Callable): template_params = [in_type, lanes_count, pack_size, leftovers, self.name] self.emitter.EmitMemberFunctionBegin( 'Stream', [], template_params, 'Pack', diff --git a/meta/generators/neon_emitter.py b/meta/generators/neon_emitter.py index 726766e..0304317 100644 --- a/meta/generators/neon_emitter.py +++ b/meta/generators/neon_emitter.py @@ -187,7 +187,7 @@ class NeonEmitter(object): self.indent = self.indent[:-delta] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def PushOp(self, op): if op in self.ops.keys(): @@ -199,13 +199,13 @@ class NeonEmitter(object): self.ops.clear() def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff --git a/meta/generators/neon_emitter_64.py b/meta/generators/neon_emitter_64.py index 13a0715..956b16b 100644 --- a/meta/generators/neon_emitter_64.py +++ b/meta/generators/neon_emitter_64.py @@ -423,7 +423,7 @@ class NeonEmitter64(object): self.indent = self.indent[:-delta] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def PushOp(self, op): if op in self.ops.keys(): @@ -435,13 +435,13 @@ class NeonEmitter64(object): self.ops.clear() def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff --git a/public/bit_depth.h b/public/bit_depth.h index 412944e..5b19430 100644 --- a/public/bit_depth.h +++ b/public/bit_depth.h @@ -22,8 +22,8 @@ namespace gemmlowp { // The range of allowed values for an operand. template <int tMinValue, int tMaxValue> struct OperandRange { - static const int kMinValue = tMinValue; - static const int kMaxValue = tMaxValue; + static constexpr int kMinValue = tMinValue; + static constexpr int kMaxValue = tMaxValue; static_assert(kMinValue < kMaxValue, ""); }; diff --git a/public/map.h b/public/map.h index fe6bc5c..1b71f9e 100644 --- a/public/map.h +++ b/public/map.h @@ -32,7 +32,7 @@ template <typename tScalar, MapOrder tOrder> class MatrixMap { public: typedef tScalar Scalar; - static const MapOrder kOrder = tOrder; + static constexpr MapOrder kOrder = tOrder; protected: Scalar* data_; // not owned. @@ -84,7 +84,7 @@ template <typename tScalar, VectorShape tShape> class VectorMap { public: typedef tScalar Scalar; - static const VectorShape kShape = tShape; + static constexpr VectorShape kShape = tShape; protected: Scalar* data_; // not owned. @@ -113,7 +113,7 @@ template <typename tScalar, VectorShape tShape> class VectorDup { public: typedef tScalar Scalar; - static const VectorShape kShape = tShape; + static constexpr VectorShape kShape = tShape; protected: Scalar data_; diff --git a/standalone/cache_counters.cc b/standalone/cache_counters.cc new file mode 100644 index 0000000..24e971c --- /dev/null +++ b/standalone/cache_counters.cc @@ -0,0 +1,404 @@ +#include <asm/unistd.h> +#include <linux/perf_event.h> +#include <sys/ioctl.h> +#include <unistd.h> +#include <algorithm> +#include <cstdint> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <random> + +#ifndef __aarch64__ +#error This program is for 64-bit ARM only. +#endif + +struct PerfEvent { + perf_event_attr pe; + int fd = -1; + + PerfEvent(std::uint32_t type, std::uint64_t config) { + memset(&pe, 0, sizeof(pe)); + pe.size = sizeof(pe); + pe.type = type; + pe.config = config; + pe.disabled = 1; + pe.exclude_kernel = 1; + pe.exclude_hv = 1; + fd = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0); + if (fd == -1) { + fprintf(stderr, "perf_event_open failed for config 0x%lx\n", config); + abort(); + } + } + + void Start() { + ioctl(fd, PERF_EVENT_IOC_RESET, 0); + ioctl(fd, PERF_EVENT_IOC_ENABLE, 0); + } + + std::int64_t Stop() { + ioctl(fd, PERF_EVENT_IOC_DISABLE, 0); + std::int64_t count = 0; + read(fd, &count, sizeof(count)); + return count; + } + + ~PerfEvent() { close(fd); } +}; + +struct ArmPmuEvent : PerfEvent { + static constexpr std::uint16_t L1I_CACHE_REFILL = 0x01; + static constexpr std::uint16_t L1I_TLB_REFILL = 0x02; + static constexpr std::uint16_t L1D_CACHE_REFILL = 0x03; + static constexpr std::uint16_t L1D_CACHE = 0x04; + static constexpr std::uint16_t L1D_TLB_REFILL = 0x05; + static constexpr std::uint16_t LD_RETIRED = 0x06; + static constexpr std::uint16_t ST_RETIRED = 0x07; + static constexpr std::uint16_t INST_RETIRED = 0x08; + static constexpr std::uint16_t EXC_TAKEN = 0x09; + static constexpr std::uint16_t EXC_RETURN = 0x0A; + static constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B; + static constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C; + static constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D; + static constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E; + static constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F; + static constexpr std::uint16_t BR_MIS_PRED = 0x10; + static constexpr std::uint16_t CPU_CYCLES = 0x11; + static constexpr std::uint16_t BR_PRED = 0x12; + static constexpr std::uint16_t MEM_ACCESS = 0x13; + static constexpr std::uint16_t L1I_CACHE = 0x14; + static constexpr std::uint16_t L1D_CACHE_WB = 0x15; + static constexpr std::uint16_t L2D_CACHE = 0x16; + static constexpr std::uint16_t L2D_CACHE_REFILL = 0x17; + static constexpr std::uint16_t L2D_CACHE_WB = 0x18; + static constexpr std::uint16_t BUS_ACCESS = 0x19; + static constexpr std::uint16_t MEMORY_ERROR = 0x1A; + static constexpr std::uint16_t INST_SPEC = 0x1B; + static constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C; + static constexpr std::uint16_t BUS_CYCLES = 0x1D; + static constexpr std::uint16_t CHAIN = 0x1E; + static constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F; + static constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20; + static constexpr std::uint16_t BR_RETIRED = 0x21; + static constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22; + static constexpr std::uint16_t STALL_FRONTEND = 0x23; + static constexpr std::uint16_t STALL_BACKEND = 0x24; + static constexpr std::uint16_t L1D_TLB = 0x25; + static constexpr std::uint16_t L1I_TLB = 0x26; + static constexpr std::uint16_t L2I_CACHE = 0x27; + static constexpr std::uint16_t L2I_CACHE_REFILL = 0x28; + static constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29; + static constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A; + static constexpr std::uint16_t L3D_CACHE = 0x2B; + static constexpr std::uint16_t L3D_CACHE_WB = 0x2C; + static constexpr std::uint16_t L2D_TLB_REFILL = 0x2D; + static constexpr std::uint16_t L2I_TLB_REFILL = 0x2E; + static constexpr std::uint16_t L2D_TLB = 0x2F; + static constexpr std::uint16_t L2I_TLB = 0x30; + static constexpr std::uint16_t LL_CACHE = 0x32; + static constexpr std::uint16_t LL_CACHE_MISS = 0x33; + static constexpr std::uint16_t DTLB_WALK = 0x34; + static constexpr std::uint16_t LL_CACHE_RD = 0x36; + static constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37; + static constexpr std::uint16_t L1D_CACHE_RD = 0x40; + static constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42; + static constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C; + static constexpr std::uint16_t L1D_TLB_RD = 0x4E; + static constexpr std::uint16_t L2D_CACHE_RD = 0x50; + static constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52; + static constexpr std::uint16_t BUS_ACCESS_RD = 0x60; + static constexpr std::uint16_t MEM_ACCESS_RD = 0x66; + static constexpr std::uint16_t L3D_CACHE_RD = 0xA0; + static constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2; + ArmPmuEvent(std::uint16_t number) : PerfEvent(PERF_TYPE_RAW, number) {} +}; + +struct CacheCounts { + int ld_retired = 0; + int mem_access = 0; + int ll_cache = 0; + int ll_cache_miss = 0; + int l1d_cache = 0; + int l1d_cache_refill = 0; + int l2d_cache = 0; + int l2d_cache_refill = 0; + int l3d_cache = 0; + int l3d_cache_refill = 0; +}; + +void PrintCacheCounts(const CacheCounts& cache_counts) { + printf("ld_retired = %d\n", cache_counts.ld_retired); + printf("mem_access = %d\n", cache_counts.mem_access); + printf("ll_cache = %d\n", cache_counts.ll_cache); + printf("ll_cache_miss = %d\n", cache_counts.ll_cache_miss); + printf("l1d_cache = %d\n", cache_counts.l1d_cache); + printf("l1d_cache_refill = %d\n", cache_counts.l1d_cache_refill); + printf("l2d_cache = %d\n", cache_counts.l2d_cache); + printf("l2d_cache_refill = %d\n", cache_counts.l2d_cache_refill); + printf("l3d_cache = %d\n", cache_counts.l3d_cache); + printf("l3d_cache_refill = %d\n", cache_counts.l3d_cache_refill); +} + +void Workload(int accesses, int size, std::uint8_t* buf) { + // The main reason to do this in assembly is an attempt to make sense + // of instruction count counters, such as LD_RETIRED. + // Also, if we did this in C++, we would need to be watchful of the compiler + // optimizing away operations whose result isn't consumed. + // + // Note that TWO separate tricks are needed here to prevent Cortex-A76 + // speculative execution om prefetching data from future loop iterations: + // 1. A data-dependency whereby the pointers being dereferenced at the + // next loop iteration depend on values loaded at the current iteration. + // That is the role of 'dummy'. + // 2. A pseudo-random sequence. This is the role of register w0, + // where we implement a simple xorshift pseudorandom generator. + // BOTH of these tricks are needed: if we disable just one of them, + // Cortex-A76 successfully speculates some addresses, resulting in different + // L3 / DRAM hit percentages on large sizes. + std::uint64_t dummy = 123456789; + asm volatile( + // w0 := xorshift RNG state. Must be nonzero. + "mov w0, #1\n" + "1:\n" + // xorshift RNG iteration: update w0 with the next pseudorandom value + // in [1 .. 2^32-1]. + // This pseudorandomness is crucial to preventing speculative execution + // on Cortex-A76 from prefetching data from future loop iterations. + "eor w0, w0, w0, lsl #13\n" + "eor w0, w0, w0, lsr #17\n" + "eor w0, w0, w0, lsl #5\n" + // w1 := size - 1 = size mask (size is required to be power-of-two). + "sub w1, %w[size], #1\n" + // w2 := (pseudorandom value w0) xor (data-dependent sum). + "eor w2, w0, %w[dummy]\n" + // w1 := w2 modulo size + "and w1, w2, w1\n" + // align w1 + "and w1, w1, #-64\n" + // load at offset w1, again using x1 as destination. + "ldr x1, [%[buf], w1, uxtw]\n" + // Update our dummy so it depends on the value we have just loaded. + // This data-dependency is key to preventing speculative execution on + // Cortex-A76 from prefetching data from future loop iterations. + "add %[dummy], %[dummy], w1, uxtw\n" + // loop back. + "subs %w[accesses], %w[accesses], #1\n" + "bne 1b\n" + : [ accesses ] "+r"(accesses), [ dummy ] "+r"(dummy) + : [ size ] "r"(size), [ buf ] "r"(buf) + : "memory", "cc", "x0", "x1", "x2"); +} + +void MeasureCacheCounts(int accesses, int size, std::uint8_t* buf, + CacheCounts* cache_counts) { + const bool only_reads = getenv("ONLY_READS"); + ArmPmuEvent ld_retired(ArmPmuEvent::LD_RETIRED); + ArmPmuEvent mem_access(only_reads ? ArmPmuEvent::MEM_ACCESS_RD + : ArmPmuEvent::MEM_ACCESS); + ArmPmuEvent ll_cache(only_reads ? ArmPmuEvent::LL_CACHE_RD + : ArmPmuEvent::LL_CACHE); + ArmPmuEvent ll_cache_miss(only_reads ? ArmPmuEvent::LL_CACHE_MISS_RD + : ArmPmuEvent::LL_CACHE_MISS); + ArmPmuEvent l1d_cache(only_reads ? ArmPmuEvent::L1D_CACHE_RD + : ArmPmuEvent::L1D_CACHE); + ArmPmuEvent l1d_cache_refill(only_reads ? ArmPmuEvent::L1D_CACHE_REFILL_RD + : ArmPmuEvent::L1D_CACHE_REFILL); + ArmPmuEvent l2d_cache(only_reads ? ArmPmuEvent::L2D_CACHE_RD + : ArmPmuEvent::L2D_CACHE); + ArmPmuEvent l2d_cache_refill(only_reads ? ArmPmuEvent::L2D_CACHE_REFILL_RD + : ArmPmuEvent::L2D_CACHE_REFILL); + ArmPmuEvent l3d_cache(only_reads ? ArmPmuEvent::L3D_CACHE_RD + : ArmPmuEvent::L3D_CACHE); + ArmPmuEvent l3d_cache_refill(only_reads ? ArmPmuEvent::L3D_CACHE_REFILL_RD + : ArmPmuEvent::L3D_CACHE_REFILL); + + ld_retired.Start(); + mem_access.Start(); + ll_cache.Start(); + ll_cache_miss.Start(); + l1d_cache.Start(); + l1d_cache_refill.Start(); + l2d_cache.Start(); + l2d_cache_refill.Start(); + l3d_cache.Start(); + l3d_cache_refill.Start(); + + Workload(accesses, size, buf); + + cache_counts->ld_retired = ld_retired.Stop(); + cache_counts->mem_access = mem_access.Stop(); + cache_counts->ll_cache = ll_cache.Stop(); + cache_counts->ll_cache_miss = ll_cache_miss.Stop(); + cache_counts->l1d_cache = l1d_cache.Stop(); + cache_counts->l1d_cache_refill = l1d_cache_refill.Stop(); + cache_counts->l2d_cache = l2d_cache.Stop(); + cache_counts->l2d_cache_refill = l2d_cache_refill.Stop(); + cache_counts->l3d_cache = l3d_cache.Stop(); + cache_counts->l3d_cache_refill = l3d_cache_refill.Stop(); +} + +struct PieChart { + // How many accesses were recorded, total? The other fields must sum to that. + int total; + // How many accesses were serviced with the typical cost of a L1 cache hit? + int l1_hits; + // How many accesses were serviced with the typical cost of a L2 cache hit? + int l2_hits; + // How many accesses were serviced with the typical cost of a L3 cache hit? + int l3_hits; + // How many accesses were serviced with the typical cost of a DRAM access? + int dram_hits; + + ~PieChart() { + // Consistency check + if (total != l1_hits + l2_hits + l3_hits + dram_hits) { + fprintf(stderr, "inconsistent pie-chart\n"); + abort(); + } + } +}; + +struct Hypothesis { + virtual ~Hypothesis() {} + virtual const char* Name() const = 0; + virtual void Analyze(const CacheCounts& cache_counts, + PieChart* pie) const = 0; +}; + +struct Hypothesis1 : Hypothesis { + const char* Name() const override { return "Hypothesis1"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->total = cache_counts.l1d_cache + cache_counts.l1d_cache_refill; + pie->l1_hits = cache_counts.l1d_cache - cache_counts.l2d_cache_refill - + cache_counts.l3d_cache_refill; + pie->l2_hits = cache_counts.l1d_cache_refill; + pie->l3_hits = cache_counts.l2d_cache_refill; + pie->dram_hits = cache_counts.l3d_cache_refill; + } +}; + +struct Hypothesis2 : Hypothesis { + const char* Name() const override { return "Hypothesis2"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->total = cache_counts.l1d_cache; + pie->l1_hits = cache_counts.l1d_cache - cache_counts.l2d_cache; + pie->l2_hits = cache_counts.l2d_cache - cache_counts.l3d_cache; + pie->l3_hits = cache_counts.l3d_cache - cache_counts.l3d_cache_refill; + pie->dram_hits = cache_counts.l3d_cache_refill; + } +}; + +struct Hypothesis3 : Hypothesis { + const char* Name() const override { return "Hypothesis3"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->total = cache_counts.l1d_cache; + int corrected_l2 = std::min(cache_counts.l2d_cache, cache_counts.l1d_cache); + int corrected_l3 = std::min(cache_counts.l3d_cache, corrected_l2); + pie->l1_hits = cache_counts.l1d_cache - corrected_l2; + pie->l2_hits = corrected_l2 - corrected_l3; + pie->l3_hits = corrected_l3 - cache_counts.l3d_cache_refill; + pie->dram_hits = cache_counts.l3d_cache_refill; + } +}; + +struct Hypothesis4 : Hypothesis { + const char* Name() const override { return "Hypothesis4"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->total = cache_counts.l1d_cache; + pie->l1_hits = cache_counts.l1d_cache - cache_counts.l1d_cache_refill; + pie->l2_hits = + cache_counts.l1d_cache_refill - cache_counts.l2d_cache_refill; + pie->l3_hits = + cache_counts.l2d_cache_refill - cache_counts.l3d_cache_refill; + pie->dram_hits = cache_counts.l3d_cache_refill; + } +}; + +struct Hypothesis5 : Hypothesis { + const char* Name() const override { return "Hypothesis5"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->l1_hits = + std::max(0, cache_counts.l1d_cache - cache_counts.l1d_cache_refill); + pie->l2_hits = std::max( + 0, cache_counts.l1d_cache_refill - cache_counts.l2d_cache_refill); + const int l3_misses = + std::max(cache_counts.ll_cache_miss, cache_counts.l3d_cache_refill); + pie->l3_hits = std::max(0, cache_counts.l2d_cache_refill - l3_misses); + pie->dram_hits = l3_misses; + pie->total = pie->l1_hits + pie->l2_hits + pie->l3_hits + pie->dram_hits; + } +}; + +void PrintPieChart(const PieChart& pie) { + printf("total accesses: %d\n", pie.total); + double l1_hits_pct = 100. * pie.l1_hits / pie.total; + double l2_hits_pct = 100. * pie.l2_hits / pie.total; + double l3_hits_pct = 100. * pie.l3_hits / pie.total; + double dram_hits_pct = 100. * pie.dram_hits / pie.total; + printf("L1 hits: %.2f%%\n", l1_hits_pct); + printf("L2 hits: %.2f%%\n", l2_hits_pct); + printf("L1/2 hits: %.2f%%\n", l1_hits_pct + l2_hits_pct); + printf("L3 hits: %.2f%%\n", l3_hits_pct); + printf("L1/2/3 hits: %.2f%%\n", l1_hits_pct + l2_hits_pct + l3_hits_pct); + printf("DRAM hits: %.2f%%\n", dram_hits_pct); +} + +void PrintPieChartCsvNoNewline(const PieChart& pie) { + double l1_hits_pct = 100. * pie.l1_hits / pie.total; + double l2_hits_pct = 100. * pie.l2_hits / pie.total; + double l3_hits_pct = 100. * pie.l3_hits / pie.total; + double dram_hits_pct = 100. * pie.dram_hits / pie.total; + printf("%.2f,%.2f,%.2f,%.2f", l1_hits_pct, l2_hits_pct, l3_hits_pct, + dram_hits_pct); +} + +void Study(int accesses, int size, std::uint8_t* buf) { + CacheCounts cache_counts; + MeasureCacheCounts(accesses, size, buf, &cache_counts); + const Hypothesis* hypotheses[] = { + new Hypothesis5, new Hypothesis4, new Hypothesis3, + new Hypothesis2, new Hypothesis1, + }; + if (getenv("DUMP_CSV")) { + printf("%d", size); + for (const Hypothesis* hypothesis : hypotheses) { + printf(","); + PieChart pie; + hypothesis->Analyze(cache_counts, &pie); + PrintPieChartCsvNoNewline(pie); + } + printf("\n"); + } else { + printf("\n\n\naccesses=%d, size=%d:\n", accesses, size); + printf("\nCache counts:\n"); + PrintCacheCounts(cache_counts); + for (const Hypothesis* hypothesis : hypotheses) { + printf("\n%s:\n", hypothesis->Name()); + PieChart pie; + hypothesis->Analyze(cache_counts, &pie); + PrintPieChart(pie); + } + } + fflush(stdout); + for (const Hypothesis* hypothesis : hypotheses) { + delete hypothesis; + } +} + +int main() { + const int kMinSize = 1 << 12; + const int kMaxSize = 1 << 24; + const int kAccesses = 1e8; + void* buf_void = nullptr; + posix_memalign(&buf_void, 64, kMaxSize); + std::uint8_t* buf = static_cast<std::uint8_t*>(buf_void); + std::default_random_engine random_engine; + for (int i = 0; i < kMaxSize; i++) { + buf[i] = random_engine(); + } + for (int size = kMinSize; size <= kMaxSize; size *= 2) { + Study(kAccesses, size, buf); + } + delete[] buf; +} diff --git a/standalone/encode.py b/standalone/encode.py new file mode 100644 index 0000000..c192ab9 --- /dev/null +++ b/standalone/encode.py @@ -0,0 +1,134 @@ +# Copyright 2018 The gemmlowp Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Encodes ARM asm code for certain instructions into the corresponding machine code encoding, as a .word directive in the asm code, preserving the original code in a comment. + +Reads from stdin, writes to stdout. + +Example diff: +- "udot v16.4s, v4.16b, v0.16b\n" ++ ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n" + +The intended use case is to make asm code easier to compile on toolchains that +do not support certain new instructions. +""" + +import sys +import re +import argparse + + +def encode_udot_sdot_vector(line): + m = re.search( + r'\b([us])dot[ ]+v([0-9]+)[ ]*\.[ ]*4s[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b', + line) + if not m: + return 0, line + + match = m.group(0) + unsigned = 1 if m.group(1) == 'u' else 0 + accum = int(m.group(2)) + lhs = int(m.group(3)) + rhs = int(m.group(4)) + assert accum >= 0 and accum <= 31 + assert lhs >= 0 and lhs <= 31 + assert rhs >= 0 and rhs <= 31 + mcode = 0x4e809400 | (accum << 0) | (lhs << 5) | (rhs << 16) | ( + unsigned << 29) + return mcode, match + + +def encode_udot_sdot_element(line): + m = re.search( + r'\b([us])dot[ ]+v([0-9]+)[ ]*\.[ ]*4s[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*4b[ ]*\[([0-9])\]', + line) + if not m: + return 0, line + + match = m.group(0) + unsigned = 1 if m.group(1) == 'u' else 0 + accum = int(m.group(2)) + lhs = int(m.group(3)) + rhs = int(m.group(4)) + lanegroup = int(m.group(5)) + assert accum >= 0 and accum <= 31 + assert lhs >= 0 and lhs <= 31 + assert rhs >= 0 and rhs <= 31 + assert lanegroup >= 0 and lanegroup <= 3 + l = 1 if lanegroup & 1 else 0 + h = 1 if lanegroup & 2 else 0 + mcode = 0x4f80e000 | (accum << 0) | (lhs << 5) | (rhs << 16) | (l << 21) | ( + h << 11) | ( + unsigned << 29) + return mcode, match + + +def encode(line): + for encode_func in [encode_udot_sdot_vector, encode_udot_sdot_element]: + mcode, match = encode_func(line) + if mcode: + return mcode, match + return 0, line + + +def read_existing_encoding(line): + m = re.search(r'\.word\ (0x[0-9a-f]+)', line) + if m: + return int(m.group(1), 16) + return 0 + + +parser = argparse.ArgumentParser(description='Encode some A64 instructions.') +parser.add_argument( + '-f', + '--fix', + help='fix existing wrong encodings in-place and continue', + action='store_true') +args = parser.parse_args() + +lineno = 0 +found_existing_encodings = False +found_error = False +found_fixes = False +for line in sys.stdin: + lineno = lineno + 1 + mcode, match = encode(line) + if mcode: + existing_encoding = read_existing_encoding(line) + if existing_encoding: + found_existing_encodings = True + if mcode != existing_encoding: + if args.fix: + line = line.replace('.word 0x%x // %s' % (existing_encoding, match), + '.word 0x%x // %s' % (mcode, match)) + found_fixes = True + else: + sys.stderr.write( + "Error at line %d: existing encoding 0x%x differs from encoding 0x%x for instruction '%s':\n\n%s\n\n" + % (lineno, existing_encoding, mcode, match, line)) + found_error = True + else: + line = line.replace(match, '.word 0x%x // %s' % (mcode, match)) + sys.stdout.write(line) +if found_error: + sys.exit(1) +if found_existing_encodings: + if found_fixes: + sys.stderr.write( + 'Note: some instructions that this program is able to encode, were already encoded and their existing encodings didn\'t match the specified asm instructions. Since --fix was passed, these were fixed in-place.\n' + ) + else: + sys.stderr.write( + 'Note: some instructions that this program is able to encode, were already encoded. These encodings have been checked.\n' + ) diff --git a/standalone/neon-gemm-kernel-benchmark.cc b/standalone/neon-gemm-kernel-benchmark.cc index bff33fb..9146179 100644 --- a/standalone/neon-gemm-kernel-benchmark.cc +++ b/standalone/neon-gemm-kernel-benchmark.cc @@ -1,4 +1,4 @@ -// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. +// Copyright 2016 The gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -240,6 +240,51 @@ struct KernelFormat { static const int kCols = Rhs::Cell::kWidth * Rhs::kCells; }; +// KernelOperandRanges specifies the minimum and maximum values an operand can +// take. It consists of two ranges: one for the LHS and one for the RHS. The +// default values are the minimum and maximum values of the operand data type. +template <typename Kernel, typename OperandType = typename Kernel::OperandType> +struct KernelOperandRanges { + static OperandType LhsMin() { + return std::numeric_limits<OperandType>::lowest(); + } + static OperandType LhsMax() { + return std::numeric_limits<OperandType>::max(); + } + static OperandType RhsMin() { + return std::numeric_limits<OperandType>::lowest(); + } + static OperandType RhsMax() { + return std::numeric_limits<OperandType>::max(); + } +}; + +template <typename Kernel> +struct KernelOperandRanges<Kernel, float> { + static float LhsMin() { return -100.f; } + static float LhsMax() { return 100.f; } + static float RhsMin() { return -100.f; } + static float RhsMax() { return 100.f; } +}; + +#define SET_7BIT_RANGES(kernel) \ +template <> \ +struct KernelOperandRanges<kernel, std::int8_t> { \ + static std::int8_t LhsMin() { return -63; } \ + static std::int8_t LhsMax() { return 63; } \ + static std::int8_t RhsMin() { return -64; } \ + static std::int8_t RhsMax() { return 63; } \ +}; + +#define SET_425BIT_RANGES(kernel) \ +template <> \ +struct KernelOperandRanges<kernel, std::int8_t> { \ + static std::int8_t LhsMin() { return -7; } \ + static std::int8_t LhsMax() { return 7; } \ + static std::int8_t RhsMin() { return -9; } \ + static std::int8_t RhsMax() { return 9; } \ +}; + inline const char* CellOrderName(CellOrder o) { switch (o) { case CellOrder::DepthMajor: @@ -596,7 +641,6 @@ struct NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits { AccumulatorType* accum_ptr, int depth) { std::size_t start_depth = 123; std::size_t run_depth = depth; - std::size_t dst_col_stride = 4; AccumulatorType* dst_ptr = accum_ptr; asm volatile( @@ -2516,6 +2560,817 @@ struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits { } }; +struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + std::size_t start_depth = 123; + std::size_t run_depth = depth; + std::size_t dst_col_stride = 4; + AccumulatorType* dst_ptr = accum_ptr; + asm volatile( + // Overview of register layout: + // + // A 4x16 block of Rhs is stored in 8 bit in v0--v3. + // A 4x16 block of Lhs is stored in 8 bit in v4--v7. + // + // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit + // components which need to be horizontally-added at the end) + // + // Register layout: + // + // +--------+--------+--------+--------+ + // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] | + // Rhs +--------+--------+--------+--------+ + // | ... | ... | ... | ... | + // +--------+--------+--------+--------| + // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]| + // +--------+--------+--------+--------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +-------+-----+--------+ - - +--------+--------+--------+--------+ + // |v4.b[0]| ... |v4.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s | + // |v5.b[0]| ... |v5.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s | + // |v6.b[0]| ... |v6.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s | + // |v7.b[0]| ... |v7.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s | + // +-------+--------------+ - - +--------+--------+--------+--------+ + // + // Accumulator + // + + // Clear accumulators + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + "dup v16.4s, wzr\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" + "dup v17.4s, wzr\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "dup v18.4s, wzr\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "dup v19.4s, wzr\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "dup v20.4s, wzr\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "dup v21.4s, wzr\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "dup v22.4s, wzr\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + "dup v23.4s, wzr\n" + "subs %w[run_depth], %w[run_depth], #16\n" + "dup v24.4s, wzr\n" + "mov x0, %[dst_ptr]\n" + "dup v25.4s, wzr\n" + "dup v26.4s, wzr\n" + "dup v27.4s, wzr\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + "beq 1f\n" + + "cmp %w[run_depth], #32\n" + "blt 2f\n" + + "3:\n" + "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n" + ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n" + "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n" + ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n" + ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n" + ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n" + ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n" + "prfm pldl1keep, [%[rhs_ptr], #128]\n" + ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n" + "ld1 {v14.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n" + ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n" + "ld1 {v15.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n" + "prfm pldl1keep, [%[lhs_ptr], #128]\n" + ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e889590 // udot v16.4s, v12.16b, v8.16b\n" + ".word 0x6e899591 // udot v17.4s, v12.16b, v9.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8a9592 // udot v18.4s, v12.16b, v10.16b\n" + ".word 0x6e8b9593 // udot v19.4s, v12.16b, v11.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8895b4 // udot v20.4s, v13.16b, v8.16b\n" + ".word 0x6e8995b5 // udot v21.4s, v13.16b, v9.16b\n" + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + "sub %[run_depth], %[run_depth], #32\n" + ".word 0x6e8a95b6 // udot v22.4s, v13.16b, v10.16b\n" + ".word 0x6e8b95b7 // udot v23.4s, v13.16b, v11.16b\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8895d8 // udot v24.4s, v14.16b, v8.16b\n" + ".word 0x6e8995d9 // udot v25.4s, v14.16b, v9.16b\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8a95da // udot v26.4s, v14.16b, v10.16b\n" + ".word 0x6e8b95db // udot v27.4s, v14.16b, v11.16b\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8895fc // udot v28.4s, v15.16b, v8.16b\n" + "prfm pldl1keep, [%[rhs_ptr], #128]\n" + ".word 0x6e8995fd // udot v29.4s, v15.16b, v9.16b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "cmp %w[run_depth], #32\n" + ".word 0x6e8a95fe // udot v30.4s, v15.16b, v10.16b\n" + "prfm pldl1keep, [%[lhs_ptr], #128]\n" + ".word 0x6e8b95ff // udot v31.4s, v15.16b, v11.16b\n" + + "bge 3b\n" + + "cmp %w[run_depth], #0\n" + "beq 1f\n" + + "2:\n" + + "subs %w[run_depth], %w[run_depth], #16\n" + + ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n" + ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n" + ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n" + ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n" + ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n" + ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n" + ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n" + ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n" + ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n" + ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n" + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + + "bne 2b\n" + + "1:\n" + + ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n" + ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n" + ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n" + ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n" + ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n" + ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n" + ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n" + ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n" + ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n" + ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n" + ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n" + ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n" + ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n" + ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n" + ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n" + ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n" + + // Load accumulators from memory + "ld1 {v8.16b}, [x0], #16\n" + "ld1 {v9.16b}, [x0], #16\n" + "ld1 {v10.16b}, [x0], #16\n" + "ld1 {v11.16b}, [x0], #16\n" + "mov x0, %[dst_ptr]\n" + + // Reduce aggregators horizontally + "addp v0.4s, v16.4s, v20.4s\n" + "addp v1.4s, v17.4s, v21.4s\n" + "addp v2.4s, v18.4s, v22.4s\n" + "addp v3.4s, v19.4s, v23.4s\n" + "addp v4.4s, v24.4s, v28.4s\n" + "addp v5.4s, v25.4s, v29.4s\n" + "addp v6.4s, v26.4s, v30.4s\n" + "addp v7.4s, v27.4s, v31.4s\n" + + "addp v12.4s, v0.4s, v4.4s\n" + "addp v13.4s, v1.4s, v5.4s\n" + "addp v14.4s, v2.4s, v6.4s\n" + "addp v15.4s, v3.4s, v7.4s\n" + + // Add to the accumulators loaded from memory + "add v8.4s, v8.4s, v12.4s\n" + "add v9.4s, v9.4s, v13.4s\n" + "add v10.4s, v10.4s, v14.4s\n" + "add v11.4s, v11.4s, v15.4s\n" + + // Store accumulators back to memory + "st1 {v8.16b}, [x0], #16\n" + "st1 {v9.16b}, [x0], #16\n" + "st1 {v10.16b}, [x0], #16\n" + "st1 {v11.16b}, [x0], #16\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth), + [dst_col_stride] "+r"(dst_col_stride) + : // inputs + [start_depth] "r"(start_depth) + : // clobbers + "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31"); + } +}; + +// Fast kernel operating on int8 operands with 7-bit range. +// It is assumed that one of the two operands only takes values in [-63, 63], +// while the other take values in [-64, 63]. +// With this restriction, it is possible to multiply-accumulate operands into +// a 16-bit integer eight times without overflow. +struct NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits { + typedef std::int8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { +#define GEMMLOWP_LABEL_64_DEPTH_LOOP "1" +#define GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP "2" +#define GEMMLOWP_LABEL_16_DEPTH_LOOP "3" +#define GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP "4" + + AccumulatorType* dst_ptr = accum_ptr; + asm volatile( + // Overview of register layout: + // + // A 4x16 block of Lhs is stored in 8 bit in v0--v7. + // A 2x16 block of Rhs is stored in 8 bit in v8--v15. + // + // A 4x2 block of global accumulators is stored in v24-v31 (as 4x32 bit + // components which need to be horizontally-added at the end). + // + // A 4x2 block of local accumulators is stored in v16-v23 (as 8x16 bit + // components which are added to global accumulators every 64 depth + // iteration. + // + // The Lhs vectors are multiplied by the Rhs vectors with a widening + // multiply over the 8 first levels of depth, producing int16x8 + // vectors of products for each position in the accumulator matrix. + // + // Like the trick used in the fast 8-bit kernel, the operands are + // restricted to 7-bit range [-2^6, 2^6) so their products are in range + // [-2^12, 2^12 -1). This enables adding eight such products without any + // risk of overflowing int16, equating to 64 levels of depth before + // horizontally adding these int16x8 accumulators into the final int32x4 + // accumulators. + // + // Register layout including both local and global accumulators. + // Since we do not have enough registers to store all Lhs values, we + // reuse the same registers v0--v7 to load the rest of the Lhs values. + // + // +-----+-----+ + // | v8 | v9 | + // Rhs +-----+-----+ + // | v10 | v11 | + // +-----+-----+ + // | v12 | v13 | + // +-----+-----+ + // | v14 | v15 | + // Lhs +-----+-----+ + // +----+----+----+----+ - - +-----+-----+ +--------+--------+ + // | v0 | v4 | v0 | v4 | | v16 | v20 | | v24.4s | v28.4s | + // | v1 | v5 | v1 | v5 | | v17 | v21 | -> | v25.4s | v29.4s | + // | v2 | v6 | v2 | v6 | | v18 | v22 | | v26.4s | v30.4s | + // | v3 | v7 | v3 | v7 | | v19 | v23 | | v27.4s | v31.4s | + // +----+----+----+----+ - - +-----+-----+ +--------+--------+ + // + // Local Accumulator Global Accumulator + // + + // Clear accumulators. + "dup v16.4s, wzr\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "dup v24.4s, wzr\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "dup v17.4s, wzr\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "dup v25.4s, wzr\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "dup v18.4s, wzr\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "dup v26.4s, wzr\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "dup v19.4s, wzr\n" + "dup v27.4s, wzr\n" + "dup v20.4s, wzr\n" + "dup v28.4s, wzr\n" + "dup v21.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v22.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v23.4s, wzr\n" + "dup v31.4s, wzr\n" + + "cmp %w[depth], #64\n" + "blt " GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP "f\n" + + //"loop_%=:\n" + GEMMLOWP_LABEL_64_DEPTH_LOOP + ":\n" + "subs %w[depth], %w[depth], #64\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "sadalp v24.4s, v16.8h\n" + "smull v16.8h, v0.8b, v8.8b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "sadalp v25.4s, v17.8h\n" + "smull v17.8h, v1.8b, v8.8b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "sadalp v26.4s, v18.8h\n" + "smull v18.8h, v2.8b, v8.8b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "sadalp v27.4s, v19.8h\n" + "smull v19.8h, v3.8b, v8.8b\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + "sadalp v28.4s, v20.8h\n" + "smull v20.8h, v0.8b, v9.8b\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "sadalp v29.4s, v21.8h\n" + "smull v21.8h, v1.8b, v9.8b\n" + "ld1 {v12.16b}, [%[rhs_ptr]], #16\n" + "sadalp v30.4s, v22.8h\n" + "smull v22.8h, v2.8b, v9.8b\n" + "ld1 {v13.16b}, [%[rhs_ptr]], #16\n" + "sadalp v31.4s, v23.8h\n" + "smull v23.8h, v3.8b, v9.8b\n" + + "cmp %w[depth], #64\n" + "smlal2 v16.8h, v0.16b, v8.16b\n" + "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v17.8h, v1.16b, v8.16b\n" + "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v18.8h, v2.16b, v8.16b\n" + "smlal2 v19.8h, v3.16b, v8.16b\n" + + "smlal2 v20.8h, v0.16b, v9.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v21.8h, v1.16b, v9.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v22.8h, v2.16b, v9.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v23.8h, v3.16b, v9.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "smlal v16.8h, v4.8b, v10.8b\n" + "smlal v17.8h, v5.8b, v10.8b\n" + "smlal v18.8h, v6.8b, v10.8b\n" + "smlal v19.8h, v7.8b, v10.8b\n" + "smlal v20.8h, v4.8b, v11.8b\n" + + "smlal v21.8h, v5.8b, v11.8b\n" + "smlal v22.8h, v6.8b, v11.8b\n" + "smlal v23.8h, v7.8b, v11.8b\n" + + "smlal2 v16.8h, v4.16b, v10.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v17.8h, v5.16b, v10.16b\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v18.8h, v6.16b, v10.16b\n" + "smlal2 v19.8h, v7.16b, v10.16b\n" + + "smlal2 v20.8h, v4.16b, v11.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v21.8h, v5.16b, v11.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v22.8h, v6.16b, v11.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v23.8h, v7.16b, v11.16b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + + "smlal v16.8h, v0.8b, v12.8b\n" + "smlal v17.8h, v1.8b, v12.8b\n" + "smlal v18.8h, v2.8b, v12.8b\n" + "smlal v19.8h, v3.8b, v12.8b\n" + "smlal v20.8h, v0.8b, v13.8b\n" + "smlal v21.8h, v1.8b, v13.8b\n" + "smlal v22.8h, v2.8b, v13.8b\n" + "smlal v23.8h, v3.8b, v13.8b\n" + + "smlal2 v16.8h, v0.16b, v12.16b\n" + "smlal2 v17.8h, v1.16b, v12.16b\n" + "smlal2 v18.8h, v2.16b, v12.16b\n" + "smlal2 v19.8h, v3.16b, v12.16b\n" + + "smlal2 v20.8h, v0.16b, v13.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v21.8h, v1.16b, v13.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v22.8h, v2.16b, v13.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v23.8h, v3.16b, v13.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "smlal v16.8h, v4.8b, v14.8b\n" + "smlal v17.8h, v5.8b, v14.8b\n" + "smlal v18.8h, v6.8b, v14.8b\n" + "smlal v19.8h, v7.8b, v14.8b\n" + + "smlal v20.8h, v4.8b, v15.8b\n" + "smlal v21.8h, v5.8b, v15.8b\n" + "smlal v22.8h, v6.8b, v15.8b\n" + "smlal v23.8h, v7.8b, v15.8b\n" + + "smlal2 v16.8h, v4.16b, v14.16b\n" + "smlal2 v17.8h, v5.16b, v14.16b\n" + "smlal2 v18.8h, v6.16b, v14.16b\n" + "smlal2 v19.8h, v7.16b, v14.16b\n" + + "smlal2 v20.8h, v4.16b, v15.16b\n" + "smlal2 v21.8h, v5.16b, v15.16b\n" + "smlal2 v22.8h, v6.16b, v15.16b\n" + "smlal2 v23.8h, v7.16b, v15.16b\n" + + "bge " GEMMLOWP_LABEL_64_DEPTH_LOOP "b\n" + + GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP + ":\n" + + "cmp %w[depth], #16\n" + "blt " GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP "f\n" + + //"loop_%=:\n" + GEMMLOWP_LABEL_16_DEPTH_LOOP + ":\n" + "sadalp v24.4s, v16.8h\n" + "smull v16.8h, v0.8b, v8.8b\n" + "subs %w[depth], %w[depth], #16\n" + "sadalp v25.4s, v17.8h\n" + "smull v17.8h, v1.8b, v8.8b\n" + "sadalp v26.4s, v18.8h\n" + "smull v18.8h, v2.8b, v8.8b\n" + "sadalp v27.4s, v19.8h\n" + "smull v19.8h, v3.8b, v8.8b\n" + "sadalp v28.4s, v20.8h\n" + "smull v20.8h, v0.8b, v9.8b\n" + "sadalp v29.4s, v21.8h\n" + "smull v21.8h, v1.8b, v9.8b\n" + "sadalp v30.4s, v22.8h\n" + "smull v22.8h, v2.8b, v9.8b\n" + "sadalp v31.4s, v23.8h\n" + "smull v23.8h, v3.8b, v9.8b\n" + + "cmp %w[depth], #16\n" + "smlal2 v16.8h, v0.16b, v8.16b\n" + "smlal2 v17.8h, v1.16b, v8.16b\n" + "smlal2 v18.8h, v2.16b, v8.16b\n" + "smlal2 v19.8h, v3.16b, v8.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + + "smlal2 v20.8h, v0.16b, v9.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v21.8h, v1.16b, v9.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v22.8h, v2.16b, v9.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v23.8h, v3.16b, v9.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + + "bge " GEMMLOWP_LABEL_16_DEPTH_LOOP "b\n" + + GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP + ":\n" + + "sadalp v24.4s, v16.8h\n" + "sadalp v25.4s, v17.8h\n" + "sadalp v26.4s, v18.8h\n" + "sadalp v27.4s, v19.8h\n" + "sadalp v28.4s, v20.8h\n" + "sadalp v29.4s, v21.8h\n" + "sadalp v30.4s, v22.8h\n" + "sadalp v31.4s, v23.8h\n" + + // Reduce aggregators horizontally. + "addp v0.4s, v24.4s, v25.4s\n" + "addp v1.4s, v26.4s, v27.4s\n" + "addp v2.4s, v28.4s, v29.4s\n" + "addp v3.4s, v30.4s, v31.4s\n" + + "addp v4.4s, v0.4s, v1.4s\n" + "addp v5.4s, v2.4s, v3.4s\n" + + // Load accumulators from memory. + "mov x0, %[dst_ptr]\n" + "ld1 {v6.16b}, [x0], #16\n" + "ld1 {v7.16b}, [x0], #16\n" + + // Add to the accumulators loaded from memory. + "add v6.4s, v6.4s, v4.4s\n" + "add v7.4s, v7.4s, v5.4s\n" + + // Store accumulators back to memory. + "mov x0, %[dst_ptr]\n" + "st1 {v6.16b}, [x0], #16\n" + "st1 {v7.16b}, [x0], #16\n" + + : + // Outputs. + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), [depth] "+r"(depth) + : + // Inputs. + + : + // Clobbers. + "cc", "memory", + // We use these NEON registers + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x0"); + } +}; + +SET_7BIT_RANGES(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits); + +// Kernel operating on int8 operands with 4.25-bit range. +// It is assumed that one of the two operands only takes values in [-7, 7], +// while the other take values in [-9, 9]. +// With this restriction, it is possible to multiply-accumulate operands into +// a 16-bit integer thirty-two times without overflow. +struct NEON_64bit_GEMM_Int425Operands { + typedef std::int8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 32, CellOrder::WidthMajor>, 1>, + KernelSideFormat<CellFormat<2, 32, CellOrder::WidthMajor>, 1> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { +#define GEMMLOWP_LABEL_512_DEPTH_LOOP "1" +#define GEMMLOWP_LABEL_32_DEPTH_LOOP "2" +#define GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP "3" + + AccumulatorType* dst_ptr = accum_ptr; + int outer_depth = depth / 512 + 1; + + asm volatile( + // Overview of register layout: + // + // A 4x32 block of Lhs is stored in 8 bit in v0--v7. + // A 2x32 block of Rhs is stored in 8 bit in v8--v11. + // + // A 4x2 block of global accumulators is stored in v24-v31 (as 4x32 bit + // components which need to be horizontally-added at the end). + // + // A 4x2 block of local accumulators is stored in v16-v23 (as 8x16 bit + // components which are horizontally-added to global accumulators every + // 512 depth iteration. + // + // The Lhs vectors are multiplied by the Rhs vectors with a multiply + // over the 16 first levels of depth, producing int8x16 vectors of + // products for each position in the accumulator matrix. + // + // Like the trick used in the fast 8-bit and 7-bit kernels, the operands + // are restricted to 4.25-bit range, [-7, 7] for one operand and [-9, 9] + // for the other operand. This enables adding two such products without + // any risk of overflowing int8, and thiry-two such products without + // overflowing int16. This equates to 512 levels of depth before + // horizontally adding these int16x8 accumulators into the final int32x4 + // accumulators. + // + // Register layout (ignoring the v12--v15 temporary 8-bit accumulators). + // Since we do not have enough registers to store all Lhs values and Rhs + // values, we reuse the same registers v0--v7 to load subsequent Lhs + // values and v8-v11 to subsequent Rhs values. + // + // +-----+-----+ + // | v8 | v9 | + // Rhs +-----+-----+ + // | v10 | v11 | + // +-----+-----+ + // | v8 | v9 | + // +-----+-----+ + // | v10 | v11 | + // Lhs +-----+-----+ + // +----+----+----+----+ - - +-----+-----+ +--------+--------+ + // | v0 | v4 | v0 | v4 | | v16 | v17 | | v24.4s | v25.4s | + // | v1 | v5 | v1 | v5 | | v18 | v19 | -> | v26.4s | v27.4s | + // | v2 | v6 | v2 | v6 | | v20 | v21 | | v28.4s | v29.4s | + // | v3 | v7 | v3 | v7 | | v22 | v23 | | v30.4s | v31.4s | + // +----+----+----+----+ - - +-----+-----+ +--------+--------+ + // + // Local Accumulator Global Accumulator + // + + // Clear global accumulators. + "dup v24.4s, wzr\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "dup v25.4s, wzr\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "dup v26.4s, wzr\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + "dup v27.4s, wzr\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "dup v28.4s, wzr\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "dup v29.4s, wzr\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "dup v30.4s, wzr\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "dup v31.4s, wzr\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + + //"loop_%=:\n" + GEMMLOWP_LABEL_512_DEPTH_LOOP + ":\n" + // Clear local accumulators. + "dup v16.8h, wzr\n" + "dup v17.8h, wzr\n" + "dup v18.8h, wzr\n" + "mov x1, #512\n" + "dup v19.8h, wzr\n" + "dup v20.8h, wzr\n" + "dup v21.8h, wzr\n" + "dup v22.8h, wzr\n" + "dup v23.8h, wzr\n" + + //"loop_%=:\n" + GEMMLOWP_LABEL_32_DEPTH_LOOP + ":\n" + "mul v12.16b, v0.16b, v8.16b\n" + "mul v13.16b, v0.16b, v10.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "mul v14.16b, v2.16b, v8.16b\n" + "mul v15.16b, v2.16b, v10.16b\n" + + "mla v12.16b, v1.16b, v9.16b\n" + "mla v13.16b, v1.16b, v11.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "mla v14.16b, v3.16b, v9.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "mla v15.16b, v3.16b, v11.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "sadalp v16.8h, v12.16b\n" + "sadalp v17.8h, v13.16b\n" + "subs %w[depth], %w[depth], #32\n" + "sadalp v18.8h, v14.16b\n" + "sadalp v19.8h, v15.16b\n" + "subs x1, x1, #32\n" + + "mul v12.16b, v4.16b, v8.16b\n" + "mul v13.16b, v4.16b, v10.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "mul v14.16b, v6.16b, v8.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "mul v15.16b, v6.16b, v10.16b\n" + + "mla v12.16b, v5.16b, v9.16b\n" + "mla v13.16b, v5.16b, v11.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "mla v14.16b, v7.16b, v9.16b\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "mla v15.16b, v7.16b, v11.16b\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + + "sadalp v20.8h, v12.16b\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "sadalp v21.8h, v13.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "sadalp v22.8h, v14.16b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "sadalp v23.8h, v15.16b\n" + + "mul v12.16b, v0.16b, v8.16b\n" + "mul v13.16b, v0.16b, v10.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "mul v14.16b, v2.16b, v8.16b\n" + "mul v15.16b, v2.16b, v10.16b\n" + + "mla v12.16b, v1.16b, v9.16b\n" + "mla v13.16b, v1.16b, v11.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "mla v14.16b, v3.16b, v9.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "mla v15.16b, v3.16b, v11.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "sadalp v16.8h, v12.16b\n" + "sadalp v17.8h, v13.16b\n" + "sadalp v18.8h, v14.16b\n" + "sadalp v19.8h, v15.16b\n" + + "mul v12.16b, v4.16b, v8.16b\n" + "mul v13.16b, v4.16b, v10.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "mul v14.16b, v6.16b, v8.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "mul v15.16b, v6.16b, v10.16b\n" + + "mla v12.16b, v5.16b, v9.16b\n" + "mla v13.16b, v5.16b, v11.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "mla v14.16b, v7.16b, v9.16b\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "mla v15.16b, v7.16b, v11.16b\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + + "sadalp v20.8h, v12.16b\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "sadalp v21.8h, v13.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "sadalp v22.8h, v14.16b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "sadalp v23.8h, v15.16b\n" + + "beq " GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP + "f\n" + + "cmp %w[depth], #0\n" + "bne " GEMMLOWP_LABEL_32_DEPTH_LOOP "b\n" + + GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP + ":\n" + + // Pairwise add 16-bit local accums to 32-bit global accums. + "sadalp v24.4s, v16.8h\n" + "sadalp v25.4s, v17.8h\n" + "sadalp v26.4s, v18.8h\n" + "sadalp v27.4s, v19.8h\n" + "sadalp v28.4s, v20.8h\n" + "sadalp v29.4s, v21.8h\n" + "sadalp v30.4s, v22.8h\n" + "sadalp v31.4s, v23.8h\n" + + "bne " GEMMLOWP_LABEL_512_DEPTH_LOOP + "b\n" + + // Reduce aggregators horizontally. + "addp v0.4s, v24.4s, v26.4s\n" + "addp v1.4s, v28.4s, v30.4s\n" + "addp v2.4s, v25.4s, v27.4s\n" + "addp v3.4s, v29.4s, v31.4s\n" + + "addp v4.4s, v0.4s, v1.4s\n" + "addp v5.4s, v2.4s, v3.4s\n" + + // Load accumulators from memory. + "mov x0, %[dst_ptr]\n" + "ld1 {v6.16b}, [x0], #16\n" + "ld1 {v7.16b}, [x0], #16\n" + + // Add to the accumulators loaded from memory. + "add v6.4s, v6.4s, v4.4s\n" + "add v7.4s, v7.4s, v5.4s\n" + + // Store accumulators back to memory. + "mov x0, %[dst_ptr]\n" + "st1 {v6.16b}, [x0], #16\n" + "st1 {v7.16b}, [x0], #16\n" + + : + // Outputs. + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), [depth] "+r"(depth), + [outer_depth] "+r"(outer_depth) + : + // Inputs. + + : + // Clobbers. + "cc", "memory", + // We use these NEON registers + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x0", "x1"); + } +}; + +SET_425BIT_RANGES(NEON_64bit_GEMM_Int425Operands); + #ifdef __ARM_FEATURE_DOTPROD // Kernels utilizing the Armv8.2 Dot Product extension. // @@ -2582,41 +3437,41 @@ struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct { // Start the MACs at the head of the loop - 1st cell from each side // already loaded. - "udot v8.4s, v2.16b, v0.b[0]\n" - "udot v9.4s, v2.16b, v0.b[1]\n" + ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n" + ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n" "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell. - "udot v10.4s, v2.16b, v0.b[2]\n" - "udot v11.4s, v2.16b, v0.b[3]\n" + ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n" + ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n" "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell. - "udot v12.4s, v2.16b, v1.b[0]\n" - "udot v13.4s, v2.16b, v1.b[1]\n" + ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n" + ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n" "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell. - "udot v14.4s, v2.16b, v1.b[2]\n" - "udot v15.4s, v2.16b, v1.b[3]\n" + ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n" + ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n" "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load // for the next iteration early. - "udot v16.4s, v3.16b, v0.b[0]\n" - "udot v17.4s, v3.16b, v0.b[1]\n" - "udot v18.4s, v3.16b, v0.b[2]\n" - "udot v19.4s, v3.16b, v0.b[3]\n" - "udot v20.4s, v3.16b, v1.b[0]\n" - "udot v21.4s, v3.16b, v1.b[1]\n" - "udot v22.4s, v3.16b, v1.b[2]\n" - "udot v23.4s, v3.16b, v1.b[3]\n" - "udot v24.4s, v4.16b, v0.b[0]\n" - "udot v25.4s, v4.16b, v0.b[1]\n" - "udot v26.4s, v4.16b, v0.b[2]\n" - "udot v27.4s, v4.16b, v0.b[3]\n" + ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n" + ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n" + ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n" + ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n" + ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n" + ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n" + ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n" + ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n" + ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n" + ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n" + ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n" + ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n" "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell - // load for the next iteration early. - "udot v28.4s, v4.16b, v1.b[0]\n" - "udot v29.4s, v4.16b, v1.b[1]\n" + ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n" + ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n" // Loop. Decrement loop index (depth) by 4 as udot processes 4 // depth values. "subs %w[depth], %w[depth], #4\n" - "udot v30.4s, v4.16b, v1.b[2]\n" - "udot v31.4s, v4.16b, v1.b[3]\n" + ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n" + ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n" "bne " GEMMLOWP_LABEL_LOOP "b\n" @@ -2712,54 +3567,67 @@ struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 { GEMMLOWP_LABEL_LOOP ":\n" - "udot v8.4s, v2.16b, v0.b[0]\n" - "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1 - "udot v9.4s, v2.16b, v0.b[1]\n" - "ins v0.d[1], x18\n" // Finish loading v0 - "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure. - "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register - "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure. - "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer. - "udot v10.4s, v2.16b, v0.b[2]\n" - "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4 - "udot v11.4s, v2.16b, v0.b[3]\n" - "ins v1.d[1], x18\n" // Finish loading v1 - "udot v12.4s, v2.16b, v1.b[0]\n" - "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register - "udot v13.4s, v2.16b, v1.b[1]\n" - "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer. - "udot v14.4s, v2.16b, v1.b[2]\n" - - "udot v15.4s, v2.16b, v1.b[3]\n" - "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time) - "udot v18.4s, v3.16b, v0.b[2]\n" - "ins v4.d[1], x18\n" // Finish loading v4 - "udot v19.4s, v3.16b, v0.b[3]\n" - "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register - "udot v20.4s, v3.16b, v1.b[0]\n" + ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n" + "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1 + ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n" + "ins v0.d[1], x18\n" // Finish loading v0 + ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n" // out of + // sequence - + // used to + // reduce + // load/use + // pressure. + "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register + ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n" // out of + // sequence - + // used to + // reduce + // load/use + // pressure. + "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment + // pointer. + ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n" + "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4 + ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n" + "ins v1.d[1], x18\n" // Finish loading v1 + ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n" + "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register + ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n" + "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment + // pointer. + ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n" + + ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n" + "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time) + ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n" + "ins v4.d[1], x18\n" // Finish loading v4 + ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n" + "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register + ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n" "subs %w[depth], %w[depth], #4\n" - "udot v21.4s, v3.16b, v1.b[1]\n" - - "udot v22.4s, v3.16b, v1.b[2]\n" - - "udot v23.4s, v3.16b, v1.b[3]\n" - "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time) - "udot v24.4s, v4.16b, v0.b[0]\n" - "ins v2.d[1], x18\n" // Finish loading next v2 - "udot v25.4s, v4.16b, v0.b[1]\n" - "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register - "udot v26.4s, v4.16b, v0.b[2]\n" - - "udot v27.4s, v4.16b, v0.b[3]\n" - "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time) - "udot v28.4s, v4.16b, v1.b[0]\n" - "ins v3.d[1], x18\n" // Finish loading next v3 - "udot v29.4s, v4.16b, v1.b[1]\n" - "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register - "udot v30.4s, v4.16b, v1.b[2]\n" - - "udot v31.4s, v4.16b, v1.b[3]\n" - "bne " GEMMLOWP_LABEL_LOOP "b\n" + ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n" + + ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n" + + ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n" + "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time) + ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n" + "ins v2.d[1], x18\n" // Finish loading next v2 + ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n" + "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register + ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n" + + ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n" + "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time) + ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n" + "ins v3.d[1], x18\n" // Finish loading next v3 + ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n" + "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register + ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n" + + ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n" + "bne " GEMMLOWP_LABEL_LOOP + "b\n" // Store accumulators "mov x0, %[accum_ptr]\n" @@ -3852,533 +4720,195 @@ using NEON_32bit_GEMM_Float32_WithScalar_intrinsics = using NEON_64bit_GEMM_Float32_WithScalar_intrinsics = NEON_GEMM_Float32_WithScalar_intrinsics<2>; -#endif // __arm__ || __aarch64__ -#ifdef __mips -static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) { - // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c). -#if 0 - return __builtin_msa_maddv_w(a, b, c); -#else - asm volatile("maddv.w %w[a], %w[b], %w[c]\n" - // Outputs - : [a] "+f"(a) - // Inputs - : [b] "f"(b), [c] "f"(c)); - return a; -#endif -} - -// Using 32x32=32 multiplications. -// 20 MSA regs used: -// - 12 accumulators -// - 6 lhs -// - 1 rhs -// - 1 temps/zeroes -// ~55 instructions in the loop. -struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics { - typedef std::uint8_t OperandType; +// C++ intrinsics-based variant of the deep, 7-bit, fast kernel +struct NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics { + typedef std::int8_t OperandType; typedef std::int32_t AccumulatorType; typedef KernelFormat< - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > + KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> > Format; static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, AccumulatorType* accum_ptr, int depth) { - const v16i8 zeroes = __builtin_msa_ldi_b(0); - v4i32 acc[3][4]; - // Load accumulators. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0); + int32x4_t acc[4][2]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + acc[i][j] = vdupq_n_s32(0); } } - while (depth > 0) { - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - v8i16 lhs[6]; - lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0)); - lhs[1] = - reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0)); - - // Zero-extend 8-bit elements of lhs[] to 16 bits. - lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, - reinterpret_cast<v16i8>(lhs[0]))); - lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes, - reinterpret_cast<v16i8>(lhs[1]))); - lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, - reinterpret_cast<v16i8>(lhs[1]))); - - // Zero-extend 16-bit elements of lhs[] to 32 bits. - lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); - lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); - lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); - lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); - lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); - lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); - - // Depth 0. - for (int j = 0; j < 4; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); + int d = 0; + for (; d <= depth - 64; d += 64) { + int16x8_t local_acc[4][2]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + local_acc[i][j] = vdupq_n_s16(0); } } - // Depth 1. - for (int j = 0; j < 4; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); + // There are not enough registers to fit all lhs and rhs values for 64 + // depth. Instead, load values for 32 depth at a time. + for (int k = 0; k < 2; k++) { + int8x16_t lhs[4][2]; + for (int i = 0; i < 4; i++) { + lhs[i][0] = vld1q_s8(lhs_ptr + 16 * i + 128 * k); + lhs[i][1] = vld1q_s8(lhs_ptr + 64 + 16 * i + 128 * k); + } + + int8x16_t rhs[4]; + for (int i = 0; i < 4; i++) { + rhs[i] = vld1q_s8(rhs_ptr + 16 * i + 64 * k); + } + + for (int i = 0; i < 4; i++) { + if (k == 0) { + local_acc[i][0] = vmull_s8(vget_low_s8(lhs[i][0]), + vget_low_s8(rhs[0])); + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][1]), + vget_low_s8(rhs[2])); + local_acc[i][1] = vmull_s8(vget_low_s8(lhs[i][0]), + vget_low_s8(rhs[1])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], + vget_low_s8(lhs[i][1]), + vget_low_s8(rhs[3])); + } else { + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][0]), + vget_low_s8(rhs[0])); + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][1]), + vget_low_s8(rhs[2])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_low_s8(lhs[i][0]), + vget_low_s8(rhs[1])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_low_s8(lhs[i][1]), + vget_low_s8(rhs[3])); + } + + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_high_s8(lhs[i][0]), + vget_high_s8(rhs[0])); + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_high_s8(lhs[i][1]), + vget_high_s8(rhs[2])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_high_s8(lhs[i][0]), + vget_high_s8(rhs[1])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_high_s8(lhs[i][1]), + vget_high_s8(rhs[3])); } } - lhs_ptr += 24; - rhs_ptr += 8; - depth -= 2; + for (int i = 0; i < 4; i++) { + acc[i][0] = vpadalq_s16(acc[i][0], local_acc[i][0]); + acc[i][1] = vpadalq_s16(acc[i][1], local_acc[i][1]); + } + + lhs_ptr += 64 * 4; + rhs_ptr += 64 * 2; } + for (; d <= depth - 16; d += 16) { + int8x16_t lhs[4]; + for (int i = 0; i < 4; i++) { + lhs[i] = vld1q_s8(lhs_ptr + 16 * i); + } + int8x16_t rhs[2]; + for (int i = 0; i < 2; i++) { + rhs[i] = vld1q_s8(rhs_ptr + 16 * i); + } - // Store accumulators. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + int16x8_t local_acc = + vmull_s8(vget_low_s8(lhs[i]), vget_low_s8(rhs[j])); + local_acc = + vmlal_s8(local_acc, vget_high_s8(lhs[i]), vget_high_s8(rhs[j])); + acc[i][j] = vpadalq_s16(acc[i][j], local_acc); + } } + lhs_ptr += 16 * 4; + rhs_ptr += 16 * 2; + } + for (int i = 0; i < 2; i++) { + int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]); + int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]); + int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1); + int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i); + dst_val = vaddq_s32(dst_val, acc_4x); + vst1q_s32(accum_ptr + 4 * i, dst_val); } } }; -// Assembly implementation of the above -// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics. -// Using 32x32=32 multiplications. -// 20 MSA regs used: -// - 12 accumulators -// - 6 lhs -// - 1 rhs -// - 1 temps/zeroes -// ~55 instructions in the loop. -struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly { - typedef std::uint8_t OperandType; - typedef std::int32_t AccumulatorType; - typedef KernelFormat< - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > - Format; - static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, - AccumulatorType* accum_ptr, int depth) { - asm volatile( - // Load accumulators - "ld.w $w0, (0*16)(%[accum_ptr])\n" - "ld.w $w4, (1*16)(%[accum_ptr])\n" - "ld.w $w8, (2*16)(%[accum_ptr])\n" - "ld.w $w1, (3*16)(%[accum_ptr])\n" - "ld.w $w5, (4*16)(%[accum_ptr])\n" - "ld.w $w9, (5*16)(%[accum_ptr])\n" - "ld.w $w2, (6*16)(%[accum_ptr])\n" - "ld.w $w6, (7*16)(%[accum_ptr])\n" - "ld.w $w10, (8*16)(%[accum_ptr])\n" - "ld.w $w3, (9*16)(%[accum_ptr])\n" - "ld.w $w7, (10*16)(%[accum_ptr])\n" - "ld.w $w11, (11*16)(%[accum_ptr])\n" - // Set a temp to all zeroes. - "ldi.b $w19, 0\n" - - GEMMLOWP_LABEL_LOOP ":\n" - // Overview of register layout: - // - // A half of the 2x4 cell of Rhs is stored in 32bit in w18. - // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17. - // A 12x4 block of accumulators is stored in 32bit in w0-w11. - // - // +------+------+------+------+ - // Rhs |w18[0]|w18[1]|w18[2]|w18[3]| - // +------+------+------+------+ - // - // | | | | | - // - // Lhs | | | | | - // - // +---+---+ - - - - +------+------+------+------+ - // |w12|w15| | w0 | w1 | w2 | w3 | - // |w12|w15| | w0 | w1 | w2 | w3 | - // |w12|w15| | w0 | w1 | w2 | w3 | - // |w12|w15| | w0 | w1 | w2 | w3 | - // +---+---+ - - - - +------+------+------+------+ - // |w13|w16| | w4 | w5 | w6 | w7 | - // |w13|w16| | w4 | w5 | w6 | w7 | - // |w13|w16| | w4 | w5 | w6 | w7 | - // |w13|w16| | w4 | w5 | w6 | w7 | - // +---+---+ - - - - +------+------+------+------+ - // |w14|w17| | w8 | w9 | w10 | w11 | - // |w14|w17| | w8 | w9 | w10 | w11 | - // |w14|w17| | w8 | w9 | w10 | w11 | - // |w14|w17| | w8 | w9 | w10 | w11 | - // +---+---+ - - - - +------+------+------+------+ - // - // Accumulator +SET_7BIT_RANGES(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics); - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - "ld.b $w12, 0(%[lhs_ptr])\n" - "ld.b $w13, 8(%[lhs_ptr])\n" - - // Load 4 bytes of rhs[] for depth 0. - "lbu $a0, 0(%[rhs_ptr])\n" - "lbu $a1, 1(%[rhs_ptr])\n" - "lbu $a2, 2(%[rhs_ptr])\n" - "lbu $a3, 3(%[rhs_ptr])\n" - - // Zero-extend 8-bit elements of lhs[] to 16 bits. - "ilvr.b $w12, $w19, $w12\n" - "ilvl.b $w14, $w19, $w13\n" - "ilvr.b $w13, $w19, $w13\n" - // Zero-extend 16-bit elements of lhs[] to 32 bits. - "ilvl.h $w15, $w19, $w12\n" - "ilvl.h $w16, $w19, $w13\n" - "ilvl.h $w17, $w19, $w14\n" - "ilvr.h $w12, $w19, $w12\n" - "ilvr.h $w13, $w19, $w13\n" - "ilvr.h $w14, $w19, $w14\n" - - // Depth 0. - "fill.w $w18, $a0\n" - "lbu $a0, 4(%[rhs_ptr])\n" - "maddv.w $w0, $w12, $w18\n" - "maddv.w $w4, $w13, $w18\n" - "maddv.w $w8, $w14, $w18\n" - "fill.w $w18, $a1\n" - "lbu $a1, 5(%[rhs_ptr])\n" - "maddv.w $w1, $w12, $w18\n" - "maddv.w $w5, $w13, $w18\n" - "maddv.w $w9, $w14, $w18\n" - "fill.w $w18, $a2\n" - "lbu $a2, 6(%[rhs_ptr])\n" - "maddv.w $w2, $w12, $w18\n" - "maddv.w $w6, $w13, $w18\n" - "maddv.w $w10, $w14, $w18\n" - "fill.w $w18, $a3\n" - "lbu $a3, 7(%[rhs_ptr])\n" - "maddv.w $w3, $w12, $w18\n" - "maddv.w $w7, $w13, $w18\n" - "maddv.w $w11, $w14, $w18\n" - - // Depth 1. - "fill.w $w18, $a0\n" - "maddv.w $w0, $w15, $w18\n" - "maddv.w $w4, $w16, $w18\n" - "maddv.w $w8, $w17, $w18\n" - "fill.w $w18, $a1\n" - "maddv.w $w1, $w15, $w18\n" - "maddv.w $w5, $w16, $w18\n" - "maddv.w $w9, $w17, $w18\n" - "fill.w $w18, $a2\n" - "maddv.w $w2, $w15, $w18\n" - "maddv.w $w6, $w16, $w18\n" - "maddv.w $w10, $w17, $w18\n" - "fill.w $w18, $a3\n" - "maddv.w $w3, $w15, $w18\n" - "maddv.w $w7, $w16, $w18\n" - "maddv.w $w11, $w17, $w18\n" - - "addiu %[depth], -2\n" - GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" - GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n" - "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" - - // Store accumulators. - "st.w $w0, (0*16)(%[accum_ptr])\n" - "st.w $w4, (1*16)(%[accum_ptr])\n" - "st.w $w8, (2*16)(%[accum_ptr])\n" - "st.w $w1, (3*16)(%[accum_ptr])\n" - "st.w $w5, (4*16)(%[accum_ptr])\n" - "st.w $w9, (5*16)(%[accum_ptr])\n" - "st.w $w2, (6*16)(%[accum_ptr])\n" - "st.w $w6, (7*16)(%[accum_ptr])\n" - "st.w $w10, (8*16)(%[accum_ptr])\n" - "st.w $w3, (9*16)(%[accum_ptr])\n" - "st.w $w7, (10*16)(%[accum_ptr])\n" - "st.w $w11, (11*16)(%[accum_ptr])\n" - : // outputs - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [depth] "+r"(depth) - : // inputs - [accum_ptr] "r"(accum_ptr) - : // clobbers - "memory", - "a0", "a1", "a2", "a3", - "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", - "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", - "$f16", "$f17", "$f18", "$f19"); - } -}; - -// Assembly implementation of the above -// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO). -// Using 16x16=32 multiplications. -// 20 MSA regs used: -// - 12 accumulators -// - 3 lhs -// - 4 rhs -// - 1 temps/zeroes -// ~45 instructions in the loop. -struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 { - typedef std::uint8_t OperandType; +// C++ intrinsics-based variant of the deep, 4.25-bit, fast kernel. +struct NEON_64bit_GEMM_Int425Operands_intrinsics { + typedef std::int8_t OperandType; typedef std::int32_t AccumulatorType; typedef KernelFormat< - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > - Format; - static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, - AccumulatorType* accum_ptr, int depth) { - asm volatile( - // Load accumulators - "ld.w $w0, (0*16)(%[accum_ptr])\n" - "ld.w $w4, (1*16)(%[accum_ptr])\n" - "ld.w $w8, (2*16)(%[accum_ptr])\n" - "ld.w $w1, (3*16)(%[accum_ptr])\n" - "ld.w $w5, (4*16)(%[accum_ptr])\n" - "ld.w $w9, (5*16)(%[accum_ptr])\n" - "ld.w $w2, (6*16)(%[accum_ptr])\n" - "ld.w $w6, (7*16)(%[accum_ptr])\n" - "ld.w $w10, (8*16)(%[accum_ptr])\n" - "ld.w $w3, (9*16)(%[accum_ptr])\n" - "ld.w $w7, (10*16)(%[accum_ptr])\n" - "ld.w $w11, (11*16)(%[accum_ptr])\n" - // Set a temp to all zeroes. - "ldi.b $w19, 0\n" - - GEMMLOWP_LABEL_LOOP ":\n" - // Overview of register layout: - // - // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register - // contains 4 replicas of a pair of elements). - // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14. - // A 12x4 block of accumulators is stored in 32bit in w0-w11. - // - // +-----+-----+-----+-----+ - // Rhs | w15 | w16 | w17 | w18 | - // +-----+-----+-----+-----+ - // - // | | | | | - // - // Lhs | | | | | - // - // +---+ - - - - +-----+-----+-----+-----+ - // |w12| | w0 | w1 | w2 | w3 | - // |w12| | w0 | w1 | w2 | w3 | - // |w12| | w0 | w1 | w2 | w3 | - // |w12| | w0 | w1 | w2 | w3 | - // +---+ - - - - +-----+-----+-----+-----+ - // |w13| | w4 | w5 | w6 | w7 | - // |w13| | w4 | w5 | w6 | w7 | - // |w13| | w4 | w5 | w6 | w7 | - // |w13| | w4 | w5 | w6 | w7 | - // +---+ - - - - +-----+-----+-----+-----+ - // |w14| | w8 | w9 | w10 | w11 | - // |w14| | w8 | w9 | w10 | w11 | - // |w14| | w8 | w9 | w10 | w11 | - // |w14| | w8 | w9 | w10 | w11 | - // +---+ - - - - +-----+-----+-----+-----+ - // - // Accumulators - - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - "ld.b $w12, 0(%[lhs_ptr])\n" - "ld.b $w13, 8(%[lhs_ptr])\n" - - // Load 4 bytes of rhs[] for depth 0. - "lbu $a0, 0(%[rhs_ptr])\n" - "lbu $a1, 1(%[rhs_ptr])\n" - "lbu $a2, 2(%[rhs_ptr])\n" - "lbu $a3, 3(%[rhs_ptr])\n" - // Load 4 bytes of rhs[] for depth 1. - "lbu $v0, 4(%[rhs_ptr])\n" - "lbu $v1, 5(%[rhs_ptr])\n" - "lbu $t8, 6(%[rhs_ptr])\n" - "lbu $t9, 7(%[rhs_ptr])\n" - - // Zero-extend 8-bit elements of lhs[] to 16 bits. - "ilvr.b $w12, $w19, $w12\n" - "ilvl.b $w14, $w19, $w13\n" - "ilvr.b $w13, $w19, $w13\n" - // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. - "ilvl.d $w15, $w19, $w12\n" - "ilvl.d $w16, $w19, $w13\n" - "ilvl.d $w17, $w19, $w14\n" - "ilvr.h $w12, $w15, $w12\n" - "ilvr.h $w13, $w16, $w13\n" - "ilvr.h $w14, $w17, $w14\n" - - // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w. - "ins $a0, $v0, 16, 8\n" - "ins $a1, $v1, 16, 8\n" - "ins $a2, $t8, 16, 8\n" - "ins $a3, $t9, 16, 8\n" - // Make 4 replicas of every pair of rhs[] elements. - "fill.w $w15, $a0\n" - "fill.w $w16, $a1\n" - "fill.w $w17, $a2\n" - "fill.w $w18, $a3\n" - - // Depths 0 and 1. - // Dot-product-(and)-add doubles multiplicand width. - "dpadd_u.w $w0, $w12, $w15\n" - "dpadd_u.w $w4, $w13, $w15\n" - "dpadd_u.w $w8, $w14, $w15\n" - "dpadd_u.w $w1, $w12, $w16\n" - "dpadd_u.w $w5, $w13, $w16\n" - "dpadd_u.w $w9, $w14, $w16\n" - "dpadd_u.w $w2, $w12, $w17\n" - "dpadd_u.w $w6, $w13, $w17\n" - "dpadd_u.w $w10, $w14, $w17\n" - "dpadd_u.w $w3, $w12, $w18\n" - "dpadd_u.w $w7, $w13, $w18\n" - "dpadd_u.w $w11, $w14, $w18\n" - - "addiu %[depth], -2\n" - GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" - GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n" - "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" - - // Store accumulators. - "st.w $w0, (0*16)(%[accum_ptr])\n" - "st.w $w4, (1*16)(%[accum_ptr])\n" - "st.w $w8, (2*16)(%[accum_ptr])\n" - "st.w $w1, (3*16)(%[accum_ptr])\n" - "st.w $w5, (4*16)(%[accum_ptr])\n" - "st.w $w9, (5*16)(%[accum_ptr])\n" - "st.w $w2, (6*16)(%[accum_ptr])\n" - "st.w $w6, (7*16)(%[accum_ptr])\n" - "st.w $w10, (8*16)(%[accum_ptr])\n" - "st.w $w3, (9*16)(%[accum_ptr])\n" - "st.w $w7, (10*16)(%[accum_ptr])\n" - "st.w $w11, (11*16)(%[accum_ptr])\n" - : // outputs - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [depth] "+r"(depth) - : // inputs - [accum_ptr] "r"(accum_ptr) - : // clobbers - "memory", - "v0", "v1", - "a0", "a1", "a2", "a3", - "t8", "t9", - "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", - "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", - "$f16", "$f17", "$f18", "$f19"); - } -}; - -// Using 32x32=32 multiplications. -// 32 MSA regs used: -// - 24 accumulators -// - 6 lhs -// - 1 rhs -// - 1 temps/zeroes -// ~95 instructions in the loop. -struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics { - typedef std::uint8_t OperandType; - typedef std::uint32_t AccumulatorType; - typedef KernelFormat< - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > + KernelSideFormat<CellFormat<4, 32, CellOrder::WidthMajor>, 1>, + KernelSideFormat<CellFormat<2, 32, CellOrder::WidthMajor>, 1> > Format; static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, AccumulatorType* accum_ptr, int depth) { - const v16i8 zeroes = __builtin_msa_ldi_b(0); - v4i32 acc[3][8]; - // Load accumulators. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 8; j++) { - acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0); + int32x4_t acc[4][2]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + acc[i][j] = vdupq_n_s32(0); } } - while (depth > 0) { - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - v8i16 lhs[6]; - lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0)); - lhs[1] = - reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0)); - - // Zero-extend 8-bit elements of lhs[] to 16 bits. - lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, - reinterpret_cast<v16i8>(lhs[0]))); - lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes, - reinterpret_cast<v16i8>(lhs[1]))); - lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, - reinterpret_cast<v16i8>(lhs[1]))); - - // Zero-extend 16-bit elements of lhs[] to 32 bits. - lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); - lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); - lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); - lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); - lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); - lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); - - // Depth 0. - for (int j = 0; j < 4; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); + const int num_outer_depth_loop = depth / 512 + 1; + int d = 0; + for (int od = 0; od < num_outer_depth_loop; od++) { + int16x8_t local_acc[4][2]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + local_acc[i][j] = vdupq_n_s16(0); } } - for (int j = 4; j < 8; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); + for (int k = 0; k < 16 && d <= depth - 32; k++, d += 32) { + int8x16_t lhs[8]; + for (int i = 0; i < 8; i++) { + lhs[i] = vld1q_s8(lhs_ptr + 16 * i); } - } - - // Depth 1. - for (int j = 0; j < 4; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); + int8x16_t rhs[4]; + for (int i = 0; i < 4; i++) { + rhs[i] = vld1q_s8(rhs_ptr + 16 * i); } - } - for (int j = 4; j < 8; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 8])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + int8x16_t temp_acc = vmulq_s8(lhs[i * 2], rhs[j * 2]); + temp_acc = vmlaq_s8(temp_acc, lhs[i * 2 + 1], rhs[j * 2 + 1]); + local_acc[i][j] = vpadalq_s8(local_acc[i][j], temp_acc); + } } + lhs_ptr += 128; + rhs_ptr += 64; } - lhs_ptr += 24; - rhs_ptr += 16; - depth -= 2; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + acc[i][j] = vpadalq_s16(acc[i][j], local_acc[i][j]); + } + } } - // Store accumulators. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 8; j++) { - __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0); - } + for (int i = 0; i < 2; i++) { + int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]); + int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]); + int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1); + + int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i); + dst_val = vaddq_s32(dst_val, acc_4x); + vst1q_s32(accum_ptr + 4 * i, dst_val); } } }; -// Assembly implementation of the above -// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics. -// Using 32x32=32 multiplications. -// 32 MSA regs used: -// - 24 accumulators -// - 6 lhs -// - 1 rhs -// - 1 temps/zeroes -// ~95 instructions in the loop. -struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly { +SET_425BIT_RANGES(NEON_64bit_GEMM_Int425Operands_intrinsics); + +#endif // __arm__ || __aarch64__ + +#ifdef __mips +// 12x8 depth 2 depth-major kernel. +struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators1 { typedef std::uint8_t OperandType; typedef std::uint32_t AccumulatorType; typedef KernelFormat< @@ -4419,36 +4949,37 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly { GEMMLOWP_LABEL_LOOP ":\n" // Overview of register layout: // - // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30. - // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29. + // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30 + // (each register contains 4 replicas of a pair of elements). + // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26. // A 12x8 block of accumulators is stored in 32bit in w0-w23. // // +------+------+------+------+ - // Rhs |w30[0]|w30[1]|w30[2]|w30[3]| + // Rhs |w27 |w28 |w29 |w30 | // +------+------+------+------+ // // | | | | | // - // Lhs | | | | | + // Lhs | | | | | // - // +---+---+ - - - - +------+------+------+------+ - // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | - // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | - // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | - // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | - // +---+---+ - - - - +------+------+------+------+ - // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | - // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | - // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | - // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | - // +---+---+ - - - - +------+------+------+------+ - // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| - // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| - // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| - // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| - // +---+---+ - - - - +------+------+------+------+ + // +---+ - - - - +------+------+------+------+ + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // +---+ - - - - +------+------+------+------+ + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // +---+ - - - - +------+------+------+------+ + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // +---+ - - - - +------+------+------+------+ // - // Accumulator + // Accumulators // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. "ld.b $w24, 0(%[lhs_ptr])\n" @@ -4459,100 +4990,88 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly { "lbu $a1, 1(%[rhs_ptr])\n" "lbu $a2, 2(%[rhs_ptr])\n" "lbu $a3, 3(%[rhs_ptr])\n" + // Load 4 bytes of rhs[] for the first half of depth 1. + "lbu $v0, 4(%[rhs_ptr])\n" + "lbu $v1, 5(%[rhs_ptr])\n" + "lbu $t8, 6(%[rhs_ptr])\n" + "lbu $t9, 7(%[rhs_ptr])\n" // Zero-extend 8-bit elements of lhs[] to 16 bits. "ilvr.b $w24, $w31, $w24\n" "ilvl.b $w26, $w31, $w25\n" "ilvr.b $w25, $w31, $w25\n" - // Zero-extend 16-bit elements of lhs[] to 32 bits. - "ilvl.h $w27, $w31, $w24\n" - "ilvl.h $w28, $w31, $w25\n" - "ilvl.h $w29, $w31, $w26\n" - "ilvr.h $w24, $w31, $w24\n" - "ilvr.h $w25, $w31, $w25\n" - "ilvr.h $w26, $w31, $w26\n" - - // Depth 0. - "fill.w $w30, $a0\n" - "lbu $a0, 8(%[rhs_ptr])\n" - "maddv.w $w0, $w24, $w30\n" - "maddv.w $w4, $w25, $w30\n" - "maddv.w $w8, $w26, $w30\n" - "fill.w $w30, $a1\n" - "lbu $a1, 9(%[rhs_ptr])\n" - "maddv.w $w1, $w24, $w30\n" - "maddv.w $w5, $w25, $w30\n" - "maddv.w $w9, $w26, $w30\n" - "fill.w $w30, $a2\n" - "lbu $a2, 10(%[rhs_ptr])\n" - "maddv.w $w2, $w24, $w30\n" - "maddv.w $w6, $w25, $w30\n" - "maddv.w $w10, $w26, $w30\n" - "fill.w $w30, $a3\n" - "lbu $a3, 11(%[rhs_ptr])\n" - "maddv.w $w3, $w24, $w30\n" - "maddv.w $w7, $w25, $w30\n" - "maddv.w $w11, $w26, $w30\n" - - "fill.w $w30, $a0\n" - "lbu $a0, 4(%[rhs_ptr])\n" - "maddv.w $w12, $w24, $w30\n" - "maddv.w $w16, $w25, $w30\n" - "maddv.w $w20, $w26, $w30\n" - "fill.w $w30, $a1\n" - "lbu $a1, 5(%[rhs_ptr])\n" - "maddv.w $w13, $w24, $w30\n" - "maddv.w $w17, $w25, $w30\n" - "maddv.w $w21, $w26, $w30\n" - "fill.w $w30, $a2\n" - "lbu $a2, 6(%[rhs_ptr])\n" - "maddv.w $w14, $w24, $w30\n" - "maddv.w $w18, $w25, $w30\n" - "maddv.w $w22, $w26, $w30\n" - "fill.w $w30, $a3\n" - "lbu $a3, 7(%[rhs_ptr])\n" - "maddv.w $w15, $w24, $w30\n" - "maddv.w $w19, $w25, $w30\n" - "maddv.w $w23, $w26, $w30\n" - - // Depth 1. - "fill.w $w30, $a0\n" - "lbu $a0, 12(%[rhs_ptr])\n" - "maddv.w $w0, $w27, $w30\n" - "maddv.w $w4, $w28, $w30\n" - "maddv.w $w8, $w29, $w30\n" - "fill.w $w30, $a1\n" - "lbu $a1, 13(%[rhs_ptr])\n" - "maddv.w $w1, $w27, $w30\n" - "maddv.w $w5, $w28, $w30\n" - "maddv.w $w9, $w29, $w30\n" - "fill.w $w30, $a2\n" - "lbu $a2, 14(%[rhs_ptr])\n" - "maddv.w $w2, $w27, $w30\n" - "maddv.w $w6, $w28, $w30\n" - "maddv.w $w10, $w29, $w30\n" - "fill.w $w30, $a3\n" - "lbu $a3, 15(%[rhs_ptr])\n" - "maddv.w $w3, $w27, $w30\n" - "maddv.w $w7, $w28, $w30\n" - "maddv.w $w11, $w29, $w30\n" - - "fill.w $w30, $a0\n" - "maddv.w $w12, $w27, $w30\n" - "maddv.w $w16, $w28, $w30\n" - "maddv.w $w20, $w29, $w30\n" - "fill.w $w30, $a1\n" - "maddv.w $w13, $w27, $w30\n" - "maddv.w $w17, $w28, $w30\n" - "maddv.w $w21, $w29, $w30\n" - "fill.w $w30, $a2\n" - "maddv.w $w14, $w27, $w30\n" - "maddv.w $w18, $w28, $w30\n" - "maddv.w $w22, $w29, $w30\n" - "fill.w $w30, $a3\n" - "maddv.w $w15, $w27, $w30\n" - "maddv.w $w19, $w28, $w30\n" - "maddv.w $w23, $w29, $w30\n" + // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. + "ilvl.d $w27, $w31, $w24\n" + "ilvl.d $w28, $w31, $w25\n" + "ilvl.d $w29, $w31, $w26\n" + "ilvr.h $w24, $w27, $w24\n" + "ilvr.h $w25, $w28, $w25\n" + "ilvr.h $w26, $w29, $w26\n" + + // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w + // (for the first half). + "ins $a0, $v0, 16, 8\n" + "ins $a1, $v1, 16, 8\n" + "ins $a2, $t8, 16, 8\n" + "ins $a3, $t9, 16, 8\n" + // Make 4 replicas of every pair of rhs[] elements. + "fill.w $w27, $a0\n" + "fill.w $w28, $a1\n" + "fill.w $w29, $a2\n" + "fill.w $w30, $a3\n" + + // Load 4 bytes of rhs[] for the second half of depth 0. + "lbu $a0, 8(%[rhs_ptr])\n" + "lbu $a1, 9(%[rhs_ptr])\n" + "lbu $a2, 10(%[rhs_ptr])\n" + "lbu $a3, 11(%[rhs_ptr])\n" + // Load 4 bytes of rhs[] for the second half of depth 1. + "lbu $v0, 12(%[rhs_ptr])\n" + "lbu $v1, 13(%[rhs_ptr])\n" + "lbu $t8, 14(%[rhs_ptr])\n" + "lbu $t9, 15(%[rhs_ptr])\n" + + // First half of depths 0 and 1. + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w0, $w24, $w27\n" + "dpadd_u.w $w4, $w25, $w27\n" + "dpadd_u.w $w8, $w26, $w27\n" + "dpadd_u.w $w1, $w24, $w28\n" + "dpadd_u.w $w5, $w25, $w28\n" + "dpadd_u.w $w9, $w26, $w28\n" + "dpadd_u.w $w2, $w24, $w29\n" + "dpadd_u.w $w6, $w25, $w29\n" + "dpadd_u.w $w10, $w26, $w29\n" + "dpadd_u.w $w3, $w24, $w30\n" + "dpadd_u.w $w7, $w25, $w30\n" + "dpadd_u.w $w11, $w26, $w30\n" + + // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w + // (for the second half). + "ins $a0, $v0, 16, 8\n" + "ins $a1, $v1, 16, 8\n" + "ins $a2, $t8, 16, 8\n" + "ins $a3, $t9, 16, 8\n" + // Make 4 replicas of every pair of rhs[] elements. + "fill.w $w27, $a0\n" + "fill.w $w28, $a1\n" + "fill.w $w29, $a2\n" + "fill.w $w30, $a3\n" + + // Second half of depths 0 and 1. + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w12, $w24, $w27\n" + "dpadd_u.w $w16, $w25, $w27\n" + "dpadd_u.w $w20, $w26, $w27\n" + "dpadd_u.w $w13, $w24, $w28\n" + "dpadd_u.w $w17, $w25, $w28\n" + "dpadd_u.w $w21, $w26, $w28\n" + "dpadd_u.w $w14, $w24, $w29\n" + "dpadd_u.w $w18, $w25, $w29\n" + "dpadd_u.w $w22, $w26, $w29\n" + "dpadd_u.w $w15, $w24, $w30\n" + "dpadd_u.w $w19, $w25, $w30\n" + "dpadd_u.w $w23, $w26, $w30\n" "addiu %[depth], -2\n" GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" @@ -4591,7 +5110,9 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly { [accum_ptr] "r"(accum_ptr) : // clobbers "memory", + "v0", "v1", "a0", "a1", "a2", "a3", + "t8", "t9", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", @@ -4599,21 +5120,14 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly { } }; -// Assembly implementation of the above -// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO). -// Using 16x16=32 multiplications. -// 32 MSA regs used: -// - 24 accumulators -// - 3 lhs -// - 4 rhs -// - 1 temps/zeroes -// ~70 instructions in the loop. -struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 { +// 12x8 depth 2 width-major kernel. +// Does less shuffling and replication than the kernel above. +struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators2 { typedef std::uint8_t OperandType; typedef std::uint32_t AccumulatorType; typedef KernelFormat< - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, - KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2> > Format; static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, AccumulatorType* accum_ptr, int depth) { @@ -4643,19 +5157,18 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 { "ld.w $w15, (21*16)(%[accum_ptr])\n" "ld.w $w19, (22*16)(%[accum_ptr])\n" "ld.w $w23, (23*16)(%[accum_ptr])\n" - // Set a temp to all zeroes. - "ldi.b $w31, 0\n" - GEMMLOWP_LABEL_LOOP ":\n" + GEMMLOWP_LABEL_LOOP + ":\n" // Overview of register layout: // - // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30 + // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31 // (each register contains 4 replicas of a pair of elements). // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26. // A 12x8 block of accumulators is stored in 32bit in w0-w23. // // +------+------+------+------+ - // Rhs |w27 |w28 |w29 |w30 | + // Rhs |w28 |w29 |w30 |w31 | // +------+------+------+------+ // // | | | | | @@ -4685,98 +5198,65 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 { "ld.b $w24, 0(%[lhs_ptr])\n" "ld.b $w25, 8(%[lhs_ptr])\n" - // Load 4 bytes of rhs[] for the first half of depth 0. - "lbu $a0, 0(%[rhs_ptr])\n" - "lbu $a1, 1(%[rhs_ptr])\n" - "lbu $a2, 2(%[rhs_ptr])\n" - "lbu $a3, 3(%[rhs_ptr])\n" - // Load 4 bytes of rhs[] for the first half of depth 1. - "lbu $v0, 4(%[rhs_ptr])\n" - "lbu $v1, 5(%[rhs_ptr])\n" - "lbu $t8, 6(%[rhs_ptr])\n" - "lbu $t9, 7(%[rhs_ptr])\n" + // Load 2 x 8 bytes of rhs[]. + "ld.b $w27, 0(%[rhs_ptr])\n" // Zero-extend 8-bit elements of lhs[] to 16 bits. + "ldi.b $w31, 0\n" "ilvr.b $w24, $w31, $w24\n" "ilvl.b $w26, $w31, $w25\n" "ilvr.b $w25, $w31, $w25\n" - // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. - "ilvl.d $w27, $w31, $w24\n" - "ilvl.d $w28, $w31, $w25\n" - "ilvl.d $w29, $w31, $w26\n" - "ilvr.h $w24, $w27, $w24\n" - "ilvr.h $w25, $w28, $w25\n" - "ilvr.h $w26, $w29, $w26\n" - - // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w - // (for the first half). - "ins $a0, $v0, 16, 8\n" - "ins $a1, $v1, 16, 8\n" - "ins $a2, $t8, 16, 8\n" - "ins $a3, $t9, 16, 8\n" - // Make 4 replicas of every pair of rhs[] elements. - "fill.w $w27, $a0\n" - "fill.w $w28, $a1\n" - "fill.w $w29, $a2\n" - "fill.w $w30, $a3\n" - - // Load 4 bytes of rhs[] for the second half of depth 0. - "lbu $a0, 8(%[rhs_ptr])\n" - "lbu $a1, 9(%[rhs_ptr])\n" - "lbu $a2, 10(%[rhs_ptr])\n" - "lbu $a3, 11(%[rhs_ptr])\n" - // Load 4 bytes of rhs[] for the second half of depth 1. - "lbu $v0, 12(%[rhs_ptr])\n" - "lbu $v1, 13(%[rhs_ptr])\n" - "lbu $t8, 14(%[rhs_ptr])\n" - "lbu $t9, 15(%[rhs_ptr])\n" // First half of depths 0 and 1. - // Dot-product-(and)-add doubles multiplicand width. - "dpadd_u.w $w0, $w24, $w27\n" - "dpadd_u.w $w4, $w25, $w27\n" - "dpadd_u.w $w8, $w26, $w27\n" - "dpadd_u.w $w1, $w24, $w28\n" - "dpadd_u.w $w5, $w25, $w28\n" - "dpadd_u.w $w9, $w26, $w28\n" - "dpadd_u.w $w2, $w24, $w29\n" - "dpadd_u.w $w6, $w25, $w29\n" - "dpadd_u.w $w10, $w26, $w29\n" - "dpadd_u.w $w3, $w24, $w30\n" - "dpadd_u.w $w7, $w25, $w30\n" - "dpadd_u.w $w11, $w26, $w30\n" - - // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w - // (for the second half). - "ins $a0, $v0, 16, 8\n" - "ins $a1, $v1, 16, 8\n" - "ins $a2, $t8, 16, 8\n" - "ins $a3, $t9, 16, 8\n" + // Zero-extend 8-bit elements of rhs[] to 16 bits. + "ilvr.b $w31, $w31, $w27\n" // Make 4 replicas of every pair of rhs[] elements. - "fill.w $w27, $a0\n" - "fill.w $w28, $a1\n" - "fill.w $w29, $a2\n" - "fill.w $w30, $a3\n" + "splati.w $w28, $w31[0]\n" + "splati.w $w29, $w31[1]\n" + "splati.w $w30, $w31[2]\n" + "splati.w $w31, $w31[3]\n" + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w0, $w24, $w28\n" + "dpadd_u.w $w4, $w25, $w28\n" + "dpadd_u.w $w8, $w26, $w28\n" + "dpadd_u.w $w1, $w24, $w29\n" + "dpadd_u.w $w5, $w25, $w29\n" + "dpadd_u.w $w9, $w26, $w29\n" + "dpadd_u.w $w2, $w24, $w30\n" + "dpadd_u.w $w6, $w25, $w30\n" + "dpadd_u.w $w10, $w26, $w30\n" + "dpadd_u.w $w3, $w24, $w31\n" + "dpadd_u.w $w7, $w25, $w31\n" + "dpadd_u.w $w11, $w26, $w31\n" // Second half of depths 0 and 1. + // Zero-extend 8-bit elements of rhs[] to 16 bits. + "ldi.b $w31, 0\n" + "ilvl.b $w31, $w31, $w27\n" + // Make 4 replicas of every pair of rhs[] elements. + "splati.w $w28, $w31[0]\n" + "splati.w $w29, $w31[1]\n" + "splati.w $w30, $w31[2]\n" + "splati.w $w31, $w31[3]\n" // Dot-product-(and)-add doubles multiplicand width. - "dpadd_u.w $w12, $w24, $w27\n" - "dpadd_u.w $w16, $w25, $w27\n" - "dpadd_u.w $w20, $w26, $w27\n" - "dpadd_u.w $w13, $w24, $w28\n" - "dpadd_u.w $w17, $w25, $w28\n" - "dpadd_u.w $w21, $w26, $w28\n" - "dpadd_u.w $w14, $w24, $w29\n" - "dpadd_u.w $w18, $w25, $w29\n" - "dpadd_u.w $w22, $w26, $w29\n" - "dpadd_u.w $w15, $w24, $w30\n" - "dpadd_u.w $w19, $w25, $w30\n" - "dpadd_u.w $w23, $w26, $w30\n" - - "addiu %[depth], -2\n" - GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" - GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n" - "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" + "dpadd_u.w $w12, $w24, $w28\n" + "dpadd_u.w $w16, $w25, $w28\n" + "dpadd_u.w $w20, $w26, $w28\n" + "dpadd_u.w $w13, $w24, $w29\n" + "dpadd_u.w $w17, $w25, $w29\n" + "dpadd_u.w $w21, $w26, $w29\n" + "dpadd_u.w $w14, $w24, $w30\n" + "dpadd_u.w $w18, $w25, $w30\n" + "dpadd_u.w $w22, $w26, $w30\n" + "dpadd_u.w $w15, $w24, $w31\n" + "dpadd_u.w $w19, $w25, $w31\n" + "dpadd_u.w $w23, $w26, $w31\n" + + "addiu %[depth], -2\n" GEMMLOWP_MIPS_XADDIU + " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU + " %[rhs_ptr], 16\n" + "bnez %[depth]," GEMMLOWP_LABEL_LOOP + "b\n" // Store accumulators. "st.w $w0, (0*16)(%[accum_ptr])\n" @@ -4809,14 +5289,304 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 { : // inputs [accum_ptr] "r"(accum_ptr) : // clobbers - "memory", - "v0", "v1", - "a0", "a1", "a2", "a3", - "t8", "t9", - "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", - "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", - "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", - "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31"); + "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", + "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", + "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", + "$f27", "$f28", "$f29", "$f30", "$f31"); + } +}; + +// 4x4 depth 16 width-major kernel operating on int8 operands. +// It is assumed that one of the two int8 operands only takes values +// in [-127, 127], while the other may freely range in [-128, 127]. +// The issue with both operands taking the value -128 is that: +// -128*-128 + -128*-128 == -32768 overflows int16. +// Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 +// range. That is the basic idea of this kernel. +struct MSA_GEMM_Int8Operands_AccumTwoWithin16Bits { + typedef std::int8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + std::size_t start_depth = 123; + std::size_t run_depth = depth; + std::size_t dst_col_stride = 4; + AccumulatorType* dst_ptr = accum_ptr; +#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1" +#define GEMMLOWP_LABEL_LOOP "2" +#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3" +#define GEMMLOWP_LABEL_STORE "4" + asm volatile( + GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n" + // Load lhs[] and rhs[], zero out internal accumulators. + "ld.b $w16, 0(%[lhs_ptr])\n" + "ldi.b $w0, 0\n" + "ld.b $w20, 0(%[rhs_ptr])\n" + "ldi.b $w1, 0\n" + "ld.b $w17, 16(%[lhs_ptr])\n" + "ldi.b $w2, 0\n" + "ld.b $w21, 16(%[rhs_ptr])\n" + "ldi.b $w3, 0\n" + "ld.b $w18, 32(%[lhs_ptr])\n" + "ldi.b $w4, 0\n" + "ld.b $w19, 48(%[lhs_ptr])\n" + "ldi.b $w5, 0\n" + "ld.b $w22, 32(%[rhs_ptr])\n" + "ldi.b $w6, 0\n" + "ld.b $w23, 48(%[rhs_ptr])\n" + "ldi.b $w7, 0\n" + "ldi.b $w8, 0\n" + "ldi.b $w9, 0\n" + "ldi.b $w10, 0\n" + "ldi.b $w11, 0\n" + "ldi.b $w12, 0\n" + "ldi.b $w13, 0\n" + "ldi.b $w14, 0\n" + "ldi.b $w15, 0\n" + "ldi.h $w31, 1\n" + // If the loop depth is only 16, then we can skip the general loop + // and go straight to the final part of the code. + "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n" + + GEMMLOWP_LABEL_LOOP ":\n" + // Overview of register layout: + // + // A 4x16 block of Rhs is stored in 8 bit in w16-w19. + // A 4x16 block of Lhs is stored in 8 bit in w20-w23. + // + // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit + // components which need to be horizontally added at the end). + // + // Dot products of Lhs and Rhs are 16-bit values, which can't + // immediately be accumulated in 32-bit accumulators by that + // same instruction that calculates them. + // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit + // sums in w25 (note, the 16 sums have already been reduced to 8 + // by the horizontal addition of the dotp instruction). + // They are then sign-extended to 32 bits, horizontally added + // (again) to form 4 32-bit sums and then they are finally added + // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31". + // + // +-----+-----+-----+-----+ + // Rhs | w20 | w21 | w22 | w23 | + // +-----+-----+-----+-----+ + // + // | | | | | + // + // Lhs | | | | | + // + // +---+ - - - - +-----+-----+-----+-----+ + // |w16| | w0 | w4 | w8 | w12 | + // |w17| | w1 | w5 | w9 | w13 | + // |w18| | w2 | w6 | w10 | w14 | + // |w19| | w3 | w7 | w11 | w15 | + // +---+ - - - - +-----+-----+-----+-----+ + // + // Accumulators + + // Calculate the results for 16 depths and load + // lhs[] and rhs[] for the next iteration. + GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n" + GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n" + GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w16, $w20\n" + "dotp_s.h $w26, $w17, $w20\n" + "dotp_s.h $w27, $w16, $w21\n" + "dotp_s.h $w28, $w17, $w21\n" + "dotp_s.h $w29, $w18, $w20\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w0, $w25, $w31\n" + "dpadd_s.w $w1, $w26, $w31\n" + "dpadd_s.w $w4, $w27, $w31\n" + "dpadd_s.w $w5, $w28, $w31\n" + "dpadd_s.w $w2, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w24, $w16, $w22\n" + "dotp_s.h $w25, $w19, $w20\n" + "dotp_s.h $w26, $w16, $w23\n" + "dotp_s.h $w27, $w17, $w22\n" + "ld.b $w20, 0(%[rhs_ptr])\n" + "dotp_s.h $w28, $w17, $w23\n" + "ld.b $w16, 0(%[lhs_ptr])\n" + "dotp_s.h $w29, $w18, $w21\n" + "ld.b $w17, 16(%[lhs_ptr])\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w8, $w24, $w31\n" + "dpadd_s.w $w3, $w25, $w31\n" + "dpadd_s.w $w12, $w26, $w31\n" + "dpadd_s.w $w9, $w27, $w31\n" + "dpadd_s.w $w13, $w28, $w31\n" + "dpadd_s.w $w6, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w19, $w21\n" + "dotp_s.h $w26, $w18, $w22\n" + "dotp_s.h $w27, $w18, $w23\n" + "ld.b $w21, 16(%[rhs_ptr])\n" + "dotp_s.h $w28, $w19, $w22\n" + "ld.b $w18, 32(%[lhs_ptr])\n" + "dotp_s.h $w29, $w19, $w23\n" + "ld.b $w22, 32(%[rhs_ptr])\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w7, $w25, $w31\n" + "ld.b $w19, 48(%[lhs_ptr])\n" + "dpadd_s.w $w10, $w26, $w31\n" + "ld.b $w23, 48(%[rhs_ptr])\n" + "dpadd_s.w $w14, $w27, $w31\n" + "dpadd_s.w $w11, $w28, $w31\n" + "dpadd_s.w $w15, $w29, $w31\n" + + "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n" + + GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n" + // Calculate the results for the last 16 depths. + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w16, $w20\n" + "dotp_s.h $w26, $w17, $w20\n" + "dotp_s.h $w27, $w16, $w21\n" + "dotp_s.h $w28, $w17, $w21\n" + "dotp_s.h $w29, $w18, $w20\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w0, $w25, $w31\n" + "dpadd_s.w $w1, $w26, $w31\n" + "dpadd_s.w $w4, $w27, $w31\n" + "dpadd_s.w $w5, $w28, $w31\n" + "dpadd_s.w $w2, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w24, $w16, $w22\n" + "dotp_s.h $w25, $w19, $w20\n" + "dotp_s.h $w26, $w16, $w23\n" + "dotp_s.h $w27, $w17, $w22\n" + "dotp_s.h $w28, $w17, $w23\n" + "dotp_s.h $w29, $w18, $w21\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w8, $w24, $w31\n" + "dpadd_s.w $w3, $w25, $w31\n" + "dpadd_s.w $w12, $w26, $w31\n" + "dpadd_s.w $w9, $w27, $w31\n" + "dpadd_s.w $w13, $w28, $w31\n" + "dpadd_s.w $w6, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w19, $w21\n" + "dotp_s.h $w26, $w18, $w22\n" + "dotp_s.h $w27, $w18, $w23\n" + "dotp_s.h $w28, $w19, $w22\n" + "dotp_s.h $w29, $w19, $w23\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w7, $w25, $w31\n" + "dpadd_s.w $w10, $w26, $w31\n" + "dpadd_s.w $w14, $w27, $w31\n" + "dpadd_s.w $w11, $w28, $w31\n" + "dpadd_s.w $w15, $w29, $w31\n" + + // Horizontal-add internal accumulators. + "hadd_s.d $w0, $w0, $w0\n" + "hadd_s.d $w1, $w1, $w1\n" + "hadd_s.d $w2, $w2, $w2\n" + "hadd_s.d $w3, $w3, $w3\n" + "hadd_s.d $w4, $w4, $w4\n" + "hadd_s.d $w5, $w5, $w5\n" + "hadd_s.d $w6, $w6, $w6\n" + "hadd_s.d $w7, $w7, $w7\n" + "hadd_s.d $w8, $w8, $w8\n" + "hadd_s.d $w9, $w9, $w9\n" + "hadd_s.d $w10, $w10, $w10\n" + "hadd_s.d $w11, $w11, $w11\n" + "hadd_s.d $w12, $w12, $w12\n" + "hadd_s.d $w13, $w13, $w13\n" + "hadd_s.d $w14, $w14, $w14\n" + "hadd_s.d $w15, $w15, $w15\n" + "pckev.w $w0, $w1, $w0\n" + "pckev.w $w2, $w3, $w2\n" + "pckev.w $w4, $w5, $w4\n" + "pckev.w $w6, $w7, $w6\n" + "pckev.w $w8, $w9, $w8\n" + "pckev.w $w10, $w11, $w10\n" + "pckev.w $w12, $w13, $w12\n" + "pckev.w $w14, $w15, $w14\n" + "hadd_s.d $w0, $w0, $w0\n" + "hadd_s.d $w2, $w2, $w2\n" + "hadd_s.d $w4, $w4, $w4\n" + "hadd_s.d $w6, $w6, $w6\n" + "hadd_s.d $w8, $w8, $w8\n" + "hadd_s.d $w10, $w10, $w10\n" + "hadd_s.d $w12, $w12, $w12\n" + "hadd_s.d $w14, $w14, $w14\n" + // 4 more pckev instructions follow in both paths below. + + // Check if start_depth==0 to decide whether we will load + // existing accumulators from memory. + "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n" + + "pckev.w $w0, $w2, $w0\n" + "pckev.w $w1, $w6, $w4\n" + "pckev.w $w2, $w10, $w8\n" + "pckev.w $w3, $w14, $w12\n" + + "b " GEMMLOWP_LABEL_STORE "f\n" + + GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n" + // Load accumulators from memory. + "ld.w $w16, 0(%[dst_ptr0])\n" + "pckev.w $w0, $w2, $w0\n" + "ld.w $w17, 0(%[dst_ptr1])\n" + "pckev.w $w1, $w6, $w4\n" + "ld.w $w18, 0(%[dst_ptr2])\n" + "pckev.w $w2, $w10, $w8\n" + "ld.w $w19, 0(%[dst_ptr3])\n" + "pckev.w $w3, $w14, $w12\n" + + // Add them to internal accumulators. + "addv.w $w0, $w0, $w16\n" + "addv.w $w1, $w1, $w17\n" + "addv.w $w2, $w2, $w18\n" + "addv.w $w3, $w3, $w19\n" + + GEMMLOWP_LABEL_STORE ":\n" + // Store accumulators. + "st.w $w0, 0(%[dst_ptr0])\n" + "st.w $w1, 0(%[dst_ptr1])\n" + "st.w $w2, 0(%[dst_ptr2])\n" + "st.w $w3, 0(%[dst_ptr3])\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [run_depth] "+r"(run_depth) + : // inputs + [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride), + [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2), + [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3), + [start_depth] "r"(start_depth) + : // clobbers + "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", + "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", + "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", + "$f27", "$f28", "$f29", "$f30", "$f31"); +#undef GEMMLOWP_LABEL_LOOP +#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16 +#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES +#undef GEMMLOWP_LABEL_STORE } }; #endif // __mips @@ -4901,13 +5671,10 @@ class CacheLineAlignedBuffer { }; template <typename DataType> -void FillRandom(CacheLineAlignedBuffer<DataType>* buffer) { +void FillRandom(CacheLineAlignedBuffer<DataType>* buffer, DataType min, + DataType max) { static std::mt19937 generator(0); - // 100 is smaller than any nonzero bound of the range of any data type. - const DataType kMaxVal = DataType(100); - const DataType kMinVal = - std::is_signed<DataType>::value ? -kMaxVal : DataType(0); - std::uniform_real_distribution<float> dist(kMinVal, kMaxVal); + std::uniform_real_distribution<float> dist(min, max); for (std::size_t i = 0; i < buffer->size(); i++) { buffer->data()[i] = DataType(dist(generator)); } @@ -4971,9 +5738,16 @@ void test_kernel(int depth, const char* kernel_name) { CacheLineAlignedBuffer<AccumulatorType> accum_reference(kLhsWidth * kRhsWidth); - FillRandom(&lhs); - FillRandom(&rhs); - FillRandom(&accum_initial); + FillRandom(&lhs, KernelOperandRanges<Kernel>::LhsMin(), + KernelOperandRanges<Kernel>::LhsMax()); + FillRandom(&rhs, KernelOperandRanges<Kernel>::RhsMin(), + KernelOperandRanges<Kernel>::RhsMax()); + FillRandom(&accum_initial, + std::is_signed<AccumulatorType>::value + ? AccumulatorType(-100) + : AccumulatorType(0), + AccumulatorType(100)); + Copy(&accum, accum_initial); Copy(&accum_reference, accum_initial); @@ -5159,6 +5933,10 @@ int main() { #endif #ifdef __aarch64__ + BENCHMARK(NEON_64bit_GEMM_Int425Operands); + BENCHMARK(NEON_64bit_GEMM_Int425Operands_intrinsics); + BENCHMARK(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits); + BENCHMARK(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics); BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits); BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics); BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators); @@ -5167,6 +5945,7 @@ int main() { #ifdef __ARM_FEATURE_DOTPROD BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct); BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1); + BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow); #endif BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar); BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar); @@ -5180,12 +5959,9 @@ int main() { #endif #ifdef __mips - BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics); - BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly); - BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2); - BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics); - BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly); - BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2); + BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators1); + BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators2); + BENCHMARK(MSA_GEMM_Int8Operands_AccumTwoWithin16Bits); #endif return 0; diff --git a/test/benchmark.cc b/test/benchmark.cc index 9a87a41..d8236de 100644 --- a/test/benchmark.cc +++ b/test/benchmark.cc @@ -36,7 +36,16 @@ #warning "Building without NEON support on ARM, check your compiler setup!" #endif -#if defined(__SSE4_2__) && !defined(GEMMLOWP_SSE4) +#if defined(__mips) && !defined(GEMMLOWP_MSA) +#warning "Building without MSA support on MIPS, check your compiler setup!" +#endif + +#if defined(__AVX2__) && !defined(GEMMLOWP_AVX2) +#warning \ + "Building without AVX2 support on AVX2 enabled machine, check your compiler setup!" +#endif + +#if defined(__SSE4_2__) && !defined(GEMMLOWP_AVX2) && !defined(GEMMLOWP_SSE4) #warning \ "Building without SSE4.2 support on SSE4.2 enabled machine, check your compiler setup!" #endif diff --git a/test/benchmark_all_sizes.cc b/test/benchmark_all_sizes.cc index 16cc57c..527aad6 100644 --- a/test/benchmark_all_sizes.cc +++ b/test/benchmark_all_sizes.cc @@ -16,6 +16,10 @@ test/benchmark_all_sizes.cc -o /tmp/b -O3 --std=c++11 -fPIE -static \ #include "../public/gemmlowp.h" +#ifdef GEMMLOWP_PROFILING +#include "../profiling/profiler.h" +#endif + #if defined GEMMLOWP_ANDROID && defined GEMMLOWP_ARM_32 // Compilation workaround namespace std { @@ -122,10 +126,10 @@ float benchmark_8bit(int rows, int depth, int cols) { MakeZero(&rhs); MakeZero(&result); - typedef std::tuple<OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + typedef std::tuple<OutputStageQuantizeDownInt32ByFixedPoint, OutputStageSaturatingCastToUint8> Pipeline; - gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint quantize_down_stage; quantize_down_stage.result_offset_after_shift = 128; quantize_down_stage.result_fixedpoint_multiplier = 1234567890; @@ -345,7 +349,18 @@ void run_benchmarks(std::map<Shape, float>* results) { int main() { std::map<Shape, float> results; + +#ifdef GEMMLOWP_PROFILING + gemmlowp::RegisterCurrentThreadForProfiling(); + gemmlowp::StartProfiling(); +#endif + run_benchmarks(&results); + +#ifdef GEMMLOWP_PROFILING + gemmlowp::FinishProfiling(); +#endif + printf("Using %d thread(s)\n", kNumThreads); printf("depth,rows,cols,latency(s),Gop/s\n"); for (const auto& result : results) { diff --git a/test/test.cc b/test/test.cc index eee16b4..735ad1e 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1277,6 +1277,47 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset, } } + // Test a variant of the familiar default pipeline consisting of quantize-down + // and clamp-and-cast-to-int16. + OutputStageSaturatingCastToInt16 saturating_cast_int16_stage; + auto quantize_down_and_saturating_cast_int16_pipeline = + std::make_tuple(quantize_down_stage, saturating_cast_int16_stage); + Matrix<std::int16_t, ResultOrder> result_quantized_down_saturated_int16(rows, + cols); + GemmWithOutputPipeline<std::uint8_t, std::int16_t, DefaultL8R8BitDepthParams>( + &context, lhs.const_map(), rhs.const_map(), + &result_quantized_down_saturated_int16, lhs_offset, rhs_offset, + quantize_down_and_saturating_cast_int16_pipeline); + + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + std::int32_t quantized = result_quantized_down_int32(r, c); + std::int16_t expected = std::min(std::max(quantized, -32768), 32767); + Check(expected == result_quantized_down_saturated_int16(r, c)); + } + } + +#ifdef GEMMLOWP_MSA + // Test a pipeline consisting of quantize-down and truncating-cast-to-uint8. + OutputStageTruncatingCastToUint8 truncating_cast_stage; + auto quantize_down_and_truncating_cast_pipeline = + std::make_tuple(quantize_down_stage, truncating_cast_stage); + Matrix<std::uint8_t, ResultOrder> result_quantized_down_truncated_uint8( + rows, cols); + GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>( + &context, lhs.const_map(), rhs.const_map(), + &result_quantized_down_truncated_uint8, lhs_offset, rhs_offset, + quantize_down_and_truncating_cast_pipeline); + + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + std::int32_t quantized = result_quantized_down_int32(r, c); + std::uint8_t expected = quantized & 255; + Check(expected == result_quantized_down_truncated_uint8(r, c)); + } + } +#endif + // Test a bias-addition with row-vector std::vector<std::int32_t> row_vector_data(cols); std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500, @@ -1428,8 +1469,8 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset, result_fixedpoint_shift++; } Check(result_fixedpoint_shift >= 0); - // Now test OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint - OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + // Now test OutputStageQuantizeDownInt32ByFixedPoint + OutputStageQuantizeDownInt32ByFixedPoint quantize_down_by_fixedpoint_stage; quantize_down_by_fixedpoint_stage.result_offset_after_shift = static_cast<std::int32_t>( @@ -1447,7 +1488,6 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset, &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset, quantize_down_by_fixedpoint_pipeline); - std::vector<std::int32_t> diffs_caused_by_fixedpoint; for (int r = 0; r < rows; r++) { for (int c = 0; c < cols; c++) { const std::int32_t actual = @@ -1462,6 +1502,44 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset, } } + // Test OutputStageScaleInt32ByFixedPointAndExponent + for (int exponent = -2; exponent <= 2; exponent++) { + OutputStageScaleInt32ByFixedPointAndExponent + scale_by_fixedpoint_and_exponent_stage; + scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift = + static_cast<std::int32_t>(round(static_cast<double>( + result_offset * result_mult_int * std::pow(2.0, exponent)))); + scale_by_fixedpoint_and_exponent_stage.result_fixedpoint_multiplier = + result_fixedpoint_multiplier; + scale_by_fixedpoint_and_exponent_stage.result_exponent = exponent; + auto scale_by_fixedpoint_and_exponent_pipeline = + std::make_tuple(scale_by_fixedpoint_and_exponent_stage); + Matrix<std::int32_t, ResultOrder> + result_scaled_by_fixedpoint_and_exponent_int32(rows, cols); + GemmWithOutputPipeline<std::uint8_t, std::int32_t, + DefaultL8R8BitDepthParams>( + &context, lhs.const_map(), rhs.const_map(), + &result_scaled_by_fixedpoint_and_exponent_int32, lhs_offset, rhs_offset, + scale_by_fixedpoint_and_exponent_pipeline); + + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + const std::int32_t actual = + result_scaled_by_fixedpoint_and_exponent_int32(r, c); + const std::int32_t raw = result_raw_int32(r, c); + int left_shift = std::max(0, exponent); + int right_shift = std::max(0, -exponent); + const std::int32_t expected = + scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift + + RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul((1 << left_shift) * raw, + result_fixedpoint_multiplier), + right_shift); + Check(actual == expected); + } + } + } + // Test the variant of the familiar default pipeline consisting of // quantize-down and // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the diff --git a/test/test.h b/test/test.h index aecd0c1..b381bad 100644 --- a/test/test.h +++ b/test/test.h @@ -49,7 +49,7 @@ class Matrix : public MatrixMap<tScalar, tOrder> { typedef MatrixMap<tScalar, tOrder> Map; typedef MatrixMap<const tScalar, tOrder> ConstMap; typedef typename Map::Scalar Scalar; - static const MapOrder Order = tOrder; + static constexpr MapOrder Order = tOrder; using Map::kOrder; using Map::rows_; using Map::cols_; @@ -92,12 +92,12 @@ class Matrix : public MatrixMap<tScalar, tOrder> { std::vector<Scalar> storage; }; -std::mt19937& RandomEngine() { +inline std::mt19937& RandomEngine() { static std::mt19937 engine; return engine; } -int Random() { +inline int Random() { std::uniform_int_distribution<int> dist(0, std::numeric_limits<int>::max()); return dist(RandomEngine()); } diff --git a/test/test_blocking_counter.cc b/test/test_blocking_counter.cc index d1e0932..34d963d 100644 --- a/test/test_blocking_counter.cc +++ b/test/test_blocking_counter.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test.h" -#include "../profiling/pthread_everywhere.h" - +#include <atomic> // NOLINT #include <vector> +#include <iostream> +#include <cstdlib> #include "../internal/multi_thread_gemm.h" +#include "../profiling/pthread_everywhere.h" +#include "test.h" namespace gemmlowp { @@ -26,16 +28,36 @@ class Thread { Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement) : blocking_counter_(blocking_counter), number_of_times_to_decrement_(number_of_times_to_decrement), - finished_(false), - made_the_last_decrement_(false) { + made_the_last_decrement_(false), + finished_(false) { +#if defined GEMMLOWP_USE_PTHREAD + // Limit the stack size so as not to deplete memory when creating + // many threads. + pthread_attr_t attr; + int err = pthread_attr_init(&attr); + if (!err) { + size_t stack_size; + err = pthread_attr_getstacksize(&attr, &stack_size); + if (!err && stack_size > max_stack_size_) { + err = pthread_attr_setstacksize(&attr, max_stack_size_); + } + if (!err) { + err = pthread_create(&thread_, &attr, ThreadFunc, this); + } + } + if (err) { + std::cerr << "Failed to create a thread.\n"; + std::abort(); + } +#else pthread_create(&thread_, nullptr, ThreadFunc, this); +#endif } ~Thread() { Join(); } - bool Join() const { - if (!finished_) { - pthread_join(thread_, nullptr); + bool Join() { + while (!finished_.load()) { } return made_the_last_decrement_; } @@ -48,7 +70,7 @@ class Thread { Check(!made_the_last_decrement_); made_the_last_decrement_ = blocking_counter_->DecrementCount(); } - finished_ = true; + finished_.store(true); } static void* ThreadFunc(void* ptr) { @@ -56,11 +78,18 @@ class Thread { return nullptr; } + static constexpr size_t max_stack_size_ = 256 * 1024; BlockingCounter* const blocking_counter_; const int number_of_times_to_decrement_; pthread_t thread_; - bool finished_; bool made_the_last_decrement_; + // finished_ is used to manually implement Join() by busy-waiting. + // I wanted to use pthread_join / std::thread::join, but the behavior + // observed on Android was that pthread_join aborts when the thread has + // already joined before calling pthread_join, making that hard to use. + // It appeared simplest to just implement this simple spinlock, and that + // is good enough as this is just a test. + std::atomic<bool> finished_; }; void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads, @@ -89,10 +118,10 @@ void test_blocking_counter() { // repeating the entire test sequence ensures that we test // non-monotonic changes. for (int repeat = 1; repeat <= 2; repeat++) { - for (int num_threads = 1; num_threads <= 16; num_threads++) { + for (int num_threads = 1; num_threads <= 5; num_threads++) { for (int num_decrements_per_thread = 1; - num_decrements_per_thread <= 64 * 1024; - num_decrements_per_thread *= 4) { + num_decrements_per_thread <= 4 * 1024; + num_decrements_per_thread *= 16) { test_blocking_counter(blocking_counter, num_threads, num_decrements_per_thread, num_threads * num_decrements_per_thread); diff --git a/test/test_fixedpoint.cc b/test/test_fixedpoint.cc index da222f0..44e6fae 100644 --- a/test/test_fixedpoint.cc +++ b/test/test_fixedpoint.cc @@ -17,479 +17,587 @@ #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS #include <algorithm> +#include <cinttypes> #include <cmath> +#include <cstdio> #include <random> #include <vector> -#include "test.h" #include "../fixedpoint/fixedpoint.h" +#include "test.h" namespace gemmlowp { namespace { -// Explanation of SimdVector type and associated functions -// (LoadSimdVector, StoreSimdVector): -// The fixedpoint stuff being tested here is generic in an underlying -// integer type which may be either scalar (int32_t) or SIMD (e.g. -// NEON int32x4_t). We want to write uniform tests that can test -// both the scalar and SIMD paths. We achieve this by having this -// generic SimdVector abstraction, local to this test. - +template <typename T> +T Load(const typename FixedPointRawTypeTraits<T>::ScalarRawType* src) { + return *src; +} +template <typename T> +void Store(typename FixedPointRawTypeTraits<T>::ScalarRawType* dst, T v) { + *dst = v; +} #ifdef GEMMLOWP_NEON -using SimdVector = int32x4_t; -constexpr std::size_t SimdVectorSize = 4; -SimdVector LoadSimdVector(const std::int32_t* src) { return vld1q_s32(src); } -void StoreSimdVector(std::int32_t* dst, SimdVector v) { vst1q_s32(dst, v); } -#elif defined(GEMMLOWP_SSE4) -using SimdVector = __m128i; -constexpr std::size_t SimdVectorSize = 4; -SimdVector LoadSimdVector(const std::int32_t* src) { +template <> +int32x4_t Load<int32x4_t>(const std::int32_t* src) { + return vld1q_s32(src); +} +template <> +int16x8_t Load<int16x8_t>(const std::int16_t* src) { + return vld1q_s16(src); +} +template <> +void Store<int32x4_t>(std::int32_t* dst, int32x4_t v) { + vst1q_s32(dst, v); +} +template <> +void Store<int16x8_t>(std::int16_t* dst, int16x8_t v) { + vst1q_s16(dst, v); +} +#endif +#ifdef GEMMLOWP_SSE4 +template <> +__m128i Load<__m128i>(const std::int32_t* src) { return _mm_loadu_si128(reinterpret_cast<const __m128i*>(src)); } -void StoreSimdVector(std::int32_t* dst, SimdVector v) { +template <> +void Store<__m128i>(std::int32_t* dst, __m128i v) { _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v); } -#else -using SimdVector = std::int32_t; -constexpr std::size_t SimdVectorSize = 1; -SimdVector LoadSimdVector(const std::int32_t* src) { return *src; } -void StoreSimdVector(std::int32_t* dst, SimdVector v) { *dst = v; } +template <> +int16x8_m128i Load<int16x8_m128i>(const std::int16_t* src) { + return to_int16x8_m128i( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(src))); +} +template <> +void Store<int16x8_m128i>(std::int16_t* dst, int16x8_m128i v) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v.v); +} +#endif +#ifdef GEMMLOWP_MSA +template <> +v4i32 Load<v4i32>(const std::int32_t* src) { + return __builtin_msa_ld_w(const_cast<std::int32_t*>(src), 0); +} +template <> +v8i16 Load<v8i16>(const std::int16_t* src) { + return __builtin_msa_ld_h(const_cast<std::int16_t*>(src), 0); +} +template <> +void Store<v4i32>(std::int32_t* dst, v4i32 v) { + __builtin_msa_st_w(v, dst, 0); +} +template <> +void Store<v8i16>(std::int16_t* dst, v8i16 v) { + __builtin_msa_st_h(v, dst, 0); +} #endif -// Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp: -// Most (though not all) of the fixedpoint functionality being tested -// consists of functions taking one fixedpoint value and returning one -// fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators". -// We factor a lot of testing boilerplate into a common TestUnaryOp function -// taking a "unary op" object that fully describes the function to be tested. -// These objects inherit UnaryOpBase mostly as a means to share some default -// values for some properties. -// -// An important design element here is that the fixed-point values are passed -// around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not -// as higher-level FixedPoint objects. The motivation for this design is 1) to -// avoid having to templatize everything in the tIntegerBits parameter of -// class FixedPoint, and 2) to allow directly testing low-level functions -// operating on raw types (e.g. RoundingDivideByPOT) without needlessly -// requiring -// wrapping raw values in FixedPoint objects. -class UnaryOpBase { - public: - // Min bound of the input range of this op. For example, an op only handling - // nonnegative values would return 0. - std::int32_t MinInput() const { - return std::numeric_limits<std::int32_t>::min(); - } - // Max bound of the input range of this op. For example, an op only handling - // nonpositive values would return 0. - std::int32_t MaxInput() const { - return std::numeric_limits<std::int32_t>::max(); - } - // Tolerated difference between actual and reference int32 values. - // Note that the corresponding real-numbers tolerance depends on the number - // of integer bits of the fixed-point representation of the results of this - // op. - // For example, for an op returning fixed-point values with 0 integer bits, - // the correspondence between real-number values and raw values is - // real_number = (2^31) * raw_value. - std::int32_t Tolerance() const { return 0; } -}; +#ifdef GEMMLOWP_AVX2 +template <> +__m256i Load<__m256i>(const std::int32_t* src) { + return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src)); +} -// Op wrapping RoundingDivideByPOT -class RoundingDivideByPOTOp final : public UnaryOpBase { - public: - RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {} - std::int32_t ReferenceOp(std::int32_t x) const { - const double d = static_cast<double>(x) / (1ll << exponent_); - return static_cast<std::int32_t>(std::round(d)); - } - template <typename tRawType> - tRawType Op(tRawType x) const { - return RoundingDivideByPOT(x, exponent_); - } +template <> +int16x16_m256i Load<int16x16_m256i>(const std::int16_t* src) { + return to_int16x16_m256i( + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src))); +} - private: - const int exponent_; -}; +template <> +void Store<__m256i>(std::int32_t* dst, __m256i v) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); +} -// Op wrapping SaturatingRoundingMultiplyByPOT -template <int tExponent> -class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase { - public: - std::int32_t ReferenceOp(std::int32_t x) const { - const double d = static_cast<double>(x) * std::pow(2., tExponent); - const double clamp_min = std::numeric_limits<std::int32_t>::min(); - const double clamp_max = std::numeric_limits<std::int32_t>::max(); - const double clamped = std::min(clamp_max, std::max(clamp_min, d)); - return static_cast<std::int32_t>(std::round(clamped)); - } - template <typename tRawType> - tRawType Op(tRawType x) const { - return SaturatingRoundingMultiplyByPOT<tExponent>(x); - } -}; +template <> +void Store<int16x16_m256i>(std::int16_t* dst, int16x16_m256i v) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v.v); +} +#endif -// Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl -class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final - : public UnaryOpBase { +template <typename tSimdType> +class TestFixedPoint { public: - std::int32_t MinInput() const { return -(1 << 29); } - std::int32_t MaxInput() const { return 0; } - std::int32_t Tolerance() const { return 500; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint<std::int32_t, 0>; - const double d = ToDouble(F::FromRaw(x)); - const double e = std::exp(d); - return F::FromDouble(e).raw(); - } - template <typename tRawType> - tRawType Op(tRawType x) const { - using F = FixedPoint<tRawType, 0>; - const F f = F::FromRaw(x); - const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f); - return e.raw(); - } -}; + using SimdType = tSimdType; + using SimdTypeTraits = FixedPointRawTypeTraits<SimdType>; + using ScalarType = typename SimdTypeTraits::ScalarRawType; + static constexpr int kSimdLanes = SimdTypeTraits::kLanes; + static constexpr int kScalarTypeBits = 8 * sizeof(ScalarType); + + // Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp: + // Most (though not all) of the fixedpoint functionality being tested + // consists of functions taking one fixedpoint value and returning one + // fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators". + // We factor a lot of testing boilerplate into a common TestUnaryOp function + // taking a "unary op" object that fully describes the function to be tested. + // These objects inherit UnaryOpBase mostly as a means to share some default + // values for some properties. + // + // An important design element here is that the fixed-point values are passed + // around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not + // as higher-level FixedPoint objects. The motivation for this design is 1) to + // avoid having to templatize everything in the tIntegerBits parameter of + // class FixedPoint, and 2) to allow directly testing low-level functions + // operating on raw types (e.g. RoundingDivideByPOT) without needlessly + // requiring + // wrapping raw values in FixedPoint objects. + class UnaryOpBase { + public: + // Min bound of the input range of this op. For example, an op only handling + // nonnegative values would return 0. + ScalarType MinInput() const { + return std::numeric_limits<ScalarType>::min(); + } + // Max bound of the input range of this op. For example, an op only handling + // nonpositive values would return 0. + ScalarType MaxInput() const { + return std::numeric_limits<ScalarType>::max(); + } + // Tolerated difference between actual and reference ScalarType values. + // Note that the corresponding real-numbers tolerance depends on the number + // of integer bits of the fixed-point representation of the results of this + // op. + // For example, for an op returning fixed-point values with 0 integer bits, + // the correspondence between real-number values and raw values is + // real_number = (2^31) * raw_value. + ScalarType Tolerance() const { return 0; } + }; + + // Op wrapping RoundingDivideByPOT + class RoundingDivideByPOTOp final : public UnaryOpBase { + public: + RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {} + ScalarType ReferenceOp(ScalarType x) const { + const double d = static_cast<double>(x) / (1ll << exponent_); + return static_cast<ScalarType>(std::round(d)); + } + template <typename RawType> + RawType Op(RawType x) const { + return RoundingDivideByPOT(x, exponent_); + } -// Op wrapping exp_on_negative_values -template <int tIntegerBits> -class ExpOnNegativeValuesOp final : public UnaryOpBase { - public: - std::int32_t MaxInput() const { return 0; } - std::int32_t Tolerance() const { return 500; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint<std::int32_t, tIntegerBits>; - using F0 = FixedPoint<std::int32_t, 0>; - const double d = ToDouble(F::FromRaw(x)); - const double e = std::exp(d); - return F0::FromDouble(e).raw(); - } - template <typename tRawType> - tRawType Op(tRawType x) const { - using F = FixedPoint<tRawType, tIntegerBits>; - const F f = F::FromRaw(x); - return exp_on_negative_values(f).raw(); + private: + const int exponent_; + }; + + // Op wrapping SaturatingRoundingMultiplyByPOT + template <int tExponent> + class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase { + public: + ScalarType ReferenceOp(ScalarType x) const { + const double d = static_cast<double>(x) * std::pow(2., tExponent); + const double clamp_min = std::numeric_limits<ScalarType>::min(); + const double clamp_max = std::numeric_limits<ScalarType>::max(); + const double clamped = std::min(clamp_max, std::max(clamp_min, d)); + return static_cast<ScalarType>(std::round(clamped)); + } + template <typename RawType> + RawType Op(RawType x) const { + return SaturatingRoundingMultiplyByPOT<tExponent>(x); + } + }; + + // Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl + class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final + : public UnaryOpBase { + public: + ScalarType MinInput() const { return -(1 << (kScalarTypeBits - 3)); } + ScalarType MaxInput() const { return 0; } + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 1; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint<ScalarType, 0>; + const double d = ToDouble(F::FromRaw(x)); + const double e = std::exp(d); + return F::FromDouble(e).raw(); + } + template <typename RawType> + RawType Op(RawType x) const { + using F = FixedPoint<RawType, 0>; + const F f = F::FromRaw(x); + const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f); + return e.raw(); + } + }; + + // Op wrapping exp_on_negative_values + template <int tIntegerBits> + class ExpOnNegativeValuesOp final : public UnaryOpBase { + public: + ScalarType MaxInput() const { return 0; } + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 2; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint<ScalarType, tIntegerBits>; + using F0 = FixedPoint<ScalarType, 0>; + const double d = ToDouble(F::FromRaw(x)); + const double e = std::exp(d); + return F0::FromDouble(e).raw(); + } + template <typename RawType> + RawType Op(RawType x) const { + using F = FixedPoint<RawType, tIntegerBits>; + const F f = F::FromRaw(x); + return exp_on_negative_values(f).raw(); + } + }; + + // Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1 + class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase { + public: + ScalarType MinInput() const { return 0; } + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 12 : 11; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint<ScalarType, 0>; + const double d = ToDouble(F::FromRaw(x)); + const double e = (1 - d) / (1 + d); + return F::FromDouble(e).raw(); + } + template <typename RawType> + RawType Op(RawType x) const { + using F = FixedPoint<RawType, 0>; + const F f = F::FromRaw(x); + return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw(); + } + }; + + // Op wrapping tanh + template <int tIntegerBits> + class TanhOp final : public UnaryOpBase { + public: + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 310 : 12; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint<ScalarType, tIntegerBits>; + using F0 = FixedPoint<ScalarType, 0>; + const double d = ToDouble(F::FromRaw(x)); + const double e = std::tanh(d); + return F0::FromDouble(e).raw(); + } + template <typename RawType> + RawType Op(RawType x) const { + using F = FixedPoint<RawType, tIntegerBits>; + const F f = F::FromRaw(x); + return tanh(f).raw(); + } + }; + + // Op wrapping one_over_one_plus_x_for_x_in_0_1 + class OneOverOnePlusXForXIn01Op final : public UnaryOpBase { + public: + ScalarType MinInput() const { return 0; } + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 6 : 5; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint<ScalarType, 0>; + const double d = ToDouble(F::FromRaw(x)); + const double e = 1 / (1 + d); + return F::FromDouble(e).raw(); + } + template <typename RawType> + RawType Op(RawType x) const { + using F = FixedPoint<RawType, 0>; + const F f = F::FromRaw(x); + return one_over_one_plus_x_for_x_in_0_1(f).raw(); + } + }; + + // Op wrapping logistic + template <int tIntegerBits> + class LogisticOp final : public UnaryOpBase { + public: + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 155 : 6; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint<ScalarType, tIntegerBits>; + using F0 = FixedPoint<ScalarType, 0>; + const double d = ToDouble(F::FromRaw(x)); + const double e = 1 / (1 + std::exp(-d)); + return F0::FromDouble(e).raw(); + } + template <typename RawType> + RawType Op(RawType x) const { + using F = FixedPoint<RawType, tIntegerBits>; + const F f = F::FromRaw(x); + return logistic(f).raw(); + } + }; + + // Tests a given op, on a given list of int32 input values. + template <typename tUnaryOpType> + void TestUnaryOp(const tUnaryOpType& unary_op, + const std::vector<ScalarType>& testvals) { + Check(0 == (testvals.size() % kSimdLanes)); + for (std::size_t i = 0; i < testvals.size(); i += kSimdLanes) { + // First, clamp input values accoding to the MinInput() and MaxInput() + // bounds returned by the op. + ScalarType input[kSimdLanes] = {0}; + for (std::size_t j = 0; j < kSimdLanes; j++) { + const ScalarType raw_input = testvals[i + j]; + input[j] = std::min(unary_op.MaxInput(), + std::max(unary_op.MinInput(), raw_input)); + } + // Compute reference results and check that the actual results on + // scalar inputs agree with them, to the Tolerance() returned by the op. + ScalarType reference[kSimdLanes] = {0}; + ScalarType actual_scalar[kSimdLanes] = {0}; + for (std::size_t j = 0; j < kSimdLanes; j++) { + reference[j] = unary_op.ReferenceOp(input[j]); + actual_scalar[j] = unary_op.Op(input[j]); + const std::int64_t diff = static_cast<std::int64_t>(actual_scalar[j]) - + static_cast<std::int64_t>(reference[j]); + if (std::abs(diff) > unary_op.Tolerance()) { + fprintf(stderr, "abs(diff) (%" PRId64 ") > tolerance (%d)\n", diff, + unary_op.Tolerance()); + } + Check(std::abs(diff) <= unary_op.Tolerance()); + } + // Check that the actual results on SIMD inputs agree *exactly* with the + // actual results on scalar inputs. I.e. SIMD must make absolutely no + // difference + // to the results, regardless of the fact that both scalar and SIMD + // results may differ from the reference results. + ScalarType actual_simd[kSimdLanes] = {0}; + Store<SimdType>(actual_simd, unary_op.Op(Load<SimdType>(input))); + for (std::size_t j = 0; j < kSimdLanes; j++) { + if (actual_simd[j] != actual_scalar[j]) { + fprintf(stderr, "SIMD (%d) != scalar (%d)\n", actual_simd[j], + actual_scalar[j]); + } + Check(actual_simd[j] == actual_scalar[j]); + } + } } -}; -// Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1 -class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase { - public: - std::int32_t MinInput() const { return 0; } - std::int32_t Tolerance() const { return 12; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint<std::int32_t, 0>; - const double d = ToDouble(F::FromRaw(x)); - const double e = (1 - d) / (1 + d); - return F::FromDouble(e).raw(); - } - template <typename tRawType> - tRawType Op(tRawType x) const { - using F = FixedPoint<tRawType, 0>; - const F f = F::FromRaw(x); - return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw(); + template <int tIntegerBits> + void test_convert(FixedPoint<ScalarType, tIntegerBits> x) { + typedef FixedPoint<ScalarType, tIntegerBits> F; + F y = F::FromDouble(ToDouble(x)); + Check(y == x); } -}; -// Op wrapping tanh -template <int tIntegerBits> -class TanhOp final : public UnaryOpBase { - public: - std::int32_t Tolerance() const { return 310; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint<std::int32_t, tIntegerBits>; - using F0 = FixedPoint<std::int32_t, 0>; - const double d = ToDouble(F::FromRaw(x)); - const double e = std::tanh(d); - return F0::FromDouble(e).raw(); - } - template <typename tRawType> - tRawType Op(tRawType x) const { - using F = FixedPoint<tRawType, tIntegerBits>; - const F f = F::FromRaw(x); - return tanh(f).raw(); + template <int tIntegerBits_a, int tIntegerBits_b> + void test_Rescale(FixedPoint<ScalarType, tIntegerBits_a> a) { + FixedPoint<ScalarType, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a); + FixedPoint<ScalarType, tIntegerBits_b> expected = + FixedPoint<ScalarType, tIntegerBits_b>::FromDouble(ToDouble(a)); + Check(actual == expected); } -}; -// Op wrapping one_over_one_plus_x_for_x_in_0_1 -class OneOverOnePlusXForXIn01Op final : public UnaryOpBase { - public: - std::int32_t MinInput() const { return 0; } - std::int32_t Tolerance() const { return 6; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint<std::int32_t, 0>; - const double d = ToDouble(F::FromRaw(x)); - const double e = 1 / (1 + d); - return F::FromDouble(e).raw(); - } - template <typename tRawType> - tRawType Op(tRawType x) const { - using F = FixedPoint<tRawType, 0>; - const F f = F::FromRaw(x); - return one_over_one_plus_x_for_x_in_0_1(f).raw(); + template <int tIntegerBits_a, int tIntegerBits_b> + void test_Rescale(const std::vector<ScalarType>& testvals) { + for (auto a : testvals) { + FixedPoint<ScalarType, tIntegerBits_a> aq; + aq.raw() = a; + test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq); + } } -}; -// Op wrapping logistic -template <int tIntegerBits> -class LogisticOp final : public UnaryOpBase { - public: - std::int32_t Tolerance() const { return 155; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint<std::int32_t, tIntegerBits>; - using F0 = FixedPoint<std::int32_t, 0>; - const double d = ToDouble(F::FromRaw(x)); - const double e = 1 / (1 + std::exp(-d)); - return F0::FromDouble(e).raw(); + template <int tIntegerBits_a, int tIntegerBits_b> + void test_mul(FixedPoint<ScalarType, tIntegerBits_a> a, + FixedPoint<ScalarType, tIntegerBits_b> b) { + static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b; + using ProductFixedPoint = FixedPoint<ScalarType, ProductIntegerBits>; + ProductFixedPoint ab; + ab = a * b; + double a_double = ToDouble(a); + double b_double = ToDouble(b); + double ab_double = a_double * b_double; + ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double); + std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw()); + Check(std::abs(diff) <= 1); } - template <typename tRawType> - tRawType Op(tRawType x) const { - using F = FixedPoint<tRawType, tIntegerBits>; - const F f = F::FromRaw(x); - return logistic(f).raw(); - } -}; -// Tests a given op, on a given list of int32 input values. -template <typename tUnaryOpType> -void TestUnaryOp(const tUnaryOpType& unary_op, - const std::vector<std::int32_t>& testvals_int32) { - Check(0 == (testvals_int32.size() % SimdVectorSize)); - for (std::size_t i = 0; i < testvals_int32.size(); i += SimdVectorSize) { - // First, clamp input int32 values accoding to the MinInput() and MaxInput() - // bounds returned by the op. - std::int32_t input[SimdVectorSize] = {0}; - for (std::size_t j = 0; j < SimdVectorSize; j++) { - const std::int32_t raw_input = testvals_int32[i + j]; - input[j] = std::min(unary_op.MaxInput(), - std::max(unary_op.MinInput(), raw_input)); - } - // Compute reference results and check that the actual results on - // scalar inputs agree with them, to the Tolerance() returned by the op. - std::int32_t reference[SimdVectorSize] = {0}; - std::int32_t actual_scalar[SimdVectorSize] = {0}; - for (std::size_t j = 0; j < SimdVectorSize; j++) { - reference[j] = unary_op.ReferenceOp(input[j]); - actual_scalar[j] = unary_op.Op(input[j]); - const std::int64_t diff = static_cast<std::int64_t>(actual_scalar[j]) - - static_cast<std::int64_t>(reference[j]); - Check(std::abs(diff) <= unary_op.Tolerance()); - } - // Check that the actual results on SIMD inputs agree *exactly* with the - // actual results on scalar inputs. I.e. SIMD must make absolutely no - // difference - // to the results, regardless of the fact that both scalar and SIMD results - // may differ from the reference results. - std::int32_t actual_simd[SimdVectorSize] = {0}; - StoreSimdVector(actual_simd, unary_op.Op(LoadSimdVector(input))); - for (std::size_t j = 0; j < SimdVectorSize; j++) { - Check(actual_simd[j] == actual_scalar[j]); + template <int tIntegerBits_a, int tIntegerBits_b> + void test_mul(const std::vector<ScalarType>& testvals) { + for (auto a : testvals) { + for (auto b : testvals) { + FixedPoint<ScalarType, tIntegerBits_a> aq; + FixedPoint<ScalarType, tIntegerBits_b> bq; + aq.raw() = a; + bq.raw() = b; + test_mul(aq, bq); + } } } -} -template <int tIntegerBits> -void test_convert(FixedPoint<std::int32_t, tIntegerBits> x) { - typedef FixedPoint<std::int32_t, tIntegerBits> F; - F y = F::FromDouble(ToDouble(x)); - Check(y == x); -} - -template <int tIntegerBits_a, int tIntegerBits_b> -void test_Rescale(FixedPoint<std::int32_t, tIntegerBits_a> a) { - FixedPoint<std::int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a); - FixedPoint<std::int32_t, tIntegerBits_b> expected = - FixedPoint<std::int32_t, tIntegerBits_b>::FromDouble(ToDouble(a)); - Check(actual == expected); -} - -template <int tIntegerBits_a, int tIntegerBits_b> -void test_Rescale(const std::vector<std::int32_t>& testvals_int32) { - for (auto a : testvals_int32) { - FixedPoint<std::int32_t, tIntegerBits_a> aq; - aq.raw() = a; - test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq); + template <int tExponent, int tIntegerBits_a> + void test_ExactMulByPot(FixedPoint<ScalarType, tIntegerBits_a> a) { + double x = ToDouble(a) * std::pow(2.0, tExponent); + double y = ToDouble(ExactMulByPot<tExponent>(a)); + Check(x == y); } -} - -template <int tIntegerBits_a, int tIntegerBits_b> -void test_mul(FixedPoint<std::int32_t, tIntegerBits_a> a, - FixedPoint<std::int32_t, tIntegerBits_b> b) { - static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b; - using ProductFixedPoint = FixedPoint<std::int32_t, ProductIntegerBits>; - ProductFixedPoint ab; - ab = a * b; - double a_double = ToDouble(a); - double b_double = ToDouble(b); - double ab_double = a_double * b_double; - ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double); - std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw()); - Check(std::abs(diff) <= 1); -} -template <int tIntegerBits_a, int tIntegerBits_b> -void test_mul(const std::vector<std::int32_t>& testvals_int32) { - for (auto a : testvals_int32) { - for (auto b : testvals_int32) { - FixedPoint<std::int32_t, tIntegerBits_a> aq; - FixedPoint<std::int32_t, tIntegerBits_b> bq; + template <int tExponent, int tIntegerBits_a> + void test_ExactMulByPot(const std::vector<ScalarType>& testvals) { + for (auto a : testvals) { + FixedPoint<ScalarType, tIntegerBits_a> aq; aq.raw() = a; - bq.raw() = b; - test_mul(aq, bq); + test_ExactMulByPot<tExponent, tIntegerBits_a>(aq); } } -} -template <int tExponent, int tIntegerBits_a> -void test_ExactMulByPot(FixedPoint<std::int32_t, tIntegerBits_a> a) { - double x = ToDouble(a) * std::pow(2.0, tExponent); - double y = ToDouble(ExactMulByPot<tExponent>(a)); - Check(x == y); -} + // Make the list of test values to test each op against. + std::vector<ScalarType> MakeTestVals() { + std::vector<ScalarType> testvals; + + for (int i = 0; i < kScalarTypeBits - 1; i++) { + testvals.push_back((1 << i) - 2); + testvals.push_back((1 << i) - 1); + testvals.push_back((1 << i)); + testvals.push_back((1 << i) + 1); + testvals.push_back((1 << i) + 2); + testvals.push_back(-(1 << i) - 2); + testvals.push_back(-(1 << i) - 1); + testvals.push_back(-(1 << i)); + testvals.push_back(-(1 << i) + 1); + testvals.push_back(-(1 << i) + 2); + } + testvals.push_back(std::numeric_limits<ScalarType>::min()); + testvals.push_back(std::numeric_limits<ScalarType>::min() + 1); + testvals.push_back(std::numeric_limits<ScalarType>::min() + 2); + testvals.push_back(std::numeric_limits<ScalarType>::max() - 2); + testvals.push_back(std::numeric_limits<ScalarType>::max() - 1); + testvals.push_back(std::numeric_limits<ScalarType>::max()); + + std::mt19937 random_engine; + std::uniform_int_distribution<ScalarType> uniform_distribution( + std::numeric_limits<ScalarType>::min(), + std::numeric_limits<ScalarType>::max()); + for (int i = 0; i < 1000; i++) { + testvals.push_back(uniform_distribution(random_engine)); + } -template <int tExponent, int tIntegerBits_a> -void test_ExactMulByPot(const std::vector<std::int32_t>& testvals_int32) { - for (auto a : testvals_int32) { - FixedPoint<std::int32_t, tIntegerBits_a> aq; - aq.raw() = a; - test_ExactMulByPot<tExponent, tIntegerBits_a>(aq); - } -} + // SIMD tests will require the length of testvals to be a multiple + // of SIMD vector size. + while (testvals.size() % kSimdLanes) { + testvals.push_back(0); + } -// Make the list of test values to test each op against. -std::vector<std::int32_t> MakeTestValsInt32() { - std::vector<std::int32_t> testvals_int32; - - for (int i = 0; i < 31; i++) { - testvals_int32.push_back((1 << i) - 2); - testvals_int32.push_back((1 << i) - 1); - testvals_int32.push_back((1 << i)); - testvals_int32.push_back((1 << i) + 1); - testvals_int32.push_back((1 << i) + 2); - testvals_int32.push_back(-(1 << i) - 2); - testvals_int32.push_back(-(1 << i) - 1); - testvals_int32.push_back(-(1 << i)); - testvals_int32.push_back(-(1 << i) + 1); - testvals_int32.push_back(-(1 << i) + 2); - } - testvals_int32.push_back(std::numeric_limits<std::int32_t>::min()); - testvals_int32.push_back(std::numeric_limits<std::int32_t>::min() + 1); - testvals_int32.push_back(std::numeric_limits<std::int32_t>::min() + 2); - testvals_int32.push_back(std::numeric_limits<std::int32_t>::max() - 2); - testvals_int32.push_back(std::numeric_limits<std::int32_t>::max() - 1); - testvals_int32.push_back(std::numeric_limits<std::int32_t>::max()); - - std::mt19937 random_engine; - std::uniform_int_distribution<std::int32_t> uniform_distribution( - std::numeric_limits<std::int32_t>::min(), - std::numeric_limits<std::int32_t>::max()); - for (int i = 0; i < 1000; i++) { - testvals_int32.push_back(uniform_distribution(random_engine)); + std::sort(testvals.begin(), testvals.end()); + return testvals; } - // SIMD tests will require the length of testvals_int32 to be a multiple - // of SIMD vector size. - while (testvals_int32.size() % SimdVectorSize) { - testvals_int32.push_back(0); - } + void RunTests(const char* msg) { + const std::vector<ScalarType> testvals = MakeTestVals(); - std::sort(testvals_int32.begin(), testvals_int32.end()); - return testvals_int32; -} + for (int s = 0; s < kScalarTypeBits; s++) { + TestUnaryOp(RoundingDivideByPOTOp(s), testvals); + } + + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<14 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 15>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 14>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 3>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 2>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 1>(), + testvals); + + TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals); + + TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals); + TestUnaryOp(TanhOp<0>(), testvals); + TestUnaryOp(TanhOp<1>(), testvals); + TestUnaryOp(TanhOp<2>(), testvals); + TestUnaryOp(TanhOp<3>(), testvals); + TestUnaryOp(TanhOp<4>(), testvals); + TestUnaryOp(TanhOp<5>(), testvals); + TestUnaryOp(TanhOp<6>(), testvals); + + TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals); + TestUnaryOp(LogisticOp<0>(), testvals); + TestUnaryOp(LogisticOp<1>(), testvals); + TestUnaryOp(LogisticOp<2>(), testvals); + TestUnaryOp(LogisticOp<3>(), testvals); + TestUnaryOp(LogisticOp<4>(), testvals); + TestUnaryOp(LogisticOp<5>(), testvals); + TestUnaryOp(LogisticOp<6>(), testvals); + + for (auto a : testvals) { + FixedPoint<ScalarType, 4> x; + x.raw() = a; + test_convert(x); + } + + test_mul<0, 0>(testvals); + test_mul<0, 1>(testvals); + test_mul<2, 0>(testvals); + test_mul<1, 1>(testvals); + test_mul<4, 4>(testvals); + test_mul<3, 5>(testvals); + test_mul<7, 2>(testvals); + test_mul<kScalarTypeBits / 2 - 1, kScalarTypeBits / 2 - 2>(testvals); + + test_Rescale<0, 0>(testvals); + test_Rescale<0, 1>(testvals); + test_Rescale<2, 0>(testvals); + test_Rescale<4, 4>(testvals); + test_Rescale<4, 5>(testvals); + test_Rescale<6, 3>(testvals); + test_Rescale<13, 9>(testvals); + + test_ExactMulByPot<0, 0>(testvals); + test_ExactMulByPot<0, 4>(testvals); + test_ExactMulByPot<1, 4>(testvals); + test_ExactMulByPot<3, 2>(testvals); + test_ExactMulByPot<-4, 5>(testvals); + test_ExactMulByPot<-2, 6>(testvals); + + fprintf(stderr, "PASS (%s)\n", msg); + } +}; } // end anonymous namespace } // end namespace gemmlowp int main() { - using namespace gemmlowp; - - const std::vector<std::int32_t> testvals_int32 = MakeTestValsInt32(); - - for (int s = 0; s < 32; s++) { - TestUnaryOp(RoundingDivideByPOTOp(s), testvals_int32); - } - - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-31>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-30>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-29>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-17>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-16>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<16>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<17>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<29>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<30>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<31>(), testvals_int32); - - TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(), - testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals_int32); - - TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals_int32); - TestUnaryOp(TanhOp<0>(), testvals_int32); - TestUnaryOp(TanhOp<1>(), testvals_int32); - TestUnaryOp(TanhOp<2>(), testvals_int32); - TestUnaryOp(TanhOp<3>(), testvals_int32); - TestUnaryOp(TanhOp<4>(), testvals_int32); - TestUnaryOp(TanhOp<5>(), testvals_int32); - TestUnaryOp(TanhOp<6>(), testvals_int32); - - TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals_int32); - TestUnaryOp(LogisticOp<0>(), testvals_int32); - TestUnaryOp(LogisticOp<1>(), testvals_int32); - TestUnaryOp(LogisticOp<2>(), testvals_int32); - TestUnaryOp(LogisticOp<3>(), testvals_int32); - TestUnaryOp(LogisticOp<4>(), testvals_int32); - TestUnaryOp(LogisticOp<5>(), testvals_int32); - TestUnaryOp(LogisticOp<6>(), testvals_int32); - - for (auto a : testvals_int32) { - FixedPoint<std::int32_t, 4> x; - x.raw() = a; - test_convert(x); - } - - test_mul<0, 0>(testvals_int32); - test_mul<0, 1>(testvals_int32); - test_mul<2, 0>(testvals_int32); - test_mul<1, 1>(testvals_int32); - test_mul<4, 4>(testvals_int32); - test_mul<3, 5>(testvals_int32); - test_mul<7, 2>(testvals_int32); - test_mul<14, 15>(testvals_int32); - - test_Rescale<0, 0>(testvals_int32); - test_Rescale<0, 1>(testvals_int32); - test_Rescale<2, 0>(testvals_int32); - test_Rescale<4, 4>(testvals_int32); - test_Rescale<4, 5>(testvals_int32); - test_Rescale<6, 3>(testvals_int32); - test_Rescale<13, 9>(testvals_int32); - - test_ExactMulByPot<0, 0>(testvals_int32); - test_ExactMulByPot<0, 4>(testvals_int32); - test_ExactMulByPot<1, 4>(testvals_int32); - test_ExactMulByPot<3, 2>(testvals_int32); - test_ExactMulByPot<-4, 5>(testvals_int32); - test_ExactMulByPot<-2, 6>(testvals_int32); - - std::cerr << "All tests passed." << std::endl; + gemmlowp::TestFixedPoint<std::int32_t>().RunTests("Scalar int32"); + gemmlowp::TestFixedPoint<std::int16_t>().RunTests("Scalar int16"); +#ifdef GEMMLOWP_SSE4 + gemmlowp::TestFixedPoint<__m128i>().RunTests("SSE4 __m128i = int32x4"); + gemmlowp::TestFixedPoint<gemmlowp::int16x8_m128i>().RunTests( + "SSE4 __m128i = int16x8"); +#endif +#ifdef GEMMLOWP_NEON + gemmlowp::TestFixedPoint<int32x4_t>().RunTests("NEON int32x4_t"); + gemmlowp::TestFixedPoint<int16x8_t>().RunTests("NEON int16x8_t"); +#endif +#ifdef GEMMLOWP_MSA + gemmlowp::TestFixedPoint<v4i32>().RunTests("MSA v4i32"); + gemmlowp::TestFixedPoint<v8i16>().RunTests("MSA v8i16"); +#endif +#ifdef GEMMLOWP_AVX2 + gemmlowp::TestFixedPoint<__m256i>().RunTests("AVX __m256i"); + gemmlowp::TestFixedPoint<gemmlowp::int16x16_m256i>().RunTests( + "AVX2 __m256i = int16x16"); +#endif } |