From 123f384187504585be3fe01002381dd459c17d96 Mon Sep 17 00:00:00 2001 From: Lev Proleev Date: Fri, 26 Feb 2021 21:44:39 +0000 Subject: Update gemmlowp to 13d57703abca3005d97b19df1f2db731607a7dc2 An updated is needed after TF Lite rebase. Bug: 178609672 Test: mma, NeuralNetworksStatic_test Change-Id: Ia7f04fc5b6bd760549395854618d8b20f5c8d228 --- AUTHORS | 14 + AUTHORS.txt | 9 - Android.bp | 2 +- CONTRIBUTING | 53 + CONTRIBUTING.txt | 53 - CONTRIBUTORS | 40 + CONTRIBUTORS.txt | 22 - LICENSE | 202 +++ LICENSE.txt | 202 --- README.md | 276 +++ README.txt | 260 --- fixedpoint/fixedpoint.h | 30 +- fixedpoint/fixedpoint_avx.h | 168 +- fixedpoint/fixedpoint_sse.h | 52 +- fixedpoint/fixedpoint_wasmsimd.h | 381 ++++ flags.bzl | 7 +- internal/allocator.h | 4 +- internal/common.h | 2 +- internal/detect_platform.h | 5 + internal/dispatch_gemm_shape.h | 6 +- internal/kernel.h | 20 +- internal/output_sse.h | 21 + internal/pack.h | 24 +- internal/pack_sse.h | 10 +- internal/platform.h | 3 +- meta/generators/cc_emitter.py | 8 +- meta/generators/common.py | 3 +- meta/generators/neon_emitter.py | 8 +- meta/generators/neon_emitter_64.py | 8 +- public/bit_depth.h | 4 +- public/map.h | 6 +- standalone/cache_counters.cc | 404 +++++ standalone/encode.py | 134 ++ standalone/neon-gemm-kernel-benchmark.cc | 2796 +++++++++++++++++++----------- test/benchmark.cc | 11 +- test/benchmark_all_sizes.cc | 19 +- test/test.cc | 84 +- test/test.h | 6 +- test/test_blocking_counter.cc | 55 +- test/test_fixedpoint.cc | 952 +++++----- 40 files changed, 4274 insertions(+), 2090 deletions(-) create mode 100644 AUTHORS delete mode 100644 AUTHORS.txt create mode 100644 CONTRIBUTING delete mode 100644 CONTRIBUTING.txt create mode 100644 CONTRIBUTORS delete mode 100644 CONTRIBUTORS.txt create mode 100644 LICENSE delete mode 100644 LICENSE.txt create mode 100644 README.md delete mode 100644 README.txt create mode 100644 fixedpoint/fixedpoint_wasmsimd.h create mode 100644 standalone/cache_counters.cc create mode 100644 standalone/encode.py diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000..996e104 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,14 @@ +# This is the official list of gemmlowp authors for copyright purposes. +# This file is distinct from the CONTRIBUTORS.txt file. +# See the latter for an explanation. + +# Names should be added to this file as: +# Name or Organization +# The email address is not required for organizations. + +Google Inc. +Intel Corporation +ARM Ltd. +Silk Labs Inc. +MIPS Tech LLC +Wave Computing Inc. diff --git a/AUTHORS.txt b/AUTHORS.txt deleted file mode 100644 index 13a49e0..0000000 --- a/AUTHORS.txt +++ /dev/null @@ -1,9 +0,0 @@ -# This is the official list of gemmlowp authors for copyright purposes. -# This file is distinct from the CONTRIBUTORS.txt file. -# See the latter for an explanation. - -# Names should be added to this file as: -# Name or Organization -# The email address is not required for organizations. - -Google Inc. diff --git a/Android.bp b/Android.bp index 649324a..5efb5ff 100644 --- a/Android.bp +++ b/Android.bp @@ -30,7 +30,7 @@ license { "SPDX-license-identifier-Apache-2.0", ], license_text: [ - "LICENSE.txt", + "LICENSE", "NOTICE", ], } diff --git a/CONTRIBUTING b/CONTRIBUTING new file mode 100644 index 0000000..d6d63bc --- /dev/null +++ b/CONTRIBUTING @@ -0,0 +1,53 @@ +Want to contribute? Great! First, read this page (including the small print at the end). + + +Before you contribute +===================== + +Before we can use your code, you must sign the Google Individual Contributor +License Agreement (CLA), + + https://developers.google.com/open-source/cla/individual?csw=1 + +which you can do online. The CLA is necessary mainly because you own the +copyright to your changes, even after your contribution becomes part of our +codebase, so we need your permission to use and distribute your code. We also +need to be sure of various other things—for instance that you'll tell us if you +know that your code infringes on other people's patents. You don't have to sign +the CLA until after you've submitted your code for review and a member has +approved it, but you must do it before we can put your code into our codebase. +Before you start working on a larger contribution, you should get in touch with +us first through the issue tracker with your idea so that we can help out and +possibly guide you. Coordinating up front makes it much easier to avoid +frustration later on. + + +Getting in touch with the gemmlowp community +============================================ + +The central point of communication around gemmlowp is the mailing list, + https://groups.google.com/forum/#!forum/gemmlowp + + +TODO items and projects +======================= + +We try to keep a current list of TODO items in the todo/ directory. +Please feel free to pick one to work on, and to ask current maintainers for +guidance. The gemmlowp mailing list is a good place for that. + + +Code reviews +============ + +All submissions, including submissions by project members, require review. +For this purpose, we use Github pull requests against this repository: + + https://github.com/google/gemmlowp + + +The small print +=============== + +Contributions made by corporations are covered by a different agreement than +the one above, the Software Grant and Corporate Contributor License Agreement. diff --git a/CONTRIBUTING.txt b/CONTRIBUTING.txt deleted file mode 100644 index d6d63bc..0000000 --- a/CONTRIBUTING.txt +++ /dev/null @@ -1,53 +0,0 @@ -Want to contribute? Great! First, read this page (including the small print at the end). - - -Before you contribute -===================== - -Before we can use your code, you must sign the Google Individual Contributor -License Agreement (CLA), - - https://developers.google.com/open-source/cla/individual?csw=1 - -which you can do online. The CLA is necessary mainly because you own the -copyright to your changes, even after your contribution becomes part of our -codebase, so we need your permission to use and distribute your code. We also -need to be sure of various other things—for instance that you'll tell us if you -know that your code infringes on other people's patents. You don't have to sign -the CLA until after you've submitted your code for review and a member has -approved it, but you must do it before we can put your code into our codebase. -Before you start working on a larger contribution, you should get in touch with -us first through the issue tracker with your idea so that we can help out and -possibly guide you. Coordinating up front makes it much easier to avoid -frustration later on. - - -Getting in touch with the gemmlowp community -============================================ - -The central point of communication around gemmlowp is the mailing list, - https://groups.google.com/forum/#!forum/gemmlowp - - -TODO items and projects -======================= - -We try to keep a current list of TODO items in the todo/ directory. -Please feel free to pick one to work on, and to ask current maintainers for -guidance. The gemmlowp mailing list is a good place for that. - - -Code reviews -============ - -All submissions, including submissions by project members, require review. -For this purpose, we use Github pull requests against this repository: - - https://github.com/google/gemmlowp - - -The small print -=============== - -Contributions made by corporations are covered by a different agreement than -the one above, the Software Grant and Corporate Contributor License Agreement. diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 0000000..3740e0e --- /dev/null +++ b/CONTRIBUTORS @@ -0,0 +1,40 @@ +# People who have agreed to one of the CLAs and can contribute patches. +# The AUTHORS.txt file lists the copyright holders; this file +# lists people. For example, Google employees are listed here +# but not in AUTHORS.txt, because Google holds the copyright. +# +# https://developers.google.com/open-source/cla/individual +# https://developers.google.com/open-source/cla/corporate +# +# Names should be added to this file as: +# Name + +Google: +Benoit Jacob +Pete Warden +Miao Wang +David Andersen +Maciek Chociej +Justine Tunney +Mark J. Matthews +Marie White +Suharsh Sivakumar + +Intel: +Sagi Marcovich +Murat Efe Guney +Sarah Knepper +Mourad Gouicem +Richard Winterton + +ARM: +David Mansell + +Silk Labs: +Andreas Gal + +MIPS Tech LLC: +Alexey Frunze + +Wave Computing Inc.: +Alexey Frunze diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt deleted file mode 100644 index 7c2415b..0000000 --- a/CONTRIBUTORS.txt +++ /dev/null @@ -1,22 +0,0 @@ -# People who have agreed to one of the CLAs and can contribute patches. -# The AUTHORS.txt file lists the copyright holders; this file -# lists people. For example, Google employees are listed here -# but not in AUTHORS.txt, because Google holds the copyright. -# -# https://developers.google.com/open-source/cla/individual -# https://developers.google.com/open-source/cla/corporate -# -# Names should be added to this file as: -# Name - -Google: -Benoit Jacob -Pete Warden -Miao Wang -David Andersen -Maciek Chociej - -Intel: -Sagi Marcovich -Murat Efe Guney -Sarah Knepper diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE.txt b/LICENSE.txt deleted file mode 100644 index d645695..0000000 --- a/LICENSE.txt +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..22fabac --- /dev/null +++ b/README.md @@ -0,0 +1,276 @@ +# gemmlowp: a small self-contained low-precision GEMM library + +[![Build Status](https://secure.travis-ci.org/google/gemmlowp.png)](http://travis-ci.org/google/gemmlowp) + +This is not a full linear algebra library, only a GEMM library: it only does +general matrix multiplication ("GEMM"). + +The meaning of "low precision" is detailed in this document: +[doc/low-precision.md](doc/low-precision.md) + +Some of the general design is explained in [doc/design.md](doc/design.md). + +**Warning:** This library goes very slow if compiled incorrectly; see below. + +## Disclaimer + +This is not an official Google product (experimental or otherwise), it is just +code that happens to be owned by Google. + +## Mailing list + +gemmlowp-related discussion, about either development or usage, is welcome on +this Google Group (mailing list / forum): + +https://groups.google.com/forum/#!forum/gemmlowp + +## Portability, target platforms/architectures + +Should be portable to any platform with some C++11 and POSIX support, while we +have optional optimized code paths for specific architectures. + +Required: + +* C++11 (a small conservative subset of it) + +Required for some features: + +* Some POSIX interfaces: + * pthreads (for multi-threaded operation and for profiling). + * sysconf (for multi-threaded operation to detect number of cores; may be + bypassed). + +Optional: + +* Architecture-specific code paths use intrinsics or inline assembly. See + "Architecture-specific optimized code paths" below. + +## Architecture-specific optimized code paths + +We have some optimized code paths for specific instruction sets. Some are +written in inline assembly, some are written in C++ using intrinsics. Both GCC +and Clang are supported. + +Current optimized code paths: + +* ARM with NEON (both 32bit and 64bit). +* Intel x86 with SSE 4.1 (both 32bit and 64bit). + +When building for x86, it's very important to pass `-msse4.1` to the compiler, +otherwise gemmlowp will use slow reference code. Bazel users can compile by +running `bazel build --copt=-msse4.1 //gemmlowp:all`. The compiled binary should +work on all Intel CPUs since 2008 (including low power microarchitectures) as +well as AMD CPUs since 2011. + +Please note when compiling binaries that don't need to be distributed, it's +generally a better idea to pass `-march=native` to the compiler. That flag +implies `-msse4.1` flag, along with others that might be helpful. This of course +assumes the host machine supports those instructions. Bazel users should prefer +to run `bazel build --config=opt //gemmlowp:all` instead. + +Details of what it takes to make an efficient port of gemmlowp, namely writing a +suitable GEMM kernel and accompanying packing code, are explained in this file: +[doc/kernel.md](doc/kernel.md). + +## Public interfaces + +### The gemmlowp public interface + +gemmlowp's main public interface is in the `public/` subdirectory. + +This is a headers-only library, so there is nothing to link to. + +Usage documentation, and comments on the deprecation status of each public entry +point, may be found in [doc/public.md](doc/public.md) . + +A full, self-contained usage example, showing how to quantize float matrices and +perform a quantized matrix multiplication approximating a float matrix +multiplication, is given in +[doc/quantization_example.cc](doc/quantization_example.cc). + +### Old EightBitIntGemm legacy deprecated interface + +The `eight_bit_int_gemm/` subdirectory contains an alternate interface that +should be considered purely legacy, deprecated, and going to be removed at some +point in the future. + +## Building + +### Building by manually invoking your compiler + +Because gemmlowp is so simple, working with it involves only single-command-line +compiler invocations. Therefore we expect that most people working with gemmlowp +will either manually invoke their compiler, or write their own rules for their +own preferred build system. + +Keep in mind (previous section) that gemmlowp itself is a pure-headers-only +library so there is nothing to build. + +For a Android gemmlowp development workflow, the `scripts/` directory contains a +script to build and run a program on an Android device: + +``` +scripts/test-android.sh +``` + +### Building using Bazel + +That being said, we also maintain a Bazel BUILD system as part of gemmlowp. Its +usage is not mandatory at all and is only one possible way that gemmlowp +libraries and tests may be built. If you are interested, Bazel's home page is +http://bazel.build/ And you can get started with using Bazel to build gemmlowp +targets by first creating an empty WORKSPACE file in a parent directory, for +instance: + +``` +$ cd gemmlowp/.. # change to parent directory containing gemmlowp/ +$ touch WORKSPACE # declare that to be our workspace root +$ bazel build gemmlowp:all +``` + +## Testing + +### Testing by manually building and running tests + +The test/ directory contains unit tests. The primary unit test is + +``` +test/test.cc +``` + +Since it covers also the EightBitIntGemm interface, it needs to be linked +against + +``` +eight_bit_int_gemm/eight_bit_int_gemm.cc +``` + +It also uses realistic data captured from a neural network run in + +``` +test/test_data.cc +``` + +Thus you'll want to pass the following list of source files to your +compiler/linker: + +``` +test/test.cc +eight_bit_int_gemm/eight_bit_int_gemm.cc +test/test_data.cc +``` + +The `scripts/` directory contains a script to build and run a program on an +Android device: + +``` +scripts/test-android.sh +``` + +It expects the `CXX` environment variable to point to an Android toolchain's C++ +compiler, and expects source files (and optionally, cflags) as command-line +parameters. To build and run the above-mentioned main unit test, first set `CXX` +e.g.: + +``` +$ export CXX=/some/toolchains/arm-linux-androideabi-4.8/bin/arm-linux-androideabi-g++ +``` + +Then run: + +``` +$ ./scripts/test-android.sh \ +test/test.cc \ +eight_bit_int_gemm/eight_bit_int_gemm.cc \ +test/test_data.cc +``` + +### Testing using Bazel + +Alternatively, you can use Bazel to build and run tests. See the Bazel +instruction in the above section on building. Once your Bazel workspace is set +up, you can for instance do: + +``` +$ bazel test gemmlowp:all +``` + +## Troubleshooting Compilation + +If you're having trouble finding the compiler, follow these instructions to +build a standalone toolchain: +https://developer.android.com/ndk/guides/standalone_toolchain.html + +Here's an example of setting up Clang 3.5: + +``` +$ export INSTALL_DIR=~/toolchains/clang-21-stl-gnu +$ $NDK/build/tools/make-standalone-toolchain.sh \ +--toolchain=arm-linux-androideabi-clang3.5 --platform=android-21 \ +--install-dir=$INSTALL_DIR +$ export CXX="$INSTALL_DIR/bin/arm-linux-androideabi-g++ \ +--sysroot=$INSTALL_DIR/sysroot" +``` + +Some compilers (e.g. the default clang++ in the same bin directory) don't +support NEON assembly. The benchmark build process will issue a warning if +support isn't detected, and you should make sure you're using a compiler like +arm-linux-androideabi-g++ that does include NEON. + +## Benchmarking + +The main benchmark is + +``` +test/benchmark.cc +``` + +It doesn't need to be linked to any other source file. We recommend building +with assertions disabled (`-DNDEBUG`). + +For example, the benchmark can be built and run on an Android device by doing: + +``` +$ ./scripts/test-android.sh test/benchmark.cc -DNDEBUG +``` + +If `GEMMLOWP_TEST_PROFILE` is defined then the benchmark will be built with +profiling instrumentation (which makes it slower) and will dump profiles. See +next section on profiling. + +## Profiling + +The `profiling/` subdirectory offers a very simple, naive, inaccurate, +non-interrupting sampling profiler that only requires pthreads (no signals). + +It relies on source code being instrumented with pseudo-stack labels. See +`profiling/instrumentation.h`. A full example of using this profiler is given in +the top comment of `profiling/profiler.h`. + +## Contributing + +Contribution-related discussion is always welcome on the gemmlowp mailing list +(see above). + +We try to keep a current list of TODO items in the `todo/` directory. +Prospective contributors are welcome to pick one to work on, and communicate +about it on the gemmlowp mailing list. + +Details of the contributing process, including legalese, are in CONTRIBUTING. + +## Performance goals + +Our performance goals differ from typical GEMM performance goals in the +following ways: + +1. We care not only about speed, but also about minimizing power usage. We + specifically care about charge usage in mobile/embedded devices. This + implies that we care doubly about minimizing memory bandwidth usage: we care + about it, like any GEMM, because of the impact on speed, and we also care + about it because it is a key factor of power usage. + +2. Most GEMMs are optimized primarily for large dense matrix sizes (>= 1000). + We do care about large sizes, but we also care specifically about the + typically smaller matrix sizes encountered in various mobile applications. + This means that we have to optimize for all sizes, not just for large enough + sizes. diff --git a/README.txt b/README.txt deleted file mode 100644 index e29f0e4..0000000 --- a/README.txt +++ /dev/null @@ -1,260 +0,0 @@ -gemmlowp: a small self-contained low-precision GEMM library -=========================================================== - -This is not a full linear algebra library, only a GEMM library: it only does -general matrix multiplication ("GEMM"). - -The meaning of "low precision" is detailed in this document: - doc/low-precision.txt - -Some of the general design is explained in - doc/design.txt - - -Disclaimer -========== - -This is not an official Google product (experimental or otherwise), it is just -code that happens to be owned by Google. - - -Mailing list -============ - -gemmlowp-related discussion, about either development or usage, is welcome -on this Google Group (mailing list / forum): - - https://groups.google.com/forum/#!forum/gemmlowp - - -Portability, target platforms/architectures -=========================================== - -Should be portable to any platform with some C++11 and POSIX support, -while we have optional optimized code paths for specific architectures. - -Required: - C++11 (a small conservative subset of it) - -Required for some features: - * Some POSIX interfaces: - * pthreads (for multi-threaded operation and for profiling). - * sysconf (for multi-threaded operation to detect number of cores; - may be bypassed). - -Optional: - Architecture-specific code paths use intrinsics or inline assembly. - See "Architecture-specific optimized code paths" below. - -Architecture-specific optimized code paths -========================================== - -We have some optimized code paths for specific instruction sets. -Some are written in inline assembly, some are written in C++ using -intrinsics. Both GCC and Clang are supported. - -At the moment, we have a full set of optimized code paths (kernels, -packing and unpacking paths) only for ARM NEON, supporting both -ARMv7 (32bit) and ARMv8 (64bit). - -We also have a partial set of optimized code paths (only kernels -at the moment) for Intel SSE. It supports both x86 and x86-64 but -only targets SSE4. The lack of packing/unpacking code paths means -that performance isn't optimal yet. - -Details of what it takes to make an efficient port of gemmlowp, namely -writing a suitable GEMM kernel and accompanying packing code, are -explained in this file: - doc/kernels.txt - - -Public interfaces -================= - -1. gemmlowp public interface ----------------------------- - - gemmlowp's main public interface is in the public/ subdirectory. The - header to include is - public/gemmlowp.h. - This is a headers-only library, so there is nothing to link to. - -2. EightBitIntGemm standard interface -------------------------------------- - - Additionally, the eight_bit_int_gemm/ subdirectory provides an - implementation of the standard EightBitIntGemm interface. The header - to include is - eight_bit_int_gemm/eight_bit_int_gemm.h - This is *NOT* a headers-only library, users need to link to - eight_bit_int_gemm/eight_bit_int_gemm.cc. - The API is similar to the standard BLAS GEMM interface, and implements - C = A * B. If the transpose flags for a matrix argument are false, its memory - order is treated as column major, and row major if its true. - - -Building -======== - -Building by manually invoking your compiler -------------------------------------------- - -Because gemmlowp is so simple, working with it involves only -single-command-line compiler invokations. Therefore we expect that -most people working with gemmlowp will either manually invoke their -compiler, or write their own rules for their own preferred build -system. - -Keep in mind (previous section) that gemmlowp itself is a pure-headers-only -library so there is nothing to build, and the eight_bit_int_gemm library -consists of a single eight_bit_int_gemm.cc file to build. - -For a Android gemmlowp development workflow, the scripts/ directory -contains a script to build and run a program on an Android device: - scripts/test-android.sh - -Building using Bazel --------------------- - -That being said, we also maintain a Bazel BUILD system as part of -gemmlowp. Its usage is not mandatory at all and is only one -possible way that gemmlowp libraries and tests may be built. If -you are interested, Bazel's home page is - http://bazel.io/ -And you can get started with using Bazel to build gemmlowp targets -by first creating an empty WORKSPACE file in a parent directory, -for instance: - -$ cd gemmlowp/.. # change to parent directory containing gemmlowp/ -$ touch WORKSPACE # declare that to be our workspace root -$ bazel build gemmlowp:all - - -Testing -======= - -Testing by manually building and running tests ----------------------------------------------- - -The test/ directory contains unit tests. The primary unit test is - test/test.cc -Since it covers also the EightBitIntGemm interface, it needs to be -linked against - eight_bit_int_gemm/eight_bit_int_gemm.cc -It also uses realistic data captured from a neural network run in - test/test_data.cc - -Thus you'll want to pass the following list of source files to your -compiler/linker: - test/test.cc - eight_bit_int_gemm/eight_bit_int_gemm.cc - test/test_data.cc - -The scripts/ directory contains a script to build and run a program -on an Android device: - scripts/test-android.sh - -It expects the CXX environment variable to point to an Android toolchain's -C++ compiler, and expects source files (and optionally, cflags) as -command-line parameters. To build and run the above-mentioned main unit test, -first set CXX e.g.: - -$ export CXX=/some/toolchains/arm-linux-androideabi-4.8/bin/arm-linux-androideabi-g++ - -Then run: - -$ ./scripts/test-android.sh \ -test/test.cc \ -eight_bit_int_gemm/eight_bit_int_gemm.cc \ -test/test_data.cc - - -Testing using Bazel -------------------- - -Alternatively, you can use Bazel to build and run tests. See the Bazel -instruction in the above section on building. Once your Bazel workspace -is set up, you can for instance do: - -$ bazel test gemmlowp:all - - -Troubleshooting Compilation -=========================== - -If you're having trouble finding the compiler, follow these instructions to -build a standalone toolchain: -https://developer.android.com/ndk/guides/standalone_toolchain.html - -Here's an example of setting up Clang 3.5: - -$ export INSTALL_DIR=~/toolchains/clang-21-stl-gnu -$ $NDK/build/tools/make-standalone-toolchain.sh \ ---toolchain=arm-linux-androideabi-clang3.5 --platform=android-21 \ ---install-dir=$INSTALL_DIR -$ export CXX="$INSTALL_DIR/bin/arm-linux-androideabi-g++ \ ---sysroot=$INSTALL_DIR/sysroot" - -Some compilers (e.g. the default clang++ in the same bin directory) don't -support NEON assembly. The benchmark build process will issue a warning if -support isn't detected, and you should make sure you're using a compiler like -arm-linux-androideabi-g++ that does include NEON. - - -Benchmarking -============ - -The main benchmark is - benchmark.cc -It doesn't need to be linked to any -other source file. We recommend building with assertions disabled (-DNDEBUG). - -For example, the benchmark can be built and run on an Android device by doing: - -$ ./scripts/test-android.sh test/benchmark.cc -DNDEBUG - -If GEMMLOWP_TEST_PROFILE is defined then the benchmark will be built with -profiling instrumentation (which makes it slower) and will dump profiles. -See next section on profiling. - - -Profiling -========= - -The profiling/ subdirectory offers a very simple non-interrupting sampling -profiler that only requires pthreads (no signals). - -It relies on source code being instrumented with pseudo-stack labels. -See profiling/instrumentation.h. -A full example of using this profiler is given in profiling/profiler.h. - - -Contributing -============ - -Contribution-related discussion is always welcome on the gemmlowp -mailing list (see above). - -We try to keep a current list of TODO items in the todo/ directory. -Prospective contributors are welcome to pick one to work on, and -communicate about it on the gemmlowp mailing list. - -Details of the contributing process, including legalese, are in CONTRIBUTING. - -Performance goals -================= - -Our performance goals differ from typical GEMM performance goals in the -following ways: - -1. We care not only about speed, but also about minimizing power usage. - We specifically care about charge usage in mobile/embedded devices. - This implies that we care doubly about minimizing memory bandwidth usage: - we care about it, like any GEMM, because of the impact on speed, and we - also care about it because it is a key factor of power usage. - -2. Most GEMMs are optimized primarily for large dense matrix sizes (>= 1000). - We do care about large sizes, but we also care specifically about the - typically smaller matrix sizes encountered in various mobile applications. - This means that we have to optimize for all sizes, not just for large enough - sizes. diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h index 58e8050..56e95c0 100644 --- a/fixedpoint/fixedpoint.h +++ b/fixedpoint/fixedpoint.h @@ -95,12 +95,13 @@ tIntegerType Add(tIntegerType a, tIntegerType b) { return a + b; } -// Integer subtraction. Not saturating. Overflow is undefined behavior. +// Integer multiplication. Not saturating. Overflow is undefined behavior. template tIntegerType Mul(tIntegerType a, tIntegerType b) { return a * b; } +// Integer subtraction. Not saturating. Overflow is undefined behavior. template tIntegerType Sub(tIntegerType a, tIntegerType b) { return a - b; @@ -268,6 +269,16 @@ inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { std::max(static_cast(-32768), sum))); } +template <> +inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) { + std::int16_t a16 = a; + std::int16_t b16 = b; + std::int16_t sum = a16 + b16; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max(static_cast(std::numeric_limits::min()), sum))); +} + // Returns a+b, saturating if the integers are 16bit or narrower, // otherwise just a plain addition. template @@ -767,13 +778,14 @@ FixedPoint exp_on_negative_values( result * kMultiplier, result); \ } - GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); - GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); - GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); - GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); - GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); - GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); - GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); + // Constants below are Q0 representations of negative exp fractionals: + GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); // exp(-1/4) + GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); // exp(-1/2) + GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); // exp(-1) + GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); // exp(-2) + GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); // exp(-4) + GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); // exp(-8) + GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); // exp(-16) #undef GEMMLOWP_EXP_BARREL_SHIFTER @@ -895,6 +907,8 @@ FixedPoint logistic(FixedPoint a) { #include "./fixedpoint_sse.h" #elif defined(GEMMLOWP_MSA) #include "./fixedpoint_msa.h" +#elif defined(GEMMLOWP_WASMSIMD) +#include "./fixedpoint_wasmsimd.h" #endif #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h index 1816386..f3fe732 100644 --- a/fixedpoint/fixedpoint_avx.h +++ b/fixedpoint/fixedpoint_avx.h @@ -17,68 +17,138 @@ #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ #define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ -#include +#include #include "fixedpoint.h" #include "fixedpoint_sse.h" namespace gemmlowp { +struct int16x16_m256i { + __m256i v; +}; + +// Keep int16x16_m256i trivially constructible/destructible and provide +// easily optimized helper function. +inline int16x16_m256i to_int16x16_m256i(__m256i w) { + int16x16_m256i r; + r.v = w; + return r; +} + template <> struct FixedPointRawTypeTraits<__m256i> { typedef std::int32_t ScalarRawType; + // TODO: This can actually support up to 8 lanes, so we should either + // change to 8 or create int32x8_m256i struct to handle that case. static const int kLanes = 4; }; +template <> +struct FixedPointRawTypeTraits { + typedef std::int16_t ScalarRawType; + static const int kLanes = 16; +}; + template <> inline __m256i BitAnd(__m256i a, __m256i b) { return _mm256_and_si256(a, b); } +template <> +inline int16x16_m256i BitAnd(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_and_si256(a.v, b.v)); +} + template <> inline __m256i BitOr(__m256i a, __m256i b) { return _mm256_or_si256(a, b); } +template <> +inline int16x16_m256i BitOr(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_or_si256(a.v, b.v)); +} + template <> inline __m256i BitXor(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); } +template <> +inline int16x16_m256i BitXor(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_xor_si256(a.v, b.v)); +} + template <> inline __m256i BitNot(__m256i a) { return _mm256_andnot_si256(a, _mm256_set1_epi32(-1)); } +template <> +inline int16x16_m256i BitNot(int16x16_m256i a) { + return to_int16x16_m256i(_mm256_andnot_si256(a.v, _mm256_set1_epi16(-1))); +} + template <> inline __m256i Add(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); } +template <> +inline int16x16_m256i Add(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_add_epi16(a.v, b.v)); +} + template <> inline __m256i Mul(__m256i a, __m256i b) { return _mm256_mullo_epi32(a, b); } +template <> +inline int16x16_m256i Mul(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_mullo_epi16(a.v, b.v)); +} + template <> inline __m256i Sub(__m256i a, __m256i b) { return _mm256_sub_epi32(a, b); } +template <> +inline int16x16_m256i Sub(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_sub_epi16(a.v, b.v)); +} + template <> inline __m256i Neg(__m256i a) { return _mm256_sign_epi32(a, _mm256_set1_epi32(-1)); } +template <> +inline int16x16_m256i Neg(int16x16_m256i a) { + return to_int16x16_m256i(_mm256_sign_epi16(a.v, _mm256_set1_epi16(-1))); +} + template <> inline __m256i ShiftLeft(__m256i a, int offset) { return _mm256_slli_epi32(a, offset); } +template <> +inline int16x16_m256i ShiftLeft(int16x16_m256i a, int offset) { + return to_int16x16_m256i(_mm256_slli_epi16(a.v, offset)); +} + template <> inline __m256i ShiftRight(__m256i a, int offset) { return _mm256_srai_epi32(a, offset); } +template <> +inline int16x16_m256i ShiftRight(int16x16_m256i a, int offset) { + return to_int16x16_m256i(_mm256_srai_epi16(a.v, offset)); +} + template <> inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val, __m256i else_val) { @@ -87,46 +157,98 @@ inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val, _mm256_castsi256_ps(if_mask))); } +template <> +inline int16x16_m256i SelectUsingMask(int16x16_m256i if_mask, + int16x16_m256i then_val, + int16x16_m256i else_val) { + // Borrowed from Intel's arm_neon_sse.h header. + return to_int16x16_m256i( + _mm256_or_si256(_mm256_and_si256(if_mask.v, then_val.v), + _mm256_andnot_si256(if_mask.v, else_val.v))); +} + template <> inline __m256i MaskIfEqual(__m256i a, __m256i b) { return _mm256_cmpeq_epi32(a, b); } +template <> +inline int16x16_m256i MaskIfEqual(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_cmpeq_epi16(a.v, b.v)); +} + template <> inline __m256i MaskIfNotEqual(__m256i a, __m256i b) { return BitNot(MaskIfEqual(a, b)); } +template <> +inline int16x16_m256i MaskIfNotEqual(int16x16_m256i a, int16x16_m256i b) { + return BitNot(MaskIfEqual(a, b)); +} + template <> inline __m256i MaskIfZero(__m256i a) { return MaskIfEqual(a, _mm256_set1_epi32(0)); } +template <> +inline int16x16_m256i MaskIfZero(int16x16_m256i a) { + return MaskIfEqual(a, to_int16x16_m256i(_mm256_set1_epi16(0))); +} + template <> inline __m256i MaskIfNonZero(__m256i a) { return MaskIfNotEqual(a, _mm256_set1_epi32(0)); } +template <> +inline int16x16_m256i MaskIfNonZero(int16x16_m256i a) { + return MaskIfNotEqual(a, to_int16x16_m256i(_mm256_set1_epi16(0))); +} + template <> inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) { return _mm256_cmpgt_epi32(a, b); } +template <> +inline int16x16_m256i MaskIfGreaterThan(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_cmpgt_epi16(a.v, b.v)); +} + template <> inline __m256i MaskIfLessThan(__m256i a, __m256i b) { return _mm256_cmpgt_epi32(b, a); } +template <> +inline int16x16_m256i MaskIfLessThan(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_cmpgt_epi16(b.v, a.v)); +} + template <> inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) { return BitNot(MaskIfLessThan(a, b)); } +template <> +inline int16x16_m256i MaskIfGreaterThanOrEqual(int16x16_m256i a, + int16x16_m256i b) { + return BitNot(MaskIfLessThan(a, b)); +} + template <> inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) { return BitNot(MaskIfGreaterThan(a, b)); } +template <> +inline int16x16_m256i MaskIfLessThanOrEqual(int16x16_m256i a, + int16x16_m256i b) { + return BitNot(MaskIfGreaterThan(a, b)); +} + /* Assumptions: - All and Any are used on masks. - masks are all_ones for true lanes, all_zeroes otherwise. @@ -138,11 +260,21 @@ inline bool All(__m256i a) { return _mm256_testc_si256(a, a); } +template <> +inline bool All(int16x16_m256i a) { + return _mm256_testc_si256(a.v, a.v); +} + template <> inline bool Any(__m256i a) { return BitNot(_mm256_testz_si256(a, a)); } +template <> +inline bool Any(int16x16_m256i a) { + return BitNot(_mm256_testz_si256(a.v, a.v)); +} + template <> inline __m256i RoundingHalfSum(__m256i a, __m256i b) { /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */ @@ -170,6 +302,17 @@ inline __m256i RoundingHalfSum(__m256i a, __m256i b) { return result; } +template <> +inline int16x16_m256i RoundingHalfSum(int16x16_m256i a, int16x16_m256i b) { + // Borrowed from Intel's arm_neon_sse.h header. + __m256i constant_neg_32768 = _mm256_set1_epi16(-32768); + __m256i a_unsigned = _mm256_sub_epi16(a.v, constant_neg_32768); + __m256i b_unsigned = _mm256_sub_epi16(b.v, constant_neg_32768); + __m256i avg_unsigned = _mm256_avg_epu16(a_unsigned, b_unsigned); + __m256i avg = _mm256_add_epi16(avg_unsigned, constant_neg_32768); + return to_int16x16_m256i(avg); +} + template <> inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) { __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3; @@ -208,11 +351,34 @@ inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) { return SelectUsingMask(saturation_mask, min, result); } +template <> +inline int16x16_m256i SaturatingRoundingDoublingHighMul(int16x16_m256i a, + int16x16_m256i b) { + // Use _mm256_mulhrs_epi16 then saturate with a bit-operation, + // borrowed from Intel's arm_neon_sse.h header. + __m256i result_unsaturated = _mm256_mulhrs_epi16(a.v, b.v); + __m256i saturation_mask = + _mm256_cmpeq_epi16(result_unsaturated, _mm256_set1_epi16(0x8000)); + __m256i result = _mm256_xor_si256(result_unsaturated, saturation_mask); + return to_int16x16_m256i(result); +} + template <> inline __m256i Dup<__m256i>(std::int32_t x) { return _mm256_set1_epi32(x); } +template <> +inline int16x16_m256i Dup(std::int16_t x) { + return to_int16x16_m256i(_mm256_set1_epi16(x)); +} + +// So far this is only needed for int16. +template <> +inline int16x16_m256i SaturatingAdd(int16x16_m256i a, int16x16_m256i b) { + return to_int16x16_m256i(_mm256_adds_epi16(a.v, b.v)); +} + } // end namespace gemmlowp #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h index a1fae32..fbaa26a 100644 --- a/fixedpoint/fixedpoint_sse.h +++ b/fixedpoint/fixedpoint_sse.h @@ -32,13 +32,17 @@ namespace gemmlowp { // data type, int16x8_m128i, that wraps __m128i while being a separate // type. struct int16x8_m128i { - int16x8_m128i() {} - explicit int16x8_m128i(__m128i w) : v(w) {} - ~int16x8_m128i() {} - __m128i v; }; +// Keep int16x8_m128i trivially constructible/destructible and provide +// easily optimized helper function. +inline int16x8_m128i to_int16x8_m128i(__m128i w) { + int16x8_m128i r; + r.v = w; + return r; +} + template <> struct FixedPointRawTypeTraits<__m128i> { typedef std::int32_t ScalarRawType; @@ -58,7 +62,7 @@ inline __m128i BitAnd(__m128i a, __m128i b) { template <> inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_and_si128(a.v, b.v)); + return to_int16x8_m128i(_mm_and_si128(a.v, b.v)); } template <> @@ -68,7 +72,7 @@ inline __m128i BitOr(__m128i a, __m128i b) { template <> inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_or_si128(a.v, b.v)); + return to_int16x8_m128i(_mm_or_si128(a.v, b.v)); } template <> @@ -78,7 +82,7 @@ inline __m128i BitXor(__m128i a, __m128i b) { template <> inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_xor_si128(a.v, b.v)); + return to_int16x8_m128i(_mm_xor_si128(a.v, b.v)); } template <> @@ -88,7 +92,7 @@ inline __m128i BitNot(__m128i a) { template <> inline int16x8_m128i BitNot(int16x8_m128i a) { - return int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1))); + return to_int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1))); } template <> @@ -98,7 +102,7 @@ inline __m128i Add(__m128i a, __m128i b) { template <> inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_add_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_add_epi16(a.v, b.v)); } template <> @@ -108,7 +112,7 @@ inline __m128i Mul(__m128i a, __m128i b) { template <> inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_mullo_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_mullo_epi16(a.v, b.v)); } template <> @@ -118,7 +122,7 @@ inline __m128i Sub(__m128i a, __m128i b) { template <> inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_sub_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_sub_epi16(a.v, b.v)); } template <> @@ -128,7 +132,7 @@ inline __m128i Neg(__m128i a) { template <> inline int16x8_m128i Neg(int16x8_m128i a) { - return int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1))); + return to_int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1))); } template <> @@ -138,7 +142,7 @@ inline __m128i ShiftLeft(__m128i a, int offset) { template <> inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) { - return int16x8_m128i(_mm_slli_epi16(a.v, offset)); + return to_int16x8_m128i(_mm_slli_epi16(a.v, offset)); } template <> @@ -148,7 +152,7 @@ inline __m128i ShiftRight(__m128i a, int offset) { template <> inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) { - return int16x8_m128i(_mm_srai_epi16(a.v, offset)); + return to_int16x8_m128i(_mm_srai_epi16(a.v, offset)); } template <> @@ -164,7 +168,7 @@ inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask, int16x8_m128i then_val, int16x8_m128i else_val) { // borrowed from Intel's arm_neon_sse.h header. - return int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v)); + return to_int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v)); } template <> @@ -174,7 +178,7 @@ inline __m128i MaskIfEqual(__m128i a, __m128i b) { template <> inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v)); } template <> @@ -194,7 +198,7 @@ inline __m128i MaskIfZero(__m128i a) { template <> inline int16x8_m128i MaskIfZero(int16x8_m128i a) { - return MaskIfEqual(a, int16x8_m128i(_mm_set1_epi16(0))); + return MaskIfEqual(a, to_int16x8_m128i(_mm_set1_epi16(0))); } template <> @@ -204,7 +208,7 @@ inline __m128i MaskIfNonZero(__m128i a) { template <> inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) { - return MaskIfNotEqual(a, int16x8_m128i(_mm_set1_epi16(0))); + return MaskIfNotEqual(a, to_int16x8_m128i(_mm_set1_epi16(0))); } template <> @@ -214,7 +218,7 @@ inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) { template <> inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v)); } template <> @@ -224,7 +228,7 @@ inline __m128i MaskIfLessThan(__m128i a, __m128i b) { template <> inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_cmplt_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_cmplt_epi16(a.v, b.v)); } template <> @@ -310,7 +314,7 @@ inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) { __m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768); __m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned); __m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768); - return int16x8_m128i(avg); + return to_int16x8_m128i(avg); } template <> @@ -360,7 +364,7 @@ inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a, __m128i saturation_mask = _mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000)); __m128i result = _mm_xor_si128(result_unsaturated, saturation_mask); - return int16x8_m128i(result); + return to_int16x8_m128i(result); } template <> @@ -370,13 +374,13 @@ inline __m128i Dup<__m128i>(std::int32_t x) { template <> inline int16x8_m128i Dup(std::int16_t x) { - return int16x8_m128i(_mm_set1_epi16(x)); + return to_int16x8_m128i(_mm_set1_epi16(x)); } // So far this is only needed for int16. template <> inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) { - return int16x8_m128i(_mm_adds_epi16(a.v, b.v)); + return to_int16x8_m128i(_mm_adds_epi16(a.v, b.v)); } } // end namespace gemmlowp diff --git a/fixedpoint/fixedpoint_wasmsimd.h b/fixedpoint/fixedpoint_wasmsimd.h new file mode 100644 index 0000000..868fbfe --- /dev/null +++ b/fixedpoint/fixedpoint_wasmsimd.h @@ -0,0 +1,381 @@ +// Copyright 2020 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// fixedpoint_wasmsimd.h: optimized WAsm SIMD specializations of the templates +// in fixedpoint.h. + +#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_ +#define GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_ + +#include + +namespace gemmlowp { + +// WAsm SIMD intrinsics are not typed: there is a single v128_t vector +// type that does not distinguish between "int32x4" and "int16x8" use +// cases, unlike the NEON equivalents. Because we had initially focused +// on int32x4, we did not pay attention and specialized these fixedpoint +// templates directly for v128_t hardcoding the int32x4 semantics, +// not leaving room for int16x8 semantics. Amending that by adding a separate +// data type, int16x8_v128_t, that wraps v128_t while being a separate +// type. +struct int16x8_v128_t { + v128_t v; +}; + +// Keep int16x8_v128_t trivially constructible/destructible and provide +// easily optimized helper function. +inline int16x8_v128_t to_int16x8_v128_t(v128_t w) { + int16x8_v128_t r; + r.v = w; + return r; +} + +template <> +struct FixedPointRawTypeTraits { + typedef std::int32_t ScalarRawType; + static constexpr int kLanes = 4; +}; + +template <> +struct FixedPointRawTypeTraits { + typedef std::int16_t ScalarRawType; + static constexpr int kLanes = 8; +}; + +template <> +inline v128_t BitAnd(v128_t a, v128_t b) { + return wasm_v128_and(a, b); +} + +template <> +inline int16x8_v128_t BitAnd(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_v128_and(a.v, b.v)); +} + +template <> +inline v128_t BitOr(v128_t a, v128_t b) { + return wasm_v128_or(a, b); +} + +template <> +inline int16x8_v128_t BitOr(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_v128_or(a.v, b.v)); +} + +template <> +inline v128_t BitXor(v128_t a, v128_t b) { + return wasm_v128_xor(a, b); +} + +template <> +inline int16x8_v128_t BitXor(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_v128_xor(a.v, b.v)); +} + +template <> +inline v128_t BitNot(v128_t a) { + return wasm_v128_not(a); +} + +template <> +inline int16x8_v128_t BitNot(int16x8_v128_t a) { + return to_int16x8_v128_t(wasm_v128_not(a.v)); +} + +template <> +inline v128_t Add(v128_t a, v128_t b) { + return wasm_i32x4_add(a, b); +} + +template <> +inline int16x8_v128_t Add(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_add(a.v, b.v)); +} + +template <> +inline v128_t Mul(v128_t a, v128_t b) { + return wasm_i32x4_mul(a, b); +} + +template <> +inline int16x8_v128_t Mul(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_mul(a.v, b.v)); +} + +template <> +inline v128_t Sub(v128_t a, v128_t b) { + return wasm_i32x4_sub(a, b); +} + +template <> +inline int16x8_v128_t Sub(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_sub(a.v, b.v)); +} + +template <> +inline v128_t Neg(v128_t a) { + return wasm_i32x4_neg(a); +} + +template <> +inline int16x8_v128_t Neg(int16x8_v128_t a) { + return to_int16x8_v128_t(wasm_i16x8_neg(a.v)); +} + +template <> +inline v128_t ShiftLeft(v128_t a, int offset) { + return wasm_i32x4_shl(a, offset); +} + +template <> +inline int16x8_v128_t ShiftLeft(int16x8_v128_t a, int offset) { + return to_int16x8_v128_t(wasm_i16x8_shl(a.v, offset)); +} + +template <> +inline v128_t ShiftRight(v128_t a, int offset) { + return wasm_i32x4_shr(a, offset); +} + +template <> +inline int16x8_v128_t ShiftRight(int16x8_v128_t a, int offset) { + return to_int16x8_v128_t(wasm_i16x8_shr(a.v, offset)); +} + +template <> +inline v128_t SelectUsingMask(v128_t if_mask, v128_t then_val, + v128_t else_val) { + return wasm_v128_bitselect(then_val, else_val, if_mask); +} + +template <> +inline int16x8_v128_t SelectUsingMask(int16x8_v128_t if_mask, + int16x8_v128_t then_val, + int16x8_v128_t else_val) { + return to_int16x8_v128_t( + wasm_v128_bitselect(then_val.v, else_val.v, if_mask.v)); +} + +template <> +inline v128_t MaskIfEqual(v128_t a, v128_t b) { + return wasm_i32x4_eq(a, b); +} + +template <> +inline int16x8_v128_t MaskIfEqual(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_eq(a.v, b.v)); +} + +template <> +inline v128_t MaskIfNotEqual(v128_t a, v128_t b) { + return wasm_i32x4_ne(a, b); +} + +template <> +inline int16x8_v128_t MaskIfNotEqual(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_ne(a.v, b.v)); +} + +template <> +inline v128_t MaskIfZero(v128_t a) { + return MaskIfEqual(a, wasm_i32x4_const(0, 0, 0, 0)); +} + +template <> +inline int16x8_v128_t MaskIfZero(int16x8_v128_t a) { + return MaskIfEqual( + a, to_int16x8_v128_t(wasm_i16x8_const(0, 0, 0, 0, 0, 0, 0, 0))); +} + +template <> +inline v128_t MaskIfNonZero(v128_t a) { + return MaskIfNotEqual(a, wasm_i32x4_const(0, 0, 0, 0)); +} + +template <> +inline int16x8_v128_t MaskIfNonZero(int16x8_v128_t a) { + return MaskIfNotEqual( + a, to_int16x8_v128_t(wasm_i16x8_const(0, 0, 0, 0, 0, 0, 0, 0))); +} + +template <> +inline v128_t MaskIfGreaterThan(v128_t a, v128_t b) { + return wasm_i32x4_gt(a, b); +} + +template <> +inline int16x8_v128_t MaskIfGreaterThan(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_gt(a.v, b.v)); +} + +template <> +inline v128_t MaskIfLessThan(v128_t a, v128_t b) { + return wasm_i32x4_lt(a, b); +} + +template <> +inline int16x8_v128_t MaskIfLessThan(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_lt(a.v, b.v)); +} + +template <> +inline v128_t MaskIfGreaterThanOrEqual(v128_t a, v128_t b) { + return wasm_i32x4_ge(a, b); +} + +template <> +inline int16x8_v128_t MaskIfGreaterThanOrEqual(int16x8_v128_t a, + int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_ge(a.v, b.v)); +} + +template <> +inline v128_t MaskIfLessThanOrEqual(v128_t a, v128_t b) { + return wasm_i32x4_le(a, b); +} + +template <> +inline int16x8_v128_t MaskIfLessThanOrEqual(int16x8_v128_t a, + int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_le(a.v, b.v)); +} + +/* Assumptions: + - All and Any are used on masks. + - masks are all_ones for true lanes, all_zeroes otherwise. +Hence, All means all 128bits set, and Any means any bit set. +*/ + +template <> +inline bool All(v128_t a) { + return wasm_i32x4_all_true(a); +} + +template <> +inline bool All(int16x8_v128_t a) { + return wasm_i16x8_all_true(a.v); +} + +template <> +inline bool Any(v128_t a) { + return wasm_i32x4_any_true(a); +} + +template <> +inline bool Any(int16x8_v128_t a) { + return wasm_i16x8_any_true(a.v); +} + +template <> +inline v128_t RoundingHalfSum(v128_t a, v128_t b) { + // We divide the inputs before the add to avoid the overflow and costly test. + const v128_t one = wasm_i32x4_const(1, 1, 1, 1); + const v128_t sign_bit_mask = + wasm_i32x4_const(0x80000000, 0x80000000, 0x80000000, 0x80000000); + const v128_t sum = Add(a, b); + const v128_t rounded_half_sum = ShiftRight(Add(sum, one), 1); + const v128_t overflow = + BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)), + sign_bit_mask); + const v128_t result = BitXor(rounded_half_sum, overflow); + return result; +} + +template <> +inline int16x8_v128_t RoundingHalfSum(int16x8_v128_t a, int16x8_v128_t b) { + // Idea: go to unsigned to use wasm_u16x8_avgr, + // borrowed from Intel's arm_neon_sse.h header. + const v128_t constant_neg_32768 = wasm_i16x8_const( + -32768, -32768, -32768, -32768, -32768, -32768, -32768, -32768); + const v128_t a_unsigned = wasm_v128_xor(a.v, constant_neg_32768); + const v128_t b_unsigned = wasm_v128_xor(b.v, constant_neg_32768); + const v128_t avg_unsigned = wasm_u16x8_avgr(a_unsigned, b_unsigned); + const v128_t avg = wasm_v128_xor(avg_unsigned, constant_neg_32768); + return to_int16x8_v128_t(avg); +} + +template <> +inline v128_t SaturatingRoundingDoublingHighMul(v128_t a, v128_t b) { + // TODO: switch to extended multiplication once implemented in the toolchain + const v128_t a_sign = wasm_i32x4_shr(a, 31); + const v128_t b_sign = wasm_i32x4_shr(b, 31); + + const v128_t a_ext_lo = wasm_v32x4_shuffle(a, a_sign, 0, 4, 1, 5); + const v128_t a_ext_hi = wasm_v32x4_shuffle(a, a_sign, 2, 6, 3, 7); + const v128_t b_ext_lo = wasm_v32x4_shuffle(b, b_sign, 0, 4, 1, 5); + const v128_t b_ext_hi = wasm_v32x4_shuffle(b, b_sign, 2, 6, 3, 7); + + const v128_t ab_lo = wasm_i64x2_mul(a_ext_lo, b_ext_lo); + const v128_t ab_hi = wasm_i64x2_mul(a_ext_hi, b_ext_hi); + + const v128_t nudge_2x = + wasm_i64x2_const(INT64_C(0x80000000), INT64_C(0x80000000)); + const v128_t ab_lo_2x = wasm_i64x2_add(ab_lo, ab_lo); + const v128_t ab_hi_2x = wasm_i64x2_add(ab_hi, ab_hi); + + const v128_t ab_lo_rounded_2x = wasm_i64x2_add(ab_lo_2x, nudge_2x); + const v128_t ab_hi_rounded_2x = wasm_i64x2_add(ab_hi_2x, nudge_2x); + + const v128_t prod = + wasm_v32x4_shuffle(ab_lo_rounded_2x, ab_hi_rounded_2x, 1, 3, 5, 7); + + // Saturation only happen if a == b == INT_MIN, and this is the only case + // where prod == INT_MIN (0x80000000) instead of INT_MAX (0x7FFFFFFF). + const v128_t min = wasm_i32x4_const(INT32_C(0x80000000), INT32_C(0x80000000), + INT32_C(0x80000000), INT32_C(0x80000000)); + + return wasm_v128_xor(prod, wasm_i32x4_eq(prod, min)); +} + +template <> +inline int16x8_v128_t SaturatingRoundingDoublingHighMul(int16x8_v128_t a, + int16x8_v128_t b) { +#if 0 + // TODO: enable if https://github.com/WebAssembly/simd/pull/365 is accepted + return to_int16x8_v128_t(__builtin_wasm_q15mulr_saturate_s_i16x8(a.v, b.v)); +#else + // TODO: switch to extended multiplication once implemented in the toolchain + v128_t lo = wasm_i32x4_mul(wasm_i32x4_widen_low_i16x8(a.v), + wasm_i32x4_widen_low_i16x8(b.v)); + v128_t hi = wasm_i32x4_mul(wasm_i32x4_widen_high_i16x8(a.v), + wasm_i32x4_widen_high_i16x8(b.v)); + const v128_t inc = wasm_i32x4_const(0x4000, 0x4000, 0x4000, 0x4000); + lo = wasm_i32x4_add(lo, inc); + hi = wasm_i32x4_add(hi, inc); + lo = wasm_i32x4_shr(lo, 15); + hi = wasm_i32x4_shr(hi, 15); + return to_int16x8_v128_t(wasm_i16x8_narrow_i32x4(lo, hi)); +#endif +} + +template <> +inline v128_t Dup(std::int32_t x) { + return wasm_i32x4_splat(x); +} + +template <> +inline int16x8_v128_t Dup(std::int16_t x) { + return to_int16x8_v128_t(wasm_i16x8_splat(x)); +} + +// So far this is only needed for int16. +template <> +inline int16x8_v128_t SaturatingAdd(int16x8_v128_t a, int16x8_v128_t b) { + return to_int16x8_v128_t(wasm_i16x8_add_saturate(a.v, b.v)); +} + +} // end namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_WASMSIMD_H_ diff --git a/flags.bzl b/flags.bzl index 16dba2d..e35fe9e 100644 --- a/flags.bzl +++ b/flags.bzl @@ -3,10 +3,9 @@ LIB_COPTS = [] LIB_LINKOPTS = select({ ":android": [], + ":windows": [], "//conditions:default": ["-lpthread"], }) -BIN_LINKOPTS = select({ - ":android": [], - "//conditions:default": ["-lpthread"], -}) +BIN_LINKOPTS = LIB_LINKOPTS + diff --git a/internal/allocator.h b/internal/allocator.h index 3a6f077..e71df15 100644 --- a/internal/allocator.h +++ b/internal/allocator.h @@ -86,11 +86,11 @@ class Allocator { } // Alignment of allocated blocks. - static const std::size_t kAlignment = kDefaultCacheLineSize; + static constexpr std::size_t kAlignment = kDefaultCacheLineSize; // This is all we need so far, and since the usage pattern is fixed, // there is no point in allowing more until we need to. - static const std::size_t kMaxBlocks = 5; + static constexpr std::size_t kMaxBlocks = 5; void Commit() { assert(!committed_); diff --git a/internal/common.h b/internal/common.h index 332ad07..708cc40 100644 --- a/internal/common.h +++ b/internal/common.h @@ -165,7 +165,7 @@ Integer RoundUpToPowerOfTwo(Integer n) { template struct IsPowerOfTwo { - static const bool value = !(N & (N - 1)); + static constexpr bool value = !(N & (N - 1)); }; template diff --git a/internal/detect_platform.h b/internal/detect_platform.h index 6f06d19..7f0d78c 100644 --- a/internal/detect_platform.h +++ b/internal/detect_platform.h @@ -71,6 +71,11 @@ #define GEMMLOWP_X86 #endif +// Detect WebAssembly SIMD. +#if defined(__wasm_simd128__) +#define GEMMLOWP_WASMSIMD +#endif + // Some of our optimized paths use inline assembly and for // now we don't bother enabling some other optimized paths using intrinddics // where we can't use inline assembly paths. diff --git a/internal/dispatch_gemm_shape.h b/internal/dispatch_gemm_shape.h index ba4f341..b844f78 100644 --- a/internal/dispatch_gemm_shape.h +++ b/internal/dispatch_gemm_shape.h @@ -74,7 +74,8 @@ struct TransposeImpl> { template struct TransposeImpl> { typedef OutputStageQuantizeDownInt32ToUint8ScalePC SrcType; - static const VectorShape TransposedShape = TransposeVectorShape::Value; + static constexpr VectorShape TransposedShape = + TransposeVectorShape::Value; typedef OutputStageQuantizeDownInt32ToUint8ScalePC DstType; static DstType Run(const SrcType& src) { DstType dst; @@ -88,7 +89,8 @@ struct TransposeImpl> { template struct TransposeImpl> { typedef OutputStageScaleInt32ByFixedPointAndExponentPC SrcType; - static const VectorShape TransposedShape = TransposeVectorShape::Value; + static constexpr VectorShape TransposedShape = + TransposeVectorShape::Value; typedef OutputStageScaleInt32ByFixedPointAndExponentPC DstType; static DstType Run(const SrcType& src) { diff --git a/internal/kernel.h b/internal/kernel.h index 3120216..f1a3fd8 100644 --- a/internal/kernel.h +++ b/internal/kernel.h @@ -126,11 +126,11 @@ enum class CellOrder { DepthMajor, WidthMajor, Diagonal }; // out in a cell. That is, a CellOrder together with actual dimensions. template struct CellFormat { - static const int kWidth = tWidth; - static const int kDepth = tDepth; - static const CellOrder kOrder = tOrder; + static constexpr int kWidth = tWidth; + static constexpr int kDepth = tDepth; + static constexpr CellOrder kOrder = tOrder; - static const int kSize = kWidth * kDepth; + static constexpr int kSize = kWidth * kDepth; }; // KernelSideFormat describes how data is laid out in a kernel side @@ -142,9 +142,9 @@ struct CellFormat { template struct KernelSideFormat { typedef tCellFormat Cell; - static const int kCells = tCells; - static const int kWidth = kCells * Cell::kWidth; - static const int kDepth = Cell::kDepth; + static constexpr int kCells = tCells; + static constexpr int kWidth = kCells * Cell::kWidth; + static constexpr int kDepth = Cell::kDepth; typedef std::uint8_t Scalar; // The scalar type of the Format. typedef std::uint8_t InputScalar; // The scalar type of the original input. }; @@ -173,9 +173,9 @@ struct KernelFormat { typedef tRhs Rhs; static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, ""); - static const int kDepth = Lhs::Cell::kDepth; - static const int kRows = Lhs::Cell::kWidth * Lhs::kCells; - static const int kCols = Rhs::Cell::kWidth * Rhs::kCells; + static constexpr int kDepth = Lhs::Cell::kDepth; + static constexpr int kRows = Lhs::Cell::kWidth * Lhs::kCells; + static constexpr int kCols = Rhs::Cell::kWidth * Rhs::kCells; }; inline const char* CellOrderName(CellOrder o) { diff --git a/internal/output_sse.h b/internal/output_sse.h index 75aebfd..6ea3290 100644 --- a/internal/output_sse.h +++ b/internal/output_sse.h @@ -535,6 +535,27 @@ struct StoreFinalOutputImpl, DstType> { } }; +// Specialization for MatrixMap, for performance. +template +struct StoreFinalOutputImpl, MatrixMap> { + static void Run(const RegBlockUint8<8, 8>& src, + MatrixMap* dst, int row, int col) { + std::uint8_t buf[64]; + StoreUint8x16(buf, src.buf.reg[0]); + StoreUint8x16(buf + 16, src.buf.reg[1]); + StoreUint8x16(buf + 32, src.buf.reg[2]); + StoreUint8x16(buf + 48, src.buf.reg[3]); + // Make a local copy so that the compiler can prove that data_ does not + // alias &data_ or &stride_. + MatrixMap local = *dst; + for (int c = 0; c < 8; c++) { + for (int r = 0; r < 8; r++) { + *local.data(row + r, col + c) = buf[r + 8 * c]; + } + } + } +}; + } // namespace gemmlowp #endif // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ diff --git a/internal/pack.h b/internal/pack.h index 7c43d6e..82f0dd1 100644 --- a/internal/pack.h +++ b/internal/pack.h @@ -143,7 +143,7 @@ template class SideMap { public: typedef tScalar Scalar; - static const SideMapOrder kOrder = tOrder; + static constexpr SideMapOrder kOrder = tOrder; SideMap(Scalar* data, int width, int depth, int stride) : data_(data), width_(width), depth_(depth), stride_(stride) {} @@ -214,13 +214,13 @@ class PackingRegisterBlockBase { typedef typename KernelSideFormat::Cell CellFormat; typedef typename KernelSideFormat::InputScalar KernelInputScalar; typedef typename KernelSideFormat::Scalar KernelScalar; - static const int kCells = KernelSideFormat::kCells; - static const int kCellWidth = CellFormat::kWidth; - static const int kKernelWidth = CellFormat::kWidth * kCells; - static const int kCellDepth = CellFormat::kDepth; - static const int kCellSize = CellFormat::kSize; - static const SideMapOrder kSrcOrder = SrcMapType::kOrder; - static const int kZeroPointInputValue = + static constexpr int kCells = KernelSideFormat::kCells; + static constexpr int kCellWidth = CellFormat::kWidth; + static constexpr int kKernelWidth = CellFormat::kWidth * kCells; + static constexpr int kCellDepth = CellFormat::kDepth; + static constexpr int kCellSize = CellFormat::kSize; + static constexpr SideMapOrder kSrcOrder = SrcMapType::kOrder; + static constexpr int kZeroPointInputValue = ZeroPointInputValue::kValue; PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {} @@ -302,10 +302,10 @@ class PackSideBlockImpl { public: typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat; typedef typename KernelSideFormat::Cell CellFormat; - static const int kCells = KernelSideFormat::kCells; - static const int kCellWidth = CellFormat::kWidth; - static const int kKernelWidth = CellFormat::kWidth * kCells; - static const int kCellDepth = CellFormat::kDepth; + static constexpr int kCells = KernelSideFormat::kCells; + static constexpr int kCellWidth = CellFormat::kWidth; + static constexpr int kKernelWidth = CellFormat::kWidth * kCells; + static constexpr int kCellDepth = CellFormat::kDepth; typedef PackingRegisterBlock PackingRegisterBlockType; diff --git a/internal/pack_sse.h b/internal/pack_sse.h index 52163c4..b729014 100644 --- a/internal/pack_sse.h +++ b/internal/pack_sse.h @@ -41,11 +41,11 @@ class PackingRegisterBlock< public: typedef WidthMajorSideFormatNCells4x2 KernelSideFormat; typedef typename KernelSideFormat::Cell CellFormat; - static const int kCells = KernelSideFormat::kCells; - static const int kCellWidth = CellFormat::kWidth; - static const int kKernelWidth = CellFormat::kWidth * kCells; - static const int kCellDepth = CellFormat::kDepth; - static const int kCellSize = CellFormat::kSize; + static constexpr int kCells = KernelSideFormat::kCells; + static constexpr int kCellWidth = CellFormat::kWidth; + static constexpr int kKernelWidth = CellFormat::kWidth * kCells; + static constexpr int kCellDepth = CellFormat::kDepth; + static constexpr int kCellSize = CellFormat::kSize; void Pack(PackedSideBlock* dst, int start_width) { std::uint8_t* dst_ptr = dst->current_data(); diff --git a/internal/platform.h b/internal/platform.h index 54517c3..0f3a2b8 100644 --- a/internal/platform.h +++ b/internal/platform.h @@ -30,8 +30,7 @@ #include #endif -#if defined __ANDROID__ -#include +#if defined ANDROID || defined __ANDROID__ #include // The 18 here should be 16, but has to be 18 for now due // to a Google-internal issue. diff --git a/meta/generators/cc_emitter.py b/meta/generators/cc_emitter.py index 8615671..c1dc75d 100644 --- a/meta/generators/cc_emitter.py +++ b/meta/generators/cc_emitter.py @@ -52,16 +52,16 @@ class CCEmitter(object): self.indent = self.indent[:-2] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff --git a/meta/generators/common.py b/meta/generators/common.py index d680372..7269b50 100644 --- a/meta/generators/common.py +++ b/meta/generators/common.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """.""" +import collections _HEADER_COPYRIGHT = ( '''// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. @@ -71,7 +72,7 @@ class StreamGenerator(object): self.emitter = emitter def SpecializeStream(self, in_type, lanes_count, pack_size, leftovers): - if callable(getattr(self, 'EmitPack', None)): + if isinstance(getattr(self, 'EmitPack', None), collections.Callable): template_params = [in_type, lanes_count, pack_size, leftovers, self.name] self.emitter.EmitMemberFunctionBegin( 'Stream', [], template_params, 'Pack', diff --git a/meta/generators/neon_emitter.py b/meta/generators/neon_emitter.py index 726766e..0304317 100644 --- a/meta/generators/neon_emitter.py +++ b/meta/generators/neon_emitter.py @@ -187,7 +187,7 @@ class NeonEmitter(object): self.indent = self.indent[:-delta] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def PushOp(self, op): if op in self.ops.keys(): @@ -199,13 +199,13 @@ class NeonEmitter(object): self.ops.clear() def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff --git a/meta/generators/neon_emitter_64.py b/meta/generators/neon_emitter_64.py index 13a0715..956b16b 100644 --- a/meta/generators/neon_emitter_64.py +++ b/meta/generators/neon_emitter_64.py @@ -423,7 +423,7 @@ class NeonEmitter64(object): self.indent = self.indent[:-delta] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def PushOp(self, op): if op in self.ops.keys(): @@ -435,13 +435,13 @@ class NeonEmitter64(object): self.ops.clear() def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff --git a/public/bit_depth.h b/public/bit_depth.h index 412944e..5b19430 100644 --- a/public/bit_depth.h +++ b/public/bit_depth.h @@ -22,8 +22,8 @@ namespace gemmlowp { // The range of allowed values for an operand. template struct OperandRange { - static const int kMinValue = tMinValue; - static const int kMaxValue = tMaxValue; + static constexpr int kMinValue = tMinValue; + static constexpr int kMaxValue = tMaxValue; static_assert(kMinValue < kMaxValue, ""); }; diff --git a/public/map.h b/public/map.h index fe6bc5c..1b71f9e 100644 --- a/public/map.h +++ b/public/map.h @@ -32,7 +32,7 @@ template class MatrixMap { public: typedef tScalar Scalar; - static const MapOrder kOrder = tOrder; + static constexpr MapOrder kOrder = tOrder; protected: Scalar* data_; // not owned. @@ -84,7 +84,7 @@ template class VectorMap { public: typedef tScalar Scalar; - static const VectorShape kShape = tShape; + static constexpr VectorShape kShape = tShape; protected: Scalar* data_; // not owned. @@ -113,7 +113,7 @@ template class VectorDup { public: typedef tScalar Scalar; - static const VectorShape kShape = tShape; + static constexpr VectorShape kShape = tShape; protected: Scalar data_; diff --git a/standalone/cache_counters.cc b/standalone/cache_counters.cc new file mode 100644 index 0000000..24e971c --- /dev/null +++ b/standalone/cache_counters.cc @@ -0,0 +1,404 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef __aarch64__ +#error This program is for 64-bit ARM only. +#endif + +struct PerfEvent { + perf_event_attr pe; + int fd = -1; + + PerfEvent(std::uint32_t type, std::uint64_t config) { + memset(&pe, 0, sizeof(pe)); + pe.size = sizeof(pe); + pe.type = type; + pe.config = config; + pe.disabled = 1; + pe.exclude_kernel = 1; + pe.exclude_hv = 1; + fd = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0); + if (fd == -1) { + fprintf(stderr, "perf_event_open failed for config 0x%lx\n", config); + abort(); + } + } + + void Start() { + ioctl(fd, PERF_EVENT_IOC_RESET, 0); + ioctl(fd, PERF_EVENT_IOC_ENABLE, 0); + } + + std::int64_t Stop() { + ioctl(fd, PERF_EVENT_IOC_DISABLE, 0); + std::int64_t count = 0; + read(fd, &count, sizeof(count)); + return count; + } + + ~PerfEvent() { close(fd); } +}; + +struct ArmPmuEvent : PerfEvent { + static constexpr std::uint16_t L1I_CACHE_REFILL = 0x01; + static constexpr std::uint16_t L1I_TLB_REFILL = 0x02; + static constexpr std::uint16_t L1D_CACHE_REFILL = 0x03; + static constexpr std::uint16_t L1D_CACHE = 0x04; + static constexpr std::uint16_t L1D_TLB_REFILL = 0x05; + static constexpr std::uint16_t LD_RETIRED = 0x06; + static constexpr std::uint16_t ST_RETIRED = 0x07; + static constexpr std::uint16_t INST_RETIRED = 0x08; + static constexpr std::uint16_t EXC_TAKEN = 0x09; + static constexpr std::uint16_t EXC_RETURN = 0x0A; + static constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B; + static constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C; + static constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D; + static constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E; + static constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F; + static constexpr std::uint16_t BR_MIS_PRED = 0x10; + static constexpr std::uint16_t CPU_CYCLES = 0x11; + static constexpr std::uint16_t BR_PRED = 0x12; + static constexpr std::uint16_t MEM_ACCESS = 0x13; + static constexpr std::uint16_t L1I_CACHE = 0x14; + static constexpr std::uint16_t L1D_CACHE_WB = 0x15; + static constexpr std::uint16_t L2D_CACHE = 0x16; + static constexpr std::uint16_t L2D_CACHE_REFILL = 0x17; + static constexpr std::uint16_t L2D_CACHE_WB = 0x18; + static constexpr std::uint16_t BUS_ACCESS = 0x19; + static constexpr std::uint16_t MEMORY_ERROR = 0x1A; + static constexpr std::uint16_t INST_SPEC = 0x1B; + static constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C; + static constexpr std::uint16_t BUS_CYCLES = 0x1D; + static constexpr std::uint16_t CHAIN = 0x1E; + static constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F; + static constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20; + static constexpr std::uint16_t BR_RETIRED = 0x21; + static constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22; + static constexpr std::uint16_t STALL_FRONTEND = 0x23; + static constexpr std::uint16_t STALL_BACKEND = 0x24; + static constexpr std::uint16_t L1D_TLB = 0x25; + static constexpr std::uint16_t L1I_TLB = 0x26; + static constexpr std::uint16_t L2I_CACHE = 0x27; + static constexpr std::uint16_t L2I_CACHE_REFILL = 0x28; + static constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29; + static constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A; + static constexpr std::uint16_t L3D_CACHE = 0x2B; + static constexpr std::uint16_t L3D_CACHE_WB = 0x2C; + static constexpr std::uint16_t L2D_TLB_REFILL = 0x2D; + static constexpr std::uint16_t L2I_TLB_REFILL = 0x2E; + static constexpr std::uint16_t L2D_TLB = 0x2F; + static constexpr std::uint16_t L2I_TLB = 0x30; + static constexpr std::uint16_t LL_CACHE = 0x32; + static constexpr std::uint16_t LL_CACHE_MISS = 0x33; + static constexpr std::uint16_t DTLB_WALK = 0x34; + static constexpr std::uint16_t LL_CACHE_RD = 0x36; + static constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37; + static constexpr std::uint16_t L1D_CACHE_RD = 0x40; + static constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42; + static constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C; + static constexpr std::uint16_t L1D_TLB_RD = 0x4E; + static constexpr std::uint16_t L2D_CACHE_RD = 0x50; + static constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52; + static constexpr std::uint16_t BUS_ACCESS_RD = 0x60; + static constexpr std::uint16_t MEM_ACCESS_RD = 0x66; + static constexpr std::uint16_t L3D_CACHE_RD = 0xA0; + static constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2; + ArmPmuEvent(std::uint16_t number) : PerfEvent(PERF_TYPE_RAW, number) {} +}; + +struct CacheCounts { + int ld_retired = 0; + int mem_access = 0; + int ll_cache = 0; + int ll_cache_miss = 0; + int l1d_cache = 0; + int l1d_cache_refill = 0; + int l2d_cache = 0; + int l2d_cache_refill = 0; + int l3d_cache = 0; + int l3d_cache_refill = 0; +}; + +void PrintCacheCounts(const CacheCounts& cache_counts) { + printf("ld_retired = %d\n", cache_counts.ld_retired); + printf("mem_access = %d\n", cache_counts.mem_access); + printf("ll_cache = %d\n", cache_counts.ll_cache); + printf("ll_cache_miss = %d\n", cache_counts.ll_cache_miss); + printf("l1d_cache = %d\n", cache_counts.l1d_cache); + printf("l1d_cache_refill = %d\n", cache_counts.l1d_cache_refill); + printf("l2d_cache = %d\n", cache_counts.l2d_cache); + printf("l2d_cache_refill = %d\n", cache_counts.l2d_cache_refill); + printf("l3d_cache = %d\n", cache_counts.l3d_cache); + printf("l3d_cache_refill = %d\n", cache_counts.l3d_cache_refill); +} + +void Workload(int accesses, int size, std::uint8_t* buf) { + // The main reason to do this in assembly is an attempt to make sense + // of instruction count counters, such as LD_RETIRED. + // Also, if we did this in C++, we would need to be watchful of the compiler + // optimizing away operations whose result isn't consumed. + // + // Note that TWO separate tricks are needed here to prevent Cortex-A76 + // speculative execution om prefetching data from future loop iterations: + // 1. A data-dependency whereby the pointers being dereferenced at the + // next loop iteration depend on values loaded at the current iteration. + // That is the role of 'dummy'. + // 2. A pseudo-random sequence. This is the role of register w0, + // where we implement a simple xorshift pseudorandom generator. + // BOTH of these tricks are needed: if we disable just one of them, + // Cortex-A76 successfully speculates some addresses, resulting in different + // L3 / DRAM hit percentages on large sizes. + std::uint64_t dummy = 123456789; + asm volatile( + // w0 := xorshift RNG state. Must be nonzero. + "mov w0, #1\n" + "1:\n" + // xorshift RNG iteration: update w0 with the next pseudorandom value + // in [1 .. 2^32-1]. + // This pseudorandomness is crucial to preventing speculative execution + // on Cortex-A76 from prefetching data from future loop iterations. + "eor w0, w0, w0, lsl #13\n" + "eor w0, w0, w0, lsr #17\n" + "eor w0, w0, w0, lsl #5\n" + // w1 := size - 1 = size mask (size is required to be power-of-two). + "sub w1, %w[size], #1\n" + // w2 := (pseudorandom value w0) xor (data-dependent sum). + "eor w2, w0, %w[dummy]\n" + // w1 := w2 modulo size + "and w1, w2, w1\n" + // align w1 + "and w1, w1, #-64\n" + // load at offset w1, again using x1 as destination. + "ldr x1, [%[buf], w1, uxtw]\n" + // Update our dummy so it depends on the value we have just loaded. + // This data-dependency is key to preventing speculative execution on + // Cortex-A76 from prefetching data from future loop iterations. + "add %[dummy], %[dummy], w1, uxtw\n" + // loop back. + "subs %w[accesses], %w[accesses], #1\n" + "bne 1b\n" + : [ accesses ] "+r"(accesses), [ dummy ] "+r"(dummy) + : [ size ] "r"(size), [ buf ] "r"(buf) + : "memory", "cc", "x0", "x1", "x2"); +} + +void MeasureCacheCounts(int accesses, int size, std::uint8_t* buf, + CacheCounts* cache_counts) { + const bool only_reads = getenv("ONLY_READS"); + ArmPmuEvent ld_retired(ArmPmuEvent::LD_RETIRED); + ArmPmuEvent mem_access(only_reads ? ArmPmuEvent::MEM_ACCESS_RD + : ArmPmuEvent::MEM_ACCESS); + ArmPmuEvent ll_cache(only_reads ? ArmPmuEvent::LL_CACHE_RD + : ArmPmuEvent::LL_CACHE); + ArmPmuEvent ll_cache_miss(only_reads ? ArmPmuEvent::LL_CACHE_MISS_RD + : ArmPmuEvent::LL_CACHE_MISS); + ArmPmuEvent l1d_cache(only_reads ? ArmPmuEvent::L1D_CACHE_RD + : ArmPmuEvent::L1D_CACHE); + ArmPmuEvent l1d_cache_refill(only_reads ? ArmPmuEvent::L1D_CACHE_REFILL_RD + : ArmPmuEvent::L1D_CACHE_REFILL); + ArmPmuEvent l2d_cache(only_reads ? ArmPmuEvent::L2D_CACHE_RD + : ArmPmuEvent::L2D_CACHE); + ArmPmuEvent l2d_cache_refill(only_reads ? ArmPmuEvent::L2D_CACHE_REFILL_RD + : ArmPmuEvent::L2D_CACHE_REFILL); + ArmPmuEvent l3d_cache(only_reads ? ArmPmuEvent::L3D_CACHE_RD + : ArmPmuEvent::L3D_CACHE); + ArmPmuEvent l3d_cache_refill(only_reads ? ArmPmuEvent::L3D_CACHE_REFILL_RD + : ArmPmuEvent::L3D_CACHE_REFILL); + + ld_retired.Start(); + mem_access.Start(); + ll_cache.Start(); + ll_cache_miss.Start(); + l1d_cache.Start(); + l1d_cache_refill.Start(); + l2d_cache.Start(); + l2d_cache_refill.Start(); + l3d_cache.Start(); + l3d_cache_refill.Start(); + + Workload(accesses, size, buf); + + cache_counts->ld_retired = ld_retired.Stop(); + cache_counts->mem_access = mem_access.Stop(); + cache_counts->ll_cache = ll_cache.Stop(); + cache_counts->ll_cache_miss = ll_cache_miss.Stop(); + cache_counts->l1d_cache = l1d_cache.Stop(); + cache_counts->l1d_cache_refill = l1d_cache_refill.Stop(); + cache_counts->l2d_cache = l2d_cache.Stop(); + cache_counts->l2d_cache_refill = l2d_cache_refill.Stop(); + cache_counts->l3d_cache = l3d_cache.Stop(); + cache_counts->l3d_cache_refill = l3d_cache_refill.Stop(); +} + +struct PieChart { + // How many accesses were recorded, total? The other fields must sum to that. + int total; + // How many accesses were serviced with the typical cost of a L1 cache hit? + int l1_hits; + // How many accesses were serviced with the typical cost of a L2 cache hit? + int l2_hits; + // How many accesses were serviced with the typical cost of a L3 cache hit? + int l3_hits; + // How many accesses were serviced with the typical cost of a DRAM access? + int dram_hits; + + ~PieChart() { + // Consistency check + if (total != l1_hits + l2_hits + l3_hits + dram_hits) { + fprintf(stderr, "inconsistent pie-chart\n"); + abort(); + } + } +}; + +struct Hypothesis { + virtual ~Hypothesis() {} + virtual const char* Name() const = 0; + virtual void Analyze(const CacheCounts& cache_counts, + PieChart* pie) const = 0; +}; + +struct Hypothesis1 : Hypothesis { + const char* Name() const override { return "Hypothesis1"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->total = cache_counts.l1d_cache + cache_counts.l1d_cache_refill; + pie->l1_hits = cache_counts.l1d_cache - cache_counts.l2d_cache_refill - + cache_counts.l3d_cache_refill; + pie->l2_hits = cache_counts.l1d_cache_refill; + pie->l3_hits = cache_counts.l2d_cache_refill; + pie->dram_hits = cache_counts.l3d_cache_refill; + } +}; + +struct Hypothesis2 : Hypothesis { + const char* Name() const override { return "Hypothesis2"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->total = cache_counts.l1d_cache; + pie->l1_hits = cache_counts.l1d_cache - cache_counts.l2d_cache; + pie->l2_hits = cache_counts.l2d_cache - cache_counts.l3d_cache; + pie->l3_hits = cache_counts.l3d_cache - cache_counts.l3d_cache_refill; + pie->dram_hits = cache_counts.l3d_cache_refill; + } +}; + +struct Hypothesis3 : Hypothesis { + const char* Name() const override { return "Hypothesis3"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->total = cache_counts.l1d_cache; + int corrected_l2 = std::min(cache_counts.l2d_cache, cache_counts.l1d_cache); + int corrected_l3 = std::min(cache_counts.l3d_cache, corrected_l2); + pie->l1_hits = cache_counts.l1d_cache - corrected_l2; + pie->l2_hits = corrected_l2 - corrected_l3; + pie->l3_hits = corrected_l3 - cache_counts.l3d_cache_refill; + pie->dram_hits = cache_counts.l3d_cache_refill; + } +}; + +struct Hypothesis4 : Hypothesis { + const char* Name() const override { return "Hypothesis4"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->total = cache_counts.l1d_cache; + pie->l1_hits = cache_counts.l1d_cache - cache_counts.l1d_cache_refill; + pie->l2_hits = + cache_counts.l1d_cache_refill - cache_counts.l2d_cache_refill; + pie->l3_hits = + cache_counts.l2d_cache_refill - cache_counts.l3d_cache_refill; + pie->dram_hits = cache_counts.l3d_cache_refill; + } +}; + +struct Hypothesis5 : Hypothesis { + const char* Name() const override { return "Hypothesis5"; } + void Analyze(const CacheCounts& cache_counts, PieChart* pie) const override { + pie->l1_hits = + std::max(0, cache_counts.l1d_cache - cache_counts.l1d_cache_refill); + pie->l2_hits = std::max( + 0, cache_counts.l1d_cache_refill - cache_counts.l2d_cache_refill); + const int l3_misses = + std::max(cache_counts.ll_cache_miss, cache_counts.l3d_cache_refill); + pie->l3_hits = std::max(0, cache_counts.l2d_cache_refill - l3_misses); + pie->dram_hits = l3_misses; + pie->total = pie->l1_hits + pie->l2_hits + pie->l3_hits + pie->dram_hits; + } +}; + +void PrintPieChart(const PieChart& pie) { + printf("total accesses: %d\n", pie.total); + double l1_hits_pct = 100. * pie.l1_hits / pie.total; + double l2_hits_pct = 100. * pie.l2_hits / pie.total; + double l3_hits_pct = 100. * pie.l3_hits / pie.total; + double dram_hits_pct = 100. * pie.dram_hits / pie.total; + printf("L1 hits: %.2f%%\n", l1_hits_pct); + printf("L2 hits: %.2f%%\n", l2_hits_pct); + printf("L1/2 hits: %.2f%%\n", l1_hits_pct + l2_hits_pct); + printf("L3 hits: %.2f%%\n", l3_hits_pct); + printf("L1/2/3 hits: %.2f%%\n", l1_hits_pct + l2_hits_pct + l3_hits_pct); + printf("DRAM hits: %.2f%%\n", dram_hits_pct); +} + +void PrintPieChartCsvNoNewline(const PieChart& pie) { + double l1_hits_pct = 100. * pie.l1_hits / pie.total; + double l2_hits_pct = 100. * pie.l2_hits / pie.total; + double l3_hits_pct = 100. * pie.l3_hits / pie.total; + double dram_hits_pct = 100. * pie.dram_hits / pie.total; + printf("%.2f,%.2f,%.2f,%.2f", l1_hits_pct, l2_hits_pct, l3_hits_pct, + dram_hits_pct); +} + +void Study(int accesses, int size, std::uint8_t* buf) { + CacheCounts cache_counts; + MeasureCacheCounts(accesses, size, buf, &cache_counts); + const Hypothesis* hypotheses[] = { + new Hypothesis5, new Hypothesis4, new Hypothesis3, + new Hypothesis2, new Hypothesis1, + }; + if (getenv("DUMP_CSV")) { + printf("%d", size); + for (const Hypothesis* hypothesis : hypotheses) { + printf(","); + PieChart pie; + hypothesis->Analyze(cache_counts, &pie); + PrintPieChartCsvNoNewline(pie); + } + printf("\n"); + } else { + printf("\n\n\naccesses=%d, size=%d:\n", accesses, size); + printf("\nCache counts:\n"); + PrintCacheCounts(cache_counts); + for (const Hypothesis* hypothesis : hypotheses) { + printf("\n%s:\n", hypothesis->Name()); + PieChart pie; + hypothesis->Analyze(cache_counts, &pie); + PrintPieChart(pie); + } + } + fflush(stdout); + for (const Hypothesis* hypothesis : hypotheses) { + delete hypothesis; + } +} + +int main() { + const int kMinSize = 1 << 12; + const int kMaxSize = 1 << 24; + const int kAccesses = 1e8; + void* buf_void = nullptr; + posix_memalign(&buf_void, 64, kMaxSize); + std::uint8_t* buf = static_cast(buf_void); + std::default_random_engine random_engine; + for (int i = 0; i < kMaxSize; i++) { + buf[i] = random_engine(); + } + for (int size = kMinSize; size <= kMaxSize; size *= 2) { + Study(kAccesses, size, buf); + } + delete[] buf; +} diff --git a/standalone/encode.py b/standalone/encode.py new file mode 100644 index 0000000..c192ab9 --- /dev/null +++ b/standalone/encode.py @@ -0,0 +1,134 @@ +# Copyright 2018 The gemmlowp Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Encodes ARM asm code for certain instructions into the corresponding machine code encoding, as a .word directive in the asm code, preserving the original code in a comment. + +Reads from stdin, writes to stdout. + +Example diff: +- "udot v16.4s, v4.16b, v0.16b\n" ++ ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n" + +The intended use case is to make asm code easier to compile on toolchains that +do not support certain new instructions. +""" + +import sys +import re +import argparse + + +def encode_udot_sdot_vector(line): + m = re.search( + r'\b([us])dot[ ]+v([0-9]+)[ ]*\.[ ]*4s[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b', + line) + if not m: + return 0, line + + match = m.group(0) + unsigned = 1 if m.group(1) == 'u' else 0 + accum = int(m.group(2)) + lhs = int(m.group(3)) + rhs = int(m.group(4)) + assert accum >= 0 and accum <= 31 + assert lhs >= 0 and lhs <= 31 + assert rhs >= 0 and rhs <= 31 + mcode = 0x4e809400 | (accum << 0) | (lhs << 5) | (rhs << 16) | ( + unsigned << 29) + return mcode, match + + +def encode_udot_sdot_element(line): + m = re.search( + r'\b([us])dot[ ]+v([0-9]+)[ ]*\.[ ]*4s[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*4b[ ]*\[([0-9])\]', + line) + if not m: + return 0, line + + match = m.group(0) + unsigned = 1 if m.group(1) == 'u' else 0 + accum = int(m.group(2)) + lhs = int(m.group(3)) + rhs = int(m.group(4)) + lanegroup = int(m.group(5)) + assert accum >= 0 and accum <= 31 + assert lhs >= 0 and lhs <= 31 + assert rhs >= 0 and rhs <= 31 + assert lanegroup >= 0 and lanegroup <= 3 + l = 1 if lanegroup & 1 else 0 + h = 1 if lanegroup & 2 else 0 + mcode = 0x4f80e000 | (accum << 0) | (lhs << 5) | (rhs << 16) | (l << 21) | ( + h << 11) | ( + unsigned << 29) + return mcode, match + + +def encode(line): + for encode_func in [encode_udot_sdot_vector, encode_udot_sdot_element]: + mcode, match = encode_func(line) + if mcode: + return mcode, match + return 0, line + + +def read_existing_encoding(line): + m = re.search(r'\.word\ (0x[0-9a-f]+)', line) + if m: + return int(m.group(1), 16) + return 0 + + +parser = argparse.ArgumentParser(description='Encode some A64 instructions.') +parser.add_argument( + '-f', + '--fix', + help='fix existing wrong encodings in-place and continue', + action='store_true') +args = parser.parse_args() + +lineno = 0 +found_existing_encodings = False +found_error = False +found_fixes = False +for line in sys.stdin: + lineno = lineno + 1 + mcode, match = encode(line) + if mcode: + existing_encoding = read_existing_encoding(line) + if existing_encoding: + found_existing_encodings = True + if mcode != existing_encoding: + if args.fix: + line = line.replace('.word 0x%x // %s' % (existing_encoding, match), + '.word 0x%x // %s' % (mcode, match)) + found_fixes = True + else: + sys.stderr.write( + "Error at line %d: existing encoding 0x%x differs from encoding 0x%x for instruction '%s':\n\n%s\n\n" + % (lineno, existing_encoding, mcode, match, line)) + found_error = True + else: + line = line.replace(match, '.word 0x%x // %s' % (mcode, match)) + sys.stdout.write(line) +if found_error: + sys.exit(1) +if found_existing_encodings: + if found_fixes: + sys.stderr.write( + 'Note: some instructions that this program is able to encode, were already encoded and their existing encodings didn\'t match the specified asm instructions. Since --fix was passed, these were fixed in-place.\n' + ) + else: + sys.stderr.write( + 'Note: some instructions that this program is able to encode, were already encoded. These encodings have been checked.\n' + ) diff --git a/standalone/neon-gemm-kernel-benchmark.cc b/standalone/neon-gemm-kernel-benchmark.cc index bff33fb..9146179 100644 --- a/standalone/neon-gemm-kernel-benchmark.cc +++ b/standalone/neon-gemm-kernel-benchmark.cc @@ -1,4 +1,4 @@ -// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. +// Copyright 2016 The gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -240,6 +240,51 @@ struct KernelFormat { static const int kCols = Rhs::Cell::kWidth * Rhs::kCells; }; +// KernelOperandRanges specifies the minimum and maximum values an operand can +// take. It consists of two ranges: one for the LHS and one for the RHS. The +// default values are the minimum and maximum values of the operand data type. +template +struct KernelOperandRanges { + static OperandType LhsMin() { + return std::numeric_limits::lowest(); + } + static OperandType LhsMax() { + return std::numeric_limits::max(); + } + static OperandType RhsMin() { + return std::numeric_limits::lowest(); + } + static OperandType RhsMax() { + return std::numeric_limits::max(); + } +}; + +template +struct KernelOperandRanges { + static float LhsMin() { return -100.f; } + static float LhsMax() { return 100.f; } + static float RhsMin() { return -100.f; } + static float RhsMax() { return 100.f; } +}; + +#define SET_7BIT_RANGES(kernel) \ +template <> \ +struct KernelOperandRanges { \ + static std::int8_t LhsMin() { return -63; } \ + static std::int8_t LhsMax() { return 63; } \ + static std::int8_t RhsMin() { return -64; } \ + static std::int8_t RhsMax() { return 63; } \ +}; + +#define SET_425BIT_RANGES(kernel) \ +template <> \ +struct KernelOperandRanges { \ + static std::int8_t LhsMin() { return -7; } \ + static std::int8_t LhsMax() { return 7; } \ + static std::int8_t RhsMin() { return -9; } \ + static std::int8_t RhsMax() { return 9; } \ +}; + inline const char* CellOrderName(CellOrder o) { switch (o) { case CellOrder::DepthMajor: @@ -596,7 +641,6 @@ struct NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits { AccumulatorType* accum_ptr, int depth) { std::size_t start_depth = 123; std::size_t run_depth = depth; - std::size_t dst_col_stride = 4; AccumulatorType* dst_ptr = accum_ptr; asm volatile( @@ -2516,142 +2560,235 @@ struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits { } }; -#ifdef __ARM_FEATURE_DOTPROD -// Kernels utilizing the Armv8.2 Dot Product extension. -// -// The dot product instructions work by taking 4 consecutive 8-bit depth -// values from each operand, multiplying the 4 pairs together and -// accumulating all the results into the corresponding 32-bit accumulator -// lane. As such, the operation is identical to a 32-bit instruction (like -// FMLA used in SGEMM), except that 4 depth values are processed at a time -// instead of 1. - -// Thus, this first kernel is a carbon copy of -// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good -// performance for most processors) below with the opcode (fmla -> udot) and -// types (float32 -> uint8/uint32) changed. -// -// A signed version of this kernel could be produced by replacing "udot" -// with "sdot" - performance should be identical to this udot kernel. -struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct { +struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow { typedef std::uint8_t OperandType; typedef std::uint32_t AccumulatorType; typedef KernelFormat< - KernelSideFormat, 3>, - KernelSideFormat, 2> > + KernelSideFormat, 1>, + KernelSideFormat, 1> > Format; static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, AccumulatorType* accum_ptr, int depth) { + std::size_t start_depth = 123; + std::size_t run_depth = depth; + std::size_t dst_col_stride = 4; + AccumulatorType* dst_ptr = accum_ptr; asm volatile( - // Load accumulators - "mov x0, %[accum_ptr]\n" - "ld1 {v8.4s}, [x0], #16\n" - "ld1 {v16.4s}, [x0], #16\n" - "ld1 {v24.4s}, [x0], #16\n" - "ld1 {v9.4s}, [x0], #16\n" - "ld1 {v17.4s}, [x0], #16\n" - "ld1 {v25.4s}, [x0], #16\n" - "ld1 {v10.4s}, [x0], #16\n" - "ld1 {v18.4s}, [x0], #16\n" - "ld1 {v26.4s}, [x0], #16\n" - "ld1 {v11.4s}, [x0], #16\n" - "ld1 {v19.4s}, [x0], #16\n" - "ld1 {v27.4s}, [x0], #16\n" - "ld1 {v12.4s}, [x0], #16\n" - "ld1 {v20.4s}, [x0], #16\n" - "ld1 {v28.4s}, [x0], #16\n" - "ld1 {v13.4s}, [x0], #16\n" - "ld1 {v21.4s}, [x0], #16\n" - "ld1 {v29.4s}, [x0], #16\n" - "ld1 {v14.4s}, [x0], #16\n" - "ld1 {v22.4s}, [x0], #16\n" - "ld1 {v30.4s}, [x0], #16\n" - "ld1 {v15.4s}, [x0], #16\n" - "ld1 {v23.4s}, [x0], #16\n" - "ld1 {v31.4s}, [x0], #16\n" + // Overview of register layout: + // + // A 4x16 block of Rhs is stored in 8 bit in v0--v3. + // A 4x16 block of Lhs is stored in 8 bit in v4--v7. + // + // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit + // components which need to be horizontally-added at the end) + // + // Register layout: + // + // +--------+--------+--------+--------+ + // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] | + // Rhs +--------+--------+--------+--------+ + // | ... | ... | ... | ... | + // +--------+--------+--------+--------| + // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]| + // +--------+--------+--------+--------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +-------+-----+--------+ - - +--------+--------+--------+--------+ + // |v4.b[0]| ... |v4.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s | + // |v5.b[0]| ... |v5.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s | + // |v6.b[0]| ... |v6.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s | + // |v7.b[0]| ... |v7.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s | + // +-------+--------------+ - - +--------+--------+--------+--------+ + // + // Accumulator + // - // The start of the loop assumes first Rhs cell is already loaded, so - // do it here for first iteration. + // Clear accumulators "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + "dup v16.4s, wzr\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" + "dup v17.4s, wzr\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "dup v18.4s, wzr\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "dup v19.4s, wzr\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "dup v20.4s, wzr\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "dup v21.4s, wzr\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "dup v22.4s, wzr\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + "dup v23.4s, wzr\n" + "subs %w[run_depth], %w[run_depth], #16\n" + "dup v24.4s, wzr\n" + "mov x0, %[dst_ptr]\n" + "dup v25.4s, wzr\n" + "dup v26.4s, wzr\n" + "dup v27.4s, wzr\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" - // And the same for the first Lhs cell. - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "beq 1f\n" + + "cmp %w[run_depth], #32\n" + "blt 2f\n" + + "3:\n" + "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n" + ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n" + "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n" + ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n" + ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n" + ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n" + ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n" + "prfm pldl1keep, [%[rhs_ptr], #128]\n" + ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n" + "ld1 {v14.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n" + ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n" + "ld1 {v15.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n" + "prfm pldl1keep, [%[lhs_ptr], #128]\n" + ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e889590 // udot v16.4s, v12.16b, v8.16b\n" + ".word 0x6e899591 // udot v17.4s, v12.16b, v9.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8a9592 // udot v18.4s, v12.16b, v10.16b\n" + ".word 0x6e8b9593 // udot v19.4s, v12.16b, v11.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8895b4 // udot v20.4s, v13.16b, v8.16b\n" + ".word 0x6e8995b5 // udot v21.4s, v13.16b, v9.16b\n" + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + "sub %[run_depth], %[run_depth], #32\n" + ".word 0x6e8a95b6 // udot v22.4s, v13.16b, v10.16b\n" + ".word 0x6e8b95b7 // udot v23.4s, v13.16b, v11.16b\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8895d8 // udot v24.4s, v14.16b, v8.16b\n" + ".word 0x6e8995d9 // udot v25.4s, v14.16b, v9.16b\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8a95da // udot v26.4s, v14.16b, v10.16b\n" + ".word 0x6e8b95db // udot v27.4s, v14.16b, v11.16b\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8895fc // udot v28.4s, v15.16b, v8.16b\n" + "prfm pldl1keep, [%[rhs_ptr], #128]\n" + ".word 0x6e8995fd // udot v29.4s, v15.16b, v9.16b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "cmp %w[run_depth], #32\n" + ".word 0x6e8a95fe // udot v30.4s, v15.16b, v10.16b\n" + "prfm pldl1keep, [%[lhs_ptr], #128]\n" + ".word 0x6e8b95ff // udot v31.4s, v15.16b, v11.16b\n" - GEMMLOWP_LABEL_LOOP - ":\n" + "bge 3b\n" - // Start the MACs at the head of the loop - 1st cell from each side - // already loaded. - "udot v8.4s, v2.16b, v0.b[0]\n" - "udot v9.4s, v2.16b, v0.b[1]\n" - "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell. - "udot v10.4s, v2.16b, v0.b[2]\n" - "udot v11.4s, v2.16b, v0.b[3]\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell. - "udot v12.4s, v2.16b, v1.b[0]\n" - "udot v13.4s, v2.16b, v1.b[1]\n" - "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell. - "udot v14.4s, v2.16b, v1.b[2]\n" - "udot v15.4s, v2.16b, v1.b[3]\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load - // for the next iteration early. - "udot v16.4s, v3.16b, v0.b[0]\n" - "udot v17.4s, v3.16b, v0.b[1]\n" - "udot v18.4s, v3.16b, v0.b[2]\n" - "udot v19.4s, v3.16b, v0.b[3]\n" - "udot v20.4s, v3.16b, v1.b[0]\n" - "udot v21.4s, v3.16b, v1.b[1]\n" - "udot v22.4s, v3.16b, v1.b[2]\n" - "udot v23.4s, v3.16b, v1.b[3]\n" - "udot v24.4s, v4.16b, v0.b[0]\n" - "udot v25.4s, v4.16b, v0.b[1]\n" - "udot v26.4s, v4.16b, v0.b[2]\n" - "udot v27.4s, v4.16b, v0.b[3]\n" - "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell - - // load for the next iteration early. - "udot v28.4s, v4.16b, v1.b[0]\n" - "udot v29.4s, v4.16b, v1.b[1]\n" + "cmp %w[run_depth], #0\n" + "beq 1f\n" - // Loop. Decrement loop index (depth) by 4 as udot processes 4 - // depth values. - "subs %w[depth], %w[depth], #4\n" - "udot v30.4s, v4.16b, v1.b[2]\n" - "udot v31.4s, v4.16b, v1.b[3]\n" + "2:\n" - "bne " GEMMLOWP_LABEL_LOOP - "b\n" + "subs %w[run_depth], %w[run_depth], #16\n" - // Store accumulators - "mov x0, %[accum_ptr]\n" + ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n" + ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n" + ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n" + ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n" + ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n" + ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n" + ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n" + ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n" + ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n" + ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n" + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + + "bne 2b\n" + + "1:\n" + + ".word 0x6e809490 // udot v16.4s, v4.16b, v0.16b\n" + ".word 0x6e819491 // udot v17.4s, v4.16b, v1.16b\n" + ".word 0x6e829492 // udot v18.4s, v4.16b, v2.16b\n" + ".word 0x6e839493 // udot v19.4s, v4.16b, v3.16b\n" + ".word 0x6e8094b4 // udot v20.4s, v5.16b, v0.16b\n" + ".word 0x6e8194b5 // udot v21.4s, v5.16b, v1.16b\n" + ".word 0x6e8294b6 // udot v22.4s, v5.16b, v2.16b\n" + ".word 0x6e8394b7 // udot v23.4s, v5.16b, v3.16b\n" + ".word 0x6e8094d8 // udot v24.4s, v6.16b, v0.16b\n" + ".word 0x6e8194d9 // udot v25.4s, v6.16b, v1.16b\n" + ".word 0x6e8294da // udot v26.4s, v6.16b, v2.16b\n" + ".word 0x6e8394db // udot v27.4s, v6.16b, v3.16b\n" + ".word 0x6e8094fc // udot v28.4s, v7.16b, v0.16b\n" + ".word 0x6e8194fd // udot v29.4s, v7.16b, v1.16b\n" + ".word 0x6e8294fe // udot v30.4s, v7.16b, v2.16b\n" + ".word 0x6e8394ff // udot v31.4s, v7.16b, v3.16b\n" + + // Load accumulators from memory + "ld1 {v8.16b}, [x0], #16\n" + "ld1 {v9.16b}, [x0], #16\n" + "ld1 {v10.16b}, [x0], #16\n" + "ld1 {v11.16b}, [x0], #16\n" + "mov x0, %[dst_ptr]\n" + + // Reduce aggregators horizontally + "addp v0.4s, v16.4s, v20.4s\n" + "addp v1.4s, v17.4s, v21.4s\n" + "addp v2.4s, v18.4s, v22.4s\n" + "addp v3.4s, v19.4s, v23.4s\n" + "addp v4.4s, v24.4s, v28.4s\n" + "addp v5.4s, v25.4s, v29.4s\n" + "addp v6.4s, v26.4s, v30.4s\n" + "addp v7.4s, v27.4s, v31.4s\n" + + "addp v12.4s, v0.4s, v4.4s\n" + "addp v13.4s, v1.4s, v5.4s\n" + "addp v14.4s, v2.4s, v6.4s\n" + "addp v15.4s, v3.4s, v7.4s\n" + + // Add to the accumulators loaded from memory + "add v8.4s, v8.4s, v12.4s\n" + "add v9.4s, v9.4s, v13.4s\n" + "add v10.4s, v10.4s, v14.4s\n" + "add v11.4s, v11.4s, v15.4s\n" + + // Store accumulators back to memory "st1 {v8.16b}, [x0], #16\n" - "st1 {v16.16b}, [x0], #16\n" - "st1 {v24.16b}, [x0], #16\n" "st1 {v9.16b}, [x0], #16\n" - "st1 {v17.16b}, [x0], #16\n" - "st1 {v25.16b}, [x0], #16\n" "st1 {v10.16b}, [x0], #16\n" - "st1 {v18.16b}, [x0], #16\n" - "st1 {v26.16b}, [x0], #16\n" "st1 {v11.16b}, [x0], #16\n" - "st1 {v19.16b}, [x0], #16\n" - "st1 {v27.16b}, [x0], #16\n" - "st1 {v12.16b}, [x0], #16\n" - "st1 {v20.16b}, [x0], #16\n" - "st1 {v28.16b}, [x0], #16\n" - "st1 {v13.16b}, [x0], #16\n" - "st1 {v21.16b}, [x0], #16\n" - "st1 {v29.16b}, [x0], #16\n" - "st1 {v14.16b}, [x0], #16\n" - "st1 {v22.16b}, [x0], #16\n" - "st1 {v30.16b}, [x0], #16\n" - "st1 {v15.16b}, [x0], #16\n" - "st1 {v23.16b}, [x0], #16\n" - "st1 {v31.16b}, [x0], #16\n" : // outputs [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [depth] "+r"(depth) + [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth), + [dst_col_stride] "+r"(dst_col_stride) : // inputs - [accum_ptr] "r"(accum_ptr) + [start_depth] "r"(start_depth) : // clobbers "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", @@ -2660,48 +2797,766 @@ struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct { } }; -// As above, except tuned for Cortex-A55r1. -// -// Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1 -// with the names changed. -struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 { - typedef std::uint8_t OperandType; - typedef std::uint32_t AccumulatorType; +// Fast kernel operating on int8 operands with 7-bit range. +// It is assumed that one of the two operands only takes values in [-63, 63], +// while the other take values in [-64, 63]. +// With this restriction, it is possible to multiply-accumulate operands into +// a 16-bit integer eight times without overflow. +struct NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits { + typedef std::int8_t OperandType; + typedef std::int32_t AccumulatorType; typedef KernelFormat< - KernelSideFormat, 3>, - KernelSideFormat, 2> > + KernelSideFormat, 1>, + KernelSideFormat, 1> > Format; static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, AccumulatorType* accum_ptr, int depth) { - asm volatile( - // Load accumulators - "mov x0, %[accum_ptr]\n" - "ld1 {v8.4s}, [x0], #16\n" - "ld1 {v16.4s}, [x0], #16\n" - "ld1 {v24.4s}, [x0], #16\n" - "ld1 {v9.4s}, [x0], #16\n" - "ld1 {v17.4s}, [x0], #16\n" - "ld1 {v25.4s}, [x0], #16\n" - "ld1 {v10.4s}, [x0], #16\n" - "ld1 {v18.4s}, [x0], #16\n" - "ld1 {v26.4s}, [x0], #16\n" - "ld1 {v11.4s}, [x0], #16\n" - "ld1 {v19.4s}, [x0], #16\n" - "ld1 {v27.4s}, [x0], #16\n" - "ld1 {v12.4s}, [x0], #16\n" - "ld1 {v20.4s}, [x0], #16\n" - "ld1 {v28.4s}, [x0], #16\n" - "ld1 {v13.4s}, [x0], #16\n" - "ld1 {v21.4s}, [x0], #16\n" - "ld1 {v29.4s}, [x0], #16\n" - "ld1 {v14.4s}, [x0], #16\n" - "ld1 {v22.4s}, [x0], #16\n" - "ld1 {v30.4s}, [x0], #16\n" - "ld1 {v15.4s}, [x0], #16\n" - "ld1 {v23.4s}, [x0], #16\n" - "ld1 {v31.4s}, [x0], #16\n" +#define GEMMLOWP_LABEL_64_DEPTH_LOOP "1" +#define GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP "2" +#define GEMMLOWP_LABEL_16_DEPTH_LOOP "3" +#define GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP "4" - // For details on how this kernel works, see the Float32 kernel below. + AccumulatorType* dst_ptr = accum_ptr; + asm volatile( + // Overview of register layout: + // + // A 4x16 block of Lhs is stored in 8 bit in v0--v7. + // A 2x16 block of Rhs is stored in 8 bit in v8--v15. + // + // A 4x2 block of global accumulators is stored in v24-v31 (as 4x32 bit + // components which need to be horizontally-added at the end). + // + // A 4x2 block of local accumulators is stored in v16-v23 (as 8x16 bit + // components which are added to global accumulators every 64 depth + // iteration. + // + // The Lhs vectors are multiplied by the Rhs vectors with a widening + // multiply over the 8 first levels of depth, producing int16x8 + // vectors of products for each position in the accumulator matrix. + // + // Like the trick used in the fast 8-bit kernel, the operands are + // restricted to 7-bit range [-2^6, 2^6) so their products are in range + // [-2^12, 2^12 -1). This enables adding eight such products without any + // risk of overflowing int16, equating to 64 levels of depth before + // horizontally adding these int16x8 accumulators into the final int32x4 + // accumulators. + // + // Register layout including both local and global accumulators. + // Since we do not have enough registers to store all Lhs values, we + // reuse the same registers v0--v7 to load the rest of the Lhs values. + // + // +-----+-----+ + // | v8 | v9 | + // Rhs +-----+-----+ + // | v10 | v11 | + // +-----+-----+ + // | v12 | v13 | + // +-----+-----+ + // | v14 | v15 | + // Lhs +-----+-----+ + // +----+----+----+----+ - - +-----+-----+ +--------+--------+ + // | v0 | v4 | v0 | v4 | | v16 | v20 | | v24.4s | v28.4s | + // | v1 | v5 | v1 | v5 | | v17 | v21 | -> | v25.4s | v29.4s | + // | v2 | v6 | v2 | v6 | | v18 | v22 | | v26.4s | v30.4s | + // | v3 | v7 | v3 | v7 | | v19 | v23 | | v27.4s | v31.4s | + // +----+----+----+----+ - - +-----+-----+ +--------+--------+ + // + // Local Accumulator Global Accumulator + // + + // Clear accumulators. + "dup v16.4s, wzr\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "dup v24.4s, wzr\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "dup v17.4s, wzr\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "dup v25.4s, wzr\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "dup v18.4s, wzr\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "dup v26.4s, wzr\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "dup v19.4s, wzr\n" + "dup v27.4s, wzr\n" + "dup v20.4s, wzr\n" + "dup v28.4s, wzr\n" + "dup v21.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v22.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v23.4s, wzr\n" + "dup v31.4s, wzr\n" + + "cmp %w[depth], #64\n" + "blt " GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP "f\n" + + //"loop_%=:\n" + GEMMLOWP_LABEL_64_DEPTH_LOOP + ":\n" + "subs %w[depth], %w[depth], #64\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "sadalp v24.4s, v16.8h\n" + "smull v16.8h, v0.8b, v8.8b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "sadalp v25.4s, v17.8h\n" + "smull v17.8h, v1.8b, v8.8b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "sadalp v26.4s, v18.8h\n" + "smull v18.8h, v2.8b, v8.8b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "sadalp v27.4s, v19.8h\n" + "smull v19.8h, v3.8b, v8.8b\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + "sadalp v28.4s, v20.8h\n" + "smull v20.8h, v0.8b, v9.8b\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "sadalp v29.4s, v21.8h\n" + "smull v21.8h, v1.8b, v9.8b\n" + "ld1 {v12.16b}, [%[rhs_ptr]], #16\n" + "sadalp v30.4s, v22.8h\n" + "smull v22.8h, v2.8b, v9.8b\n" + "ld1 {v13.16b}, [%[rhs_ptr]], #16\n" + "sadalp v31.4s, v23.8h\n" + "smull v23.8h, v3.8b, v9.8b\n" + + "cmp %w[depth], #64\n" + "smlal2 v16.8h, v0.16b, v8.16b\n" + "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v17.8h, v1.16b, v8.16b\n" + "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v18.8h, v2.16b, v8.16b\n" + "smlal2 v19.8h, v3.16b, v8.16b\n" + + "smlal2 v20.8h, v0.16b, v9.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v21.8h, v1.16b, v9.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v22.8h, v2.16b, v9.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v23.8h, v3.16b, v9.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "smlal v16.8h, v4.8b, v10.8b\n" + "smlal v17.8h, v5.8b, v10.8b\n" + "smlal v18.8h, v6.8b, v10.8b\n" + "smlal v19.8h, v7.8b, v10.8b\n" + "smlal v20.8h, v4.8b, v11.8b\n" + + "smlal v21.8h, v5.8b, v11.8b\n" + "smlal v22.8h, v6.8b, v11.8b\n" + "smlal v23.8h, v7.8b, v11.8b\n" + + "smlal2 v16.8h, v4.16b, v10.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v17.8h, v5.16b, v10.16b\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v18.8h, v6.16b, v10.16b\n" + "smlal2 v19.8h, v7.16b, v10.16b\n" + + "smlal2 v20.8h, v4.16b, v11.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v21.8h, v5.16b, v11.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v22.8h, v6.16b, v11.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v23.8h, v7.16b, v11.16b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + + "smlal v16.8h, v0.8b, v12.8b\n" + "smlal v17.8h, v1.8b, v12.8b\n" + "smlal v18.8h, v2.8b, v12.8b\n" + "smlal v19.8h, v3.8b, v12.8b\n" + "smlal v20.8h, v0.8b, v13.8b\n" + "smlal v21.8h, v1.8b, v13.8b\n" + "smlal v22.8h, v2.8b, v13.8b\n" + "smlal v23.8h, v3.8b, v13.8b\n" + + "smlal2 v16.8h, v0.16b, v12.16b\n" + "smlal2 v17.8h, v1.16b, v12.16b\n" + "smlal2 v18.8h, v2.16b, v12.16b\n" + "smlal2 v19.8h, v3.16b, v12.16b\n" + + "smlal2 v20.8h, v0.16b, v13.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v21.8h, v1.16b, v13.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v22.8h, v2.16b, v13.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v23.8h, v3.16b, v13.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "smlal v16.8h, v4.8b, v14.8b\n" + "smlal v17.8h, v5.8b, v14.8b\n" + "smlal v18.8h, v6.8b, v14.8b\n" + "smlal v19.8h, v7.8b, v14.8b\n" + + "smlal v20.8h, v4.8b, v15.8b\n" + "smlal v21.8h, v5.8b, v15.8b\n" + "smlal v22.8h, v6.8b, v15.8b\n" + "smlal v23.8h, v7.8b, v15.8b\n" + + "smlal2 v16.8h, v4.16b, v14.16b\n" + "smlal2 v17.8h, v5.16b, v14.16b\n" + "smlal2 v18.8h, v6.16b, v14.16b\n" + "smlal2 v19.8h, v7.16b, v14.16b\n" + + "smlal2 v20.8h, v4.16b, v15.16b\n" + "smlal2 v21.8h, v5.16b, v15.16b\n" + "smlal2 v22.8h, v6.16b, v15.16b\n" + "smlal2 v23.8h, v7.16b, v15.16b\n" + + "bge " GEMMLOWP_LABEL_64_DEPTH_LOOP "b\n" + + GEMMLOWP_LABEL_64_DEPTH_AFTER_LOOP + ":\n" + + "cmp %w[depth], #16\n" + "blt " GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP "f\n" + + //"loop_%=:\n" + GEMMLOWP_LABEL_16_DEPTH_LOOP + ":\n" + "sadalp v24.4s, v16.8h\n" + "smull v16.8h, v0.8b, v8.8b\n" + "subs %w[depth], %w[depth], #16\n" + "sadalp v25.4s, v17.8h\n" + "smull v17.8h, v1.8b, v8.8b\n" + "sadalp v26.4s, v18.8h\n" + "smull v18.8h, v2.8b, v8.8b\n" + "sadalp v27.4s, v19.8h\n" + "smull v19.8h, v3.8b, v8.8b\n" + "sadalp v28.4s, v20.8h\n" + "smull v20.8h, v0.8b, v9.8b\n" + "sadalp v29.4s, v21.8h\n" + "smull v21.8h, v1.8b, v9.8b\n" + "sadalp v30.4s, v22.8h\n" + "smull v22.8h, v2.8b, v9.8b\n" + "sadalp v31.4s, v23.8h\n" + "smull v23.8h, v3.8b, v9.8b\n" + + "cmp %w[depth], #16\n" + "smlal2 v16.8h, v0.16b, v8.16b\n" + "smlal2 v17.8h, v1.16b, v8.16b\n" + "smlal2 v18.8h, v2.16b, v8.16b\n" + "smlal2 v19.8h, v3.16b, v8.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + + "smlal2 v20.8h, v0.16b, v9.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v21.8h, v1.16b, v9.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v22.8h, v2.16b, v9.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v23.8h, v3.16b, v9.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + + "bge " GEMMLOWP_LABEL_16_DEPTH_LOOP "b\n" + + GEMMLOWP_LABEL_16_DEPTH_AFTER_LOOP + ":\n" + + "sadalp v24.4s, v16.8h\n" + "sadalp v25.4s, v17.8h\n" + "sadalp v26.4s, v18.8h\n" + "sadalp v27.4s, v19.8h\n" + "sadalp v28.4s, v20.8h\n" + "sadalp v29.4s, v21.8h\n" + "sadalp v30.4s, v22.8h\n" + "sadalp v31.4s, v23.8h\n" + + // Reduce aggregators horizontally. + "addp v0.4s, v24.4s, v25.4s\n" + "addp v1.4s, v26.4s, v27.4s\n" + "addp v2.4s, v28.4s, v29.4s\n" + "addp v3.4s, v30.4s, v31.4s\n" + + "addp v4.4s, v0.4s, v1.4s\n" + "addp v5.4s, v2.4s, v3.4s\n" + + // Load accumulators from memory. + "mov x0, %[dst_ptr]\n" + "ld1 {v6.16b}, [x0], #16\n" + "ld1 {v7.16b}, [x0], #16\n" + + // Add to the accumulators loaded from memory. + "add v6.4s, v6.4s, v4.4s\n" + "add v7.4s, v7.4s, v5.4s\n" + + // Store accumulators back to memory. + "mov x0, %[dst_ptr]\n" + "st1 {v6.16b}, [x0], #16\n" + "st1 {v7.16b}, [x0], #16\n" + + : + // Outputs. + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), [depth] "+r"(depth) + : + // Inputs. + + : + // Clobbers. + "cc", "memory", + // We use these NEON registers + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x0"); + } +}; + +SET_7BIT_RANGES(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits); + +// Kernel operating on int8 operands with 4.25-bit range. +// It is assumed that one of the two operands only takes values in [-7, 7], +// while the other take values in [-9, 9]. +// With this restriction, it is possible to multiply-accumulate operands into +// a 16-bit integer thirty-two times without overflow. +struct NEON_64bit_GEMM_Int425Operands { + typedef std::int8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat, 1>, + KernelSideFormat, 1> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { +#define GEMMLOWP_LABEL_512_DEPTH_LOOP "1" +#define GEMMLOWP_LABEL_32_DEPTH_LOOP "2" +#define GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP "3" + + AccumulatorType* dst_ptr = accum_ptr; + int outer_depth = depth / 512 + 1; + + asm volatile( + // Overview of register layout: + // + // A 4x32 block of Lhs is stored in 8 bit in v0--v7. + // A 2x32 block of Rhs is stored in 8 bit in v8--v11. + // + // A 4x2 block of global accumulators is stored in v24-v31 (as 4x32 bit + // components which need to be horizontally-added at the end). + // + // A 4x2 block of local accumulators is stored in v16-v23 (as 8x16 bit + // components which are horizontally-added to global accumulators every + // 512 depth iteration. + // + // The Lhs vectors are multiplied by the Rhs vectors with a multiply + // over the 16 first levels of depth, producing int8x16 vectors of + // products for each position in the accumulator matrix. + // + // Like the trick used in the fast 8-bit and 7-bit kernels, the operands + // are restricted to 4.25-bit range, [-7, 7] for one operand and [-9, 9] + // for the other operand. This enables adding two such products without + // any risk of overflowing int8, and thiry-two such products without + // overflowing int16. This equates to 512 levels of depth before + // horizontally adding these int16x8 accumulators into the final int32x4 + // accumulators. + // + // Register layout (ignoring the v12--v15 temporary 8-bit accumulators). + // Since we do not have enough registers to store all Lhs values and Rhs + // values, we reuse the same registers v0--v7 to load subsequent Lhs + // values and v8-v11 to subsequent Rhs values. + // + // +-----+-----+ + // | v8 | v9 | + // Rhs +-----+-----+ + // | v10 | v11 | + // +-----+-----+ + // | v8 | v9 | + // +-----+-----+ + // | v10 | v11 | + // Lhs +-----+-----+ + // +----+----+----+----+ - - +-----+-----+ +--------+--------+ + // | v0 | v4 | v0 | v4 | | v16 | v17 | | v24.4s | v25.4s | + // | v1 | v5 | v1 | v5 | | v18 | v19 | -> | v26.4s | v27.4s | + // | v2 | v6 | v2 | v6 | | v20 | v21 | | v28.4s | v29.4s | + // | v3 | v7 | v3 | v7 | | v22 | v23 | | v30.4s | v31.4s | + // +----+----+----+----+ - - +-----+-----+ +--------+--------+ + // + // Local Accumulator Global Accumulator + // + + // Clear global accumulators. + "dup v24.4s, wzr\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "dup v25.4s, wzr\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "dup v26.4s, wzr\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + "dup v27.4s, wzr\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "dup v28.4s, wzr\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "dup v29.4s, wzr\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "dup v30.4s, wzr\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "dup v31.4s, wzr\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + + //"loop_%=:\n" + GEMMLOWP_LABEL_512_DEPTH_LOOP + ":\n" + // Clear local accumulators. + "dup v16.8h, wzr\n" + "dup v17.8h, wzr\n" + "dup v18.8h, wzr\n" + "mov x1, #512\n" + "dup v19.8h, wzr\n" + "dup v20.8h, wzr\n" + "dup v21.8h, wzr\n" + "dup v22.8h, wzr\n" + "dup v23.8h, wzr\n" + + //"loop_%=:\n" + GEMMLOWP_LABEL_32_DEPTH_LOOP + ":\n" + "mul v12.16b, v0.16b, v8.16b\n" + "mul v13.16b, v0.16b, v10.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "mul v14.16b, v2.16b, v8.16b\n" + "mul v15.16b, v2.16b, v10.16b\n" + + "mla v12.16b, v1.16b, v9.16b\n" + "mla v13.16b, v1.16b, v11.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "mla v14.16b, v3.16b, v9.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "mla v15.16b, v3.16b, v11.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "sadalp v16.8h, v12.16b\n" + "sadalp v17.8h, v13.16b\n" + "subs %w[depth], %w[depth], #32\n" + "sadalp v18.8h, v14.16b\n" + "sadalp v19.8h, v15.16b\n" + "subs x1, x1, #32\n" + + "mul v12.16b, v4.16b, v8.16b\n" + "mul v13.16b, v4.16b, v10.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "mul v14.16b, v6.16b, v8.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "mul v15.16b, v6.16b, v10.16b\n" + + "mla v12.16b, v5.16b, v9.16b\n" + "mla v13.16b, v5.16b, v11.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "mla v14.16b, v7.16b, v9.16b\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "mla v15.16b, v7.16b, v11.16b\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + + "sadalp v20.8h, v12.16b\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "sadalp v21.8h, v13.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "sadalp v22.8h, v14.16b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "sadalp v23.8h, v15.16b\n" + + "mul v12.16b, v0.16b, v8.16b\n" + "mul v13.16b, v0.16b, v10.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "mul v14.16b, v2.16b, v8.16b\n" + "mul v15.16b, v2.16b, v10.16b\n" + + "mla v12.16b, v1.16b, v9.16b\n" + "mla v13.16b, v1.16b, v11.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "mla v14.16b, v3.16b, v9.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "mla v15.16b, v3.16b, v11.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "sadalp v16.8h, v12.16b\n" + "sadalp v17.8h, v13.16b\n" + "sadalp v18.8h, v14.16b\n" + "sadalp v19.8h, v15.16b\n" + + "mul v12.16b, v4.16b, v8.16b\n" + "mul v13.16b, v4.16b, v10.16b\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "mul v14.16b, v6.16b, v8.16b\n" + "ld1 {v8.16b}, [%[rhs_ptr]], #16\n" + "mul v15.16b, v6.16b, v10.16b\n" + + "mla v12.16b, v5.16b, v9.16b\n" + "mla v13.16b, v5.16b, v11.16b\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "mla v14.16b, v7.16b, v9.16b\n" + "ld1 {v9.16b}, [%[rhs_ptr]], #16\n" + "mla v15.16b, v7.16b, v11.16b\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + + "sadalp v20.8h, v12.16b\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "sadalp v21.8h, v13.16b\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "sadalp v22.8h, v14.16b\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "sadalp v23.8h, v15.16b\n" + + "beq " GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP + "f\n" + + "cmp %w[depth], #0\n" + "bne " GEMMLOWP_LABEL_32_DEPTH_LOOP "b\n" + + GEMMLOWP_LABEL_32_DEPTH_AFTER_LOOP + ":\n" + + // Pairwise add 16-bit local accums to 32-bit global accums. + "sadalp v24.4s, v16.8h\n" + "sadalp v25.4s, v17.8h\n" + "sadalp v26.4s, v18.8h\n" + "sadalp v27.4s, v19.8h\n" + "sadalp v28.4s, v20.8h\n" + "sadalp v29.4s, v21.8h\n" + "sadalp v30.4s, v22.8h\n" + "sadalp v31.4s, v23.8h\n" + + "bne " GEMMLOWP_LABEL_512_DEPTH_LOOP + "b\n" + + // Reduce aggregators horizontally. + "addp v0.4s, v24.4s, v26.4s\n" + "addp v1.4s, v28.4s, v30.4s\n" + "addp v2.4s, v25.4s, v27.4s\n" + "addp v3.4s, v29.4s, v31.4s\n" + + "addp v4.4s, v0.4s, v1.4s\n" + "addp v5.4s, v2.4s, v3.4s\n" + + // Load accumulators from memory. + "mov x0, %[dst_ptr]\n" + "ld1 {v6.16b}, [x0], #16\n" + "ld1 {v7.16b}, [x0], #16\n" + + // Add to the accumulators loaded from memory. + "add v6.4s, v6.4s, v4.4s\n" + "add v7.4s, v7.4s, v5.4s\n" + + // Store accumulators back to memory. + "mov x0, %[dst_ptr]\n" + "st1 {v6.16b}, [x0], #16\n" + "st1 {v7.16b}, [x0], #16\n" + + : + // Outputs. + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), [depth] "+r"(depth), + [outer_depth] "+r"(outer_depth) + : + // Inputs. + + : + // Clobbers. + "cc", "memory", + // We use these NEON registers + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x0", "x1"); + } +}; + +SET_425BIT_RANGES(NEON_64bit_GEMM_Int425Operands); + +#ifdef __ARM_FEATURE_DOTPROD +// Kernels utilizing the Armv8.2 Dot Product extension. +// +// The dot product instructions work by taking 4 consecutive 8-bit depth +// values from each operand, multiplying the 4 pairs together and +// accumulating all the results into the corresponding 32-bit accumulator +// lane. As such, the operation is identical to a 32-bit instruction (like +// FMLA used in SGEMM), except that 4 depth values are processed at a time +// instead of 1. + +// Thus, this first kernel is a carbon copy of +// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good +// performance for most processors) below with the opcode (fmla -> udot) and +// types (float32 -> uint8/uint32) changed. +// +// A signed version of this kernel could be produced by replacing "udot" +// with "sdot" - performance should be identical to this udot kernel. +struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat, 3>, + KernelSideFormat, 2> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "mov x0, %[accum_ptr]\n" + "ld1 {v8.4s}, [x0], #16\n" + "ld1 {v16.4s}, [x0], #16\n" + "ld1 {v24.4s}, [x0], #16\n" + "ld1 {v9.4s}, [x0], #16\n" + "ld1 {v17.4s}, [x0], #16\n" + "ld1 {v25.4s}, [x0], #16\n" + "ld1 {v10.4s}, [x0], #16\n" + "ld1 {v18.4s}, [x0], #16\n" + "ld1 {v26.4s}, [x0], #16\n" + "ld1 {v11.4s}, [x0], #16\n" + "ld1 {v19.4s}, [x0], #16\n" + "ld1 {v27.4s}, [x0], #16\n" + "ld1 {v12.4s}, [x0], #16\n" + "ld1 {v20.4s}, [x0], #16\n" + "ld1 {v28.4s}, [x0], #16\n" + "ld1 {v13.4s}, [x0], #16\n" + "ld1 {v21.4s}, [x0], #16\n" + "ld1 {v29.4s}, [x0], #16\n" + "ld1 {v14.4s}, [x0], #16\n" + "ld1 {v22.4s}, [x0], #16\n" + "ld1 {v30.4s}, [x0], #16\n" + "ld1 {v15.4s}, [x0], #16\n" + "ld1 {v23.4s}, [x0], #16\n" + "ld1 {v31.4s}, [x0], #16\n" + + // The start of the loop assumes first Rhs cell is already loaded, so + // do it here for first iteration. + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + + // And the same for the first Lhs cell. + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + + GEMMLOWP_LABEL_LOOP + ":\n" + + // Start the MACs at the head of the loop - 1st cell from each side + // already loaded. + ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n" + ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell. + ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n" + ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell. + ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n" + ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell. + ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n" + ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load + // for the next iteration early. + ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n" + ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n" + ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n" + ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n" + ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n" + ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n" + ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n" + ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n" + ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n" + ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n" + ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n" + ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n" + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell - + // load for the next iteration early. + ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n" + ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n" + + // Loop. Decrement loop index (depth) by 4 as udot processes 4 + // depth values. + "subs %w[depth], %w[depth], #4\n" + ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n" + ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n" + + "bne " GEMMLOWP_LABEL_LOOP + "b\n" + + // Store accumulators + "mov x0, %[accum_ptr]\n" + "st1 {v8.16b}, [x0], #16\n" + "st1 {v16.16b}, [x0], #16\n" + "st1 {v24.16b}, [x0], #16\n" + "st1 {v9.16b}, [x0], #16\n" + "st1 {v17.16b}, [x0], #16\n" + "st1 {v25.16b}, [x0], #16\n" + "st1 {v10.16b}, [x0], #16\n" + "st1 {v18.16b}, [x0], #16\n" + "st1 {v26.16b}, [x0], #16\n" + "st1 {v11.16b}, [x0], #16\n" + "st1 {v19.16b}, [x0], #16\n" + "st1 {v27.16b}, [x0], #16\n" + "st1 {v12.16b}, [x0], #16\n" + "st1 {v20.16b}, [x0], #16\n" + "st1 {v28.16b}, [x0], #16\n" + "st1 {v13.16b}, [x0], #16\n" + "st1 {v21.16b}, [x0], #16\n" + "st1 {v29.16b}, [x0], #16\n" + "st1 {v14.16b}, [x0], #16\n" + "st1 {v22.16b}, [x0], #16\n" + "st1 {v30.16b}, [x0], #16\n" + "st1 {v15.16b}, [x0], #16\n" + "st1 {v23.16b}, [x0], #16\n" + "st1 {v31.16b}, [x0], #16\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31"); + } +}; + +// As above, except tuned for Cortex-A55r1. +// +// Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1 +// with the names changed. +struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat, 3>, + KernelSideFormat, 2> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "mov x0, %[accum_ptr]\n" + "ld1 {v8.4s}, [x0], #16\n" + "ld1 {v16.4s}, [x0], #16\n" + "ld1 {v24.4s}, [x0], #16\n" + "ld1 {v9.4s}, [x0], #16\n" + "ld1 {v17.4s}, [x0], #16\n" + "ld1 {v25.4s}, [x0], #16\n" + "ld1 {v10.4s}, [x0], #16\n" + "ld1 {v18.4s}, [x0], #16\n" + "ld1 {v26.4s}, [x0], #16\n" + "ld1 {v11.4s}, [x0], #16\n" + "ld1 {v19.4s}, [x0], #16\n" + "ld1 {v27.4s}, [x0], #16\n" + "ld1 {v12.4s}, [x0], #16\n" + "ld1 {v20.4s}, [x0], #16\n" + "ld1 {v28.4s}, [x0], #16\n" + "ld1 {v13.4s}, [x0], #16\n" + "ld1 {v21.4s}, [x0], #16\n" + "ld1 {v29.4s}, [x0], #16\n" + "ld1 {v14.4s}, [x0], #16\n" + "ld1 {v22.4s}, [x0], #16\n" + "ld1 {v30.4s}, [x0], #16\n" + "ld1 {v15.4s}, [x0], #16\n" + "ld1 {v23.4s}, [x0], #16\n" + "ld1 {v31.4s}, [x0], #16\n" + + // For details on how this kernel works, see the Float32 kernel below. "ldr d0, [%[rhs_ptr]]\n" "ldr x18, [%[rhs_ptr], #8]\n" @@ -2712,54 +3567,67 @@ struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 { GEMMLOWP_LABEL_LOOP ":\n" - "udot v8.4s, v2.16b, v0.b[0]\n" - "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1 - "udot v9.4s, v2.16b, v0.b[1]\n" - "ins v0.d[1], x18\n" // Finish loading v0 - "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure. - "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register - "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure. - "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer. - "udot v10.4s, v2.16b, v0.b[2]\n" - "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4 - "udot v11.4s, v2.16b, v0.b[3]\n" - "ins v1.d[1], x18\n" // Finish loading v1 - "udot v12.4s, v2.16b, v1.b[0]\n" - "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register - "udot v13.4s, v2.16b, v1.b[1]\n" - "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer. - "udot v14.4s, v2.16b, v1.b[2]\n" - - "udot v15.4s, v2.16b, v1.b[3]\n" - "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time) - "udot v18.4s, v3.16b, v0.b[2]\n" - "ins v4.d[1], x18\n" // Finish loading v4 - "udot v19.4s, v3.16b, v0.b[3]\n" - "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register - "udot v20.4s, v3.16b, v1.b[0]\n" + ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n" + "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1 + ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n" + "ins v0.d[1], x18\n" // Finish loading v0 + ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n" // out of + // sequence - + // used to + // reduce + // load/use + // pressure. + "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register + ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n" // out of + // sequence - + // used to + // reduce + // load/use + // pressure. + "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment + // pointer. + ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n" + "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4 + ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n" + "ins v1.d[1], x18\n" // Finish loading v1 + ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n" + "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register + ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n" + "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment + // pointer. + ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n" + + ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n" + "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time) + ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n" + "ins v4.d[1], x18\n" // Finish loading v4 + ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n" + "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register + ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n" "subs %w[depth], %w[depth], #4\n" - "udot v21.4s, v3.16b, v1.b[1]\n" - - "udot v22.4s, v3.16b, v1.b[2]\n" - - "udot v23.4s, v3.16b, v1.b[3]\n" - "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time) - "udot v24.4s, v4.16b, v0.b[0]\n" - "ins v2.d[1], x18\n" // Finish loading next v2 - "udot v25.4s, v4.16b, v0.b[1]\n" - "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register - "udot v26.4s, v4.16b, v0.b[2]\n" - - "udot v27.4s, v4.16b, v0.b[3]\n" - "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time) - "udot v28.4s, v4.16b, v1.b[0]\n" - "ins v3.d[1], x18\n" // Finish loading next v3 - "udot v29.4s, v4.16b, v1.b[1]\n" - "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register - "udot v30.4s, v4.16b, v1.b[2]\n" - - "udot v31.4s, v4.16b, v1.b[3]\n" - "bne " GEMMLOWP_LABEL_LOOP "b\n" + ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n" + + ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n" + + ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n" + "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time) + ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n" + "ins v2.d[1], x18\n" // Finish loading next v2 + ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n" + "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register + ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n" + + ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n" + "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time) + ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n" + "ins v3.d[1], x18\n" // Finish loading next v3 + ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n" + "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register + ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n" + + ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n" + "bne " GEMMLOWP_LABEL_LOOP + "b\n" // Store accumulators "mov x0, %[accum_ptr]\n" @@ -3775,840 +4643,272 @@ struct NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics { vmlal_lane_s16(acc[i][4 * j + 1], vget_high_s16(lhs[i]), vget_high_s16(rhs[j]), 1); acc[i][4 * j + 2] = - vmlal_lane_s16(acc[i][4 * j + 2], vget_high_s16(lhs[i]), - vget_high_s16(rhs[j]), 2); - acc[i][4 * j + 3] = - vmlal_lane_s16(acc[i][4 * j + 3], vget_high_s16(lhs[i]), - vget_high_s16(rhs[j]), 3); - } - } - lhs_ptr += 24; - rhs_ptr += 8 * RhsCells; - } - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4 * RhsCells; j++) { - vst1q_s32(accum_ptr + 4 * (i + 3 * j), acc[i][j]); - } - } - } -}; - -using NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics = - NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<1>; - -using NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics = - NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<2>; - -template -struct NEON_GEMM_Float32_WithScalar_intrinsics { - typedef float OperandType; - typedef float AccumulatorType; - typedef KernelFormat< - KernelSideFormat, 3>, - KernelSideFormat, RhsCells> > - Format; - static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, - AccumulatorType* accum_ptr, int depth) { - float32x4_t acc[3][4 * RhsCells]; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4 * RhsCells; j++) { - acc[i][j] = vld1q_f32(accum_ptr + 4 * (i + 3 * j)); - } - } - for (int d = 0; d < depth; d++) { - float32x4_t lhs[3]; - for (int i = 0; i < 3; i++) { - lhs[i] = vld1q_f32(lhs_ptr + 4 * i); - } - float32x4_t rhs[RhsCells]; - for (int i = 0; i < RhsCells; i++) { - rhs[i] = vld1q_f32(rhs_ptr + 4 * i); - } - for (int i = 0; i < 3; i++) { - for (int j = 0; j < RhsCells; j++) { - acc[i][4 * j + 0] = vmlaq_lane_f32(acc[i][4 * j + 0], lhs[i], - vget_low_f32(rhs[j]), 0); - acc[i][4 * j + 1] = vmlaq_lane_f32(acc[i][4 * j + 1], lhs[i], - vget_low_f32(rhs[j]), 1); - acc[i][4 * j + 2] = vmlaq_lane_f32(acc[i][4 * j + 2], lhs[i], - vget_high_f32(rhs[j]), 0); - acc[i][4 * j + 3] = vmlaq_lane_f32(acc[i][4 * j + 3], lhs[i], - vget_high_f32(rhs[j]), 1); - } - } - lhs_ptr += 12; - rhs_ptr += 4 * RhsCells; - } - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4 * RhsCells; j++) { - vst1q_f32(accum_ptr + 4 * (i + 3 * j), acc[i][j]); - } - } - } -}; - -using NEON_32bit_GEMM_Float32_WithScalar_intrinsics = - NEON_GEMM_Float32_WithScalar_intrinsics<1>; - -using NEON_64bit_GEMM_Float32_WithScalar_intrinsics = - NEON_GEMM_Float32_WithScalar_intrinsics<2>; -#endif // __arm__ || __aarch64__ - -#ifdef __mips -static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) { - // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c). -#if 0 - return __builtin_msa_maddv_w(a, b, c); -#else - asm volatile("maddv.w %w[a], %w[b], %w[c]\n" - // Outputs - : [a] "+f"(a) - // Inputs - : [b] "f"(b), [c] "f"(c)); - return a; -#endif -} - -// Using 32x32=32 multiplications. -// 20 MSA regs used: -// - 12 accumulators -// - 6 lhs -// - 1 rhs -// - 1 temps/zeroes -// ~55 instructions in the loop. -struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics { - typedef std::uint8_t OperandType; - typedef std::int32_t AccumulatorType; - typedef KernelFormat< - KernelSideFormat, 3>, - KernelSideFormat, 1> > - Format; - static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, - AccumulatorType* accum_ptr, int depth) { - const v16i8 zeroes = __builtin_msa_ldi_b(0); - v4i32 acc[3][4]; - // Load accumulators. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0); - } - } - - while (depth > 0) { - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - v8i16 lhs[6]; - lhs[0] = reinterpret_cast(__builtin_msa_ld_b(const_cast(lhs_ptr), 0)); - lhs[1] = - reinterpret_cast(__builtin_msa_ld_b(const_cast(lhs_ptr + 8), 0)); - - // Zero-extend 8-bit elements of lhs[] to 16 bits. - lhs[0] = reinterpret_cast(__builtin_msa_ilvr_b(zeroes, - reinterpret_cast(lhs[0]))); - lhs[2] = reinterpret_cast(__builtin_msa_ilvl_b(zeroes, - reinterpret_cast(lhs[1]))); - lhs[1] = reinterpret_cast(__builtin_msa_ilvr_b(zeroes, - reinterpret_cast(lhs[1]))); - - // Zero-extend 16-bit elements of lhs[] to 32 bits. - lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast(zeroes), lhs[0]); - lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast(zeroes), lhs[1]); - lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast(zeroes), lhs[2]); - lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast(zeroes), lhs[0]); - lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast(zeroes), lhs[1]); - lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast(zeroes), lhs[2]); - - // Depth 0. - for (int j = 0; j < 4; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast(__builtin_msa_fill_w(rhs_ptr[j])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast(lhs[i]), rhs); - } - } - - // Depth 1. - for (int j = 0; j < 4; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast(__builtin_msa_fill_w(rhs_ptr[j + 4])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast(lhs[i + 3]), rhs); - } - } - - lhs_ptr += 24; - rhs_ptr += 8; - depth -= 2; - } - - // Store accumulators. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0); - } - } - } -}; - -// Assembly implementation of the above -// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics. -// Using 32x32=32 multiplications. -// 20 MSA regs used: -// - 12 accumulators -// - 6 lhs -// - 1 rhs -// - 1 temps/zeroes -// ~55 instructions in the loop. -struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly { - typedef std::uint8_t OperandType; - typedef std::int32_t AccumulatorType; - typedef KernelFormat< - KernelSideFormat, 3>, - KernelSideFormat, 1> > - Format; - static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, - AccumulatorType* accum_ptr, int depth) { - asm volatile( - // Load accumulators - "ld.w $w0, (0*16)(%[accum_ptr])\n" - "ld.w $w4, (1*16)(%[accum_ptr])\n" - "ld.w $w8, (2*16)(%[accum_ptr])\n" - "ld.w $w1, (3*16)(%[accum_ptr])\n" - "ld.w $w5, (4*16)(%[accum_ptr])\n" - "ld.w $w9, (5*16)(%[accum_ptr])\n" - "ld.w $w2, (6*16)(%[accum_ptr])\n" - "ld.w $w6, (7*16)(%[accum_ptr])\n" - "ld.w $w10, (8*16)(%[accum_ptr])\n" - "ld.w $w3, (9*16)(%[accum_ptr])\n" - "ld.w $w7, (10*16)(%[accum_ptr])\n" - "ld.w $w11, (11*16)(%[accum_ptr])\n" - // Set a temp to all zeroes. - "ldi.b $w19, 0\n" - - GEMMLOWP_LABEL_LOOP ":\n" - // Overview of register layout: - // - // A half of the 2x4 cell of Rhs is stored in 32bit in w18. - // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17. - // A 12x4 block of accumulators is stored in 32bit in w0-w11. - // - // +------+------+------+------+ - // Rhs |w18[0]|w18[1]|w18[2]|w18[3]| - // +------+------+------+------+ - // - // | | | | | - // - // Lhs | | | | | - // - // +---+---+ - - - - +------+------+------+------+ - // |w12|w15| | w0 | w1 | w2 | w3 | - // |w12|w15| | w0 | w1 | w2 | w3 | - // |w12|w15| | w0 | w1 | w2 | w3 | - // |w12|w15| | w0 | w1 | w2 | w3 | - // +---+---+ - - - - +------+------+------+------+ - // |w13|w16| | w4 | w5 | w6 | w7 | - // |w13|w16| | w4 | w5 | w6 | w7 | - // |w13|w16| | w4 | w5 | w6 | w7 | - // |w13|w16| | w4 | w5 | w6 | w7 | - // +---+---+ - - - - +------+------+------+------+ - // |w14|w17| | w8 | w9 | w10 | w11 | - // |w14|w17| | w8 | w9 | w10 | w11 | - // |w14|w17| | w8 | w9 | w10 | w11 | - // |w14|w17| | w8 | w9 | w10 | w11 | - // +---+---+ - - - - +------+------+------+------+ - // - // Accumulator - - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - "ld.b $w12, 0(%[lhs_ptr])\n" - "ld.b $w13, 8(%[lhs_ptr])\n" - - // Load 4 bytes of rhs[] for depth 0. - "lbu $a0, 0(%[rhs_ptr])\n" - "lbu $a1, 1(%[rhs_ptr])\n" - "lbu $a2, 2(%[rhs_ptr])\n" - "lbu $a3, 3(%[rhs_ptr])\n" - - // Zero-extend 8-bit elements of lhs[] to 16 bits. - "ilvr.b $w12, $w19, $w12\n" - "ilvl.b $w14, $w19, $w13\n" - "ilvr.b $w13, $w19, $w13\n" - // Zero-extend 16-bit elements of lhs[] to 32 bits. - "ilvl.h $w15, $w19, $w12\n" - "ilvl.h $w16, $w19, $w13\n" - "ilvl.h $w17, $w19, $w14\n" - "ilvr.h $w12, $w19, $w12\n" - "ilvr.h $w13, $w19, $w13\n" - "ilvr.h $w14, $w19, $w14\n" - - // Depth 0. - "fill.w $w18, $a0\n" - "lbu $a0, 4(%[rhs_ptr])\n" - "maddv.w $w0, $w12, $w18\n" - "maddv.w $w4, $w13, $w18\n" - "maddv.w $w8, $w14, $w18\n" - "fill.w $w18, $a1\n" - "lbu $a1, 5(%[rhs_ptr])\n" - "maddv.w $w1, $w12, $w18\n" - "maddv.w $w5, $w13, $w18\n" - "maddv.w $w9, $w14, $w18\n" - "fill.w $w18, $a2\n" - "lbu $a2, 6(%[rhs_ptr])\n" - "maddv.w $w2, $w12, $w18\n" - "maddv.w $w6, $w13, $w18\n" - "maddv.w $w10, $w14, $w18\n" - "fill.w $w18, $a3\n" - "lbu $a3, 7(%[rhs_ptr])\n" - "maddv.w $w3, $w12, $w18\n" - "maddv.w $w7, $w13, $w18\n" - "maddv.w $w11, $w14, $w18\n" - - // Depth 1. - "fill.w $w18, $a0\n" - "maddv.w $w0, $w15, $w18\n" - "maddv.w $w4, $w16, $w18\n" - "maddv.w $w8, $w17, $w18\n" - "fill.w $w18, $a1\n" - "maddv.w $w1, $w15, $w18\n" - "maddv.w $w5, $w16, $w18\n" - "maddv.w $w9, $w17, $w18\n" - "fill.w $w18, $a2\n" - "maddv.w $w2, $w15, $w18\n" - "maddv.w $w6, $w16, $w18\n" - "maddv.w $w10, $w17, $w18\n" - "fill.w $w18, $a3\n" - "maddv.w $w3, $w15, $w18\n" - "maddv.w $w7, $w16, $w18\n" - "maddv.w $w11, $w17, $w18\n" - - "addiu %[depth], -2\n" - GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" - GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n" - "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" - - // Store accumulators. - "st.w $w0, (0*16)(%[accum_ptr])\n" - "st.w $w4, (1*16)(%[accum_ptr])\n" - "st.w $w8, (2*16)(%[accum_ptr])\n" - "st.w $w1, (3*16)(%[accum_ptr])\n" - "st.w $w5, (4*16)(%[accum_ptr])\n" - "st.w $w9, (5*16)(%[accum_ptr])\n" - "st.w $w2, (6*16)(%[accum_ptr])\n" - "st.w $w6, (7*16)(%[accum_ptr])\n" - "st.w $w10, (8*16)(%[accum_ptr])\n" - "st.w $w3, (9*16)(%[accum_ptr])\n" - "st.w $w7, (10*16)(%[accum_ptr])\n" - "st.w $w11, (11*16)(%[accum_ptr])\n" - : // outputs - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [depth] "+r"(depth) - : // inputs - [accum_ptr] "r"(accum_ptr) - : // clobbers - "memory", - "a0", "a1", "a2", "a3", - "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", - "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", - "$f16", "$f17", "$f18", "$f19"); + vmlal_lane_s16(acc[i][4 * j + 2], vget_high_s16(lhs[i]), + vget_high_s16(rhs[j]), 2); + acc[i][4 * j + 3] = + vmlal_lane_s16(acc[i][4 * j + 3], vget_high_s16(lhs[i]), + vget_high_s16(rhs[j]), 3); + } + } + lhs_ptr += 24; + rhs_ptr += 8 * RhsCells; + } + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4 * RhsCells; j++) { + vst1q_s32(accum_ptr + 4 * (i + 3 * j), acc[i][j]); + } + } } }; -// Assembly implementation of the above -// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO). -// Using 16x16=32 multiplications. -// 20 MSA regs used: -// - 12 accumulators -// - 3 lhs -// - 4 rhs -// - 1 temps/zeroes -// ~45 instructions in the loop. -struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 { - typedef std::uint8_t OperandType; - typedef std::int32_t AccumulatorType; +using NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics = + NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<1>; + +using NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics = + NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<2>; + +template +struct NEON_GEMM_Float32_WithScalar_intrinsics { + typedef float OperandType; + typedef float AccumulatorType; typedef KernelFormat< - KernelSideFormat, 3>, - KernelSideFormat, 1> > + KernelSideFormat, 3>, + KernelSideFormat, RhsCells> > Format; - static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, AccumulatorType* accum_ptr, int depth) { - asm volatile( - // Load accumulators - "ld.w $w0, (0*16)(%[accum_ptr])\n" - "ld.w $w4, (1*16)(%[accum_ptr])\n" - "ld.w $w8, (2*16)(%[accum_ptr])\n" - "ld.w $w1, (3*16)(%[accum_ptr])\n" - "ld.w $w5, (4*16)(%[accum_ptr])\n" - "ld.w $w9, (5*16)(%[accum_ptr])\n" - "ld.w $w2, (6*16)(%[accum_ptr])\n" - "ld.w $w6, (7*16)(%[accum_ptr])\n" - "ld.w $w10, (8*16)(%[accum_ptr])\n" - "ld.w $w3, (9*16)(%[accum_ptr])\n" - "ld.w $w7, (10*16)(%[accum_ptr])\n" - "ld.w $w11, (11*16)(%[accum_ptr])\n" - // Set a temp to all zeroes. - "ldi.b $w19, 0\n" - - GEMMLOWP_LABEL_LOOP ":\n" - // Overview of register layout: - // - // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register - // contains 4 replicas of a pair of elements). - // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14. - // A 12x4 block of accumulators is stored in 32bit in w0-w11. - // - // +-----+-----+-----+-----+ - // Rhs | w15 | w16 | w17 | w18 | - // +-----+-----+-----+-----+ - // - // | | | | | - // - // Lhs | | | | | - // - // +---+ - - - - +-----+-----+-----+-----+ - // |w12| | w0 | w1 | w2 | w3 | - // |w12| | w0 | w1 | w2 | w3 | - // |w12| | w0 | w1 | w2 | w3 | - // |w12| | w0 | w1 | w2 | w3 | - // +---+ - - - - +-----+-----+-----+-----+ - // |w13| | w4 | w5 | w6 | w7 | - // |w13| | w4 | w5 | w6 | w7 | - // |w13| | w4 | w5 | w6 | w7 | - // |w13| | w4 | w5 | w6 | w7 | - // +---+ - - - - +-----+-----+-----+-----+ - // |w14| | w8 | w9 | w10 | w11 | - // |w14| | w8 | w9 | w10 | w11 | - // |w14| | w8 | w9 | w10 | w11 | - // |w14| | w8 | w9 | w10 | w11 | - // +---+ - - - - +-----+-----+-----+-----+ - // - // Accumulators - - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - "ld.b $w12, 0(%[lhs_ptr])\n" - "ld.b $w13, 8(%[lhs_ptr])\n" - - // Load 4 bytes of rhs[] for depth 0. - "lbu $a0, 0(%[rhs_ptr])\n" - "lbu $a1, 1(%[rhs_ptr])\n" - "lbu $a2, 2(%[rhs_ptr])\n" - "lbu $a3, 3(%[rhs_ptr])\n" - // Load 4 bytes of rhs[] for depth 1. - "lbu $v0, 4(%[rhs_ptr])\n" - "lbu $v1, 5(%[rhs_ptr])\n" - "lbu $t8, 6(%[rhs_ptr])\n" - "lbu $t9, 7(%[rhs_ptr])\n" - - // Zero-extend 8-bit elements of lhs[] to 16 bits. - "ilvr.b $w12, $w19, $w12\n" - "ilvl.b $w14, $w19, $w13\n" - "ilvr.b $w13, $w19, $w13\n" - // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. - "ilvl.d $w15, $w19, $w12\n" - "ilvl.d $w16, $w19, $w13\n" - "ilvl.d $w17, $w19, $w14\n" - "ilvr.h $w12, $w15, $w12\n" - "ilvr.h $w13, $w16, $w13\n" - "ilvr.h $w14, $w17, $w14\n" - - // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w. - "ins $a0, $v0, 16, 8\n" - "ins $a1, $v1, 16, 8\n" - "ins $a2, $t8, 16, 8\n" - "ins $a3, $t9, 16, 8\n" - // Make 4 replicas of every pair of rhs[] elements. - "fill.w $w15, $a0\n" - "fill.w $w16, $a1\n" - "fill.w $w17, $a2\n" - "fill.w $w18, $a3\n" - - // Depths 0 and 1. - // Dot-product-(and)-add doubles multiplicand width. - "dpadd_u.w $w0, $w12, $w15\n" - "dpadd_u.w $w4, $w13, $w15\n" - "dpadd_u.w $w8, $w14, $w15\n" - "dpadd_u.w $w1, $w12, $w16\n" - "dpadd_u.w $w5, $w13, $w16\n" - "dpadd_u.w $w9, $w14, $w16\n" - "dpadd_u.w $w2, $w12, $w17\n" - "dpadd_u.w $w6, $w13, $w17\n" - "dpadd_u.w $w10, $w14, $w17\n" - "dpadd_u.w $w3, $w12, $w18\n" - "dpadd_u.w $w7, $w13, $w18\n" - "dpadd_u.w $w11, $w14, $w18\n" - - "addiu %[depth], -2\n" - GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" - GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n" - "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" - - // Store accumulators. - "st.w $w0, (0*16)(%[accum_ptr])\n" - "st.w $w4, (1*16)(%[accum_ptr])\n" - "st.w $w8, (2*16)(%[accum_ptr])\n" - "st.w $w1, (3*16)(%[accum_ptr])\n" - "st.w $w5, (4*16)(%[accum_ptr])\n" - "st.w $w9, (5*16)(%[accum_ptr])\n" - "st.w $w2, (6*16)(%[accum_ptr])\n" - "st.w $w6, (7*16)(%[accum_ptr])\n" - "st.w $w10, (8*16)(%[accum_ptr])\n" - "st.w $w3, (9*16)(%[accum_ptr])\n" - "st.w $w7, (10*16)(%[accum_ptr])\n" - "st.w $w11, (11*16)(%[accum_ptr])\n" - : // outputs - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [depth] "+r"(depth) - : // inputs - [accum_ptr] "r"(accum_ptr) - : // clobbers - "memory", - "v0", "v1", - "a0", "a1", "a2", "a3", - "t8", "t9", - "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", - "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", - "$f16", "$f17", "$f18", "$f19"); + float32x4_t acc[3][4 * RhsCells]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4 * RhsCells; j++) { + acc[i][j] = vld1q_f32(accum_ptr + 4 * (i + 3 * j)); + } + } + for (int d = 0; d < depth; d++) { + float32x4_t lhs[3]; + for (int i = 0; i < 3; i++) { + lhs[i] = vld1q_f32(lhs_ptr + 4 * i); + } + float32x4_t rhs[RhsCells]; + for (int i = 0; i < RhsCells; i++) { + rhs[i] = vld1q_f32(rhs_ptr + 4 * i); + } + for (int i = 0; i < 3; i++) { + for (int j = 0; j < RhsCells; j++) { + acc[i][4 * j + 0] = vmlaq_lane_f32(acc[i][4 * j + 0], lhs[i], + vget_low_f32(rhs[j]), 0); + acc[i][4 * j + 1] = vmlaq_lane_f32(acc[i][4 * j + 1], lhs[i], + vget_low_f32(rhs[j]), 1); + acc[i][4 * j + 2] = vmlaq_lane_f32(acc[i][4 * j + 2], lhs[i], + vget_high_f32(rhs[j]), 0); + acc[i][4 * j + 3] = vmlaq_lane_f32(acc[i][4 * j + 3], lhs[i], + vget_high_f32(rhs[j]), 1); + } + } + lhs_ptr += 12; + rhs_ptr += 4 * RhsCells; + } + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4 * RhsCells; j++) { + vst1q_f32(accum_ptr + 4 * (i + 3 * j), acc[i][j]); + } + } } }; -// Using 32x32=32 multiplications. -// 32 MSA regs used: -// - 24 accumulators -// - 6 lhs -// - 1 rhs -// - 1 temps/zeroes -// ~95 instructions in the loop. -struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics { - typedef std::uint8_t OperandType; - typedef std::uint32_t AccumulatorType; +using NEON_32bit_GEMM_Float32_WithScalar_intrinsics = + NEON_GEMM_Float32_WithScalar_intrinsics<1>; + +using NEON_64bit_GEMM_Float32_WithScalar_intrinsics = + NEON_GEMM_Float32_WithScalar_intrinsics<2>; + +// C++ intrinsics-based variant of the deep, 7-bit, fast kernel +struct NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics { + typedef std::int8_t OperandType; + typedef std::int32_t AccumulatorType; typedef KernelFormat< - KernelSideFormat, 3>, - KernelSideFormat, 2> > + KernelSideFormat, 1>, + KernelSideFormat, 1> > Format; static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, AccumulatorType* accum_ptr, int depth) { - const v16i8 zeroes = __builtin_msa_ldi_b(0); - v4i32 acc[3][8]; - // Load accumulators. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 8; j++) { - acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0); + int32x4_t acc[4][2]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + acc[i][j] = vdupq_n_s32(0); } } - while (depth > 0) { - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - v8i16 lhs[6]; - lhs[0] = reinterpret_cast(__builtin_msa_ld_b(const_cast(lhs_ptr), 0)); - lhs[1] = - reinterpret_cast(__builtin_msa_ld_b(const_cast(lhs_ptr + 8), 0)); - - // Zero-extend 8-bit elements of lhs[] to 16 bits. - lhs[0] = reinterpret_cast(__builtin_msa_ilvr_b(zeroes, - reinterpret_cast(lhs[0]))); - lhs[2] = reinterpret_cast(__builtin_msa_ilvl_b(zeroes, - reinterpret_cast(lhs[1]))); - lhs[1] = reinterpret_cast(__builtin_msa_ilvr_b(zeroes, - reinterpret_cast(lhs[1]))); - - // Zero-extend 16-bit elements of lhs[] to 32 bits. - lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast(zeroes), lhs[0]); - lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast(zeroes), lhs[1]); - lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast(zeroes), lhs[2]); - lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast(zeroes), lhs[0]); - lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast(zeroes), lhs[1]); - lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast(zeroes), lhs[2]); - - // Depth 0. - for (int j = 0; j < 4; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast(__builtin_msa_fill_w(rhs_ptr[j])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast(lhs[i]), rhs); + int d = 0; + for (; d <= depth - 64; d += 64) { + int16x8_t local_acc[4][2]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + local_acc[i][j] = vdupq_n_s16(0); } } - for (int j = 4; j < 8; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast(__builtin_msa_fill_w(rhs_ptr[j + 4])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast(lhs[i]), rhs); + + // There are not enough registers to fit all lhs and rhs values for 64 + // depth. Instead, load values for 32 depth at a time. + for (int k = 0; k < 2; k++) { + int8x16_t lhs[4][2]; + for (int i = 0; i < 4; i++) { + lhs[i][0] = vld1q_s8(lhs_ptr + 16 * i + 128 * k); + lhs[i][1] = vld1q_s8(lhs_ptr + 64 + 16 * i + 128 * k); } - } - // Depth 1. - for (int j = 0; j < 4; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast(__builtin_msa_fill_w(rhs_ptr[j + 4])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast(lhs[i + 3]), rhs); + int8x16_t rhs[4]; + for (int i = 0; i < 4; i++) { + rhs[i] = vld1q_s8(rhs_ptr + 16 * i + 64 * k); } - } - for (int j = 4; j < 8; j++) { - // Load 1 byte of rhs, making 4 32-bit replicas of it. - v4i32 rhs = reinterpret_cast(__builtin_msa_fill_w(rhs_ptr[j + 8])); - // Multiply-add into accumulators. - for (int i = 0; i < 3; i++) { - acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast(lhs[i + 3]), rhs); + + for (int i = 0; i < 4; i++) { + if (k == 0) { + local_acc[i][0] = vmull_s8(vget_low_s8(lhs[i][0]), + vget_low_s8(rhs[0])); + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][1]), + vget_low_s8(rhs[2])); + local_acc[i][1] = vmull_s8(vget_low_s8(lhs[i][0]), + vget_low_s8(rhs[1])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], + vget_low_s8(lhs[i][1]), + vget_low_s8(rhs[3])); + } else { + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][0]), + vget_low_s8(rhs[0])); + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_low_s8(lhs[i][1]), + vget_low_s8(rhs[2])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_low_s8(lhs[i][0]), + vget_low_s8(rhs[1])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_low_s8(lhs[i][1]), + vget_low_s8(rhs[3])); + } + + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_high_s8(lhs[i][0]), + vget_high_s8(rhs[0])); + local_acc[i][0] = vmlal_s8(local_acc[i][0], vget_high_s8(lhs[i][1]), + vget_high_s8(rhs[2])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_high_s8(lhs[i][0]), + vget_high_s8(rhs[1])); + local_acc[i][1] = vmlal_s8(local_acc[i][1], vget_high_s8(lhs[i][1]), + vget_high_s8(rhs[3])); } } - lhs_ptr += 24; - rhs_ptr += 16; - depth -= 2; + for (int i = 0; i < 4; i++) { + acc[i][0] = vpadalq_s16(acc[i][0], local_acc[i][0]); + acc[i][1] = vpadalq_s16(acc[i][1], local_acc[i][1]); + } + + lhs_ptr += 64 * 4; + rhs_ptr += 64 * 2; } + for (; d <= depth - 16; d += 16) { + int8x16_t lhs[4]; + for (int i = 0; i < 4; i++) { + lhs[i] = vld1q_s8(lhs_ptr + 16 * i); + } + int8x16_t rhs[2]; + for (int i = 0; i < 2; i++) { + rhs[i] = vld1q_s8(rhs_ptr + 16 * i); + } - // Store accumulators. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 8; j++) { - __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + int16x8_t local_acc = + vmull_s8(vget_low_s8(lhs[i]), vget_low_s8(rhs[j])); + local_acc = + vmlal_s8(local_acc, vget_high_s8(lhs[i]), vget_high_s8(rhs[j])); + acc[i][j] = vpadalq_s16(acc[i][j], local_acc); + } } + lhs_ptr += 16 * 4; + rhs_ptr += 16 * 2; + } + for (int i = 0; i < 2; i++) { + int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]); + int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]); + int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1); + int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i); + dst_val = vaddq_s32(dst_val, acc_4x); + vst1q_s32(accum_ptr + 4 * i, dst_val); } } }; -// Assembly implementation of the above -// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics. -// Using 32x32=32 multiplications. -// 32 MSA regs used: -// - 24 accumulators -// - 6 lhs -// - 1 rhs -// - 1 temps/zeroes -// ~95 instructions in the loop. -struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly { - typedef std::uint8_t OperandType; - typedef std::uint32_t AccumulatorType; - typedef KernelFormat< - KernelSideFormat, 3>, - KernelSideFormat, 2> > - Format; - static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, - AccumulatorType* accum_ptr, int depth) { - asm volatile( - // Load accumulators - "ld.w $w0, (0*16)(%[accum_ptr])\n" - "ld.w $w4, (1*16)(%[accum_ptr])\n" - "ld.w $w8, (2*16)(%[accum_ptr])\n" - "ld.w $w1, (3*16)(%[accum_ptr])\n" - "ld.w $w5, (4*16)(%[accum_ptr])\n" - "ld.w $w9, (5*16)(%[accum_ptr])\n" - "ld.w $w2, (6*16)(%[accum_ptr])\n" - "ld.w $w6, (7*16)(%[accum_ptr])\n" - "ld.w $w10, (8*16)(%[accum_ptr])\n" - "ld.w $w3, (9*16)(%[accum_ptr])\n" - "ld.w $w7, (10*16)(%[accum_ptr])\n" - "ld.w $w11, (11*16)(%[accum_ptr])\n" - "ld.w $w12, (12*16)(%[accum_ptr])\n" - "ld.w $w16, (13*16)(%[accum_ptr])\n" - "ld.w $w20, (14*16)(%[accum_ptr])\n" - "ld.w $w13, (15*16)(%[accum_ptr])\n" - "ld.w $w17, (16*16)(%[accum_ptr])\n" - "ld.w $w21, (17*16)(%[accum_ptr])\n" - "ld.w $w14, (18*16)(%[accum_ptr])\n" - "ld.w $w18, (19*16)(%[accum_ptr])\n" - "ld.w $w22, (20*16)(%[accum_ptr])\n" - "ld.w $w15, (21*16)(%[accum_ptr])\n" - "ld.w $w19, (22*16)(%[accum_ptr])\n" - "ld.w $w23, (23*16)(%[accum_ptr])\n" - // Set a temp to all zeroes. - "ldi.b $w31, 0\n" - - GEMMLOWP_LABEL_LOOP ":\n" - // Overview of register layout: - // - // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30. - // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29. - // A 12x8 block of accumulators is stored in 32bit in w0-w23. - // - // +------+------+------+------+ - // Rhs |w30[0]|w30[1]|w30[2]|w30[3]| - // +------+------+------+------+ - // - // | | | | | - // - // Lhs | | | | | - // - // +---+---+ - - - - +------+------+------+------+ - // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | - // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | - // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | - // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | - // +---+---+ - - - - +------+------+------+------+ - // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | - // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | - // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | - // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | - // +---+---+ - - - - +------+------+------+------+ - // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| - // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| - // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| - // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| - // +---+---+ - - - - +------+------+------+------+ - // - // Accumulator +SET_7BIT_RANGES(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics); - // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. - "ld.b $w24, 0(%[lhs_ptr])\n" - "ld.b $w25, 8(%[lhs_ptr])\n" +// C++ intrinsics-based variant of the deep, 4.25-bit, fast kernel. +struct NEON_64bit_GEMM_Int425Operands_intrinsics { + typedef std::int8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat, 1>, + KernelSideFormat, 1> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + int32x4_t acc[4][2]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + acc[i][j] = vdupq_n_s32(0); + } + } - // Load 4 bytes of rhs[] for the first half of depth 0. - "lbu $a0, 0(%[rhs_ptr])\n" - "lbu $a1, 1(%[rhs_ptr])\n" - "lbu $a2, 2(%[rhs_ptr])\n" - "lbu $a3, 3(%[rhs_ptr])\n" + const int num_outer_depth_loop = depth / 512 + 1; + int d = 0; + for (int od = 0; od < num_outer_depth_loop; od++) { + int16x8_t local_acc[4][2]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + local_acc[i][j] = vdupq_n_s16(0); + } + } + for (int k = 0; k < 16 && d <= depth - 32; k++, d += 32) { + int8x16_t lhs[8]; + for (int i = 0; i < 8; i++) { + lhs[i] = vld1q_s8(lhs_ptr + 16 * i); + } + int8x16_t rhs[4]; + for (int i = 0; i < 4; i++) { + rhs[i] = vld1q_s8(rhs_ptr + 16 * i); + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + int8x16_t temp_acc = vmulq_s8(lhs[i * 2], rhs[j * 2]); + temp_acc = vmlaq_s8(temp_acc, lhs[i * 2 + 1], rhs[j * 2 + 1]); + local_acc[i][j] = vpadalq_s8(local_acc[i][j], temp_acc); + } + } + lhs_ptr += 128; + rhs_ptr += 64; + } - // Zero-extend 8-bit elements of lhs[] to 16 bits. - "ilvr.b $w24, $w31, $w24\n" - "ilvl.b $w26, $w31, $w25\n" - "ilvr.b $w25, $w31, $w25\n" - // Zero-extend 16-bit elements of lhs[] to 32 bits. - "ilvl.h $w27, $w31, $w24\n" - "ilvl.h $w28, $w31, $w25\n" - "ilvl.h $w29, $w31, $w26\n" - "ilvr.h $w24, $w31, $w24\n" - "ilvr.h $w25, $w31, $w25\n" - "ilvr.h $w26, $w31, $w26\n" - - // Depth 0. - "fill.w $w30, $a0\n" - "lbu $a0, 8(%[rhs_ptr])\n" - "maddv.w $w0, $w24, $w30\n" - "maddv.w $w4, $w25, $w30\n" - "maddv.w $w8, $w26, $w30\n" - "fill.w $w30, $a1\n" - "lbu $a1, 9(%[rhs_ptr])\n" - "maddv.w $w1, $w24, $w30\n" - "maddv.w $w5, $w25, $w30\n" - "maddv.w $w9, $w26, $w30\n" - "fill.w $w30, $a2\n" - "lbu $a2, 10(%[rhs_ptr])\n" - "maddv.w $w2, $w24, $w30\n" - "maddv.w $w6, $w25, $w30\n" - "maddv.w $w10, $w26, $w30\n" - "fill.w $w30, $a3\n" - "lbu $a3, 11(%[rhs_ptr])\n" - "maddv.w $w3, $w24, $w30\n" - "maddv.w $w7, $w25, $w30\n" - "maddv.w $w11, $w26, $w30\n" - - "fill.w $w30, $a0\n" - "lbu $a0, 4(%[rhs_ptr])\n" - "maddv.w $w12, $w24, $w30\n" - "maddv.w $w16, $w25, $w30\n" - "maddv.w $w20, $w26, $w30\n" - "fill.w $w30, $a1\n" - "lbu $a1, 5(%[rhs_ptr])\n" - "maddv.w $w13, $w24, $w30\n" - "maddv.w $w17, $w25, $w30\n" - "maddv.w $w21, $w26, $w30\n" - "fill.w $w30, $a2\n" - "lbu $a2, 6(%[rhs_ptr])\n" - "maddv.w $w14, $w24, $w30\n" - "maddv.w $w18, $w25, $w30\n" - "maddv.w $w22, $w26, $w30\n" - "fill.w $w30, $a3\n" - "lbu $a3, 7(%[rhs_ptr])\n" - "maddv.w $w15, $w24, $w30\n" - "maddv.w $w19, $w25, $w30\n" - "maddv.w $w23, $w26, $w30\n" - - // Depth 1. - "fill.w $w30, $a0\n" - "lbu $a0, 12(%[rhs_ptr])\n" - "maddv.w $w0, $w27, $w30\n" - "maddv.w $w4, $w28, $w30\n" - "maddv.w $w8, $w29, $w30\n" - "fill.w $w30, $a1\n" - "lbu $a1, 13(%[rhs_ptr])\n" - "maddv.w $w1, $w27, $w30\n" - "maddv.w $w5, $w28, $w30\n" - "maddv.w $w9, $w29, $w30\n" - "fill.w $w30, $a2\n" - "lbu $a2, 14(%[rhs_ptr])\n" - "maddv.w $w2, $w27, $w30\n" - "maddv.w $w6, $w28, $w30\n" - "maddv.w $w10, $w29, $w30\n" - "fill.w $w30, $a3\n" - "lbu $a3, 15(%[rhs_ptr])\n" - "maddv.w $w3, $w27, $w30\n" - "maddv.w $w7, $w28, $w30\n" - "maddv.w $w11, $w29, $w30\n" - - "fill.w $w30, $a0\n" - "maddv.w $w12, $w27, $w30\n" - "maddv.w $w16, $w28, $w30\n" - "maddv.w $w20, $w29, $w30\n" - "fill.w $w30, $a1\n" - "maddv.w $w13, $w27, $w30\n" - "maddv.w $w17, $w28, $w30\n" - "maddv.w $w21, $w29, $w30\n" - "fill.w $w30, $a2\n" - "maddv.w $w14, $w27, $w30\n" - "maddv.w $w18, $w28, $w30\n" - "maddv.w $w22, $w29, $w30\n" - "fill.w $w30, $a3\n" - "maddv.w $w15, $w27, $w30\n" - "maddv.w $w19, $w28, $w30\n" - "maddv.w $w23, $w29, $w30\n" + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + acc[i][j] = vpadalq_s16(acc[i][j], local_acc[i][j]); + } + } + } - "addiu %[depth], -2\n" - GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" - GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n" - "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" + for (int i = 0; i < 2; i++) { + int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]); + int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]); + int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1); - // Store accumulators. - "st.w $w0, (0*16)(%[accum_ptr])\n" - "st.w $w4, (1*16)(%[accum_ptr])\n" - "st.w $w8, (2*16)(%[accum_ptr])\n" - "st.w $w1, (3*16)(%[accum_ptr])\n" - "st.w $w5, (4*16)(%[accum_ptr])\n" - "st.w $w9, (5*16)(%[accum_ptr])\n" - "st.w $w2, (6*16)(%[accum_ptr])\n" - "st.w $w6, (7*16)(%[accum_ptr])\n" - "st.w $w10, (8*16)(%[accum_ptr])\n" - "st.w $w3, (9*16)(%[accum_ptr])\n" - "st.w $w7, (10*16)(%[accum_ptr])\n" - "st.w $w11, (11*16)(%[accum_ptr])\n" - "st.w $w12, (12*16)(%[accum_ptr])\n" - "st.w $w16, (13*16)(%[accum_ptr])\n" - "st.w $w20, (14*16)(%[accum_ptr])\n" - "st.w $w13, (15*16)(%[accum_ptr])\n" - "st.w $w17, (16*16)(%[accum_ptr])\n" - "st.w $w21, (17*16)(%[accum_ptr])\n" - "st.w $w14, (18*16)(%[accum_ptr])\n" - "st.w $w18, (19*16)(%[accum_ptr])\n" - "st.w $w22, (20*16)(%[accum_ptr])\n" - "st.w $w15, (21*16)(%[accum_ptr])\n" - "st.w $w19, (22*16)(%[accum_ptr])\n" - "st.w $w23, (23*16)(%[accum_ptr])\n" - : // outputs - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [depth] "+r"(depth) - : // inputs - [accum_ptr] "r"(accum_ptr) - : // clobbers - "memory", - "a0", "a1", "a2", "a3", - "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", - "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", - "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", - "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31"); + int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i); + dst_val = vaddq_s32(dst_val, acc_4x); + vst1q_s32(accum_ptr + 4 * i, dst_val); + } } }; -// Assembly implementation of the above -// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO). -// Using 16x16=32 multiplications. -// 32 MSA regs used: -// - 24 accumulators -// - 3 lhs -// - 4 rhs -// - 1 temps/zeroes -// ~70 instructions in the loop. -struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 { +SET_425BIT_RANGES(NEON_64bit_GEMM_Int425Operands_intrinsics); + +#endif // __arm__ || __aarch64__ + +#ifdef __mips +// 12x8 depth 2 depth-major kernel. +struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators1 { typedef std::uint8_t OperandType; typedef std::uint32_t AccumulatorType; typedef KernelFormat< @@ -4819,6 +5119,476 @@ struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 { "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31"); } }; + +// 12x8 depth 2 width-major kernel. +// Does less shuffling and replication than the kernel above. +struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators2 { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat, 3>, + KernelSideFormat, 2> > + Format; + static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "ld.w $w0, (0*16)(%[accum_ptr])\n" + "ld.w $w4, (1*16)(%[accum_ptr])\n" + "ld.w $w8, (2*16)(%[accum_ptr])\n" + "ld.w $w1, (3*16)(%[accum_ptr])\n" + "ld.w $w5, (4*16)(%[accum_ptr])\n" + "ld.w $w9, (5*16)(%[accum_ptr])\n" + "ld.w $w2, (6*16)(%[accum_ptr])\n" + "ld.w $w6, (7*16)(%[accum_ptr])\n" + "ld.w $w10, (8*16)(%[accum_ptr])\n" + "ld.w $w3, (9*16)(%[accum_ptr])\n" + "ld.w $w7, (10*16)(%[accum_ptr])\n" + "ld.w $w11, (11*16)(%[accum_ptr])\n" + "ld.w $w12, (12*16)(%[accum_ptr])\n" + "ld.w $w16, (13*16)(%[accum_ptr])\n" + "ld.w $w20, (14*16)(%[accum_ptr])\n" + "ld.w $w13, (15*16)(%[accum_ptr])\n" + "ld.w $w17, (16*16)(%[accum_ptr])\n" + "ld.w $w21, (17*16)(%[accum_ptr])\n" + "ld.w $w14, (18*16)(%[accum_ptr])\n" + "ld.w $w18, (19*16)(%[accum_ptr])\n" + "ld.w $w22, (20*16)(%[accum_ptr])\n" + "ld.w $w15, (21*16)(%[accum_ptr])\n" + "ld.w $w19, (22*16)(%[accum_ptr])\n" + "ld.w $w23, (23*16)(%[accum_ptr])\n" + + GEMMLOWP_LABEL_LOOP + ":\n" + // Overview of register layout: + // + // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31 + // (each register contains 4 replicas of a pair of elements). + // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26. + // A 12x8 block of accumulators is stored in 32bit in w0-w23. + // + // +------+------+------+------+ + // Rhs |w28 |w29 |w30 |w31 | + // +------+------+------+------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +---+ - - - - +------+------+------+------+ + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // +---+ - - - - +------+------+------+------+ + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // +---+ - - - - +------+------+------+------+ + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // +---+ - - - - +------+------+------+------+ + // + // Accumulators + + // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. + "ld.b $w24, 0(%[lhs_ptr])\n" + "ld.b $w25, 8(%[lhs_ptr])\n" + + // Load 2 x 8 bytes of rhs[]. + "ld.b $w27, 0(%[rhs_ptr])\n" + + // Zero-extend 8-bit elements of lhs[] to 16 bits. + "ldi.b $w31, 0\n" + "ilvr.b $w24, $w31, $w24\n" + "ilvl.b $w26, $w31, $w25\n" + "ilvr.b $w25, $w31, $w25\n" + + // First half of depths 0 and 1. + // Zero-extend 8-bit elements of rhs[] to 16 bits. + "ilvr.b $w31, $w31, $w27\n" + // Make 4 replicas of every pair of rhs[] elements. + "splati.w $w28, $w31[0]\n" + "splati.w $w29, $w31[1]\n" + "splati.w $w30, $w31[2]\n" + "splati.w $w31, $w31[3]\n" + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w0, $w24, $w28\n" + "dpadd_u.w $w4, $w25, $w28\n" + "dpadd_u.w $w8, $w26, $w28\n" + "dpadd_u.w $w1, $w24, $w29\n" + "dpadd_u.w $w5, $w25, $w29\n" + "dpadd_u.w $w9, $w26, $w29\n" + "dpadd_u.w $w2, $w24, $w30\n" + "dpadd_u.w $w6, $w25, $w30\n" + "dpadd_u.w $w10, $w26, $w30\n" + "dpadd_u.w $w3, $w24, $w31\n" + "dpadd_u.w $w7, $w25, $w31\n" + "dpadd_u.w $w11, $w26, $w31\n" + + // Second half of depths 0 and 1. + // Zero-extend 8-bit elements of rhs[] to 16 bits. + "ldi.b $w31, 0\n" + "ilvl.b $w31, $w31, $w27\n" + // Make 4 replicas of every pair of rhs[] elements. + "splati.w $w28, $w31[0]\n" + "splati.w $w29, $w31[1]\n" + "splati.w $w30, $w31[2]\n" + "splati.w $w31, $w31[3]\n" + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w12, $w24, $w28\n" + "dpadd_u.w $w16, $w25, $w28\n" + "dpadd_u.w $w20, $w26, $w28\n" + "dpadd_u.w $w13, $w24, $w29\n" + "dpadd_u.w $w17, $w25, $w29\n" + "dpadd_u.w $w21, $w26, $w29\n" + "dpadd_u.w $w14, $w24, $w30\n" + "dpadd_u.w $w18, $w25, $w30\n" + "dpadd_u.w $w22, $w26, $w30\n" + "dpadd_u.w $w15, $w24, $w31\n" + "dpadd_u.w $w19, $w25, $w31\n" + "dpadd_u.w $w23, $w26, $w31\n" + + "addiu %[depth], -2\n" GEMMLOWP_MIPS_XADDIU + " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU + " %[rhs_ptr], 16\n" + "bnez %[depth]," GEMMLOWP_LABEL_LOOP + "b\n" + + // Store accumulators. + "st.w $w0, (0*16)(%[accum_ptr])\n" + "st.w $w4, (1*16)(%[accum_ptr])\n" + "st.w $w8, (2*16)(%[accum_ptr])\n" + "st.w $w1, (3*16)(%[accum_ptr])\n" + "st.w $w5, (4*16)(%[accum_ptr])\n" + "st.w $w9, (5*16)(%[accum_ptr])\n" + "st.w $w2, (6*16)(%[accum_ptr])\n" + "st.w $w6, (7*16)(%[accum_ptr])\n" + "st.w $w10, (8*16)(%[accum_ptr])\n" + "st.w $w3, (9*16)(%[accum_ptr])\n" + "st.w $w7, (10*16)(%[accum_ptr])\n" + "st.w $w11, (11*16)(%[accum_ptr])\n" + "st.w $w12, (12*16)(%[accum_ptr])\n" + "st.w $w16, (13*16)(%[accum_ptr])\n" + "st.w $w20, (14*16)(%[accum_ptr])\n" + "st.w $w13, (15*16)(%[accum_ptr])\n" + "st.w $w17, (16*16)(%[accum_ptr])\n" + "st.w $w21, (17*16)(%[accum_ptr])\n" + "st.w $w14, (18*16)(%[accum_ptr])\n" + "st.w $w18, (19*16)(%[accum_ptr])\n" + "st.w $w22, (20*16)(%[accum_ptr])\n" + "st.w $w15, (21*16)(%[accum_ptr])\n" + "st.w $w19, (22*16)(%[accum_ptr])\n" + "st.w $w23, (23*16)(%[accum_ptr])\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", + "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", + "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", + "$f27", "$f28", "$f29", "$f30", "$f31"); + } +}; + +// 4x4 depth 16 width-major kernel operating on int8 operands. +// It is assumed that one of the two int8 operands only takes values +// in [-127, 127], while the other may freely range in [-128, 127]. +// The issue with both operands taking the value -128 is that: +// -128*-128 + -128*-128 == -32768 overflows int16. +// Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 +// range. That is the basic idea of this kernel. +struct MSA_GEMM_Int8Operands_AccumTwoWithin16Bits { + typedef std::int8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat, 1>, + KernelSideFormat, 1> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + std::size_t start_depth = 123; + std::size_t run_depth = depth; + std::size_t dst_col_stride = 4; + AccumulatorType* dst_ptr = accum_ptr; +#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1" +#define GEMMLOWP_LABEL_LOOP "2" +#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3" +#define GEMMLOWP_LABEL_STORE "4" + asm volatile( + GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n" + // Load lhs[] and rhs[], zero out internal accumulators. + "ld.b $w16, 0(%[lhs_ptr])\n" + "ldi.b $w0, 0\n" + "ld.b $w20, 0(%[rhs_ptr])\n" + "ldi.b $w1, 0\n" + "ld.b $w17, 16(%[lhs_ptr])\n" + "ldi.b $w2, 0\n" + "ld.b $w21, 16(%[rhs_ptr])\n" + "ldi.b $w3, 0\n" + "ld.b $w18, 32(%[lhs_ptr])\n" + "ldi.b $w4, 0\n" + "ld.b $w19, 48(%[lhs_ptr])\n" + "ldi.b $w5, 0\n" + "ld.b $w22, 32(%[rhs_ptr])\n" + "ldi.b $w6, 0\n" + "ld.b $w23, 48(%[rhs_ptr])\n" + "ldi.b $w7, 0\n" + "ldi.b $w8, 0\n" + "ldi.b $w9, 0\n" + "ldi.b $w10, 0\n" + "ldi.b $w11, 0\n" + "ldi.b $w12, 0\n" + "ldi.b $w13, 0\n" + "ldi.b $w14, 0\n" + "ldi.b $w15, 0\n" + "ldi.h $w31, 1\n" + // If the loop depth is only 16, then we can skip the general loop + // and go straight to the final part of the code. + "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n" + + GEMMLOWP_LABEL_LOOP ":\n" + // Overview of register layout: + // + // A 4x16 block of Rhs is stored in 8 bit in w16-w19. + // A 4x16 block of Lhs is stored in 8 bit in w20-w23. + // + // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit + // components which need to be horizontally added at the end). + // + // Dot products of Lhs and Rhs are 16-bit values, which can't + // immediately be accumulated in 32-bit accumulators by that + // same instruction that calculates them. + // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit + // sums in w25 (note, the 16 sums have already been reduced to 8 + // by the horizontal addition of the dotp instruction). + // They are then sign-extended to 32 bits, horizontally added + // (again) to form 4 32-bit sums and then they are finally added + // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31". + // + // +-----+-----+-----+-----+ + // Rhs | w20 | w21 | w22 | w23 | + // +-----+-----+-----+-----+ + // + // | | | | | + // + // Lhs | | | | | + // + // +---+ - - - - +-----+-----+-----+-----+ + // |w16| | w0 | w4 | w8 | w12 | + // |w17| | w1 | w5 | w9 | w13 | + // |w18| | w2 | w6 | w10 | w14 | + // |w19| | w3 | w7 | w11 | w15 | + // +---+ - - - - +-----+-----+-----+-----+ + // + // Accumulators + + // Calculate the results for 16 depths and load + // lhs[] and rhs[] for the next iteration. + GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n" + GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n" + GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w16, $w20\n" + "dotp_s.h $w26, $w17, $w20\n" + "dotp_s.h $w27, $w16, $w21\n" + "dotp_s.h $w28, $w17, $w21\n" + "dotp_s.h $w29, $w18, $w20\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w0, $w25, $w31\n" + "dpadd_s.w $w1, $w26, $w31\n" + "dpadd_s.w $w4, $w27, $w31\n" + "dpadd_s.w $w5, $w28, $w31\n" + "dpadd_s.w $w2, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w24, $w16, $w22\n" + "dotp_s.h $w25, $w19, $w20\n" + "dotp_s.h $w26, $w16, $w23\n" + "dotp_s.h $w27, $w17, $w22\n" + "ld.b $w20, 0(%[rhs_ptr])\n" + "dotp_s.h $w28, $w17, $w23\n" + "ld.b $w16, 0(%[lhs_ptr])\n" + "dotp_s.h $w29, $w18, $w21\n" + "ld.b $w17, 16(%[lhs_ptr])\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w8, $w24, $w31\n" + "dpadd_s.w $w3, $w25, $w31\n" + "dpadd_s.w $w12, $w26, $w31\n" + "dpadd_s.w $w9, $w27, $w31\n" + "dpadd_s.w $w13, $w28, $w31\n" + "dpadd_s.w $w6, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w19, $w21\n" + "dotp_s.h $w26, $w18, $w22\n" + "dotp_s.h $w27, $w18, $w23\n" + "ld.b $w21, 16(%[rhs_ptr])\n" + "dotp_s.h $w28, $w19, $w22\n" + "ld.b $w18, 32(%[lhs_ptr])\n" + "dotp_s.h $w29, $w19, $w23\n" + "ld.b $w22, 32(%[rhs_ptr])\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w7, $w25, $w31\n" + "ld.b $w19, 48(%[lhs_ptr])\n" + "dpadd_s.w $w10, $w26, $w31\n" + "ld.b $w23, 48(%[rhs_ptr])\n" + "dpadd_s.w $w14, $w27, $w31\n" + "dpadd_s.w $w11, $w28, $w31\n" + "dpadd_s.w $w15, $w29, $w31\n" + + "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n" + + GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n" + // Calculate the results for the last 16 depths. + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w16, $w20\n" + "dotp_s.h $w26, $w17, $w20\n" + "dotp_s.h $w27, $w16, $w21\n" + "dotp_s.h $w28, $w17, $w21\n" + "dotp_s.h $w29, $w18, $w20\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w0, $w25, $w31\n" + "dpadd_s.w $w1, $w26, $w31\n" + "dpadd_s.w $w4, $w27, $w31\n" + "dpadd_s.w $w5, $w28, $w31\n" + "dpadd_s.w $w2, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w24, $w16, $w22\n" + "dotp_s.h $w25, $w19, $w20\n" + "dotp_s.h $w26, $w16, $w23\n" + "dotp_s.h $w27, $w17, $w22\n" + "dotp_s.h $w28, $w17, $w23\n" + "dotp_s.h $w29, $w18, $w21\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w8, $w24, $w31\n" + "dpadd_s.w $w3, $w25, $w31\n" + "dpadd_s.w $w12, $w26, $w31\n" + "dpadd_s.w $w9, $w27, $w31\n" + "dpadd_s.w $w13, $w28, $w31\n" + "dpadd_s.w $w6, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w19, $w21\n" + "dotp_s.h $w26, $w18, $w22\n" + "dotp_s.h $w27, $w18, $w23\n" + "dotp_s.h $w28, $w19, $w22\n" + "dotp_s.h $w29, $w19, $w23\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w7, $w25, $w31\n" + "dpadd_s.w $w10, $w26, $w31\n" + "dpadd_s.w $w14, $w27, $w31\n" + "dpadd_s.w $w11, $w28, $w31\n" + "dpadd_s.w $w15, $w29, $w31\n" + + // Horizontal-add internal accumulators. + "hadd_s.d $w0, $w0, $w0\n" + "hadd_s.d $w1, $w1, $w1\n" + "hadd_s.d $w2, $w2, $w2\n" + "hadd_s.d $w3, $w3, $w3\n" + "hadd_s.d $w4, $w4, $w4\n" + "hadd_s.d $w5, $w5, $w5\n" + "hadd_s.d $w6, $w6, $w6\n" + "hadd_s.d $w7, $w7, $w7\n" + "hadd_s.d $w8, $w8, $w8\n" + "hadd_s.d $w9, $w9, $w9\n" + "hadd_s.d $w10, $w10, $w10\n" + "hadd_s.d $w11, $w11, $w11\n" + "hadd_s.d $w12, $w12, $w12\n" + "hadd_s.d $w13, $w13, $w13\n" + "hadd_s.d $w14, $w14, $w14\n" + "hadd_s.d $w15, $w15, $w15\n" + "pckev.w $w0, $w1, $w0\n" + "pckev.w $w2, $w3, $w2\n" + "pckev.w $w4, $w5, $w4\n" + "pckev.w $w6, $w7, $w6\n" + "pckev.w $w8, $w9, $w8\n" + "pckev.w $w10, $w11, $w10\n" + "pckev.w $w12, $w13, $w12\n" + "pckev.w $w14, $w15, $w14\n" + "hadd_s.d $w0, $w0, $w0\n" + "hadd_s.d $w2, $w2, $w2\n" + "hadd_s.d $w4, $w4, $w4\n" + "hadd_s.d $w6, $w6, $w6\n" + "hadd_s.d $w8, $w8, $w8\n" + "hadd_s.d $w10, $w10, $w10\n" + "hadd_s.d $w12, $w12, $w12\n" + "hadd_s.d $w14, $w14, $w14\n" + // 4 more pckev instructions follow in both paths below. + + // Check if start_depth==0 to decide whether we will load + // existing accumulators from memory. + "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n" + + "pckev.w $w0, $w2, $w0\n" + "pckev.w $w1, $w6, $w4\n" + "pckev.w $w2, $w10, $w8\n" + "pckev.w $w3, $w14, $w12\n" + + "b " GEMMLOWP_LABEL_STORE "f\n" + + GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n" + // Load accumulators from memory. + "ld.w $w16, 0(%[dst_ptr0])\n" + "pckev.w $w0, $w2, $w0\n" + "ld.w $w17, 0(%[dst_ptr1])\n" + "pckev.w $w1, $w6, $w4\n" + "ld.w $w18, 0(%[dst_ptr2])\n" + "pckev.w $w2, $w10, $w8\n" + "ld.w $w19, 0(%[dst_ptr3])\n" + "pckev.w $w3, $w14, $w12\n" + + // Add them to internal accumulators. + "addv.w $w0, $w0, $w16\n" + "addv.w $w1, $w1, $w17\n" + "addv.w $w2, $w2, $w18\n" + "addv.w $w3, $w3, $w19\n" + + GEMMLOWP_LABEL_STORE ":\n" + // Store accumulators. + "st.w $w0, 0(%[dst_ptr0])\n" + "st.w $w1, 0(%[dst_ptr1])\n" + "st.w $w2, 0(%[dst_ptr2])\n" + "st.w $w3, 0(%[dst_ptr3])\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [run_depth] "+r"(run_depth) + : // inputs + [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride), + [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2), + [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3), + [start_depth] "r"(start_depth) + : // clobbers + "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", + "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", + "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", + "$f27", "$f28", "$f29", "$f30", "$f31"); +#undef GEMMLOWP_LABEL_LOOP +#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16 +#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES +#undef GEMMLOWP_LABEL_STORE + } +}; #endif // __mips // BEGIN code copied from gemmlowp/internal/kernel_reference.h @@ -4901,13 +5671,10 @@ class CacheLineAlignedBuffer { }; template -void FillRandom(CacheLineAlignedBuffer* buffer) { +void FillRandom(CacheLineAlignedBuffer* buffer, DataType min, + DataType max) { static std::mt19937 generator(0); - // 100 is smaller than any nonzero bound of the range of any data type. - const DataType kMaxVal = DataType(100); - const DataType kMinVal = - std::is_signed::value ? -kMaxVal : DataType(0); - std::uniform_real_distribution dist(kMinVal, kMaxVal); + std::uniform_real_distribution dist(min, max); for (std::size_t i = 0; i < buffer->size(); i++) { buffer->data()[i] = DataType(dist(generator)); } @@ -4971,9 +5738,16 @@ void test_kernel(int depth, const char* kernel_name) { CacheLineAlignedBuffer accum_reference(kLhsWidth * kRhsWidth); - FillRandom(&lhs); - FillRandom(&rhs); - FillRandom(&accum_initial); + FillRandom(&lhs, KernelOperandRanges::LhsMin(), + KernelOperandRanges::LhsMax()); + FillRandom(&rhs, KernelOperandRanges::RhsMin(), + KernelOperandRanges::RhsMax()); + FillRandom(&accum_initial, + std::is_signed::value + ? AccumulatorType(-100) + : AccumulatorType(0), + AccumulatorType(100)); + Copy(&accum, accum_initial); Copy(&accum_reference, accum_initial); @@ -5159,6 +5933,10 @@ int main() { #endif #ifdef __aarch64__ + BENCHMARK(NEON_64bit_GEMM_Int425Operands); + BENCHMARK(NEON_64bit_GEMM_Int425Operands_intrinsics); + BENCHMARK(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits); + BENCHMARK(NEON_64bit_GEMM_Int7Operands_AccumEightWithin16Bits_intrinsics); BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits); BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics); BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators); @@ -5167,6 +5945,7 @@ int main() { #ifdef __ARM_FEATURE_DOTPROD BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct); BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1); + BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_narrow); #endif BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar); BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar); @@ -5180,12 +5959,9 @@ int main() { #endif #ifdef __mips - BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics); - BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly); - BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2); - BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics); - BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly); - BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2); + BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators1); + BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators2); + BENCHMARK(MSA_GEMM_Int8Operands_AccumTwoWithin16Bits); #endif return 0; diff --git a/test/benchmark.cc b/test/benchmark.cc index 9a87a41..d8236de 100644 --- a/test/benchmark.cc +++ b/test/benchmark.cc @@ -36,7 +36,16 @@ #warning "Building without NEON support on ARM, check your compiler setup!" #endif -#if defined(__SSE4_2__) && !defined(GEMMLOWP_SSE4) +#if defined(__mips) && !defined(GEMMLOWP_MSA) +#warning "Building without MSA support on MIPS, check your compiler setup!" +#endif + +#if defined(__AVX2__) && !defined(GEMMLOWP_AVX2) +#warning \ + "Building without AVX2 support on AVX2 enabled machine, check your compiler setup!" +#endif + +#if defined(__SSE4_2__) && !defined(GEMMLOWP_AVX2) && !defined(GEMMLOWP_SSE4) #warning \ "Building without SSE4.2 support on SSE4.2 enabled machine, check your compiler setup!" #endif diff --git a/test/benchmark_all_sizes.cc b/test/benchmark_all_sizes.cc index 16cc57c..527aad6 100644 --- a/test/benchmark_all_sizes.cc +++ b/test/benchmark_all_sizes.cc @@ -16,6 +16,10 @@ test/benchmark_all_sizes.cc -o /tmp/b -O3 --std=c++11 -fPIE -static \ #include "../public/gemmlowp.h" +#ifdef GEMMLOWP_PROFILING +#include "../profiling/profiler.h" +#endif + #if defined GEMMLOWP_ANDROID && defined GEMMLOWP_ARM_32 // Compilation workaround namespace std { @@ -122,10 +126,10 @@ float benchmark_8bit(int rows, int depth, int cols) { MakeZero(&rhs); MakeZero(&result); - typedef std::tuple Pipeline; - gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint quantize_down_stage; quantize_down_stage.result_offset_after_shift = 128; quantize_down_stage.result_fixedpoint_multiplier = 1234567890; @@ -345,7 +349,18 @@ void run_benchmarks(std::map* results) { int main() { std::map results; + +#ifdef GEMMLOWP_PROFILING + gemmlowp::RegisterCurrentThreadForProfiling(); + gemmlowp::StartProfiling(); +#endif + run_benchmarks(&results); + +#ifdef GEMMLOWP_PROFILING + gemmlowp::FinishProfiling(); +#endif + printf("Using %d thread(s)\n", kNumThreads); printf("depth,rows,cols,latency(s),Gop/s\n"); for (const auto& result : results) { diff --git a/test/test.cc b/test/test.cc index eee16b4..735ad1e 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1277,6 +1277,47 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset, } } + // Test a variant of the familiar default pipeline consisting of quantize-down + // and clamp-and-cast-to-int16. + OutputStageSaturatingCastToInt16 saturating_cast_int16_stage; + auto quantize_down_and_saturating_cast_int16_pipeline = + std::make_tuple(quantize_down_stage, saturating_cast_int16_stage); + Matrix result_quantized_down_saturated_int16(rows, + cols); + GemmWithOutputPipeline( + &context, lhs.const_map(), rhs.const_map(), + &result_quantized_down_saturated_int16, lhs_offset, rhs_offset, + quantize_down_and_saturating_cast_int16_pipeline); + + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + std::int32_t quantized = result_quantized_down_int32(r, c); + std::int16_t expected = std::min(std::max(quantized, -32768), 32767); + Check(expected == result_quantized_down_saturated_int16(r, c)); + } + } + +#ifdef GEMMLOWP_MSA + // Test a pipeline consisting of quantize-down and truncating-cast-to-uint8. + OutputStageTruncatingCastToUint8 truncating_cast_stage; + auto quantize_down_and_truncating_cast_pipeline = + std::make_tuple(quantize_down_stage, truncating_cast_stage); + Matrix result_quantized_down_truncated_uint8( + rows, cols); + GemmWithOutputPipeline( + &context, lhs.const_map(), rhs.const_map(), + &result_quantized_down_truncated_uint8, lhs_offset, rhs_offset, + quantize_down_and_truncating_cast_pipeline); + + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + std::int32_t quantized = result_quantized_down_int32(r, c); + std::uint8_t expected = quantized & 255; + Check(expected == result_quantized_down_truncated_uint8(r, c)); + } + } +#endif + // Test a bias-addition with row-vector std::vector row_vector_data(cols); std::uniform_int_distribution uniform_minus_500_plus_500(-500, @@ -1428,8 +1469,8 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset, result_fixedpoint_shift++; } Check(result_fixedpoint_shift >= 0); - // Now test OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint - OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + // Now test OutputStageQuantizeDownInt32ByFixedPoint + OutputStageQuantizeDownInt32ByFixedPoint quantize_down_by_fixedpoint_stage; quantize_down_by_fixedpoint_stage.result_offset_after_shift = static_cast( @@ -1447,7 +1488,6 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset, &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset, quantize_down_by_fixedpoint_pipeline); - std::vector diffs_caused_by_fixedpoint; for (int r = 0; r < rows; r++) { for (int c = 0; c < cols; c++) { const std::int32_t actual = @@ -1462,6 +1502,44 @@ void TestOutputStages(int rows, int depth, int cols, int result_offset, } } + // Test OutputStageScaleInt32ByFixedPointAndExponent + for (int exponent = -2; exponent <= 2; exponent++) { + OutputStageScaleInt32ByFixedPointAndExponent + scale_by_fixedpoint_and_exponent_stage; + scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift = + static_cast(round(static_cast( + result_offset * result_mult_int * std::pow(2.0, exponent)))); + scale_by_fixedpoint_and_exponent_stage.result_fixedpoint_multiplier = + result_fixedpoint_multiplier; + scale_by_fixedpoint_and_exponent_stage.result_exponent = exponent; + auto scale_by_fixedpoint_and_exponent_pipeline = + std::make_tuple(scale_by_fixedpoint_and_exponent_stage); + Matrix + result_scaled_by_fixedpoint_and_exponent_int32(rows, cols); + GemmWithOutputPipeline( + &context, lhs.const_map(), rhs.const_map(), + &result_scaled_by_fixedpoint_and_exponent_int32, lhs_offset, rhs_offset, + scale_by_fixedpoint_and_exponent_pipeline); + + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + const std::int32_t actual = + result_scaled_by_fixedpoint_and_exponent_int32(r, c); + const std::int32_t raw = result_raw_int32(r, c); + int left_shift = std::max(0, exponent); + int right_shift = std::max(0, -exponent); + const std::int32_t expected = + scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift + + RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul((1 << left_shift) * raw, + result_fixedpoint_multiplier), + right_shift); + Check(actual == expected); + } + } + } + // Test the variant of the familiar default pipeline consisting of // quantize-down and // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the diff --git a/test/test.h b/test/test.h index aecd0c1..b381bad 100644 --- a/test/test.h +++ b/test/test.h @@ -49,7 +49,7 @@ class Matrix : public MatrixMap { typedef MatrixMap Map; typedef MatrixMap ConstMap; typedef typename Map::Scalar Scalar; - static const MapOrder Order = tOrder; + static constexpr MapOrder Order = tOrder; using Map::kOrder; using Map::rows_; using Map::cols_; @@ -92,12 +92,12 @@ class Matrix : public MatrixMap { std::vector storage; }; -std::mt19937& RandomEngine() { +inline std::mt19937& RandomEngine() { static std::mt19937 engine; return engine; } -int Random() { +inline int Random() { std::uniform_int_distribution dist(0, std::numeric_limits::max()); return dist(RandomEngine()); } diff --git a/test/test_blocking_counter.cc b/test/test_blocking_counter.cc index d1e0932..34d963d 100644 --- a/test/test_blocking_counter.cc +++ b/test/test_blocking_counter.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test.h" -#include "../profiling/pthread_everywhere.h" - +#include // NOLINT #include +#include +#include #include "../internal/multi_thread_gemm.h" +#include "../profiling/pthread_everywhere.h" +#include "test.h" namespace gemmlowp { @@ -26,16 +28,36 @@ class Thread { Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement) : blocking_counter_(blocking_counter), number_of_times_to_decrement_(number_of_times_to_decrement), - finished_(false), - made_the_last_decrement_(false) { + made_the_last_decrement_(false), + finished_(false) { +#if defined GEMMLOWP_USE_PTHREAD + // Limit the stack size so as not to deplete memory when creating + // many threads. + pthread_attr_t attr; + int err = pthread_attr_init(&attr); + if (!err) { + size_t stack_size; + err = pthread_attr_getstacksize(&attr, &stack_size); + if (!err && stack_size > max_stack_size_) { + err = pthread_attr_setstacksize(&attr, max_stack_size_); + } + if (!err) { + err = pthread_create(&thread_, &attr, ThreadFunc, this); + } + } + if (err) { + std::cerr << "Failed to create a thread.\n"; + std::abort(); + } +#else pthread_create(&thread_, nullptr, ThreadFunc, this); +#endif } ~Thread() { Join(); } - bool Join() const { - if (!finished_) { - pthread_join(thread_, nullptr); + bool Join() { + while (!finished_.load()) { } return made_the_last_decrement_; } @@ -48,7 +70,7 @@ class Thread { Check(!made_the_last_decrement_); made_the_last_decrement_ = blocking_counter_->DecrementCount(); } - finished_ = true; + finished_.store(true); } static void* ThreadFunc(void* ptr) { @@ -56,11 +78,18 @@ class Thread { return nullptr; } + static constexpr size_t max_stack_size_ = 256 * 1024; BlockingCounter* const blocking_counter_; const int number_of_times_to_decrement_; pthread_t thread_; - bool finished_; bool made_the_last_decrement_; + // finished_ is used to manually implement Join() by busy-waiting. + // I wanted to use pthread_join / std::thread::join, but the behavior + // observed on Android was that pthread_join aborts when the thread has + // already joined before calling pthread_join, making that hard to use. + // It appeared simplest to just implement this simple spinlock, and that + // is good enough as this is just a test. + std::atomic finished_; }; void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads, @@ -89,10 +118,10 @@ void test_blocking_counter() { // repeating the entire test sequence ensures that we test // non-monotonic changes. for (int repeat = 1; repeat <= 2; repeat++) { - for (int num_threads = 1; num_threads <= 16; num_threads++) { + for (int num_threads = 1; num_threads <= 5; num_threads++) { for (int num_decrements_per_thread = 1; - num_decrements_per_thread <= 64 * 1024; - num_decrements_per_thread *= 4) { + num_decrements_per_thread <= 4 * 1024; + num_decrements_per_thread *= 16) { test_blocking_counter(blocking_counter, num_threads, num_decrements_per_thread, num_threads * num_decrements_per_thread); diff --git a/test/test_fixedpoint.cc b/test/test_fixedpoint.cc index da222f0..44e6fae 100644 --- a/test/test_fixedpoint.cc +++ b/test/test_fixedpoint.cc @@ -17,479 +17,587 @@ #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS #include +#include #include +#include #include #include -#include "test.h" #include "../fixedpoint/fixedpoint.h" +#include "test.h" namespace gemmlowp { namespace { -// Explanation of SimdVector type and associated functions -// (LoadSimdVector, StoreSimdVector): -// The fixedpoint stuff being tested here is generic in an underlying -// integer type which may be either scalar (int32_t) or SIMD (e.g. -// NEON int32x4_t). We want to write uniform tests that can test -// both the scalar and SIMD paths. We achieve this by having this -// generic SimdVector abstraction, local to this test. - +template +T Load(const typename FixedPointRawTypeTraits::ScalarRawType* src) { + return *src; +} +template +void Store(typename FixedPointRawTypeTraits::ScalarRawType* dst, T v) { + *dst = v; +} #ifdef GEMMLOWP_NEON -using SimdVector = int32x4_t; -constexpr std::size_t SimdVectorSize = 4; -SimdVector LoadSimdVector(const std::int32_t* src) { return vld1q_s32(src); } -void StoreSimdVector(std::int32_t* dst, SimdVector v) { vst1q_s32(dst, v); } -#elif defined(GEMMLOWP_SSE4) -using SimdVector = __m128i; -constexpr std::size_t SimdVectorSize = 4; -SimdVector LoadSimdVector(const std::int32_t* src) { +template <> +int32x4_t Load(const std::int32_t* src) { + return vld1q_s32(src); +} +template <> +int16x8_t Load(const std::int16_t* src) { + return vld1q_s16(src); +} +template <> +void Store(std::int32_t* dst, int32x4_t v) { + vst1q_s32(dst, v); +} +template <> +void Store(std::int16_t* dst, int16x8_t v) { + vst1q_s16(dst, v); +} +#endif +#ifdef GEMMLOWP_SSE4 +template <> +__m128i Load<__m128i>(const std::int32_t* src) { return _mm_loadu_si128(reinterpret_cast(src)); } -void StoreSimdVector(std::int32_t* dst, SimdVector v) { +template <> +void Store<__m128i>(std::int32_t* dst, __m128i v) { _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v); } -#else -using SimdVector = std::int32_t; -constexpr std::size_t SimdVectorSize = 1; -SimdVector LoadSimdVector(const std::int32_t* src) { return *src; } -void StoreSimdVector(std::int32_t* dst, SimdVector v) { *dst = v; } +template <> +int16x8_m128i Load(const std::int16_t* src) { + return to_int16x8_m128i( + _mm_loadu_si128(reinterpret_cast(src))); +} +template <> +void Store(std::int16_t* dst, int16x8_m128i v) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v.v); +} +#endif +#ifdef GEMMLOWP_MSA +template <> +v4i32 Load(const std::int32_t* src) { + return __builtin_msa_ld_w(const_cast(src), 0); +} +template <> +v8i16 Load(const std::int16_t* src) { + return __builtin_msa_ld_h(const_cast(src), 0); +} +template <> +void Store(std::int32_t* dst, v4i32 v) { + __builtin_msa_st_w(v, dst, 0); +} +template <> +void Store(std::int16_t* dst, v8i16 v) { + __builtin_msa_st_h(v, dst, 0); +} #endif -// Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp: -// Most (though not all) of the fixedpoint functionality being tested -// consists of functions taking one fixedpoint value and returning one -// fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators". -// We factor a lot of testing boilerplate into a common TestUnaryOp function -// taking a "unary op" object that fully describes the function to be tested. -// These objects inherit UnaryOpBase mostly as a means to share some default -// values for some properties. -// -// An important design element here is that the fixed-point values are passed -// around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not -// as higher-level FixedPoint objects. The motivation for this design is 1) to -// avoid having to templatize everything in the tIntegerBits parameter of -// class FixedPoint, and 2) to allow directly testing low-level functions -// operating on raw types (e.g. RoundingDivideByPOT) without needlessly -// requiring -// wrapping raw values in FixedPoint objects. -class UnaryOpBase { - public: - // Min bound of the input range of this op. For example, an op only handling - // nonnegative values would return 0. - std::int32_t MinInput() const { - return std::numeric_limits::min(); - } - // Max bound of the input range of this op. For example, an op only handling - // nonpositive values would return 0. - std::int32_t MaxInput() const { - return std::numeric_limits::max(); - } - // Tolerated difference between actual and reference int32 values. - // Note that the corresponding real-numbers tolerance depends on the number - // of integer bits of the fixed-point representation of the results of this - // op. - // For example, for an op returning fixed-point values with 0 integer bits, - // the correspondence between real-number values and raw values is - // real_number = (2^31) * raw_value. - std::int32_t Tolerance() const { return 0; } -}; +#ifdef GEMMLOWP_AVX2 +template <> +__m256i Load<__m256i>(const std::int32_t* src) { + return _mm256_loadu_si256(reinterpret_cast(src)); +} -// Op wrapping RoundingDivideByPOT -class RoundingDivideByPOTOp final : public UnaryOpBase { - public: - RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {} - std::int32_t ReferenceOp(std::int32_t x) const { - const double d = static_cast(x) / (1ll << exponent_); - return static_cast(std::round(d)); - } - template - tRawType Op(tRawType x) const { - return RoundingDivideByPOT(x, exponent_); - } +template <> +int16x16_m256i Load(const std::int16_t* src) { + return to_int16x16_m256i( + _mm256_loadu_si256(reinterpret_cast(src))); +} - private: - const int exponent_; -}; +template <> +void Store<__m256i>(std::int32_t* dst, __m256i v) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); +} -// Op wrapping SaturatingRoundingMultiplyByPOT -template -class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase { - public: - std::int32_t ReferenceOp(std::int32_t x) const { - const double d = static_cast(x) * std::pow(2., tExponent); - const double clamp_min = std::numeric_limits::min(); - const double clamp_max = std::numeric_limits::max(); - const double clamped = std::min(clamp_max, std::max(clamp_min, d)); - return static_cast(std::round(clamped)); - } - template - tRawType Op(tRawType x) const { - return SaturatingRoundingMultiplyByPOT(x); - } -}; +template <> +void Store(std::int16_t* dst, int16x16_m256i v) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v.v); +} +#endif -// Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl -class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final - : public UnaryOpBase { +template +class TestFixedPoint { public: - std::int32_t MinInput() const { return -(1 << 29); } - std::int32_t MaxInput() const { return 0; } - std::int32_t Tolerance() const { return 500; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint; - const double d = ToDouble(F::FromRaw(x)); - const double e = std::exp(d); - return F::FromDouble(e).raw(); - } - template - tRawType Op(tRawType x) const { - using F = FixedPoint; - const F f = F::FromRaw(x); - const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f); - return e.raw(); - } -}; + using SimdType = tSimdType; + using SimdTypeTraits = FixedPointRawTypeTraits; + using ScalarType = typename SimdTypeTraits::ScalarRawType; + static constexpr int kSimdLanes = SimdTypeTraits::kLanes; + static constexpr int kScalarTypeBits = 8 * sizeof(ScalarType); + + // Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp: + // Most (though not all) of the fixedpoint functionality being tested + // consists of functions taking one fixedpoint value and returning one + // fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators". + // We factor a lot of testing boilerplate into a common TestUnaryOp function + // taking a "unary op" object that fully describes the function to be tested. + // These objects inherit UnaryOpBase mostly as a means to share some default + // values for some properties. + // + // An important design element here is that the fixed-point values are passed + // around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not + // as higher-level FixedPoint objects. The motivation for this design is 1) to + // avoid having to templatize everything in the tIntegerBits parameter of + // class FixedPoint, and 2) to allow directly testing low-level functions + // operating on raw types (e.g. RoundingDivideByPOT) without needlessly + // requiring + // wrapping raw values in FixedPoint objects. + class UnaryOpBase { + public: + // Min bound of the input range of this op. For example, an op only handling + // nonnegative values would return 0. + ScalarType MinInput() const { + return std::numeric_limits::min(); + } + // Max bound of the input range of this op. For example, an op only handling + // nonpositive values would return 0. + ScalarType MaxInput() const { + return std::numeric_limits::max(); + } + // Tolerated difference between actual and reference ScalarType values. + // Note that the corresponding real-numbers tolerance depends on the number + // of integer bits of the fixed-point representation of the results of this + // op. + // For example, for an op returning fixed-point values with 0 integer bits, + // the correspondence between real-number values and raw values is + // real_number = (2^31) * raw_value. + ScalarType Tolerance() const { return 0; } + }; + + // Op wrapping RoundingDivideByPOT + class RoundingDivideByPOTOp final : public UnaryOpBase { + public: + RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {} + ScalarType ReferenceOp(ScalarType x) const { + const double d = static_cast(x) / (1ll << exponent_); + return static_cast(std::round(d)); + } + template + RawType Op(RawType x) const { + return RoundingDivideByPOT(x, exponent_); + } -// Op wrapping exp_on_negative_values -template -class ExpOnNegativeValuesOp final : public UnaryOpBase { - public: - std::int32_t MaxInput() const { return 0; } - std::int32_t Tolerance() const { return 500; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint; - using F0 = FixedPoint; - const double d = ToDouble(F::FromRaw(x)); - const double e = std::exp(d); - return F0::FromDouble(e).raw(); - } - template - tRawType Op(tRawType x) const { - using F = FixedPoint; - const F f = F::FromRaw(x); - return exp_on_negative_values(f).raw(); + private: + const int exponent_; + }; + + // Op wrapping SaturatingRoundingMultiplyByPOT + template + class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase { + public: + ScalarType ReferenceOp(ScalarType x) const { + const double d = static_cast(x) * std::pow(2., tExponent); + const double clamp_min = std::numeric_limits::min(); + const double clamp_max = std::numeric_limits::max(); + const double clamped = std::min(clamp_max, std::max(clamp_min, d)); + return static_cast(std::round(clamped)); + } + template + RawType Op(RawType x) const { + return SaturatingRoundingMultiplyByPOT(x); + } + }; + + // Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl + class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final + : public UnaryOpBase { + public: + ScalarType MinInput() const { return -(1 << (kScalarTypeBits - 3)); } + ScalarType MaxInput() const { return 0; } + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 1; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint; + const double d = ToDouble(F::FromRaw(x)); + const double e = std::exp(d); + return F::FromDouble(e).raw(); + } + template + RawType Op(RawType x) const { + using F = FixedPoint; + const F f = F::FromRaw(x); + const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f); + return e.raw(); + } + }; + + // Op wrapping exp_on_negative_values + template + class ExpOnNegativeValuesOp final : public UnaryOpBase { + public: + ScalarType MaxInput() const { return 0; } + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 2; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint; + using F0 = FixedPoint; + const double d = ToDouble(F::FromRaw(x)); + const double e = std::exp(d); + return F0::FromDouble(e).raw(); + } + template + RawType Op(RawType x) const { + using F = FixedPoint; + const F f = F::FromRaw(x); + return exp_on_negative_values(f).raw(); + } + }; + + // Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1 + class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase { + public: + ScalarType MinInput() const { return 0; } + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 12 : 11; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint; + const double d = ToDouble(F::FromRaw(x)); + const double e = (1 - d) / (1 + d); + return F::FromDouble(e).raw(); + } + template + RawType Op(RawType x) const { + using F = FixedPoint; + const F f = F::FromRaw(x); + return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw(); + } + }; + + // Op wrapping tanh + template + class TanhOp final : public UnaryOpBase { + public: + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 310 : 12; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint; + using F0 = FixedPoint; + const double d = ToDouble(F::FromRaw(x)); + const double e = std::tanh(d); + return F0::FromDouble(e).raw(); + } + template + RawType Op(RawType x) const { + using F = FixedPoint; + const F f = F::FromRaw(x); + return tanh(f).raw(); + } + }; + + // Op wrapping one_over_one_plus_x_for_x_in_0_1 + class OneOverOnePlusXForXIn01Op final : public UnaryOpBase { + public: + ScalarType MinInput() const { return 0; } + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 6 : 5; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint; + const double d = ToDouble(F::FromRaw(x)); + const double e = 1 / (1 + d); + return F::FromDouble(e).raw(); + } + template + RawType Op(RawType x) const { + using F = FixedPoint; + const F f = F::FromRaw(x); + return one_over_one_plus_x_for_x_in_0_1(f).raw(); + } + }; + + // Op wrapping logistic + template + class LogisticOp final : public UnaryOpBase { + public: + ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 155 : 6; } + ScalarType ReferenceOp(ScalarType x) const { + using F = FixedPoint; + using F0 = FixedPoint; + const double d = ToDouble(F::FromRaw(x)); + const double e = 1 / (1 + std::exp(-d)); + return F0::FromDouble(e).raw(); + } + template + RawType Op(RawType x) const { + using F = FixedPoint; + const F f = F::FromRaw(x); + return logistic(f).raw(); + } + }; + + // Tests a given op, on a given list of int32 input values. + template + void TestUnaryOp(const tUnaryOpType& unary_op, + const std::vector& testvals) { + Check(0 == (testvals.size() % kSimdLanes)); + for (std::size_t i = 0; i < testvals.size(); i += kSimdLanes) { + // First, clamp input values accoding to the MinInput() and MaxInput() + // bounds returned by the op. + ScalarType input[kSimdLanes] = {0}; + for (std::size_t j = 0; j < kSimdLanes; j++) { + const ScalarType raw_input = testvals[i + j]; + input[j] = std::min(unary_op.MaxInput(), + std::max(unary_op.MinInput(), raw_input)); + } + // Compute reference results and check that the actual results on + // scalar inputs agree with them, to the Tolerance() returned by the op. + ScalarType reference[kSimdLanes] = {0}; + ScalarType actual_scalar[kSimdLanes] = {0}; + for (std::size_t j = 0; j < kSimdLanes; j++) { + reference[j] = unary_op.ReferenceOp(input[j]); + actual_scalar[j] = unary_op.Op(input[j]); + const std::int64_t diff = static_cast(actual_scalar[j]) - + static_cast(reference[j]); + if (std::abs(diff) > unary_op.Tolerance()) { + fprintf(stderr, "abs(diff) (%" PRId64 ") > tolerance (%d)\n", diff, + unary_op.Tolerance()); + } + Check(std::abs(diff) <= unary_op.Tolerance()); + } + // Check that the actual results on SIMD inputs agree *exactly* with the + // actual results on scalar inputs. I.e. SIMD must make absolutely no + // difference + // to the results, regardless of the fact that both scalar and SIMD + // results may differ from the reference results. + ScalarType actual_simd[kSimdLanes] = {0}; + Store(actual_simd, unary_op.Op(Load(input))); + for (std::size_t j = 0; j < kSimdLanes; j++) { + if (actual_simd[j] != actual_scalar[j]) { + fprintf(stderr, "SIMD (%d) != scalar (%d)\n", actual_simd[j], + actual_scalar[j]); + } + Check(actual_simd[j] == actual_scalar[j]); + } + } } -}; -// Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1 -class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase { - public: - std::int32_t MinInput() const { return 0; } - std::int32_t Tolerance() const { return 12; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint; - const double d = ToDouble(F::FromRaw(x)); - const double e = (1 - d) / (1 + d); - return F::FromDouble(e).raw(); - } - template - tRawType Op(tRawType x) const { - using F = FixedPoint; - const F f = F::FromRaw(x); - return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw(); + template + void test_convert(FixedPoint x) { + typedef FixedPoint F; + F y = F::FromDouble(ToDouble(x)); + Check(y == x); } -}; -// Op wrapping tanh -template -class TanhOp final : public UnaryOpBase { - public: - std::int32_t Tolerance() const { return 310; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint; - using F0 = FixedPoint; - const double d = ToDouble(F::FromRaw(x)); - const double e = std::tanh(d); - return F0::FromDouble(e).raw(); - } - template - tRawType Op(tRawType x) const { - using F = FixedPoint; - const F f = F::FromRaw(x); - return tanh(f).raw(); + template + void test_Rescale(FixedPoint a) { + FixedPoint actual = Rescale(a); + FixedPoint expected = + FixedPoint::FromDouble(ToDouble(a)); + Check(actual == expected); } -}; -// Op wrapping one_over_one_plus_x_for_x_in_0_1 -class OneOverOnePlusXForXIn01Op final : public UnaryOpBase { - public: - std::int32_t MinInput() const { return 0; } - std::int32_t Tolerance() const { return 6; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint; - const double d = ToDouble(F::FromRaw(x)); - const double e = 1 / (1 + d); - return F::FromDouble(e).raw(); - } - template - tRawType Op(tRawType x) const { - using F = FixedPoint; - const F f = F::FromRaw(x); - return one_over_one_plus_x_for_x_in_0_1(f).raw(); + template + void test_Rescale(const std::vector& testvals) { + for (auto a : testvals) { + FixedPoint aq; + aq.raw() = a; + test_Rescale(aq); + } } -}; -// Op wrapping logistic -template -class LogisticOp final : public UnaryOpBase { - public: - std::int32_t Tolerance() const { return 155; } - std::int32_t ReferenceOp(std::int32_t x) const { - using F = FixedPoint; - using F0 = FixedPoint; - const double d = ToDouble(F::FromRaw(x)); - const double e = 1 / (1 + std::exp(-d)); - return F0::FromDouble(e).raw(); + template + void test_mul(FixedPoint a, + FixedPoint b) { + static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b; + using ProductFixedPoint = FixedPoint; + ProductFixedPoint ab; + ab = a * b; + double a_double = ToDouble(a); + double b_double = ToDouble(b); + double ab_double = a_double * b_double; + ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double); + std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw()); + Check(std::abs(diff) <= 1); } - template - tRawType Op(tRawType x) const { - using F = FixedPoint; - const F f = F::FromRaw(x); - return logistic(f).raw(); - } -}; -// Tests a given op, on a given list of int32 input values. -template -void TestUnaryOp(const tUnaryOpType& unary_op, - const std::vector& testvals_int32) { - Check(0 == (testvals_int32.size() % SimdVectorSize)); - for (std::size_t i = 0; i < testvals_int32.size(); i += SimdVectorSize) { - // First, clamp input int32 values accoding to the MinInput() and MaxInput() - // bounds returned by the op. - std::int32_t input[SimdVectorSize] = {0}; - for (std::size_t j = 0; j < SimdVectorSize; j++) { - const std::int32_t raw_input = testvals_int32[i + j]; - input[j] = std::min(unary_op.MaxInput(), - std::max(unary_op.MinInput(), raw_input)); - } - // Compute reference results and check that the actual results on - // scalar inputs agree with them, to the Tolerance() returned by the op. - std::int32_t reference[SimdVectorSize] = {0}; - std::int32_t actual_scalar[SimdVectorSize] = {0}; - for (std::size_t j = 0; j < SimdVectorSize; j++) { - reference[j] = unary_op.ReferenceOp(input[j]); - actual_scalar[j] = unary_op.Op(input[j]); - const std::int64_t diff = static_cast(actual_scalar[j]) - - static_cast(reference[j]); - Check(std::abs(diff) <= unary_op.Tolerance()); - } - // Check that the actual results on SIMD inputs agree *exactly* with the - // actual results on scalar inputs. I.e. SIMD must make absolutely no - // difference - // to the results, regardless of the fact that both scalar and SIMD results - // may differ from the reference results. - std::int32_t actual_simd[SimdVectorSize] = {0}; - StoreSimdVector(actual_simd, unary_op.Op(LoadSimdVector(input))); - for (std::size_t j = 0; j < SimdVectorSize; j++) { - Check(actual_simd[j] == actual_scalar[j]); + template + void test_mul(const std::vector& testvals) { + for (auto a : testvals) { + for (auto b : testvals) { + FixedPoint aq; + FixedPoint bq; + aq.raw() = a; + bq.raw() = b; + test_mul(aq, bq); + } } } -} -template -void test_convert(FixedPoint x) { - typedef FixedPoint F; - F y = F::FromDouble(ToDouble(x)); - Check(y == x); -} - -template -void test_Rescale(FixedPoint a) { - FixedPoint actual = Rescale(a); - FixedPoint expected = - FixedPoint::FromDouble(ToDouble(a)); - Check(actual == expected); -} - -template -void test_Rescale(const std::vector& testvals_int32) { - for (auto a : testvals_int32) { - FixedPoint aq; - aq.raw() = a; - test_Rescale(aq); + template + void test_ExactMulByPot(FixedPoint a) { + double x = ToDouble(a) * std::pow(2.0, tExponent); + double y = ToDouble(ExactMulByPot(a)); + Check(x == y); } -} - -template -void test_mul(FixedPoint a, - FixedPoint b) { - static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b; - using ProductFixedPoint = FixedPoint; - ProductFixedPoint ab; - ab = a * b; - double a_double = ToDouble(a); - double b_double = ToDouble(b); - double ab_double = a_double * b_double; - ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double); - std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw()); - Check(std::abs(diff) <= 1); -} -template -void test_mul(const std::vector& testvals_int32) { - for (auto a : testvals_int32) { - for (auto b : testvals_int32) { - FixedPoint aq; - FixedPoint bq; + template + void test_ExactMulByPot(const std::vector& testvals) { + for (auto a : testvals) { + FixedPoint aq; aq.raw() = a; - bq.raw() = b; - test_mul(aq, bq); + test_ExactMulByPot(aq); } } -} -template -void test_ExactMulByPot(FixedPoint a) { - double x = ToDouble(a) * std::pow(2.0, tExponent); - double y = ToDouble(ExactMulByPot(a)); - Check(x == y); -} + // Make the list of test values to test each op against. + std::vector MakeTestVals() { + std::vector testvals; + + for (int i = 0; i < kScalarTypeBits - 1; i++) { + testvals.push_back((1 << i) - 2); + testvals.push_back((1 << i) - 1); + testvals.push_back((1 << i)); + testvals.push_back((1 << i) + 1); + testvals.push_back((1 << i) + 2); + testvals.push_back(-(1 << i) - 2); + testvals.push_back(-(1 << i) - 1); + testvals.push_back(-(1 << i)); + testvals.push_back(-(1 << i) + 1); + testvals.push_back(-(1 << i) + 2); + } + testvals.push_back(std::numeric_limits::min()); + testvals.push_back(std::numeric_limits::min() + 1); + testvals.push_back(std::numeric_limits::min() + 2); + testvals.push_back(std::numeric_limits::max() - 2); + testvals.push_back(std::numeric_limits::max() - 1); + testvals.push_back(std::numeric_limits::max()); + + std::mt19937 random_engine; + std::uniform_int_distribution uniform_distribution( + std::numeric_limits::min(), + std::numeric_limits::max()); + for (int i = 0; i < 1000; i++) { + testvals.push_back(uniform_distribution(random_engine)); + } -template -void test_ExactMulByPot(const std::vector& testvals_int32) { - for (auto a : testvals_int32) { - FixedPoint aq; - aq.raw() = a; - test_ExactMulByPot(aq); - } -} + // SIMD tests will require the length of testvals to be a multiple + // of SIMD vector size. + while (testvals.size() % kSimdLanes) { + testvals.push_back(0); + } -// Make the list of test values to test each op against. -std::vector MakeTestValsInt32() { - std::vector testvals_int32; - - for (int i = 0; i < 31; i++) { - testvals_int32.push_back((1 << i) - 2); - testvals_int32.push_back((1 << i) - 1); - testvals_int32.push_back((1 << i)); - testvals_int32.push_back((1 << i) + 1); - testvals_int32.push_back((1 << i) + 2); - testvals_int32.push_back(-(1 << i) - 2); - testvals_int32.push_back(-(1 << i) - 1); - testvals_int32.push_back(-(1 << i)); - testvals_int32.push_back(-(1 << i) + 1); - testvals_int32.push_back(-(1 << i) + 2); - } - testvals_int32.push_back(std::numeric_limits::min()); - testvals_int32.push_back(std::numeric_limits::min() + 1); - testvals_int32.push_back(std::numeric_limits::min() + 2); - testvals_int32.push_back(std::numeric_limits::max() - 2); - testvals_int32.push_back(std::numeric_limits::max() - 1); - testvals_int32.push_back(std::numeric_limits::max()); - - std::mt19937 random_engine; - std::uniform_int_distribution uniform_distribution( - std::numeric_limits::min(), - std::numeric_limits::max()); - for (int i = 0; i < 1000; i++) { - testvals_int32.push_back(uniform_distribution(random_engine)); + std::sort(testvals.begin(), testvals.end()); + return testvals; } - // SIMD tests will require the length of testvals_int32 to be a multiple - // of SIMD vector size. - while (testvals_int32.size() % SimdVectorSize) { - testvals_int32.push_back(0); - } + void RunTests(const char* msg) { + const std::vector testvals = MakeTestVals(); - std::sort(testvals_int32.begin(), testvals_int32.end()); - return testvals_int32; -} + for (int s = 0; s < kScalarTypeBits; s++) { + TestUnaryOp(RoundingDivideByPOTOp(s), testvals); + } + + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<14 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15 - kScalarTypeBits>(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp(), + testvals); + TestUnaryOp(SaturatingRoundingMultiplyByPOTOp(), + testvals); + + TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals); + TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals); + + TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals); + TestUnaryOp(TanhOp<0>(), testvals); + TestUnaryOp(TanhOp<1>(), testvals); + TestUnaryOp(TanhOp<2>(), testvals); + TestUnaryOp(TanhOp<3>(), testvals); + TestUnaryOp(TanhOp<4>(), testvals); + TestUnaryOp(TanhOp<5>(), testvals); + TestUnaryOp(TanhOp<6>(), testvals); + + TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals); + TestUnaryOp(LogisticOp<0>(), testvals); + TestUnaryOp(LogisticOp<1>(), testvals); + TestUnaryOp(LogisticOp<2>(), testvals); + TestUnaryOp(LogisticOp<3>(), testvals); + TestUnaryOp(LogisticOp<4>(), testvals); + TestUnaryOp(LogisticOp<5>(), testvals); + TestUnaryOp(LogisticOp<6>(), testvals); + + for (auto a : testvals) { + FixedPoint x; + x.raw() = a; + test_convert(x); + } + + test_mul<0, 0>(testvals); + test_mul<0, 1>(testvals); + test_mul<2, 0>(testvals); + test_mul<1, 1>(testvals); + test_mul<4, 4>(testvals); + test_mul<3, 5>(testvals); + test_mul<7, 2>(testvals); + test_mul(testvals); + + test_Rescale<0, 0>(testvals); + test_Rescale<0, 1>(testvals); + test_Rescale<2, 0>(testvals); + test_Rescale<4, 4>(testvals); + test_Rescale<4, 5>(testvals); + test_Rescale<6, 3>(testvals); + test_Rescale<13, 9>(testvals); + + test_ExactMulByPot<0, 0>(testvals); + test_ExactMulByPot<0, 4>(testvals); + test_ExactMulByPot<1, 4>(testvals); + test_ExactMulByPot<3, 2>(testvals); + test_ExactMulByPot<-4, 5>(testvals); + test_ExactMulByPot<-2, 6>(testvals); + + fprintf(stderr, "PASS (%s)\n", msg); + } +}; } // end anonymous namespace } // end namespace gemmlowp int main() { - using namespace gemmlowp; - - const std::vector testvals_int32 = MakeTestValsInt32(); - - for (int s = 0; s < 32; s++) { - TestUnaryOp(RoundingDivideByPOTOp(s), testvals_int32); - } - - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-31>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-30>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-29>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-17>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-16>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<16>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<17>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<29>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<30>(), testvals_int32); - TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<31>(), testvals_int32); - - TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(), - testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals_int32); - TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals_int32); - - TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals_int32); - TestUnaryOp(TanhOp<0>(), testvals_int32); - TestUnaryOp(TanhOp<1>(), testvals_int32); - TestUnaryOp(TanhOp<2>(), testvals_int32); - TestUnaryOp(TanhOp<3>(), testvals_int32); - TestUnaryOp(TanhOp<4>(), testvals_int32); - TestUnaryOp(TanhOp<5>(), testvals_int32); - TestUnaryOp(TanhOp<6>(), testvals_int32); - - TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals_int32); - TestUnaryOp(LogisticOp<0>(), testvals_int32); - TestUnaryOp(LogisticOp<1>(), testvals_int32); - TestUnaryOp(LogisticOp<2>(), testvals_int32); - TestUnaryOp(LogisticOp<3>(), testvals_int32); - TestUnaryOp(LogisticOp<4>(), testvals_int32); - TestUnaryOp(LogisticOp<5>(), testvals_int32); - TestUnaryOp(LogisticOp<6>(), testvals_int32); - - for (auto a : testvals_int32) { - FixedPoint x; - x.raw() = a; - test_convert(x); - } - - test_mul<0, 0>(testvals_int32); - test_mul<0, 1>(testvals_int32); - test_mul<2, 0>(testvals_int32); - test_mul<1, 1>(testvals_int32); - test_mul<4, 4>(testvals_int32); - test_mul<3, 5>(testvals_int32); - test_mul<7, 2>(testvals_int32); - test_mul<14, 15>(testvals_int32); - - test_Rescale<0, 0>(testvals_int32); - test_Rescale<0, 1>(testvals_int32); - test_Rescale<2, 0>(testvals_int32); - test_Rescale<4, 4>(testvals_int32); - test_Rescale<4, 5>(testvals_int32); - test_Rescale<6, 3>(testvals_int32); - test_Rescale<13, 9>(testvals_int32); - - test_ExactMulByPot<0, 0>(testvals_int32); - test_ExactMulByPot<0, 4>(testvals_int32); - test_ExactMulByPot<1, 4>(testvals_int32); - test_ExactMulByPot<3, 2>(testvals_int32); - test_ExactMulByPot<-4, 5>(testvals_int32); - test_ExactMulByPot<-2, 6>(testvals_int32); - - std::cerr << "All tests passed." << std::endl; + gemmlowp::TestFixedPoint().RunTests("Scalar int32"); + gemmlowp::TestFixedPoint().RunTests("Scalar int16"); +#ifdef GEMMLOWP_SSE4 + gemmlowp::TestFixedPoint<__m128i>().RunTests("SSE4 __m128i = int32x4"); + gemmlowp::TestFixedPoint().RunTests( + "SSE4 __m128i = int16x8"); +#endif +#ifdef GEMMLOWP_NEON + gemmlowp::TestFixedPoint().RunTests("NEON int32x4_t"); + gemmlowp::TestFixedPoint().RunTests("NEON int16x8_t"); +#endif +#ifdef GEMMLOWP_MSA + gemmlowp::TestFixedPoint().RunTests("MSA v4i32"); + gemmlowp::TestFixedPoint().RunTests("MSA v8i16"); +#endif +#ifdef GEMMLOWP_AVX2 + gemmlowp::TestFixedPoint<__m256i>().RunTests("AVX __m256i"); + gemmlowp::TestFixedPoint().RunTests( + "AVX2 __m256i = int16x16"); +#endif } -- cgit v1.2.3