aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLev Proleev <levp@google.com>2021-03-12 18:55:26 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2021-03-12 18:55:26 +0000
commit2c005aca73d9a32151a040aa476fed0ec89a14ae (patch)
tree45b033783360ed59b9efbfd8bd2dfdb4fc8bdb72
parentd23d5384ee2dad29223e8c57248ea83ec23da4bf (diff)
parent998a6df5933c918fe486c3cbd3cb1e699e0211b5 (diff)
downloadruy-2c005aca73d9a32151a040aa476fed0ec89a14ae.tar.gz
Merge remote-tracking branch 'aosp/upstream-master' into tflite-rebase-feb-2021 am: 713d254ecf am: dd1f1778c2 am: 039d972c6e am: 998a6df593
Original change: https://android-review.googlesource.com/c/platform/external/ruy/+/1610773 MUST ONLY BE SUBMITTED BY AUTOMERGER Change-Id: Ie195bc9d9ec4d8755daeb0c6b0c52f52ada4df51
-rw-r--r--.gitignore30
-rw-r--r--.gitmodules6
-rw-r--r--BUILD7
-rw-r--r--CMakeLists.txt90
-rw-r--r--CONTRIBUTING.md28
-rw-r--r--LICENSE202
-rw-r--r--README.md25
-rw-r--r--WORKSPACE44
-rwxr-xr-xcmake/bazel_to_cmake.py279
-rwxr-xr-xcmake/bazel_to_cmake.sh35
-rwxr-xr-xcmake/run_android_test.sh16
-rw-r--r--cmake/ruy_add_all_subdirs.cmake37
-rw-r--r--cmake/ruy_cc_binary.cmake57
-rw-r--r--cmake/ruy_cc_library.cmake85
-rw-r--r--cmake/ruy_cc_test.cmake76
-rw-r--r--cmake/ruy_include_directories.cmake33
-rw-r--r--doc/README.md3
-rwxr-xr-xdoc/depgraph.sh153
-rw-r--r--doc/depgraph.svg377
-rw-r--r--example/BUILD16
-rw-r--r--example/CMakeLists.txt23
-rw-r--r--example/README.md14
-rw-r--r--example/example.cc161
-rw-r--r--example/parametrized_example.cc198
-rw-r--r--ruy/BUILD1220
-rw-r--r--ruy/CMakeLists.txt1696
-rw-r--r--ruy/allocator.cc124
-rw-r--r--ruy/allocator.h100
-rw-r--r--ruy/allocator_test.cc126
-rw-r--r--ruy/apply_multiplier.cc70
-rw-r--r--ruy/apply_multiplier.h92
-rw-r--r--ruy/apply_multiplier_test.cc137
-rw-r--r--ruy/asm_helpers.h43
-rw-r--r--ruy/benchmark.cc223
-rw-r--r--ruy/block_map.cc497
-rw-r--r--ruy/block_map.h162
-rw-r--r--ruy/block_map_test.cc259
-rw-r--r--ruy/blocking_counter.cc49
-rw-r--r--ruy/blocking_counter.h66
-rw-r--r--ruy/build_defs.bzl79
-rw-r--r--ruy/build_defs.oss.bzl15
-rw-r--r--ruy/check_macros.h147
-rw-r--r--ruy/check_macros_test.cc153
-rw-r--r--ruy/context.cc58
-rw-r--r--ruy/context.h108
-rw-r--r--ruy/context_get_ctx.cc27
-rw-r--r--ruy/context_get_ctx.h32
-rw-r--r--ruy/context_test.cc45
-rw-r--r--ruy/cpu_cache_params.h83
-rw-r--r--ruy/cpuinfo.cc163
-rw-r--r--ruy/cpuinfo.h61
-rw-r--r--ruy/create_trmul_params.h484
-rw-r--r--ruy/ctx.cc216
-rw-r--r--ruy/ctx.h91
-rw-r--r--ruy/ctx_impl.h84
-rw-r--r--ruy/ctx_test.cc76
-rw-r--r--ruy/frontend.cc37
-rw-r--r--ruy/frontend.h99
-rw-r--r--ruy/gtest_wrapper.h33
-rw-r--r--ruy/have_built_path_for.h31
-rw-r--r--ruy/have_built_path_for_avx.cc35
-rw-r--r--ruy/have_built_path_for_avx2_fma.cc35
-rw-r--r--ruy/have_built_path_for_avx512.cc35
-rw-r--r--ruy/kernel.h245
-rw-r--r--ruy/kernel_arm.h212
-rw-r--r--ruy/kernel_arm32.cc2515
-rw-r--r--ruy/kernel_arm64.cc8075
-rw-r--r--ruy/kernel_avx.cc1476
-rw-r--r--ruy/kernel_avx2_fma.cc1011
-rw-r--r--ruy/kernel_avx512.cc1550
-rw-r--r--ruy/kernel_common.h287
-rw-r--r--ruy/kernel_x86.h874
-rw-r--r--ruy/mat.h492
-rw-r--r--ruy/matrix.h218
-rw-r--r--ruy/matrix_test.cc101
-rw-r--r--ruy/mul_params.h299
-rw-r--r--ruy/mul_params_test.cc79
-rw-r--r--ruy/opt_set.h51
-rw-r--r--ruy/pack.h155
-rw-r--r--ruy/pack_arm.cc2480
-rw-r--r--ruy/pack_arm.h613
-rw-r--r--ruy/pack_avx.cc831
-rw-r--r--ruy/pack_avx2_fma.cc689
-rw-r--r--ruy/pack_avx512.cc828
-rw-r--r--ruy/pack_common.h143
-rw-r--r--ruy/pack_x86.h659
-rw-r--r--ruy/path.h203
-rw-r--r--ruy/perchannel_buffers_reallocation_test.cc120
-rw-r--r--ruy/performance_advisory.h40
-rw-r--r--ruy/platform.h159
-rw-r--r--ruy/pmu.cc297
-rw-r--r--ruy/pmu.h46
-rw-r--r--ruy/prepacked_cache.cc129
-rw-r--r--ruy/prepacked_cache.h141
-rw-r--r--ruy/prepacked_cache_test.cc309
-rw-r--r--ruy/prepare_packed_matrices.cc94
-rw-r--r--ruy/prepare_packed_matrices.h42
-rw-r--r--ruy/profiler/BUILD66
-rw-r--r--ruy/profiler/CMakeLists.txt72
-rw-r--r--ruy/profiler/README.md149
-rw-r--r--ruy/profiler/instrumentation.cc132
-rw-r--r--ruy/profiler/instrumentation.h203
-rw-r--r--ruy/profiler/profiler.cc109
-rw-r--r--ruy/profiler/profiler.h106
-rw-r--r--ruy/profiler/test.cc167
-rw-r--r--ruy/profiler/test_instrumented_library.cc59
-rw-r--r--ruy/profiler/test_instrumented_library.h23
-rw-r--r--ruy/profiler/treeview.cc252
-rw-r--r--ruy/profiler/treeview.h130
-rw-r--r--ruy/reference_mul.h56
-rw-r--r--ruy/ruy.h114
-rw-r--r--ruy/ruy_test.bzl34
-rw-r--r--ruy/ruy_test_ext.oss.bzl7
-rw-r--r--ruy/side_pair.h68
-rw-r--r--ruy/size_util.h105
-rw-r--r--ruy/size_util_test.cc101
-rw-r--r--ruy/system_aligned_alloc.cc51
-rw-r--r--ruy/system_aligned_alloc.h53
-rw-r--r--ruy/test.h2308
-rw-r--r--ruy/test_fast.cc109
-rw-r--r--ruy/test_slow.cc70
-rw-r--r--ruy/thread_pool.cc218
-rw-r--r--ruy/thread_pool.h127
-rw-r--r--ruy/time.h87
-rw-r--r--ruy/trace.h836
-rw-r--r--ruy/trmul.cc397
-rw-r--r--ruy/trmul.h39
-rw-r--r--ruy/trmul_params.h92
-rw-r--r--ruy/tune.cc46
-rw-r--r--ruy/tune.h117
-rw-r--r--ruy/tune_test.cc55
-rw-r--r--ruy/validate.h75
-rw-r--r--ruy/wait.cc44
-rw-r--r--ruy/wait.h68
-rw-r--r--ruy/wait_test.cc117
-rw-r--r--third_party/BUILD0
-rw-r--r--third_party/CMakeLists.txt5
m---------third_party/cpuinfo0
-rw-r--r--third_party/cpuinfo.BUILD386
m---------third_party/googletest0
140 files changed, 41802 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..950a145
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,30 @@
+# Visual Studio files
+.vs/
+.vscode/
+*.sdf
+*.opensdf
+*.VC.opendb
+*.suo
+*.user
+
+# macOS files
+.DS_Store
+
+# CMake artifacts
+build/
+build-*/
+
+# Bazel artifacts
+**/bazel-*
+
+# Emacs autosaves
+*~
+\#*\#
+
+# Vim swap files
+[._]*.sw[a-p]
+
+# Source indexing files
+compile_commands.json
+.cache/clangd
+.clangd/
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..0e03f4e
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,6 @@
+[submodule "googletest"]
+ path = third_party/googletest
+ url = https://github.com/google/googletest
+[submodule "cpuinfo"]
+ path = third_party/cpuinfo
+ url = https://github.com/pytorch/cpuinfo
diff --git a/BUILD b/BUILD
new file mode 100644
index 0000000..8c2d62e
--- /dev/null
+++ b/BUILD
@@ -0,0 +1,7 @@
+# Ruy is not BLAS
+
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["LICENSE"])
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..98d480d
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,90 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+cmake_policy(SET CMP0012 NEW)
+cmake_policy(SET CMP0048 NEW)
+project(ruy CXX)
+cmake_minimum_required(VERSION 3.13) # Copied from IREE
+set(CMAKE_CXX_STANDARD 14)
+
+
+
+if (PROJECT_NAME STREQUAL CMAKE_PROJECT_NAME)
+ set(RUY_IS_TOPLEVEL TRUE)
+ set(RUY_MINIMAL_BUILD_DEFAULT_VALUE OFF)
+else()
+ set(RUY_IS_TOPLEVEL FALSE)
+ set(RUY_MINIMAL_BUILD_DEFAULT_VALUE ON)
+endif()
+
+option(RUY_MINIMAL_BUILD "Disable ruy's tests, examples, etc. Build only ruy public libraries." ${RUY_MINIMAL_BUILD_DEFAULT_VALUE})
+if (NOT RUY_MINIMAL_BUILD)
+ enable_testing()
+endif()
+
+option(RUY_PROFILER "Enable ruy's built-in profiler (harms performance)" OFF)
+
+include(cmake/ruy_add_all_subdirs.cmake)
+include(cmake/ruy_cc_library.cmake)
+include(cmake/ruy_cc_binary.cmake)
+include(cmake/ruy_cc_test.cmake)
+
+# Skip cpuinfo if it was already generated, which can happen when ruy is
+# a subdirectory in a wider project that already uses cpuinfo.
+if (NOT TARGET cpuinfo)
+ # Test if the third_party/cpuinfo submodule was checked out before
+ # adding that subdirectory, so we can do more helpful things below in the
+ # else() block when it's not.
+ set(RUY_CPUINFO_CMAKELISTS_FILE "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cpuinfo/CMakeLists.txt")
+ if (EXISTS "${RUY_CPUINFO_CMAKELISTS_FILE}")
+ # Disabling cpuinfo's tests and benchmarks to prevent a copy of its
+ # googletest dependency getting downloaded into a 'deps' directory in the
+ # source tree!
+ set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE)
+ set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "" FORCE)
+ set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "" FORCE)
+ add_subdirectory("third_party/cpuinfo" EXCLUDE_FROM_ALL)
+ else()
+ # third_party/cpuinfo is not checked out. That could be intentional when
+ # ruy is a subdirectory in a wider project that is already providing
+ # the cpuinfo target. Maybe that wider project's CMakeLists is ordered
+ # in such a way that cpuinfo gets generated after ruy. In that case,
+ # it's helpful that we continue silently. In the worst case if the cpuinfo
+ # target never gets defined, ruy will fail to compile.
+ # On the other hand, if ruy is the top-level project here (not part of a
+ # wider project) then nothing will define the cpuinfo target for us,
+ # so we will definitely fail to compile, so we may as well fail right here.
+ if (RUY_IS_TOPLEVEL)
+ message(FATAL_ERROR "This file does not exist:\n${RUY_CPUINFO_CMAKELISTS_FILE}\n"
+ "That typically means that the git submodules of the ruy "
+ "repository haven't been checked out. Try this in the ruy "
+ "git directory:\n git submodule update --init")
+ endif()
+ endif()
+endif()
+
+# googletest is only needed for tests. Projects embedding ruy as a subdirectory
+# and not needing to build ruy tests may proceed without a local checkout of
+# third_party/googletest.
+if (NOT RUY_MINIMAL_BUILD
+ AND NOT TARGET gtest
+ AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/third_party/googletest/CMakeLists.txt")
+ add_subdirectory("third_party/googletest" EXCLUDE_FROM_ALL)
+endif()
+
+add_subdirectory("ruy")
+
+if (NOT RUY_MINIMAL_BUILD)
+ add_subdirectory("example")
+endif()
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..654a071
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,28 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution;
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to <https://cla.developers.google.com/> to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+## Community Guidelines
+
+This project follows [Google's Open Source Community
+Guidelines](https://opensource.google/conduct/).
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..f341563
--- /dev/null
+++ b/README.md
@@ -0,0 +1,25 @@
+# The ruy matrix multiplication library
+
+This is not an officially supported Google product.
+
+ruy is a matrix multiplication library. Its focus is to cover the matrix
+multiplication needs of neural network inference engines. Its initial user has
+been TensorFlow Lite, where it is used by default on the ARM CPU architecture.
+
+ruy supports both floating-point and 8bit-integer-quantized matrices.
+
+## Efficiency
+
+ruy is designed to achieve high performance not just on very large sizes, as
+is the focus of many established libraries, but on whatever are the actual sizes
+and shapes of matrices most critical in current TensorFlow Lite applications.
+This often means quite small sizes, e.g. 100x100 or even 50x50, and all sorts of
+rectangular shapes. It's not as fast as completely specialized code for each
+shape, but it aims to offer a good compromise of speed across all shapes and a
+small binary size.
+
+## Documentation
+
+Some documentation will eventually be available in the doc/ directory, see
+[doc/README.md](doc/README.md).
+
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 0000000..da4fe9f
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,44 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Workspace file for the Ruy project.
+
+workspace(name = "com_google_ruy")
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
+
+maybe(
+ local_repository,
+ name = "com_google_googletest",
+ path = "third_party/googletest",
+)
+
+maybe(
+ new_local_repository,
+ name = "cpuinfo",
+ path = "third_party/cpuinfo",
+ build_file = "@//third_party:cpuinfo.BUILD",
+)
+
+# skylib utility for additional bazel functionality.
+skylib_version = "0.9.0"
+http_archive(
+ name = "bazel_skylib",
+ type = "tar.gz",
+ url = "https://github.com/bazelbuild/bazel-skylib/releases/download/{}/bazel_skylib-{}.tar.gz".format (skylib_version, skylib_version),
+ sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
+)
+load("@bazel_skylib//lib:versions.bzl", "versions")
+versions.check(minimum_bazel_version = "2.0.0")
diff --git a/cmake/bazel_to_cmake.py b/cmake/bazel_to_cmake.py
new file mode 100755
index 0000000..ba1a38b
--- /dev/null
+++ b/cmake/bazel_to_cmake.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python3
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This is yet another bazel-to-cmake converter. It's independently written from
+scratch but relies on the same basic idea as others (including IREE's), namely:
+to let the python interpreter do the bulk of the work, exploiting the fact that
+both Bazel's BUILD syntax and Starlark (".bzl") languages are more or less
+subsets of Python.
+
+The main features that this converter supports and that others don't, justifying
+its existence as of early 2021, are:
+ 1. Ad-hoc support for select(), generating CMake if()...elseif()... chains
+ parsing the condition keys (e.g. anything ending in ":windows" is
+ interpreted as the condition "the target platform is Windows"). This allows
+ to just ignore config_setting, as we only care about the config_setting
+ names, not their actual implementation, as well as all the variants from
+ the Bazel 'selects' library.
+ 2. Support for load(), loading macros from Starlark files.
+"""
+
+import re
+import os
+import os.path
+import pickle
+import sys
+import datetime
+import itertools
+
+# Ruy's dependencies.
+external_targets = ['gtest', 'gtest_main', 'cpuinfo']
+
+# Text replacements [oldstring, newstring] pairs, applied on all BUILD and
+# Starlark files that we load. Only used by preprocess_input_text.
+replacements = [
+ ['$(STACK_FRAME_UNLIMITED)', ''],
+ ['native.cc_', 'cc_'],
+ ['selects.config_setting_group', 'config_setting_group'],
+ ['@com_google_googletest//:gtest', 'gtest'],
+ ['@com_google_googletest//:gtest_main', 'gtest_main'],
+ ['@cpuinfo//:cpuinfo_with_unstripped_include_path', 'cpuinfo'],
+]
+
+
+def preprocess_input_text(text):
+ result = text
+ for replacement in replacements:
+ result = result.replace(replacement[0], replacement[1])
+ return result
+
+
+def set_cmake_list(list_name, values, indent):
+ semicolon_separated = ";".join(values)
+ print(f'{indent}set({list_name} "{semicolon_separated}")')
+
+
+def generate_cmake_select(select_name, dict):
+ new_if_branch_keyword = 'if'
+ default_value = []
+ for key in dict:
+ condition = ''
+ if key == '//conditions:default':
+ default_value = dict[key]
+ continue
+ elif re.search(r':windows$', key):
+ condition = 'CMAKE_SYSTEM_NAME STREQUAL Windows'
+ elif re.search(r':ppc$', key):
+ condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le'
+ elif re.search(r':s390x$', key):
+ condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x'
+ elif re.search(r':fuchsia$', key):
+ condition = 'CMAKE_SYSTEM_NAME STREQUAL Fuchsia'
+ elif re.search(r':arm32_assuming_neon$', key):
+ condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL arm'
+ elif re.search(r':do_not_want_O3$', key):
+ # Ruy is a specialist library: we always want code to be compiled
+ # with -O3 unless the build type is Debug or the compiler does not
+ # support that flag syntax.
+ condition = '(CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC'
+ elif re.search(r':x86_64_and_not_msvc$', key):
+ condition = '(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC'
+ elif re.search(r':windows_msvc$', key):
+ condition = 'MSVC'
+ elif re.search(r':ruy_profiler$', key):
+ condition = '${RUY_PROFILER}'
+ else:
+ raise ValueError(f'Unhandled key in select: {key}')
+
+ print(f'{new_if_branch_keyword}({condition})')
+ set_cmake_list(select_name, dict[key], ' ')
+ new_if_branch_keyword = 'elseif'
+
+ print('else()')
+ set_cmake_list(select_name, default_value, ' ')
+
+ print('endif()\n')
+
+
+def trim_multiple_ruy_prefixes(name):
+ return re.sub(r'(ruy_)+ruy', 'ruy', name)
+
+def get_cmake_local_target_name(name):
+ global package_prefix
+ return trim_multiple_ruy_prefixes(f'ruy_{package_prefix}_{name}')
+
+
+def get_cmake_dep_target_name(name):
+ if name in external_targets:
+ return name
+ if name.startswith('$'):
+ # Happens for deps that are the result of expanding a select() that we
+ # have compiled to expanding a variable.
+ return name
+ if name.startswith('//'):
+ after_last_slash = name.split('/')[-1]
+ if not ':' in after_last_slash:
+ name = f'{name}:{after_last_slash}'
+ raw=name[2:].replace('/', '_').replace(':', '_')
+ return trim_multiple_ruy_prefixes(raw)
+ if name.startswith(':'):
+ name = name[1:]
+ return get_cmake_local_target_name(name)
+
+
+#
+# Functions implementing BUILD functions
+#
+
+
+def package(**kwargs):
+ pass
+
+
+def exports_files(*args):
+ pass
+
+
+def load(filename, *args):
+ if filename.startswith('@'):
+ return
+ elif filename.startswith(':'):
+ filename = os.path.join(bazel_package_dir, filename[1:])
+ elif filename.startswith('//'):
+ split = filename[2:].split(':')
+ filename = os.path.join(bazel_workspace_dir, split[0], split[1])
+
+ src_file_content = open(filename).read()
+ processed_file_content = preprocess_input_text(src_file_content)
+ exec(processed_file_content, globals(), globals())
+
+
+def config_setting(**kwargs):
+ # Nothing to do since our implementation of select() is based on parsing
+ # the names of config_settings, not looking deep into their actual
+ # implementation.
+ pass
+
+
+def filegroup(**kwargs):
+ pass
+
+
+def config_setting_group(**kwargs):
+ # See config_setting.
+ pass
+
+
+def bzl_library(**kwargs):
+ pass
+
+
+select_index = 0
+select_cache = {}
+
+
+def select(select_dict):
+ global select_index
+ global select_cache
+ global package_prefix
+ key = pickle.dumps(sorted(select_dict.items()))
+ if key in select_cache:
+ select_name = select_cache[key]
+ else:
+ unique_values = sorted(set(itertools.chain.from_iterable(select_dict.values()))) # sorting ensures determinism, no spurious diffs
+ description = '_'.join(unique_values)
+ select_name = f'{package_prefix}_{select_index}_{description}'
+ select_name = select_name.replace('c++', 'cxx')
+ select_name = re.sub(r'[^a-zA-Z0-9]+', '_', select_name)
+ select_index = select_index + 1
+ select_cache[key] = select_name
+ generate_cmake_select(select_name, select_dict)
+
+ return [f'${{{select_name}}}']
+
+
+def generic_rule(rule_name, **kwargs):
+ print(f'{rule_name}(')
+ for key in kwargs.keys():
+ values = kwargs[key]
+ if type(values) is bool:
+ if values:
+ print(f' {key.upper()}')
+ continue
+ else:
+ raise ValueError(
+ 'Cannot specify FALSE boolean args in CMake')
+ if key == 'visibility':
+ if values == ['//visibility:public']:
+ print(f' PUBLIC')
+ continue
+ if key == 'tags':
+ values = list(filter(lambda x : not x.startswith('req_dep'), values))
+ if not values:
+ continue
+ print(f' {key.upper()}')
+ if type(values) is list:
+ for value in values:
+ if key == 'deps':
+ target_name = get_cmake_dep_target_name(value)
+ print(f' {target_name}')
+ else:
+ print(f' {value}')
+ else:
+ if key == 'name':
+ target_name = get_cmake_local_target_name(values)
+ print(f' {target_name}')
+ else:
+ print(f' {values}')
+ print(')\n')
+
+
+def cc_library(**kwargs):
+ generic_rule('ruy_cc_library', **kwargs)
+
+
+def cc_test(**kwargs):
+ generic_rule('ruy_cc_test', **kwargs)
+
+
+def cc_binary(**kwargs):
+ generic_rule('ruy_cc_binary', **kwargs)
+
+
+#
+# Program entry point.
+#
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print("Usage: bazel_to_cmake.py bazel_workspace_dir bazel_package_dir")
+ sys.exit(1)
+
+ bazel_workspace_dir = sys.argv[1]
+ bazel_package_dir = sys.argv[2]
+ bazel_package_relative_dir = os.path.relpath(
+ bazel_package_dir, bazel_workspace_dir)
+ package_prefix = bazel_package_relative_dir.replace(os.path.sep, '_')
+
+ print("""# This file is generated (whence no license header). Do not edit!
+# To regenerate, run:
+# cmake/bazel_to_cmake.sh
+""")
+
+ src_build_file = os.path.join(bazel_package_dir, "BUILD")
+ src_build_content = open(src_build_file).read()
+ processed_build_content = preprocess_input_text(src_build_content)
+ exec(processed_build_content)
+
+ print("ruy_add_all_subdirs()")
diff --git a/cmake/bazel_to_cmake.sh b/cmake/bazel_to_cmake.sh
new file mode 100755
index 0000000..296219e
--- /dev/null
+++ b/cmake/bazel_to_cmake.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+this_script_dir="$(dirname "$0")"
+
+root_dir="$(git -C "${this_script_dir}" rev-parse --show-toplevel)"
+
+build_files="$(find "${root_dir}" -type f -name BUILD)"
+
+if ! command -v python3 &> /dev/null; then
+ python_command=python
+else
+ python_command=python3
+fi
+
+for build_file in ${build_files}; do
+ package_dir="$(dirname "${build_file}")"
+ if [[ "${package_dir}" == "${root_dir}" ]]; then
+ # The root CMakeLists.txt is not generated.
+ continue
+ fi
+ "${python_command}" "${this_script_dir}/bazel_to_cmake.py" "${root_dir}" "${package_dir}" > "${package_dir}/CMakeLists.txt"
+done
diff --git a/cmake/run_android_test.sh b/cmake/run_android_test.sh
new file mode 100755
index 0000000..d643232
--- /dev/null
+++ b/cmake/run_android_test.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+# Minimal script pushing and running a file on device!
+# Contemporary versions of ADB properly propagate exit codes so nothing more
+# is needed to let CTest report test success/failure.
+
+# TODO: consider clearing temporary files after testing, although that will
+# get in the way of debugging and will make code more complex... also,
+# Ruy's test files aren't huge and people running these probably have
+# bigger clutter issues in their /data/local/tmp anyway. Anyway, if we want
+# to do this, we could copy IREE's code.
+
+device_tmpdir=/data/local/tmp
+
+adb push "$1" "${device_tmpdir}"
+adb shell "${device_tmpdir}/$(basename "$1")"
diff --git a/cmake/ruy_add_all_subdirs.cmake b/cmake/ruy_add_all_subdirs.cmake
new file mode 100644
index 0000000..1a7d126
--- /dev/null
+++ b/cmake/ruy_add_all_subdirs.cmake
@@ -0,0 +1,37 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Forked from IREE's iree_add_all_subdirs.cmake.
+
+# add_all_subidrs
+#
+# CMake function to add all subdirectories of the current directory that contain
+# a CMakeLists.txt file
+#
+# Takes no arguments.
+function(ruy_add_all_subdirs)
+ FILE(GLOB _CHILDREN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/*)
+ SET(_DIRLIST "")
+ foreach(_CHILD ${_CHILDREN})
+ if((NOT(subdir MATCHES third_party)) AND
+ (IS_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/${_CHILD}) AND
+ (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${_CHILD}/CMakeLists.txt))
+ LIST(APPEND _DIRLIST ${_CHILD})
+ endif()
+ endforeach()
+
+ foreach(subdir ${_DIRLIST})
+ add_subdirectory(${subdir})
+ endforeach()
+endfunction()
diff --git a/cmake/ruy_cc_binary.cmake b/cmake/ruy_cc_binary.cmake
new file mode 100644
index 0000000..93d1adf
--- /dev/null
+++ b/cmake/ruy_cc_binary.cmake
@@ -0,0 +1,57 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Forked from IREE's iree_cc_binary.cmake.
+
+include(CMakeParseArguments)
+include(cmake/ruy_include_directories.cmake)
+
+# ruy_cc_binary()
+#
+# CMake function to imitate Bazel's cc_binary rule.
+function(ruy_cc_binary)
+ cmake_parse_arguments(
+ _RULE
+ "TESTONLY"
+ "NAME"
+ "SRCS;COPTS;LINKOPTS;DEPS;TAGS"
+ ${ARGN}
+ )
+
+ if(_RULE_TESTONLY AND RUY_MINIMAL_BUILD)
+ return()
+ endif()
+
+ set(_NAME "${_RULE_NAME}")
+
+ add_executable(${_NAME} "")
+ target_sources(${_NAME}
+ PRIVATE
+ ${_RULE_SRCS}
+ )
+ set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}")
+ ruy_include_directories(${_NAME} "${_RULE_DEPS}")
+ target_compile_options(${_NAME}
+ PRIVATE
+ ${_RULE_COPTS}
+ )
+ target_link_options(${_NAME}
+ PRIVATE
+ ${_RULE_LINKOPTS}
+ )
+ target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
+ )
+endfunction()
diff --git a/cmake/ruy_cc_library.cmake b/cmake/ruy_cc_library.cmake
new file mode 100644
index 0000000..38accc5
--- /dev/null
+++ b/cmake/ruy_cc_library.cmake
@@ -0,0 +1,85 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Forked from IREE's iree_cc_library.cmake.
+
+include(CMakeParseArguments)
+include(cmake/ruy_include_directories.cmake)
+
+# ruy_cc_library()
+#
+# CMake function to imitate Bazel's cc_library rule.
+function(ruy_cc_library)
+ cmake_parse_arguments(
+ _RULE
+ "PUBLIC;TESTONLY"
+ "NAME"
+ "HDRS;SRCS;COPTS;DEFINES;LINKOPTS;DEPS"
+ ${ARGN}
+ )
+
+ if(_RULE_TESTONLY AND RUY_MINIMAL_BUILD)
+ return()
+ endif()
+
+ set(_NAME "${_RULE_NAME}")
+
+ # Check if this is a header-only library.
+ if("${_RULE_SRCS}" STREQUAL "")
+ set(_RULE_IS_INTERFACE 1)
+ else()
+ set(_RULE_IS_INTERFACE 0)
+ endif()
+
+ if(_RULE_IS_INTERFACE)
+ # Generating a header-only library.
+ add_library(${_NAME} INTERFACE)
+ target_include_directories(${_NAME}
+ INTERFACE
+ "${PROJECT_SOURCE_DIR}"
+ )
+ target_link_libraries(${_NAME}
+ INTERFACE
+ ${_RULE_DEPS}
+ ${_RULE_LINKOPTS}
+ )
+ target_compile_definitions(${_NAME}
+ INTERFACE
+ ${_RULE_DEFINES}
+ )
+ else()
+ # Generating a static binary library.
+ add_library(${_NAME} STATIC "")
+ target_sources(${_NAME}
+ PRIVATE
+ ${_RULE_SRCS}
+ ${_RULE_HDRS}
+ )
+ ruy_include_directories(${_NAME} "${_RULE_DEPS}")
+ target_compile_options(${_NAME}
+ PRIVATE
+ ${_RULE_COPTS}
+ )
+ target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
+ PRIVATE
+ ${_RULE_LINKOPTS}
+ )
+ target_compile_definitions(${_NAME}
+ PUBLIC
+ ${_RULE_DEFINES}
+ )
+ endif()
+endfunction()
diff --git a/cmake/ruy_cc_test.cmake b/cmake/ruy_cc_test.cmake
new file mode 100644
index 0000000..6ad247e
--- /dev/null
+++ b/cmake/ruy_cc_test.cmake
@@ -0,0 +1,76 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Forked from IREE's iree_cc_test.cmake.
+
+include(CMakeParseArguments)
+include(cmake/ruy_include_directories.cmake)
+
+# ruy_cc_test()
+#
+# CMake function to imitate Bazel's cc_test rule.
+function(ruy_cc_test)
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME"
+ "SRCS;COPTS;LINKOPTS;DEPS;TAGS"
+ ${ARGN}
+ )
+
+ if(RUY_MINIMAL_BUILD)
+ return()
+ endif()
+
+ set(_NAME "${_RULE_NAME}")
+
+ add_executable(${_NAME} "")
+ target_sources(${_NAME}
+ PRIVATE
+ ${_RULE_SRCS}
+ )
+ set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}")
+ ruy_include_directories(${_NAME} "${_RULE_DEPS}")
+ target_compile_options(${_NAME}
+ PRIVATE
+ ${_RULE_COPTS}
+ )
+ target_link_options(${_NAME}
+ PRIVATE
+ ${_RULE_LINKOPTS}
+ )
+ target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
+ )
+ if(ANDROID)
+ add_test(
+ NAME
+ ${_NAME}
+ COMMAND
+ "${CMAKE_SOURCE_DIR}/cmake/run_android_test.sh"
+ "$<TARGET_FILE:${_NAME}>"
+ )
+ else()
+ add_test(
+ NAME
+ ${_NAME}
+ COMMAND
+ "$<TARGET_FILE:${_NAME}>"
+ )
+ endif()
+ if (_RULE_TAGS)
+ set_property(TEST ${_NAME} PROPERTY LABELS ${_RULE_TAGS})
+ endif()
+endfunction()
diff --git a/cmake/ruy_include_directories.cmake b/cmake/ruy_include_directories.cmake
new file mode 100644
index 0000000..e9b50a9
--- /dev/null
+++ b/cmake/ruy_include_directories.cmake
@@ -0,0 +1,33 @@
+# Copyright 2019-2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+function(ruy_include_directories NAME DEPS)
+ target_include_directories(${NAME}
+ PUBLIC
+ "${PROJECT_SOURCE_DIR}"
+ )
+ if (cpuinfo IN_LIST DEPS)
+ target_include_directories(${NAME}
+ PRIVATE
+ "${PROJECT_SOURCE_DIR}/third_party/cpuinfo/include"
+ )
+ endif()
+ if ((gtest IN_LIST DEPS) OR
+ (gtest_main IN_LIST DEPS))
+ target_include_directories(${NAME}
+ PRIVATE
+ "${PROJECT_SOURCE_DIR}/third_party/googletest/googletest"
+ )
+ endif()
+endfunction() \ No newline at end of file
diff --git a/doc/README.md b/doc/README.md
new file mode 100644
index 0000000..2a03104
--- /dev/null
+++ b/doc/README.md
@@ -0,0 +1,3 @@
+# Ruy documentation
+
+This directory will eventually contain ruy documentation.
diff --git a/doc/depgraph.sh b/doc/depgraph.sh
new file mode 100755
index 0000000..d66d44f
--- /dev/null
+++ b/doc/depgraph.sh
@@ -0,0 +1,153 @@
+#!/bin/bash
+
+# Generates a graphviz dependency graph for :ruy, with details trimmed.
+# Suggested rendering: pipe to `neato` (part of graphviz standard distribution)
+# doc/depgraph.sh | dot -Tsvg > depgraph.svg
+
+drop=(
+ ':platform'
+ ':check_macros'
+ ':asm_helpers'
+ ':size_util'
+ ':system_aligned_alloc'
+ ':side_pair'
+ ':opt_set'
+ ':blocking_counter'
+ ':wait'
+ ':time'
+ ':path'
+ ':performance_advisory'
+ ':tune'
+ ':matrix'
+ ':mat'
+ ':mul_params'
+ ':context_get_ctx'
+ ':have_built_path_for'
+ ':pack_common'
+ ':kernel_common'
+ ':trace'
+ ':validate'
+ 'profiler:instrumentation'
+ '\bclog\b'
+ '\bcpuinfo_impl\b'
+ ':apply_multiplier'
+ '\blabel='
+)
+
+graph="$(bazel query 'kind("cc_library", deps(//ruy))' --output graph --noimplicit_deps 2>/dev/null)"
+
+graph="$(echo "${graph}" | sed 's|//ruy/\?||g')"
+
+for t in "${drop[@]}"; do
+ graph="$(echo "${graph}" | grep -v "${t}")"
+done
+
+graph="$(echo "${graph}" | sed 's|//:cpuinfo_with_unstripped_include_path||g')"
+graph="$(echo "${graph}" | sed 's|//third_party/cpuinfo:[a-z0-9_]*|@cpuinfo|g')"
+
+frontend=(
+ ':ruy'
+ ':context'
+ ':frontend'
+ ':prepare_packed_matrices'
+ ':create_trmul_params'
+)
+
+middleend=(
+ ':ctx'
+ ':trmul_params'
+ ':trmul'
+ ':block_map'
+ ':cpuinfo'
+ ':cpu_cache_params'
+ ':allocator'
+ ':prepacked_cache'
+)
+
+backend=(
+ ':kernel.*'
+ ':pack.*'
+)
+
+threadpool=(
+ ':thread_pool'
+)
+
+frontend_lines=()
+middleend_lines=()
+backend_lines=()
+threadpool_lines=()
+misc_lines=()
+arrow_lines=()
+
+while IFS= read -r line; do
+ if [[ "${line}" =~ '->' ]]; then
+ arrow_lines+=("${line}")
+ else
+ handled=false
+ if [ $handled = false ]; then
+ for f in "${frontend[@]}"; do
+ if [[ "${line}" =~ ${f} ]]; then
+ frontend_lines+=("${line}")
+ handled=true
+ break
+ fi
+ done
+ fi
+ if [ $handled = false ]; then
+ for f in "${middleend[@]}"; do
+ if [[ "${line}" =~ ${f} ]]; then
+ middleend_lines+=("${line}")
+ handled=true
+ break
+ fi
+ done
+ fi
+ if [ $handled = false ]; then
+ for f in "${backend[@]}"; do
+ if [[ "${line}" =~ ${f} ]]; then
+ backend_lines+=("${line}")
+ handled=true
+ break
+ fi
+ done
+ fi
+ if [ $handled = false ]; then
+ for f in "${threadpool[@]}"; do
+ if [[ "${line}" =~ ${f} ]]; then
+ threadpool_lines+=("${line}")
+ handled=true
+ break
+ fi
+ done
+ fi
+ if [ $handled = false ]; then
+ if [[ "${line}" =~ ^[[:space:]]+\" ]]; then
+ misc_lines+=("${line}")
+ fi
+ fi
+ fi
+done <<< "${graph}"
+
+echo "digraph ruy {"
+echo " splines = true"
+echo " node [shape=box]"
+for f in "${frontend_lines[@]}"; do
+ echo " $f [style=filled, color=\"#B2EBF2\"];"
+done
+for m in "${middleend_lines[@]}"; do
+ echo " $m [style=filled, color=\"#C8E6C9\"];"
+done
+for b in "${backend_lines[@]}"; do
+ echo " $b [style=filled, color=\"#FFCDD2\"];"
+done
+for b in "${threadpool_lines[@]}"; do
+ echo " $b [style=filled, color=\"#FFF9C4\"];"
+done
+for m in "${misc_lines[@]}"; do
+ echo "$m"
+done
+for a in "${arrow_lines[@]}"; do
+ echo "$a"
+done
+echo "}"
diff --git a/doc/depgraph.svg b/doc/depgraph.svg
new file mode 100644
index 0000000..79fb27c
--- /dev/null
+++ b/doc/depgraph.svg
@@ -0,0 +1,377 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: ruy Pages: 1 -->
+<svg width="1007pt" height="421pt"
+ viewBox="0.00 0.00 1007.00 421.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 417)">
+<title>ruy</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-417 1003,-417 1003,4 -4,4"/>
+<!-- :ruy -->
+<g id="node1" class="node">
+<title>:ruy</title>
+<polygon fill="#b2ebf2" stroke="#b2ebf2" points="233.5,-413 179.5,-413 179.5,-377 233.5,-377 233.5,-413"/>
+<text text-anchor="middle" x="206.5" y="-391.3" font-family="Times,serif" font-size="14.00">:ruy</text>
+</g>
+<!-- :frontend -->
+<g id="node2" class="node">
+<title>:frontend</title>
+<polygon fill="#b2ebf2" stroke="#b2ebf2" points="375.5,-341 309.5,-341 309.5,-305 375.5,-305 375.5,-341"/>
+<text text-anchor="middle" x="342.5" y="-319.3" font-family="Times,serif" font-size="14.00">:frontend</text>
+</g>
+<!-- :ruy&#45;&gt;:frontend -->
+<g id="edge2" class="edge">
+<title>:ruy&#45;&gt;:frontend</title>
+<path fill="none" stroke="black" d="M233.69,-380C252.69,-370.23 278.42,-356.98 300.09,-345.83"/>
+<polygon fill="black" stroke="black" points="301.81,-348.88 309.09,-341.19 298.6,-342.66 301.81,-348.88"/>
+</g>
+<!-- :context -->
+<g id="node5" class="node">
+<title>:context</title>
+<polygon fill="#b2ebf2" stroke="#b2ebf2" points="100.5,-269 40.5,-269 40.5,-233 100.5,-233 100.5,-269"/>
+<text text-anchor="middle" x="70.5" y="-247.3" font-family="Times,serif" font-size="14.00">:context</text>
+</g>
+<!-- :ruy&#45;&gt;:context -->
+<g id="edge1" class="edge">
+<title>:ruy&#45;&gt;:context</title>
+<path fill="none" stroke="black" d="M190.1,-376.87C166.2,-351.92 121.69,-305.45 94.23,-276.77"/>
+<polygon fill="black" stroke="black" points="96.41,-273.99 86.96,-269.19 91.35,-278.83 96.41,-273.99"/>
+</g>
+<!-- :prepare_packed_matrices -->
+<g id="node3" class="node">
+<title>:prepare_packed_matrices</title>
+<polygon fill="#b2ebf2" stroke="#b2ebf2" points="315,-269 156,-269 156,-233 315,-233 315,-269"/>
+<text text-anchor="middle" x="235.5" y="-247.3" font-family="Times,serif" font-size="14.00">:prepare_packed_matrices</text>
+</g>
+<!-- :frontend&#45;&gt;:prepare_packed_matrices -->
+<g id="edge6" class="edge">
+<title>:frontend&#45;&gt;:prepare_packed_matrices</title>
+<path fill="none" stroke="black" d="M316.32,-304.88C302.46,-295.81 285.26,-284.55 270.29,-274.76"/>
+<polygon fill="black" stroke="black" points="272.06,-271.74 261.78,-269.19 268.23,-277.59 272.06,-271.74"/>
+</g>
+<!-- :create_trmul_params -->
+<g id="node4" class="node">
+<title>:create_trmul_params</title>
+<polygon fill="#b2ebf2" stroke="#b2ebf2" points="700.5,-269 564.5,-269 564.5,-233 700.5,-233 700.5,-269"/>
+<text text-anchor="middle" x="632.5" y="-247.3" font-family="Times,serif" font-size="14.00">:create_trmul_params</text>
+</g>
+<!-- :frontend&#45;&gt;:create_trmul_params -->
+<g id="edge4" class="edge">
+<title>:frontend&#45;&gt;:create_trmul_params</title>
+<path fill="none" stroke="black" d="M375.77,-313.97C419.01,-303.53 495.91,-284.97 554.3,-270.88"/>
+<polygon fill="black" stroke="black" points="555.34,-274.22 564.24,-268.48 553.7,-267.42 555.34,-274.22"/>
+</g>
+<!-- :trmul -->
+<g id="node6" class="node">
+<title>:trmul</title>
+<polygon fill="#c8e6c9" stroke="#c8e6c9" points="424.5,-269 370.5,-269 370.5,-233 424.5,-233 424.5,-269"/>
+<text text-anchor="middle" x="397.5" y="-247.3" font-family="Times,serif" font-size="14.00">:trmul</text>
+</g>
+<!-- :frontend&#45;&gt;:trmul -->
+<g id="edge7" class="edge">
+<title>:frontend&#45;&gt;:trmul</title>
+<path fill="none" stroke="black" d="M356.1,-304.7C362.62,-296.39 370.57,-286.28 377.75,-277.14"/>
+<polygon fill="black" stroke="black" points="380.63,-279.13 384.06,-269.1 375.13,-274.81 380.63,-279.13"/>
+</g>
+<!-- :trmul_params -->
+<g id="node8" class="node">
+<title>:trmul_params</title>
+<polygon fill="#c8e6c9" stroke="#c8e6c9" points="643,-197 546,-197 546,-161 643,-161 643,-197"/>
+<text text-anchor="middle" x="594.5" y="-175.3" font-family="Times,serif" font-size="14.00">:trmul_params</text>
+</g>
+<!-- :frontend&#45;&gt;:trmul_params -->
+<g id="edge8" class="edge">
+<title>:frontend&#45;&gt;:trmul_params</title>
+<path fill="none" stroke="black" d="M372.9,-304.87C418.49,-279.18 504.61,-230.65 555.02,-202.25"/>
+<polygon fill="black" stroke="black" points="557,-205.15 564,-197.19 553.56,-199.05 557,-205.15"/>
+</g>
+<!-- :ctx -->
+<g id="node9" class="node">
+<title>:ctx</title>
+<polygon fill="#c8e6c9" stroke="#c8e6c9" points="200.5,-197 146.5,-197 146.5,-161 200.5,-161 200.5,-197"/>
+<text text-anchor="middle" x="173.5" y="-175.3" font-family="Times,serif" font-size="14.00">:ctx</text>
+</g>
+<!-- :frontend&#45;&gt;:ctx -->
+<g id="edge5" class="edge">
+<title>:frontend&#45;&gt;:ctx</title>
+<path fill="none" stroke="black" d="M309.39,-317.89C258.84,-310.86 166.21,-294.75 146.5,-269 132.35,-250.52 142.7,-224.62 154.44,-205.53"/>
+<polygon fill="black" stroke="black" points="157.47,-207.31 160.04,-197.03 151.62,-203.46 157.47,-207.31"/>
+</g>
+<!-- :allocator -->
+<g id="node13" class="node">
+<title>:allocator</title>
+<polygon fill="#c8e6c9" stroke="#c8e6c9" points="382.5,-116.5 314.5,-116.5 314.5,-80.5 382.5,-80.5 382.5,-116.5"/>
+<text text-anchor="middle" x="348.5" y="-94.8" font-family="Times,serif" font-size="14.00">:allocator</text>
+</g>
+<!-- :frontend&#45;&gt;:allocator -->
+<g id="edge3" class="edge">
+<title>:frontend&#45;&gt;:allocator</title>
+<path fill="none" stroke="black" d="M342.96,-304.91C343.99,-266.57 346.51,-173.41 347.76,-126.89"/>
+<polygon fill="black" stroke="black" points="351.27,-126.74 348.04,-116.65 344.27,-126.55 351.27,-126.74"/>
+</g>
+<!-- :prepare_packed_matrices&#45;&gt;:trmul_params -->
+<g id="edge20" class="edge">
+<title>:prepare_packed_matrices&#45;&gt;:trmul_params</title>
+<path fill="none" stroke="black" d="M315.21,-238.43C373.78,-229.36 455.42,-215.53 536.04,-196.97"/>
+<polygon fill="black" stroke="black" points="536.92,-200.35 545.87,-194.68 535.34,-193.54 536.92,-200.35"/>
+</g>
+<!-- :prepare_packed_matrices&#45;&gt;:ctx -->
+<g id="edge18" class="edge">
+<title>:prepare_packed_matrices&#45;&gt;:ctx</title>
+<path fill="none" stroke="black" d="M220.17,-232.7C212.74,-224.3 203.68,-214.07 195.52,-204.86"/>
+<polygon fill="black" stroke="black" points="197.9,-202.27 188.65,-197.1 192.66,-206.91 197.9,-202.27"/>
+</g>
+<!-- :prepacked_cache -->
+<g id="node10" class="node">
+<title>:prepacked_cache</title>
+<polygon fill="#c8e6c9" stroke="#c8e6c9" points="113,-116.5 0,-116.5 0,-80.5 113,-80.5 113,-116.5"/>
+<text text-anchor="middle" x="56.5" y="-94.8" font-family="Times,serif" font-size="14.00">:prepacked_cache</text>
+</g>
+<!-- :prepare_packed_matrices&#45;&gt;:prepacked_cache -->
+<g id="edge19" class="edge">
+<title>:prepare_packed_matrices&#45;&gt;:prepacked_cache</title>
+<path fill="none" stroke="black" d="M195.68,-232.78C176.98,-223.65 155,-211.35 137.5,-197 112.04,-176.13 88.67,-146.35 73.64,-125.2"/>
+<polygon fill="black" stroke="black" points="76.32,-122.92 67.73,-116.72 70.58,-126.92 76.32,-122.92"/>
+</g>
+<!-- :prepare_packed_matrices&#45;&gt;:allocator -->
+<g id="edge17" class="edge">
+<title>:prepare_packed_matrices&#45;&gt;:allocator</title>
+<path fill="none" stroke="black" d="M248.46,-232.74C268.39,-206.2 306.64,-155.25 329.63,-124.63"/>
+<polygon fill="black" stroke="black" points="332.48,-126.66 335.69,-116.56 326.89,-122.46 332.48,-126.66"/>
+</g>
+<!-- :create_trmul_params&#45;&gt;:trmul_params -->
+<g id="edge25" class="edge">
+<title>:create_trmul_params&#45;&gt;:trmul_params</title>
+<path fill="none" stroke="black" d="M623.11,-232.7C618.74,-224.64 613.44,-214.89 608.6,-205.98"/>
+<polygon fill="black" stroke="black" points="611.63,-204.22 603.79,-197.1 605.48,-207.56 611.63,-204.22"/>
+</g>
+<!-- :create_trmul_params&#45;&gt;:ctx -->
+<g id="edge22" class="edge">
+<title>:create_trmul_params&#45;&gt;:ctx</title>
+<path fill="none" stroke="black" d="M564.28,-239.6C466.01,-224.61 287.97,-197.46 210.82,-185.69"/>
+<polygon fill="black" stroke="black" points="211.04,-182.18 200.62,-184.14 209.98,-189.1 211.04,-182.18"/>
+</g>
+<!-- :create_trmul_params&#45;&gt;:allocator -->
+<g id="edge21" class="edge">
+<title>:create_trmul_params&#45;&gt;:allocator</title>
+<path fill="none" stroke="black" d="M643.65,-232.97C654.87,-213.46 668.3,-181.55 651.5,-161 635.2,-141.06 473.1,-116.46 392.93,-105.39"/>
+<polygon fill="black" stroke="black" points="393.03,-101.87 382.65,-103.98 392.08,-108.81 393.03,-101.87"/>
+</g>
+<!-- :pack -->
+<g id="node14" class="node">
+<title>:pack</title>
+<polygon fill="#ffcdd2" stroke="#ffcdd2" points="753.5,-197 699.5,-197 699.5,-161 753.5,-161 753.5,-197"/>
+<text text-anchor="middle" x="726.5" y="-175.3" font-family="Times,serif" font-size="14.00">:pack</text>
+</g>
+<!-- :create_trmul_params&#45;&gt;:pack -->
+<g id="edge24" class="edge">
+<title>:create_trmul_params&#45;&gt;:pack</title>
+<path fill="none" stroke="black" d="M655.74,-232.7C667.69,-223.8 682.42,-212.82 695.35,-203.2"/>
+<polygon fill="black" stroke="black" points="697.6,-205.88 703.53,-197.1 693.42,-200.27 697.6,-205.88"/>
+</g>
+<!-- :kernel -->
+<g id="node17" class="node">
+<title>:kernel</title>
+<polygon fill="#ffcdd2" stroke="#ffcdd2" points="866.5,-197 812.5,-197 812.5,-161 866.5,-161 866.5,-197"/>
+<text text-anchor="middle" x="839.5" y="-175.3" font-family="Times,serif" font-size="14.00">:kernel</text>
+</g>
+<!-- :create_trmul_params&#45;&gt;:kernel -->
+<g id="edge23" class="edge">
+<title>:create_trmul_params&#45;&gt;:kernel</title>
+<path fill="none" stroke="black" d="M682.87,-232.97C719.86,-220.46 769.25,-203.75 802.61,-192.48"/>
+<polygon fill="black" stroke="black" points="804.06,-195.68 812.41,-189.16 801.82,-189.05 804.06,-195.68"/>
+</g>
+<!-- :context&#45;&gt;:ctx -->
+<g id="edge31" class="edge">
+<title>:context&#45;&gt;:ctx</title>
+<path fill="none" stroke="black" d="M95.7,-232.88C108.91,-223.89 125.29,-212.76 139.61,-203.03"/>
+<polygon fill="black" stroke="black" points="141.9,-205.71 148.21,-197.19 137.97,-199.92 141.9,-205.71"/>
+</g>
+<!-- :context&#45;&gt;:prepacked_cache -->
+<g id="edge32" class="edge">
+<title>:context&#45;&gt;:prepacked_cache</title>
+<path fill="none" stroke="black" d="M51.9,-232.76C42.97,-223.2 33.21,-210.53 28.5,-197 20.19,-173.13 30,-145.38 40.22,-125.6"/>
+<polygon fill="black" stroke="black" points="43.43,-127.03 45.19,-116.59 37.3,-123.65 43.43,-127.03"/>
+</g>
+<!-- :context&#45;&gt;:allocator -->
+<g id="edge30" class="edge">
+<title>:context&#45;&gt;:allocator</title>
+<path fill="none" stroke="black" d="M78.54,-232.97C89.1,-212.53 109.63,-178.8 137.5,-161 201.86,-119.9 234.39,-152.84 305.5,-125 308.23,-123.93 310.98,-122.7 313.71,-121.37"/>
+<polygon fill="black" stroke="black" points="315.61,-124.33 322.81,-116.55 312.33,-118.14 315.61,-124.33"/>
+</g>
+<!-- :thread_pool -->
+<g id="node20" class="node">
+<title>:thread_pool</title>
+<polygon fill="#fff9c4" stroke="#fff9c4" points="216,-116.5 131,-116.5 131,-80.5 216,-80.5 216,-116.5"/>
+<text text-anchor="middle" x="173.5" y="-94.8" font-family="Times,serif" font-size="14.00">:thread_pool</text>
+</g>
+<!-- :context&#45;&gt;:thread_pool -->
+<g id="edge33" class="edge">
+<title>:context&#45;&gt;:thread_pool</title>
+<path fill="none" stroke="black" d="M64.6,-232.97C59.14,-214.05 53.64,-183.22 66.5,-161 70.31,-154.41 102.87,-136.18 131.08,-121.27"/>
+<polygon fill="black" stroke="black" points="132.82,-124.31 140.04,-116.56 129.56,-118.11 132.82,-124.31"/>
+</g>
+<!-- :block_map -->
+<g id="node7" class="node">
+<title>:block_map</title>
+<polygon fill="#c8e6c9" stroke="#c8e6c9" points="528,-197 447,-197 447,-161 528,-161 528,-197"/>
+<text text-anchor="middle" x="487.5" y="-175.3" font-family="Times,serif" font-size="14.00">:block_map</text>
+</g>
+<!-- :trmul&#45;&gt;:block_map -->
+<g id="edge10" class="edge">
+<title>:trmul&#45;&gt;:block_map</title>
+<path fill="none" stroke="black" d="M419.75,-232.7C431.08,-223.88 445.03,-213.03 457.32,-203.47"/>
+<polygon fill="black" stroke="black" points="459.76,-206.01 465.51,-197.1 455.47,-200.48 459.76,-206.01"/>
+</g>
+<!-- :trmul&#45;&gt;:trmul_params -->
+<g id="edge15" class="edge">
+<title>:trmul&#45;&gt;:trmul_params</title>
+<path fill="none" stroke="black" d="M424.89,-240.27C453.72,-230.02 499.94,-213.6 536.85,-200.48"/>
+<polygon fill="black" stroke="black" points="538.25,-203.7 546.51,-197.05 535.91,-197.1 538.25,-203.7"/>
+</g>
+<!-- :trmul&#45;&gt;:ctx -->
+<g id="edge13" class="edge">
+<title>:trmul&#45;&gt;:ctx</title>
+<path fill="none" stroke="black" d="M370.23,-243.12C334.21,-233.79 268.96,-216.11 214.5,-197 213.09,-196.51 211.66,-195.99 210.22,-195.46"/>
+<polygon fill="black" stroke="black" points="211.27,-192.11 200.68,-191.77 208.75,-198.64 211.27,-192.11"/>
+</g>
+<!-- :cpuinfo -->
+<g id="node11" class="node">
+<title>:cpuinfo</title>
+<polygon fill="#c8e6c9" stroke="#c8e6c9" points="296.5,-116.5 234.5,-116.5 234.5,-80.5 296.5,-80.5 296.5,-116.5"/>
+<text text-anchor="middle" x="265.5" y="-94.8" font-family="Times,serif" font-size="14.00">:cpuinfo</text>
+</g>
+<!-- :trmul&#45;&gt;:cpuinfo -->
+<g id="edge12" class="edge">
+<title>:trmul&#45;&gt;:cpuinfo</title>
+<path fill="none" stroke="black" d="M382.36,-232.74C358.98,-206.09 314.02,-154.82 287.19,-124.23"/>
+<polygon fill="black" stroke="black" points="289.69,-121.77 280.47,-116.56 284.43,-126.39 289.69,-121.77"/>
+</g>
+<!-- :cpu_cache_params -->
+<g id="node12" class="node">
+<title>:cpu_cache_params</title>
+<polygon fill="#c8e6c9" stroke="#c8e6c9" points="480.5,-36 356.5,-36 356.5,0 480.5,0 480.5,-36"/>
+<text text-anchor="middle" x="418.5" y="-14.3" font-family="Times,serif" font-size="14.00">:cpu_cache_params</text>
+</g>
+<!-- :trmul&#45;&gt;:cpu_cache_params -->
+<g id="edge11" class="edge">
+<title>:trmul&#45;&gt;:cpu_cache_params</title>
+<path fill="none" stroke="black" d="M399.08,-232.64C402.71,-192.74 411.66,-94.27 416.02,-46.24"/>
+<polygon fill="black" stroke="black" points="419.51,-46.52 416.93,-36.25 412.54,-45.89 419.51,-46.52"/>
+</g>
+<!-- :trmul&#45;&gt;:allocator -->
+<g id="edge9" class="edge">
+<title>:trmul&#45;&gt;:allocator</title>
+<path fill="none" stroke="black" d="M391.88,-232.74C383.39,-206.65 367.22,-156.99 357.2,-126.21"/>
+<polygon fill="black" stroke="black" points="360.48,-124.99 354.06,-116.56 353.82,-127.16 360.48,-124.99"/>
+</g>
+<!-- :trmul&#45;&gt;:thread_pool -->
+<g id="edge14" class="edge">
+<title>:trmul&#45;&gt;:thread_pool</title>
+<path fill="none" stroke="black" d="M371.15,-232.94C355.27,-222.61 334.67,-209.14 316.5,-197 278.68,-171.74 235.6,-142.27 206.7,-122.4"/>
+<polygon fill="black" stroke="black" points="208.43,-119.34 198.2,-116.55 204.46,-125.1 208.43,-119.34"/>
+</g>
+<!-- :block_map&#45;&gt;:cpu_cache_params -->
+<g id="edge16" class="edge">
+<title>:block_map&#45;&gt;:cpu_cache_params</title>
+<path fill="none" stroke="black" d="M480.12,-160.98C468.14,-133.4 444.4,-78.69 430.14,-45.82"/>
+<polygon fill="black" stroke="black" points="433.23,-44.16 426.04,-36.38 426.81,-46.94 433.23,-44.16"/>
+</g>
+<!-- :ctx&#45;&gt;:prepacked_cache -->
+<g id="edge36" class="edge">
+<title>:ctx&#45;&gt;:prepacked_cache</title>
+<path fill="none" stroke="black" d="M148.11,-160.97C131.17,-149.6 108.7,-134.53 90.26,-122.15"/>
+<polygon fill="black" stroke="black" points="92.09,-119.17 81.84,-116.5 88.19,-124.98 92.09,-119.17"/>
+</g>
+<!-- :ctx&#45;&gt;:cpuinfo -->
+<g id="edge35" class="edge">
+<title>:ctx&#45;&gt;:cpuinfo</title>
+<path fill="none" stroke="black" d="M193.46,-160.97C206.48,-149.86 223.65,-135.21 237.96,-123"/>
+<polygon fill="black" stroke="black" points="240.24,-125.65 245.58,-116.5 235.7,-120.33 240.24,-125.65"/>
+</g>
+<!-- :ctx&#45;&gt;:allocator -->
+<g id="edge34" class="edge">
+<title>:ctx&#45;&gt;:allocator</title>
+<path fill="none" stroke="black" d="M200.57,-168.52C227.63,-158.72 270.18,-142.43 305.5,-125 307.75,-123.89 310.03,-122.71 312.32,-121.48"/>
+<polygon fill="black" stroke="black" points="314.06,-124.52 321.09,-116.6 310.65,-118.41 314.06,-124.52"/>
+</g>
+<!-- :ctx&#45;&gt;:thread_pool -->
+<g id="edge37" class="edge">
+<title>:ctx&#45;&gt;:thread_pool</title>
+<path fill="none" stroke="black" d="M173.5,-160.97C173.5,-150.99 173.5,-138.15 173.5,-126.8"/>
+<polygon fill="black" stroke="black" points="177,-126.5 173.5,-116.5 170,-126.5 177,-126.5"/>
+</g>
+<!-- :cpuinfo&#45;&gt;:cpu_cache_params -->
+<g id="edge38" class="edge">
+<title>:cpuinfo&#45;&gt;:cpu_cache_params</title>
+<path fill="none" stroke="black" d="M291.46,-80.27C296.07,-77.41 300.88,-74.54 305.5,-72 326.07,-60.68 349.47,-49.45 369.61,-40.26"/>
+<polygon fill="black" stroke="black" points="371.25,-43.36 378.92,-36.06 368.36,-36.98 371.25,-43.36"/>
+</g>
+<!-- @cpuinfo -->
+<g id="node21" class="node">
+<title>@cpuinfo</title>
+<polygon fill="none" stroke="black" points="301,-36 230,-36 230,0 301,0 301,-36"/>
+<text text-anchor="middle" x="265.5" y="-14.3" font-family="Times,serif" font-size="14.00">@cpuinfo</text>
+</g>
+<!-- :cpuinfo&#45;&gt;@cpuinfo -->
+<g id="edge39" class="edge">
+<title>:cpuinfo&#45;&gt;@cpuinfo</title>
+<path fill="none" stroke="black" d="M265.5,-80.47C265.5,-70.49 265.5,-57.65 265.5,-46.3"/>
+<polygon fill="black" stroke="black" points="269,-46 265.5,-36 262,-46 269,-46"/>
+</g>
+<!-- :pack_avx2_fma\n:pack_avx512\n:pack_avx -->
+<g id="node15" class="node">
+<title>:pack_avx2_fma\n:pack_avx512\n:pack_avx</title>
+<polygon fill="#ffcdd2" stroke="#ffcdd2" points="671,-125 564,-125 564,-72 671,-72 671,-125"/>
+<text text-anchor="middle" x="617.5" y="-109.8" font-family="Times,serif" font-size="14.00">:pack_avx2_fma</text>
+<text text-anchor="middle" x="617.5" y="-94.8" font-family="Times,serif" font-size="14.00">:pack_avx512</text>
+<text text-anchor="middle" x="617.5" y="-79.8" font-family="Times,serif" font-size="14.00">:pack_avx</text>
+</g>
+<!-- :pack&#45;&gt;:pack_avx2_fma\n:pack_avx512\n:pack_avx -->
+<g id="edge27" class="edge">
+<title>:pack&#45;&gt;:pack_avx2_fma\n:pack_avx512\n:pack_avx</title>
+<path fill="none" stroke="black" d="M702.85,-160.97C690.71,-152.22 675.52,-141.28 661.33,-131.07"/>
+<polygon fill="black" stroke="black" points="663.28,-128.16 653.12,-125.16 659.19,-133.84 663.28,-128.16"/>
+</g>
+<!-- :pack_arm -->
+<g id="node16" class="node">
+<title>:pack_arm</title>
+<polygon fill="#ffcdd2" stroke="#ffcdd2" points="763.5,-116.5 689.5,-116.5 689.5,-80.5 763.5,-80.5 763.5,-116.5"/>
+<text text-anchor="middle" x="726.5" y="-94.8" font-family="Times,serif" font-size="14.00">:pack_arm</text>
+</g>
+<!-- :pack&#45;&gt;:pack_arm -->
+<g id="edge26" class="edge">
+<title>:pack&#45;&gt;:pack_arm</title>
+<path fill="none" stroke="black" d="M726.5,-160.97C726.5,-150.99 726.5,-138.15 726.5,-126.8"/>
+<polygon fill="black" stroke="black" points="730,-126.5 726.5,-116.5 723,-126.5 730,-126.5"/>
+</g>
+<!-- :kernel_avx\n:kernel_avx512\n:kernel_avx2_fma -->
+<g id="node18" class="node">
+<title>:kernel_avx\n:kernel_avx512\n:kernel_avx2_fma</title>
+<polygon fill="#ffcdd2" stroke="#ffcdd2" points="897.5,-125 781.5,-125 781.5,-72 897.5,-72 897.5,-125"/>
+<text text-anchor="middle" x="839.5" y="-109.8" font-family="Times,serif" font-size="14.00">:kernel_avx</text>
+<text text-anchor="middle" x="839.5" y="-94.8" font-family="Times,serif" font-size="14.00">:kernel_avx512</text>
+<text text-anchor="middle" x="839.5" y="-79.8" font-family="Times,serif" font-size="14.00">:kernel_avx2_fma</text>
+</g>
+<!-- :kernel&#45;&gt;:kernel_avx\n:kernel_avx512\n:kernel_avx2_fma -->
+<g id="edge29" class="edge">
+<title>:kernel&#45;&gt;:kernel_avx\n:kernel_avx512\n:kernel_avx2_fma</title>
+<path fill="none" stroke="black" d="M839.5,-160.97C839.5,-153.45 839.5,-144.31 839.5,-135.4"/>
+<polygon fill="black" stroke="black" points="843,-135.16 839.5,-125.16 836,-135.16 843,-135.16"/>
+</g>
+<!-- :kernel_arm -->
+<g id="node19" class="node">
+<title>:kernel_arm</title>
+<polygon fill="#ffcdd2" stroke="#ffcdd2" points="999,-116.5 916,-116.5 916,-80.5 999,-80.5 999,-116.5"/>
+<text text-anchor="middle" x="957.5" y="-94.8" font-family="Times,serif" font-size="14.00">:kernel_arm</text>
+</g>
+<!-- :kernel&#45;&gt;:kernel_arm -->
+<g id="edge28" class="edge">
+<title>:kernel&#45;&gt;:kernel_arm</title>
+<path fill="none" stroke="black" d="M865.1,-160.97C882.19,-149.6 904.85,-134.53 923.45,-122.15"/>
+<polygon fill="black" stroke="black" points="925.56,-124.95 931.95,-116.5 921.68,-119.12 925.56,-124.95"/>
+</g>
+</g>
+</svg>
diff --git a/example/BUILD b/example/BUILD
new file mode 100644
index 0000000..738c33e
--- /dev/null
+++ b/example/BUILD
@@ -0,0 +1,16 @@
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+# Usage examples.
+cc_binary(
+ name = "example",
+ srcs = ["example.cc"],
+ deps = ["//ruy"],
+)
+
+cc_binary(
+ name = "parametrized_example",
+ srcs = ["parametrized_example.cc"],
+ deps = ["//ruy"],
+)
diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt
new file mode 100644
index 0000000..63fbd66
--- /dev/null
+++ b/example/CMakeLists.txt
@@ -0,0 +1,23 @@
+# This file is generated (whence no license header). Do not edit!
+# To regenerate, run:
+# cmake/bazel_to_cmake.sh
+
+ruy_cc_binary(
+ NAME
+ ruy_example_example
+ SRCS
+ example.cc
+ DEPS
+ ruy
+)
+
+ruy_cc_binary(
+ NAME
+ ruy_example_parametrized_example
+ SRCS
+ parametrized_example.cc
+ DEPS
+ ruy
+)
+
+ruy_add_all_subdirs()
diff --git a/example/README.md b/example/README.md
new file mode 100644
index 0000000..e29ee12
--- /dev/null
+++ b/example/README.md
@@ -0,0 +1,14 @@
+## Introduction
+
+These are some examples about how to use RUY.
+
+## BUILD
+
+Build the example with bazel commands:
+```
+bazel build //ruy/example:example
+```
+You can find the generated target under directory:
+```
+./bazel-bin/ruy/example
+```
diff --git a/example/example.cc b/example/example.cc
new file mode 100644
index 0000000..3bb95f4
--- /dev/null
+++ b/example/example.cc
@@ -0,0 +1,161 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <iostream>
+
+#include "ruy/ruy.h"
+
+void ExampleMulFloat(ruy::Context *context) {
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2, 3, 4};
+ float dst_data[4];
+
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
+ lhs.set_data(lhs_data);
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
+ rhs.set_data(rhs_data);
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
+ dst.set_data(dst_data);
+
+ ruy::MulParams<float, float> mul_params;
+ ruy::Mul(lhs, rhs, mul_params, context, &dst);
+
+ std::cout << "Example Mul, float:\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+
+void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) {
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2, 3, 4};
+ const float bias_data[] = {1, 0};
+ float dst_data[4];
+
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
+ lhs.set_data(lhs_data);
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
+ rhs.set_data(rhs_data);
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
+ dst.set_data(dst_data);
+
+ ruy::MulParams<float, float> mul_params;
+ mul_params.set_bias(bias_data);
+ mul_params.set_clamp_min(0);
+ mul_params.set_clamp_max(15);
+ ruy::Mul(lhs, rhs, mul_params, context, &dst);
+
+ std::cout << "Example Mul, float with bias addition and clamp:\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+
+void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) {
+ const std::uint8_t lhs_data[] = {124, 125, 126, 127};
+ const std::uint8_t rhs_data[] = {129, 130, 131, 132};
+ std::uint8_t dst_data[4];
+
+ ruy::Matrix<std::uint8_t> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
+ lhs.set_data(lhs_data);
+ lhs.set_zero_point(125);
+ ruy::Matrix<std::uint8_t> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
+ rhs.set_data(rhs_data);
+ rhs.set_zero_point(132);
+ ruy::Matrix<std::uint8_t> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
+ dst.set_data(dst_data);
+ dst.set_zero_point(129);
+
+ ruy::MulParams<std::int32_t, std::uint8_t> mul_params;
+ mul_params.set_multiplier_fixedpoint(1 << 30);
+
+ mul_params.set_multiplier_exponent(0);
+ ruy::Mul(lhs, rhs, mul_params, context, &dst);
+
+ std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+void ExampleMulInt8PerChannelQuantized(ruy::Context *context) {
+ const std::int8_t lhs_data[] = {1, 2, 3, 4};
+ const std::int8_t rhs_data[] = {1, 2, 3, 4};
+ const std::int32_t multiplier_data[] = {3 << 28, 5 << 28};
+ const int exponent_data[] = {1, -2};
+ std::int8_t dst_data[4];
+
+ ruy::Matrix<std::int8_t> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
+ lhs.set_data(lhs_data);
+ ruy::Matrix<std::int8_t> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
+ rhs.set_data(rhs_data);
+ ruy::Matrix<std::int8_t> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
+ dst.set_data(dst_data);
+
+ ruy::MulParams<std::int32_t, std::int8_t> mul_params;
+ mul_params.set_multiplier_fixedpoint_perchannel(multiplier_data);
+ mul_params.set_multiplier_exponent_perchannel(exponent_data);
+ ruy::Mul(lhs, rhs, mul_params, context, &dst);
+
+ std::cout << "Example Mul, int8 quantized with per-channel multipliers\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+void ExampleMulInt8GetRawAccumulators(ruy::Context *context) {
+ const std::int8_t lhs_data[] = {1, 2, 3, 4};
+ const std::int8_t rhs_data[] = {1, 2, 3, 4};
+ std::int32_t dst_data[4];
+
+ ruy::Matrix<std::int8_t> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
+ lhs.set_data(lhs_data);
+ ruy::Matrix<std::int8_t> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
+ rhs.set_data(rhs_data);
+ ruy::Matrix<std::int32_t> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
+ dst.set_data(dst_data);
+
+ // When Dst is int32, mul_params is unused.
+ ruy::MulParams<std::int32_t, std::int32_t> mul_params;
+ ruy::Mul(lhs, rhs, mul_params, context, &dst);
+
+ std::cout << "Example Mul, returning raw int32 accumulators:\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+
+int main() {
+ ruy::Context context;
+ ExampleMulFloat(&context);
+ ExampleMulFloatWithBiasAddAndClamp(&context);
+ ExampleMulUint8AsymmetricQuantized(&context);
+ ExampleMulInt8PerChannelQuantized(&context);
+ ExampleMulInt8GetRawAccumulators(&context);
+}
diff --git a/example/parametrized_example.cc b/example/parametrized_example.cc
new file mode 100644
index 0000000..ef6ad23
--- /dev/null
+++ b/example/parametrized_example.cc
@@ -0,0 +1,198 @@
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <type_traits>
+
+#include "ruy/context.h"
+#include "ruy/matrix.h"
+#include "ruy/mul_params.h"
+#include "ruy/ruy.h"
+
+template <typename... Dst>
+void read_cmdline_args(bool help, int argc, char* argv[], const char* name,
+ const char* format, const char* default_value,
+ const char* allowed_values, Dst... dst) {
+ if (help) {
+ fprintf(stderr, "%-20s %-12s %-16s %s\n", name, format, default_value,
+ allowed_values ? allowed_values : "");
+ return;
+ }
+ const char* value = default_value;
+ for (int i = 1; i < argc; i++) {
+ if (std::strstr(argv[i], name) == argv[i]) {
+ const char* equal_sign = std::strchr(argv[i], '=');
+ if (equal_sign == argv[i] + std::strlen(name)) {
+ value = equal_sign + 1;
+ }
+ break;
+ }
+ }
+ if (allowed_values) {
+ if (!std::strstr(allowed_values, value)) {
+ fprintf(stderr, "Illegal value %s. The legal values are %s.\n", value,
+ allowed_values);
+ exit(1);
+ }
+ }
+ if (sizeof...(Dst) != sscanf(value, format, dst...)) {
+ fprintf(stderr, "Failed to parse %s\n", value);
+ exit(1);
+ }
+}
+
+struct Params {
+ char types[100];
+ int m, k, n; // matmul shape m*k*n
+ int paths;
+ int num_threads;
+ int repeat;
+ int lhs_cache_policy;
+ int rhs_cache_policy;
+ int lhs_stride;
+ int rhs_stride;
+ int dst_stride;
+ int lhs_zero_point;
+ int rhs_zero_point;
+ int dst_zero_point;
+ char lhs_order[100];
+ char rhs_order[100];
+ char dst_order[100];
+};
+
+template <typename LhsType, typename RhsType, typename DstType>
+void run(const Params& params) {
+ using AccumType =
+ typename std::conditional<std::is_floating_point<DstType>::value, DstType,
+ std::int32_t>::type;
+
+ ruy::Matrix<LhsType> lhs;
+ ruy::Matrix<RhsType> rhs;
+ ruy::Matrix<DstType> dst;
+
+ auto parse_order = [](const char* name) {
+ if (!std::strcmp(name, "row-major")) {
+ return ruy::Order::kRowMajor;
+ } else if (!std::strcmp(name, "column-major")) {
+ return ruy::Order::kColMajor;
+ } else {
+ fprintf(stderr, "Failed to parse %s\n", name);
+ exit(1);
+ }
+ };
+
+ auto make_layout = [](int rows, int cols, int stride, ruy::Order order,
+ ruy::Layout* layout) {
+ layout->set_rows(rows);
+ layout->set_cols(cols);
+ layout->set_order(order);
+ int base_stride = order == ruy::Order::kRowMajor ? cols : rows;
+ layout->set_stride(stride ? stride : base_stride);
+ };
+
+ make_layout(params.m, params.k, params.lhs_stride,
+ parse_order(params.lhs_order), lhs.mutable_layout());
+ make_layout(params.k, params.n, params.rhs_stride,
+ parse_order(params.rhs_order), rhs.mutable_layout());
+ make_layout(params.m, params.n, params.dst_stride,
+ parse_order(params.dst_order), dst.mutable_layout());
+
+ lhs.set_zero_point(params.lhs_zero_point);
+ rhs.set_zero_point(params.rhs_zero_point);
+ dst.set_zero_point(params.dst_zero_point);
+
+ lhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.lhs_cache_policy));
+ rhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.rhs_cache_policy));
+
+ auto flat_size = [](const ruy::Layout& layout) {
+ int outer_size =
+ layout.order() == ruy::Order::kRowMajor ? layout.rows() : layout.cols();
+ return outer_size * layout.stride();
+ };
+
+ std::vector<LhsType> lhs_buf(flat_size(lhs.layout()));
+ std::vector<RhsType> rhs_buf(flat_size(rhs.layout()));
+ std::vector<DstType> dst_buf(flat_size(dst.layout()));
+
+ lhs.set_data(lhs_buf.data());
+ rhs.set_data(rhs_buf.data());
+ dst.set_data(dst_buf.data());
+
+ ruy::Context context;
+ context.set_max_num_threads(params.num_threads);
+ context.set_runtime_enabled_paths(static_cast<ruy::Path>(params.paths));
+
+ ruy::MulParams<AccumType, DstType> mul_params;
+ // Here an actual application might set some mul_params fields.
+ // Quantization multipliers, bias-vector, clamp bounds, etc.
+
+ for (int r = 0; r < params.repeat; r++) {
+ ruy::Mul(lhs, rhs, mul_params, &context, &dst);
+ }
+}
+
+int main(int argc, char* argv[]) {
+ bool help = argc == 1 || (argc == 2 && !strcmp(argv[1], "--help"));
+ if (help) {
+ fprintf(stderr, "Command-line flags (all in the form --flag=value):\n");
+ fprintf(stderr, "%-20s %-12s %-16s %s\n", "flag", "format", "default",
+ "allowed");
+ }
+ Params params;
+ const char* allowed_types =
+ "f32xf32->f32, i8xi8->i8, i8xi8->i16, i8xi8->i32, u8xu8->i16, u8xi8->u8";
+ const char* allowed_orders = "row-major, column-major";
+ read_cmdline_args(help, argc, argv, "--types", "%s", "f32xf32->f32",
+ allowed_types, &params.types);
+ read_cmdline_args(help, argc, argv, "--shape", "%dx%dx%d", "100x100x100",
+ nullptr, &params.m, &params.k, &params.n);
+ read_cmdline_args(help, argc, argv, "--paths", "%x", "0", nullptr,
+ &params.paths);
+ read_cmdline_args(help, argc, argv, "--num_threads", "%d", "1", nullptr,
+ &params.num_threads);
+ read_cmdline_args(help, argc, argv, "--repeat", "%d", "1", nullptr,
+ &params.repeat);
+ read_cmdline_args(help, argc, argv, "--lhs_cache_policy", "%d", "0",
+ "0, 1, 2, 3", &params.lhs_cache_policy);
+ read_cmdline_args(help, argc, argv, "--rhs_cache_policy", "%d", "0",
+ "0, 1, 2, 3", &params.rhs_cache_policy);
+ read_cmdline_args(help, argc, argv, "--lhs_stride", "%d", "0", nullptr,
+ &params.lhs_stride);
+ read_cmdline_args(help, argc, argv, "--rhs_stride", "%d", "0", nullptr,
+ &params.rhs_stride);
+ read_cmdline_args(help, argc, argv, "--dst_stride", "%d", "0", nullptr,
+ &params.dst_stride);
+ read_cmdline_args(help, argc, argv, "--lhs_zero_point", "%d", "0", nullptr,
+ &params.lhs_zero_point);
+ read_cmdline_args(help, argc, argv, "--rhs_zero_point", "%d", "0", nullptr,
+ &params.rhs_zero_point);
+ read_cmdline_args(help, argc, argv, "--dst_zero_point", "%d", "0", nullptr,
+ &params.dst_zero_point);
+ read_cmdline_args(help, argc, argv, "--lhs_order", "%s", "row-major",
+ allowed_orders, &params.lhs_order);
+ read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
+ allowed_orders, &params.rhs_order);
+ read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
+ allowed_orders, &params.dst_order);
+
+ if (help) {
+ exit(1);
+ }
+
+ if (!strcmp(params.types, "f32xf32->f32")) {
+ run<float, float, float>(params);
+ } else if (!strcmp(params.types, "i8xi8->i8")) {
+ run<std::int8_t, std::int8_t, std::int8_t>(params);
+ } else if (!strcmp(params.types, "i8xi8->i16")) {
+ run<std::int8_t, std::int8_t, std::int16_t>(params);
+ } else if (!strcmp(params.types, "i8xi8->i32")) {
+ run<std::int8_t, std::int8_t, std::int32_t>(params);
+ } else if (!strcmp(params.types, "u8xu8->i16")) {
+ run<std::uint8_t, std::uint8_t, std::int16_t>(params);
+ } else if (!strcmp(params.types, "u8xi8->u8")) {
+ run<std::uint8_t, std::int8_t, std::uint8_t>(params);
+ } else {
+ fprintf(stderr, "Unknown types: %s\n", params.types);
+ exit(1);
+ }
+}
diff --git a/ruy/BUILD b/ruy/BUILD
new file mode 100644
index 0000000..37e89ab
--- /dev/null
+++ b/ruy/BUILD
@@ -0,0 +1,1220 @@
+# Ruy is not BLAS
+
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+load("@bazel_skylib//lib:selects.bzl", "selects")
+load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx", "ruy_copts_avx2_fma", "ruy_copts_avx512")
+load(":build_defs.oss.bzl", "ruy_linkopts_thread_standard_library")
+load(":ruy_test_ext.oss.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps")
+load(":ruy_test.bzl", "ruy_benchmark", "ruy_test")
+
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+config_setting(
+ name = "armeabi-v7a",
+ values = {"cpu": "armeabi-v7a"},
+)
+
+config_setting(
+ name = "armv7a",
+ values = {"cpu": "armv7a"},
+)
+
+# Detect ARM 32-bit targets where we are going to just assume NEON support.
+selects.config_setting_group(
+ name = "arm32_assuming_neon",
+ match_any = [
+ ":armeabi-v7a",
+ ":armv7a",
+ ],
+)
+
+config_setting(
+ name = "x86_64_k8",
+ values = {"cpu": "k8"},
+)
+
+config_setting(
+ name = "x86_64_haswell",
+ values = {"cpu": "haswell"},
+)
+
+# MSVC toolchains define a different "cpu" value, which helps us as we need
+# to pass different flags on MSVC vs GCC-compatible toolchains to enable
+# x86 SIMD extensions.
+selects.config_setting_group(
+ name = "x86_64_and_not_msvc",
+ match_any = [
+ ":x86_64_k8",
+ ":x86_64_haswell",
+ ],
+)
+
+config_setting(
+ name = "ppc",
+ values = {
+ "cpu": "ppc",
+ },
+)
+
+config_setting(
+ name = "s390x",
+ values = {
+ "cpu": "s390x",
+ },
+)
+
+config_setting(
+ name = "fuchsia",
+ values = {"cpu": "fuchsia"},
+)
+
+config_setting(
+ name = "dbg_build",
+ values = {
+ "compilation_mode": "dbg",
+ },
+)
+
+config_setting(
+ name = "fastbuild_build",
+ values = {
+ "compilation_mode": "fastbuild",
+ },
+)
+
+selects.config_setting_group(
+ name = "do_not_want_O3",
+ match_any = [
+ "@bazel_tools//src/conditions:windows_msvc",
+ ":dbg_build",
+ ":fastbuild_build",
+ ],
+)
+
+cc_library(
+ name = "trace",
+ hdrs = ["trace.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":mat",
+ ":matrix",
+ ":path",
+ ":platform",
+ ":side_pair",
+ ],
+)
+
+cc_library(
+ name = "platform",
+ hdrs = ["platform.h"],
+ copts = ruy_copts(),
+)
+
+cc_library(
+ name = "gtest_wrapper",
+ testonly = True,
+ hdrs = ["gtest_wrapper.h"],
+ visibility = [":__subpackages__"],
+ deps = ["@com_google_googletest//:gtest"],
+)
+
+cc_library(
+ name = "check_macros",
+ hdrs = ["check_macros.h"],
+ copts = ruy_copts(),
+)
+
+cc_test(
+ name = "check_macros_test",
+ srcs = ["check_macros_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":check_macros",
+ ":gtest_wrapper",
+ ],
+)
+
+cc_library(
+ name = "opt_set",
+ hdrs = ["opt_set.h"],
+ copts = ruy_copts(),
+)
+
+cc_library(
+ name = "time",
+ hdrs = ["time.h"],
+ copts = ruy_copts(),
+)
+
+cc_library(
+ name = "wait",
+ srcs = ["wait.cc"],
+ hdrs = ["wait.h"],
+ copts = ruy_copts(),
+ linkopts = ruy_linkopts_thread_standard_library(),
+ deps = [":time"],
+)
+
+cc_test(
+ name = "wait_test",
+ srcs = ["wait_test.cc"],
+ copts = ruy_copts(),
+ linkopts = ruy_linkopts_thread_standard_library(),
+ deps = [
+ ":gtest_wrapper",
+ ":platform",
+ ":wait",
+ ],
+)
+
+cc_library(
+ name = "size_util",
+ hdrs = ["size_util.h"],
+ copts = ruy_copts(),
+ deps = [":check_macros"],
+)
+
+cc_test(
+ name = "size_util_test",
+ srcs = ["size_util_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":gtest_wrapper",
+ ":size_util",
+ ],
+)
+
+cc_library(
+ name = "tune",
+ srcs = [
+ "tune.cc",
+ ],
+ hdrs = [
+ "tune.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":cpu_cache_params",
+ ":cpuinfo",
+ ":opt_set",
+ ":platform",
+ ":time",
+ ],
+)
+
+cc_library(
+ name = "system_aligned_alloc",
+ srcs = [
+ "system_aligned_alloc.cc",
+ ],
+ hdrs = [
+ "system_aligned_alloc.h",
+ ],
+ copts = ruy_copts(),
+)
+
+cc_library(
+ name = "prepacked_cache",
+ srcs = [
+ "prepacked_cache.cc",
+ ],
+ hdrs = [
+ "prepacked_cache.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":mat",
+ ":system_aligned_alloc",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_test(
+ name = "tune_test",
+ srcs = ["tune_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":cpuinfo",
+ ":gtest_wrapper",
+ ":tune",
+ ],
+)
+
+cc_test(
+ name = "prepacked_cache_test",
+ srcs = ["prepacked_cache_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":context",
+ ":context_get_ctx",
+ ":ctx",
+ ":gtest_wrapper",
+ ":mat",
+ ":matrix",
+ ":prepacked_cache",
+ ":ruy",
+ ":time",
+ ],
+)
+
+cc_library(
+ name = "allocator",
+ srcs = [
+ "allocator.cc",
+ ],
+ hdrs = [
+ "allocator.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":opt_set",
+ ":size_util",
+ ":system_aligned_alloc",
+ ],
+)
+
+cc_test(
+ name = "allocator_test",
+ srcs = ["allocator_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":allocator",
+ ":gtest_wrapper",
+ ],
+)
+
+cc_library(
+ name = "side_pair",
+ hdrs = ["side_pair.h"],
+ copts = ruy_copts(),
+ deps = [":check_macros"],
+)
+
+cc_library(
+ name = "block_map",
+ srcs = [
+ "block_map.cc",
+ ],
+ hdrs = [
+ "block_map.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":check_macros",
+ ":cpu_cache_params",
+ ":opt_set",
+ ":side_pair",
+ ":size_util",
+ ":trace",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_test(
+ name = "block_map_test",
+ srcs = ["block_map_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":block_map",
+ ":cpu_cache_params",
+ ":gtest_wrapper",
+ ":path",
+ ":platform",
+ ":side_pair",
+ ],
+)
+
+cc_library(
+ name = "blocking_counter",
+ srcs = [
+ "blocking_counter.cc",
+ ],
+ hdrs = [
+ "blocking_counter.h",
+ ],
+ copts = ruy_copts(),
+ linkopts = ruy_linkopts_thread_standard_library(),
+ deps = [
+ ":check_macros",
+ ":time",
+ ":wait",
+ ],
+)
+
+cc_library(
+ name = "thread_pool",
+ srcs = [
+ "thread_pool.cc",
+ ],
+ hdrs = [
+ "thread_pool.h",
+ ],
+ copts = ruy_copts(),
+ linkopts = ruy_linkopts_thread_standard_library(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":blocking_counter",
+ ":check_macros",
+ ":time",
+ ":trace",
+ ":wait",
+ ],
+)
+
+cc_library(
+ name = "cpu_cache_params",
+ hdrs = ["cpu_cache_params.h"],
+ copts = ruy_copts(),
+)
+
+cc_library(
+ name = "cpuinfo",
+ srcs = [
+ "cpuinfo.cc",
+ ],
+ hdrs = [
+ "cpuinfo.h",
+ ],
+ copts = ruy_copts() +
+ select({
+ "@bazel_tools//src/conditions:windows": [],
+ "//conditions:default": [
+ # ruy_copts contains -Wundef, but cpuinfo's header warns with that.
+ "-Wno-undef",
+ ],
+ }) + select({
+ # This select must match the similar select in `deps`.
+ # We intentionally define this token in the BUILD
+ # file so that ports to other build-systems do not
+ # use cpuinfo by default - they need to port the
+ # cpuinfo BUILD first, then can define this token.
+ ":ppc": [],
+ ":s390x": [],
+ ":fuchsia": [],
+ "//conditions:default": ["-DRUY_HAVE_CPUINFO"],
+ }),
+ deps = [
+ ":platform",
+ ":check_macros",
+ ":cpu_cache_params",
+ ] + select({
+ # This select must match the similar select in `copts`
+ ":ppc": [],
+ ":s390x": [],
+ ":fuchsia": [],
+ "//conditions:default": ["@cpuinfo"],
+ }),
+)
+
+cc_library(
+ name = "path",
+ hdrs = ["path.h"],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":platform",
+ ":size_util",
+ ],
+)
+
+cc_library(
+ name = "performance_advisory",
+ hdrs = ["performance_advisory.h"],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "matrix",
+ hdrs = ["matrix.h"],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+ deps = [":check_macros"],
+)
+
+cc_test(
+ name = "matrix_test",
+ srcs = ["matrix_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":gtest_wrapper",
+ ":matrix",
+ ],
+)
+
+cc_library(
+ name = "mul_params",
+ hdrs = ["mul_params.h"],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":check_macros",
+ ":size_util",
+ ],
+)
+
+cc_test(
+ name = "mul_params_test",
+ srcs = ["mul_params_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":gtest_wrapper",
+ ":mul_params",
+ ],
+)
+
+cc_library(
+ name = "mat",
+ hdrs = ["mat.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":check_macros",
+ ":matrix",
+ ":size_util",
+ ],
+)
+
+cc_library(
+ name = "asm_helpers",
+ hdrs = [
+ "asm_helpers.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":opt_set",
+ ],
+)
+
+cc_library(
+ name = "apply_multiplier",
+ srcs = ["apply_multiplier.cc"],
+ hdrs = ["apply_multiplier.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":check_macros",
+ ":mul_params",
+ ],
+)
+
+cc_test(
+ name = "apply_multiplier_test",
+ srcs = ["apply_multiplier_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":apply_multiplier",
+ ":gtest_wrapper",
+ ":mul_params",
+ ],
+)
+
+cc_library(
+ name = "kernel_common",
+ hdrs = [
+ "kernel_common.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":apply_multiplier",
+ ":check_macros",
+ ":mat",
+ ":matrix",
+ ":mul_params",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":side_pair",
+ ":size_util",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_common",
+ hdrs = [
+ "pack_common.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":check_macros",
+ ":mat",
+ ":matrix",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "kernel_arm",
+ srcs = [
+ "kernel_arm32.cc",
+ "kernel_arm64.cc",
+ ],
+ hdrs = ["kernel_arm.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":asm_helpers",
+ ":check_macros",
+ ":kernel_common",
+ ":mat",
+ ":mul_params",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":side_pair",
+ ":size_util",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_arm",
+ srcs = [
+ "pack_arm.cc",
+ ],
+ hdrs = [
+ "pack_arm.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":asm_helpers",
+ ":check_macros",
+ ":mat",
+ ":opt_set",
+ ":pack_common",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "kernel_avx512",
+ srcs = [
+ "kernel_avx512.cc",
+ ],
+ hdrs = [
+ "kernel_x86.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx512(),
+ deps = [
+ ":check_macros",
+ ":kernel_common",
+ ":mat",
+ ":mul_params",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_avx512",
+ srcs = [
+ "pack_avx512.cc",
+ ],
+ hdrs = [
+ "pack_x86.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx512(),
+ deps = [
+ ":check_macros",
+ ":mat",
+ ":opt_set",
+ ":pack_common",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for_avx512",
+ srcs = [
+ "have_built_path_for_avx512.cc",
+ ],
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx512(),
+ deps = [
+ ":opt_set",
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "kernel_avx2_fma",
+ srcs = [
+ "kernel_avx2_fma.cc",
+ ],
+ hdrs = [
+ "kernel_x86.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx2_fma(),
+ deps = [
+ ":check_macros",
+ ":kernel_common",
+ ":mat",
+ ":mul_params",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_avx2_fma",
+ srcs = [
+ "pack_avx2_fma.cc",
+ ],
+ hdrs = [
+ "pack_x86.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx2_fma(),
+ deps = [
+ ":check_macros",
+ ":mat",
+ ":opt_set",
+ ":pack_common",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for_avx2_fma",
+ srcs = [
+ "have_built_path_for_avx2_fma.cc",
+ ],
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx2_fma(),
+ deps = [
+ ":opt_set",
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "kernel_avx",
+ srcs = [
+ "kernel_avx.cc",
+ ],
+ hdrs = [
+ "kernel_x86.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx(),
+ deps = [
+ ":check_macros",
+ ":kernel_common",
+ ":mat",
+ ":mul_params",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_avx",
+ srcs = [
+ "pack_avx.cc",
+ ],
+ hdrs = [
+ "pack_x86.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx(),
+ deps = [
+ ":check_macros",
+ ":mat",
+ ":opt_set",
+ ":pack_common",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for_avx",
+ srcs = [
+ "have_built_path_for_avx.cc",
+ ],
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ copts = ruy_copts() + ruy_copts_avx(),
+ deps = [
+ ":opt_set",
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "kernel",
+ hdrs = [
+ "kernel.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":apply_multiplier",
+ ":check_macros",
+ ":kernel_arm", # fixdeps: keep
+ ":kernel_avx",
+ ":kernel_avx2_fma", # fixdeps: keep
+ ":kernel_avx512", # fixdeps: keep
+ ":kernel_common",
+ ":mat",
+ ":matrix",
+ ":mul_params",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":side_pair",
+ ":size_util",
+ ":trace",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack",
+ hdrs = [
+ "pack.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":check_macros",
+ ":mat",
+ ":matrix",
+ ":opt_set",
+ ":pack_arm", # fixdeps: keep
+ ":pack_avx", # fixdeps: keep
+ ":pack_avx2_fma", # fixdeps: keep
+ ":pack_avx512", # fixdeps: keep
+ ":pack_common",
+ ":path",
+ ":platform",
+ ":trace",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for",
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ deps = [
+ ":have_built_path_for_avx",
+ ":have_built_path_for_avx2_fma",
+ ":have_built_path_for_avx512",
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "context",
+ srcs = ["context.cc"],
+ hdrs = [
+ "context.h",
+ ],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":allocator",
+ ":check_macros",
+ ":ctx",
+ ":path",
+ ":performance_advisory",
+ ":platform",
+ ":prepacked_cache",
+ ":thread_pool",
+ ":tune",
+ ],
+)
+
+cc_test(
+ name = "context_test",
+ srcs = ["context_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":context",
+ ":gtest_wrapper",
+ ":path",
+ ":platform",
+ ":prepacked_cache",
+ ":tune",
+ ],
+)
+
+cc_library(
+ name = "ctx_header_only_should_not_include_other_ruy_headers",
+ testonly = True,
+ hdrs = [
+ "ctx.h",
+ ],
+ # Intentionally no deps. This will cause the stand-alone build of ctx.h to
+ # fail if ctx.h #includes any other ruy header.
+)
+
+cc_library(
+ name = "ctx",
+ srcs = [
+ "ctx.cc",
+ ],
+ hdrs = [
+ "ctx.h",
+ "ctx_impl.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":allocator",
+ ":check_macros",
+ ":cpuinfo",
+ ":have_built_path_for",
+ ":path",
+ ":performance_advisory",
+ ":platform",
+ ":prepacked_cache",
+ ":thread_pool",
+ ":trace",
+ ":tune",
+ ],
+)
+
+cc_library(
+ name = "context_get_ctx",
+ srcs = [
+ "context_get_ctx.cc",
+ ],
+ hdrs = [
+ "context_get_ctx.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":context",
+ ":ctx",
+ ],
+)
+
+cc_test(
+ name = "ctx_test",
+ srcs = ["ctx_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":ctx",
+ ":gtest_wrapper",
+ ":path",
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "trmul_params",
+ hdrs = ["trmul_params.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":mat",
+ ":mul_params",
+ ":path",
+ ":side_pair",
+ ":tune",
+ ],
+)
+
+cc_library(
+ name = "trmul",
+ srcs = ["trmul.cc"],
+ hdrs = ["trmul.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":allocator",
+ ":block_map",
+ ":check_macros",
+ ":cpu_cache_params",
+ ":cpuinfo",
+ ":ctx",
+ ":mat",
+ ":matrix",
+ ":mul_params",
+ ":opt_set",
+ ":side_pair",
+ ":size_util",
+ ":thread_pool",
+ ":trace",
+ ":trmul_params",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "prepare_packed_matrices",
+ srcs = ["prepare_packed_matrices.cc"],
+ hdrs = ["prepare_packed_matrices.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":allocator",
+ ":ctx",
+ ":matrix",
+ ":prepacked_cache",
+ ":side_pair",
+ ":trace",
+ ":trmul_params",
+ ],
+)
+
+cc_library(
+ name = "create_trmul_params",
+ hdrs = ["create_trmul_params.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":allocator",
+ ":check_macros",
+ ":ctx",
+ ":kernel",
+ ":mat",
+ ":mul_params",
+ ":pack",
+ ":path",
+ ":performance_advisory",
+ ":platform",
+ ":side_pair",
+ ":trace",
+ ":trmul_params",
+ ],
+)
+
+cc_library(
+ name = "validate",
+ hdrs = ["validate.h"],
+ copts = ruy_copts(),
+ deps = [
+ ":check_macros",
+ ":mat",
+ ":mul_params",
+ ":side_pair",
+ ],
+)
+
+cc_library(
+ name = "frontend",
+ srcs = [
+ "frontend.cc",
+ ],
+ hdrs = [
+ "frontend.h",
+ ],
+ copts = ruy_copts(),
+ deps = [
+ ":allocator",
+ ":create_trmul_params",
+ ":ctx",
+ ":mat",
+ ":mul_params",
+ ":prepare_packed_matrices",
+ ":trace",
+ ":trmul",
+ ":trmul_params",
+ ":validate",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+# The main library.
+cc_library(
+ name = "ruy",
+ hdrs = [
+ "context.h",
+ "matrix.h",
+ "mul_params.h",
+ "path.h",
+ "ruy.h",
+ ],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":check_macros",
+ ":context",
+ ":context_get_ctx",
+ ":frontend",
+ ":mat",
+ ":matrix",
+ ":mul_params",
+ ":path",
+ ":platform",
+ ":size_util",
+ ":trace",
+ ],
+)
+
+cc_test(
+ name = "perchannel_buffers_reallocation_test",
+ srcs = ["perchannel_buffers_reallocation_test.cc"],
+ copts = ruy_copts(),
+ deps = [
+ ":context",
+ ":gtest_wrapper",
+ ":kernel",
+ ":matrix",
+ ":path",
+ ":performance_advisory",
+ ":ruy",
+ ],
+)
+
+# Small library to query PMU counters, for benchmark only
+cc_library(
+ name = "pmu",
+ testonly = True,
+ srcs = ["pmu.cc"],
+ hdrs = ["pmu.h"],
+ copts = ruy_copts(),
+ deps = [":check_macros"],
+)
+
+cc_library(
+ name = "reference_mul",
+ hdrs = ["reference_mul.h"],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":apply_multiplier",
+ ":matrix",
+ ":mul_params",
+ ],
+)
+
+# Testing framework.
+cc_library(
+ name = "test_lib",
+ testonly = True,
+ hdrs = ["test.h"],
+ copts = ruy_copts(),
+ # need defines, not copts, because it's controlling a header, test.h
+ defines = ruy_test_ext_defines(),
+ linkopts = select({
+ "@bazel_tools//src/conditions:windows": [],
+ "//conditions:default": ["-lm"],
+ }),
+ deps = [
+ ":allocator",
+ ":size_util",
+ ":reference_mul",
+ ":matrix",
+ ":pmu",
+ ":ruy",
+ ":mul_params",
+ ":time",
+ ":gtest_wrapper",
+ ":platform",
+ ":context",
+ ":ctx",
+ ":context_get_ctx",
+ ":pack_common",
+ "//ruy/profiler",
+ ] + ruy_test_ext_deps(),
+)
+
+ruy_benchmark(
+ name = "benchmark",
+ srcs = ["benchmark.cc"],
+ copts = ruy_copts(),
+ lhs_rhs_accum_dst = [
+ ("f32", "f32", "f32", "f32"),
+ ("u8", "u8", "i32", "u8"),
+ ("i8", "i8", "i32", "u8"),
+ ("i8", "i8", "i32", "i8"),
+ ("u8", "u8", "i32", "i16"),
+ ("i8", "i8", "i32", "i32"),
+ ],
+ deps = [
+ ":test_lib",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+ruy_test(
+ name = "test_fast",
+ srcs = ["test_fast.cc"],
+ copts = ruy_copts(),
+ lhs_rhs_accum_dst = [
+ ("f32", "f32", "f32", "f32"),
+ ("f64", "f32", "f64", "f32"),
+ ("f32", "f64", "f64", "f64"),
+ ("u8", "u8", "i32", "u8"),
+ ("i8", "i8", "i32", "i8"),
+ ("i8", "u8", "i32", "i8"),
+ ("u8", "u8", "i32", "i16"),
+ ("i8", "i8", "i32", "i32"),
+ ("i8", "u8", "i32", "i32"),
+ ],
+ deps = [
+ ":test_lib",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+ruy_test(
+ name = "test_slow",
+ srcs = ["test_slow.cc"],
+ copts = ruy_copts(),
+ lhs_rhs_accum_dst = [
+ ("f32", "f32", "f32", "f32"),
+ ("u8", "u8", "i32", "u8"),
+ ("i8", "i8", "i32", "i8"),
+ ("u8", "u8", "i32", "i16"),
+ ("i8", "i8", "i32", "i32"),
+ ],
+ tags = ["slow"],
+ deps = [
+ ":test_lib",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+bzl_library(
+ name = "ruy_test_ext.oss_bzl",
+ srcs = ["ruy_test_ext.oss.bzl"],
+ visibility = ["//visibility:private"],
+)
+
+bzl_library(
+ name = "ruy_test_bzl",
+ srcs = ["ruy_test.bzl"],
+ visibility = ["//visibility:private"],
+)
+
+bzl_library(
+ name = "build_defs.oss_bzl",
+ srcs = ["build_defs.oss.bzl"],
+ visibility = ["//visibility:private"],
+)
+
+bzl_library(
+ name = "build_defs_bzl",
+ srcs = ["build_defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/ruy/CMakeLists.txt b/ruy/CMakeLists.txt
new file mode 100644
index 0000000..4c3e394
--- /dev/null
+++ b/ruy/CMakeLists.txt
@@ -0,0 +1,1696 @@
+# This file is generated (whence no license header). Do not edit!
+# To regenerate, run:
+# cmake/bazel_to_cmake.sh
+
+if(CMAKE_SYSTEM_NAME STREQUAL Windows)
+ set(ruy_0_Wall_Wcxx14_compat_Wextra_Wundef "")
+else()
+ set(ruy_0_Wall_Wcxx14_compat_Wextra_Wundef "-Wall;-Wextra;-Wc++14-compat;-Wundef")
+endif()
+
+if(CMAKE_SYSTEM_PROCESSOR STREQUAL arm)
+ set(ruy_1_mfpu_neon "-mfpu=neon")
+else()
+ set(ruy_1_mfpu_neon "")
+endif()
+
+if((CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC)
+ set(ruy_2_O3 "")
+else()
+ set(ruy_2_O3 "-O3")
+endif()
+
+ruy_cc_library(
+ NAME
+ ruy_trace
+ HDRS
+ trace.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_mat
+ ruy_matrix
+ ruy_path
+ ruy_platform
+ ruy_side_pair
+)
+
+ruy_cc_library(
+ NAME
+ ruy_platform
+ HDRS
+ platform.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+)
+
+ruy_cc_library(
+ NAME
+ ruy_gtest_wrapper
+ TESTONLY
+ HDRS
+ gtest_wrapper.h
+ DEPS
+ gtest
+)
+
+ruy_cc_library(
+ NAME
+ ruy_check_macros
+ HDRS
+ check_macros.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+)
+
+ruy_cc_test(
+ NAME
+ ruy_check_macros_test
+ SRCS
+ check_macros_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+ ruy_gtest_wrapper
+)
+
+ruy_cc_library(
+ NAME
+ ruy_opt_set
+ HDRS
+ opt_set.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+)
+
+ruy_cc_library(
+ NAME
+ ruy_time
+ HDRS
+ time.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+)
+
+if(CMAKE_SYSTEM_NAME STREQUAL Windows)
+ set(ruy_3_pthread "")
+else()
+ set(ruy_3_pthread "-pthread")
+endif()
+
+ruy_cc_library(
+ NAME
+ ruy_wait
+ SRCS
+ wait.cc
+ HDRS
+ wait.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ LINKOPTS
+ ${ruy_3_pthread}
+ DEPS
+ ruy_time
+)
+
+ruy_cc_test(
+ NAME
+ ruy_wait_test
+ SRCS
+ wait_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ LINKOPTS
+ ${ruy_3_pthread}
+ DEPS
+ ruy_gtest_wrapper
+ ruy_platform
+ ruy_wait
+)
+
+ruy_cc_library(
+ NAME
+ ruy_size_util
+ HDRS
+ size_util.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+)
+
+ruy_cc_test(
+ NAME
+ ruy_size_util_test
+ SRCS
+ size_util_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_gtest_wrapper
+ ruy_size_util
+)
+
+ruy_cc_library(
+ NAME
+ ruy_tune
+ SRCS
+ tune.cc
+ HDRS
+ tune.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_cpu_cache_params
+ ruy_cpuinfo
+ ruy_opt_set
+ ruy_platform
+ ruy_time
+)
+
+ruy_cc_library(
+ NAME
+ ruy_system_aligned_alloc
+ SRCS
+ system_aligned_alloc.cc
+ HDRS
+ system_aligned_alloc.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+)
+
+ruy_cc_library(
+ NAME
+ ruy_prepacked_cache
+ SRCS
+ prepacked_cache.cc
+ HDRS
+ prepacked_cache.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_mat
+ ruy_system_aligned_alloc
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_test(
+ NAME
+ ruy_tune_test
+ SRCS
+ tune_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_cpuinfo
+ ruy_gtest_wrapper
+ ruy_tune
+)
+
+ruy_cc_test(
+ NAME
+ ruy_prepacked_cache_test
+ SRCS
+ prepacked_cache_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_context
+ ruy_context_get_ctx
+ ruy_ctx
+ ruy_gtest_wrapper
+ ruy_mat
+ ruy_matrix
+ ruy_prepacked_cache
+ ruy
+ ruy_time
+)
+
+ruy_cc_library(
+ NAME
+ ruy_allocator
+ SRCS
+ allocator.cc
+ HDRS
+ allocator.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_opt_set
+ ruy_size_util
+ ruy_system_aligned_alloc
+)
+
+ruy_cc_test(
+ NAME
+ ruy_allocator_test
+ SRCS
+ allocator_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_allocator
+ ruy_gtest_wrapper
+)
+
+ruy_cc_library(
+ NAME
+ ruy_side_pair
+ HDRS
+ side_pair.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+)
+
+ruy_cc_library(
+ NAME
+ ruy_block_map
+ SRCS
+ block_map.cc
+ HDRS
+ block_map.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+ ruy_cpu_cache_params
+ ruy_opt_set
+ ruy_side_pair
+ ruy_size_util
+ ruy_trace
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_test(
+ NAME
+ ruy_block_map_test
+ SRCS
+ block_map_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_block_map
+ ruy_cpu_cache_params
+ ruy_gtest_wrapper
+ ruy_path
+ ruy_platform
+ ruy_side_pair
+)
+
+ruy_cc_library(
+ NAME
+ ruy_blocking_counter
+ SRCS
+ blocking_counter.cc
+ HDRS
+ blocking_counter.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ LINKOPTS
+ ${ruy_3_pthread}
+ DEPS
+ ruy_check_macros
+ ruy_time
+ ruy_wait
+)
+
+ruy_cc_library(
+ NAME
+ ruy_thread_pool
+ SRCS
+ thread_pool.cc
+ HDRS
+ thread_pool.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ LINKOPTS
+ ${ruy_3_pthread}
+ PUBLIC
+ DEPS
+ ruy_blocking_counter
+ ruy_check_macros
+ ruy_time
+ ruy_trace
+ ruy_wait
+)
+
+ruy_cc_library(
+ NAME
+ ruy_cpu_cache_params
+ HDRS
+ cpu_cache_params.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+)
+
+if(CMAKE_SYSTEM_NAME STREQUAL Windows)
+ set(ruy_4_Wno_undef "")
+else()
+ set(ruy_4_Wno_undef "-Wno-undef")
+endif()
+
+if(CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le)
+ set(ruy_5_DRUY_HAVE_CPUINFO "")
+elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x)
+ set(ruy_5_DRUY_HAVE_CPUINFO "")
+elseif(CMAKE_SYSTEM_NAME STREQUAL Fuchsia)
+ set(ruy_5_DRUY_HAVE_CPUINFO "")
+else()
+ set(ruy_5_DRUY_HAVE_CPUINFO "-DRUY_HAVE_CPUINFO")
+endif()
+
+if(CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le)
+ set(ruy_6_cpuinfo "")
+elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x)
+ set(ruy_6_cpuinfo "")
+elseif(CMAKE_SYSTEM_NAME STREQUAL Fuchsia)
+ set(ruy_6_cpuinfo "")
+else()
+ set(ruy_6_cpuinfo "cpuinfo")
+endif()
+
+ruy_cc_library(
+ NAME
+ ruy_cpuinfo
+ SRCS
+ cpuinfo.cc
+ HDRS
+ cpuinfo.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_4_Wno_undef}
+ ${ruy_5_DRUY_HAVE_CPUINFO}
+ DEPS
+ ruy_platform
+ ruy_check_macros
+ ruy_cpu_cache_params
+ ${ruy_6_cpuinfo}
+)
+
+ruy_cc_library(
+ NAME
+ ruy_path
+ HDRS
+ path.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ PUBLIC
+ DEPS
+ ruy_platform
+ ruy_size_util
+)
+
+ruy_cc_library(
+ NAME
+ ruy_performance_advisory
+ HDRS
+ performance_advisory.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ PUBLIC
+)
+
+ruy_cc_library(
+ NAME
+ ruy_matrix
+ HDRS
+ matrix.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ PUBLIC
+ DEPS
+ ruy_check_macros
+)
+
+ruy_cc_test(
+ NAME
+ ruy_matrix_test
+ SRCS
+ matrix_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_gtest_wrapper
+ ruy_matrix
+)
+
+ruy_cc_library(
+ NAME
+ ruy_mul_params
+ HDRS
+ mul_params.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ PUBLIC
+ DEPS
+ ruy_check_macros
+ ruy_size_util
+)
+
+ruy_cc_test(
+ NAME
+ ruy_mul_params_test
+ SRCS
+ mul_params_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_gtest_wrapper
+ ruy_mul_params
+)
+
+ruy_cc_library(
+ NAME
+ ruy_mat
+ HDRS
+ mat.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+ ruy_matrix
+ ruy_size_util
+)
+
+ruy_cc_library(
+ NAME
+ ruy_asm_helpers
+ HDRS
+ asm_helpers.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_opt_set
+)
+
+ruy_cc_library(
+ NAME
+ ruy_apply_multiplier
+ SRCS
+ apply_multiplier.cc
+ HDRS
+ apply_multiplier.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+ ruy_mul_params
+)
+
+ruy_cc_test(
+ NAME
+ ruy_apply_multiplier_test
+ SRCS
+ apply_multiplier_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_apply_multiplier
+ ruy_gtest_wrapper
+ ruy_mul_params
+)
+
+ruy_cc_library(
+ NAME
+ ruy_kernel_common
+ HDRS
+ kernel_common.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_apply_multiplier
+ ruy_check_macros
+ ruy_mat
+ ruy_matrix
+ ruy_mul_params
+ ruy_opt_set
+ ruy_path
+ ruy_platform
+ ruy_side_pair
+ ruy_size_util
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_pack_common
+ HDRS
+ pack_common.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+ ruy_mat
+ ruy_matrix
+ ruy_opt_set
+ ruy_path
+ ruy_platform
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_kernel_arm
+ SRCS
+ kernel_arm32.cc
+ kernel_arm64.cc
+ HDRS
+ kernel_arm.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_asm_helpers
+ ruy_check_macros
+ ruy_kernel_common
+ ruy_mat
+ ruy_mul_params
+ ruy_opt_set
+ ruy_path
+ ruy_platform
+ ruy_side_pair
+ ruy_size_util
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_pack_arm
+ SRCS
+ pack_arm.cc
+ HDRS
+ pack_arm.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_asm_helpers
+ ruy_check_macros
+ ruy_mat
+ ruy_opt_set
+ ruy_pack_common
+ ruy_path
+ ruy_platform
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+if((CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC)
+ set(ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512 ";-mavx512f;-mavx512vl;-mavx512cd;-mavx512bw;-mavx512dq")
+elseif(MSVC)
+ set(ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512 "/arch:AVX512")
+else()
+ set(ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512 "")
+endif()
+
+ruy_cc_library(
+ NAME
+ ruy_kernel_avx512
+ SRCS
+ kernel_avx512.cc
+ HDRS
+ kernel_x86.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512}
+ DEPS
+ ruy_check_macros
+ ruy_kernel_common
+ ruy_mat
+ ruy_mul_params
+ ruy_opt_set
+ ruy_path
+ ruy_platform
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_pack_avx512
+ SRCS
+ pack_avx512.cc
+ HDRS
+ pack_x86.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512}
+ DEPS
+ ruy_check_macros
+ ruy_mat
+ ruy_opt_set
+ ruy_pack_common
+ ruy_path
+ ruy_platform
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_have_built_path_for_avx512
+ SRCS
+ have_built_path_for_avx512.cc
+ HDRS
+ have_built_path_for.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512}
+ DEPS
+ ruy_opt_set
+ ruy_platform
+)
+
+if((CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC)
+ set(ruy_8_mavx2_mfma_arch_AVX2 "-mavx2;-mfma")
+elseif(MSVC)
+ set(ruy_8_mavx2_mfma_arch_AVX2 "/arch:AVX2")
+else()
+ set(ruy_8_mavx2_mfma_arch_AVX2 "")
+endif()
+
+ruy_cc_library(
+ NAME
+ ruy_kernel_avx2_fma
+ SRCS
+ kernel_avx2_fma.cc
+ HDRS
+ kernel_x86.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_8_mavx2_mfma_arch_AVX2}
+ DEPS
+ ruy_check_macros
+ ruy_kernel_common
+ ruy_mat
+ ruy_mul_params
+ ruy_opt_set
+ ruy_path
+ ruy_platform
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_pack_avx2_fma
+ SRCS
+ pack_avx2_fma.cc
+ HDRS
+ pack_x86.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_8_mavx2_mfma_arch_AVX2}
+ DEPS
+ ruy_check_macros
+ ruy_mat
+ ruy_opt_set
+ ruy_pack_common
+ ruy_path
+ ruy_platform
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_have_built_path_for_avx2_fma
+ SRCS
+ have_built_path_for_avx2_fma.cc
+ HDRS
+ have_built_path_for.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_8_mavx2_mfma_arch_AVX2}
+ DEPS
+ ruy_opt_set
+ ruy_platform
+)
+
+if((CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC)
+ set(ruy_9_mavx_arch_AVX "-mavx")
+elseif(MSVC)
+ set(ruy_9_mavx_arch_AVX "/arch:AVX")
+else()
+ set(ruy_9_mavx_arch_AVX "")
+endif()
+
+ruy_cc_library(
+ NAME
+ ruy_kernel_avx
+ SRCS
+ kernel_avx.cc
+ HDRS
+ kernel_x86.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_9_mavx_arch_AVX}
+ DEPS
+ ruy_check_macros
+ ruy_kernel_common
+ ruy_mat
+ ruy_mul_params
+ ruy_opt_set
+ ruy_path
+ ruy_platform
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_pack_avx
+ SRCS
+ pack_avx.cc
+ HDRS
+ pack_x86.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_9_mavx_arch_AVX}
+ DEPS
+ ruy_check_macros
+ ruy_mat
+ ruy_opt_set
+ ruy_pack_common
+ ruy_path
+ ruy_platform
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_have_built_path_for_avx
+ SRCS
+ have_built_path_for_avx.cc
+ HDRS
+ have_built_path_for.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ ${ruy_9_mavx_arch_AVX}
+ DEPS
+ ruy_opt_set
+ ruy_platform
+)
+
+ruy_cc_library(
+ NAME
+ ruy_kernel
+ HDRS
+ kernel.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_apply_multiplier
+ ruy_check_macros
+ ruy_kernel_arm
+ ruy_kernel_avx
+ ruy_kernel_avx2_fma
+ ruy_kernel_avx512
+ ruy_kernel_common
+ ruy_mat
+ ruy_matrix
+ ruy_mul_params
+ ruy_opt_set
+ ruy_path
+ ruy_platform
+ ruy_side_pair
+ ruy_size_util
+ ruy_trace
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_pack
+ HDRS
+ pack.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+ ruy_mat
+ ruy_matrix
+ ruy_opt_set
+ ruy_pack_arm
+ ruy_pack_avx
+ ruy_pack_avx2_fma
+ ruy_pack_avx512
+ ruy_pack_common
+ ruy_path
+ ruy_platform
+ ruy_trace
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_have_built_path_for
+ HDRS
+ have_built_path_for.h
+ DEPS
+ ruy_have_built_path_for_avx
+ ruy_have_built_path_for_avx2_fma
+ ruy_have_built_path_for_avx512
+ ruy_platform
+)
+
+ruy_cc_library(
+ NAME
+ ruy_context
+ SRCS
+ context.cc
+ HDRS
+ context.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ PUBLIC
+ DEPS
+ ruy_allocator
+ ruy_check_macros
+ ruy_ctx
+ ruy_path
+ ruy_performance_advisory
+ ruy_platform
+ ruy_prepacked_cache
+ ruy_thread_pool
+ ruy_tune
+)
+
+ruy_cc_test(
+ NAME
+ ruy_context_test
+ SRCS
+ context_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_context
+ ruy_gtest_wrapper
+ ruy_path
+ ruy_platform
+ ruy_prepacked_cache
+ ruy_tune
+)
+
+ruy_cc_library(
+ NAME
+ ruy_ctx_header_only_should_not_include_other_ruy_headers
+ TESTONLY
+ HDRS
+ ctx.h
+)
+
+ruy_cc_library(
+ NAME
+ ruy_ctx
+ SRCS
+ ctx.cc
+ HDRS
+ ctx.h
+ ctx_impl.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_allocator
+ ruy_check_macros
+ ruy_cpuinfo
+ ruy_have_built_path_for
+ ruy_path
+ ruy_performance_advisory
+ ruy_platform
+ ruy_prepacked_cache
+ ruy_thread_pool
+ ruy_trace
+ ruy_tune
+)
+
+ruy_cc_library(
+ NAME
+ ruy_context_get_ctx
+ SRCS
+ context_get_ctx.cc
+ HDRS
+ context_get_ctx.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_context
+ ruy_ctx
+)
+
+ruy_cc_test(
+ NAME
+ ruy_ctx_test
+ SRCS
+ ctx_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_ctx
+ ruy_gtest_wrapper
+ ruy_path
+ ruy_platform
+)
+
+ruy_cc_library(
+ NAME
+ ruy_trmul_params
+ HDRS
+ trmul_params.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_mat
+ ruy_mul_params
+ ruy_path
+ ruy_side_pair
+ ruy_tune
+)
+
+ruy_cc_library(
+ NAME
+ ruy_trmul
+ SRCS
+ trmul.cc
+ HDRS
+ trmul.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_allocator
+ ruy_block_map
+ ruy_check_macros
+ ruy_cpu_cache_params
+ ruy_cpuinfo
+ ruy_ctx
+ ruy_mat
+ ruy_matrix
+ ruy_mul_params
+ ruy_opt_set
+ ruy_side_pair
+ ruy_size_util
+ ruy_thread_pool
+ ruy_trace
+ ruy_trmul_params
+ ruy_tune
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_prepare_packed_matrices
+ SRCS
+ prepare_packed_matrices.cc
+ HDRS
+ prepare_packed_matrices.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_allocator
+ ruy_ctx
+ ruy_matrix
+ ruy_prepacked_cache
+ ruy_side_pair
+ ruy_trace
+ ruy_trmul_params
+)
+
+ruy_cc_library(
+ NAME
+ ruy_create_trmul_params
+ HDRS
+ create_trmul_params.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_allocator
+ ruy_check_macros
+ ruy_ctx
+ ruy_kernel
+ ruy_mat
+ ruy_mul_params
+ ruy_pack
+ ruy_path
+ ruy_performance_advisory
+ ruy_platform
+ ruy_side_pair
+ ruy_trace
+ ruy_trmul_params
+)
+
+ruy_cc_library(
+ NAME
+ ruy_validate
+ HDRS
+ validate.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+ ruy_mat
+ ruy_mul_params
+ ruy_side_pair
+)
+
+ruy_cc_library(
+ NAME
+ ruy_frontend
+ SRCS
+ frontend.cc
+ HDRS
+ frontend.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_allocator
+ ruy_create_trmul_params
+ ruy_ctx
+ ruy_mat
+ ruy_mul_params
+ ruy_prepare_packed_matrices
+ ruy_trace
+ ruy_trmul
+ ruy_trmul_params
+ ruy_validate
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy
+ HDRS
+ context.h
+ matrix.h
+ mul_params.h
+ path.h
+ ruy.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ PUBLIC
+ DEPS
+ ruy_check_macros
+ ruy_context
+ ruy_context_get_ctx
+ ruy_frontend
+ ruy_mat
+ ruy_matrix
+ ruy_mul_params
+ ruy_path
+ ruy_platform
+ ruy_size_util
+ ruy_trace
+)
+
+ruy_cc_test(
+ NAME
+ ruy_perchannel_buffers_reallocation_test
+ SRCS
+ perchannel_buffers_reallocation_test.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_context
+ ruy_gtest_wrapper
+ ruy_kernel
+ ruy_matrix
+ ruy_path
+ ruy_performance_advisory
+ ruy
+)
+
+ruy_cc_library(
+ NAME
+ ruy_pmu
+ TESTONLY
+ SRCS
+ pmu.cc
+ HDRS
+ pmu.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ DEPS
+ ruy_check_macros
+)
+
+ruy_cc_library(
+ NAME
+ ruy_reference_mul
+ HDRS
+ reference_mul.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ PUBLIC
+ DEPS
+ ruy_apply_multiplier
+ ruy_matrix
+ ruy_mul_params
+)
+
+if(CMAKE_SYSTEM_NAME STREQUAL Windows)
+ set(ruy_10_lm "")
+else()
+ set(ruy_10_lm "-lm")
+endif()
+
+ruy_cc_library(
+ NAME
+ ruy_test_lib
+ TESTONLY
+ HDRS
+ test.h
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ LINKOPTS
+ ${ruy_10_lm}
+ DEPS
+ ruy_allocator
+ ruy_size_util
+ ruy_reference_mul
+ ruy_matrix
+ ruy_pmu
+ ruy
+ ruy_mul_params
+ ruy_time
+ ruy_gtest_wrapper
+ ruy_platform
+ ruy_context
+ ruy_ctx
+ ruy_context_get_ctx
+ ruy_pack_common
+ ruy_profiler_profiler
+)
+
+ruy_cc_binary(
+ NAME
+ ruy_benchmark_f32_f32_f32_f32
+ TESTONLY
+ SRCS
+ benchmark.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=f32
+ -DRUY_TEST_RHSSCALAR=f32
+ -DRUY_TEST_ACCUMSCALAR=f32
+ -DRUY_TEST_DSTSCALAR=f32
+ DEPS
+ ruy_test_lib
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_binary(
+ NAME
+ ruy_benchmark_u8_u8_i32_u8
+ TESTONLY
+ SRCS
+ benchmark.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=u8
+ -DRUY_TEST_RHSSCALAR=u8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=u8
+ DEPS
+ ruy_test_lib
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_binary(
+ NAME
+ ruy_benchmark_i8_i8_i32_u8
+ TESTONLY
+ SRCS
+ benchmark.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=i8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=u8
+ DEPS
+ ruy_test_lib
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_binary(
+ NAME
+ ruy_benchmark_i8_i8_i32_i8
+ TESTONLY
+ SRCS
+ benchmark.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=i8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i8
+ DEPS
+ ruy_test_lib
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_binary(
+ NAME
+ ruy_benchmark_u8_u8_i32_i16
+ TESTONLY
+ SRCS
+ benchmark.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=u8
+ -DRUY_TEST_RHSSCALAR=u8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i16
+ DEPS
+ ruy_test_lib
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_binary(
+ NAME
+ ruy_benchmark_i8_i8_i32_i32
+ TESTONLY
+ SRCS
+ benchmark.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=i8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i32
+ DEPS
+ ruy_test_lib
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_f32_f32_f32_f32
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=f32
+ -DRUY_TEST_RHSSCALAR=f32
+ -DRUY_TEST_ACCUMSCALAR=f32
+ -DRUY_TEST_DSTSCALAR=f32
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_f64_f32_f64_f32
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=f64
+ -DRUY_TEST_RHSSCALAR=f32
+ -DRUY_TEST_ACCUMSCALAR=f64
+ -DRUY_TEST_DSTSCALAR=f32
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_f32_f64_f64_f64
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=f32
+ -DRUY_TEST_RHSSCALAR=f64
+ -DRUY_TEST_ACCUMSCALAR=f64
+ -DRUY_TEST_DSTSCALAR=f64
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_u8_u8_i32_u8
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=u8
+ -DRUY_TEST_RHSSCALAR=u8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=u8
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_i8_i8_i32_i8
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=i8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i8
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_i8_u8_i32_i8
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=u8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i8
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_u8_u8_i32_i16
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=u8
+ -DRUY_TEST_RHSSCALAR=u8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i16
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_i8_i8_i32_i32
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=i8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i32
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_fast_i8_u8_i32_i32
+ SRCS
+ test_fast.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=u8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i32
+ DEPS
+ ruy_test_lib
+ gtest_main
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_slow_f32_f32_f32_f32
+ SRCS
+ test_slow.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=f32
+ -DRUY_TEST_RHSSCALAR=f32
+ -DRUY_TEST_ACCUMSCALAR=f32
+ -DRUY_TEST_DSTSCALAR=f32
+ DEPS
+ ruy_test_lib
+ gtest_main
+ TAGS
+ slow
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_slow_u8_u8_i32_u8
+ SRCS
+ test_slow.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=u8
+ -DRUY_TEST_RHSSCALAR=u8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=u8
+ DEPS
+ ruy_test_lib
+ gtest_main
+ TAGS
+ slow
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_slow_i8_i8_i32_i8
+ SRCS
+ test_slow.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=i8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i8
+ DEPS
+ ruy_test_lib
+ gtest_main
+ TAGS
+ slow
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_slow_u8_u8_i32_i16
+ SRCS
+ test_slow.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=u8
+ -DRUY_TEST_RHSSCALAR=u8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i16
+ DEPS
+ ruy_test_lib
+ gtest_main
+ TAGS
+ slow
+)
+
+ruy_cc_test(
+ NAME
+ ruy_test_slow_i8_i8_i32_i32
+ SRCS
+ test_slow.cc
+ COPTS
+ ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef}
+ ${ruy_1_mfpu_neon}
+ ${ruy_2_O3}
+ -DRUY_TEST_LHSSCALAR=i8
+ -DRUY_TEST_RHSSCALAR=i8
+ -DRUY_TEST_ACCUMSCALAR=i32
+ -DRUY_TEST_DSTSCALAR=i32
+ DEPS
+ ruy_test_lib
+ gtest_main
+ TAGS
+ slow
+)
+
+ruy_add_all_subdirs()
diff --git a/ruy/allocator.cc b/ruy/allocator.cc
new file mode 100644
index 0000000..64da664
--- /dev/null
+++ b/ruy/allocator.cc
@@ -0,0 +1,124 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/allocator.h"
+
+#include "ruy/opt_set.h"
+#include "ruy/size_util.h"
+#include "ruy/system_aligned_alloc.h"
+
+namespace ruy {
+
+Allocator::~Allocator() {
+ FreeAll();
+ detail::SystemAlignedFree(ptr_);
+}
+
+void* Allocator::AllocateFast(std::ptrdiff_t num_bytes) {
+ if (current_ + num_bytes > size_) {
+ return nullptr;
+ }
+ void* ret = static_cast<char*>(ptr_) + current_;
+ current_ += num_bytes;
+ return ret;
+}
+
+void* Allocator::AllocateSlow(std::ptrdiff_t num_bytes) {
+ void* p = detail::SystemAlignedAlloc(num_bytes);
+ fallback_blocks_total_size_ += num_bytes;
+ fallback_blocks_.push_back(p);
+ return p;
+}
+
+void* Allocator::AllocateBytes(std::ptrdiff_t num_bytes) {
+ if (num_bytes == 0) {
+ return nullptr;
+ }
+ const std::ptrdiff_t rounded_num_bytes =
+ round_up_pot(num_bytes, detail::kMinimumBlockAlignment);
+ if (void* p = AllocateFast(rounded_num_bytes)) {
+ return p;
+ }
+ return AllocateSlow(rounded_num_bytes);
+}
+
+void* Allocator::AllocateBytesAvoidingAliasingWith(std::ptrdiff_t num_bytes,
+ const void* to_avoid) {
+#if RUY_OPT(AVOID_ALIASING)
+ if (num_bytes == 0) {
+ return nullptr;
+ }
+ // The minimum L1D cache aliasing periodicity in bytes that we expect to
+ // encounter on any device. This does not seem to be documented, but
+ // empirically we observe the following:
+ // Cortex-A53: 1024
+ // Cortex-A55r1: 2048
+ // Cortex-A76: not as easily observable.
+ // Over-estimating this makes the AVOID_ALIASING optimization useless on
+ // devices with lower periodicity.
+ // Under-estimating this by 2x should be harmless.
+ // Under-estimating this by a larger factor should gradually degrade
+ // performance due to cache aliasing causing mutual eviction between
+ // the packed matrix data, and the source matrix data being prefetched by the
+ // CPU ahead of the packing code execution.
+ static constexpr std::uint32_t kMinPeriod = 1024;
+ static_assert(is_pot(kMinPeriod), "");
+ void* p = AllocateBytes(num_bytes + kMinPeriod);
+ auto unsigned_low_bits = [](const void* p) {
+ return static_cast<std::uint32_t>(reinterpret_cast<std::uintptr_t>(p));
+ };
+ // This relies on unsigned integer overflow wrapping around.
+ std::uint32_t diff_modulus =
+ (unsigned_low_bits(p) - unsigned_low_bits(to_avoid)) % kMinPeriod;
+ // diff_modulus is in [0, kMinPeriod).
+ // We want it as close as possible to the middle of that interval,
+ // kMinPeriod/2. The bad 'aliasing' case, that we are working to avoid,
+ // is when diff_modulus is close to the ends of that interval, 0 or
+ // kMinPeriod. So we want to add an offset of kMinPeriod/2 if it is in the
+ // first or the last quarter of that interval.
+ bool need_offset =
+ diff_modulus < kMinPeriod / 4 || diff_modulus > 3 * kMinPeriod / 4;
+ return static_cast<char*>(p) + (need_offset ? (kMinPeriod / 2) : 0);
+#else
+ (void)to_avoid;
+ return AllocateBytes(num_bytes);
+#endif
+}
+
+void Allocator::FreeAll() {
+ current_ = 0;
+ if (fallback_blocks_.empty()) {
+ return;
+ }
+
+ // No rounding-up of the size means linear instead of logarithmic
+ // bound on the number of allocation in some worst-case calling patterns.
+ // This is considered worth it because minimizing memory usage is important
+ // and actual calling patterns in applications that we care about still
+ // reach the no-further-allocations steady state in a small finite number
+ // of iterations.
+ std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_;
+ detail::SystemAlignedFree(ptr_);
+ ptr_ = detail::SystemAlignedAlloc(new_size);
+ size_ = new_size;
+
+ for (void* p : fallback_blocks_) {
+ detail::SystemAlignedFree(p);
+ }
+ fallback_blocks_.clear();
+ fallback_blocks_total_size_ = 0;
+}
+
+} // namespace ruy
diff --git a/ruy/allocator.h b/ruy/allocator.h
new file mode 100644
index 0000000..8aee01c
--- /dev/null
+++ b/ruy/allocator.h
@@ -0,0 +1,100 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_ALLOCATOR_H_
+#define RUY_RUY_ALLOCATOR_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+namespace ruy {
+
+// Specialized allocator designed to converge to a steady-state where all
+// allocations are bump-ptr allocations from an already-allocated buffer.
+//
+// To support these constraints, this allocator only supports two
+// operations.
+// - AllocateBytes/Allocate<Pointer>: allocates a pointer to storage of a
+// specified size, which will be aligned to kMinimumBlockAlignment.
+// - FreeAll: frees all previous allocations (but retains the internal
+// buffer to minimize future calls into the system allocator).
+//
+// This class is specialized for supporting just those two operations
+// under this specific steady-state usage pattern. Extending this class
+// with new allocation interfaces that don't fit that pattern is probably not
+// the right choice. Instead, build a new class on top of
+// SystemAlignedAlloc/SystemAlignedFree.
+//
+// All operations happen on aligned blocks for simplicity.
+//
+// Theory of operation:
+//
+// - ptr_, current_, and size_ implement a basic bump-ptr allocator.
+//
+// - in AllocateBytes, the fast path is just a bump-ptr
+// allocation. If our bump-ptr allocator doesn't have enough space for an
+// allocation, then we allocate a block from the system allocator to
+// service the allocation request. We save that block in fallback_blocks_
+// and track the total size of the fallback blocks in
+// fallback_blocks_total_size_.
+//
+// - in FreeAll, the fast path just resets the bump-ptr allocator. If
+// there are any fallback blocks, we free them and reallocate the
+// bump-ptr allocator's buffer so that the next sequence of allocations
+// will hopefully not need any fallback blocks.
+class Allocator final {
+ public:
+ ~Allocator();
+
+ // Allocate a buffer.
+ void* AllocateBytes(std::ptrdiff_t num_bytes);
+ // Allocate a buffer, trying to avoid having its address close to aliasing
+ // the specified `to_avoid` in the L1D cache.
+ void* AllocateBytesAvoidingAliasingWith(std::ptrdiff_t num_bytes,
+ const void* to_avoid);
+ // Allocate an array of `count` elements of type T.
+ template <typename T>
+ T* Allocate(std::ptrdiff_t count) {
+ return static_cast<T*>(AllocateBytes(count * sizeof(T)));
+ }
+ // Allocate an array of `count` elements of the given `Pointer` type's
+ // element_type.
+ template <typename Pointer>
+ void Allocate(std::ptrdiff_t count, Pointer* out) {
+ using T = typename std::pointer_traits<Pointer>::element_type;
+ *out = Allocate<T>(count);
+ }
+
+ // Free all allocated blocks. Internally consolidate allocated buffers as
+ // explained in the class comment.
+ void FreeAll();
+
+ private:
+ void operator=(const Allocator&) = delete;
+ void* AllocateFast(std::ptrdiff_t num_bytes);
+ void* AllocateSlow(std::ptrdiff_t num_bytes);
+
+ void* ptr_ = nullptr;
+ std::ptrdiff_t current_ = 0;
+ std::ptrdiff_t size_ = 0;
+ std::vector<void*> fallback_blocks_;
+ std::ptrdiff_t fallback_blocks_total_size_ = 0;
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_ALLOCATOR_H_
diff --git a/ruy/allocator_test.cc b/ruy/allocator_test.cc
new file mode 100644
index 0000000..ee31669
--- /dev/null
+++ b/ruy/allocator_test.cc
@@ -0,0 +1,126 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/allocator.h"
+
+#include "ruy/gtest_wrapper.h"
+
+namespace ruy {
+namespace {
+
+TEST(AllocatorTest, ReturnsValidMemory) {
+ Allocator allocator;
+ int *p;
+ allocator.Allocate(1, &p);
+ ASSERT_NE(p, nullptr);
+
+ // If this is bogus memory, ASan will cause this test to fail.
+ *p = 42;
+
+ allocator.FreeAll();
+}
+
+TEST(AllocatorTest, NoLeak) {
+ Allocator allocator;
+ // Allocate and free some ridiculously large total amount of memory, so
+ // that a leak will hopefully cause some sort of resource exhaustion.
+ //
+ // Despite the large number of allocations, this test is actually quite
+ // fast, since our fast-path allocation logic is very fast.
+ constexpr int kNumAllocations = 100 * 1024;
+ constexpr int kAllocationSize = 1024 * 1024;
+ for (int i = 0; i < kNumAllocations; i++) {
+ char *p;
+ allocator.Allocate(kAllocationSize, &p);
+ allocator.FreeAll();
+ }
+}
+
+TEST(AllocatorTest, IncreasingSizes) {
+ Allocator allocator;
+ // Allocate sizes that increase by small amounts across FreeAll calls.
+ for (int i = 1; i < 100 * 1024; i++) {
+ char *p;
+ allocator.Allocate(i, &p);
+ allocator.FreeAll();
+ }
+}
+
+TEST(AllocatorTest, ManySmallAllocations) {
+ Allocator allocator;
+ // Allocate many small allocations between FreeAll calls.
+ for (int i = 0; i < 10 * 1024; i += 100) {
+ for (int j = 0; j < i; j++) {
+ char *p;
+ allocator.Allocate(1, &p);
+ }
+ allocator.FreeAll();
+ }
+}
+
+TEST(AllocatorTest, DestructorHandlesMainBumpPtr) {
+ // This is a white-box test.
+ Allocator allocator;
+ allocator.AllocateBytes(1);
+ allocator.FreeAll();
+ // After the call to FreeAll, the allocator will consolidate all of the memory
+ // into the main bump-ptr allocator's block, which we then expect to be freed
+ // in the destructor.
+ //
+ // We have no test assertions -- we primarily expect that this trigger a leak
+ // checker and cause the test to fail.
+}
+
+TEST(AllocatorTest, DestructorHandlesFallbackBlocks) {
+ // This is a white-box test.
+ Allocator allocator;
+ // Since we just created the allocator, this will allocate a fallback block,
+ // which we then expect to be freed in the destructor.
+ //
+ // We have no test assertions -- we primarily expect that this trigger a leak
+ // checker and cause the test to fail.
+ allocator.AllocateBytes(1);
+}
+
+TEST(AllocatorTest, AvoidAliasing) {
+ Allocator allocator;
+ // Run twice with a FreeAll in between, just in case some future
+ // change of internal logic makes that bug-prone.
+ for (int repeat = 0; repeat < 2; repeat++) {
+ for (int i = 1; i < 100; i++) {
+ const void *to_avoid =
+ reinterpret_cast<const void *>(0x1234567890123ull + 123 * i);
+ void *ptr = allocator.AllocateBytesAvoidingAliasingWith(i * 10, to_avoid);
+ auto unsigned_low_bits = [](const void *p) {
+ return static_cast<std::uint32_t>(reinterpret_cast<std::uintptr_t>(p));
+ };
+ static constexpr int kMinPeriod = 1024;
+ std::uint32_t unsigned_diff =
+ (unsigned_low_bits(ptr) - unsigned_low_bits(to_avoid)) % kMinPeriod;
+ std::uint32_t unsigned_diff_mod = unsigned_diff % kMinPeriod;
+ ASSERT_TRUE(unsigned_diff_mod >= (kMinPeriod / 4) &&
+ unsigned_diff_mod <= 3 * (kMinPeriod / 4));
+ }
+ allocator.FreeAll();
+ }
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/apply_multiplier.cc b/ruy/apply_multiplier.cc
new file mode 100644
index 0000000..19bfd88
--- /dev/null
+++ b/ruy/apply_multiplier.cc
@@ -0,0 +1,70 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/apply_multiplier.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <limits>
+
+namespace ruy {
+namespace detail {
+
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Warning: this code is not meant to be bit-exact-normative.
+// Please refer to the class comment of ruy::MulParams, in mul_params.h.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Simplified multiplier application function
+//
+// Double rounding and symmetric rounding are removed compared to reference.
+// Double rounding seems unnecessary and can complicate implementations.
+// Symmetric rounding also adds implementation complexity.
+//
+// Composed of a single rounding shift right and can lead to more HW
+// friendly implementations.
+//
+// On NEON this can be translated to a SQDMULH + rounding shift right sequence.
+// The use of SQDMULH rather than SQRDMULH gives a result that is
+// equivalent to a single rounded shift since the truncating shift of SQDMULH
+// can be combined with the rounding right shift via the formula (for k>=1):
+// ((x>>31)+(1<<(k-1)))>>k = (x + (1<<(30+k))>>(31+k)
+//
+// Preconditions:
+// - quantized_multiplier >= 0
+// - shift is -31 to +7 (negative for right shift)
+std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x,
+ std::int32_t quantized_multiplier,
+ int shift) {
+ RUY_CHECK_GE(shift, -31);
+ RUY_CHECK_LE(shift, 7);
+
+ int total_shift = 31 - shift;
+
+ std::int64_t x_64(x);
+ std::int64_t quantized_multiplier_64(quantized_multiplier);
+ std::int64_t round = (int64_t)1 << (total_shift - 1);
+ int64_t result = x_64 * quantized_multiplier_64 + round;
+ result = result >> total_shift;
+
+ RUY_DCHECK_GE(result, std::numeric_limits<std::int32_t>::lowest());
+ RUY_DCHECK_LE(result, std::numeric_limits<std::int32_t>::max());
+
+ return static_cast<std::int32_t>(result);
+}
+
+} // namespace detail
+
+} // namespace ruy
diff --git a/ruy/apply_multiplier.h b/ruy/apply_multiplier.h
new file mode 100644
index 0000000..120b990
--- /dev/null
+++ b/ruy/apply_multiplier.h
@@ -0,0 +1,92 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Provides a reference (portable, non-optimized) ApplyMultiplier function.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Warning: this code is not meant to be bit-exact-normative.
+// Please refer to the class comment of ruy::MulParams, in mul_params.h.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+#ifndef RUY_RUY_APPLY_MULTIPLIER_H_
+#define RUY_RUY_APPLY_MULTIPLIER_H_
+
+#include <cstdint>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/mul_params.h"
+
+namespace ruy {
+
+// Applies the quantized multiplier to the `*accum` accumulator value, if
+// applicable, that is, if AccumScalar==int32 and DstScalar!=int32. Otherwise,
+// does nothing.
+//
+// This is slow, portable, 'reference' code. It should only be used in
+// ReferenceMul and in Path::kStandardCpp. There isn't a point in optimizing it,
+// either. Fast paths have that multiplier work done as part of the kernel,
+// typically written in assembly anyway.
+template <typename AccumScalar, typename DstScalar>
+void ApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params,
+ int channel, AccumScalar* accum);
+
+namespace detail {
+
+// Copied from TF Lite code.
+std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x,
+ std::int32_t quantized_multiplier,
+ int shift);
+
+// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar
+// is int32 (i.e. in all cases except floating-point) and if the destination is
+// not int32 (i.e. unless the user wants to get raw accumulators).
+template <typename AccumScalar, typename DstScalar,
+ bool IsApplicable = std::is_same<AccumScalar, std::int32_t>::value &&
+ !std::is_same<DstScalar, std::int32_t>::value>
+struct ApplyMultiplierImpl {};
+
+// Specialization in non-applicable case: do nothing.
+template <typename AccumScalar, typename DstScalar>
+struct ApplyMultiplierImpl<AccumScalar, DstScalar, false> {
+ static void Run(const MulParams<AccumScalar, DstScalar>&, int, AccumScalar*) {
+ }
+};
+
+template <typename AccumScalar, typename DstScalar>
+struct ApplyMultiplierImpl<AccumScalar, DstScalar, true> {
+ static void Run(const MulParams<AccumScalar, DstScalar>& mul_params,
+ int channel, AccumScalar* accum) {
+ AccumScalar m = mul_params.multiplier_fixedpoint_perchannel()
+ ? mul_params.multiplier_fixedpoint_perchannel()[channel]
+ : mul_params.multiplier_fixedpoint();
+ int e = mul_params.multiplier_exponent_perchannel()
+ ? mul_params.multiplier_exponent_perchannel()[channel]
+ : mul_params.multiplier_exponent();
+ *accum = MultiplyByQuantizedMultiplier(*accum, m, e);
+ }
+};
+
+} // namespace detail
+
+template <typename AccumScalar, typename DstScalar>
+void ApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params,
+ int channel, AccumScalar* accum) {
+ detail::ApplyMultiplierImpl<AccumScalar, DstScalar>::Run(mul_params, channel,
+ accum);
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_APPLY_MULTIPLIER_H_
diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc
new file mode 100644
index 0000000..2df80d7
--- /dev/null
+++ b/ruy/apply_multiplier_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/apply_multiplier.h"
+
+#include <cstdint>
+#include <limits>
+
+#include "ruy/gtest_wrapper.h"
+#include "ruy/mul_params.h"
+
+namespace ruy {
+namespace {
+
+void TestMultiplyByQuantizedMultiplier(std::int32_t input,
+ std::int32_t multiplier_fixedpoint,
+ int multiplier_exponent,
+ std::int32_t expected_output) {
+ EXPECT_EQ(expected_output,
+ detail::MultiplyByQuantizedMultiplier(input, multiplier_fixedpoint,
+ multiplier_exponent));
+}
+
+// These testcases exercise various multiplier_fixedpoint values, leaving
+// multiplier_exponent = 0. They exercise the logic in
+// SaturatingRoundingDoublingHighMul.
+TEST(ApplyMultiplierTest, SaturatingRoundingDoublingHighMul) {
+ const std::int32_t max_int32 = std::numeric_limits<std::int32_t>::max();
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, 0, 1000);
+ TestMultiplyByQuantizedMultiplier(1000, 1 << 30, 0, 500);
+ TestMultiplyByQuantizedMultiplier(1000, 1 << 29, 0, 250);
+ TestMultiplyByQuantizedMultiplier(1000, (1 << 30) + (1 << 29), 0, 750);
+ TestMultiplyByQuantizedMultiplier(1000, (1 << 30) + (1 << 28), 0, 625);
+ // This 563 expected value comes from the breaking of the tie on 562.5.
+ // As a positive value, it does not distinguish between 'upward' and
+ // 'away from zero' tie-breaking behavior.
+ TestMultiplyByQuantizedMultiplier(1000, (1 << 30) + (1 << 27), 0, 563);
+ TestMultiplyByQuantizedMultiplier(1000, (1 << 30) + (1 << 26), 0, 531);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, 0, -1000);
+ TestMultiplyByQuantizedMultiplier(-1000, 1 << 30, 0, -500);
+ TestMultiplyByQuantizedMultiplier(-1000, 1 << 29, 0, -250);
+ TestMultiplyByQuantizedMultiplier(-1000, (1 << 30) + (1 << 29), 0, -750);
+ TestMultiplyByQuantizedMultiplier(-1000, (1 << 30) + (1 << 28), 0, -625);
+ // This -562 expected value is because the SaturatingRoundingDoublingHighMul
+ // breaks ties upwards, not away-from-zero. The value before rounding is
+ // -562.5. No other test case here tests a tie on a negative value.
+ TestMultiplyByQuantizedMultiplier(-1000, (1 << 30) + (1 << 27), 0, -562);
+ TestMultiplyByQuantizedMultiplier(-1000, (1 << 30) + (1 << 26), 0, -531);
+}
+
+// These testcases exercise various negative multiplier_exponent values while
+// keeping multiplier_fixedpoint trivial.
+TEST(ApplyMultiplierTest, RoundingRightShift) {
+ const std::int32_t max_int32 = std::numeric_limits<std::int32_t>::max();
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, -1, 500);
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, -2, 250);
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, -3, 125);
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, -4, 62);
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, -5, 31);
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, -6, 16);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, -1, -500);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, -2, -250);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, -3, -125);
+ // This -62 value comes from rounding -62.5, which is a tie.
+ // Is the only test case here that exercises a tie-break on a negative value,
+ // distinguishing between 'upward' and 'away from zero'.
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, -4, -62);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, -5, -31);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, -6, -16);
+}
+
+// These testcases exercise various positive multiplier_exponent values while
+// keeping multiplier_fixedpoint trivial.
+TEST(ApplyMultiplierTest, LeftShift) {
+ const std::int32_t max_int32 = std::numeric_limits<std::int32_t>::max();
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, 1, 2000);
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, 2, 4000);
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, 3, 8000);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, 1, -2000);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, 2, -4000);
+ TestMultiplyByQuantizedMultiplier(-1000, max_int32, 3, -8000);
+}
+
+template <typename AccumScalar, typename DstScalar>
+void TestApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params,
+ int channel, AccumScalar input,
+ AccumScalar expected_output) {
+ AccumScalar actual_output = input;
+ ApplyMultiplier(mul_params, channel, &actual_output);
+ EXPECT_EQ(expected_output, actual_output);
+}
+
+TEST(ApplyMultiplierTest, ApplyMultiplierUniform) {
+ MulParams<std::int32_t, std::int8_t> mul_params;
+ // Test that default values give a multiplication by 1.
+ TestApplyMultiplier(mul_params, 0, 1000, 1000);
+ mul_params.set_multiplier_fixedpoint(1 << 30);
+ mul_params.set_multiplier_exponent(-1);
+ TestApplyMultiplier(mul_params, 0, 1000, 250);
+ mul_params.set_multiplier_fixedpoint(1 << 25);
+ mul_params.set_multiplier_exponent(3);
+ TestApplyMultiplier(mul_params, 0, 1000, 125);
+}
+
+TEST(ApplyMultiplierTest, ApplyMultiplierPerChannel) {
+ const std::int32_t max_int32 = std::numeric_limits<std::int32_t>::max();
+ const std::int32_t multiplier_fixedpoint[4] = {max_int32, 1 << 30, max_int32,
+ 1 << 30};
+ const int multiplier_exponent[4] = {0, 0, -1, -1};
+ MulParams<std::int32_t, std::int8_t> mul_params;
+ mul_params.set_multiplier_fixedpoint_perchannel(multiplier_fixedpoint);
+ mul_params.set_multiplier_exponent_perchannel(multiplier_exponent);
+ TestApplyMultiplier(mul_params, 0, 1000, 1000);
+ TestApplyMultiplier(mul_params, 1, 1000, 500);
+ TestApplyMultiplier(mul_params, 2, 1000, 500);
+ TestApplyMultiplier(mul_params, 3, 1000, 250);
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/asm_helpers.h b/ruy/asm_helpers.h
new file mode 100644
index 0000000..f0bc303
--- /dev/null
+++ b/ruy/asm_helpers.h
@@ -0,0 +1,43 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Some helpers to write inline asm.
+
+#ifndef RUY_RUY_ASM_HELPERS_H_
+#define RUY_RUY_ASM_HELPERS_H_
+
+#include "ruy/opt_set.h"
+
+// Enclose load-prefetch instructions in RUY_PREFETCH_LOAD() so we can
+// conditionally enable them based on the RUY_OPT_SET.
+#if RUY_OPT(PREFETCH_LOAD)
+#define RUY_PREFETCH_LOAD(X) X
+#else
+#define RUY_PREFETCH_LOAD(X)
+#endif
+
+// Enclose store-prefetch instructions in RUY_PREFETCH_STORE() so we can
+// conditionally enable them based on the RUY_OPT_SET.
+#if RUY_OPT(PREFETCH_STORE)
+#define RUY_PREFETCH_STORE(X) X
+#else
+#define RUY_PREFETCH_STORE(X)
+#endif
+
+// The usual stringification macro.
+#define RUY_STR(s) RUY_STR_UNEXPANDED(s)
+#define RUY_STR_UNEXPANDED(s) #s
+
+#endif // RUY_RUY_ASM_HELPERS_H_
diff --git a/ruy/benchmark.cc b/ruy/benchmark.cc
new file mode 100644
index 0000000..3c63249
--- /dev/null
+++ b/ruy/benchmark.cc
@@ -0,0 +1,223 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdio>
+#include <cstdlib>
+#include <string>
+
+#include "ruy/test.h"
+
+namespace ruy {
+
+using LhsScalar = RUY_TEST_LHSSCALAR;
+using RhsScalar = RUY_TEST_RHSSCALAR;
+using AccumScalar = RUY_TEST_ACCUMSCALAR;
+using DstScalar = RUY_TEST_DSTSCALAR;
+using TestSetType = TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>;
+
+struct BenchmarkShape {
+ int rows;
+ int depth;
+ int cols;
+ int symm_lhs;
+ int symm_rhs;
+};
+
+template <typename TestSetType>
+std::vector<std::unique_ptr<TestResult<DstScalar>>> Benchmark(
+ const BenchmarkShape& shape) {
+ TestSetType test_set;
+ test_set.rows = shape.rows;
+ test_set.depth = shape.depth;
+ test_set.cols = shape.cols;
+ const char* orders = "RCC";
+ const char* orders_env = getenv("ORDERS");
+ if (orders_env) {
+ bool error = false;
+ if (strlen(orders_env) != 3) {
+ error = true;
+ } else {
+ for (int i = 0; i < 3; i++) {
+ if (orders_env[i] != 'R' && orders_env[i] != 'C') {
+ error = true;
+ }
+ }
+ }
+ if (error) {
+ fprintf(stderr,
+ "ORDERS must contain 3 letters, each either R or C, indicating "
+ "whether to use Row-major or Column-major storage order for the "
+ "LHS, RHS and Destination matrix.\n");
+ exit(EXIT_FAILURE);
+ }
+ orders = orders_env;
+ }
+ test_set.lhs_order = orders[0] == 'R' ? Order::kRowMajor : Order::kColMajor;
+ test_set.rhs_order = orders[1] == 'R' ? Order::kRowMajor : Order::kColMajor;
+ test_set.dst_order = orders[2] == 'R' ? Order::kRowMajor : Order::kColMajor;
+ test_set.layout_style = LayoutStyle::kUnstridedLinear;
+ test_set.benchmark = true;
+ const int asymmetry_lhs = shape.symm_lhs ? 0 : 1;
+ const int asymmetry_rhs = shape.symm_rhs ? 0 : 1;
+ test_set.lhs_zero_point = SymmetricZeroPoint<LhsScalar>() + asymmetry_lhs;
+ test_set.rhs_zero_point = SymmetricZeroPoint<RhsScalar>() + asymmetry_rhs;
+ test_set.use_specified_zero_points = true;
+ test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL");
+ if (getenv("PREPACK_LHS") || getenv("PREPACK_RHS")) {
+ fprintf(stderr,
+ "PREPACK_LHS and PREPACK_RHS are deprecated. Use CACHE_LHS and "
+ "CACHE_RHS instead.\n");
+ exit(EXIT_FAILURE);
+ }
+ test_set.cache_lhs = GetBoolEnvVarOrFalse("CACHE_LHS");
+ test_set.cache_rhs = GetBoolEnvVarOrFalse("CACHE_RHS");
+ test_set.Run();
+ return std::move(test_set.results);
+}
+
+std::vector<int> ParseCommaSeparatedInts(
+ const std::string& comma_separated_ints) {
+ std::vector<int> result;
+ for (std::size_t pos = 0; pos < comma_separated_ints.size();) {
+ std::size_t delim_pos = comma_separated_ints.find(',', pos);
+ if (delim_pos == std::string::npos) {
+ delim_pos = comma_separated_ints.size();
+ }
+ result.push_back(
+ std::stoi(comma_separated_ints.substr(pos, delim_pos - pos)));
+ pos = delim_pos + 1;
+ }
+ return result;
+}
+
+void Benchmark() {
+ const bool symm_lhs = std::is_floating_point<LhsScalar>::value ||
+ GetBoolEnvVarOrFalse("SYMM_LHS");
+ const bool symm_rhs = std::is_floating_point<RhsScalar>::value ||
+ GetBoolEnvVarOrFalse("SYMM_RHS");
+ const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") ||
+ GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST");
+ const int explicit_rows = GetIntEnvVarOrZero("ROWS");
+ const int explicit_cols = GetIntEnvVarOrZero("COLS");
+ const int explicit_depth = GetIntEnvVarOrZero("DEPTH");
+
+ std::vector<BenchmarkShape> shapes;
+
+ if (benchmark_cubic) {
+ std::vector<int> sizes;
+ const char* benchmark_cubic_list_env = getenv("RUY_BENCHMARK_CUBIC_LIST");
+ if (benchmark_cubic_list_env) {
+ sizes = ParseCommaSeparatedInts(benchmark_cubic_list_env);
+ } else {
+ // Often 8 is used for this multiplier, but to check teeny sizes one can
+ // use 1.
+ static constexpr int cubic_size_multiplier = 8;
+ for (int i = 2 * cubic_size_multiplier;
+ i <= (512 * cubic_size_multiplier); i *= 2) {
+ sizes.push_back(i);
+ if (i < (512 * cubic_size_multiplier)) {
+ sizes.push_back(i * 3 / 2);
+ }
+ }
+ }
+ for (int i : sizes) {
+ BenchmarkShape shape;
+ // Even in cubic mode, one may still override an individual dimension
+ // to allow testing a batch of rectangular sizes.
+ shape.rows = explicit_rows ? explicit_rows : i;
+ shape.cols = explicit_cols ? explicit_cols : i;
+ shape.depth = explicit_depth ? explicit_depth : i;
+ shape.symm_lhs = symm_lhs;
+ shape.symm_rhs = symm_rhs;
+ shapes.push_back(shape);
+ }
+ } else {
+ BenchmarkShape shape;
+ shape.rows = explicit_rows;
+ shape.cols = explicit_cols;
+ shape.depth = explicit_depth;
+ if (!shape.rows || !shape.depth || !shape.cols) {
+ fprintf(stderr,
+ "Please specify positive sizes with these env vars: ROWS, DEPTH, "
+ "COLS.\n");
+ exit(1);
+ }
+ shape.symm_lhs = symm_lhs;
+ shape.symm_rhs = symm_rhs;
+ shapes.push_back(shape);
+ }
+
+ for (int i = 0; i < static_cast<int>(shapes.size()); i++) {
+ const auto& shape = shapes[i];
+ const auto& results = Benchmark<TestSetType>(shape);
+ if (i == 0) {
+ if (benchmark_cubic) {
+ printf("size");
+ for (const auto& result : results) {
+ if (results.size() > 1) {
+ printf(",%s:Gop/s", PathName(*result).c_str());
+ } else {
+ printf(",Gop/s");
+ }
+ if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) {
+ printf(
+ ",l1_refill,l2_refill,l3_refill,l1tlb_refill,l2tlb_refill,"
+ "mispred,frontend_stall,backend_stall");
+ }
+ }
+ printf("\n");
+ } else {
+ printf("path,shape,Gop/s\n");
+ }
+ fflush(stdout);
+ }
+ if (benchmark_cubic) {
+ printf("%d", shape.rows);
+ for (const auto& result : results) {
+ printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth /
+ result->latency);
+ if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) {
+ printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g",
+ result->l1_refill_rate, result->l2_refill_rate,
+ result->l3_refill_rate, result->l1tlb_refill_rate,
+ result->l2tlb_refill_rate, result->mispred_rate,
+ result->frontend_stall_rate, result->backend_stall_rate);
+ }
+ }
+ printf("\n");
+ fflush(stdout);
+ } else {
+ for (const auto& result : results) {
+ printf(
+ "%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows,
+ shape.depth, shape.cols,
+ 2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency);
+ if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) {
+ printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g",
+ result->l1_refill_rate, result->l2_refill_rate,
+ result->l3_refill_rate, result->l1tlb_refill_rate,
+ result->l2tlb_refill_rate, result->mispred_rate,
+ result->frontend_stall_rate, result->backend_stall_rate);
+ }
+ printf("\n");
+ }
+ fflush(stdout);
+ }
+ }
+}
+
+} // namespace ruy
+
+int main() { ruy::Benchmark(); }
diff --git a/ruy/block_map.cc b/ruy/block_map.cc
new file mode 100644
index 0000000..8240de2
--- /dev/null
+++ b/ruy/block_map.cc
@@ -0,0 +1,497 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/block_map.h"
+
+#include <algorithm>
+#include <cstdint>
+
+#ifdef RUY_MAKEBLOCKMAP_DEBUG
+#include <cstdio>
+#include <cstdlib>
+#include <string>
+#endif
+
+#include "ruy/check_macros.h"
+#include "ruy/opt_set.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/size_util.h"
+#include "ruy/trace.h"
+
+namespace ruy {
+
+namespace {
+
+void DecodeTraversalLinear(int size_log2, std::uint32_t square_index,
+ SidePair<int>* local_pos) {
+ (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1);
+ (*local_pos)[Side::kRhs] = square_index >> size_log2;
+}
+
+void DecodeTraversalFractalZ(std::uint32_t square_index,
+ SidePair<int>* local_pos) {
+ const std::uint32_t n1 = square_index;
+ const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) |
+ ((n1 & 0x22222222u) << 1);
+ const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) |
+ ((n2 & 0x0c0c0c0cu) << 2);
+ const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) |
+ ((n4 & 0x00f000f0u) << 4);
+ const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) |
+ ((n8 & 0x0000ff00u) << 8);
+ (*local_pos)[Side::kLhs] = n16 & 0xffff;
+ (*local_pos)[Side::kRhs] = n16 >> 16;
+}
+
+void DecodeTraversalFractalU(std::uint32_t square_index,
+ SidePair<int>* local_pos) {
+ DecodeTraversalFractalZ(square_index, local_pos);
+ // Change fractal z-order to u-order
+ (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs];
+}
+
+// Code inspired by the sample code in
+// https://en.wikipedia.org/wiki/Hilbert_curve
+// The main optimization is to avoid hard-to-predict conditional branches
+// based on the bits of the square_index parameter.
+void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index,
+ SidePair<int>* local_pos) {
+ std::uint32_t t = square_index;
+ std::uint32_t x = 0;
+ std::uint32_t y = 0;
+ // Easy-to-predict for loop, the number of iterations is the same for
+ // an entire GEMM.
+ for (int sb = 0; sb < size_log2; sb++) {
+ std::uint32_t s = 1 << sb;
+ bool rx = t & 2;
+ bool ry = (t & 1) ^ rx;
+ std::uint32_t tmp = rx ? (s - 1 - x) : x;
+ x = ry ? x : rx ? (s - 1 - y) : y;
+ y = ry ? (y + s) : tmp;
+ x = rx ? (x + s) : x;
+ t >>= 2;
+ }
+ (*local_pos)[Side::kLhs] = y;
+ (*local_pos)[Side::kRhs] = x;
+}
+
+} // end anonymous namespace
+
+void GetBlockByIndex(const BlockMap& block_map, int index,
+ SidePair<int>* block) {
+ profiler::ScopeLabel label("GetBlockByIndex");
+ const std::uint32_t index_u32 = index;
+
+ const std::uint32_t num_blocks_per_local_curve =
+ 1u << (2 * block_map.num_blocks_base_log2);
+ const std::uint32_t square_index =
+ index_u32 & (num_blocks_per_local_curve - 1);
+
+ const int size_log2 = block_map.num_blocks_base_log2;
+ SidePair<int> local_pos;
+ switch (block_map.traversal_order) {
+ case BlockMapTraversalOrder::kFractalZ:
+ DecodeTraversalFractalZ(square_index, &local_pos);
+ break;
+ case BlockMapTraversalOrder::kFractalU:
+ DecodeTraversalFractalU(square_index, &local_pos);
+ break;
+ case BlockMapTraversalOrder::kFractalHilbert:
+ DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos);
+ break;
+ default:
+ RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear);
+ DecodeTraversalLinear(size_log2, square_index, &local_pos);
+ break;
+ }
+
+ const std::uint32_t rectangular_index =
+ index_u32 >> 2 * block_map.num_blocks_base_log2;
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1;
+ const int rectangular_offset = (rectangular_index & mask)
+ << block_map.num_blocks_base_log2;
+ (*block)[side] = local_pos[side] + rectangular_offset;
+ }
+}
+
+namespace {
+
+BlockMapTraversalOrder GetTraversalOrder(
+ int rows_after_rectangularness_division,
+ int cols_after_rectangularness_division, int depth, int lhs_scalar_size,
+ int rhs_scalar_size, const CpuCacheParams& cpu_cache_params) {
+ static constexpr bool kAnyFractal =
+ RUY_OPT(FRACTAL_Z) | RUY_OPT(FRACTAL_U) | RUY_OPT(FRACTAL_HILBERT);
+ const int working_set_size =
+ (lhs_scalar_size * rows_after_rectangularness_division +
+ rhs_scalar_size * cols_after_rectangularness_division) *
+ depth;
+ if (kAnyFractal && (working_set_size > cpu_cache_params.local_cache_size)) {
+ if (RUY_OPT(FRACTAL_HILBERT) &&
+ (working_set_size > cpu_cache_params.last_level_cache_size)) {
+ return BlockMapTraversalOrder::kFractalHilbert;
+ } else if (RUY_OPT(FRACTAL_U)) {
+ return BlockMapTraversalOrder::kFractalU;
+ } else {
+ return BlockMapTraversalOrder::kFractalZ;
+ }
+ } else {
+ return BlockMapTraversalOrder::kLinear;
+ }
+}
+
+int floor_log2_quotient(int num, int denom) {
+ if (num <= denom) {
+ return 0;
+ }
+ int log2_quotient = floor_log2(num) - ceil_log2(denom);
+ if ((denom << (log2_quotient + 1)) <= num) {
+ log2_quotient++;
+ }
+ return log2_quotient;
+}
+
+// Computes the rectangularness of the matrix shape (rows, cols). This is
+// essentially just the log2 of the quotient (rows / cols). The kernel_rows and
+// kernel_cols only get into the picture for clamping bounds but don't affect
+// the generic computation.
+void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols,
+ int* rows_rectangularness_log2,
+ int* cols_rectangularness_log2) {
+ *rows_rectangularness_log2 = 0;
+ *cols_rectangularness_log2 = 0;
+
+ // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel
+ // itself, we risk having too small kernel blocks for good kernel
+ // amortization. We avoid that by limiting recangularness so that kernel
+ // blocks are not too tiny at least in that dimension. Specifically, we try to
+ // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each
+ // kernel block along the large dimension.
+ const int min_kernel_inner_loop_runs_log2 = 3;
+ if (rows > cols) {
+ int cols_of_kernel_inner_loop_runs_log2 =
+ ceil_log2(cols) - pot_log2(kernel_cols);
+ int min_rows_of_kernel_inner_loop_runs_log2 =
+ std::max(0, min_kernel_inner_loop_runs_log2 -
+ cols_of_kernel_inner_loop_runs_log2);
+ *rows_rectangularness_log2 =
+ std::min(floor_log2_quotient(rows, cols),
+ std::max(0, floor_log2(rows) - pot_log2(kernel_rows) -
+ min_rows_of_kernel_inner_loop_runs_log2));
+ // Sanity check that we did not over-estimate rows_rectangularness_log2.
+ RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols);
+ } else if (cols > rows) {
+ int rows_of_kernel_inner_loop_runs_log2 =
+ ceil_log2(rows) - pot_log2(kernel_rows);
+ int min_cols_of_kernel_inner_loop_runs_log2 =
+ std::max(0, min_kernel_inner_loop_runs_log2 -
+ rows_of_kernel_inner_loop_runs_log2);
+ *cols_rectangularness_log2 =
+ std::min(floor_log2_quotient(cols, rows),
+ std::max(0, floor_log2(cols) - pot_log2(kernel_cols) -
+ min_cols_of_kernel_inner_loop_runs_log2));
+ // Sanity check that we did not over-estimate cols_rectangularness_log2.
+ RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows);
+ }
+ RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2);
+}
+
+// Computes a 'multithreading score'. When multithreading, we need there to
+// be at least as many tiles as there are threads, and hopefully
+// substantially more than that, so we benefit from ruy's ability to
+// dispatch fine-grained workloads to threads.
+int GetMultithreadingScore(int block_size_log2, int rows, int cols,
+ int tentative_thread_count) {
+ const int num_full_blocks_of_rows = rows >> block_size_log2;
+ const int num_full_blocks_of_cols = cols >> block_size_log2;
+ const int candidate_num_full_blocks_log2 = floor_log2(
+ std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols));
+
+ // The values here have been tuned on ARM Cortex-A55.
+ // We expect this to have to be tuned differently for other CPUs.
+ if (tentative_thread_count == 1) {
+ return 0;
+ } else {
+ const int blocks_per_thread_log2 =
+ candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count);
+ if (blocks_per_thread_log2 < 0) {
+ return -64;
+ } else if (blocks_per_thread_log2 == 0) {
+ return -16;
+ } else if (blocks_per_thread_log2 == 1) {
+ return -8;
+ } else if (blocks_per_thread_log2 == 2) {
+ return 0;
+ } else if (blocks_per_thread_log2 == 3) {
+ return 8;
+ } else {
+ return 16;
+ }
+ }
+}
+
+// Computes a 'cache locality score'.
+int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth,
+ int kernel_rows_log2, int kernel_cols_log2,
+ int lhs_scalar_size, int rhs_scalar_size,
+ const CpuCacheParams& cpu_cache_params) {
+ // In the narrow case (e.g. matrix*vector), each byte of the big operand
+ // matrix (either LHS or RHS) is traversed only once, so any notion of data
+ // locality is irrelevant. Ignore the 'cache locality score' by forcing it to
+ // be 0 in that case.
+ if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) {
+ return 0;
+ }
+ const int block_rows = std::min(1 << block_size_log2, rows);
+ const int block_cols = std::min(1 << block_size_log2, cols);
+ const int total_read_bytes =
+ (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth;
+ const int total_read_bytes_log2 = ceil_log2(total_read_bytes);
+ const int nonlocality_log2 =
+ total_read_bytes_log2 - floor_log2(cpu_cache_params.local_cache_size);
+ // The values here have been tuned on ARM Cortex-A55.
+ // We expect this to have to be tuned differently for other CPUs.
+ if (nonlocality_log2 < -1) {
+ return 64;
+ } else if (nonlocality_log2 == -1) {
+ return 56;
+ } else if (nonlocality_log2 == 0) {
+ return 48;
+ } else if (nonlocality_log2 == 1) {
+ return 32;
+ } else if (nonlocality_log2 == 2) {
+ return 16;
+ } else if (nonlocality_log2 == 3) {
+ return 0;
+ } else {
+ return -64;
+ }
+}
+
+// Compute a 'kernel amortization score'. This is the notion that very small
+// tiles result in more overhead outside of kernels, more complex memory
+// access patterns and less benefits from ruy's fat kernels, so we reward
+// larger blocks more than smaller ones.
+int GetKernelAmortizationScore(int block_size_log2, int rows, int cols,
+ int kernel_rows_log2, int kernel_cols_log2) {
+ const int block_rows = std::min(1 << block_size_log2, rows);
+ const int block_cols = std::min(1 << block_size_log2, cols);
+ const int kernels_per_block_log2 =
+ floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2;
+ RUY_DCHECK_GE(kernels_per_block_log2, 0);
+ // The values here have been tuned on ARM Cortex-A55.
+ // We expect this to have to be tuned differently for other CPUs.
+ if (kernels_per_block_log2 == 0) {
+ return 0;
+ } else if (kernels_per_block_log2 == 1) {
+ return 8;
+ } else if (kernels_per_block_log2 == 2) {
+ return 16;
+ } else if (kernels_per_block_log2 == 3) {
+ return 24;
+ } else if (kernels_per_block_log2 == 4) {
+ return 32;
+ } else if (kernels_per_block_log2 == 5) {
+ return 40;
+ } else if (kernels_per_block_log2 == 6) {
+ return 48;
+ } else if (kernels_per_block_log2 == 7) {
+ return 56;
+ } else {
+ return 64;
+ }
+}
+
+} // namespace
+
+bool IsObviouslyLinearTraversal(int rows, int cols, int depth,
+ int lhs_scalar_size, int rhs_scalar_size,
+ const CpuCacheParams& cpu_cache_params) {
+ if (rows == 1 || cols == 1) {
+ return true;
+ }
+ // Normally, GetTraversalOrder wants the dimensions (rows x cols) divided
+ // by the rectangularness factors, since any non-linear traversal order will
+ // be local to each subdivision. In the present function, we don't know the
+ // rectangularness factors yet, and we can't just call GetRectangularness
+ // as that requires knowing the kernel block layout. Since we just want
+ // a coarse estimate with only the guarantee that if we return `true` then
+ // linear traversal will be used, it is OK here to over-estimate `rows` and
+ // `cols`, by omitting to divide them by the rectangularness factors.ß
+ return GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size,
+ cpu_cache_params) == BlockMapTraversalOrder::kLinear;
+}
+
+void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
+ int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
+ int tentative_thread_count,
+ const CpuCacheParams& cpu_cache_params, BlockMap* block_map) {
+ RUY_TRACE_SCOPE;
+ profiler::ScopeLabel label("MakeBlockMap");
+
+ RUY_DCHECK_GE(rows, kernel_rows);
+ RUY_DCHECK_GE(cols, kernel_cols);
+ RUY_DCHECK_EQ(rows % kernel_rows, 0);
+ RUY_DCHECK_EQ(cols % kernel_cols, 0);
+
+ // Estimate the 'rectangularness', the first level of subdivision bringing
+ // the shape to within 2x of a square shape.
+ int rows_rectangularness_log2 = 0;
+ int cols_rectangularness_log2 = 0;
+ GetRectangularness(rows, cols, kernel_rows, kernel_cols,
+ &rows_rectangularness_log2, &cols_rectangularness_log2);
+
+ const int kernel_rows_log2 = pot_log2(kernel_rows);
+ const int kernel_cols_log2 = pot_log2(kernel_cols);
+ const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2);
+
+ const int size = std::min(rows, cols);
+ const int size_log2 = std::max(kernel_size_log2, floor_log2(size));
+
+ RUY_DCHECK_GE(size_log2, kernel_size_log2);
+
+ // Heuristic selecting the power-of-two grid subdivision insider of each
+ // square-ish region (past the above subdivision by 'rectangularness').
+ // Note that it is the number of subdivisions, not the resulting block size,
+ // that will be a power of two. But inside of that heuristic, it simplifies
+ // code to talk in terms of 'block_size_log2', as if it were the block size
+ // that were a power of two. This 'block_size_log2' is to be interpreted as
+ // "log2 rounded below", e.g. when block_size_log2=8 we might have a block
+ // size in [256, 511]. When the shape is non-square, rows!=cols, this
+ // refers to the smaller of the two, so the other might be as large as
+ // 1021 (can't be 1022 because following the above 'rectangularness'
+ // subdivision, the aspect ratio is already < 2).
+
+ // We are going to try candidate values for block_size_log2 ranging from
+ // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2).
+ // For each of them we will compute a 'score' by adding individual scores
+ // for a few different considerations, all of which is entirely empirical.
+ // The values (and possibly the logic) around here are all subject to tuning
+ // based on benchmarks on different hardware. The current values are based
+ // on benchmarking on Qualcomm S855 (big and little cores), arm64,
+ // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead
+ // and tune this as needed to achieve good performance elsewhere. Use
+ // the unit test, block_map_test, to encode values that should be preserved
+ // on specific architectures. Use RUY_TRACE to debug the current heuristics
+ // and RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2 to test the impact of a
+ // different block_size_log2 choice, to empirically find the optimal value
+ // before getting to updating the heuristic so that it produces that value.
+ static constexpr int kMaxKernelsPerBlockLog2 = 6;
+ const int max_block_size_log2 =
+ std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2);
+ int best_score = std::numeric_limits<int>::min();
+ int best_score_block_size_log2 = -1;
+ RUY_TRACE_INFO(MAKE_BLOCK_MAP_START);
+ for (int block_size_log2 = kernel_size_log2;
+ block_size_log2 <= max_block_size_log2; block_size_log2++) {
+ const int multithreading_score = GetMultithreadingScore(
+ block_size_log2, rows, cols, tentative_thread_count);
+ const int cache_locality_score = GetCacheLocalityScore(
+ block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2,
+ lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
+ const int kernel_amortization_score = GetKernelAmortizationScore(
+ block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2);
+ const int score =
+ multithreading_score + cache_locality_score + kernel_amortization_score;
+ if (score >= best_score) {
+ best_score = score;
+ best_score_block_size_log2 = block_size_log2;
+ }
+ RUY_TRACE_INFO(MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE);
+ }
+
+#ifdef RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2
+ // Useful for tuning.
+ best_score_block_size_log2 = RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2;
+#endif
+
+ // As explained in the above comment, phrasing the above code in terms of
+ // block_size_log2 was only convenience inside of that heuristic. Now we
+ // revert to talking in terms of grid subdivision. That is what will actually
+ // be powers of two.
+ int num_blocks_base_log2 = size_log2 - best_score_block_size_log2;
+ RUY_DCHECK_GE(num_blocks_base_log2, 0);
+ const int num_blocks_of_rows_log2 =
+ num_blocks_base_log2 + rows_rectangularness_log2;
+ const int num_blocks_of_cols_log2 =
+ num_blocks_base_log2 + cols_rectangularness_log2;
+
+ // Now that we know the grid subdivision, we can pinpoint the exact block
+ // sizes. They can't be powers of two in general; they can't even be all
+ // equal in general; so the following few parameters will govern how blocks
+ // of slightly different shapes are put together in the block map.
+ const int small_block_rows =
+ round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows);
+ const int small_block_cols =
+ round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols);
+ const int rows_of_large_blocks =
+ round_up_pot(rows - (small_block_rows << num_blocks_of_rows_log2),
+ kernel_rows) >>
+ pot_log2(kernel_rows);
+ const int cols_of_large_blocks =
+ round_up_pot(cols - (small_block_cols << num_blocks_of_cols_log2),
+ kernel_cols) >>
+ pot_log2(kernel_cols);
+
+ // We have everything! Write out to the destination block_map.
+ block_map->dims[Side::kLhs] = rows;
+ block_map->dims[Side::kRhs] = cols;
+ block_map->kernel_dims[Side::kLhs] = kernel_rows;
+ block_map->kernel_dims[Side::kRhs] = kernel_cols;
+ block_map->num_blocks_base_log2 = num_blocks_base_log2;
+ block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2;
+ block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2;
+ block_map->small_block_dims[Side::kLhs] = small_block_rows;
+ block_map->small_block_dims[Side::kRhs] = small_block_cols;
+ block_map->large_blocks[Side::kLhs] = rows_of_large_blocks;
+ block_map->large_blocks[Side::kRhs] = cols_of_large_blocks;
+ // See the comment on GetTraversalOrder for why we are dividing `rows` and
+ // `cols` by the rectangularness subdivision parameters here.
+ block_map->traversal_order = GetTraversalOrder(
+ rows >> rows_rectangularness_log2, cols >> cols_rectangularness_log2,
+ depth, lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
+ // Done last: NumBlocks needs some of the block_map fields to be already set.
+ block_map->thread_count =
+ std::min(tentative_thread_count, NumBlocks(*block_map));
+ RUY_TRACE_INFO(MAKE_BLOCK_MAP_END);
+}
+
+void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
+ int* start, int* end) {
+ profiler::ScopeLabel label("GetBlockMatrixCoords");
+ *start = block * block_map.small_block_dims[side] +
+ std::min(block, block_map.large_blocks[side]) *
+ block_map.kernel_dims[side];
+ *end =
+ *start + block_map.small_block_dims[side] +
+ (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
+
+ RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]);
+ RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]);
+ RUY_DCHECK_LE(*end, block_map.dims[side]);
+ RUY_DCHECK_LT(*start, *end);
+ RUY_DCHECK_GE(*start, 0);
+}
+
+void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
+ SidePair<int>* start, SidePair<int>* end) {
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side],
+ &(*end)[side]);
+ }
+}
+
+} // namespace ruy
diff --git a/ruy/block_map.h b/ruy/block_map.h
new file mode 100644
index 0000000..0057509
--- /dev/null
+++ b/ruy/block_map.h
@@ -0,0 +1,162 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_BLOCK_MAP_H_
+#define RUY_RUY_BLOCK_MAP_H_
+
+#include "ruy/cpu_cache_params.h"
+#include "ruy/side_pair.h"
+
+namespace ruy {
+
+enum class BlockMapTraversalOrder {
+ // Plain old row-by-row or column-by-column traversal.
+ kLinear,
+ // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve
+ kFractalZ,
+ // Variant of Z-order doing a U instead of a Z.
+ kFractalU,
+ // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve
+ kFractalHilbert
+};
+
+// A BlockMap describes a tiling of a matrix, typically the destination matrix
+// of a matrix multiplication computation. As is standard in matrix
+// multiplication, a tile is called a "block".
+//
+// Ruy subdivides work by blocks of the destination matrix: each thread fully
+// computes a block at once, then moves on to another block; each block is
+// produced by a single thread.
+//
+// This ensures that the workloads for each block are mutually independent,
+// which reduces synchronization requirements.
+//
+// Typically, a matrix multiplication will early on create a BlockMap by
+// calling MakeBlockMap. It will then query the number of blocks in that
+// BlockMap by calling NumBlocks. It will then create a single atomic integer
+// counter indexing these blocks, called the 'index', and will distribute
+// work to its N threads by ensuring that each thread works on disjoint sets
+// of index values. For a given index value, the thread will call
+// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords
+// to find the actual row and column numbers of this block.
+//
+// There are two nested levels of subdivision. On a local level, the matrix is
+// tiled into a square NxN grid where N is a power of two, specifically:
+// N = 2^num_blocks_base_log2.
+//
+// At a larger scale, around these blocks, there may be one further
+// level of subdivision, in only one dimension: either along rows or along
+// columns. That is used to handle arbitrarily rectangular matrices. The
+// aforementioned high-level block grid is square, so it does not readily fit
+// well very rectangular matrices.
+//
+// Taking together these two nested levels of subdivision, the effective
+// tiling is by
+// 2^(num_blocks_base_log2 + rows_rectangularness_log2)
+// blocks in the row dimension, and by
+// 2^(num_blocks_base_log2 + cols_rectangularness_log2)
+// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols.
+//
+// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero.
+//
+// Finally, this BlockMap is designed to operate under alignment constraints:
+// two fields, kernel_rows and kernel_cols, describe the requested alignment
+// of the effective grid in both dimensions. The idea is to feed matrix
+// multiplication kernels with tiles that fit their width as much as possible.
+// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp.
+// kernel_cols) then some tile will have to have unaligned size. BlockMap
+// will only allow that to happen in the last position along each axis, so
+// as to minimize the overhead incurred onto the matrix multiplication kernels.
+struct BlockMap {
+ // The number of threads to use (to distribute the blocks to).
+ int thread_count;
+ // The order in which to traverse the matrix of which this BlockMap represents
+ // a tiling (hereafter "the matrix").
+ BlockMapTraversalOrder traversal_order;
+ // The dimensions of the block_map, that is, of the destination
+ // matrix rounded up to next multiples of kernel_dims.
+ SidePair<int> dims;
+ // Log2 of the minimum number of subdivisions of the grid along either axis.
+ int num_blocks_base_log2;
+ // Log2 of the additional subdivision of the rows/columns axis.
+ SidePair<int> rectangularness_log2;
+ // Requested alignment of the subdivisions of the grid along the rows/columns
+ // axis.
+ SidePair<int> kernel_dims;
+ // Internal helper. Minimum number of rows/columns in each block.
+ SidePair<int> small_block_dims;
+ // Internal helper. Number of blocks along each dimension that need to have
+ // their size in that dimension be given by (small_block_dims + kernel_dims)
+ // instead of just small_block_dims.
+ SidePair<int> large_blocks;
+};
+
+// This function produces a coarse estimate of whether linear traversal will
+// be used for this matmul. It offers a one-way guarantee: if this function
+// returns true then linear traversal will be used.
+//
+// The purpose of this function is to allow TrMul to make a cheap, early
+// decision to enter a "simple loop" code path for simple cases.
+bool IsObviouslyLinearTraversal(int rows, int cols, int depth,
+ int lhs_scalar_size, int rhs_scalar_size,
+ const CpuCacheParams& cpu_cache_params);
+
+// Create a BlockMap suitable for tiling the destination matrix in a
+// matrix multiplication with the given parameters.
+void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
+ int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
+ int tentative_thread_count,
+ const CpuCacheParams& cpu_cache_params, BlockMap* block_map);
+
+// Maps an integer index to a block position in the grid.
+void GetBlockByIndex(const BlockMap& block_map, int index,
+ SidePair<int>* block);
+
+// Given a block position in the grid, returns its actual
+// position in the matrix that the BlockMap refers to in the dimension
+// referred to by `side`: along rows if side==kLhs, along columns if
+// side==kRhs.
+void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
+ int* start, int* end);
+
+// Given a block position in the grid, returns its actual
+// position in the matrix that the BlockMap refers to in terms of
+// actual row/column indices.
+void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
+ SidePair<int>* start, SidePair<int>* end);
+
+// Returns the number of grid subdivisions along the rows dimension (if
+// side == kLhs) or columns dimension (if side == kRhs).
+inline int NumBlocksPerSide(Side side, const BlockMap& block_map) {
+ return 1 << (block_map.num_blocks_base_log2 +
+ block_map.rectangularness_log2[side]);
+}
+
+// Returns the overall number of blocks in
+// the BlockMap. The valid index values to pass to GetBlockByIndex are the
+// integers from 0 to N-1 where N is the value returned here.
+//
+// Note that it is always true that
+// NumBlocks == NumBlocksOfRows * NumBlocksOfCols
+// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0.
+inline int NumBlocks(const BlockMap& block_map) {
+ return 1 << (2 * block_map.num_blocks_base_log2 +
+ block_map.rectangularness_log2[Side::kLhs] +
+ block_map.rectangularness_log2[Side::kRhs]);
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_BLOCK_MAP_H_
diff --git a/ruy/block_map_test.cc b/ruy/block_map_test.cc
new file mode 100644
index 0000000..8245a5c
--- /dev/null
+++ b/ruy/block_map_test.cc
@@ -0,0 +1,259 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/block_map.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <limits>
+#include <vector>
+
+#include "ruy/cpu_cache_params.h"
+#include "ruy/gtest_wrapper.h"
+#include "ruy/platform.h"
+#include "ruy/side_pair.h"
+
+namespace ruy {
+namespace {
+
+#if RUY_PLATFORM_NEON_64
+
+// Unless otherwise specified, these tests have been tuned on ARM Cortex-A55.
+void MakeBlockMapTuningTest(int rows, int cols, int depth, int kernel_rows,
+ int kernel_cols, int lhs_scalar_size,
+ int rhs_scalar_size, int tentative_thread_count,
+ int expected_num_blocks_base_log2,
+ int expected_rectangularness_log2) {
+ // Plausible Cortex-A55 cache sizes.
+ CpuCacheParams cpu_cache_params;
+ cpu_cache_params.local_cache_size = 128 * 1024;
+ cpu_cache_params.last_level_cache_size = 1024 * 1024;
+ BlockMap block_map;
+ MakeBlockMap(rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size,
+ rhs_scalar_size, tentative_thread_count, cpu_cache_params,
+ &block_map);
+ EXPECT_EQ(block_map.num_blocks_base_log2, expected_num_blocks_base_log2);
+ EXPECT_EQ(std::min(block_map.rectangularness_log2[Side::kLhs],
+ block_map.rectangularness_log2[Side::kRhs]),
+ 0);
+ EXPECT_EQ(std::max(block_map.rectangularness_log2[Side::kLhs],
+ block_map.rectangularness_log2[Side::kRhs]),
+ expected_rectangularness_log2);
+}
+
+TEST(BlockMapTest, MakeBlockMapTuningTest8bitCubicShapesOneThreadNeonDotprod) {
+ MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+}
+
+TEST(BlockMapTest,
+ MakeBlockMapTuningTest8bitCubicShapesFourThreadsNeonDotprod) {
+ MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1,
+ /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1,
+ /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1,
+ /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 2,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1,
+ /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 2,
+ /* expected_rectangularness_log2 */ 0);
+}
+
+TEST(BlockMapTest, MakeBlockMapTuningTest32bit) {
+ MakeBlockMapTuningTest(256, 256, 256, 8, 8, 4, 4,
+ /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 3,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(4096, 4096, 4096, 8, 8, 4, 4,
+ /* tentative_thread_count */ 4,
+ /* expected_num_blocks_base_log2 */ 7,
+ /* expected_rectangularness_log2 */ 0);
+}
+
+TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) {
+ MakeBlockMapTuningTest(256, 16, 256, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 3);
+ MakeBlockMapTuningTest(24, 2400, 256, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 6);
+}
+
+#endif
+
+int L1Distance(const SidePair<int>& a, const SidePair<int>& b) {
+ return std::abs(a[Side::kLhs] - b[Side::kLhs]) +
+ std::abs(a[Side::kRhs] - b[Side::kRhs]);
+}
+
+void GetBlockByIndexSquareTest(int num_blocks_base_log2,
+ BlockMapTraversalOrder traversal_order) {
+ // Arbitrary, does not affect this test. 3 is just a typical value.
+ constexpr int kKernelSizeLog2 = 3;
+
+ const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2;
+ BlockMap block_map;
+ block_map.thread_count = 1;
+ block_map.traversal_order = traversal_order;
+ block_map.num_blocks_base_log2 = num_blocks_base_log2;
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ block_map.dims[side] = 1 << size_log2;
+ block_map.rectangularness_log2[side] = 0;
+ block_map.kernel_dims[side] = 1 << kKernelSizeLog2;
+ block_map.small_block_dims[side] = block_map.kernel_dims[side];
+ block_map.large_blocks[side] = 0;
+ }
+
+ const int num_blocks_per_side = 1 << num_blocks_base_log2;
+ const int num_blocks = num_blocks_per_side * num_blocks_per_side;
+ EXPECT_EQ(num_blocks, NumBlocks(block_map));
+
+ // Perform a full traversal of all blocks, as if computing a whole matrix
+ // multiplication.
+ //
+ // Used to record how many times each block was hit by the traversal.
+ std::vector<int> block_hit_counts(num_blocks);
+ // Here we guard an assumption that all traversal orders start at (0, 0).
+ SidePair<int> previous_block_coords(0, 0);
+ // Sum of L1 norm of the coordinate change at every step of the traversal.
+ std::int64_t total_l1_distance = 0;
+ // Number of jumps i.e. traversal steps with a L1 norm greater than 1.
+ int discontinuity_count = 0;
+ for (int block_index = 0; block_index < num_blocks; block_index++) {
+ SidePair<int> block_coords;
+ GetBlockByIndex(block_map, block_index, &block_coords);
+ ++block_hit_counts[block_coords[Side::kLhs] +
+ num_blocks_per_side * block_coords[Side::kRhs]];
+ int distance = L1Distance(block_coords, previous_block_coords);
+ total_l1_distance += distance;
+ discontinuity_count += (distance > 1);
+ previous_block_coords = block_coords;
+ }
+
+ // Verify that each block was traversed exactly once.
+ for (int l = 0; l < num_blocks_per_side; l++) {
+ for (int r = 0; r < num_blocks_per_side; r++) {
+ EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1);
+ }
+ }
+
+ // Verify that the discontinuity_count and total_l1_distance are as expected
+ // for the given traversal_order.
+ switch (traversal_order) {
+ case BlockMapTraversalOrder::kFractalHilbert:
+ // No discontinuity at all with this space-filling continuous curve!
+ EXPECT_EQ(discontinuity_count, 0);
+ // Therefore, total_l1_distance has to be the number of blocks minus one.
+ EXPECT_EQ(total_l1_distance, num_blocks - 1);
+ break;
+ case BlockMapTraversalOrder::kLinear:
+ EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1);
+ EXPECT_EQ(total_l1_distance,
+ 2 * num_blocks_per_side * (num_blocks_per_side - 1));
+ break;
+ case BlockMapTraversalOrder::kFractalZ:
+ EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0);
+ EXPECT_EQ(total_l1_distance,
+ 2 * num_blocks_per_side * (num_blocks_per_side - 1));
+ break;
+ case BlockMapTraversalOrder::kFractalU: {
+ if (num_blocks_base_log2 == 0) {
+ EXPECT_EQ(discontinuity_count, 0);
+ EXPECT_EQ(total_l1_distance, 0);
+ } else {
+ int expected_discontinuity_count = 0;
+ int expected_total_l1_distance = 3;
+ for (int i = 2; i <= num_blocks_base_log2; i++) {
+ expected_discontinuity_count = 4 * expected_discontinuity_count + 2;
+ expected_total_l1_distance =
+ 4 * expected_total_l1_distance + (1 << (i + 1)) - 1;
+ }
+ EXPECT_EQ(discontinuity_count, expected_discontinuity_count);
+ EXPECT_EQ(total_l1_distance, expected_total_l1_distance);
+ }
+ break;
+ }
+ default:
+ abort();
+ }
+}
+
+TEST(BlockMapTest, GetBlockByIndexSquare) {
+ for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10;
+ num_blocks_base_log2++) {
+ for (BlockMapTraversalOrder traversal_order :
+ {BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ,
+ BlockMapTraversalOrder::kFractalU,
+ BlockMapTraversalOrder::kFractalHilbert}) {
+ GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order);
+ }
+ }
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/blocking_counter.cc b/ruy/blocking_counter.cc
new file mode 100644
index 0000000..b62eca9
--- /dev/null
+++ b/ruy/blocking_counter.cc
@@ -0,0 +1,49 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/blocking_counter.h"
+
+#include "ruy/check_macros.h"
+#include "ruy/wait.h"
+
+namespace ruy {
+
+void BlockingCounter::Reset(int initial_count) {
+ int old_count_value = count_.load(std::memory_order_relaxed);
+ RUY_DCHECK_EQ(old_count_value, 0);
+ (void)old_count_value;
+ count_.store(initial_count, std::memory_order_release);
+}
+
+bool BlockingCounter::DecrementCount() {
+ int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel);
+ RUY_DCHECK_GT(old_count_value, 0);
+ int count_value = old_count_value - 1;
+ bool hit_zero = (count_value == 0);
+ if (hit_zero) {
+ std::lock_guard<std::mutex> lock(count_mutex_);
+ count_cond_.notify_all();
+ }
+ return hit_zero;
+}
+
+void BlockingCounter::Wait(const Duration spin_duration) {
+ const auto& condition = [this]() {
+ return count_.load(std::memory_order_acquire) == 0;
+ };
+ ruy::Wait(condition, spin_duration, &count_cond_, &count_mutex_);
+}
+
+} // namespace ruy
diff --git a/ruy/blocking_counter.h b/ruy/blocking_counter.h
new file mode 100644
index 0000000..806909b
--- /dev/null
+++ b/ruy/blocking_counter.h
@@ -0,0 +1,66 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_BLOCKING_COUNTER_H_
+#define RUY_RUY_BLOCKING_COUNTER_H_
+
+#include <atomic>
+#include <condition_variable> // NOLINT(build/c++11) // IWYU pragma: keep
+#include <mutex> // NOLINT(build/c++11) // IWYU pragma: keep
+
+#include "ruy/time.h"
+
+namespace ruy {
+
+// A BlockingCounter lets one thread to wait for N events to occur.
+// This is how the master thread waits for all the worker threads
+// to have finished working.
+// The waiting is done using a naive spinlock waiting for the atomic
+// count_ to hit the value 0. This is acceptable because in our usage
+// pattern, BlockingCounter is used only to synchronize threads after
+// short-lived tasks (performing parts of the same GEMM). It is not used
+// for synchronizing longer waits (resuming work on the next GEMM).
+class BlockingCounter {
+ public:
+ BlockingCounter() : count_(0) {}
+
+ // Sets/resets the counter; initial_count is the number of
+ // decrementing events that the Wait() call will be waiting for.
+ void Reset(int initial_count);
+
+ // Decrements the counter; if the counter hits zero, signals
+ // the threads that were waiting for that, and returns true.
+ // Otherwise (if the decremented count is still nonzero),
+ // returns false.
+ bool DecrementCount();
+
+ // Waits for the N other threads (N having been set by Reset())
+ // to hit the BlockingCounter.
+ //
+ // Will first spin-wait for `spin_duration` before reverting to passive wait.
+ void Wait(const Duration spin_duration);
+
+ private:
+ std::atomic<int> count_;
+
+ // The condition variable and mutex allowing to passively wait for count_
+ // to reach the value zero, in the case of longer waits.
+ std::condition_variable count_cond_;
+ std::mutex count_mutex_;
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_BLOCKING_COUNTER_H_
diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl
new file mode 100644
index 0000000..836f47a
--- /dev/null
+++ b/ruy/build_defs.bzl
@@ -0,0 +1,79 @@
+"""Build definitions for Ruy."""
+
+# Helper for ruy_copts().
+# Returns warnings flags to use for all ruy code.
+def ruy_copts_warnings():
+ return select({
+ "@bazel_tools//src/conditions:windows": [
+ # We run into trouble on Windows toolchains with warning flags,
+ # as mentioned in the comments below on each flag.
+ # We could be more aggressive in enabling supported warnings on each
+ # Windows toolchain, but we compromise with keeping BUILD files simple
+ # by limiting the number of config_setting's.
+ ],
+ "//conditions:default": [
+ "-Wall",
+ # Some clang-based Windows toolchains have more warnings in -Wextra.
+ "-Wextra",
+ # TensorFlow is C++14 at the moment. This flag ensures that we warn
+ # on any code that isn't C++14, but MSVC does not support it.
+ "-Wc++14-compat",
+ # Warn on preprocessor expansion of an undefined token, e.g. catching
+ # typos such as `#ifdef __linus__` instead of `#ifdef __linux__`.
+ # Not supported by MSVC.
+ "-Wundef",
+ ],
+ })
+
+# Helper for ruy_copts().
+# Returns flags to use to enable NEON if applicable, for all ruy code.
+def ruy_copts_neon():
+ return select({
+ # OK to crash old devices that lack full NEON support.
+ # No need to pass -mfloat-abi=softfp, that is already on.
+ "//ruy:arm32_assuming_neon": [
+ "-mfpu=neon",
+ ],
+ "//conditions:default": [],
+ })
+
+# Helper for ruy_copts().
+# Returns optimization flags to use for all ruy code.
+def ruy_copts_optimize():
+ return select({
+ # On some toolchains, typically mobile, "-c opt" is interpreted by
+ # default as "optimize for size, not for speed". For Ruy code,
+ # optimizing for speed is the better compromise, so we override that.
+ # Careful to keep debug builds debuggable, whence the select based
+ # on the compilation mode.
+ "//ruy:do_not_want_O3": [],
+ "//conditions:default": ["-O3"],
+ })
+
+# Returns compiler flags to use for all ruy code.
+def ruy_copts():
+ return ruy_copts_warnings() + ruy_copts_neon() + ruy_copts_optimize()
+
+def ruy_copts_avx():
+ return select({
+ "//ruy:x86_64_and_not_msvc": ["-mavx"],
+ "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX"],
+ "//conditions:default": [],
+ })
+
+def ruy_copts_avx2_fma():
+ return select({
+ "//ruy:x86_64_and_not_msvc": ["-mavx2", "-mfma"],
+ "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX2"],
+ "//conditions:default": [],
+ })
+
+def ruy_copts_avx512():
+ # In some clang-based toolchains, in the default compilation mode (not -c opt),
+ # heavy spillage in the AVX512 kernels results in stack frames > 50k. This issue does not exist
+ # in optimized builds (-c opt).
+ return select({
+ "//ruy:x86_64_and_not_msvc": ["$(STACK_FRAME_UNLIMITED)", "-mavx512f", "-mavx512vl", "-mavx512cd", "-mavx512bw", "-mavx512dq"],
+ "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX512"],
+ "//conditions:default": [],
+ })
diff --git a/ruy/build_defs.oss.bzl b/ruy/build_defs.oss.bzl
new file mode 100644
index 0000000..e405b41
--- /dev/null
+++ b/ruy/build_defs.oss.bzl
@@ -0,0 +1,15 @@
+"""Build definitions for Ruy that are specific to the open-source build."""
+
+# Used for targets that #include <thread>
+def ruy_linkopts_thread_standard_library():
+ # In open source builds, GCC is a common occurence. It requires "-pthread"
+ # to use the C++11 <thread> standard library header. This breaks the
+ # opensource build on Windows and probably some other platforms, so that
+ # will need to be fixed as needed. Ideally we would like to do this based
+ # on GCC being the compiler, but that does not seem to be easy to achieve
+ # with Bazel. Instead we do the following, which is copied from
+ # https://github.com/abseil/abseil-cpp/blob/1112609635037a32435de7aa70a9188dcb591458/absl/base/BUILD.bazel#L155
+ return select({
+ "@bazel_tools//src/conditions:windows": [],
+ "//conditions:default": ["-pthread"],
+ })
diff --git a/ruy/check_macros.h b/ruy/check_macros.h
new file mode 100644
index 0000000..13d27e2
--- /dev/null
+++ b/ruy/check_macros.h
@@ -0,0 +1,147 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// self-contained, minimal, CHECK/DCHECK macros similar to glog.
+
+#ifndef RUY_RUY_CHECK_MACROS_H_
+#define RUY_RUY_CHECK_MACROS_H_
+
+#include <cstdio>
+#include <cstdlib>
+#include <functional>
+#include <type_traits>
+
+namespace ruy {
+namespace check_macros {
+
+constexpr int kValueBufSize = 32;
+
+template <typename T, typename Enable = void>
+struct ToString {
+ static void Run(const T&, char* buf) { snprintf(buf, kValueBufSize, "(?)"); }
+};
+
+template <>
+struct ToString<float, void> {
+ static void Run(float value, char* buf) {
+ snprintf(buf, kValueBufSize, "%.9g", static_cast<double>(value));
+ }
+};
+
+template <>
+struct ToString<double, void> {
+ static void Run(double value, char* buf) {
+ snprintf(buf, kValueBufSize, "%.16g", value);
+ }
+};
+
+template <typename T>
+struct ToString<T, typename std::enable_if<std::is_integral<T>::value>::type> {
+ static void Run(const T& value, char* buf) {
+ snprintf(buf, kValueBufSize, "%lld", static_cast<long long>(value));
+ }
+};
+
+template <typename T>
+struct ToString<T*, void> {
+ static void Run(T* value, char* buf) {
+ snprintf(buf, kValueBufSize, "%p", value);
+ }
+};
+
+template <typename T>
+struct ToString<T, typename std::enable_if<std::is_enum<T>::value>::type> {
+ static void Run(const T& value, char* buf) {
+ snprintf(buf, kValueBufSize, "(enum value %d)", static_cast<int>(value));
+ }
+};
+
+inline void CheckImpl(bool condition, const char* file, int line,
+ const char* macro, const char* condition_str) {
+ if (!condition) {
+ fprintf(stderr, "%s:%d: %s condition not satisfied: %s\n", file, line,
+ macro, condition_str);
+ abort();
+ }
+}
+
+template <template <typename T> class Comparison, typename LhsType,
+ typename RhsType>
+inline void CheckImpl(const char* file, int line, const char* macro,
+ const char* lhs, const LhsType& lhs_value,
+ const char* op_symbol, const char* rhs,
+ const RhsType& rhs_value) {
+ using CommonType = typename std::common_type<LhsType, RhsType>::type;
+ if (!Comparison<CommonType>()(lhs_value, rhs_value)) {
+ char lhs_value_buf[kValueBufSize];
+ ToString<LhsType>::Run(lhs_value, lhs_value_buf);
+ char rhs_value_buf[kValueBufSize];
+ ToString<RhsType>::Run(rhs_value, rhs_value_buf);
+ fprintf(
+ stderr,
+ "%s:%d: %s condition not satisfied: [ %s %s %s ] with values [ "
+ "%s %s %s ].\n",
+ file, line, macro, lhs, op_symbol, rhs, lhs_value_buf, op_symbol,
+ rhs_value_buf);
+ abort();
+ }
+}
+
+#define RUY_CHECK_IMPL(macro, condition) \
+ ruy::check_macros::CheckImpl(condition, __FILE__, __LINE__, #macro, \
+ #condition)
+
+#define RUY_CHECK_OP_IMPL(macro, lhs, op_symbol, op_comparison, rhs) \
+ ruy::check_macros::CheckImpl<op_comparison>( \
+ __FILE__, __LINE__, #macro, #lhs, lhs, #op_symbol, #rhs, rhs)
+
+#define RUY_CHECK(condition) RUY_CHECK_IMPL(RUY_CHECK, condition)
+#define RUY_CHECK_EQ(x, y) \
+ RUY_CHECK_OP_IMPL(RUY_CHECK_EQ, x, ==, std::equal_to, y)
+#define RUY_CHECK_NE(x, y) \
+ RUY_CHECK_OP_IMPL(RUY_CHECK_NE, x, !=, std::not_equal_to, y)
+#define RUY_CHECK_GE(x, y) \
+ RUY_CHECK_OP_IMPL(RUY_CHECK_GE, x, >=, std::greater_equal, y)
+#define RUY_CHECK_GT(x, y) \
+ RUY_CHECK_OP_IMPL(RUY_CHECK_GT, x, >, std::greater, y)
+#define RUY_CHECK_LE(x, y) \
+ RUY_CHECK_OP_IMPL(RUY_CHECK_LE, x, <=, std::less_equal, y)
+#define RUY_CHECK_LT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LT, x, <, std::less, y)
+
+#ifdef NDEBUG
+#define RUY_DCHECK_IS_ENABLED false
+#else
+#define RUY_DCHECK_IS_ENABLED true
+#endif
+
+#define RUY_DCHECK(condition) \
+ if (RUY_DCHECK_IS_ENABLED) RUY_CHECK(condition)
+#define RUY_DCHECK_EQ(x, y) \
+ if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_EQ(x, y)
+#define RUY_DCHECK_NE(x, y) \
+ if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_NE(x, y)
+#define RUY_DCHECK_GE(x, y) \
+ if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_GE(x, y)
+#define RUY_DCHECK_GT(x, y) \
+ if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_GT(x, y)
+#define RUY_DCHECK_LE(x, y) \
+ if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_LE(x, y)
+#define RUY_DCHECK_LT(x, y) \
+ if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_LT(x, y)
+
+} // end namespace check_macros
+} // end namespace ruy
+
+#endif // RUY_RUY_CHECK_MACROS_H_
diff --git a/ruy/check_macros_test.cc b/ruy/check_macros_test.cc
new file mode 100644
index 0000000..4e7508f
--- /dev/null
+++ b/ruy/check_macros_test.cc
@@ -0,0 +1,153 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/check_macros.h"
+
+#include "ruy/gtest_wrapper.h"
+
+namespace {
+
+#define TEST_CONDITION_FOR_FAMILY(family, vacuously_succeeds, condition) \
+ do { \
+ if (vacuously_succeeds || (condition)) { \
+ RUY_##family(condition); \
+ } \
+ } while (false)
+
+#define TEST_COMPARISON_FOR_FAMILY(family, vacuously_succeeds, op_name, x, op, \
+ y) \
+ do { \
+ if (vacuously_succeeds || ((x)op(y))) { \
+ RUY_##family##_##op_name(x, y); \
+ } \
+ } while (false)
+
+#ifdef NDEBUG
+#define TEST_CONDITION(condition) \
+ do { \
+ TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \
+ } while (false)
+#define TEST_COMPARISON(op_name, x, op, y) \
+ do { \
+ TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \
+ } while (false)
+#else
+#define TEST_CONDITION(condition) \
+ do { \
+ TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \
+ TEST_CONDITION_FOR_FAMILY(DCHECK, false, condition); \
+ } while (false)
+#define TEST_COMPARISON(op_name, x, op, y) \
+ do { \
+ TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \
+ TEST_COMPARISON_FOR_FAMILY(DCHECK, false, op_name, x, op, y); \
+ } while (false)
+
+#endif
+
+template <typename LhsType, typename RhsType>
+void TestEqualityComparisons(const LhsType& lhs, const RhsType& rhs) {
+ RUY_CHECK_EQ(lhs, lhs);
+ TEST_COMPARISON(EQ, lhs, ==, lhs);
+ RUY_CHECK_EQ(lhs, lhs);
+ RUY_CHECK_EQ(lhs, lhs);
+ if (lhs == rhs) {
+ RUY_CHECK_EQ(lhs, rhs);
+ }
+ if (lhs != rhs) {
+ RUY_CHECK_NE(lhs, rhs);
+ }
+}
+
+template <typename LhsType, typename RhsType>
+void TestComparisons(const LhsType& lhs, const RhsType& rhs) {
+ TestEqualityComparisons(lhs, rhs);
+ if (lhs > rhs) {
+ RUY_CHECK_GT(lhs, rhs);
+ }
+ if (lhs >= rhs) {
+ RUY_CHECK_GE(lhs, rhs);
+ }
+ if (lhs < rhs) {
+ RUY_CHECK_LT(lhs, rhs);
+ }
+ if (lhs <= rhs) {
+ RUY_CHECK_LE(lhs, rhs);
+ }
+}
+
+TEST(CheckMacrosTest, IntInt) {
+ TestComparisons(0, 0);
+ TestComparisons(0, 1);
+ TestComparisons(1, -1);
+ TestComparisons(-1, 0);
+ TestComparisons(123, -456);
+ TestComparisons(std::numeric_limits<int>::min(),
+ std::numeric_limits<int>::max());
+ TestComparisons(123, std::numeric_limits<int>::max());
+ TestComparisons(123, std::numeric_limits<int>::min());
+}
+
+TEST(CheckMacrosTest, Uint8Uint8) {
+ TestComparisons<std::uint8_t, std::uint8_t>(0, 0);
+ TestComparisons<std::uint8_t, std::uint8_t>(255, 0);
+ TestComparisons<std::uint8_t, std::uint8_t>(0, 255);
+ TestComparisons<std::uint8_t, std::uint8_t>(12, 34);
+}
+
+TEST(CheckMacrosTest, Uint8Int) {
+ TestComparisons<std::uint8_t, int>(0, std::numeric_limits<int>::min());
+ TestComparisons<std::uint8_t, int>(255, std::numeric_limits<int>::min());
+ TestComparisons<std::uint8_t, int>(0, std::numeric_limits<int>::max());
+ TestComparisons<std::uint8_t, int>(255, std::numeric_limits<int>::max());
+}
+
+TEST(CheckMacrosTest, FloatFloat) {
+ TestComparisons(0.f, 0.f);
+ TestComparisons(0.f, 1.f);
+ TestComparisons(1.f, -1.f);
+ TestComparisons(-1.f, 0.f);
+ TestComparisons(123.f, -456.f);
+ TestComparisons(std::numeric_limits<float>::lowest(),
+ std::numeric_limits<float>::max());
+ TestComparisons(123.f, std::numeric_limits<float>::max());
+ TestComparisons(123.f, std::numeric_limits<float>::lowest());
+}
+
+TEST(CheckMacrosTest, IntFloat) {
+ TestComparisons(0, 0.f);
+ TestComparisons(0, 1.f);
+ TestComparisons(1, -1.f);
+ TestComparisons(-1, 0.f);
+ TestComparisons(123, -456.f);
+ TestComparisons(std::numeric_limits<int>::lowest(),
+ std::numeric_limits<float>::max());
+ TestComparisons(123, std::numeric_limits<float>::max());
+ TestComparisons(123, std::numeric_limits<float>::lowest());
+}
+
+TEST(CheckMacrosTest, EnumClass) {
+ enum class SomeEnumClass { kA, kB, kC };
+ TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kA);
+ TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kB);
+ TestEqualityComparisons(SomeEnumClass::kC, SomeEnumClass::kB);
+}
+
+} // namespace
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/context.cc b/ruy/context.cc
new file mode 100644
index 0000000..4661738
--- /dev/null
+++ b/ruy/context.cc
@@ -0,0 +1,58 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/context.h"
+
+#include "ruy/ctx.h"
+#include "ruy/ctx_impl.h"
+#include "ruy/path.h"
+#include "ruy/performance_advisory.h"
+#include "ruy/prepacked_cache.h"
+#include "ruy/thread_pool.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+Context::Context() : impl_(new CtxImpl) {}
+Context::~Context() { delete impl_; }
+
+const Ctx& Context::ctx() const { return static_cast<const Ctx&>(*impl_); }
+Ctx* Context::mutable_ctx() { return static_cast<Ctx*>(impl_); }
+
+Path Context::last_used_path() const { return ctx().last_used_path(); }
+Tuning Context::explicit_tuning() const { return ctx().explicit_tuning(); }
+void Context::set_explicit_tuning(Tuning value) {
+ mutable_ctx()->set_explicit_tuning(value);
+}
+const ThreadPool& Context::thread_pool() const { return ctx().thread_pool(); }
+ThreadPool* Context::mutable_thread_pool() {
+ return mutable_ctx()->mutable_thread_pool();
+}
+int Context::max_num_threads() const { return ctx().max_num_threads(); }
+void Context::set_max_num_threads(int value) {
+ mutable_ctx()->set_max_num_threads(value);
+}
+
+void Context::ClearPrepackedCache() { mutable_ctx()->ClearPrepackedCache(); }
+
+bool Context::performance_advisory(PerformanceAdvisory advisory) const {
+ return ctx().performance_advisory(advisory);
+}
+
+void Context::set_runtime_enabled_paths(Path paths) {
+ mutable_ctx()->SetRuntimeEnabledPaths(paths);
+}
+
+} // namespace ruy
diff --git a/ruy/context.h b/ruy/context.h
new file mode 100644
index 0000000..79a4b5c
--- /dev/null
+++ b/ruy/context.h
@@ -0,0 +1,108 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Context is the user-facing context class.
+
+#ifndef RUY_RUY_CONTEXT_H_
+#define RUY_RUY_CONTEXT_H_
+
+#include <cstdint>
+
+namespace ruy {
+
+class Ctx;
+class CtxImpl;
+class ThreadPool;
+enum class Path : std::uint8_t;
+enum class Tuning;
+enum class PerformanceAdvisory;
+
+// A Context holds runtime information used by Ruy. It holds runtime resources
+// such as the workers thread pool and the allocator (which holds buffers for
+// temporary data), as well as runtime options controlling which Paths are
+// enabled (typically based on which instruction sets are detected) and how
+// many threads to use.
+class Context final {
+ public:
+ Context();
+ ~Context();
+
+ // Returns the Path enum value that corresponds to the code path used by
+ // the last ruy::Mul with this Context.
+ Path last_used_path() const;
+
+ // Control of whether to use kernels tuned for in-order or out-of-order CPU
+ // cores. The default is auto-detection, so these methods should only be used
+ // to override that auto-detection if it's not working as intended or for
+ // testing.
+ Tuning explicit_tuning() const;
+ void set_explicit_tuning(Tuning value);
+
+ // The thread pool held by this context to dispatch a ruy::Mul to worker
+ // threads.
+ //
+ // By default, threads may spin-wait for a few milliseconds before reverting
+ // to passive wait. This can be controlled by
+ // `mutable_thread_pool()->set_spin_milliseconds(value)`.
+ const ThreadPool& thread_pool() const;
+ ThreadPool* mutable_thread_pool();
+
+ // Controls the maximum number of threads to be used by ruy::Mul with this
+ // Context. The number of threads in the pool will be that value minus one,
+ // as the remaining portion of the work is done directly on the calling
+ // thread.
+ //
+ // This defaults to 1. Multi-threading in ruy is always opt-in. There is
+ // no auto-detection of hardware concurrency. That is on purpose, ruy focuses
+ // on mobile applications where such concepts are difficult to define
+ // (e.g. ARM big.LITTLE).
+ int max_num_threads() const;
+ void set_max_num_threads(int value);
+
+ // Returns true of the last ruy::Mul using this Context flagged the specified
+ // `advisory`. This is reset by each ruy::Mul call.
+ bool performance_advisory(PerformanceAdvisory advisory) const;
+
+ // When using Matrix::set_cache_policy(), this Context will keep a cache of
+ // pre-packed matrix data. This function clears that cache.
+ void ClearPrepackedCache();
+
+ // Override auto-detection of supported code paths.
+ //
+ // Passing `paths == Path::kNone` means reverting to the default behavior.
+ // This will trigger auto-detection on the next use.
+ //
+ // Other values will override auto-detection with the explicitly provided set
+ // of paths.
+ //
+ // Paths in kNonArchPaths are always implicitly supported.
+ void set_runtime_enabled_paths(Path paths);
+
+ private:
+ CtxImpl* const impl_;
+
+ const Ctx& ctx() const;
+ Ctx* mutable_ctx();
+
+ friend const Ctx* get_ctx(const Context*);
+ friend Ctx* get_ctx(Context*);
+
+ // Disallow copy
+ Context(const Context&) = delete;
+};
+
+} // end namespace ruy
+
+#endif // RUY_RUY_CONTEXT_H_
diff --git a/ruy/context_get_ctx.cc b/ruy/context_get_ctx.cc
new file mode 100644
index 0000000..dce5858
--- /dev/null
+++ b/ruy/context_get_ctx.cc
@@ -0,0 +1,27 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/context_get_ctx.h"
+
+#include "ruy/ctx_impl.h"
+
+namespace ruy {
+
+const Ctx* get_ctx(const Context* context) {
+ return static_cast<const Ctx*>(context->impl_);
+}
+Ctx* get_ctx(Context* context) { return static_cast<Ctx*>(context->impl_); }
+
+} // namespace ruy
diff --git a/ruy/context_get_ctx.h b/ruy/context_get_ctx.h
new file mode 100644
index 0000000..cc599fe
--- /dev/null
+++ b/ruy/context_get_ctx.h
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Gateway to access the Ctx (internal context interface for ruy code) from
+// a Context (public-facing class). Befriended by Context.
+
+#ifndef THIRD_PARTY_RUY_RUY_CONTEXT_GET_CTX_H_
+#define THIRD_PARTY_RUY_RUY_CONTEXT_GET_CTX_H_
+
+#include "ruy/context.h"
+#include "ruy/ctx.h"
+
+namespace ruy {
+
+const Ctx* get_ctx(const Context* context);
+Ctx* get_ctx(Context*);
+
+} // namespace ruy
+
+#endif // THIRD_PARTY_RUY_RUY_CONTEXT_GET_CTX_H_
diff --git a/ruy/context_test.cc b/ruy/context_test.cc
new file mode 100644
index 0000000..4e69e65
--- /dev/null
+++ b/ruy/context_test.cc
@@ -0,0 +1,45 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/context.h"
+
+#include "ruy/gtest_wrapper.h"
+#include "ruy/path.h"
+#include "ruy/prepacked_cache.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+namespace {
+
+TEST(ContextTest, ContextClassSanity) {
+ Context context;
+ EXPECT_EQ(context.last_used_path(), Path::kNone);
+ EXPECT_EQ(context.explicit_tuning(), Tuning::kAuto);
+ EXPECT_EQ(&context.thread_pool(), context.mutable_thread_pool());
+ EXPECT_NE(context.mutable_thread_pool(), nullptr);
+ EXPECT_EQ(context.max_num_threads(), 1);
+ context.set_explicit_tuning(Tuning::kGeneric);
+ context.set_max_num_threads(2);
+ EXPECT_EQ(context.explicit_tuning(), Tuning::kGeneric);
+ EXPECT_EQ(context.max_num_threads(), 2);
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/cpu_cache_params.h b/ruy/cpu_cache_params.h
new file mode 100644
index 0000000..202cd54
--- /dev/null
+++ b/ruy/cpu_cache_params.h
@@ -0,0 +1,83 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_CPU_CACHE_PARAMS_H_
+#define RUY_RUY_CPU_CACHE_PARAMS_H_
+
+namespace ruy {
+
+// Holds some information about a CPU's data caches.
+//
+// Meaning of 'local': a 'local' cache means a cache that is used by only one
+// CPU core, not shared with other cores. It might still be used by multiple
+// 'processors' in case of SMT as in Intel HyperThreading. CPUs often have
+// multiple levels of local cache, e.g. L1 and L2. We typically return the
+// larger one, the assumption being that even the larger one has substantially
+// lower latency than any higher (non-local) cache, however as noted below (*)
+// the implementation may choose to ignore a cache level.
+//
+// Meaning of 'last level': this refers to some higher cache level, typically
+// shared among multiple CPU cores, so we considered using the terminology
+// 'shared' instead of 'last_level'. However that created some confusion of its
+// own, as the meaning of 'shared' varies between CPUs, with some CPUs not
+// having any level of cache shared among all cores. That is why we stick with
+// the 'last_level' terminology, however with the following caveats:
+// 1. As noted below (*) the implementation may choose to ignore a cache
+// level, which could cause the 'last level' cache according to ruy not to be
+// the actual last level.
+// 2. On some systems-on-chip there is a 'last level' cache outside of the
+// last level cache in the CPU complex. Ruy is not currently doing anything
+// specific regarding such caches.
+// 3. We haven't figured out how to amend our terminology to be meaningful
+// on NUMA architectures. NUMA hasn't been part of ruy's scope so far.
+//
+// (*) Note on ignoring certain cache levels:
+// The implementation may choose to ignore a cache if it's suspected not to
+// have compelling performance. This is true about all cache levels, but more
+// likely regarding the 'last level' cache. For example, a L4 cache may be
+// ignored if we believe that it's not the right latency/size compromise for us,
+// so on such a CPU, the L3 cache may be used as the 'last level' cache instead.
+//
+// (**) Note on CPUs with heterogeneous cores:
+// Some CPUs have multiple cores with different local caches. For example, some
+// ARM big.LITTLE CPUs have some CPU cores with L1=32k and L2=128k, and some
+// other CPU cores with L1=64k and L2=256k or even 512k. On such CPUs, the
+// fields in this struct refer to the minimum value over all cores. In other
+// words, we use conservative values that do not risk over-estimating local
+// cache sizes in case of a migration of our threads to smaller cores.
+//
+// Example:
+// On a Qualcomm S855 SoC, there are 8 CPU cores. Each core has L1 and L2 data
+// caches local to it:
+// - 4 cores have L1=32k, L2=128k.
+// - 3 cores have L1=64k, L2=256k.
+// - 1 core has L1=64k, L2=512k.
+// All 8 cores share a L3 cache of size 2M, and there is beyond that a SoC-level
+// cache of size 3M.
+// On such a system, we should have:
+// - local_level_cache_size=128k, the smallest L2 size.
+// - last_level_cache_size=2M, the L3 cache size, ignoring the SoC-level cache.
+struct CpuCacheParams final {
+ // Minimum value (see (**)), over all cores, of the size in bytes of its local
+ // cache (see "Meaning of 'local'").
+ int local_cache_size = 0;
+ // Minimum value (see (**)), over all cores, of the size in bytes of its last
+ // level cache (see "Meaning of 'last level'").
+ int last_level_cache_size = 0;
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_CPU_CACHE_PARAMS_H_
diff --git a/ruy/cpuinfo.cc b/ruy/cpuinfo.cc
new file mode 100644
index 0000000..b1f54bc
--- /dev/null
+++ b/ruy/cpuinfo.cc
@@ -0,0 +1,163 @@
+#include "ruy/cpuinfo.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+
+#include "ruy/check_macros.h"
+#include "ruy/cpu_cache_params.h"
+#include "ruy/platform.h"
+
+#ifdef RUY_HAVE_CPUINFO
+#include <cpuinfo.h>
+#endif
+
+namespace ruy {
+
+namespace {
+void MakeDummyCacheParams(CpuCacheParams* result) {
+ // Reasonable dummy values
+ result->local_cache_size = 32 * 1024;
+ result->last_level_cache_size = 512 * 1024;
+}
+} // end namespace
+
+#ifdef RUY_HAVE_CPUINFO
+
+CpuInfo::~CpuInfo() {
+ if (init_status_ == InitStatus::kInitialized) {
+ cpuinfo_deinitialize();
+ }
+}
+
+bool CpuInfo::EnsureInitialized() {
+ if (init_status_ == InitStatus::kNotYetAttempted) {
+ init_status_ = Initialize();
+ RUY_DCHECK_NE(init_status_, InitStatus::kNotYetAttempted);
+ }
+ return init_status_ == InitStatus::kInitialized;
+}
+
+namespace {
+void QueryCacheParams(CpuCacheParams* cache_params) {
+ const int processors_count = cpuinfo_get_processors_count();
+ RUY_DCHECK_GT(processors_count, 0);
+ int overall_local_cache_size = std::numeric_limits<int>::max();
+ int overall_last_level_cache_size = std::numeric_limits<int>::max();
+ for (int i = 0; i < processors_count; i++) {
+ int local_cache_size = 0;
+ int last_level_cache_size = 0;
+ const cpuinfo_processor* processor = cpuinfo_get_processor(i);
+ // Loop over cache levels. Ignoring L4 for now: it seems that in CPUs that
+ // have L4, we would still prefer to stay in lower-latency L3.
+ for (const cpuinfo_cache* cache :
+ {processor->cache.l1d, processor->cache.l2, processor->cache.l3}) {
+ if (!cache) {
+ continue; // continue, not break, it is possible to have L1+L3 but no
+ // L2.
+ }
+ const bool is_local =
+ cpuinfo_get_processor(cache->processor_start)->core ==
+ cpuinfo_get_processor(cache->processor_start +
+ cache->processor_count - 1)
+ ->core;
+ if (is_local) {
+ local_cache_size = cache->size;
+ }
+ last_level_cache_size = cache->size;
+ }
+ // If no local cache was found, use the last-level cache.
+ if (!local_cache_size) {
+ local_cache_size = last_level_cache_size;
+ }
+ RUY_DCHECK_GT(local_cache_size, 0);
+ RUY_DCHECK_GT(last_level_cache_size, 0);
+ RUY_DCHECK_GE(last_level_cache_size, local_cache_size);
+ overall_local_cache_size =
+ std::min(overall_local_cache_size, local_cache_size);
+ overall_last_level_cache_size =
+ std::min(overall_last_level_cache_size, last_level_cache_size);
+ }
+ cache_params->local_cache_size = overall_local_cache_size;
+ cache_params->last_level_cache_size = overall_last_level_cache_size;
+}
+} // end namespace
+
+CpuInfo::InitStatus CpuInfo::Initialize() {
+ RUY_DCHECK_EQ(init_status_, InitStatus::kNotYetAttempted);
+ if (!cpuinfo_initialize()) {
+ MakeDummyCacheParams(&cache_params_);
+ return InitStatus::kFailed;
+ }
+ QueryCacheParams(&cache_params_);
+ return InitStatus::kInitialized;
+}
+
+bool CpuInfo::NeonDotprod() {
+ return EnsureInitialized() && cpuinfo_has_arm_neon_dot();
+}
+
+bool CpuInfo::Sse42() {
+ return EnsureInitialized() && cpuinfo_has_x86_sse4_2();
+}
+
+bool CpuInfo::Avx2Fma() {
+ return EnsureInitialized() && cpuinfo_has_x86_avx2() &&
+ cpuinfo_has_x86_fma3();
+}
+
+bool CpuInfo::Avx() { return EnsureInitialized() && cpuinfo_has_x86_avx(); }
+
+bool CpuInfo::Avx512() {
+ return EnsureInitialized() && cpuinfo_has_x86_avx512f() &&
+ cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512cd() &&
+ cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512vl();
+}
+
+bool CpuInfo::AvxVnni() {
+ return EnsureInitialized() && cpuinfo_has_x86_avx512vnni();
+}
+
+bool CpuInfo::CurrentCpuIsA55ish() {
+ if (!EnsureInitialized()) {
+ return false;
+ }
+
+ switch (cpuinfo_get_uarch(cpuinfo_get_current_uarch_index())->uarch) {
+ case cpuinfo_uarch_cortex_a53:
+ case cpuinfo_uarch_cortex_a55r0:
+ case cpuinfo_uarch_cortex_a55:
+ return true;
+ default:
+ return false;
+ }
+}
+
+#else // not defined RUY_HAVE_CPUINFO
+
+CpuInfo::~CpuInfo() {}
+bool CpuInfo::EnsureInitialized() {
+ if (init_status_ == InitStatus::kNotYetAttempted) {
+ MakeDummyCacheParams(&cache_params_);
+ init_status_ = InitStatus::kInitialized;
+ }
+ RUY_DCHECK_EQ(init_status_, InitStatus::kInitialized);
+ return true;
+}
+bool CpuInfo::NeonDotprod() { return false; }
+bool CpuInfo::Sse42() { return false; }
+bool CpuInfo::Avx() { return false; }
+bool CpuInfo::Avx2Fma() { return false; }
+bool CpuInfo::Avx512() { return false; }
+bool CpuInfo::AvxVnni() { return false; }
+bool CpuInfo::CurrentCpuIsA55ish() { return false; }
+
+#endif
+
+const CpuCacheParams& CpuInfo::CacheParams() {
+ EnsureInitialized();
+ // On failure, EnsureInitialized leaves dummy values in cache_params_.
+ return cache_params_;
+}
+
+} // namespace ruy
diff --git a/ruy/cpuinfo.h b/ruy/cpuinfo.h
new file mode 100644
index 0000000..e45fa51
--- /dev/null
+++ b/ruy/cpuinfo.h
@@ -0,0 +1,61 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_CPUINFO_H_
+#define RUY_RUY_CPUINFO_H_
+
+#include "ruy/cpu_cache_params.h"
+
+namespace ruy {
+
+// Wraps the functionality that ruy needs from the cpuinfo library.
+class CpuInfo final {
+ public:
+ CpuInfo() {}
+ ~CpuInfo();
+
+ // ARM features
+ bool NeonDotprod();
+
+ // X86 features
+ bool Sse42();
+ bool Avx();
+ bool Avx2Fma();
+ bool Avx512();
+ bool AvxVnni();
+
+ // Common features
+ const CpuCacheParams& CacheParams();
+ bool CurrentCpuIsA55ish();
+
+ private:
+ enum class InitStatus {
+ kNotYetAttempted,
+ kInitialized,
+ kFailed,
+ };
+
+ InitStatus init_status_ = InitStatus::kNotYetAttempted;
+ CpuCacheParams cache_params_;
+
+ bool EnsureInitialized();
+ InitStatus Initialize();
+
+ CpuInfo(const CpuInfo&) = delete;
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_CPUINFO_H_
diff --git a/ruy/create_trmul_params.h b/ruy/create_trmul_params.h
new file mode 100644
index 0000000..531e066
--- /dev/null
+++ b/ruy/create_trmul_params.h
@@ -0,0 +1,484 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required_capacity by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Implementation of CreateTrMulParams, see function comment.
+
+#ifndef RUY_RUY_CREATE_TRMUL_PARAMS_H_
+#define RUY_RUY_CREATE_TRMUL_PARAMS_H_
+
+#include <cstdint>
+#include <cstring>
+#include <type_traits>
+
+#include "ruy/allocator.h"
+#include "ruy/ctx.h"
+#include "ruy/kernel.h"
+#include "ruy/mat.h"
+#include "ruy/mul_params.h"
+#include "ruy/pack.h"
+#include "ruy/path.h"
+#include "ruy/performance_advisory.h"
+#include "ruy/trace.h"
+#include "ruy/trmul_params.h"
+
+namespace ruy {
+// While the only entry point to this file is CreateTrMulParams, its templatized
+// nature requires putting more code in this header than we would like. This
+// internal implementation code is enclosed in namespace 'detail'.
+namespace detail {
+
+inline void CreatePackedLayout(const MatLayout& src,
+ const KernelLayout& kernel_layout,
+ PMatLayout* packed_layout) {
+ // Packed matrices are always column-major, because in TrMul that is always
+ // the dimension of traversal of the kernel's inner loop.
+ packed_layout->order = Order::kColMajor;
+ packed_layout->rows = round_up_pot(src.rows, kernel_layout.rows);
+ packed_layout->cols = round_up_pot(src.cols, kernel_layout.cols);
+ packed_layout->stride = packed_layout->rows;
+ packed_layout->kernel = kernel_layout;
+}
+
+template <typename Scalar, typename PackedScalar>
+void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout,
+ TrMulParams* params) {
+ // Ruy always uses 32-bit signed accumulators for quantized
+ // matrix multiplication, so we would like to always use std::int32_t
+ // unconditionally for SumsType.
+ // However, for floating point types, we still need a reasonable type here to
+ // avoid tripping assertions elsewhere in the code.
+ using SumsType =
+ typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
+ std::int32_t>::type;
+
+ const EMat& src = params->src[side];
+ PEMat* packed_matrix = &params->packed_matrix[side];
+ packed_matrix->data_type = Type::Create<PackedScalar>();
+ packed_matrix->sums_type = Type::Create<SumsType>();
+ CreatePackedLayout(src.layout, kernel_layout, &packed_matrix->layout);
+ packed_matrix->zero_point = Pack<PackedScalar, Scalar>(src.zero_point);
+}
+
+template <typename KernelType>
+struct CheckKernelPathImpl {
+ static void Run(Path) {
+ // Do nothing.
+ // Path fallbacks are normal in general (see RUY_INHERIT_KERNEL).
+ // That is to say that one may instantiate ruy::Mul with a weird combination
+ // of types, such as LhsScalar==float and RhsScalar==double, and have it
+ // work by silently falling back to Path::kStandardCpp. Only in specific
+ // cases do we have dedicated kernels overriding that fallback, and that is
+ // what partial specializations of this template will check.
+ }
+};
+
+#if RUY_DCHECK_IS_ENABLED
+template <Path ThePath, typename SrcScalar, typename AccumScalar,
+ typename DstScalar>
+struct CheckKernelPathImpl<Kernel<ThePath, SrcScalar, SrcScalar, DstScalar,
+ MulParams<AccumScalar, DstScalar>>>
+ final {
+ using KernelType = Kernel<ThePath, SrcScalar, SrcScalar, DstScalar,
+ MulParams<AccumScalar, DstScalar>>;
+ static void Run(Path expected_path) {
+ // We want to assert that we are using a dedicated Kernel specialization and
+ // not a fallback when we know we are in a case where such a kernel
+ // specialization exists. At the moment in the current state of ruy's
+ // architecture support for ARM and x86, that is when LhsScalar==RhsScalar
+ // (already implied in this partial specialization) and when that type is
+ // either float, int8, or uint8. Indeed, we have kernels supporting float
+ // and int8, and we have the packing code converting uint8 to int8 (see
+ // PackedTypeImpl).
+ static constexpr bool kSrcScalarTypeSupportsFastKernels =
+ std::is_same<SrcScalar, float>::value ||
+ std::is_same<SrcScalar, std::int8_t>::value ||
+ std::is_same<SrcScalar, std::uint8_t>::value;
+ if (kSrcScalarTypeSupportsFastKernels) {
+ RUY_DCHECK_EQ(expected_path, KernelType::kPath);
+ }
+ }
+};
+#endif
+
+template <typename KernelType>
+void CheckKernelPath(Path expected_path) {
+ CheckKernelPathImpl<KernelType>::Run(expected_path);
+}
+
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+void PopulateTrMulParams(TrMulParams* params) {
+ RUY_TRACE_SCOPE;
+ using PackedLhsScalar = PackedType<ThePath, LhsScalar>;
+ using PackedRhsScalar = PackedType<ThePath, RhsScalar>;
+ using Kernel =
+ Kernel<ThePath, PackedLhsScalar, PackedRhsScalar, AccumScalar, DstScalar>;
+ using LhsKernelLayout = typename Kernel::LhsLayout;
+ using RhsKernelLayout = typename Kernel::RhsLayout;
+
+ params->path = ThePath;
+
+ CreatePackedMatrix<LhsScalar, PackedLhsScalar>(
+ Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params);
+ CreatePackedMatrix<RhsScalar, PackedRhsScalar>(
+ Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params);
+ params->run_pack[Side::kLhs] =
+ &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>;
+ params->run_pack[Side::kRhs] =
+ &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>;
+ params->run_kernel = &RunKernel<Kernel>::Run;
+ CheckKernelPath<Kernel>(ThePath);
+ RUY_TRACE_INFO(POPULATE_TRMUL_PARAMS);
+}
+
+// PopulateTrMulParamsAllCompiledPaths calls into one of multiple
+// instantiations of PopulateTrMulParams. For each bit that is set in
+// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path
+// corresponding to that single bit. The call to PopulateTrMulParams is
+// guarded by a runtime check that it is in fact the dynamically selected path.
+//
+// PopulateTrMulParamsAllCompiledPaths is implemented with template
+// metaprogramming by mutual recursion between PathSearchCountdown and
+// PathSearchCompiledPaths.
+//
+// PopulateTrMulParamsAllCompiledPaths is logically implementing the following
+// computation:
+//
+// template <Path CompiledPaths>
+// void PopulateTrMulParamsAllCompiledPaths(Path the_path,
+// TrMulParams* params) {
+// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1]
+// Path current_path = static_cast<Path>(1 << bit);
+// if ((CompiledPaths & current_path) != Path::kNone) { // [2]
+// if (current_path == the_path) { // [3]
+// PopulateTrMulParams<current_path, ...>(the_path, params);
+// return;
+// }
+// }
+// }
+// }
+//
+//
+//
+// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is
+// done in the recursion of PathSearchOnlyCompiledPaths.
+// [2] - Done by PathSearchOnlyCompiledPaths's partial template
+// specialization on InCompiledPaths. This is the check which necessitates
+// doing the whole computation at C++ compile time.
+// [3] - Done by the `if` in the main definition of
+// PathSearchOnlyCompiledPaths.
+//
+// The template metaprogramming is necessary because:
+// - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++
+// compile-time constant.
+// - PopulateTrMulParamsAllCompiledPaths must not instantiate
+// inner loops for paths that are not in CompiledPaths, since that can result in
+// bogus instantiations which cause a compile time failure.
+template <Path CompiledPaths, int BitNumber, typename LhsScalar,
+ typename RhsScalar, typename AccumScalar, typename DstScalar>
+struct PathSearchCountdown;
+
+template <Path CompiledPaths, bool InCompiledPaths, int BitNumber,
+ typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+struct PathSearchOnlyCompiledPaths {
+ static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
+ static void Search(Path the_path, TrMulParams* params) {
+ if (kCurrentPath == the_path) {
+ PopulateTrMulParams<kCurrentPath, LhsScalar, RhsScalar, AccumScalar,
+ DstScalar>(params);
+ return;
+ }
+ PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
+ AccumScalar, DstScalar>::Search(the_path, params);
+ }
+};
+
+// Skip this iteration if CompiledPaths doesn't contain the specified path.
+template <Path CompiledPaths, int BitNumber, typename LhsScalar,
+ typename RhsScalar, typename AccumScalar, typename DstScalar>
+struct PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar,
+ RhsScalar, AccumScalar, DstScalar> {
+ static void Search(Path the_path, TrMulParams* params) {
+ PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
+ AccumScalar, DstScalar>::Search(the_path, params);
+ }
+};
+
+template <Path CompiledPaths, int BitNumber, typename LhsScalar,
+ typename RhsScalar, typename AccumScalar, typename DstScalar>
+struct PathSearchCountdown {
+ static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
+ static void Search(Path the_path, TrMulParams* params) {
+ PathSearchOnlyCompiledPaths<
+ CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber,
+ LhsScalar, RhsScalar, AccumScalar, DstScalar>::Search(the_path, params);
+ }
+};
+
+// Termination of the countdown. If the counter reaches -1, then we haven't
+// found the specified path.
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+struct PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, AccumScalar,
+ DstScalar> {
+ static void Search(Path, TrMulParams*) { RUY_DCHECK(false); }
+};
+
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) {
+ RUY_TRACE_SCOPE;
+ return PathSearchCountdown<CompiledPaths, 8 * sizeof(Path) - 1, LhsScalar,
+ RhsScalar, AccumScalar,
+ DstScalar>::Search(the_path, params);
+}
+
+template <typename AccumScalar, typename DstScalar>
+void AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized(
+ const MulParams<AccumScalar, DstScalar>& mul_params, int user_size,
+ int user_capacity) {
+#if RUY_DCHECK_IS_ENABLED
+ if (mul_params.bias()) {
+ for (int i = user_size; i < user_capacity; i++) {
+ RUY_DCHECK_EQ(mul_params.bias()[i], 0);
+ }
+ }
+ if (mul_params.multiplier_fixedpoint_perchannel()) {
+ for (int i = user_size; i < user_capacity; i++) {
+ RUY_DCHECK_EQ(mul_params.multiplier_fixedpoint_perchannel()[i], 0);
+ }
+ }
+ if (mul_params.multiplier_exponent_perchannel()) {
+ for (int i = user_size; i < user_capacity; i++) {
+ RUY_DCHECK_EQ(mul_params.multiplier_exponent_perchannel()[i], 0);
+ }
+ }
+#else
+ (void)mul_params;
+ (void)user_size;
+ (void)user_capacity;
+#endif
+}
+
+template <typename AccumScalar, typename DstScalar,
+ bool HaveQuantizedMultipliers =
+ std::is_same<AccumScalar, std::int32_t>::value &&
+ !std::is_same<DstScalar, std::int32_t>::value>
+struct EnsurePerChannelBuffersLargeEnoughImpl {
+ static void Run(const TrMulParams& params, Allocator* allocator,
+ MulParams<AccumScalar, DstScalar>* mul_params) {
+ const Side channel_side =
+ mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
+ : Side::kRhs;
+ const int required_capacity =
+ params.packed_matrix[channel_side].layout.cols;
+ const int user_size = params.src[channel_side].layout.cols;
+ const int user_capacity = round_up_pot(
+ user_size, mul_params->perchannel_buffers_capacity_rounding());
+ // We should have already checked earlier for the case where
+ // user_capacity >= required_capacity.
+ RUY_DCHECK_GT(required_capacity, user_capacity);
+ if (mul_params->bias()) {
+ AccumScalar* new_data =
+ allocator->Allocate<AccumScalar>(required_capacity);
+ std::memcpy(new_data, mul_params->bias(),
+ user_size * sizeof(AccumScalar));
+ std::memset(new_data + user_size, 0,
+ (required_capacity - user_size) * sizeof(AccumScalar));
+ mul_params->set_bias(new_data);
+ }
+ if (mul_params->multiplier_fixedpoint_perchannel()) {
+ AccumScalar* new_data =
+ allocator->Allocate<AccumScalar>(required_capacity);
+ std::memcpy(new_data, mul_params->multiplier_fixedpoint_perchannel(),
+ user_size * sizeof(AccumScalar));
+ std::memset(new_data + user_size, 0,
+ (required_capacity - user_size) * sizeof(AccumScalar));
+ mul_params->set_multiplier_fixedpoint_perchannel(new_data);
+ }
+ if (mul_params->multiplier_exponent_perchannel()) {
+ int* new_data = allocator->Allocate<int>(required_capacity);
+ std::memcpy(new_data, mul_params->multiplier_exponent_perchannel(),
+ user_size * sizeof(int));
+ std::memset(new_data + user_size, 0,
+ (required_capacity - user_size) * sizeof(int));
+ mul_params->set_multiplier_exponent_perchannel(new_data);
+ }
+ }
+};
+
+template <typename AccumScalar, typename DstScalar>
+struct EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar, false> {
+ static void Run(const TrMulParams& params, Allocator* allocator,
+ MulParams<AccumScalar, DstScalar>* mul_params) {
+ const Side channel_side =
+ mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
+ : Side::kRhs;
+ const int required_capacity =
+ params.packed_matrix[channel_side].layout.cols;
+ const int user_size = params.src[channel_side].layout.cols;
+ const int user_capacity = round_up_pot(
+ user_size, mul_params->perchannel_buffers_capacity_rounding());
+ // We should have already checked earlier for the case where
+ // user_capacity >= required_capacity.
+ RUY_DCHECK_GT(required_capacity, user_capacity);
+ if (mul_params->bias()) {
+ AccumScalar* new_data =
+ allocator->Allocate<AccumScalar>(required_capacity);
+ std::memcpy(new_data, mul_params->bias(),
+ user_size * sizeof(AccumScalar));
+ std::memset(new_data + user_size, 0,
+ (required_capacity - user_size) * sizeof(AccumScalar));
+ mul_params->set_bias(new_data);
+ }
+ }
+};
+
+template <typename AccumScalar, typename DstScalar>
+void EnsurePerChannelBuffersLargeEnough(
+ const TrMulParams& params, Ctx* ctx,
+ MulParams<AccumScalar, DstScalar>* mul_params) {
+ // Early exit in the common case where the packed matrix size matches the
+ // number of channels (as opposed to having been rounded up to a slightly
+ // larger value).
+ const Side channel_side =
+ mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
+ : Side::kRhs;
+ const int required_capacity = params.packed_matrix[channel_side].layout.cols;
+ const int user_size = params.src[channel_side].layout.cols;
+ const int user_capacity = round_up_pot(
+ user_size, mul_params->perchannel_buffers_capacity_rounding());
+ AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized(
+ *mul_params, user_size, user_capacity);
+ if (required_capacity <= user_capacity) {
+ return;
+ }
+ ctx->set_performance_advisory(
+ PerformanceAdvisory::kReallocatedPerChannelBuffer);
+ EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar>::Run(
+ params, ctx->GetMainAllocator(), mul_params);
+}
+
+// Ensures that `params->mul_params_bytes` contains MulParams data that's ready
+// to be consumed by the kernel. As a first-order approximation, that is simply
+// copying the user-provided `mul_params`, however there are a few changes.
+//
+// 1. The specified `channel_dimension` value overrides the channel_dimension
+// member in `mul_params`. The reason why `channel_dimension` is being
+// special-cased among MulParams members is that we will need to transpose
+// MulParams, and that consists just in toggling channel_dimension.
+// 2. Per-channel buffers may be reallocated, see
+// EnsurePerChannelBuffersLargeEnough.
+template <typename AccumScalar, typename DstScalar>
+void FinalizeMulParams(const MulParams<AccumScalar, DstScalar>& mul_params,
+ ChannelDimension channel_dimension, Ctx* ctx,
+ TrMulParams* params) {
+ using MulParamsType = MulParams<AccumScalar, DstScalar>;
+ static_assert(alignof(MulParamsType) <= kMaxMulParamsAlignment, "");
+ static_assert(sizeof(MulParamsType) <= kMaxMulParamsSize, "");
+ static_assert(std::is_trivially_copyable<MulParamsType>::value, "");
+ auto* dst_mul_params =
+ reinterpret_cast<MulParamsType*>(params->mul_params_bytes);
+ std::memcpy(dst_mul_params, &mul_params, sizeof(MulParamsType));
+ dst_mul_params->set_channel_dimension(channel_dimension);
+ EnsurePerChannelBuffersLargeEnough(*params, ctx, dst_mul_params);
+}
+
+// In this function, the `channel_dimension` parameter overrides the value
+// of the channel_dimension member in the `mul_params` parameter. See the
+// FinalizeMulParams comment.
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+void CreateTrMulParamsAssumingColMajorDst(
+ const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
+ const Mat<DstScalar>& dst,
+ const MulParams<AccumScalar, DstScalar>& mul_params,
+ ChannelDimension channel_dimension, Ctx* ctx, TrMulParams* params) {
+ RUY_TRACE_SCOPE;
+ RUY_DCHECK(IsColMajor(dst.layout));
+
+ // Fill in the fields we already know.
+ params->src[Side::kLhs] = EraseType(lhs);
+ params->src[Side::kRhs] = EraseType(rhs);
+ params->dst = EraseType(dst);
+
+ // Determine which exact Path we're going to take in this Mul call.
+ // This is cheap because it's cached in `ctx`. In user scenarios this always
+ // evaluates to the same value on a given machine with given `CompiledPaths`,
+ // but could be invalidated by a call to Ctx::SetRuntimeEnabledPaths(), which
+ // might be exposed publicly in Context in the future.
+ const Path the_path = ctx->SelectPath(CompiledPaths);
+
+ RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST);
+
+ // If we ever need again to fall back to Path::kStandardCpp, this is a good
+ // place to do it -- just pass Path::kStandardCpp as both the template and
+ // runtime parameters in this function call.
+ // In the past we did that here (as version control history remembers).
+ // A typical reason why we might need to resurrect that is if we implement
+ // a new Path (i.e. port to a new ISA) and need to subdivide that work into
+ // a series of incremental changes.
+ PopulateTrMulParamsAllCompiledPaths<CompiledPaths, LhsScalar, RhsScalar,
+ AccumScalar, DstScalar>(the_path, params);
+
+ // This must be done last, as it depends on the specific choice of kernel.
+ // Specifically, the EnsurePerChannelBuffersLargeEnough part of this will read
+ // the packed matrix layouts that are written to `params` by the above
+ // PopulateTrMulParams* call.
+ FinalizeMulParams(mul_params, channel_dimension, ctx, params);
+}
+
+} // namespace detail
+
+inline ChannelDimension Transpose(ChannelDimension channel_dimension) {
+ return channel_dimension == ChannelDimension::kCol ? ChannelDimension::kRow
+ : ChannelDimension::kCol;
+}
+
+// CreateTrMulParams's output is a TrMulParams object that encodes
+// all of the input information required_capacity by the middle-end, that is,
+// the TrMul function.
+//
+// CreateTrMulParams performs the following tasks:
+// 1. Reduce to the case of column-major destination, by transposing the
+// whole problem as needed.
+// 2. Select the single code path to be taken, out of the set of paths
+// described by the `CompiledPaths` template parameter, based on the
+// runtime input parameter `the_path`.
+// 3. Perform type-erasure, converting templatized typed input parameters
+// to the un-typed data stored in TrMulParams.
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+void CreateTrMulParams(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
+ const Mat<DstScalar>& dst,
+ const MulParams<AccumScalar, DstScalar>& mul_params,
+ Ctx* ctx, TrMulParams* params) {
+ RUY_TRACE_SCOPE;
+ ChannelDimension channel_dimension = mul_params.channel_dimension();
+ if (IsColMajor(dst.layout)) {
+ detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>(
+ lhs, rhs, dst, mul_params, channel_dimension, ctx, params);
+ } else {
+ RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_TRANSPOSING);
+ detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>(
+ rhs, lhs, Transpose(dst), mul_params, Transpose(channel_dimension), ctx,
+ params);
+ }
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_CREATE_TRMUL_PARAMS_H_
diff --git a/ruy/ctx.cc b/ruy/ctx.cc
new file mode 100644
index 0000000..0ef098d
--- /dev/null
+++ b/ruy/ctx.cc
@@ -0,0 +1,216 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/ctx.h"
+
+#include <cstdlib>
+#include <functional>
+#include <string>
+
+#include "ruy/check_macros.h"
+#include "ruy/cpuinfo.h"
+#include "ruy/ctx_impl.h"
+#include "ruy/have_built_path_for.h"
+#include "ruy/path.h"
+#include "ruy/performance_advisory.h"
+#include "ruy/platform.h"
+#include "ruy/prepacked_cache.h"
+#include "ruy/trace.h"
+
+namespace ruy {
+
+const CtxImpl& Ctx::impl() const { return static_cast<const CtxImpl&>(*this); }
+CtxImpl* Ctx::mutable_impl() { return static_cast<CtxImpl*>(this); }
+
+Path Ctx::last_used_path() const { return impl().last_used_path_; }
+Tuning Ctx::explicit_tuning() const { return impl().explicit_tuning_; }
+void Ctx::set_explicit_tuning(Tuning value) {
+ mutable_impl()->explicit_tuning_ = value;
+}
+const ThreadPool& Ctx::thread_pool() const { return impl().thread_pool_; }
+ThreadPool* Ctx::mutable_thread_pool() { return &mutable_impl()->thread_pool_; }
+int Ctx::max_num_threads() const { return impl().max_num_threads_; }
+void Ctx::set_max_num_threads(int value) {
+ mutable_impl()->max_num_threads_ = value;
+}
+void Ctx::clear_performance_advisories() {
+ mutable_impl()->performance_advisory_ = PerformanceAdvisory::kNone;
+}
+void Ctx::set_performance_advisory(PerformanceAdvisory advisory) {
+ mutable_impl()->performance_advisory_ =
+ mutable_impl()->performance_advisory_ | advisory;
+}
+bool Ctx::performance_advisory(PerformanceAdvisory advisory) const {
+ return (impl().performance_advisory_ & advisory) !=
+ PerformanceAdvisory::kNone;
+}
+
+void Ctx::SetRuntimeEnabledPaths(Path paths) {
+ if (paths == Path::kNone) {
+ // Revert to default behavior using runtime detection.
+ mutable_impl()->runtime_enabled_paths_ = Path::kNone;
+ } else {
+ // Explicitly set enabled paths. Ensure that non-arch are always enabled
+ // (needed for fallbacks).
+ mutable_impl()->runtime_enabled_paths_ = paths | kNonArchPaths;
+ }
+}
+
+CpuInfo* Ctx::mutable_cpuinfo() { return &mutable_impl()->cpuinfo_; }
+
+namespace {
+
+int GetHexIntEnvVarOrZero(const char* name) {
+ const char* val = getenv(name);
+ if (!val) {
+ return 0;
+ }
+ return std::stoi(val, nullptr, 16);
+}
+
+// For each Path bit set in `paths_to_test`, performs runtime detection and
+// sets the corresponding bit in the return value if and only if it is
+// supported. Path bits that are not set in the input
+// `paths_to_detect` value are also left not set in the return value.
+Path DetectRuntimeSupportedPaths(Path paths_to_detect, CpuInfo* cpuinfo) {
+ // Paths in kNonArchPathsIncludingInternalVariants are always implicitly
+ // supported. Further logic below may add more bits to `results`.
+ Path result = kNonArchPathsIncludingInternalVariants;
+
+ // Conditionally sets the `path` bit in `result`, if reported as supported
+ // by the `is_supported` predicate.
+ auto maybe_add = [&](Path path, std::function<bool(void)> is_supported) {
+ if ((paths_to_detect & path) != Path::kNone) {
+ if (is_supported()) {
+ result = result | path;
+ }
+ }
+ };
+
+#if RUY_PLATFORM_ARM
+ // NEON is unconditionally available on ARM64.
+ // On ARM32 it's technically possible for it to be unavailable, but we've
+ // always chosen to just crash on such devices. We could reevaluate that,
+ // however for non-NEON devices to be actually supported, we would need to
+ // address also compiler-generated NEON code. That would mean to remove
+ // -mfpu=neon from ruy_copts and only use this flag in select NEON translation
+ // units, and implement have_built_path_for_neon, similar to the x86 SIMD
+ // paths.
+ maybe_add(Path::kNeon, []() { return true; });
+
+ // NEON dotprod requires runtime detection, however unlike the x86 SIMD paths
+ // it still does not require have_built_path_for because we unconditionally
+ // build it at the moment. That is largely because we have had to machine
+ // encode dotprod instructions, so we don't actually rely on toolchain support
+ // for them.
+ maybe_add(Path::kNeonDotprod, [=]() { return cpuinfo->NeonDotprod(); });
+#elif RUY_PLATFORM_X86
+ // x86 SIMD paths currently require both runtime detection, and detection of
+ // whether we're building the path at all.
+ maybe_add(Path::kAvx,
+ [=]() { return HaveBuiltPathForAvx() && cpuinfo->Avx(); });
+ maybe_add(Path::kAvx2Fma,
+ [=]() { return HaveBuiltPathForAvx2Fma() && cpuinfo->Avx2Fma(); });
+ maybe_add(Path::kAvx512,
+ [=]() { return HaveBuiltPathForAvx512() && cpuinfo->Avx512(); });
+#else
+ (void)maybe_add;
+ (void)cpuinfo;
+#endif
+
+ // Sanity checks
+ RUY_DCHECK_EQ(kNonArchPaths & ~result, Path::kNone);
+ RUY_DCHECK_EQ(
+ result & ~(kNonArchPathsIncludingInternalVariants | paths_to_detect),
+ Path::kNone);
+ return result;
+}
+
+} // namespace
+
+Path Ctx::GetRuntimeEnabledPaths() {
+ RUY_TRACE_SCOPE;
+ // Just a shorthand alias. Using a pointer to make it clear we're mutating
+ // this value in-place.
+ Path* paths = &mutable_impl()->runtime_enabled_paths_;
+
+ // The value Path::kNone indicates the initial state before detection has been
+ // performed.
+ if (*paths != Path::kNone) {
+ RUY_TRACE_INFO(GET_RUNTIME_ENABLED_PATHS_USING_SET_VALUE);
+ return *paths;
+ }
+ // User may have set path explicitly in env var.
+ Path paths_bitfield = static_cast<Path>(GetHexIntEnvVarOrZero("RUY_PATHS"));
+ if (paths_bitfield != Path::kNone) {
+ *paths = paths_bitfield;
+ RUY_TRACE_INFO(GET_RUNTIME_ENABLED_PATHS_USING_ENV_VAR);
+ return *paths;
+ }
+ // Finally, use runtime detection.
+ *paths = DetectRuntimeSupportedPaths(kAllPaths, mutable_cpuinfo());
+ RUY_TRACE_INFO(GET_RUNTIME_ENABLED_PATHS_USING_DETECTION);
+ return *paths;
+}
+
+Path Ctx::SelectPath(Path compiled_paths) {
+ return mutable_impl()->last_used_path_ =
+ GetMostSignificantPath(compiled_paths & GetRuntimeEnabledPaths());
+}
+
+void Ctx::EnsureThreadSpecificResources(int thread_count) {
+ auto& resources = mutable_impl()->thread_specific_resources_;
+ while (thread_count > static_cast<int>(resources.size())) {
+ resources.emplace_back(new ThreadSpecificResource);
+ }
+ RUY_DCHECK_LE(thread_count, static_cast<int>(resources.size()));
+}
+
+TuningResolver* Ctx::GetThreadSpecificTuningResolver(int thread_index) const {
+ const auto& resources = impl().thread_specific_resources_;
+ RUY_DCHECK_LT(thread_index, static_cast<int>(resources.size()));
+ return &resources[thread_index]->tuning_resolver;
+}
+
+Allocator* Ctx::GetThreadSpecificAllocator(int thread_index) const {
+ const auto& resources = impl().thread_specific_resources_;
+ RUY_DCHECK_LT(thread_index, static_cast<int>(resources.size()));
+ return &resources[thread_index]->allocator;
+}
+
+Allocator* Ctx::GetMainAllocator() {
+ if (!impl().main_allocator_) {
+ mutable_impl()->main_allocator_.reset(new Allocator);
+ }
+ return impl().main_allocator_.get();
+}
+
+PrepackedCache* Ctx::GetPrepackedCache() {
+ if (!impl().prepacked_cache_) {
+ mutable_impl()->prepacked_cache_.reset(new PrepackedCache);
+ }
+ return impl().prepacked_cache_.get();
+}
+
+Tuning Ctx::GetMainThreadTuning() {
+ EnsureThreadSpecificResources(1);
+ TuningResolver* tuning_resolver = GetThreadSpecificTuningResolver(0);
+ tuning_resolver->SetTuning(explicit_tuning());
+ return tuning_resolver->Resolve(mutable_cpuinfo());
+}
+
+void Ctx::ClearPrepackedCache() { mutable_impl()->prepacked_cache_ = nullptr; }
+
+} // namespace ruy
diff --git a/ruy/ctx.h b/ruy/ctx.h
new file mode 100644
index 0000000..df9dee2
--- /dev/null
+++ b/ruy/ctx.h
@@ -0,0 +1,91 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Ctx is the internal context interface class used by most of ruy's own code.
+// It is subclassed by CtxImpl which provides the actual data members.
+
+#ifndef RUY_RUY_CTX_H_
+#define RUY_RUY_CTX_H_
+
+#include <cstdint>
+
+namespace ruy {
+
+class CtxImpl;
+class ThreadPool;
+class Allocator;
+class TuningResolver;
+class PrepackedCache;
+class CpuInfo;
+enum class Path : std::uint8_t;
+enum class Tuning;
+enum class PerformanceAdvisory;
+
+// Ctx is the internal context class used throughout ruy code. Whereas Context
+// is exposed to users, Ctx is internal to ruy. As many of ruy's internal
+// headers, included by ruy public headers, need to use Ctx, it is important
+// that it does not include definition of all the actual data members. This is
+// solved by a variant of the 'pimpl' idiom, where instead of being implemented
+// in the usual way with a pointer member, it is implemented in a subclass,
+// CtxImpl.
+class Ctx /* not final, subclassed by CtxImpl */ {
+ public:
+ Path last_used_path() const;
+ Tuning explicit_tuning() const;
+ void set_explicit_tuning(Tuning value);
+ const ThreadPool& thread_pool() const;
+ ThreadPool* mutable_thread_pool();
+ int max_num_threads() const;
+ void set_max_num_threads(int value);
+ CpuInfo* mutable_cpuinfo();
+ void clear_performance_advisories();
+ void set_performance_advisory(PerformanceAdvisory advisory);
+ bool performance_advisory(PerformanceAdvisory advisory) const;
+
+ // Returns the set of Path's that are available. By default, this is based on
+ // runtime detection of CPU features, as well as on which code paths were
+ // built. Detection results are stored on the context object so that
+ // subsequent calls are fast. This is overridden by SetRuntimeEnabledPaths.
+ Path GetRuntimeEnabledPaths();
+
+ // Override auto-detection of supported code paths.
+ //
+ // Passing `paths == Path::kNone` means reverting to the default behavior.
+ // This will trigger auto-detection on the next use.
+ //
+ // Other values will override auto-detection with the explicitly provided set
+ // of paths.
+ //
+ // Paths in kNonArchPaths are always implicitly supported.
+ void SetRuntimeEnabledPaths(Path paths);
+
+ Path SelectPath(Path compiled_paths);
+ void EnsureThreadSpecificResources(int thread_count);
+ TuningResolver* GetThreadSpecificTuningResolver(int thread_index) const;
+ Allocator* GetThreadSpecificAllocator(int thread_index) const;
+ Allocator* GetMainAllocator();
+ PrepackedCache* GetPrepackedCache();
+ Tuning GetMainThreadTuning();
+ void ClearPrepackedCache();
+
+ private:
+ // Downcast helpers.
+ const CtxImpl& impl() const;
+ CtxImpl* mutable_impl();
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_CTX_H_
diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h
new file mode 100644
index 0000000..0a07ef6
--- /dev/null
+++ b/ruy/ctx_impl.h
@@ -0,0 +1,84 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Internal implementation details for Ctx. Drags in the entire world. Avoid
+// #including this, use "ctx.h" instead.
+
+#ifndef RUY_RUY_CTX_IMPL_H_
+#define RUY_RUY_CTX_IMPL_H_
+
+#include <cstddef>
+#include <memory>
+#include <vector>
+
+#include "ruy/allocator.h"
+#include "ruy/cpuinfo.h"
+#include "ruy/ctx.h"
+#include "ruy/path.h"
+#include "ruy/performance_advisory.h"
+#include "ruy/prepacked_cache.h"
+#include "ruy/thread_pool.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+// The resources private to each Ruy thread.
+struct ThreadSpecificResource final {
+ // Each thread may be running on a different microarchitecture. For example,
+ // some threads may be on big cores, while others are on little cores. Thus,
+ // it's best for the tuning to be per-thread.
+ TuningResolver tuning_resolver;
+ // Each thread has its own local allocator.
+ Allocator allocator;
+};
+
+// CtxImpl is what actually holds all the data members in a context.
+// It is a subclass of Ctx, which provides the interface that is what most
+// of ruy's code needs.
+//
+// A key requirement is that since many ruy files, including public headers,
+// need a definition of Ctx, the "ctx.h" header defining it must minimize how
+// many other ruy internal headers it includes. That is achieved by putting data
+// members in the CtxImpl subclass, and ensuring that only a few .cc files, not
+// header files, need a definition of CtxImpl.
+class CtxImpl final : public Ctx {
+ private:
+ friend class Ctx;
+
+ // Single Path bit indicating which Path was used last.
+ Path last_used_path_ = Path::kNone;
+ PerformanceAdvisory performance_advisory_ = PerformanceAdvisory::kNone;
+ Tuning explicit_tuning_ = Tuning::kAuto;
+ ThreadPool thread_pool_;
+ int max_num_threads_ = 1;
+ // Allocator for main thread work before invoking the threadpool.
+ // Our simple Allocator does not allow reserving/allocating more blocks
+ // while it's already in committed state, so the main thread needs both
+ // this allocator, and its per-thread allocator.
+ std::unique_ptr<Allocator> main_allocator_;
+ std::unique_ptr<PrepackedCache> prepacked_cache_;
+ // Set of Paths enabled at runtime. By default, that is based on runtime
+ // detection, but may be overridden. The initial value kNone
+ // means that detection has not yet been performed.
+ Path runtime_enabled_paths_ = Path::kNone;
+ CpuInfo cpuinfo_;
+ // State for each thread in the thread pool. Entry 0 is the main thread.
+ std::vector<std::unique_ptr<ThreadSpecificResource>>
+ thread_specific_resources_;
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_CTX_IMPL_H_
diff --git a/ruy/ctx_test.cc b/ruy/ctx_test.cc
new file mode 100644
index 0000000..e55dcfc
--- /dev/null
+++ b/ruy/ctx_test.cc
@@ -0,0 +1,76 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/ctx_impl.h"
+#include "ruy/gtest_wrapper.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+
+namespace ruy {
+namespace {
+
+TEST(ContextInternalTest, EnabledPathsGeneral) {
+ CtxImpl ctx;
+ const auto ruy_paths = ctx.GetRuntimeEnabledPaths();
+ const auto ruy_paths_repeat = ctx.GetRuntimeEnabledPaths();
+ ASSERT_EQ(ruy_paths, ruy_paths_repeat);
+ EXPECT_NE(ruy_paths, Path::kNone);
+ EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp);
+}
+
+#if RUY_PLATFORM_X86
+TEST(ContextInternalTest, EnabledPathsX86Explicit) {
+ CtxImpl ctx;
+ ctx.SetRuntimeEnabledPaths(Path::kAvx2Fma);
+ const auto ruy_paths = ctx.GetRuntimeEnabledPaths();
+ EXPECT_EQ(ruy_paths, Path::kStandardCpp | Path::kAvx2Fma);
+}
+#endif // RUY_PLATFORM_X86
+
+#if RUY_PLATFORM_ARM
+TEST(ContextInternalTest, EnabledPathsX86Explicit) {
+ CtxImpl ctx;
+ ctx.SetRuntimeEnabledPaths(Path::kNeonDotprod);
+ const auto ruy_paths = ctx.GetRuntimeEnabledPaths();
+ EXPECT_EQ(ruy_paths, Path::kStandardCpp | Path::kNeonDotprod);
+}
+
+TEST(ContextInternalTest, EnabledPathsArmDefault) {
+ CtxImpl ctx;
+ const auto ruy_paths = ctx.GetRuntimeEnabledPaths();
+ EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp);
+ // NEON is always assumed to be supported at the moment.
+ EXPECT_EQ(ruy_paths & Path::kNeon, Path::kNeon);
+}
+#endif // RUY_PLATFORM_ARM
+
+TEST(ContextInternalTest, ThreadSpecificResources) {
+ CtxImpl ctx;
+ for (int i = 1; i <= 4; i++) {
+ ctx.EnsureThreadSpecificResources(i);
+ for (int j = 0; j < i; j++) {
+ EXPECT_NE(ctx.GetThreadSpecificAllocator(j), nullptr);
+ EXPECT_NE(ctx.GetThreadSpecificTuningResolver(j), nullptr);
+ }
+ }
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/frontend.cc b/ruy/frontend.cc
new file mode 100644
index 0000000..01ee474
--- /dev/null
+++ b/ruy/frontend.cc
@@ -0,0 +1,37 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/frontend.h"
+
+#include "ruy/allocator.h"
+#include "ruy/prepare_packed_matrices.h"
+#include "ruy/trmul.h"
+#include "ruy/trmul_params.h"
+
+namespace ruy {
+
+void MulFrontEndFromTrMulParams(Ctx* ctx, TrMulParams* params) {
+ RUY_TRACE_SCOPE;
+ // Handle Matrix::cache_policy, possibly retrieving existing packed matrices
+ // or packing and caching now.
+ PreparePackedMatrices(ctx, params);
+
+ // We're done with the front-end work, now enter the middle-end.
+ TrMul(ctx, params);
+
+ ctx->GetMainAllocator()->FreeAll();
+}
+
+} // namespace ruy
diff --git a/ruy/frontend.h b/ruy/frontend.h
new file mode 100644
index 0000000..a79f590
--- /dev/null
+++ b/ruy/frontend.h
@@ -0,0 +1,99 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Implementation of MulFrontEnd, the front-end part of ruy.
+// This is what the ruy::Mul entry point calls, and this ends in a call to
+// TrMul, at which point we enter the middle-end.
+// The front-end work includes parameter validation (Validate), detemplatization
+// and resolution of the specific code path to take (CreateTrMulParams), and
+// any additional logic best done upfront before entering the middle-end
+// (e.g. HandlePrepackedCaching).
+// The call to CreateTrMulParams is an important watershed in this code's
+// structure: code before it needs to be templatized like the ruy::Mul entry
+// point, code after it is un-templatized.
+
+#ifndef RUY_RUY_FRONTEND_H_
+#define RUY_RUY_FRONTEND_H_
+
+#include "ruy/create_trmul_params.h"
+#include "ruy/ctx.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/trace.h"
+#include "ruy/trmul_params.h"
+#include "ruy/validate.h"
+
+namespace ruy {
+
+// The first half of front-end work, up to the point where we have TrMulParams.
+// In other words, this is the part of the front-end work that needs to be
+// templatized like the entry point, and that performs the initial work that
+// requires this templatization, and the de-templatization. The output of this
+// function is the TrMulParams, which contain enough information to allow the
+// un-templatized code to take over from there.
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+void MulFrontEndUpToCreateTrMulParams(
+ const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
+ const Mat<DstScalar>& dst,
+ const MulParams<AccumScalar, DstScalar>& mul_params, Ctx* ctx,
+ TrMulParams* params) {
+ RUY_TRACE_SCOPE;
+ static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path");
+ static_assert(
+ (CompiledPaths & ~kAllPathsIncludingInternalVariants) == Path::kNone,
+ "CompiledPaths must be a subset of "
+ "ruy::kAllPathsIncludingInternalVariants");
+
+ // Perform validation of parameters early so that failures are easier to map
+ // to user errors. In particular, perform this validation before the
+ // transposition.
+ Validate(lhs, rhs, dst);
+
+ // De-templatize this Mul call by creating a TrMulParams structure.
+ // This is also where the specific kernel and pack code paths corresponding to
+ // `the_path` are selected, among all the code paths in `CompiledPaths`, and
+ // recorded as function pointers in the TrMulParams.
+ // The Transpose(lhs) here is where we switch from 'Mul' to 'TrMul'.
+ CreateTrMulParams<CompiledPaths>(Transpose(lhs), rhs, dst, mul_params, ctx,
+ params);
+}
+
+// The second part of the front-end work, starting from where we have freshly
+// created TrMulParams, performing any remaining front-end work and entering the
+// middle-end.
+void MulFrontEndFromTrMulParams(Ctx* ctx, TrMulParams* params);
+
+// Top-level function orchestrating the two halves of front-end work:
+// before and after we have detemplatized the call by creating TrMulParams.
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+void MulFrontEnd(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
+ const MulParams<AccumScalar, DstScalar>& mul_params, Ctx* ctx,
+ Mat<DstScalar>* dst) {
+ RUY_TRACE_SCOPE;
+ profiler::ScopeLabel mul_label("Mul");
+ profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d",
+ lhs.layout.rows, lhs.layout.cols,
+ rhs.layout.cols);
+ ctx->clear_performance_advisories();
+ TrMulParams params;
+ MulFrontEndUpToCreateTrMulParams<CompiledPaths>(lhs, rhs, *dst, mul_params,
+ ctx, &params);
+ MulFrontEndFromTrMulParams(ctx, &params);
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_FRONTEND_H_
diff --git a/ruy/gtest_wrapper.h b/ruy/gtest_wrapper.h
new file mode 100644
index 0000000..690cea4
--- /dev/null
+++ b/ruy/gtest_wrapper.h
@@ -0,0 +1,33 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Wrapper around GTest that works around warnings and inconsistencies.
+
+#ifndef THIRD_PARTY_RUY_RUY_GTEST_WRAPPER_H_
+#define THIRD_PARTY_RUY_RUY_GTEST_WRAPPER_H_
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#include "gtest/gtest.h" // IWYU pragma: export
+#pragma GCC diagnostic pop
+
+// When building for WAsm, ASSERT_DEATH is not defined.
+#ifdef ASSERT_DEATH
+#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) ASSERT_DEATH(CONDITION, MESSAGE)
+#else
+#define RUY_ASSERT_DEATH(CONDITION, MESSAGE)
+#endif
+
+#endif // THIRD_PARTY_RUY_RUY_GTEST_WRAPPER_H_
diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h
new file mode 100644
index 0000000..23cb028
--- /dev/null
+++ b/ruy/have_built_path_for.h
@@ -0,0 +1,31 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_HAVE_BUILT_PATH_FOR_H_
+#define RUY_RUY_HAVE_BUILT_PATH_FOR_H_
+
+#include "ruy/platform.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_X86
+bool HaveBuiltPathForAvx();
+bool HaveBuiltPathForAvx2Fma();
+bool HaveBuiltPathForAvx512();
+#endif // RUY_PLATFORM_X86
+
+} // namespace ruy
+
+#endif // RUY_RUY_HAVE_BUILT_PATH_FOR_H_
diff --git a/ruy/have_built_path_for_avx.cc b/ruy/have_built_path_for_avx.cc
new file mode 100644
index 0000000..948c7a5
--- /dev/null
+++ b/ruy/have_built_path_for_avx.cc
@@ -0,0 +1,35 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/have_built_path_for.h"
+#include "ruy/opt_set.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_X86
+// IMPORTANT:
+// These patterns must match those in the pack and kernel cc files.
+#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
+
+bool HaveBuiltPathForAvx() { return false; }
+
+#else // RUY_PLATFORM_AVX && RUY_OPT(ASM)
+
+bool HaveBuiltPathForAvx() { return true; }
+
+#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM)
+#endif // RUY_PLATFORM_X86
+
+} // namespace ruy
diff --git a/ruy/have_built_path_for_avx2_fma.cc b/ruy/have_built_path_for_avx2_fma.cc
new file mode 100644
index 0000000..03e8f8d
--- /dev/null
+++ b/ruy/have_built_path_for_avx2_fma.cc
@@ -0,0 +1,35 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/have_built_path_for.h"
+#include "ruy/opt_set.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_X86
+// IMPORTANT:
+// These patterns must match those in the pack and kernel cc files.
+#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
+
+bool HaveBuiltPathForAvx2Fma() { return false; }
+
+#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
+
+bool HaveBuiltPathForAvx2Fma() { return true; }
+
+#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
+#endif // RUY_PLATFORM_X86
+
+} // namespace ruy
diff --git a/ruy/have_built_path_for_avx512.cc b/ruy/have_built_path_for_avx512.cc
new file mode 100644
index 0000000..29d56a8
--- /dev/null
+++ b/ruy/have_built_path_for_avx512.cc
@@ -0,0 +1,35 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/have_built_path_for.h"
+#include "ruy/opt_set.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_X86
+// IMPORTANT:
+// These patterns must match those in the pack and kernel cc files.
+#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
+
+bool HaveBuiltPathForAvx512() { return false; }
+
+#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
+
+bool HaveBuiltPathForAvx512() { return true; }
+
+#endif // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
+#endif // RUY_PLATFORM_X86
+
+} // namespace ruy
diff --git a/ruy/kernel.h b/ruy/kernel.h
new file mode 100644
index 0000000..6bfeb4a
--- /dev/null
+++ b/ruy/kernel.h
@@ -0,0 +1,245 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_KERNEL_H_
+#define RUY_RUY_KERNEL_H_
+
+#include "ruy/kernel_common.h"
+#include "ruy/mul_params.h"
+#include "ruy/platform.h"
+#include "ruy/trace.h"
+
+// IWYU pragma: begin_exports
+#if RUY_PLATFORM_NEON
+#include "ruy/kernel_arm.h"
+#elif RUY_PLATFORM_X86
+#include "ruy/kernel_x86.h"
+#endif
+// IWYU pragma: end_exports
+
+namespace ruy {
+
+// KernelArgs is a helper to access the template parameter values from a Kernel
+// template instantiation.
+template <typename KernelType>
+struct KernelArgs {};
+
+template <Path tPath, typename tLhsScalar, typename tRhsScalar,
+ typename tAccumScalar, typename tDstScalar>
+struct KernelArgs<
+ Kernel<tPath, tLhsScalar, tRhsScalar, tAccumScalar, tDstScalar>> {
+ static constexpr Path kPath = tPath;
+ using LhsScalar = tLhsScalar;
+ using RhsScalar = tRhsScalar;
+ using AccumScalar = tAccumScalar;
+ using DstScalar = tDstScalar;
+};
+
+// RunKernel::Run() is the only place that directly invokes Kernel::Run().
+// It performs the types un-erasure, and factoring all Kernel::Run() calls
+// through this function also gives a single place where to conditionally
+// implement RUY_OPT(FAT_KERNEL). This should be a function but is a class to
+// hide and share some boilerplate (see the member types, and the RunTyped
+// method also using them).
+template <typename KernelType>
+class RunKernel final {
+ public:
+ static void Run(Tuning tuning, const SidePair<PEMat>& src,
+ const void* mul_params, const SidePair<int>& start,
+ const SidePair<int>& end, EMat* dst) {
+ RUY_TRACE_SCOPE_NAME("RunKernel");
+ const auto& unerased_lhs = UneraseType<LhsScalar>(src[Side::kLhs]);
+ const auto& unerased_rhs = UneraseType<RhsScalar>(src[Side::kRhs]);
+ auto unerased_dst = UneraseType<DstScalar>(*dst);
+ RUY_TRACE_INFO(RUN_KERNEL);
+ RunTyped(tuning, unerased_lhs, unerased_rhs,
+ *static_cast<const MulParamsType*>(mul_params), start, end,
+ &unerased_dst);
+ }
+
+ private:
+ using Args = KernelArgs<KernelType>;
+ using LhsScalar = typename Args::LhsScalar;
+ using RhsScalar = typename Args::RhsScalar;
+ using AccumScalar = typename Args::AccumScalar;
+ using DstScalar = typename Args::DstScalar;
+ using MulParamsType = MulParams<AccumScalar, DstScalar>;
+ static void RunTyped(Tuning tuning, const PMat<LhsScalar>& lhs,
+ const PMat<RhsScalar>& rhs,
+ const MulParamsType& mul_params,
+ const SidePair<int>& start, const SidePair<int>& end,
+ Mat<DstScalar>* dst) {
+ const int start_row = start[Side::kLhs];
+ const int start_col = start[Side::kRhs];
+ const int end_row = end[Side::kLhs];
+ const int end_col = end[Side::kRhs];
+ KernelType kernel(tuning);
+ using LhsLayout = typename KernelType::LhsLayout;
+ using RhsLayout = typename KernelType::RhsLayout;
+ // This is a good place to validate kernel layouts. The Kernel class
+ // template itself isn't a good place to do that because it has
+ // specializations.
+ // The kRows of both sides have to match: in TrMul, kRows is the depth
+ // dimension, on which LHS and RHS have to agree for the matrix
+ // multiplication to be defined at all, so requiring the corresponding
+ // dimension of the kernel layouts to also match is reasonable. If it didn't
+ // match, then the packed matrices could have mismatching depth dimensions
+ // even with the source matrices agreeing.
+ static_assert(LhsLayout::kRows == RhsLayout::kRows, "");
+ // The kernel layouts have to be power-of-two. This simplifies BlockMap
+ // logic considerably. This also avoids leaking fine performance
+ // optimization details up the stack. For instance, if one of the dimensions
+ // were 6, then users might notice that optimal performance is achieved with
+ // matrix dimensions that are multiples of 6, and might start contorting
+ // their own application code to match that requirement, in a way that would
+ // not be future-proof.
+ static_assert(is_pot(LhsLayout::kRows), "");
+ static_assert(is_pot(LhsLayout::kCols), "");
+ static_assert(is_pot(RhsLayout::kRows), "");
+ static_assert(is_pot(RhsLayout::kCols), "");
+ // end_row and end_col may be larger than dst dimensions.
+ // that is because kernels write directly to the destination matrix, whose
+ // dimensions may not be a multiple of the kernel dimensions, and we try to
+ // keep this annoyance localized as an implementation detail in kernels,
+ // by allowing to pass rounded-up values down as far as possible.
+ // These assertions encode the contract.
+ RUY_DCHECK_LE(0, start_row);
+ RUY_DCHECK_LE(start_row, end_row);
+ RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols);
+ RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0);
+ RUY_DCHECK_LE(0, start_col);
+ RUY_DCHECK_LE(start_col, end_col);
+ RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols);
+ RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0);
+#if RUY_OPT(FAT_KERNEL)
+ kernel.Run(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst);
+#else
+ for (int col = start_col; col < end_col; col += RhsLayout::kCols) {
+ int block_end_col = std::min(col + RhsLayout::kCols, end_col);
+ for (int row = start_row; row < end_row; row += LhsLayout::kCols) {
+ int block_end_row = std::min(row + LhsLayout::kCols, end_row);
+ kernel.Run(lhs, rhs, mul_params, row, col, block_end_row, block_end_col,
+ dst);
+ }
+ }
+#endif
+ }
+};
+
+template <Path ThePath>
+struct StandardCppKernelLayout {};
+
+template <>
+struct StandardCppKernelLayout<Path::kStandardCpp> {
+ using Lhs = FixedKernelLayout<Order::kColMajor, 1, 1>;
+ using Rhs = FixedKernelLayout<Order::kColMajor, 1, 1>;
+};
+
+// A variant exercising RowMajor square blocks
+template <>
+struct StandardCppKernelLayout<Path::kInternalStandardCppVariant1> {
+ using Lhs = FixedKernelLayout<Order::kRowMajor, 4, 4>;
+ using Rhs = FixedKernelLayout<Order::kRowMajor, 4, 4>;
+};
+
+// A variant with a rectangular layout: 4x8
+template <>
+struct StandardCppKernelLayout<Path::kInternalStandardCppVariant2> {
+ using Lhs = FixedKernelLayout<Order::kColMajor, 1, 4>;
+ using Rhs = FixedKernelLayout<Order::kColMajor, 1, 8>;
+};
+
+// A variant with different block orders in LHS vs RHS.
+template <>
+struct StandardCppKernelLayout<Path::kInternalStandardCppVariant3> {
+ using Lhs = FixedKernelLayout<Order::kColMajor, 2, 16>;
+ using Rhs = FixedKernelLayout<Order::kRowMajor, 2, 8>;
+};
+
+// General implementation of the Kernel template, overridden by template
+// specializations for specific SIMD code paths. This general implementation
+// covers Path::kStandardCpp and its internal test-only variants.
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+struct Kernel {
+ // Each Kernel specialization defines kPath as the ground-truth path that it
+ // implements. This is used in assertions. As we support fallbacks between
+ // paths (see RUY_INHERIT_KERNEL), Unless a specialization for a specific set
+ // of template parameters was defined, it is normal for template
+ // instantiations of the form Kernel<SomePath, ...> to have kPath!=SomePath.
+ // Assertions that kPath==SomePath are used in places where we know that we
+ // should be using a template specialization for a specific path rather than a
+ // fallback.
+ static constexpr Path kPath = ThePath;
+ using MulParamsType = MulParams<AccumScalar, DstScalar>;
+ using LhsLayout = typename StandardCppKernelLayout<ThePath>::Lhs;
+ using RhsLayout = typename StandardCppKernelLayout<ThePath>::Rhs;
+ explicit Kernel(Tuning) {}
+ void Run(const PMat<LhsScalar>& lhs, const PMat<RhsScalar>& rhs,
+ const MulParamsType& mul_params, int start_row, int start_col,
+ int end_row, int end_col, Mat<DstScalar>* dst) const {
+ // See the comment in RunKernelTyped. end_row may be larger than
+ // dst->layout.rows. It's the responsibility of the kernel to avoid
+ // overrunning dst boundaries, which we do here by computing
+ // clamped_end_row.
+ int clamped_end_row = std::min(end_row, dst->layout.rows);
+ int clamped_end_col = std::min(end_col, dst->layout.cols);
+ RUY_DCHECK_LE(0, start_row);
+ RUY_DCHECK_LE(start_row, clamped_end_row);
+ RUY_DCHECK_LE(clamped_end_row, dst->layout.rows);
+ RUY_DCHECK_LE(clamped_end_row, end_row);
+ RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols);
+ RUY_DCHECK_LE(0, start_col);
+ RUY_DCHECK_LE(start_col, clamped_end_col);
+ RUY_DCHECK_LE(clamped_end_col, dst->layout.cols);
+ RUY_DCHECK_LE(clamped_end_col, end_col);
+ RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols);
+ profiler::ScopeLabel label("Kernel (Standard Cpp)");
+ const int depth = lhs.layout.rows;
+ for (int i = start_row; i < clamped_end_row; i++) {
+ for (int j = start_col; j < clamped_end_col; j++) {
+ AccumScalar accum = 0;
+ for (int k = 0; k < depth; k++) {
+ AccumScalar lhs_val = Element(lhs, k, i);
+ AccumScalar rhs_val = Element(rhs, k, j);
+ accum += lhs_val * rhs_val;
+ }
+ int channel =
+ mul_params.channel_dimension() == ChannelDimension::kRow ? i : j;
+ if (mul_params.bias()) {
+ accum += mul_params.bias()[channel];
+ }
+ if (lhs.zero_point) {
+ accum -= lhs.zero_point * rhs.sums[j];
+ }
+ if (rhs.zero_point) {
+ accum -= rhs.zero_point * lhs.sums[i];
+ }
+ if (lhs.zero_point && rhs.zero_point) {
+ accum += lhs.zero_point * rhs.zero_point * depth;
+ }
+ ApplyMultiplier(mul_params, channel, &accum);
+ accum += dst->zero_point;
+ accum = std::min<AccumScalar>(accum, mul_params.clamp_max());
+ accum = std::max<AccumScalar>(accum, mul_params.clamp_min());
+ *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
+ }
+ }
+ }
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_KERNEL_H_
diff --git a/ruy/kernel_arm.h b/ruy/kernel_arm.h
new file mode 100644
index 0000000..76cfc82
--- /dev/null
+++ b/ruy/kernel_arm.h
@@ -0,0 +1,212 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_KERNEL_ARM_H_
+#define RUY_RUY_KERNEL_ARM_H_
+
+#include <cstddef>
+#include <cstdint>
+
+#include "ruy/asm_helpers.h"
+#include "ruy/kernel_common.h"
+#include "ruy/mat.h"
+#include "ruy/mul_params.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/side_pair.h"
+#include "ruy/size_util.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_NEON && RUY_OPT(ASM)
+
+RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
+RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
+
+#if RUY_PLATFORM_NEON_64
+void Kernel8bitNeon(const KernelParams8bit<4, 4>& params);
+void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params);
+#elif RUY_PLATFORM_NEON_32
+void Kernel8bitNeon(const KernelParams8bit<4, 2>& params);
+void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params);
+#endif
+void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params);
+void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params);
+void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params);
+void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params);
+
+#if RUY_PLATFORM_NEON_64
+template <typename DstScalar>
+struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
+ static constexpr Path kPath = Path::kNeon;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
+ Tuning tuning = Tuning::kAuto;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ Kernel8bitNeon1Col(params);
+ return;
+ }
+ if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
+ Kernel8bitNeonA55ish(params);
+ } else {
+ Kernel8bitNeon(params);
+ }
+ }
+};
+#endif
+
+#if RUY_PLATFORM_NEON_32
+template <typename DstScalar>
+struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
+ static constexpr Path kPath = Path::kNeon;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>;
+ Tuning tuning = Tuning::kAuto;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ Kernel8bitNeon1Col(params);
+ return;
+ }
+ Kernel8bitNeon(params);
+ }
+};
+#endif
+
+#if RUY_PLATFORM_NEON_64
+template <typename DstScalar>
+struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
+ static constexpr Path kPath = Path::kNeonDotprod;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ Kernel8bitNeonDotprod1Col(params);
+ } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
+ Kernel8bitNeonDotprodA55ish(params);
+ } else {
+ Kernel8bitNeonDotprod(params);
+ }
+ }
+};
+#endif
+
+void KernelFloatNeon(const KernelParamsFloat<8, 8>& params);
+void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params);
+void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params);
+void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params);
+
+#if RUY_PLATFORM_NEON_64
+// A Float kernel for ARM64 Neon.
+template <>
+struct Kernel<Path::kNeon, float, float, float, float> {
+ static constexpr Path kPath = Path::kNeon;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
+ const MulParams<float, float>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
+ KernelFloatNeonA55ish(params);
+ } else {
+ KernelFloatNeon(params);
+ }
+ }
+};
+#endif
+
+#if RUY_PLATFORM_NEON_32
+// A Float kernel for ARM32 Neon.
+template <>
+struct Kernel<Path::kNeon, float, float, float, float> {
+ static constexpr Path kPath = Path::kNeon;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
+ const MulParams<float, float>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
+ KernelParamsFloat<8, 4> params;
+
+ MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+
+ KernelFloat32Neon(params);
+ }
+};
+#endif
+
+// While the dotprod NEON extension does not concern floating-point arithmetic,
+// its presence allows us to distinguish, in the in-order tuning case, between
+// A53 and A55r1. TODO: should this be folded into tuning?
+template <>
+struct Kernel<Path::kNeonDotprod, float, float, float, float> {
+ static constexpr Path kPath = Path::kNeonDotprod;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using Base =
+ Kernel<Path::kNeon, float, float, float, float>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
+ const MulParams<float, float>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
+ KernelFloatNeonDotprodA55ish(params);
+ } else {
+ KernelFloatNeon(params);
+ }
+ }
+};
+
+#endif // RUY_PLATFORM_NEON && RUY_OPT(ASM)
+
+} // namespace ruy
+
+#endif // RUY_RUY_KERNEL_ARM_H_
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc
new file mode 100644
index 0000000..b20f668
--- /dev/null
+++ b/ruy/kernel_arm32.cc
@@ -0,0 +1,2515 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/kernel_arm.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
+
+#define RUY_ASM_LABEL_STORE_UINT8 91
+#define RUY_ASM_LABEL_STORE_INT8 92
+#define RUY_ASM_LABEL_STORE_INT16 93
+#define RUY_ASM_LABEL_STORE_INT32 94
+#define RUY_ASM_LABEL_AFTER_STORE 99
+
+#define RUY_OFFSET_LHS_BASE_PTR 0
+#define RUY_OFFSET_RHS_BASE_PTR 4
+#define RUY_OFFSET_DST_BASE_PTR 8
+#define RUY_OFFSET_BIAS 12
+#define RUY_OFFSET_START_ROW 16
+#define RUY_OFFSET_START_COL 20
+#define RUY_OFFSET_LAST_ROW 24
+#define RUY_OFFSET_LAST_COL 28
+#define RUY_OFFSET_DST_ROWS 32
+#define RUY_OFFSET_DST_COLS 36
+#define RUY_OFFSET_LHS_STRIDE 40
+#define RUY_OFFSET_RHS_STRIDE 44
+#define RUY_OFFSET_DST_STRIDE 48
+#define RUY_OFFSET_DEPTH 52
+#define RUY_OFFSET_CLAMP_MIN 56
+#define RUY_OFFSET_CLAMP_MAX 60
+#define RUY_OFFSET_FLAGS 64
+
+#define RUY_STACK_OFFSET_SIZE 96
+#define RUY_STACK_OFFSET_DST_COL_PTR 0
+#define RUY_STACK_OFFSET_DST_PTR 16
+#define RUY_STACK_OFFSET_ROW 32
+#define RUY_STACK_OFFSET_COL 48
+#define RUY_STACK_OFFSET_LHS_COL_PTR 64
+#define RUY_STACK_OFFSET_RHS_COL_PTR 80
+
+template <typename Params>
+void CheckOffsetsInKernelParamsFloat32(const Params&) {
+ static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
+ static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
+ static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
+ static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
+ static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
+ static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
+ static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
+ static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
+ static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, "");
+ static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
+ static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
+ static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
+ static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
+ static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
+ static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
+ static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
+}
+
+// Float kernel for ARM32 out-of-order cores.
+// Just like Float 64 version, except accumulate in to 8x4 block to only
+// use 16 128-bit NEON registers. This is a "first pass" kernel and not
+// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9.
+void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params) {
+ CheckOffsetsInKernelParamsFloat32(params);
+ profiler::ScopeLabel label("Kernel (kNeon)");
+
+ const float* lhs_ptr = params.lhs_base_ptr;
+ const float* rhs_ptr = params.rhs_base_ptr;
+ // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are
+ // each composed of two 64-bit "d" registers. The asm kernel below has the
+ // following NEON register allocation:
+ // Registers q3 -- q10 are accumulators. During accumulation,
+ // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1
+ // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block
+ // of RHS, like this:
+
+ // Register layout in "q" registers:
+ // RHS 1x4 block
+ // /--------------------------|
+ // |q2.s[0] ... q2.s[3] |
+ // \--------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /--------------------------|
+ // | q0.s[0] | | q3.s[0] ... q9.s[0] |
+ // | ... | | ... ... |
+ // | q0.s[3] | | q3.s[3] q9.s[3] |
+ // | q1.s[0] | | q4.s[0] q10.s[0] |
+ // | ... | | ... ... ... |
+ // | q1.s[3] | | q4.s[3] .. q10.s[3] |
+ // \---------------------/ \--------------------------/
+ // accumulators 8x4 block
+ // q11, q14, q15 currently unused. q12 and q13 are used to load
+ // parameters used for the post-accumulation part of the kernel.
+ // For completeness, here is the register layout in "d" registers:
+ // RHS 1x4 block
+ // /--------------------------|
+ // |d4[0] ... d5[1] |
+ // \--------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /--------------------------|
+ // | d0[0] | | d6[0] ... d18[0] |
+ // | ... | | ... ... |
+ // | d1[1] | | d7[1] d19[1] |
+ // | d2[0] | | d8[0] d20[0] |
+ // | ... | | ... ... ... |
+ // | d3[1] | | d9[1] ... d21[1] |
+ // \---------------------/ \--------------------------/
+ // accumulators 8x4 block
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n"
+
+ // clang-format off
+
+ // Load the first 32 bytes of LHS and RHS data.
+ // Load q0, q1
+ "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n"
+ "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ // Load q2
+ "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
+ "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q3)
+ RUY_MAKE_ZERO(q4)
+ RUY_MAKE_ZERO(q5)
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+
+ // r1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 1.
+ "mov r1, #1\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Accumulation loop
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+ "cmp r1, r2\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ "vmla.f32 q3, q0, d4[0]\n"
+ "vmla.f32 q5, q0, d4[1]\n"
+ "vmla.f32 q7, q0, d5[0]\n"
+ "vmla.f32 q9, q0, d5[1]\n"
+ "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
+
+ "vmla.f32 q4, q1, d4[0]\n"
+ "vmla.f32 q6, q1, d4[1]\n"
+ "vmla.f32 q8, q1, d5[0]\n"
+ "vmla.f32 q10, q1, d5[1]\n"
+ "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ "add r1, r1, #1\n"
+ "cmp r1, r2\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last level of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ "vmla.f32 q3, q0, d4[0]\n"
+ "vmla.f32 q5, q0, d4[1]\n"
+ "vmla.f32 q7, q0, d5[0]\n"
+ "vmla.f32 q9, q0, d5[1]\n"
+
+ "vmla.f32 q4, q1, d4[0]\n"
+ "vmla.f32 q6, q1, d4[1]\n"
+ "vmla.f32 q8, q1, d5[0]\n"
+ "vmla.f32 q10, q1, d5[1]\n"
+
+ // End of accumulation. The registers q3 -- q10 contain the final
+ // float32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final values from these accumulators
+ // and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r1, r3\n" // Have we finished the last row?
+
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "add r4, r4, r1, lsl #3\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ // Go back to first row
+ "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "add r10, r10, r1, lsl #2\n"
+ "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "mov %[lhs_ptr], r4\n"
+ "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "mov %[rhs_ptr], r5\n"
+
+ // Load some parameters needed for the end work on current block.
+ "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Let r8 be stack offset of the row or column variable, whichever
+ // is the channel index.
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "ite eq\n"
+ "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
+ "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
+ // Let r8 be the channel index.
+ "ldr r8, [sp, r8]\n"
+ // Compute the bias pointer, by conditionally using the channel index
+ // (r8) as offset into bias buffer (r1).
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "it ne\n"
+ "addne r1, r1, r8, lsl #2\n"
+
+ // Load 4 bias values. When the channel dimension is rows, we will load
+ // another 4 bias values just before performing the bias addition below,
+ // as this kernel has a 8x4 rectangular shape.
+ "vld1.32 {d24, d25}, [r1]!\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore
+ // in the rest of the work on the current block.
+ // Load q0, q1
+ "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ // Load q2
+ "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Perform the bias-addition.
+ // Jump based on channel dimension.
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows.
+ // Load the remaining 4 bias values, since we're on the width-8 side
+ // of this 8x4 kernel.
+ "vld1.32 {d26, d27}, [r1]\n"
+ "vadd.f32 q3, q3, q12\n"
+ "vadd.f32 q5, q5, q12\n"
+ "vadd.f32 q7, q7, q12\n"
+ "vadd.f32 q9, q9, q12\n"
+ "vadd.f32 q4, q4, q13\n"
+ "vadd.f32 q6, q6, q13\n"
+ "vadd.f32 q8, q8, q13\n"
+ "vadd.f32 q10, q10, q13\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns.
+ "vdup.32 q11, d24[0]\n"
+ "vdup.32 q13, d24[1]\n"
+ "vdup.32 q14, d25[0]\n"
+ "vdup.32 q15, d25[1]\n"
+ "vadd.f32 q3, q3, q11\n"
+ "vadd.f32 q4, q4, q11\n"
+ "vadd.f32 q5, q5, q13\n"
+ "vadd.f32 q6, q6, q13\n"
+ "vadd.f32 q7, q7, q14\n"
+ "vadd.f32 q8, q8, q14\n"
+ "vadd.f32 q9, q9, q15\n"
+ "vadd.f32 q10, q10, q15\n"
+ "7:\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.32 q12, r2\n" // clamp_min
+ "vdup.32 q13, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.f32 q3, q3, q12\n"
+ "vmax.f32 q4, q4, q12\n"
+ "vmax.f32 q5, q5, q12\n"
+ "vmax.f32 q6, q6, q12\n"
+ "vmax.f32 q7, q7, q12\n"
+ "vmax.f32 q8, q8, q12\n"
+ "vmax.f32 q9, q9, q12\n"
+ "vmax.f32 q10, q10, q12\n"
+
+ // Apply the clamp_max bound
+ "vmin.f32 q3, q3, q13\n"
+ "vmin.f32 q4, q4, q13\n"
+ "vmin.f32 q5, q5, q13\n"
+ "vmin.f32 q6, q6, q13\n"
+ "vmin.f32 q7, q7, q13\n"
+ "vmin.f32 q8, q8, q13\n"
+ "vmin.f32 q9, q9, q13\n"
+ "vmin.f32 q10, q10, q13\n"
+
+ // Compute how much of the 8x4 block of destination values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x4, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #8\n"
+ "mov r5, #4\n"
+ "cmp r1, #8\n"
+ // Compute r1 = how many rows of the 8x4 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+ "cmp r2, #4\n"
+ // Compute r2 = how many cols of the 8x4 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ // Yes, all of the 8x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x4 block fits.
+ // Set (r3 address, r4 stride) to write to dst_tmp_buf
+ "mov r3, %[dst_tmp_buf]\n"
+ "mov r4, #32\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x4 block fits.
+ // Set (r3 address, r4 stride) to write directly to destination matrix.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r5\n"
+ "31:\n"
+
+ // Write our float values to the destination described by
+ // (r3 address, r4 stride)
+ "vst1.32 {d6, d7, d8, d9}, [r3]\n"
+ "add r3, r3, r4\n"
+ RUY_MAKE_ZERO(q3)
+ RUY_MAKE_ZERO(q4)
+ "vst1.32 {d10, d11, d12, d13}, [r3]\n"
+ "add r3, r3, r4\n"
+ RUY_MAKE_ZERO(q5)
+ RUY_MAKE_ZERO(q6)
+ "vst1.32 {d14, d15, d16, d17}, [r3]\n"
+ "add r3, r3, r4\n"
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ "vst1.32 {d18, d19, d20, d21}, [r3]\n"
+ "add r3, r3, r4\n"
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+
+ // If all of the 8x4 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "mov r3, %[dst_tmp_buf]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r5, #0\n"
+ "51:\n"
+ "ldr r10, [r3, r5, lsl #2]\n"
+ "str r10, [r4, r5, lsl #2]\n"
+ "add r5, r5, #1\n"
+ "cmp r5, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #32\n"
+ "add r4, r4, r8\n"
+ // r2 = how many cols of the 8x4 block fit
+ "cmp r6, r2\n"
+ "blt 50b\n"
+ "41:\n"
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #32\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ // Reload some params --- we had used r3, r5, r10 for a few other things
+ // since the last time we had loaded them.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r8, r3\n"
+
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add r8, r8, #8\n"
+ // Store new value of row
+ "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ // Move back to first row.
+ "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Move to the next column.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns)
+ "add r1, r1, r8, lsl #2\n"
+ // Store dst_col_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Store dst_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n"
+
+ // r1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 1.
+ "mov r1, #1\n"
+
+ "ble 1b\n"
+
+ // Restore stack pointer.
+ "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ // clang-format on
+ : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
+ : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ // Clobber list must specify q registers (and not their constituent
+ // d registers). There is a (currently unexplained) slowdown if
+ // d registers are listed in the clobbers list.
+ : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
+ "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
+ "q9", "q10", "q12", "q13");
+}
+
+#undef RUY_MAKE_ZERO
+#undef RUY_STACK_OFFSET_SIZE
+#undef RUY_STACK_OFFSET_DST_COL_PTR
+#undef RUY_STACK_OFFSET_DST_PTR
+#undef RUY_STACK_OFFSET_ROW
+#undef RUY_STACK_OFFSET_COL
+#undef RUY_STACK_OFFSET_LHS_COL_PTR
+#undef RUY_STACK_OFFSET_RHS_COL_PTR
+
+#undef RUY_OFFSET_LHS_BASE_PTR
+#undef RUY_OFFSET_RHS_BASE_PTR
+#undef RUY_OFFSET_DST_BASE_PTR
+#undef RUY_OFFSET_BIAS
+#undef RUY_OFFSET_START_ROW
+#undef RUY_OFFSET_START_COL
+#undef RUY_OFFSET_LAST_ROW
+#undef RUY_OFFSET_LAST_COL
+#undef RUY_OFFSET_DST_ROWS
+#undef RUY_OFFSET_DST_COLS
+#undef RUY_OFFSET_LHS_STRIDE
+#undef RUY_OFFSET_RHS_STRIDE
+#undef RUY_OFFSET_DST_STRIDE
+#undef RUY_OFFSET_DEPTH
+#undef RUY_OFFSET_CLAMP_MIN
+#undef RUY_OFFSET_CLAMP_MAX
+#undef RUY_OFFSET_FLAGS
+
+#define RUY_OFFSET_BIAS 0
+#define RUY_OFFSET_LHS_SUMS 4
+#define RUY_OFFSET_RHS_SUMS 8
+#define RUY_OFFSET_LHS_BASE_PTR 12
+#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16
+#define RUY_OFFSET_MULTIPLIER_EXPONENT 20
+#define RUY_OFFSET_RHS_BASE_PTR 24
+#define RUY_OFFSET_DST_BASE_PTR 28
+#define RUY_OFFSET_LHS_ZERO_POINT 32
+#define RUY_OFFSET_RHS_ZERO_POINT 36
+#define RUY_OFFSET_DST_ZERO_POINT 40
+#define RUY_OFFSET_PROD_ZP_DEPTH 44
+#define RUY_OFFSET_START_ROW 48
+#define RUY_OFFSET_START_COL 52
+#define RUY_OFFSET_LAST_ROW 56
+#define RUY_OFFSET_LAST_COL 60
+#define RUY_OFFSET_DST_ROWS 64
+#define RUY_OFFSET_DST_COLS 68
+#define RUY_OFFSET_LHS_STRIDE 72
+#define RUY_OFFSET_RHS_STRIDE 76
+#define RUY_OFFSET_DST_STRIDE 80
+#define RUY_OFFSET_DEPTH 84
+#define RUY_OFFSET_CLAMP_MIN 88
+#define RUY_OFFSET_CLAMP_MAX 92
+#define RUY_OFFSET_FLAGS 96
+#define RUY_OFFSET_DST_TYPE_ID 97
+
+#define RUY_STACK_OFFSET_SIZE 96
+#define RUY_STACK_OFFSET_DST_COL_PTR 0
+#define RUY_STACK_OFFSET_DST_PTR 16
+#define RUY_STACK_OFFSET_ROW 32
+#define RUY_STACK_OFFSET_COL 48
+#define RUY_STACK_OFFSET_LHS_COL_PTR 64
+#define RUY_STACK_OFFSET_RHS_COL_PTR 80
+
+template <typename Params>
+void CheckOffsetsInKernelParams8bit(const Params&) {
+ static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
+ "");
+ static_assert(offsetof(Params, multiplier_fixedpoint) ==
+ RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
+ "");
+ static_assert(
+ offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
+ "");
+ static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
+ static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
+ static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
+ static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
+ static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
+ static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
+ static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
+ static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
+ static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
+ static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
+ static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
+ static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
+ static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
+ static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
+}
+
+// Fast-int8 kernel, ported from ARM 64 version.
+// Relevant target CPUs for this kernel include Krait 400 and A9,
+// since these are 32-bit, out-of-order CPUs.
+void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
+ profiler::ScopeLabel label("Kernel (kNeon)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // q6 - q13 are 128-bit (4x32b) accumulators.
+ // During accumulation, d0 -- d7 are used to load int8 data from LHS and
+ // d8 -- d11 from RHS:
+ // int8 RHS 16x2 block
+ // /-----------------------------|
+ // |d8.b[0-7] ..... d10.b[0-7]|
+ // | ... ... |
+ // |d9.b[0-7] ..... d11.b[0-7]|
+ // \-----------------------------/
+ // int8 LHS 4x16 block
+ // /------------------------\ /-----------------------------|
+ // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 |
+ // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 |
+ // (Reload d0, d1, d2, d3)
+ // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 |
+ // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 |
+ // \------------------------/ \-----------------------------/
+ // 128-bit accumulators 4x2 block
+ //
+ // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
+ // optimization for this kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
+
+ // clang-format off
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
+ RUY_MAKE_ZERO(q11)
+ "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"
+
+ "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ RUY_MAKE_ZERO(q12)
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ RUY_MAKE_ZERO(q13)
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ RUY_MAKE_ZERO(q14)
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
+ RUY_MAKE_ZERO(q15)
+ "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+
+
+ // r1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov r1, #16\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // r1 is how many levels of depth we have already loaded
+ // data for, r10 is the total depth.
+ "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+ "cmp r1, r10\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Mult, mult-acc in to q14, q15, q2, q3
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q2, d0, d10\n"
+
+ "vmull.s8 q15, d2, d8\n"
+ "vmull.s8 q3, d2, d10\n"
+
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q2, d1, d11\n"
+ "vmlal.s8 q15, d3, d9\n"
+ "vmlal.s8 q3, d3, d11\n"
+ "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+
+ // Then pairwise accumulate in to q6, q7, q10, q11
+ "vpadal.s16 q6, q14\n"
+ "vpadal.s16 q7, q15\n"
+ "vpadal.s16 q10, q2\n"
+ "vpadal.s16 q11, q3\n"
+
+ // Mult, mult-acc in to q14, q15, q2, q3
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q2, d0, d10\n"
+
+ "vmull.s8 q15, d2, d8\n"
+ "vmull.s8 q3, d2, d10\n"
+
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q2, d1, d11\n"
+ "vmlal.s8 q15, d3, d9\n"
+ "vmlal.s8 q3, d3, d11\n"
+ "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+
+ // Then pairwise accumulate in to q8, q9, q12, q13
+ "vpadal.s16 q8, q14\n"
+ "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
+ "vpadal.s16 q9, q15\n"
+ "vpadal.s16 q12, q2\n"
+ "vpadal.s16 q13, q3\n"
+
+ // Prefetch the next 64 bytes of LHS and RHS data.
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add r1, r1, #16\n"
+
+ // Loop termination condition
+ "cmp r1, r10\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ // Mult, mult-acc in to q14, q15, q2, q3
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q2, d0, d10\n"
+
+ "vmull.s8 q15, d2, d8\n"
+ "vmull.s8 q3, d2, d10\n"
+
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q2, d1, d11\n"
+ "vmlal.s8 q15, d3, d9\n"
+ "vmlal.s8 q3, d3, d11\n"
+ "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+
+ // Then pairwise accumulate in to q6, q7, q10, q11
+ "vpadal.s16 q6, q14\n"
+ "vpadal.s16 q7, q15\n"
+ "vpadal.s16 q10, q2\n"
+ "vpadal.s16 q11, q3\n"
+
+ // Mult, mult-acc in to q14, q15, q2, q3
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q2, d0, d10\n"
+
+ "vmull.s8 q15, d2, d8\n"
+ "vmull.s8 q3, d2, d10\n"
+
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q2, d1, d11\n"
+ "vmlal.s8 q15, d3, d9\n"
+ "vmlal.s8 q3, d3, d11\n"
+
+ // Then pairwise accumulate in to q8, q9, q12, q13
+ "vpadal.s16 q8, q14\n"
+ "vpadal.s16 q9, q15\n"
+ "vpadal.s16 q12, q2\n"
+ "vpadal.s16 q13, q3\n"
+
+
+ // All accumulation over depth done. q6 - q13 contain the 4x32b
+ // accumulators for the 4x2 final matrix.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x2 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // q6-q13 now contain 4 x 32b
+ "vpadd.i32 d0, d12, d13\n"
+ "vpadd.i32 d1, d14, d15\n"
+ "vpadd.i32 d2, d16, d17\n"
+ "vpadd.i32 d3, d18, d19\n"
+ "vpadd.i32 d4, d20, d21\n"
+ "vpadd.i32 d5, d22, d23\n"
+ "vpadd.i32 d6, d24, d25\n"
+ "vpadd.i32 d7, d26, d27\n"
+
+ // d0-d7 each contain 2 x 32b accumulators.
+ // Need to add pairwise to get 1 x 32b for each of the 4x2 entries
+ // of destination, (Four 'd' registers total)
+ "vpadd.i32 d28, d0, d1\n"
+ "vpadd.i32 d29, d2, d3\n"
+ "vpadd.i32 d30, d4, d5\n"
+ "vpadd.i32 d31, d6, d7\n"
+
+ //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r1, r3\n" // Have we finished the last row?
+
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "add r4, r4, r1, lsl #2\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ // Go back to first row
+ "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "add r10, r10, r1, lsl #1\n"
+ "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "mov %[lhs_ptr], r4\n"
+ "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "mov %[rhs_ptr], r5\n"
+
+ // Now we load: bias data, LHS sums data, RHS sums data.
+
+ // First, load the base pointers from the params.
+ "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Let r8 be stack offset of the row or column variable, whichever
+ // is the channel index.
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "ite eq\n"
+ "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
+ "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
+ // Let r8 be the channel index.
+ "ldr r8, [sp, r8]\n"
+ // Compute the bias pointer, by conditionally using the channel index
+ // (r8) as offset into bias buffer (r1).
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "it ne\n"
+ "addne r1, r1, r8, lsl #2\n"
+
+ // Load 2 bias values. When the channel dimension is rows, we will load
+ // another 2 bias values just before performing the bias addition below,
+ // as this kernel has a 4x2 rectangular shape.
+ "vld1.32 {d24}, [r1]!\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Add to the bias values the product
+ // (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in
+ // https://arxiv.org/pdf/1712.05877.pdf
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "vdup.32 q9, r3\n"
+ "vadd.i32 d24, d24, d18\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // Jump based on channel dimension.
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows.
+ // Load the remaining 2 bias values, since we're on the width-4 side
+ // of this 4x2 kernel.
+ "vld1.32 {d25}, [r1]\n"
+ "vadd.i32 d25, d25, d19\n"
+ "vadd.i32 q14, q14, q12\n"
+ "vadd.i32 q15, q15, q12\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns.
+ "vdup.32 q10, d24[0]\n"
+ "vdup.32 q11, d24[1]\n"
+ "vadd.i32 q14, q14, q10\n"
+ "vadd.i32 q15, q15, q11\n"
+ "7:\n"
+
+ // LHS/RHS zero points
+ // Has RHS sums
+ "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ // Offset by current col * number of bytes per value
+ "add r3, r3, r4, lsl #2\n"
+ "vld1.32 { d12 }, [r3]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "vdup.32 q10, r5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "vmls.i32 q14, q10, d12[0]\n"
+ "vmls.i32 q15, q10, d12[1]\n"
+ "401:\n"
+
+ // Has LHS sums
+ "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Offset by current row * number of bytes per value
+ "add r2, r2, r4, lsl #2\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+
+ // Load 4 lhs_sums values.
+ "vld1.32 {d22, d23}, [r2]\n"
+ "vdup.32 d13, r5\n" // rhs_zero_point
+
+ // Compute lhs_sums * rhs_zero_point.
+ "vmul.i32 q11, q11, d13[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "vsub.s32 q14, q14, q11\n"
+ "vsub.s32 q15, q15, q11\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ // Compute the data pointers for the multiplier data
+ // r1 = exponent part
+ // r2 = fixedpoint part
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ // r6 has flags, r8 has channel index
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "it ne\n"
+ "addne r1, r1, r8, lsl #2\n"
+ "it ne\n"
+ "addne r2, r2, r8, lsl #2\n"
+
+ // Load the first 2 values of multiplier exponent and fixedpoint data
+ // Since this kernel is rectangular 4x2, we will only conditionally load
+ // 2 more values below.
+ "vld1.32 {d20}, [r1]!\n" // 2 values of multiplier_exponent
+ "vld1.32 {d12}, [r2]!\n" // 2 values of multiplier_fixedpoint
+
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "vmvn.i32 q8, #0\n"
+ "bne 8f\n"
+ // Case where channels are rows.
+ // Load the remaining 2 bias values, since we're on the width-4 side
+ // of this 4x2 kernel.
+ "vld1.32 {d21}, [r1]\n" // 2 more values of multiplier_exponent
+ "vld1.32 {d13}, [r2]\n" // 2 more values of multiplier_fixedpoint
+ "vmin.s32 q11, q10, q8\n"
+ "vsub.s32 q10, q10, q11\n"
+
+ // Apply the positive exponent part of the multiplier.
+ "vshl.s32 q14, q14, q10\n"
+ "vshl.s32 q15, q15, q10\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "vqdmulh.s32 q14, q14, q6\n"
+ "vqdmulh.s32 q15, q15, q6\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "vrshl.s32 q14, q14, q11\n"
+ "vrshl.s32 q15, q15, q11\n"
+ "b 9f\n"
+
+ "8:\n"
+ // Case where channels are columns.
+ "vmin.s32 d22, d20, d16\n"
+ "vsub.s32 d20, d20, d22\n"
+
+ // Apply the positive exponent part of the multiplier.
+ "vdup.32 q12, d20[0]\n"
+ "vdup.32 q13, d20[1]\n"
+ "vshl.s32 q14, q14, q12\n"
+ "vshl.s32 q15, q15, q13\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "vqdmulh.s32 q14, q14, d12[0]\n"
+ "vqdmulh.s32 q15, q15, d12[1]\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "vdup.32 q12, d22[0]\n"
+ "vdup.32 q13, d22[1]\n"
+ "vrshl.s32 q14, q14, q12\n"
+ "vrshl.s32 q15, q15, q13\n"
+
+ "9:\n"
+
+ "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ // Store uint8 values:
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in q14.
+ "vqmovn.s32 d28, q14\n"
+ "vqmovn.s32 d29, q15\n"
+
+ // At this point, d12 -- d26, d30, d31 aren't used anymore for the
+ // current block, so we can start clearing these accumulators for the
+ // next block (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the destination zero point into each of the 8 16-bit slots
+ // in a q register.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.16 q13, r4\n" // dst_zero_point
+
+ // Add the destination zero point
+ "vadd.i16 q14, q14, q13\n"
+
+ // Cast-and-saturate from int16 to uint8
+ // Now all 8 1-byte values are in d30.
+ "vqmovun.s16 d30, q14\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.8 d28, r2\n" // clamp_min
+ "vdup.8 d29, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.u8 d30, d30, d28\n"
+ // Apply the clamp_max bound
+ "vmin.u8 d30, d30, d29\n"
+
+ // Compute how much of the 4x2 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x2 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ "cmp r2, #2\n"
+ // Compute r2 = how many cols of the 4x2 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.8 {d30}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ "ldrb r10, [r3, r8]\n"
+ "strb r10, [r4, r8]\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #4\n"
+ "add r4, r4, r5\n"
+ "cmp r6, r2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x2 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #1\n"
+
+ "vst1.32 {d30[0]}, [r3]\n"
+ "add r4, r4, r5\n"
+ "mov r3, r4\n"
+ "vst1.32 {d30[1]}, [r3]\n"
+
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ // Store int8 values:
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in q14.
+ "vqmovn.s32 d28, q14\n"
+ "vqmovn.s32 d29, q15\n"
+
+ // At this point, d12 -- d26, d30, d31 aren't used anymore for the
+ // current block, so we can start clearing these accumulators for the
+ // next block (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the destination zero point into each of the 8 16-bit slots
+ // in a q register.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.16 q13, r4\n" // dst_zero_point
+
+ // Add the destination zero point
+ "vadd.i16 q14, q14, q13\n"
+
+ // Cast-and-saturate from int16 to int8
+ // Now all 8 1-byte values are in d30.
+ "vqmovn.s16 d30, q14\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.8 d28, r2\n" // clamp_min
+ "vdup.8 d29, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.s8 d30, d30, d28\n"
+ // Apply the clamp_max bound
+ "vmin.s8 d30, d30, d29\n"
+
+ // Compute how much of the 4x2 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x2 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ "cmp r2, #2\n"
+ // Compute r2 = how many cols of the 4x2 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.8 {d30}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ "ldrb r10, [r3, r8]\n"
+ "strb r10, [r4, r8]\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #4\n"
+ "add r4, r4, r5\n"
+ "cmp r6, r2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x2 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #1\n"
+
+ "vst1.32 {d30[0]}, [r3]\n"
+ "add r4, r4, r5\n"
+ "mov r3, r4\n"
+ "vst1.32 {d30[1]}, [r3]\n"
+
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Load the destination zero point into each of the 4 32-bit slots
+ // in a q register.
+ "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.32 q13, r4\n" // dst_zero_point
+ // Add the destination zero point
+ "vadd.s32 q14, q14, q13\n"
+ "vadd.s32 q15, q15, q13\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in q14.
+ "vqmovn.s32 d28, q14\n"
+ "vqmovn.s32 d29, q15\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.16 q12, r2\n" // clamp_min
+ "vdup.16 q13, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.s16 q14, q14, q12\n"
+ // Apply the clamp_max bound
+ "vmin.s16 q14, q14, q13\n"
+
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+
+ // Compute how much of the 4x2 block of destination 16-bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x2 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ "cmp r2, #2\n"
+ // Compute r2 = how many cols of the 4x2 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.16 {q14}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ // Shift of offset register for half-word loads not allowed in A32,
+ // so we shift, load/store, then shift back r8.
+ "lsl r8, r8, #1\n"
+ "ldrh r10, [r3, r8]\n"
+ "strh r10, [r4, r8]\n"
+ "lsr r8, r8, #1\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #8\n"
+ "add r4, r4, r5\n"
+ "cmp r6, r2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x2 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #2\n"
+
+ "vst1.16 {d28[0]}, [r3], r6\n"
+ "add r4, r4, r5\n"
+ "vst1.16 {d28[1]}, [r3], r6\n"
+ "vst1.16 {d28[2]}, [r3], r6\n"
+ "vst1.16 {d28[3]}, [r3], r6\n"
+ "mov r3, r4\n"
+ "vst1.16 {d29[0]}, [r3], r6\n"
+ "vst1.16 {d29[1]}, [r3], r6\n"
+ "vst1.16 {d29[2]}, [r3], r6\n"
+ "vst1.16 {d29[3]}, [r3], r6\n"
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #8\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q14)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // At this point, v20 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+
+ // Compute how much of the 4x2 block of destination 32 bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ "cmp r2, #2\n"
+ // Compute r2 = how many cols of the 4x2 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Set (r3 address, r4 stride) to write to dst_tmp_buf
+ "mov r3, %[dst_tmp_buf]\n"
+ "mov r4, #16\n"
+ "b 31f\n"
+
+ "30:\n"
+ // Yes, all of the 4x2 block fits.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // r3 address, r4 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r5\n"
+
+ "31:\n"
+
+ "vst1.32 {d28, d29}, [r3]\n"
+ "add r3, r3, r4\n"
+ "vst1.32 {d30, d31}, [r3]\n"
+
+ // If all of the 4x2 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 4x2 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "mov r3, %[dst_tmp_buf]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r5, #0\n"
+ "51:\n"
+ "ldr r10, [r3, r5, lsl #2]\n"
+ "str r10, [r4, r5, lsl #2]\n"
+ "add r5, r5, #1\n"
+ "cmp r5, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #16\n"
+ "add r4, r4, r8\n"
+ // r2 = how many cols of the 8x4 block fit
+ "cmp r6, r2\n"
+ "blt 50b\n"
+
+ "41:\n"
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #16\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r8, r3\n"
+
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add r8, r8, #4\n"
+ // Store new value of row
+ "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ // Move back to first row.
+ "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Move to the next column.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "add r4, r4, #2\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns)
+ "add r1, r1, r8, lsl #1\n"
+ // Store dst_col_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Store dst_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov r1, #16\n"
+
+ "ble 1b\n"
+
+ // Restore stack pointer.
+ "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ // clang-format on
+
+ : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
+ : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
+ // Clobber list must specify q registers (and not their constituent
+ // d registers). There is a (currently unexplained) slowdown if
+ // d registers are listed in the clobbers list.
+ "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
+ "q9", "q10", "q12", "q13", "q14", "q15");
+}
+
+// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS
+// is still packed as if it has two columns
+void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
+ profiler::ScopeLabel label("Kernel (kNeon)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+
+ RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // q6 - q13 are 128-bit (4x32b) accumulators.
+ // During accumulation, d0 -- d7 are used to load int8 data from LHS and
+ // d8 -- d11 from RHS:
+ // int8 RHS 16x1 block
+ // /------------|
+ // | d8.b[0] |
+ // | ... |
+ // | d8.b[7] |
+ // | d9.b[0] |
+ // | ... |
+ // | d9.b[7] |
+ // \------------/
+ // int8 LHS 4x16 block
+ // /-----------------------------------------\ /------------|
+ // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 |
+ // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 |
+ // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 |
+ // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 |
+ // \-----------------------------------------/ \------------/
+ // 128-bit accumulators 4x1 block
+ //
+ // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
+ // optimization for this kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
+
+ // clang-format off
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
+ // Skip the other column and advance the pointer.
+ "add %[rhs_ptr], %[rhs_ptr], #16\n"
+
+ "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
+ "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ // r1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov r1, #16\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // r1 is how many levels of depth we have already loaded
+ // data for, r10 is the total depth.
+ "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+ "cmp r1, r10\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Mult, mult-acc in to q14, q15
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q15, d2, d8\n"
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q15, d3, d9\n"
+
+ // Then pairwise accumulate in to q6, q7
+ "vpadal.s16 q6, q14\n"
+ "vpadal.s16 q7, q15\n"
+
+ // Mult, mult-acc in to q14, q15
+ "vmull.s8 q14, d4, d8\n"
+ "vmull.s8 q15, d6, d8\n"
+ "vmlal.s8 q14, d5, d9\n"
+ "vmlal.s8 q15, d7, d9\n"
+
+ // Then pairwise accumulate in to q8, q9
+ "vpadal.s16 q8, q14\n"
+ "vpadal.s16 q9, q15\n"
+
+
+ // Load the next 64 bytes of LHS and RHS data.
+ "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
+ // Skip the other column and advance the pointer.
+ "add %[rhs_ptr], %[rhs_ptr], #16\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add r1, r1, #16\n"
+
+ // Loop termination condition
+ "cmp r1, r10\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ // Mult, mult-acc in to q14, q15
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q15, d2, d8\n"
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q15, d3, d9\n"
+
+ // Then pairwise accumulate in to q6, q7
+ "vpadal.s16 q6, q14\n"
+ "vpadal.s16 q7, q15\n"
+
+ // Mult, mult-acc in to q14, q15
+ "vmull.s8 q14, d4, d8\n"
+ "vmull.s8 q15, d6, d8\n"
+ "vmlal.s8 q14, d5, d9\n"
+ "vmlal.s8 q15, d7, d9\n"
+
+ // Then pairwise accumulate in to q8, q9
+ "vpadal.s16 q8, q14\n"
+ "vpadal.s16 q9, q15\n"
+
+ // All accumulation over depth done. q6 - q9 contain the 4x32b
+ // accumulators for the 4x1 final matrix.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x2 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // q6-q9 now contain 4 x 32b
+ "vpadd.i32 d0, d12, d13\n"
+ "vpadd.i32 d1, d14, d15\n"
+ "vpadd.i32 d2, d16, d17\n"
+ "vpadd.i32 d3, d18, d19\n"
+
+ // d0-d4 each contain 2 x 32b accumulators.
+ // Need to add pairwise to get 1 x 32b for each of the 4x1 entries
+ // of destination, (Four 'd' registers total)
+ "vpadd.i32 d28, d0, d1\n"
+ "vpadd.i32 d29, d2, d3\n"
+
+ // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries.
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r1, r3\n" // Have we finished the last row?
+
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "add r4, r4, r1, lsl #2\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ // Go back to first row
+ "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "add r10, r10, r1, lsl #1\n"
+ "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "mov %[lhs_ptr], r4\n"
+ "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "mov %[rhs_ptr], r5\n"
+
+ // Now we load: bias data, LHS sums data, RHS sums data.
+
+ // First, load the base pointers from the params.
+ "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Offset these base pointers as needed given the current row, col.
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "it ne\n"
+ "addne r1, r1, r8, lsl #2\n"
+
+ // Load 4 bias values.
+ "vld1.32 {d24, d25}, [r1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
+ // Skip the other column and advance the pointer.
+ "add %[rhs_ptr], %[rhs_ptr], #16\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Add to the bias values the product
+ // (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in
+ // https://arxiv.org/pdf/1712.05877.pdf
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "vdup.32 q9, r3\n"
+ "vadd.i32 q12, q12, q9\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "vadd.i32 q14, q14, q12\n"
+
+ // LHS/RHS zero points
+ // Has RHS sums
+ "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ // Offset by current col * number of bytes per value
+ "add r3, r3, r4, lsl #2\n"
+ "vld1.32 { d12 }, [r3]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "vdup.32 q10, r5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "vmls.i32 q14, q10, d12[0]\n"
+ "401:\n"
+
+ // Has LHS sums
+ "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Offset by current row * number of bytes per value
+ "add r2, r2, r4, lsl #2\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+
+ // Load 4 lhs_sums values.
+ "vld1.32 {d22, d23}, [r2]\n"
+ "vdup.32 d13, r5\n" // rhs_zero_point
+
+ // Compute lhs_sums * rhs_zero_point.
+ "vmul.i32 q11, q11, d13[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "vsub.s32 q14, q14, q11\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "it ne\n"
+ "addne r1, r1, r4, lsl #2\n"
+
+ "vld1.32 {q10}, [r1]\n"
+
+ "vmvn.i32 q8, #0\n"
+ "vmin.s32 q13, q10, q8\n"
+ "vsub.s32 q12, q10, q13\n"
+
+ // Apply the positive exponent part of the multiplier.
+ "vshl.s32 q14, q14, q12\n"
+
+ // Load fixed point part of the multiplier
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ // r6 has flags, r4 has row
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "it ne\n"
+ "addne r1, r1, r4, lsl #2\n"
+ "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
+
+ // Apply the fixed-point part of the multiplier.
+ "vqdmulh.s32 q14, q14, q10\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "vrshl.s32 q14, q14, q13\n"
+
+ "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ // Store uint8 values:
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in d28.
+ "vqmovn.s32 d28, q14\n"
+
+ // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
+ // current block, so we can start clearing these accumulators for the
+ // next block (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the destination zero point into each of the 8 16-bit slots
+ // in a q register.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.16 q13, r4\n" // dst_zero_point
+
+ // Add the destination zero point
+ "vadd.i16 q14, q14, q13\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "vqmovun.s16 d30, q14\n"
+ // At this point, we only need 4 8-bit values in the lower half
+ // of d30.
+
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.8 d28, r2\n" // clamp_min
+ "vdup.8 d29, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.u8 d30, d30, d28\n"
+ // Apply the clamp_max bound
+ "vmin.u8 d30, d30, d29\n"
+
+ // Compute how much of the 4x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x1, there are some 4x1 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x1 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ // Test if r1==4, i.e. if all of the 4x1 block fits.
+ "cmp r1, r3\n"
+
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x1 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.8 {d30}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ "ldrb r10, [r3, r8]\n"
+ "strb r10, [r4, r8]\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #1\n"
+
+ "vst1.8 {d30[0]}, [r3], r6\n"
+ "vst1.8 {d30[1]}, [r3], r6\n"
+ "vst1.8 {d30[2]}, [r3], r6\n"
+ "vst1.8 {d30[3]}, [r3], r6\n"
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ // Store int8 values:
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in d28.
+ "vqmovn.s32 d28, q14\n"
+
+ // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
+ // current block, so we can start clearing these accumulators for the
+ // next block (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the destination zero point into each of the 8 16-bit slots
+ // in a q register.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.16 q13, r4\n" // dst_zero_point
+
+ // Add the destination zero point
+ "vadd.i16 q14, q14, q13\n"
+
+ // Cast-and-saturate from int16 to int8
+ "vqmovn.s16 d30, q14\n"
+ // At this point, we only need 4 8-bit values in the lower half
+ // of d30.
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.8 d28, r2\n" // clamp_min
+ "vdup.8 d29, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.s8 d30, d30, d28\n"
+ // Apply the clamp_max bound
+ "vmin.s8 d30, d30, d29\n"
+
+ // Compute how much of the 4x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x2 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ // Test if r1==4 i.e. if all of the 4x1 block fits.
+ "cmp r1, r3\n"
+
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.8 {d30}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ "ldrb r10, [r3, r8]\n"
+ "strb r10, [r4, r8]\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #1\n"
+
+ "vst1.8 {d30[0]}, [r3], r6\n"
+ "vst1.8 {d30[1]}, [r3], r6\n"
+ "vst1.8 {d30[2]}, [r3], r6\n"
+ "vst1.8 {d30[3]}, [r3], r6\n"
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Load the destination zero point into each of the 4 32-bit slots
+ // in a q register.
+ "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.32 q13, r4\n" // dst_zero_point
+ // Add the destination zero point
+ "vadd.s32 q14, q14, q13\n"
+ //"vadd.s32 q15, q15, q13\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in d28.
+ "vqmovn.s32 d28, q14\n"
+
+ // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.16 d24, r2\n" // clamp_min
+ "vdup.16 d26, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.s16 d28, d28, d24\n"
+ // Apply the clamp_max bound
+ "vmin.s16 d28, d28, d26\n"
+
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+
+ // Compute how much of the 4x1 block of destination 16-bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x1, there are some 4x1 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x1 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ // Test if r1==4, i.e. if all of the 4x1 block fits.
+ "cmp r1, r3\n"
+
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x1 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.16 {d28}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ // Shift of offset register for half-word loads not allowed in A32,
+ // so we shift, load/store, then shift back r8.
+ "lsl r8, r8, #1\n"
+ "ldrh r10, [r3, r8]\n"
+ "strh r10, [r4, r8]\n"
+ "lsr r8, r8, #1\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #2\n"
+
+ "vst1.16 {d28[0]}, [r3], r6\n"
+ "vst1.16 {d28[1]}, [r3], r6\n"
+ "vst1.16 {d28[2]}, [r3], r6\n"
+ "vst1.16 {d28[3]}, [r3], r6\n"
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #8\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q14)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // At this point, v20 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+
+ // Compute how much of the 4x1 block of destination 32 bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ // Test if r1==4, i.e. if all of the 4x1 block fits.
+ "cmp r1, r3\n"
+
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x1 block fits.
+ // Set (r3 address, r4 stride) to write to dst_tmp_buf
+ "mov r3, %[dst_tmp_buf]\n"
+ "mov r4, #16\n"
+ "b 31f\n"
+
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // r3 address, r4 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r5\n"
+
+ "31:\n"
+
+ "vst1.32 {d28, d29}, [r3]\n"
+
+ // If all of the 4x1 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 4x1 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "mov r3, %[dst_tmp_buf]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "50:\n"
+ "mov r5, #0\n"
+ "51:\n"
+ "ldr r10, [r3, r5, lsl #2]\n"
+ "str r10, [r4, r5, lsl #2]\n"
+ "add r5, r5, #1\n"
+ "cmp r5, r1\n"
+ "blt 51b\n"
+
+ "41:\n"
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #16\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r8, r3\n"
+
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add r8, r8, #4\n"
+ // Store new value of row
+ "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ // Move back to first row.
+ "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Move to the next column.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "add r4, r4, #2\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Increment dst_col_ptr by dst_stride (i.e. 1 column)
+ "add r1, r1, r8\n"
+ // Store dst_col_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Store dst_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov r1, #16\n"
+
+ "ble 1b\n"
+
+ // Restore stack pointer.
+ "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ // clang-format on
+
+ : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
+ : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
+ // Clobber list must specify q registers (and not their constituent
+ // d registers). There is a (currently unexplained) slowdown if
+ // d registers are listed in the clobbers list.
+ "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
+ "q9", "q10", "q12", "q13", "q14", "q15");
+}
+
+#undef RUY_OFFSET_BIAS
+#undef RUY_OFFSET_LHS_SUMS
+#undef RUY_OFFSET_RHS_SUMS
+#undef RUY_OFFSET_LHS_BASE_PTR
+#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
+#undef RUY_OFFSET_MULTIPLIER_EXPONENT
+#undef RUY_OFFSET_RHS_BASE_PTR
+#undef RUY_OFFSET_DST_BASE_PTR
+#undef RUY_OFFSET_LHS_ZERO_POINT
+#undef RUY_OFFSET_RHS_ZERO_POINT
+#undef RUY_OFFSET_DST_ZERO_POINT
+#undef RUY_OFFSET_PROD_ZP_DEPTH
+#undef RUY_OFFSET_START_ROW
+#undef RUY_OFFSET_START_COL
+#undef RUY_OFFSET_LAST_ROW
+#undef RUY_OFFSET_LAST_COL
+#undef RUY_OFFSET_DST_ROWS
+#undef RUY_OFFSET_DST_COLS
+#undef RUY_OFFSET_LHS_STRIDE
+#undef RUY_OFFSET_RHS_STRIDE
+#undef RUY_OFFSET_DST_STRIDE
+#undef RUY_OFFSET_DEPTH
+#undef RUY_OFFSET_CLAMP_MIN
+#undef RUY_OFFSET_CLAMP_MAX
+#undef RUY_OFFSET_FLAGS
+#undef RUY_OFFSET_DST_TYPE_ID
+
+#undef RUY_STACK_OFFSET_SIZE
+#undef RUY_STACK_OFFSET_DST_COL_PTR
+#undef RUY_STACK_OFFSET_DST_PTR
+#undef RUY_STACK_OFFSET_ROW
+#undef RUY_STACK_OFFSET_COL
+#undef RUY_STACK_OFFSET_LHS_COL_PTR
+#undef RUY_STACK_OFFSET_RHS_COL_PTR
+
+#endif // RUY_PLATFORM_NEON_32 && (RUY_OPT(ASM)
+} // namespace ruy
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
new file mode 100644
index 0000000..fe65d9c
--- /dev/null
+++ b/ruy/kernel_arm64.cc
@@ -0,0 +1,8075 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+
+#include "ruy/asm_helpers.h"
+#include "ruy/check_macros.h"
+#include "ruy/kernel_arm.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+
+#define RUY_ASM_LABEL_STORE_UINT8 91
+#define RUY_ASM_LABEL_STORE_INT8 92
+#define RUY_ASM_LABEL_STORE_INT16 93
+#define RUY_ASM_LABEL_STORE_INT32 94
+#define RUY_ASM_LABEL_AFTER_STORE 99
+
+#define RUY_OFFSET_BIAS 0
+#define RUY_OFFSET_LHS_SUMS 8
+#define RUY_OFFSET_RHS_SUMS 16
+#define RUY_OFFSET_LHS_BASE_PTR 24
+#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32
+#define RUY_OFFSET_MULTIPLIER_EXPONENT 40
+#define RUY_OFFSET_RHS_BASE_PTR 48
+#define RUY_OFFSET_DST_BASE_PTR 56
+#define RUY_OFFSET_LHS_ZERO_POINT 64
+#define RUY_OFFSET_RHS_ZERO_POINT 68
+#define RUY_OFFSET_DST_ZERO_POINT 72
+#define RUY_OFFSET_PROD_ZP_DEPTH 76
+#define RUY_OFFSET_START_ROW 80
+#define RUY_OFFSET_START_COL 84
+#define RUY_OFFSET_LAST_ROW 88
+#define RUY_OFFSET_LAST_COL 92
+#define RUY_OFFSET_DST_ROWS 96
+#define RUY_OFFSET_DST_COLS 100
+#define RUY_OFFSET_LHS_STRIDE 104
+#define RUY_OFFSET_RHS_STRIDE 108
+#define RUY_OFFSET_DST_STRIDE 112
+#define RUY_OFFSET_DEPTH 116
+#define RUY_OFFSET_CLAMP_MIN 120
+#define RUY_OFFSET_CLAMP_MAX 124
+#define RUY_OFFSET_FLAGS 128
+
+template <typename Params>
+void CheckOffsetsInKernelParams8bit(const Params&) {
+ static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
+ "");
+ static_assert(offsetof(Params, multiplier_fixedpoint) ==
+ RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
+ "");
+ static_assert(
+ offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
+ "");
+ static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
+ static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
+ static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
+ static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
+ static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
+ static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
+ static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
+ static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
+ static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
+ static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
+ static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
+ static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
+ static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
+ static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
+}
+
+// Fast-int8-trick kernel, similar to this production gemmlowp kernel:
+// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296
+//
+// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
+// since these are 64-bit, out-of-order and without dotprod support.
+void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
+ profiler::ScopeLabel label("Kernel (kNeon)");
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v3 are used to load int8 data from LHS and
+ // v4 -- v7 from RHS:
+ //
+ // int8 RHS 16x4 block
+ // /-----------------------------------------|
+ // |v4.b[0] ... v7.b[0] |
+ // | ... ... |
+ // |v4.b[15] ... v7.b[15] |
+ // \-----------------------------------------/
+ // int8 LHS 4x16 block
+ // /---------------------\ /-----------------------------------------|
+ // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
+ // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
+ // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
+ // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
+ // \---------------------/ \-----------------------------------------/
+ // int32 accumulators 4x4 block
+ //
+ // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
+ // optimization for this kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov w1, #16\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Some multiplications and 16-bit accumulation were already done above,
+ // so we start right away in the middle.
+ "sadalp v16.4s, v8.8h\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "smull v8.8h, v0.8b, v6.8b\n"
+ "sadalp v17.4s, v9.8h\n"
+ "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
+ "smull v9.8h, v1.8b, v6.8b\n"
+ "sadalp v18.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v6.8b\n"
+ "sadalp v19.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v6.8b\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v7.8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v7.8b\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v7.8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v7.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v6.16b\n"
+ "smlal2 v9.8h, v1.16b, v6.16b\n"
+ "smlal2 v10.8h, v2.16b, v6.16b\n"
+ "smlal2 v11.8h, v3.16b, v6.16b\n"
+
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+
+ "smlal2 v12.8h, v0.16b, v7.16b\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v13.8h, v1.16b, v7.16b\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v14.8h, v2.16b, v7.16b\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v15.8h, v3.16b, v7.16b\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "sadalp v24.4s, v8.8h\n"
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "sadalp v25.4s, v9.8h\n"
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "sadalp v26.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "sadalp v27.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "sadalp v28.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "sadalp v29.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "sadalp v30.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "sadalp v31.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+
+
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add w1, w1, #16\n"
+
+ // Loop termination condition
+ "cmp w1, w12\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ "sadalp v16.4s, v8.8h\n"
+ "smull v8.8h, v0.8b, v6.8b\n"
+ "sadalp v17.4s, v9.8h\n"
+ "smull v9.8h, v1.8b, v6.8b\n"
+ "sadalp v18.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v6.8b\n"
+ "sadalp v19.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v6.8b\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v7.8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v7.8b\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v7.8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v7.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v6.16b\n"
+ "smlal2 v9.8h, v1.16b, v6.16b\n"
+ "smlal2 v10.8h, v2.16b, v6.16b\n"
+ "smlal2 v11.8h, v3.16b, v6.16b\n"
+
+ "smlal2 v12.8h, v0.16b, v7.16b\n"
+ "smlal2 v13.8h, v1.16b, v7.16b\n"
+ "smlal2 v14.8h, v2.16b, v7.16b\n"
+ "smlal2 v15.8h, v3.16b, v7.16b\n"
+
+ "sadalp v24.4s, v8.8h\n"
+ "sadalp v25.4s, v9.8h\n"
+ "sadalp v26.4s, v10.8h\n"
+ "sadalp v27.4s, v11.8h\n"
+ "sadalp v28.4s, v12.8h\n"
+ "sadalp v29.4s, v13.8h\n"
+ "sadalp v30.4s, v14.8h\n"
+ "sadalp v31.4s, v15.8h\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 4x4 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x4 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Reduce 32bit accumulators horizontally.
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "addp v18.4s, v18.4s, v19.4s\n"
+ "addp v20.4s, v20.4s, v21.4s\n"
+ "addp v22.4s, v22.4s, v23.4s\n"
+ "addp v24.4s, v24.4s, v25.4s\n"
+ "addp v26.4s, v26.4s, v27.4s\n"
+ "addp v28.4s, v28.4s, v29.4s\n"
+ "addp v30.4s, v30.4s, v31.4s\n"
+
+ // Reduce 32bit accumulators horizontally, second pass
+ // (each pass adds pairwise. we need to add 4-wise).
+ "addp v16.4s, v16.4s, v18.4s\n"
+ "addp v17.4s, v20.4s, v22.4s\n"
+ "addp v18.4s, v24.4s, v26.4s\n"
+ "addp v19.4s, v28.4s, v30.4s\n"
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "mvni v8.4s, #0\n"
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+
+ // Now we load: bias data, LHS sums data, RHS sums data.
+
+ // First, load the base pointers from the params.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 4 bias values.
+ "ld1 {v14.4s}, [x1]\n"
+
+ // Load the multiplier_fixedpoint values.
+ "add x5, x4, x3, lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+ "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
+ "add v16.4s, v16.4s, v14.4s\n"
+ "add v17.4s, v17.4s, v14.4s\n"
+ "add v18.4s, v18.4s, v14.4s\n"
+ "add v19.4s, v19.4s, v14.4s\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v20.4s, v14.s[0]\n"
+ "dup v21.4s, v14.s[1]\n"
+ "dup v22.4s, v14.s[2]\n"
+ "dup v23.4s, v14.s[3]\n"
+ "add v16.4s, v16.4s, v20.4s\n"
+ "add v17.4s, v17.4s, v21.4s\n"
+ "add v18.4s, v18.4s, v22.4s\n"
+ "add v19.4s, v19.4s, v23.4s\n"
+ "7:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[1]\n"
+ "mls v18.4s, v10.4s, v14.s[2]\n"
+ "mls v19.4s, v10.4s, v14.s[3]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v11.4s\n"
+ "sub v18.4s, v18.4s, v11.4s\n"
+ "sub v19.4s, v19.4s, v11.4s\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, x3, lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ld1 {v14.4s}, [x1]\n"
+
+ "smin v11.4s, v8.4s, v14.4s\n"
+ "sub v12.4s, v14.4s, v11.4s\n"
+
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 8f\n"
+ // Case where channels are rows
+
+ // Apply the positive exponent part of the multiplier.
+ "sshl v16.4s, v16.4s, v12.4s\n"
+ "sshl v17.4s, v17.4s, v12.4s\n"
+ "sshl v18.4s, v18.4s, v12.4s\n"
+ "sshl v19.4s, v19.4s, v12.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqdmulh v16.4s, v16.4s, v15.4s\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqdmulh v18.4s, v18.4s, v15.4s\n"
+ "sqdmulh v19.4s, v19.4s, v15.4s\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v11.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v11.4s\n"
+ "b 9f\n"
+
+ "8:\n"
+ // Case where channels are columns
+
+ // Apply the positive exponent part of the multiplier.
+ "dup v20.4s, v12.s[0]\n"
+ "dup v21.4s, v12.s[1]\n"
+ "dup v22.4s, v12.s[2]\n"
+ "dup v23.4s, v12.s[3]\n"
+ "sshl v16.4s, v16.4s, v20.4s\n"
+ "sshl v17.4s, v17.4s, v21.4s\n"
+ "sshl v18.4s, v18.4s, v22.4s\n"
+ "sshl v19.4s, v19.4s, v23.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqdmulh v16.4s, v16.4s, v15.s[0]\n"
+ "sqdmulh v17.4s, v17.4s, v15.s[1]\n"
+ "sqdmulh v18.4s, v18.4s, v15.s[2]\n"
+ "sqdmulh v19.4s, v19.4s, v15.s[3]\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "dup v20.4s, v11.s[0]\n"
+ "dup v21.4s, v11.s[1]\n"
+ "dup v22.4s, v11.s[2]\n"
+ "dup v23.4s, v11.s[3]\n"
+ "srshl v16.4s, v16.4s, v20.4s\n"
+ "srshl v17.4s, v17.4s, v21.4s\n"
+ "srshl v18.4s, v18.4s, v22.4s\n"
+ "srshl v19.4s, v19.4s, v23.4s\n"
+ "9:\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtun v16.8b, v16.8h\n"
+ "sqxtun2 v16.16b, v17.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #4\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[4], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[5], [x3], #1\n"
+ "st1 {v16.b}[6], [x3], #1\n"
+ "st1 {v16.b}[7], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[8], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[9], [x3], #1\n"
+ "st1 {v16.b}[10], [x3], #1\n"
+ "st1 {v16.b}[11], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[12], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[13], [x3], #1\n"
+ "st1 {v16.b}[14], [x3], #1\n"
+ "st1 {v16.b}[15], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to int8
+ "sqxtn v16.8b, v16.8h\n"
+ "sqxtn2 v16.16b, v17.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #4\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[4], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[5], [x3], #1\n"
+ "st1 {v16.b}[6], [x3], #1\n"
+ "st1 {v16.b}[7], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[8], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[9], [x3], #1\n"
+ "st1 {v16.b}[10], [x3], #1\n"
+ "st1 {v16.b}[11], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[12], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[13], [x3], #1\n"
+ "st1 {v16.b}[14], [x3], #1\n"
+ "st1 {v16.b}[15], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.4h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+ "saddw v18.4s, v18.4s, v14.4h\n"
+ "saddw v19.4s, v19.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ "smax v17.8h, v17.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+ "smin v17.8h, v17.8h, v15.8h\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ "str q17, [%[dst_tmp_buf], #16]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[0], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.h}[1], [x3], #2\n"
+ "st1 {v16.h}[2], [x3], #2\n"
+ "st1 {v16.h}[3], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[4], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.h}[5], [x3], #2\n"
+ "st1 {v16.h}[6], [x3], #2\n"
+ "st1 {v16.h}[7], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.h}[0], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.h}[1], [x3], #2\n"
+ "st1 {v17.h}[2], [x3], #2\n"
+ "st1 {v17.h}[3], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.h}[4], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.h}[5], [x3], #2\n"
+ "st1 {v17.h}[6], [x3], #2\n"
+ "st1 {v17.h}[7], [x3], #2\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // At this point, v20 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ "str q17, [%[dst_tmp_buf], #16]\n"
+ "str q18, [%[dst_tmp_buf], #32]\n"
+ "str q19, [%[dst_tmp_buf], #48]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #16\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.s}[1], [x3], #4\n"
+ "st1 {v16.s}[2], [x3], #4\n"
+ "st1 {v16.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.s}[1], [x3], #4\n"
+ "st1 {v17.s}[2], [x3], #4\n"
+ "st1 {v17.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v18.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v18.s}[1], [x3], #4\n"
+ "st1 {v18.s}[2], [x3], #4\n"
+ "st1 {v18.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v19.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v19.s}[1], [x3], #4\n"
+ "st1 {v19.s}[2], [x3], #4\n"
+ "st1 {v19.s}[3], [x3], #4\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #4\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #4\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #16\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Similar to existing Kernel8bitNeon but specialized for the case of
+// RHS cols == 1.
+// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
+// since these are 64-bit, out-of-order and without dotprod support.
+void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) {
+ profiler::ScopeLabel label("Kernel (kNeon)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v19 are int32 accumulators.
+ // During accumulation, v0 -- v3 are used to load int8 data from LHS and
+ // v4 from RHS:
+ //
+ // int8 RHS 16x1 block
+ // /-----------|
+ // |v4.b[0] |
+ // | ... |
+ // |v4.b[15] |
+ // \-----------/
+ // int8 LHS 4x16 block
+ // /---------------------\ /-----------|
+ // |v0.b[0] ... v0.b[15] | |v16.4s |
+ // |v1.b[0] ... v1.b[15] | |v17.4s |
+ // |v2.b[0] ... v2.b[15] | |v18.4s |
+ // |v3.b[0] ... v3.b[15] | |v19.4s |
+ // \---------------------/ \-----------/
+ // int32 accumulators 4x1 block
+ //
+ // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
+ // optimization for this kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "add %[rhs_ptr], %[rhs_ptr], #48\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov w1, #16\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Some multiplications and 16-bit accumulation were already done above,
+ // so we start right away in the middle.
+ "sadalp v16.4s, v8.8h\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "add %[rhs_ptr], %[rhs_ptr], #48\n"
+ "sadalp v17.4s, v9.8h\n"
+ "sadalp v18.4s, v10.8h\n"
+ "sadalp v19.4s, v11.8h\n"
+
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add w1, w1, #16\n"
+
+ // Loop termination condition
+ "cmp w1, w12\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ "sadalp v16.4s, v8.8h\n"
+ "sadalp v17.4s, v9.8h\n"
+ "sadalp v18.4s, v10.8h\n"
+ "sadalp v19.4s, v11.8h\n"
+
+ // End of accumulation. The registers v16 -- v19 contain the final
+ // int32 accumulator values of the current 4x1 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x1 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Reduce 32bit accumulators horizontally.
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "addp v18.4s, v18.4s, v19.4s\n"
+
+ // Reduce 32bit accumulators horizontally, second pass
+ // (each pass adds pairwise. we need to add 4-wise).
+ "addp v16.4s, v16.4s, v18.4s\n"
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ // (still multiply column stride by 4 due to packing)
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "mvni v8.4s, #0\n"
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+ "add x5, x4, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+
+ "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
+ // Now we load: bias data, LHS sums data, RHS sums data.
+
+ // First, load the base pointers from the params.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ "add x5, x1, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 4 bias values.
+ "ld1 {v14.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "add %[rhs_ptr], %[rhs_ptr], #48\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // (all four 32-bit accumulators are in v16 at this point)
+ "add v16.4s, v16.4s, v14.4s\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ld1 {v14.4s}, [x1]\n"
+
+ "smin v11.4s, v8.4s, v14.4s\n"
+ "sub v12.4s, v14.4s, v11.4s\n"
+
+ // Apply the positive exponent part of the multiplier.
+ "sshl v16.4s, v16.4s, v12.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqdmulh v16.4s, v16.4s, v15.4s\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this instruction, all data is in lower half (64-bits) of v16
+ "sqxtn v16.4h, v16.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ // Now all data is in the first 32-bits of v16
+ "sqxtun v16.8b, v16.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+
+ // Compute how much of the 4x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x1, there are some 4x1 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x1 block fit
+ "csel w1, w1, w3, le\n"
+
+ // Test if w1==4, i.e. if all of the 4x1 block fits.
+ "cmp w1, w3\n"
+
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x1 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in the lower half (64 bits) of v16.
+ "sqxtn v16.4h, v16.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to int8
+ "sqxtn v16.8b, v16.8h\n"
+ // At this point, we only need 4 lowest 8-bit values in v16.
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+
+ // Test if w1==4, i.e. if all of the 4x1 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.4h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this instruction, all data is in lower half of v16.
+ "sqxtn v16.4h, v16.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[0], [x3], #2\n"
+ "st1 {v16.h}[1], [x3], #2\n"
+ "st1 {v16.h}[2], [x3], #2\n"
+ "st1 {v16.h}[3], [x3], #2\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+
+ // Test if w1==4 i.e. if all of the 4x1 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.s}[0], [x3], #4\n"
+ "st1 {v16.s}[1], [x3], #4\n"
+ "st1 {v16.s}[2], [x3], #4\n"
+ "st1 {v16.s}[3], [x3], #4\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #4\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #4\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov w1, #16\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19");
+}
+
+// Variant of the above Kernel8bitNeon, tuned for A55-ish CPUs.
+// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and
+// the original Cortex-A55, since these are 64-bit and do not support dotprod.
+//
+// While this kernel does not have a direct equivalent in gemmlowp, it was
+// developed based on insights that David Mansell at ARM shared with their
+// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful
+// comments. Specifically, see this comment about tuning for Cortex-A53:
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
+void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
+ profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v3 are used to load int8 data from LHS and
+ // v4 -- v7 from RHS:
+ //
+ // int8 RHS 16x4 block
+ // /-----------------------------------------|
+ // |v4.b[0] ... v7.b[0] |
+ // | ... ... |
+ // |v4.b[15] ... v7.b[15] |
+ // \-----------------------------------------/
+ // int8 LHS 4x16 block
+ // /---------------------\ /-----------------------------------------|
+ // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
+ // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
+ // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
+ // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
+ // \---------------------/ \-----------------------------------------/
+ // int32 accumulators 4x4 block
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ RUY_MAKE_ZERO(v16)
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ RUY_MAKE_ZERO(v17)
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ RUY_MAKE_ZERO(v18)
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ RUY_MAKE_ZERO(v19)
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v20)
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v21)
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v22)
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+ RUY_MAKE_ZERO(v23)
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v24)
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v25)
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v26)
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v27)
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v28)
+ "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v29)
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v30)
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v31)
+
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov w1, #16\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Some multiplications and 16-bit accumulation were already done above,
+ // so we start right away in the middle.
+ "sadalp v16.4s, v8.8h\n"
+ "ldr d4, [%[rhs_ptr], #0]\n"
+ "smull v8.8h, v0.8b, v6.8b\n"
+ "ldr x7, [%[rhs_ptr], #8]\n"
+ "sadalp v17.4s, v9.8h\n"
+ "ldr d5, [%[rhs_ptr], #16]\n"
+ "smull v9.8h, v1.8b, v6.8b\n"
+ "ldr x8, [%[rhs_ptr], #24]\n"
+ "sadalp v18.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v6.8b\n"
+ "sadalp v19.4s, v11.8h\n"
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ "smull v11.8h, v3.8b, v6.8b\n"
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ "sadalp v20.4s, v12.8h\n"
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add w1, w1, #16\n"
+ "smull v12.8h, v0.8b, v7.8b\n"
+ // Loop termination condition
+ "cmp w1, w12\n"
+ "sadalp v21.4s, v13.8h\n"
+ "ldr x3, [%[lhs_ptr], #-56]\n"
+ "smull v13.8h, v1.8b, v7.8b\n"
+ "ldr x4, [%[lhs_ptr], #-40]\n"
+ "sadalp v22.4s, v14.8h\n"
+ "ldr x5, [%[lhs_ptr], #-24]\n"
+ "smull v14.8h, v2.8b, v7.8b\n"
+ "ldr x6, [%[lhs_ptr], #-8]\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v7.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v6.16b\n"
+ "smlal2 v9.8h, v1.16b, v6.16b\n"
+ "smlal2 v10.8h, v2.16b, v6.16b\n"
+ "ldr x9, [%[rhs_ptr], #-24]\n"
+ "smlal2 v11.8h, v3.16b, v6.16b\n"
+ "ldr d6, [%[rhs_ptr], #-32]\n"
+ "smlal2 v12.8h, v0.16b, v7.16b\n"
+ "ldr d0, [%[lhs_ptr], #-64]\n"
+ "smlal2 v13.8h, v1.16b, v7.16b\n"
+ "ldr d1, [%[lhs_ptr], #-48]\n"
+ "smlal2 v14.8h, v2.16b, v7.16b\n"
+ "ins v4.d[1], x7\n"
+ "smlal2 v15.8h, v3.16b, v7.16b\n"
+ "ins v5.d[1], x8\n"
+
+ "ldr d2, [%[lhs_ptr], #-32]\n"
+ "ins v0.d[1], x3\n"
+ "sadalp v24.4s, v8.8h\n"
+ "ldr d3, [%[lhs_ptr], #-16]\n"
+ "ins v1.d[1], x4\n"
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "ins v2.d[1], x5\n"
+ "sadalp v25.4s, v9.8h\n"
+ "ins v3.d[1], x6\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "ldr d7, [%[rhs_ptr], #-16]\n"
+ "sadalp v26.4s, v10.8h\n"
+ "ldr x10, [%[rhs_ptr], #-8]\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "sadalp v27.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "sadalp v28.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "sadalp v29.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "sadalp v30.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "sadalp v31.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "ins v6.d[1], x9\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "ins v7.d[1], x10\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ "sadalp v16.4s, v8.8h\n"
+ "smull v8.8h, v0.8b, v6.8b\n"
+ "sadalp v17.4s, v9.8h\n"
+ "smull v9.8h, v1.8b, v6.8b\n"
+ "sadalp v18.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v6.8b\n"
+ "sadalp v19.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v6.8b\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v7.8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v7.8b\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v7.8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v7.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v6.16b\n"
+ "smlal2 v9.8h, v1.16b, v6.16b\n"
+ "smlal2 v10.8h, v2.16b, v6.16b\n"
+ "smlal2 v11.8h, v3.16b, v6.16b\n"
+
+ "smlal2 v12.8h, v0.16b, v7.16b\n"
+ "smlal2 v13.8h, v1.16b, v7.16b\n"
+ "smlal2 v14.8h, v2.16b, v7.16b\n"
+ "smlal2 v15.8h, v3.16b, v7.16b\n"
+
+ "sadalp v24.4s, v8.8h\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "sadalp v25.4s, v9.8h\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "sadalp v26.4s, v10.8h\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "sadalp v27.4s, v11.8h\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "sadalp v28.4s, v12.8h\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "sadalp v29.4s, v13.8h\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "sadalp v30.4s, v14.8h\n"
+ "sadalp v31.4s, v15.8h\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 4x4 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x4 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Reduce 32bit accumulators horizontally.
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "addp v18.4s, v18.4s, v19.4s\n"
+ "addp v20.4s, v20.4s, v21.4s\n"
+ "addp v22.4s, v22.4s, v23.4s\n"
+ "addp v24.4s, v24.4s, v25.4s\n"
+ "addp v26.4s, v26.4s, v27.4s\n"
+ "addp v28.4s, v28.4s, v29.4s\n"
+ "addp v30.4s, v30.4s, v31.4s\n"
+
+ // Reduce 32bit accumulators horizontally, second pass
+ // (each pass adds pairwise. we need to add 4-wise).
+ "addp v16.4s, v16.4s, v18.4s\n"
+ "addp v17.4s, v20.4s, v22.4s\n"
+ "addp v18.4s, v24.4s, v26.4s\n"
+ "addp v19.4s, v28.4s, v30.4s\n"
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "mvni v8.4s, #0\n"
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 4 bias values.
+ "ld1 {v14.4s}, [x1]\n"
+
+ // Load the multiplier_fixedpoint values.
+ "add x5, x4, x3, lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+ "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+ "ldr d0, [%[lhs_ptr], #0]\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
+
+ "add v16.4s, v16.4s, v14.4s\n"
+ "ldr d1, [%[lhs_ptr], #16]\n"
+ "add v17.4s, v17.4s, v14.4s\n"
+ "ldr d2, [%[lhs_ptr], #32]\n"
+ "add v18.4s, v18.4s, v14.4s\n"
+ "ldr d3, [%[lhs_ptr], #48]\n"
+ "add v19.4s, v19.4s, v14.4s\n"
+ "ldr d4, [%[rhs_ptr], #0]\n"
+ "ldr d5, [%[rhs_ptr], #16]\n"
+ "ldr d6, [%[rhs_ptr], #32]\n"
+ "ldr d7, [%[rhs_ptr], #48]\n"
+
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v20.4s, v14.s[0]\n"
+ "ldr d1, [%[lhs_ptr], #16]\n"
+ "dup v21.4s, v14.s[1]\n"
+ "ldr d2, [%[lhs_ptr], #32]\n"
+ "dup v22.4s, v14.s[2]\n"
+ "ldr d3, [%[lhs_ptr], #48]\n"
+ "dup v23.4s, v14.s[3]\n"
+ "ldr d4, [%[rhs_ptr], #0]\n"
+ "add v16.4s, v16.4s, v20.4s\n"
+ "ldr d5, [%[rhs_ptr], #16]\n"
+ "add v17.4s, v17.4s, v21.4s\n"
+ "ldr d6, [%[rhs_ptr], #32]\n"
+ "add v18.4s, v18.4s, v22.4s\n"
+ "ldr d7, [%[rhs_ptr], #48]\n"
+ "add v19.4s, v19.4s, v23.4s\n"
+ "7:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[1]\n"
+ "mls v18.4s, v10.4s, v14.s[2]\n"
+ "mls v19.4s, v10.4s, v14.s[3]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v11.4s\n"
+ "sub v18.4s, v18.4s, v11.4s\n"
+ "sub v19.4s, v19.4s, v11.4s\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, x3, lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ld1 {v14.4s}, [x1]\n"
+
+ "smin v11.4s, v8.4s, v14.4s\n"
+ "ldr x1, [%[lhs_ptr], #8]\n"
+ "sub v12.4s, v14.4s, v11.4s\n"
+
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 8f\n"
+ // Case where channels are rows
+
+
+ // Apply the positive exponent part of the multiplier.
+ "sshl v16.4s, v16.4s, v12.4s\n"
+ "ldr x2, [%[lhs_ptr], #24]\n"
+ "sshl v17.4s, v17.4s, v12.4s\n"
+ "ldr x3, [%[lhs_ptr], #40]\n"
+ "sshl v18.4s, v18.4s, v12.4s\n"
+ "ldr x4, [%[lhs_ptr], #56]\n"
+ "sshl v19.4s, v19.4s, v12.4s\n"
+
+
+ // Apply the fixed-point part of the multiplier.
+ "ins v0.d[1], x1\n"
+ "ldr x1, [%[rhs_ptr], #8]\n"
+ "sqdmulh v16.4s, v16.4s, v15.4s\n"
+ "ins v1.d[1], x2\n"
+ "ldr x2, [%[rhs_ptr], #24]\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
+ "ins v2.d[1], x3\n"
+ "ldr x3, [%[rhs_ptr], #40]\n"
+ "sqdmulh v18.4s, v18.4s, v15.4s\n"
+ "ins v3.d[1], x4\n"
+ "ldr x4, [%[rhs_ptr], #56]\n"
+ "sqdmulh v19.4s, v19.4s, v15.4s\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v11.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v11.4s\n"
+
+ "b 9f\n"
+
+ "8:\n"
+ // Case where channels are columns
+
+ // Apply the positive exponent part of the multiplier.
+ "dup v20.4s, v12.s[0]\n"
+ "ldr x2, [%[lhs_ptr], #24]\n"
+ "ldr x3, [%[lhs_ptr], #40]\n"
+ "dup v21.4s, v12.s[1]\n"
+ "ldr x4, [%[lhs_ptr], #56]\n"
+ "dup v22.4s, v12.s[2]\n"
+ "ins v0.d[1], x1\n"
+ "dup v23.4s, v12.s[3]\n"
+ "ldr x1, [%[rhs_ptr], #8]\n"
+ "sshl v16.4s, v16.4s, v20.4s\n"
+ "ins v1.d[1], x2\n"
+ "sshl v17.4s, v17.4s, v21.4s\n"
+ "ldr x2, [%[rhs_ptr], #24]\n"
+ "sshl v18.4s, v18.4s, v22.4s\n"
+ "ins v2.d[1], x3\n"
+ "sshl v19.4s, v19.4s, v23.4s\n"
+ "ldr x3, [%[rhs_ptr], #40]\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqdmulh v16.4s, v16.4s, v15.s[0]\n"
+ "ins v3.d[1], x4\n"
+ "sqdmulh v17.4s, v17.4s, v15.s[1]\n"
+ "ldr x4, [%[rhs_ptr], #56]\n"
+ "sqdmulh v18.4s, v18.4s, v15.s[2]\n"
+ "dup v20.4s, v11.s[0]\n"
+ "sqdmulh v19.4s, v19.4s, v15.s[3]\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "dup v21.4s, v11.s[1]\n"
+ "srshl v16.4s, v16.4s, v20.4s\n"
+ "dup v22.4s, v11.s[2]\n"
+ "srshl v17.4s, v17.4s, v21.4s\n"
+ "dup v23.4s, v11.s[3]\n"
+ "srshl v18.4s, v18.4s, v22.4s\n"
+ "srshl v19.4s, v19.4s, v23.4s\n"
+
+ "9:\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ "ins v4.d[1], x1\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "ins v5.d[1], x2\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "ins v6.d[1], x3\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "ins v7.d[1], x4\n"
+ RUY_MAKE_ZERO(v18)
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v19)
+
+ // Add the destination zero point
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ "dup v14.8h, v13.h[4]\n"
+ RUY_MAKE_ZERO(v20)
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v21)
+ "add v17.8h, v17.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v22)
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtun v16.8b, v16.8h\n"
+ RUY_MAKE_ZERO(v23)
+ "sqxtun2 v16.16b, v17.8h\n"
+ RUY_MAKE_ZERO(v24)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ RUY_MAKE_ZERO(v25)
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ RUY_MAKE_ZERO(v26)
+ "dup v14.16b, w2\n" // clamp_min
+ RUY_MAKE_ZERO(v27)
+ "dup v15.16b, w3\n" // clamp_max
+ RUY_MAKE_ZERO(v28)
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ RUY_MAKE_ZERO(v29)
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+ RUY_MAKE_ZERO(v30)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ RUY_MAKE_ZERO(v31)
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #4\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[4], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[5], [x3], #1\n"
+ "st1 {v16.b}[6], [x3], #1\n"
+ "st1 {v16.b}[7], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[8], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[9], [x3], #1\n"
+ "st1 {v16.b}[10], [x3], #1\n"
+ "st1 {v16.b}[11], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[12], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[13], [x3], #1\n"
+ "st1 {v16.b}[14], [x3], #1\n"
+ "st1 {v16.b}[15], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ "ins v4.d[1], x1\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "ins v5.d[1], x2\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "ins v6.d[1], x3\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "ins v7.d[1], x4\n"
+ RUY_MAKE_ZERO(v18)
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v19)
+
+ // Add the destination zero point
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ "dup v14.8h, v13.h[4]\n"
+ RUY_MAKE_ZERO(v20)
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v21)
+ "add v17.8h, v17.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v22)
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtn v16.8b, v16.8h\n"
+ RUY_MAKE_ZERO(v23)
+ "sqxtn2 v16.16b, v17.8h\n"
+ RUY_MAKE_ZERO(v24)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ RUY_MAKE_ZERO(v25)
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ RUY_MAKE_ZERO(v26)
+ "dup v14.16b, w2\n" // clamp_min
+ RUY_MAKE_ZERO(v27)
+ "dup v15.16b, w3\n" // clamp_max
+ RUY_MAKE_ZERO(v28)
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ RUY_MAKE_ZERO(v29)
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+ RUY_MAKE_ZERO(v30)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ RUY_MAKE_ZERO(v31)
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #4\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[4], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[5], [x3], #1\n"
+ "st1 {v16.b}[6], [x3], #1\n"
+ "st1 {v16.b}[7], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[8], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[9], [x3], #1\n"
+ "st1 {v16.b}[10], [x3], #1\n"
+ "st1 {v16.b}[11], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[12], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[13], [x3], #1\n"
+ "st1 {v16.b}[14], [x3], #1\n"
+ "st1 {v16.b}[15], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.4h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+ "saddw v18.4s, v18.4s, v14.4h\n"
+ "saddw v19.4s, v19.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "ins v4.d[1], x1\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "ins v5.d[1], x2\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "ins v6.d[1], x3\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "ins v7.d[1], x4\n"
+ RUY_MAKE_ZERO(v18)
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v19)
+
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ RUY_MAKE_ZERO(v20)
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ RUY_MAKE_ZERO(v25)
+ "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ RUY_MAKE_ZERO(v26)
+ "dup v14.8h, w2\n" // clamp_min
+ RUY_MAKE_ZERO(v27)
+ "dup v15.8h, w3\n" // clamp_max
+ RUY_MAKE_ZERO(v28)
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ "smax v17.8h, v17.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v29)
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+ "smin v17.8h, v17.8h, v15.8h\n"
+ RUY_MAKE_ZERO(v30)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ RUY_MAKE_ZERO(v31)
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ "str q17, [%[dst_tmp_buf], #16]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[0], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.h}[1], [x3], #2\n"
+ "st1 {v16.h}[2], [x3], #2\n"
+ "st1 {v16.h}[3], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[4], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.h}[5], [x3], #2\n"
+ "st1 {v16.h}[6], [x3], #2\n"
+ "st1 {v16.h}[7], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.h}[0], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.h}[1], [x3], #2\n"
+ "st1 {v17.h}[2], [x3], #2\n"
+ "st1 {v17.h}[3], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.h}[4], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.h}[5], [x3], #2\n"
+ "st1 {v17.h}[6], [x3], #2\n"
+ "st1 {v17.h}[7], [x3], #2\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ "ldr x1, [%[lhs_ptr], #8]\n"
+ "ldr x2, [%[lhs_ptr], #24]\n"
+ "ldr x3, [%[lhs_ptr], #40]\n"
+ "ldr x4, [%[lhs_ptr], #56]\n"
+
+ "ins v0.d[1], x1\n"
+ "ldr x1, [%[rhs_ptr], #8]\n"
+ "ins v1.d[1], x2\n"
+ "ldr x2, [%[rhs_ptr], #24]\n"
+ "ins v2.d[1], x3\n"
+ "ldr x3, [%[rhs_ptr], #40]\n"
+ "ins v3.d[1], x4\n"
+ "ldr x4, [%[rhs_ptr], #56]\n"
+ "ins v4.d[1], x1\n"
+ "ins v5.d[1], x2\n"
+ "ins v6.d[1], x3\n"
+ "ins v7.d[1], x4\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // At this point, v20 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+
+ RUY_MAKE_ZERO(v20)
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ RUY_MAKE_ZERO(v21)
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ RUY_MAKE_ZERO(v22)
+
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ RUY_MAKE_ZERO(v31)
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ "str q17, [%[dst_tmp_buf], #16]\n"
+ "str q18, [%[dst_tmp_buf], #32]\n"
+ "str q19, [%[dst_tmp_buf], #48]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #16\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.s}[1], [x3], #4\n"
+ "st1 {v16.s}[2], [x3], #4\n"
+ "st1 {v16.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.s}[1], [x3], #4\n"
+ "st1 {v17.s}[2], [x3], #4\n"
+ "st1 {v17.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v18.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v18.s}[1], [x3], #4\n"
+ "st1 {v18.s}[2], [x3], #4\n"
+ "st1 {v18.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v19.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v19.s}[1], [x3], #4\n"
+ "st1 {v19.s}[2], [x3], #4\n"
+ "st1 {v19.s}[3], [x3], #4\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #4\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #4\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #16\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params),[dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Kernel taking advantage of the optional dotprod instruction.
+// This is very similar to (and directly inspired by) this gemmlowp kernel
+// which was contributed by David Mansell at ARM:
+// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391
+//
+// Besides the ruy-ification, the main difference here is that we use a 8x8
+// instead of 12x8 width, so as to stick to power-of-two widths. This slightly
+// narrower kernel layout is still wide enough to achieve high performance
+// although we haven't actually performed a real comparison to know exactly
+// how this compares to ARM's aforementioned kernel.
+//
+// Relevant target CPUs for this kernel include ARM Cortex-A76,
+// since these are 64-bit, out-of-order and with dotprod support.
+void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel (kNeonDotprod)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v15 are used to load int8 data from LHS and
+ // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
+ // v3 are used to load a 4x8 block of RHS, like this:
+ //
+ // int8 RHS 4x8 block
+ // /-----------------------------------------|
+ // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
+ // | ... ... |
+ // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
+ // \-----------------------------------------/
+ // int8 LHS 8x4 block
+ // /---------------------\ /-----------------------------------------|
+ // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
+ // | ... ... | | ... ... |
+ // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
+ // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
+ // | ... ... | | ... ... |
+ // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // int32 accumulators 8x8 block
+ //
+ // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
+ // is repeated 4 times, using 4x more registers for LHS and RHS, so that
+ // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
+ //
+ // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
+ // unused, and v8 -- v15 are used for loading parameters used for the
+ // post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #4\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Optional, maximally-streaming, partial-unrolling (4x unrolled)
+ // optimization of the kernel inner loop (over depth). For more
+ // comments, see the non-unrolled loop below after the #endif.
+#if RUY_OPT(MAX_STREAMING)
+ "cmp w12, #32\n"
+ "blt 78f\n"
+
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v8.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v9.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
+ "mov w1, #16\n"
+
+ "and w3, w12, #-16\n"
+ "81:\n"
+ "add w1, w1, #16\n"
+
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ "ldr q0, [%[lhs_ptr], #0]\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ "ldr q2, [%[rhs_ptr], #0]\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+ "ldr q1, [%[lhs_ptr], #16]\n"
+
+ ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
+ ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
+ "ldr q3, [%[rhs_ptr], #16]\n"
+ ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
+ ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
+ ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
+ ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
+ ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
+ ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
+ ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
+ ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
+ ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
+ ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
+ "ldr q5, [%[lhs_ptr], #48]\n"
+ ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
+ ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
+ "ldr q7, [%[rhs_ptr], #48]\n"
+ ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
+ ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
+ "ldr q4, [%[lhs_ptr], #32]\n"
+
+ ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
+ ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
+ "ldr q6, [%[rhs_ptr], #32]\n"
+ ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
+ ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
+ ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
+ ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
+ ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
+ ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
+ ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
+ ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
+ ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
+ ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
+ "ldr q9, [%[lhs_ptr], #80]\n"
+ ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
+ ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
+ "ldr q11, [%[rhs_ptr], #80]\n"
+ ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
+ ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
+ "ldr q8, [%[lhs_ptr], #64]\n"
+
+ ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
+ ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
+ "ldr q10, [%[rhs_ptr], #64]\n"
+ ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
+ ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #128\n"
+ ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
+ ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #128\n"
+ ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
+ ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
+ ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
+ ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
+ "cmp w1, w3\n"
+ ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
+ ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
+ "ldr q13, [%[lhs_ptr], #-16]\n"
+ ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
+ ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
+ "ldr q15, [%[rhs_ptr], #-16]\n"
+ ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
+ ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
+ "ldr q12, [%[lhs_ptr], #-32]\n"
+
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ "ldr q14, [%[rhs_ptr], #-32]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ "blt 81b\n"
+
+ ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
+ ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
+ ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
+ ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
+ ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
+ ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
+ ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
+ ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
+ ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
+ ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
+ ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
+ ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
+ ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
+ ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
+ ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
+ ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
+
+ ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
+ ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
+ ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
+ ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
+ ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
+ ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
+ ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
+ ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
+ ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
+ ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
+ ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
+ ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
+ ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
+ ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
+ ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
+ ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
+
+ ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
+ ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
+ ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
+ ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
+ ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
+ ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
+ ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
+ ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
+ ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
+ ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
+ ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
+ ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
+ ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
+ ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
+ ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
+ ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
+
+ "78:\n"
+
+#endif // #if RUY_OPT(MAX_STREAMING)
+
+ // Ordinary kernel inner loop (over depth), the simpler loop that the
+ // above was an equivalent 4x-partially-unrolled version of.
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Because of the data that we have already loaded, we can start the
+ // loop body right away with some multiply-adds.
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ // Each iteration of this loop advances by 4 levels of depth.
+ "add w1, w1, #4\n"
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ // Loop termination condition.
+ "cmp w1, w12\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last 4 levels of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "mvni v8.4s, #0\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+ "add v15.4s, v15.4s, v9.4s\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
+ "add v16.4s, v16.4s, v14.4s\n"
+ "add v17.4s, v17.4s, v15.4s\n"
+ "add v18.4s, v18.4s, v14.4s\n"
+ "add v19.4s, v19.4s, v15.4s\n"
+ "add v20.4s, v20.4s, v14.4s\n"
+ "add v21.4s, v21.4s, v15.4s\n"
+ "add v22.4s, v22.4s, v14.4s\n"
+ "add v23.4s, v23.4s, v15.4s\n"
+ "add v24.4s, v24.4s, v14.4s\n"
+ "add v25.4s, v25.4s, v15.4s\n"
+ "add v26.4s, v26.4s, v14.4s\n"
+ "add v27.4s, v27.4s, v15.4s\n"
+ "add v28.4s, v28.4s, v14.4s\n"
+ "add v29.4s, v29.4s, v15.4s\n"
+ "add v30.4s, v30.4s, v14.4s\n"
+ "add v31.4s, v31.4s, v15.4s\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v10.4s, v14.s[0]\n"
+ "dup v11.4s, v14.s[1]\n"
+ "dup v12.4s, v14.s[2]\n"
+ "dup v13.4s, v14.s[3]\n"
+ "add v16.4s, v16.4s, v10.4s\n"
+ "add v17.4s, v17.4s, v10.4s\n"
+ "add v18.4s, v18.4s, v11.4s\n"
+ "add v19.4s, v19.4s, v11.4s\n"
+ "add v20.4s, v20.4s, v12.4s\n"
+ "add v21.4s, v21.4s, v12.4s\n"
+ "add v22.4s, v22.4s, v13.4s\n"
+ "add v23.4s, v23.4s, v13.4s\n"
+ "dup v10.4s, v15.s[0]\n"
+ "dup v11.4s, v15.s[1]\n"
+ "dup v12.4s, v15.s[2]\n"
+ "dup v13.4s, v15.s[3]\n"
+ "add v24.4s, v24.4s, v10.4s\n"
+ "add v25.4s, v25.4s, v10.4s\n"
+ "add v26.4s, v26.4s, v11.4s\n"
+ "add v27.4s, v27.4s, v11.4s\n"
+ "add v28.4s, v28.4s, v12.4s\n"
+ "add v29.4s, v29.4s, v12.4s\n"
+ "add v30.4s, v30.4s, v13.4s\n"
+ "add v31.4s, v31.4s, v13.4s\n"
+ "7:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3], #16\n"
+ "ld1 {v15.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[0]\n"
+ "mls v18.4s, v10.4s, v14.s[1]\n"
+ "mls v19.4s, v10.4s, v14.s[1]\n"
+ "mls v20.4s, v10.4s, v14.s[2]\n"
+ "mls v21.4s, v10.4s, v14.s[2]\n"
+ "mls v22.4s, v10.4s, v14.s[3]\n"
+ "mls v23.4s, v10.4s, v14.s[3]\n"
+ "mls v24.4s, v10.4s, v15.s[0]\n"
+ "mls v25.4s, v10.4s, v15.s[0]\n"
+ "mls v26.4s, v10.4s, v15.s[1]\n"
+ "mls v27.4s, v10.4s, v15.s[1]\n"
+ "mls v28.4s, v10.4s, v15.s[2]\n"
+ "mls v29.4s, v10.4s, v15.s[2]\n"
+ "mls v30.4s, v10.4s, v15.s[3]\n"
+ "mls v31.4s, v10.4s, v15.s[3]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2], #16\n"
+ "ld1 {v12.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ "mul v12.4s, v12.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v12.4s\n"
+ "sub v18.4s, v18.4s, v11.4s\n"
+ "sub v19.4s, v19.4s, v12.4s\n"
+ "sub v20.4s, v20.4s, v11.4s\n"
+ "sub v21.4s, v21.4s, v12.4s\n"
+ "sub v22.4s, v22.4s, v11.4s\n"
+ "sub v23.4s, v23.4s, v12.4s\n"
+ "sub v24.4s, v24.4s, v11.4s\n"
+ "sub v25.4s, v25.4s, v12.4s\n"
+ "sub v26.4s, v26.4s, v11.4s\n"
+ "sub v27.4s, v27.4s, v12.4s\n"
+ "sub v28.4s, v28.4s, v11.4s\n"
+ "sub v29.4s, v29.4s, v12.4s\n"
+ "sub v30.4s, v30.4s, v11.4s\n"
+ "sub v31.4s, v31.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+ // Compute the multiplier_exponent pointer
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, x3, lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+ // Load multiplier_exponent
+ "ldr q9, [x1]\n"
+ "ldr q10, [x1, #16]\n"
+ // Separate positive and negative exponents
+ "smin v11.4s, v8.4s, v9.4s\n"
+ "smin v12.4s, v8.4s, v10.4s\n"
+ "sub v9.4s, v9.4s, v11.4s\n"
+ "sub v10.4s, v10.4s, v12.4s\n"
+
+ // Compute the multiplier_fixedpoint pointer
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "add x5, x4, x3, lsl #2\n"
+ "csel x4, x4, x5, eq\n"
+ // Load multiplier_fixedpoint
+ "ldr q14, [x4]\n"
+ "ldr q15, [x4, #16]\n"
+
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 8f\n"
+ // Case where channels are rows
+
+ // Apply the positive exponent part of the multiplier.
+ "sshl v16.4s, v16.4s, v9.4s\n"
+ "sshl v17.4s, v17.4s, v10.4s\n"
+ "sshl v18.4s, v18.4s, v9.4s\n"
+ "sshl v19.4s, v19.4s, v10.4s\n"
+ "sshl v20.4s, v20.4s, v9.4s\n"
+ "sshl v21.4s, v21.4s, v10.4s\n"
+ "sshl v22.4s, v22.4s, v9.4s\n"
+ "sshl v23.4s, v23.4s, v10.4s\n"
+ "sshl v24.4s, v24.4s, v9.4s\n"
+ "sshl v25.4s, v25.4s, v10.4s\n"
+ "sshl v26.4s, v26.4s, v9.4s\n"
+ "sshl v27.4s, v27.4s, v10.4s\n"
+ "sshl v28.4s, v28.4s, v9.4s\n"
+ "sshl v29.4s, v29.4s, v10.4s\n"
+ "sshl v30.4s, v30.4s, v9.4s\n"
+ "sshl v31.4s, v31.4s, v10.4s\n"
+ "10:\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqdmulh v16.4s, v16.4s, v14.4s\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqdmulh v18.4s, v18.4s, v14.4s\n"
+ "sqdmulh v19.4s, v19.4s, v15.4s\n"
+ "sqdmulh v20.4s, v20.4s, v14.4s\n"
+ "sqdmulh v21.4s, v21.4s, v15.4s\n"
+ "sqdmulh v22.4s, v22.4s, v14.4s\n"
+ "sqdmulh v23.4s, v23.4s, v15.4s\n"
+ "sqdmulh v24.4s, v24.4s, v14.4s\n"
+ "sqdmulh v25.4s, v25.4s, v15.4s\n"
+ "sqdmulh v26.4s, v26.4s, v14.4s\n"
+ "sqdmulh v27.4s, v27.4s, v15.4s\n"
+ "sqdmulh v28.4s, v28.4s, v14.4s\n"
+ "sqdmulh v29.4s, v29.4s, v15.4s\n"
+ "sqdmulh v30.4s, v30.4s, v14.4s\n"
+ "sqdmulh v31.4s, v31.4s, v15.4s\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v12.4s\n"
+ "srshl v20.4s, v20.4s, v11.4s\n"
+ "srshl v21.4s, v21.4s, v12.4s\n"
+ "srshl v22.4s, v22.4s, v11.4s\n"
+ "srshl v23.4s, v23.4s, v12.4s\n"
+ "srshl v24.4s, v24.4s, v11.4s\n"
+ "srshl v25.4s, v25.4s, v12.4s\n"
+ "srshl v26.4s, v26.4s, v11.4s\n"
+ "srshl v27.4s, v27.4s, v12.4s\n"
+ "srshl v28.4s, v28.4s, v11.4s\n"
+ "srshl v29.4s, v29.4s, v12.4s\n"
+ "srshl v30.4s, v30.4s, v11.4s\n"
+ "srshl v31.4s, v31.4s, v12.4s\n"
+ "b 9f\n"
+
+ "8:\n"
+ // Case where channels are columns
+
+ // Apply the positive exponent part of the multiplier.
+ "dup v4.4s, v9.s[0]\n"
+ "dup v5.4s, v9.s[1]\n"
+ "dup v6.4s, v9.s[2]\n"
+ "dup v7.4s, v9.s[3]\n"
+ "sshl v16.4s, v16.4s, v4.4s\n"
+ "sshl v17.4s, v17.4s, v4.4s\n"
+ "sshl v18.4s, v18.4s, v5.4s\n"
+ "sshl v19.4s, v19.4s, v5.4s\n"
+ "sshl v20.4s, v20.4s, v6.4s\n"
+ "sshl v21.4s, v21.4s, v6.4s\n"
+ "sshl v22.4s, v22.4s, v7.4s\n"
+ "sshl v23.4s, v23.4s, v7.4s\n"
+ "dup v4.4s, v10.s[0]\n"
+ "dup v5.4s, v10.s[1]\n"
+ "dup v6.4s, v10.s[2]\n"
+ "dup v7.4s, v10.s[3]\n"
+ "sshl v24.4s, v24.4s, v4.4s\n"
+ "sshl v25.4s, v25.4s, v4.4s\n"
+ "sshl v26.4s, v26.4s, v5.4s\n"
+ "sshl v27.4s, v27.4s, v5.4s\n"
+ "sshl v28.4s, v28.4s, v6.4s\n"
+ "sshl v29.4s, v29.4s, v6.4s\n"
+ "sshl v30.4s, v30.4s, v7.4s\n"
+ "sshl v31.4s, v31.4s, v7.4s\n"
+ "11:\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqdmulh v16.4s, v16.4s, v14.s[0]\n"
+ "sqdmulh v17.4s, v17.4s, v14.s[0]\n"
+ "sqdmulh v18.4s, v18.4s, v14.s[1]\n"
+ "sqdmulh v19.4s, v19.4s, v14.s[1]\n"
+ "sqdmulh v20.4s, v20.4s, v14.s[2]\n"
+ "sqdmulh v21.4s, v21.4s, v14.s[2]\n"
+ "sqdmulh v22.4s, v22.4s, v14.s[3]\n"
+ "sqdmulh v23.4s, v23.4s, v14.s[3]\n"
+ "sqdmulh v24.4s, v24.4s, v15.s[0]\n"
+ "sqdmulh v25.4s, v25.4s, v15.s[0]\n"
+ "sqdmulh v26.4s, v26.4s, v15.s[1]\n"
+ "sqdmulh v27.4s, v27.4s, v15.s[1]\n"
+ "sqdmulh v28.4s, v28.4s, v15.s[2]\n"
+ "sqdmulh v29.4s, v29.4s, v15.s[2]\n"
+ "sqdmulh v30.4s, v30.4s, v15.s[3]\n"
+ "sqdmulh v31.4s, v31.4s, v15.s[3]\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "dup v4.4s, v11.s[0]\n"
+ "dup v5.4s, v11.s[1]\n"
+ "dup v6.4s, v11.s[2]\n"
+ "dup v7.4s, v11.s[3]\n"
+ "srshl v16.4s, v16.4s, v4.4s\n"
+ "srshl v17.4s, v17.4s, v4.4s\n"
+ "srshl v18.4s, v18.4s, v5.4s\n"
+ "srshl v19.4s, v19.4s, v5.4s\n"
+ "srshl v20.4s, v20.4s, v6.4s\n"
+ "srshl v21.4s, v21.4s, v6.4s\n"
+ "srshl v22.4s, v22.4s, v7.4s\n"
+ "srshl v23.4s, v23.4s, v7.4s\n"
+ "dup v4.4s, v12.s[0]\n"
+ "dup v5.4s, v12.s[1]\n"
+ "dup v6.4s, v12.s[2]\n"
+ "dup v7.4s, v12.s[3]\n"
+ "srshl v24.4s, v24.4s, v4.4s\n"
+ "srshl v25.4s, v25.4s, v4.4s\n"
+ "srshl v26.4s, v26.4s, v5.4s\n"
+ "srshl v27.4s, v27.4s, v5.4s\n"
+ "srshl v28.4s, v28.4s, v6.4s\n"
+ "srshl v29.4s, v29.4s, v6.4s\n"
+ "srshl v30.4s, v30.4s, v7.4s\n"
+ "srshl v31.4s, v31.4s, v7.4s\n"
+ "9:\n"
+
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+ "add v18.8h, v18.8h, v14.8h\n"
+ "add v19.8h, v19.8h, v14.8h\n"
+ "add v20.8h, v20.8h, v14.8h\n"
+ "add v21.8h, v21.8h, v14.8h\n"
+ "add v22.8h, v22.8h, v14.8h\n"
+ "add v23.8h, v23.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtun v16.8b, v16.8h\n"
+ "sqxtun2 v16.16b, v17.8h\n"
+ "sqxtun v17.8b, v18.8h\n"
+ "sqxtun2 v17.16b, v19.8h\n"
+ "sqxtun v18.8b, v20.8h\n"
+ "sqxtun2 v18.16b, v21.8h\n"
+ "sqxtun v19.8b, v22.8h\n"
+ "sqxtun2 v19.16b, v23.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ "umax v17.16b, v17.16b, v14.16b\n"
+ "umax v18.16b, v18.16b, v14.16b\n"
+ "umax v19.16b, v19.16b, v14.16b\n"
+
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+ "umin v17.16b, v17.16b, v15.16b\n"
+ "umin v18.16b, v18.16b, v15.16b\n"
+ "umin v19.16b, v19.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+ "dup d21, v17.d[1]\n"
+ "dup d22, v18.d[1]\n"
+ "dup d23, v19.d[1]\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+ "add v18.8h, v18.8h, v14.8h\n"
+ "add v19.8h, v19.8h, v14.8h\n"
+ "add v20.8h, v20.8h, v14.8h\n"
+ "add v21.8h, v21.8h, v14.8h\n"
+ "add v22.8h, v22.8h, v14.8h\n"
+ "add v23.8h, v23.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtn v16.8b, v16.8h\n"
+ "sqxtn2 v16.16b, v17.8h\n"
+ "sqxtn v17.8b, v18.8h\n"
+ "sqxtn2 v17.16b, v19.8h\n"
+ "sqxtn v18.8b, v20.8h\n"
+ "sqxtn2 v18.16b, v21.8h\n"
+ "sqxtn v19.8b, v22.8h\n"
+ "sqxtn2 v19.16b, v23.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ "smax v17.16b, v17.16b, v14.16b\n"
+ "smax v18.16b, v18.16b, v14.16b\n"
+ "smax v19.16b, v19.16b, v14.16b\n"
+
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+ "smin v17.16b, v17.16b, v15.16b\n"
+ "smin v18.16b, v18.16b, v15.16b\n"
+ "smin v19.16b, v19.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+ "dup d21, v17.d[1]\n"
+ "dup d22, v18.d[1]\n"
+ "dup d23, v19.d[1]\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 130f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 131f\n"
+ "130:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "131:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 141f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "150:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "151:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 151b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 150b\n"
+ "141:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+ "saddw v18.4s, v18.4s, v14.4h\n"
+ "saddw v19.4s, v19.4s, v14.4h\n"
+ "saddw v20.4s, v20.4s, v14.4h\n"
+ "saddw v21.4s, v21.4s, v14.4h\n"
+ "saddw v22.4s, v22.4s, v14.4h\n"
+ "saddw v23.4s, v23.4s, v14.4h\n"
+ "saddw v24.4s, v24.4s, v14.4h\n"
+ "saddw v25.4s, v25.4s, v14.4h\n"
+ "saddw v26.4s, v26.4s, v14.4h\n"
+ "saddw v27.4s, v27.4s, v14.4h\n"
+ "saddw v28.4s, v28.4s, v14.4h\n"
+ "saddw v29.4s, v29.4s, v14.4h\n"
+ "saddw v30.4s, v30.4s, v14.4h\n"
+ "saddw v31.4s, v31.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ "smax v17.8h, v17.8h, v14.8h\n"
+ "smax v18.8h, v18.8h, v14.8h\n"
+ "smax v19.8h, v19.8h, v14.8h\n"
+ "smax v20.8h, v20.8h, v14.8h\n"
+ "smax v21.8h, v21.8h, v14.8h\n"
+ "smax v22.8h, v22.8h, v14.8h\n"
+ "smax v23.8h, v23.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+ "smin v17.8h, v17.8h, v15.8h\n"
+ "smin v18.8h, v18.8h, v15.8h\n"
+ "smin v19.8h, v19.8h, v15.8h\n"
+ "smin v20.8h, v20.8h, v15.8h\n"
+ "smin v21.8h, v21.8h, v15.8h\n"
+ "smin v22.8h, v22.8h, v15.8h\n"
+ "smin v23.8h, v23.8h, v15.8h\n"
+
+ // Compute how much of the 8x8 block of destination 16bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 230f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #16\n"
+ "b 231f\n"
+ "230:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "231:\n"
+
+ // Write our 16bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 241f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "250:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "251:\n"
+ "ldrsh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 251b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #16\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 250b\n"
+ "241:\n"
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // Compute how much of the 8x8 block of destination 32it values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 330f\n"
+ // Not all of the 8x8 block fits.
+ // Write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "st1 {v16.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v16)
+ "st1 {v17.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v17)
+ "st1 {v18.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v18)
+ "st1 {v19.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v19)
+ "st1 {v20.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v20)
+ "st1 {v21.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v21)
+ "st1 {v22.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v22)
+ "st1 {v23.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v23)
+ "st1 {v24.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v24)
+ "st1 {v25.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v25)
+ "st1 {v26.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v26)
+ "st1 {v27.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v27)
+ "st1 {v28.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v28)
+ "st1 {v29.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v29)
+ "st1 {v30.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v30)
+ "st1 {v31.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v31)
+
+ "b 331f\n"
+
+ "330:\n"
+ // Yes, all of the 8x8 block fits.
+ "mov x4, %[dst_ptr]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.4s, v17.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v18.4s, v19.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v20.4s, v21.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v22.4s, v23.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v24.4s, v25.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v26.4s, v27.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v28.4s, v29.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v30.4s, v31.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ "331:\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 341f\n"
+
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "350:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "351:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 351b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 350b\n"
+ "341:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #4\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Similar to the above 8-bit dotprod kernel, but specialized for the case of
+// RHS cols == 1.
+// Relevant target CPUs for this kernel include ARM Cortex-A76,
+// since these are 64-bit, out-of-order and with dotprod support.
+void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel (kNeonDotprod)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v15 are used to load int8 data from LHS and
+ // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
+ // v3 are used to load a 4x8 block of RHS, like this:
+ //
+ // int8 RHS 4x1 block
+ // /-------|
+ // |v2.b[0]|
+ // | ... |
+ // |v2.b[3]|
+ // \-------/
+ // int8 LHS 8x4 block
+ // /---------------------\ /--------|
+ // |v0.b[0] ... v0.b[3] | |v16.s[0]|
+ // | ... ... | | ... |
+ // |v0.b[12] ... v0.b[15]| |v16.s[3]|
+ // |v1.b[0] ... v1.b[3] | |v17.s[0]|
+ // | ... ... | | ... |
+ // |v1.b[12] ... v1.b[15]| |v17.s[3]|
+ // \---------------------/ \--------/
+ // int32 accumulators 8x1 block
+ //
+ // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
+ // is repeated 4 times, using 4x more registers for LHS and RHS, so that
+ // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
+ //
+ // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
+ // unused, and v8 -- v15 are used for loading parameters used for the
+ // post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #32\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #4\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Ordinary kernel inner loop (over depth), the simpler loop that the
+ // above was an equivalent 4x-partially-unrolled version of.
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Because of the data that we have already loaded, we can start the
+ // loop body right away with some multiply-adds.
+ // Each iteration of this loop advances by 4 levels of depth.
+ "add w1, w1, #4\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ // Loop termination condition.
+ "cmp w1, w12\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #32\n"
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last 4 levels of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "mvni v8.4s, #0\n"
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+ "add x5, x4, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+ "add x5, x1, %x[row], lsl #2\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #32\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+ "add v15.4s, v15.4s, v9.4s\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "add v16.4s, v16.4s, v14.4s\n"
+ "add v17.4s, v17.4s, v15.4s\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3], #16\n"
+ "ld1 {v15.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[0]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2], #16\n"
+ "ld1 {v12.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ "mul v12.4s, v12.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ldr q9, [x1]\n"
+ "ldr q10, [x1, #16]\n"
+
+ "smin v11.4s, v8.4s, v9.4s\n"
+ "smin v12.4s, v8.4s, v10.4s\n"
+ "sub v9.4s, v9.4s, v11.4s\n"
+ "sub v10.4s, v10.4s, v12.4s\n"
+
+ // Apply the positive exponent part of the multiplier.
+ "sshl v16.4s, v16.4s, v9.4s\n"
+ "sshl v17.4s, v17.4s, v10.4s\n"
+ "403:\n"
+
+ "ldr q14, [x4]\n" // multiplier_fixedpoint
+ "ldr q15, [x4, #16]\n" // multiplier_fixedpoint
+
+ // Apply the fixed-point part of the multiplier.
+ "sqdmulh v16.4s, v16.4s, v14.4s\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ // All data in v16 at this point.
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8, leaving all data in the
+ // lower half of v16.
+ "sqxtun v16.8b, v16.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+
+ // Compute how much of the 8x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x1, there are some 8x1 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+
+ // Test if w1==8, i.e. if all of the 8x1 block fits.
+ "cmp w1, w3\n"
+ // Yes, all of the 8x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v16.8b}, [x3]\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtn v16.8b, v16.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+
+ // Compute how much of the 8x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+
+ // Test if w1==8, i.e. if all of the 8x1 block fits.
+ "cmp w1, w3\n"
+ // Yes, all of the 8x1 block fits, go to fast path.
+ "beq 130f\n"
+ // Not all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 131f\n"
+ "130:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "131:\n"
+
+ // Write our 8bit values to the destination
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v16.8b}, [x3]\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 141f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "150:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "151:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 151b\n"
+ "141:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+
+ // Compute how much of the 8x1 block of destination 16bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x1 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+
+ // Test if w1==8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ // Yes, all of the 8x1 block fits, go to fast path.
+ "beq 230f\n"
+ // Not all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #16\n"
+ "b 231f\n"
+ "230:\n"
+ // Yes, all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "231:\n"
+
+ // Write our 16bit values to the destination
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v16.8h}, [x3]\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // If all of the 8x1 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 241f\n"
+ // Not all of the 8x1 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "250:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "251:\n"
+ "ldrsh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 251b\n"
+ "241:\n"
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // Compute how much of the 8x1 block of destination 32 bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x1, there are some 8x1 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ // Yes, all of the 8x1 block fits, go to fast path.
+ "beq 330f\n"
+ // Not all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #16\n"
+
+ // Write our 32bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.4s}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.4s}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+
+ "b 331f\n"
+
+ "330:\n"
+ // Yes, all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x4, %[dst_ptr]\n"
+ "mov x3, x4\n"
+
+ // Write our 32bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.4s, v17.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "331:\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 341f\n"
+
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "350:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "mov w5, #0\n"
+ "351:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 351b\n"
+ "341:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #4\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17");
+}
+
+// Variant of the above Kernel8bitNeonDotprod, tuned for in-order
+// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1,
+// since these are 64-bit and support dotprod.
+//
+// While this kernel does not have a direct equivalent in gemmlowp, it was
+// developed based on insights that David Mansell at ARM shared with their
+// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful
+// comments. Specifically, see this comment about tuning for Cortex-A55r1:
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
+void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeonDotprod, optimized for in-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v3 are used to load int8 data from LHS and
+ // RHS.
+ //
+ // int8 RHS 4x8 block
+ // /-----------------------------------------|
+ // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
+ // | ... ... |
+ // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
+ // \-----------------------------------------/
+ // int8 LHS 8x4 block
+ // /---------------------\ /-----------------------------------------|
+ // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
+ // | ... ... | | ... ... |
+ // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
+ // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
+ // | ... ... | | ... ... |
+ // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // int32 accumulators 8x8 block
+ //
+ // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
+ // we did not observe a benefit of such partial unrolling on in-order CPUs.
+ //
+ // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for
+ // the post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ RUY_MAKE_ZERO(v16)
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ RUY_MAKE_ZERO(v17)
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ RUY_MAKE_ZERO(v18)
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ RUY_MAKE_ZERO(v19)
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v20)
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v21)
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v22)
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ RUY_MAKE_ZERO(v28)
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ RUY_MAKE_ZERO(v29)
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ RUY_MAKE_ZERO(v30)
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ RUY_MAKE_ZERO(v31)
+
+
+ "1:\n"
+
+ "add x5, %[lhs_ptr], x12, lsl #3\n"
+ "sub x5, x5, #32\n"
+ "cmp %[lhs_ptr], x5\n"
+
+ "beq 79f\n"
+
+ // Main accumulation loop
+ "2:\n"
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ "ldr x1, [%[lhs_ptr], #8]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ "ldr x3, [%[rhs_ptr], #8]\n"
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ "ldr x4, [%[rhs_ptr], #24]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ "ldr d0, [%[lhs_ptr], #0]\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ "ins v0.d[1], x1\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ "ldr x2, [%[lhs_ptr], #24]\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #32\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ "ldr d2, [%[rhs_ptr], #0]\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ "ins v2.d[1], x3\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ "cmp %[lhs_ptr], x5\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #32\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+ "ldr d3, [%[rhs_ptr], #-16]\n"
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ "ldr d1, [%[lhs_ptr], #-16]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ "ins v3.d[1], x4\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ "ins v1.d[1], x2\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ "blt 2b\n"
+
+ // Last accumulation steps, nothing left to load.
+ "79:\n"
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ // Load some parameters needed for the end work on current block.
+ "mvni v8.4s, #0\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.2s}, [x1], #8\n"
+ "ldr x5, [x1], #8\n"
+ "ins v14.d[1], x5\n"
+ "ld1 {v15.2s}, [x1], #8\n"
+ "ldr x5, [x1], #8\n"
+ "ins v15.d[1], x5\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+ "add v15.4s, v15.4s, v9.4s\n"
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
+ "add v16.4s, v16.4s, v14.4s\n"
+ "add v17.4s, v17.4s, v15.4s\n"
+ "add v18.4s, v18.4s, v14.4s\n"
+ "add v19.4s, v19.4s, v15.4s\n"
+ "add v20.4s, v20.4s, v14.4s\n"
+ "add v21.4s, v21.4s, v15.4s\n"
+ "add v22.4s, v22.4s, v14.4s\n"
+ "add v23.4s, v23.4s, v15.4s\n"
+ "add v24.4s, v24.4s, v14.4s\n"
+ "add v25.4s, v25.4s, v15.4s\n"
+ "add v26.4s, v26.4s, v14.4s\n"
+ "add v27.4s, v27.4s, v15.4s\n"
+ "add v28.4s, v28.4s, v14.4s\n"
+ "add v29.4s, v29.4s, v15.4s\n"
+ "add v30.4s, v30.4s, v14.4s\n"
+ "add v31.4s, v31.4s, v15.4s\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v10.4s, v14.s[0]\n"
+ "dup v11.4s, v14.s[1]\n"
+ "add v16.4s, v16.4s, v10.4s\n"
+ "dup v12.4s, v14.s[2]\n"
+ "add v17.4s, v17.4s, v10.4s\n"
+ "dup v13.4s, v14.s[3]\n"
+ "add v18.4s, v18.4s, v11.4s\n"
+ "dup v10.4s, v15.s[0]\n"
+ "add v19.4s, v19.4s, v11.4s\n"
+ "dup v11.4s, v15.s[1]\n"
+ "add v20.4s, v20.4s, v12.4s\n"
+ "add v21.4s, v21.4s, v12.4s\n"
+ "dup v12.4s, v15.s[2]\n"
+ "add v22.4s, v22.4s, v13.4s\n"
+ "add v23.4s, v23.4s, v13.4s\n"
+ "dup v13.4s, v15.s[3]\n"
+ "add v24.4s, v24.4s, v10.4s\n"
+ "add v25.4s, v25.4s, v10.4s\n"
+ "add v26.4s, v26.4s, v11.4s\n"
+ "add v27.4s, v27.4s, v11.4s\n"
+ "add v28.4s, v28.4s, v12.4s\n"
+ "add v29.4s, v29.4s, v12.4s\n"
+ "add v30.4s, v30.4s, v13.4s\n"
+ "add v31.4s, v31.4s, v13.4s\n"
+ "7:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x5, x5, %x[col], lsl #2\n"
+ // Load 8 rhs_sums values.
+ "ld1 {v14.2s}, [x5], #8\n"
+ "ldr x7, [x5], #8\n"
+ "ld1 {v15.2s}, [x5], #8\n"
+ "ins v14.d[1], x7\n"
+ "ldr x7, [x5], #8\n"
+ "ins v15.d[1], x7\n"
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[0]\n"
+ "mls v18.4s, v10.4s, v14.s[1]\n"
+ "mls v19.4s, v10.4s, v14.s[1]\n"
+ "mls v20.4s, v10.4s, v14.s[2]\n"
+ "mls v21.4s, v10.4s, v14.s[2]\n"
+ "mls v22.4s, v10.4s, v14.s[3]\n"
+ "mls v23.4s, v10.4s, v14.s[3]\n"
+ "mls v24.4s, v10.4s, v15.s[0]\n"
+ "mls v25.4s, v10.4s, v15.s[0]\n"
+ "mls v26.4s, v10.4s, v15.s[1]\n"
+ "mls v27.4s, v10.4s, v15.s[1]\n"
+ "mls v28.4s, v10.4s, v15.s[2]\n"
+ "mls v29.4s, v10.4s, v15.s[2]\n"
+ "mls v30.4s, v10.4s, v15.s[3]\n"
+ "mls v31.4s, v10.4s, v15.s[3]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Load 8 lhs_sums values.
+ "ld1 {v11.2s}, [x2], #8\n"
+ "ldr x4, [x2], #8\n"
+ "ins v11.d[1], x4\n"
+ "ld1 {v12.2s}, [x2], #8\n"
+ "ldr x4, [x2], #8\n"
+ "ins v12.d[1], x4\n"
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ "mul v12.4s, v12.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v12.4s\n"
+ "sub v18.4s, v18.4s, v11.4s\n"
+ "sub v19.4s, v19.4s, v12.4s\n"
+ "sub v20.4s, v20.4s, v11.4s\n"
+ "sub v21.4s, v21.4s, v12.4s\n"
+ "sub v22.4s, v22.4s, v11.4s\n"
+ "sub v23.4s, v23.4s, v12.4s\n"
+ "sub v24.4s, v24.4s, v11.4s\n"
+ "sub v25.4s, v25.4s, v12.4s\n"
+ "sub v26.4s, v26.4s, v11.4s\n"
+ "sub v27.4s, v27.4s, v12.4s\n"
+ "sub v28.4s, v28.4s, v11.4s\n"
+ "sub v29.4s, v29.4s, v12.4s\n"
+ "sub v30.4s, v30.4s, v11.4s\n"
+ "sub v31.4s, v31.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ // Compute the multiplier_exponent pointer
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, x3, lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+ // Load multiplier_exponent
+ "ldr q9, [x1]\n"
+ "ldr q10, [x1, #16]\n"
+ // Separate positive and negative exponents
+ "smin v11.4s, v8.4s, v9.4s\n"
+ "smin v12.4s, v8.4s, v10.4s\n"
+ "sub v9.4s, v9.4s, v11.4s\n"
+ "sub v10.4s, v10.4s, v12.4s\n"
+
+ // Compute the multiplier_fixedpoint pointer
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "add x5, x4, x3, lsl #2\n"
+ "csel x4, x4, x5, eq\n"
+ // Load multiplier_fixedpoint
+ "ldr q14, [x4]\n"
+ "ldr q15, [x4, #16]\n"
+
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 8f\n"
+ // Case where channels are rows
+
+ // Apply the positive exponent part of the multiplier.
+ "sshl v16.4s, v16.4s, v9.4s\n"
+ "sshl v17.4s, v17.4s, v10.4s\n"
+ "sshl v18.4s, v18.4s, v9.4s\n"
+ "sshl v19.4s, v19.4s, v10.4s\n"
+ "sshl v20.4s, v20.4s, v9.4s\n"
+ "sshl v21.4s, v21.4s, v10.4s\n"
+ "sshl v22.4s, v22.4s, v9.4s\n"
+ "sshl v23.4s, v23.4s, v10.4s\n"
+ "sshl v24.4s, v24.4s, v9.4s\n"
+ "sshl v25.4s, v25.4s, v10.4s\n"
+ "sshl v26.4s, v26.4s, v9.4s\n"
+ "sshl v27.4s, v27.4s, v10.4s\n"
+ "sshl v28.4s, v28.4s, v9.4s\n"
+ "sshl v29.4s, v29.4s, v10.4s\n"
+ "sshl v30.4s, v30.4s, v9.4s\n"
+ "sshl v31.4s, v31.4s, v10.4s\n"
+ "10:\n"
+
+ // Apply the fixed-point part of the multiplier.
+ //
+ // ... and, interleaved into that:
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
+ "sqdmulh v16.4s, v16.4s, v14.4s\n"
+ "ldr x1, [%[lhs_ptr]], #8\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
+ "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
+ "sqdmulh v18.4s, v18.4s, v14.4s\n"
+ "ldr x2, [%[lhs_ptr]], #8\n"
+ "sqdmulh v19.4s, v19.4s, v15.4s\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
+ "sqdmulh v20.4s, v20.4s, v14.4s\n"
+ "ldr x5, [%[rhs_ptr]], #8\n"
+ "sqdmulh v21.4s, v21.4s, v15.4s\n"
+ "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
+ "sqdmulh v22.4s, v22.4s, v14.4s\n"
+ "ldr x6, [%[rhs_ptr]], #8\n"
+ "sqdmulh v23.4s, v23.4s, v15.4s\n"
+ "sqdmulh v24.4s, v24.4s, v14.4s\n"
+ "sqdmulh v25.4s, v25.4s, v15.4s\n"
+ "sqdmulh v26.4s, v26.4s, v14.4s\n"
+ "sqdmulh v27.4s, v27.4s, v15.4s\n"
+ "sqdmulh v28.4s, v28.4s, v14.4s\n"
+ "sqdmulh v29.4s, v29.4s, v15.4s\n"
+ "sqdmulh v30.4s, v30.4s, v14.4s\n"
+ "sqdmulh v31.4s, v31.4s, v15.4s\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v12.4s\n"
+ "srshl v20.4s, v20.4s, v11.4s\n"
+ "srshl v21.4s, v21.4s, v12.4s\n"
+ "srshl v22.4s, v22.4s, v11.4s\n"
+ "srshl v23.4s, v23.4s, v12.4s\n"
+ "srshl v24.4s, v24.4s, v11.4s\n"
+ "srshl v25.4s, v25.4s, v12.4s\n"
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "srshl v26.4s, v26.4s, v11.4s\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "srshl v27.4s, v27.4s, v12.4s\n"
+ "ins v0.d[1], x1\n"
+ "srshl v28.4s, v28.4s, v11.4s\n"
+ "ins v1.d[1], x2\n"
+ "srshl v29.4s, v29.4s, v12.4s\n"
+ "ins v2.d[1], x5\n"
+ "srshl v30.4s, v30.4s, v11.4s\n"
+ "ins v3.d[1], x6\n"
+ "srshl v31.4s, v31.4s, v12.4s\n"
+ "b 9f\n"
+
+ "8:\n"
+ // Case where channels are columns
+
+ // Apply the positive exponent part of the multiplier.
+ "dup v4.4s, v9.s[0]\n"
+ "dup v5.4s, v9.s[1]\n"
+ "sshl v16.4s, v16.4s, v4.4s\n"
+ "dup v6.4s, v9.s[2]\n"
+ "sshl v17.4s, v17.4s, v4.4s\n"
+ "dup v7.4s, v9.s[3]\n"
+ "sshl v18.4s, v18.4s, v5.4s\n"
+ "dup v4.4s, v10.s[0]\n"
+ "sshl v19.4s, v19.4s, v5.4s\n"
+ "dup v5.4s, v10.s[1]\n"
+ "sshl v20.4s, v20.4s, v6.4s\n"
+ "sshl v21.4s, v21.4s, v6.4s\n"
+ "dup v6.4s, v10.s[2]\n"
+ "sshl v22.4s, v22.4s, v7.4s\n"
+ "sshl v23.4s, v23.4s, v7.4s\n"
+ "dup v7.4s, v10.s[3]\n"
+ "sshl v24.4s, v24.4s, v4.4s\n"
+ "sshl v25.4s, v25.4s, v4.4s\n"
+ "sshl v26.4s, v26.4s, v5.4s\n"
+ "sshl v27.4s, v27.4s, v5.4s\n"
+ "sshl v28.4s, v28.4s, v6.4s\n"
+ "sshl v29.4s, v29.4s, v6.4s\n"
+ "sshl v30.4s, v30.4s, v7.4s\n"
+ "sshl v31.4s, v31.4s, v7.4s\n"
+ "11:\n"
+
+ // Apply the fixed-point part of the multiplier.
+ //
+ // ... and, interleaved into that:
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
+ "sqdmulh v16.4s, v16.4s, v14.s[0]\n"
+ "ldr x1, [%[lhs_ptr]], #8\n"
+ "sqdmulh v17.4s, v17.4s, v14.s[0]\n"
+ "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
+ "sqdmulh v18.4s, v18.4s, v14.s[1]\n"
+ "ldr x2, [%[lhs_ptr]], #8\n"
+ "sqdmulh v19.4s, v19.4s, v14.s[1]\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
+ "sqdmulh v20.4s, v20.4s, v14.s[2]\n"
+ "ldr x5, [%[rhs_ptr]], #8\n"
+ "sqdmulh v21.4s, v21.4s, v14.s[2]\n"
+ "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
+ "sqdmulh v22.4s, v22.4s, v14.s[3]\n"
+ "ldr x6, [%[rhs_ptr]], #8\n"
+ "sqdmulh v23.4s, v23.4s, v14.s[3]\n"
+ "dup v4.4s, v11.s[0]\n"
+ "sqdmulh v24.4s, v24.4s, v15.s[0]\n"
+ "dup v5.4s, v11.s[1]\n"
+ "sqdmulh v25.4s, v25.4s, v15.s[0]\n"
+ "dup v6.4s, v11.s[2]\n"
+ "sqdmulh v26.4s, v26.4s, v15.s[1]\n"
+ "dup v7.4s, v11.s[3]\n"
+ "sqdmulh v27.4s, v27.4s, v15.s[1]\n"
+ "sqdmulh v28.4s, v28.4s, v15.s[2]\n"
+ "sqdmulh v29.4s, v29.4s, v15.s[2]\n"
+ "sqdmulh v30.4s, v30.4s, v15.s[3]\n"
+ "sqdmulh v31.4s, v31.4s, v15.s[3]\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "srshl v16.4s, v16.4s, v4.4s\n"
+ "srshl v17.4s, v17.4s, v4.4s\n"
+ "dup v4.4s, v12.s[0]\n"
+ "srshl v18.4s, v18.4s, v5.4s\n"
+ "srshl v19.4s, v19.4s, v5.4s\n"
+ "dup v5.4s, v12.s[1]\n"
+ "srshl v20.4s, v20.4s, v6.4s\n"
+ "srshl v21.4s, v21.4s, v6.4s\n"
+ "dup v6.4s, v12.s[2]\n"
+ "srshl v22.4s, v22.4s, v7.4s\n"
+ "srshl v23.4s, v23.4s, v7.4s\n"
+ "dup v7.4s, v12.s[3]\n"
+ "srshl v24.4s, v24.4s, v4.4s\n"
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "srshl v25.4s, v25.4s, v4.4s\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "srshl v26.4s, v26.4s, v5.4s\n"
+ "ins v0.d[1], x1\n"
+ "srshl v27.4s, v27.4s, v5.4s\n"
+ "ins v1.d[1], x2\n"
+ "srshl v28.4s, v28.4s, v6.4s\n"
+ "ins v2.d[1], x5\n"
+ "srshl v29.4s, v29.4s, v6.4s\n"
+ "ins v3.d[1], x6\n"
+ "srshl v30.4s, v30.4s, v7.4s\n"
+ "srshl v31.4s, v31.4s, v7.4s\n"
+ "9:\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // Destination zero_point
+ "dup v14.8h, v13.h[4]\n"
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+ "add v18.8h, v18.8h, v14.8h\n"
+ "add v19.8h, v19.8h, v14.8h\n"
+ "add v20.8h, v20.8h, v14.8h\n"
+ "add v21.8h, v21.8h, v14.8h\n"
+ "add v22.8h, v22.8h, v14.8h\n"
+ "add v23.8h, v23.8h, v14.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ // Cast-and-saturate from int16 to uint8
+ "sqxtun v16.8b, v16.8h\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "sqxtun2 v16.16b, v17.8h\n"
+ "sqxtun v17.8b, v18.8h\n"
+ "sqxtun2 v17.16b, v19.8h\n"
+ "sqxtun v18.8b, v20.8h\n"
+ "sqxtun2 v18.16b, v21.8h\n"
+ "sqxtun v19.8b, v22.8h\n"
+ "sqxtun2 v19.16b, v23.8h\n"
+
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "umax v17.16b, v17.16b, v14.16b\n"
+ "mov w3, #8\n"
+ "umax v18.16b, v18.16b, v14.16b\n"
+ "cmp w1, #8\n"
+ "umax v19.16b, v19.16b, v14.16b\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+ "cmp w2, #8\n"
+ "umin v17.16b, v17.16b, v15.16b\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+ "umin v18.16b, v18.16b, v15.16b\n"
+ "umin v19.16b, v19.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+ "dup d21, v17.d[1]\n"
+ "dup d22, v18.d[1]\n"
+ "dup d23, v19.d[1]\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // Destination zero_point
+ "dup v14.8h, v13.h[4]\n"
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+ "add v18.8h, v18.8h, v14.8h\n"
+ "add v19.8h, v19.8h, v14.8h\n"
+ "add v20.8h, v20.8h, v14.8h\n"
+ "add v21.8h, v21.8h, v14.8h\n"
+ "add v22.8h, v22.8h, v14.8h\n"
+ "add v23.8h, v23.8h, v14.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ // Cast-and-saturate from int16 to uint8
+ "sqxtn v16.8b, v16.8h\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "sqxtn2 v16.16b, v17.8h\n"
+ "sqxtn v17.8b, v18.8h\n"
+ "sqxtn2 v17.16b, v19.8h\n"
+ "sqxtn v18.8b, v20.8h\n"
+ "sqxtn2 v18.16b, v21.8h\n"
+ "sqxtn v19.8b, v22.8h\n"
+ "sqxtn2 v19.16b, v23.8h\n"
+
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "smax v17.16b, v17.16b, v14.16b\n"
+ "mov w3, #8\n"
+ "smax v18.16b, v18.16b, v14.16b\n"
+ "cmp w1, #8\n"
+ "smax v19.16b, v19.16b, v14.16b\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+ "cmp w2, #8\n"
+ "smin v17.16b, v17.16b, v15.16b\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+ "smin v18.16b, v18.16b, v15.16b\n"
+ "smin v19.16b, v19.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+ "dup d21, v17.d[1]\n"
+ "dup d22, v18.d[1]\n"
+ "dup d23, v19.d[1]\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 130f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 131f\n"
+ "130:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "131:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 141f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "150:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "151:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 151b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 150b\n"
+ "141:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+ "saddw v18.4s, v18.4s, v14.4h\n"
+ "saddw v19.4s, v19.4s, v14.4h\n"
+ "saddw v20.4s, v20.4s, v14.4h\n"
+ "saddw v21.4s, v21.4s, v14.4h\n"
+ "saddw v22.4s, v22.4s, v14.4h\n"
+ "saddw v23.4s, v23.4s, v14.4h\n"
+ "saddw v24.4s, v24.4s, v14.4h\n"
+ "saddw v25.4s, v25.4s, v14.4h\n"
+ "saddw v26.4s, v26.4s, v14.4h\n"
+ "saddw v27.4s, v27.4s, v14.4h\n"
+ "saddw v28.4s, v28.4s, v14.4h\n"
+ "saddw v29.4s, v29.4s, v14.4h\n"
+ "saddw v30.4s, v30.4s, v14.4h\n"
+ "saddw v31.4s, v31.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ "smax v17.8h, v17.8h, v14.8h\n"
+ "smax v18.8h, v18.8h, v14.8h\n"
+ "smax v19.8h, v19.8h, v14.8h\n"
+ "smax v20.8h, v20.8h, v14.8h\n"
+ "smax v21.8h, v21.8h, v14.8h\n"
+ "smax v22.8h, v22.8h, v14.8h\n"
+ "smax v23.8h, v23.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+ "smin v17.8h, v17.8h, v15.8h\n"
+ "smin v18.8h, v18.8h, v15.8h\n"
+ "smin v19.8h, v19.8h, v15.8h\n"
+ "smin v20.8h, v20.8h, v15.8h\n"
+ "smin v21.8h, v21.8h, v15.8h\n"
+ "smin v22.8h, v22.8h, v15.8h\n"
+ "smin v23.8h, v23.8h, v15.8h\n"
+
+ // Compute how much of the 8x8 block of destination 16bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 230f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #16\n"
+ "b 231f\n"
+ "230:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "231:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 241f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "250:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "251:\n"
+ "ldrsh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 251b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #16\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 250b\n"
+ "241:\n"
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
+ "ldr x1, [%[lhs_ptr]], #8\n"
+ "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
+ "ldr x2, [%[lhs_ptr]], #8\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
+ "ldr x5, [%[rhs_ptr]], #8\n"
+ "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
+ "ldr x6, [%[rhs_ptr]], #8\n"
+ "ins v0.d[1], x1\n"
+ "ins v1.d[1], x2\n"
+ "ins v2.d[1], x5\n"
+ "ins v3.d[1], x6\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // Compute how much of the 8x8 block of destination 32it values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 330f\n"
+ // Not all of the 8x8 block fits.
+ // Write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "st1 {v16.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v16)
+ "st1 {v17.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v17)
+ "st1 {v18.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v18)
+ "st1 {v19.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v19)
+ "st1 {v20.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v20)
+ "st1 {v21.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v21)
+ "st1 {v22.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v22)
+ "st1 {v23.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v23)
+ "st1 {v24.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v24)
+ "st1 {v25.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v25)
+ "st1 {v26.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v26)
+ "st1 {v27.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v27)
+ "st1 {v28.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v28)
+ "st1 {v29.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v29)
+ "st1 {v30.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v30)
+ "st1 {v31.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v31)
+
+ "b 331f\n"
+
+ "330:\n"
+ // Yes, all of the 8x8 block fits.
+ "mov x4, %[dst_ptr]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v16.4s, v17.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v18.4s, v19.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v20.4s, v21.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v22.4s, v23.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v24.4s, v25.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v26.4s, v27.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v28.4s, v29.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v30.4s, v31.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ "331:\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 341f\n"
+
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "350:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "351:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 351b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 350b\n"
+ "341:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#undef RUY_OFFSET_BIAS
+#undef RUY_OFFSET_LHS_SUMS
+#undef RUY_OFFSET_RHS_SUMS
+#undef RUY_OFFSET_LHS_BASE_PTR
+#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
+#undef RUY_OFFSET_MULTIPLIER_EXPONENT
+#undef RUY_OFFSET_RHS_BASE_PTR
+#undef RUY_OFFSET_DST_BASE_PTR
+#undef RUY_OFFSET_LHS_ZERO_POINT
+#undef RUY_OFFSET_RHS_ZERO_POINT
+#undef RUY_OFFSET_DST_ZERO_POINT
+#undef RUY_OFFSET_PROD_ZP_DEPTH
+#undef RUY_OFFSET_START_ROW
+#undef RUY_OFFSET_START_COL
+#undef RUY_OFFSET_LAST_ROW
+#undef RUY_OFFSET_LAST_COL
+#undef RUY_OFFSET_DST_ROWS
+#undef RUY_OFFSET_DST_COLS
+#undef RUY_OFFSET_LHS_STRIDE
+#undef RUY_OFFSET_RHS_STRIDE
+#undef RUY_OFFSET_DST_STRIDE
+#undef RUY_OFFSET_DEPTH
+#undef RUY_OFFSET_CLAMP_MIN
+#undef RUY_OFFSET_CLAMP_MAX
+#undef RUY_OFFSET_FLAGS
+
+#define RUY_OFFSET_LHS_BASE_PTR 0
+#define RUY_OFFSET_RHS_BASE_PTR 8
+#define RUY_OFFSET_DST_BASE_PTR 16
+#define RUY_OFFSET_BIAS 24
+#define RUY_OFFSET_START_ROW 32
+#define RUY_OFFSET_START_COL 36
+#define RUY_OFFSET_LAST_ROW 40
+#define RUY_OFFSET_LAST_COL 44
+#define RUY_OFFSET_LHS_STRIDE 56
+#define RUY_OFFSET_RHS_STRIDE 60
+#define RUY_OFFSET_DST_STRIDE 64
+#define RUY_OFFSET_DEPTH 68
+#define RUY_OFFSET_CLAMP_MIN 72
+#define RUY_OFFSET_CLAMP_MAX 76
+#define RUY_OFFSET_FLAGS 80
+
+template <typename Params>
+void CheckOffsetsInKernelParamsFloat(const Params&) {
+ static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
+ static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
+ static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
+ static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
+ static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
+ static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
+ static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
+ static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
+ static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
+ static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
+ static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
+ static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
+ static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
+ static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
+ static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
+}
+
+// Just a plain float kernel; good enough for out-of-order cores.
+// The closest to it in the gemmlowp collection would be
+// NEON_64bit_GEMM_Float32_WithScalar,
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925
+//
+// Besides ruy-ification, the main nuance here is that we stick to a 8x8
+// width instead of the wider 12x8 that the register space permits and that
+// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now
+// and we don't have evidence that going beyond 8x8 is needed.
+void KernelFloatNeon(const KernelParamsFloat<8, 8>& params) {
+ CheckOffsetsInKernelParamsFloat(params);
+ profiler::ScopeLabel label("Kernel (kNeon)");
+
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are accumulators.
+ // During accumulation, v0 -- v15 are used to load data from LHS and RHS.
+ // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and
+ // v3 are used to load a 1x8 block of RHS, like this:
+ //
+ // RHS 1x8 block
+ // /-----------------------------------------|
+ // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
+ // \-----------------------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /-----------------------------------------|
+ // | v0.s[0] | |v16.s[0] ... v30.s[0]|
+ // | ... | | ... ... |
+ // | v0.s[3] | |v16.s[3] ... v30.s[3]|
+ // | v1.s[0] | |v17.s[0] ... v31.s[0]|
+ // | ... | | ... ... |
+ // | v1.s[3] | |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // accumulators 8x8 block
+ //
+ // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
+ // is repeated 4 times, using 4x more registers for LHS and RHS, so that
+ // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
+ //
+ // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
+ // unused, and v8 -- v15 are used for floading parameters used for the
+ // post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 1.
+ "mov w1, #1\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ "fmla v16.4s, v0.4s, v2.s[0]\n"
+ "fmla v18.4s, v0.4s, v2.s[1]\n"
+ "fmla v20.4s, v0.4s, v2.s[2]\n"
+ "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+#if RUY_OPT(MAX_STREAMING)
+ "cmp w12, #8\n"
+ "blt 78f\n"
+ "and w2, w12, #-4\n"
+
+ "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v5.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v6.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.4s}, [%[rhs_ptr]], #16\n"
+
+ "ld1 {v8.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v9.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v10.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v11.4s}, [%[rhs_ptr]], #16\n"
+
+ "ld1 {v12.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v13.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v14.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v15.4s}, [%[rhs_ptr]], #16\n"
+ "mov w1, #4\n"
+
+ "80:\n"
+
+ "add %[lhs_ptr], %[lhs_ptr], #128\n"
+ "add %[rhs_ptr], %[rhs_ptr], #128\n"
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "ldr q0, [%[lhs_ptr], #-128]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ldr q3, [%[rhs_ptr], #-112]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+ "ldr q1, [%[lhs_ptr], #-112]\n"
+ "fmla v16.4s, v4.4s, v6.s[0]\n"
+ "fmla v18.4s, v4.4s, v6.s[1]\n"
+ "ldr q2, [%[rhs_ptr], #-128]\n"
+ "fmla v20.4s, v4.4s, v6.s[2]\n"
+ "fmla v22.4s, v4.4s, v6.s[3]\n"
+
+ "fmla v24.4s, v4.4s, v7.s[0]\n"
+ "fmla v26.4s, v4.4s, v7.s[1]\n"
+ "fmla v28.4s, v4.4s, v7.s[2]\n"
+ "fmla v30.4s, v4.4s, v7.s[3]\n"
+ "ldr q4, [%[lhs_ptr], #-96]\n"
+ "fmla v25.4s, v5.4s, v7.s[0]\n"
+ "fmla v27.4s, v5.4s, v7.s[1]\n"
+ "fmla v29.4s, v5.4s, v7.s[2]\n"
+ "fmla v31.4s, v5.4s, v7.s[3]\n"
+ "ldr q7, [%[rhs_ptr], #-80]\n"
+ "fmla v17.4s, v5.4s, v6.s[0]\n"
+ "fmla v19.4s, v5.4s, v6.s[1]\n"
+ "fmla v21.4s, v5.4s, v6.s[2]\n"
+ "fmla v23.4s, v5.4s, v6.s[3]\n"
+ "ldr q5, [%[lhs_ptr], #-80]\n"
+ "fmla v16.4s, v8.4s, v10.s[0]\n"
+ "fmla v18.4s, v8.4s, v10.s[1]\n"
+ "ldr q6, [%[rhs_ptr], #-96]\n"
+ "fmla v20.4s, v8.4s, v10.s[2]\n"
+ "fmla v22.4s, v8.4s, v10.s[3]\n"
+
+ "fmla v24.4s, v8.4s, v11.s[0]\n"
+ "fmla v26.4s, v8.4s, v11.s[1]\n"
+ "fmla v28.4s, v8.4s, v11.s[2]\n"
+ "fmla v30.4s, v8.4s, v11.s[3]\n"
+ "ldr q8, [%[lhs_ptr], #-64]\n"
+ "fmla v25.4s, v9.4s, v11.s[0]\n"
+ "fmla v27.4s, v9.4s, v11.s[1]\n"
+ "fmla v29.4s, v9.4s, v11.s[2]\n"
+ "fmla v31.4s, v9.4s, v11.s[3]\n"
+ "ldr q11, [%[rhs_ptr], #-48]\n"
+ "fmla v17.4s, v9.4s, v10.s[0]\n"
+ "fmla v19.4s, v9.4s, v10.s[1]\n"
+ "fmla v21.4s, v9.4s, v10.s[2]\n"
+ "fmla v23.4s, v9.4s, v10.s[3]\n"
+ "ldr q9, [%[lhs_ptr], #-48]\n"
+ "fmla v16.4s, v12.4s, v14.s[0]\n"
+ "fmla v18.4s, v12.4s, v14.s[1]\n"
+ "ldr q10, [%[rhs_ptr], #-64]\n"
+ "fmla v20.4s, v12.4s, v14.s[2]\n"
+ "fmla v22.4s, v12.4s, v14.s[3]\n"
+
+ "fmla v24.4s, v12.4s, v15.s[0]\n"
+ "fmla v26.4s, v12.4s, v15.s[1]\n"
+ "fmla v28.4s, v12.4s, v15.s[2]\n"
+ "fmla v30.4s, v12.4s, v15.s[3]\n"
+ "ldr q12, [%[lhs_ptr], #-32]\n"
+ "fmla v25.4s, v13.4s, v15.s[0]\n"
+ "fmla v27.4s, v13.4s, v15.s[1]\n"
+ "fmla v29.4s, v13.4s, v15.s[2]\n"
+ "fmla v31.4s, v13.4s, v15.s[3]\n"
+ "ldr q15, [%[rhs_ptr], #-16]\n"
+ "fmla v17.4s, v13.4s, v14.s[0]\n"
+ "fmla v19.4s, v13.4s, v14.s[1]\n"
+ "fmla v21.4s, v13.4s, v14.s[2]\n"
+ "fmla v23.4s, v13.4s, v14.s[3]\n"
+ "ldr q13, [%[lhs_ptr], #-16]\n"
+ "fmla v16.4s, v0.4s, v2.s[0]\n"
+ "fmla v18.4s, v0.4s, v2.s[1]\n"
+ "ldr q14, [%[rhs_ptr], #-32]\n"
+ "fmla v20.4s, v0.4s, v2.s[2]\n"
+ "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+ "add w1, w1, #4\n"
+ "cmp w1, w2\n"
+ "blt 80b\n"
+
+ "fmla v16.4s, v4.4s, v6.s[0]\n"
+ "fmla v18.4s, v4.4s, v6.s[1]\n"
+ "fmla v20.4s, v4.4s, v6.s[2]\n"
+ "fmla v22.4s, v4.4s, v6.s[3]\n"
+ "fmla v24.4s, v4.4s, v7.s[0]\n"
+ "fmla v26.4s, v4.4s, v7.s[1]\n"
+ "fmla v28.4s, v4.4s, v7.s[2]\n"
+ "fmla v30.4s, v4.4s, v7.s[3]\n"
+ "fmla v25.4s, v5.4s, v7.s[0]\n"
+ "fmla v27.4s, v5.4s, v7.s[1]\n"
+ "fmla v29.4s, v5.4s, v7.s[2]\n"
+ "fmla v31.4s, v5.4s, v7.s[3]\n"
+ "fmla v17.4s, v5.4s, v6.s[0]\n"
+ "fmla v19.4s, v5.4s, v6.s[1]\n"
+ "fmla v21.4s, v5.4s, v6.s[2]\n"
+ "fmla v23.4s, v5.4s, v6.s[3]\n"
+
+ "fmla v16.4s, v8.4s, v10.s[0]\n"
+ "fmla v18.4s, v8.4s, v10.s[1]\n"
+ "fmla v20.4s, v8.4s, v10.s[2]\n"
+ "fmla v22.4s, v8.4s, v10.s[3]\n"
+ "fmla v24.4s, v8.4s, v11.s[0]\n"
+ "fmla v26.4s, v8.4s, v11.s[1]\n"
+ "fmla v28.4s, v8.4s, v11.s[2]\n"
+ "fmla v30.4s, v8.4s, v11.s[3]\n"
+ "fmla v25.4s, v9.4s, v11.s[0]\n"
+ "fmla v27.4s, v9.4s, v11.s[1]\n"
+ "fmla v29.4s, v9.4s, v11.s[2]\n"
+ "fmla v31.4s, v9.4s, v11.s[3]\n"
+ "fmla v17.4s, v9.4s, v10.s[0]\n"
+ "fmla v19.4s, v9.4s, v10.s[1]\n"
+ "fmla v21.4s, v9.4s, v10.s[2]\n"
+ "fmla v23.4s, v9.4s, v10.s[3]\n"
+
+ "fmla v16.4s, v12.4s, v14.s[0]\n"
+ "fmla v18.4s, v12.4s, v14.s[1]\n"
+ "fmla v20.4s, v12.4s, v14.s[2]\n"
+ "fmla v22.4s, v12.4s, v14.s[3]\n"
+ "fmla v24.4s, v12.4s, v15.s[0]\n"
+ "fmla v26.4s, v12.4s, v15.s[1]\n"
+ "fmla v28.4s, v12.4s, v15.s[2]\n"
+ "fmla v30.4s, v12.4s, v15.s[3]\n"
+ "fmla v25.4s, v13.4s, v15.s[0]\n"
+ "fmla v27.4s, v13.4s, v15.s[1]\n"
+ "fmla v29.4s, v13.4s, v15.s[2]\n"
+ "fmla v31.4s, v13.4s, v15.s[3]\n"
+ "fmla v17.4s, v13.4s, v14.s[0]\n"
+ "fmla v19.4s, v13.4s, v14.s[1]\n"
+ "fmla v21.4s, v13.4s, v14.s[2]\n"
+ "fmla v23.4s, v13.4s, v14.s[3]\n"
+
+ "78:\n"
+#endif
+
+ // Accumulation loop
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "ld1 {v4.4s}, [%[rhs_ptr]], #16\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "add w1, w1, #1\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "cmp w1, w12\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "fmla v16.4s, v0.4s, v4.s[0]\n"
+ "fmla v18.4s, v0.4s, v4.s[1]\n"
+ "mov v2.16b, v4.16b\n"
+ "fmla v20.4s, v0.4s, v4.s[2]\n"
+ "fmla v22.4s, v0.4s, v4.s[3]\n"
+ "blt 2b\n"
+
+ "79:\n"
+
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last level of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Determine the channel index.
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+ // Perform the bias-addition.
+ // Jump based on channel dimension.
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
+ "fadd v16.4s, v16.4s, v14.4s\n"
+ "fadd v17.4s, v17.4s, v15.4s\n"
+ "fadd v18.4s, v18.4s, v14.4s\n"
+ "fadd v19.4s, v19.4s, v15.4s\n"
+ "fadd v20.4s, v20.4s, v14.4s\n"
+ "fadd v21.4s, v21.4s, v15.4s\n"
+ "fadd v22.4s, v22.4s, v14.4s\n"
+ "fadd v23.4s, v23.4s, v15.4s\n"
+ "fadd v24.4s, v24.4s, v14.4s\n"
+ "fadd v25.4s, v25.4s, v15.4s\n"
+ "fadd v26.4s, v26.4s, v14.4s\n"
+ "fadd v27.4s, v27.4s, v15.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v15.4s\n"
+ "fadd v30.4s, v30.4s, v14.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v8.4s, v14.s[0]\n"
+ "dup v9.4s, v14.s[1]\n"
+ "dup v10.4s, v14.s[2]\n"
+ "dup v11.4s, v14.s[3]\n"
+ "dup v12.4s, v15.s[0]\n"
+ "dup v13.4s, v15.s[1]\n"
+ "dup v14.4s, v15.s[2]\n"
+ "dup v15.4s, v15.s[3]\n"
+ "fadd v16.4s, v16.4s, v8.4s\n"
+ "fadd v17.4s, v17.4s, v8.4s\n"
+ "fadd v18.4s, v18.4s, v9.4s\n"
+ "fadd v19.4s, v19.4s, v9.4s\n"
+ "fadd v20.4s, v20.4s, v10.4s\n"
+ "fadd v21.4s, v21.4s, v10.4s\n"
+ "fadd v22.4s, v22.4s, v11.4s\n"
+ "fadd v23.4s, v23.4s, v11.4s\n"
+ "fadd v24.4s, v24.4s, v12.4s\n"
+ "fadd v25.4s, v25.4s, v12.4s\n"
+ "fadd v26.4s, v26.4s, v13.4s\n"
+ "fadd v27.4s, v27.4s, v13.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v14.4s\n"
+ "fadd v30.4s, v30.4s, v15.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+ "7:\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.4s, w2\n" // clamp_min
+ "dup v15.4s, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "fmax v16.4s, v16.4s, v14.4s\n"
+ "fmax v17.4s, v17.4s, v14.4s\n"
+ "fmax v18.4s, v18.4s, v14.4s\n"
+ "fmax v19.4s, v19.4s, v14.4s\n"
+ "fmax v20.4s, v20.4s, v14.4s\n"
+ "fmax v21.4s, v21.4s, v14.4s\n"
+ "fmax v22.4s, v22.4s, v14.4s\n"
+ "fmax v23.4s, v23.4s, v14.4s\n"
+ "fmax v24.4s, v24.4s, v14.4s\n"
+ "fmax v25.4s, v25.4s, v14.4s\n"
+ "fmax v26.4s, v26.4s, v14.4s\n"
+ "fmax v27.4s, v27.4s, v14.4s\n"
+ "fmax v28.4s, v28.4s, v14.4s\n"
+ "fmax v29.4s, v29.4s, v14.4s\n"
+ "fmax v30.4s, v30.4s, v14.4s\n"
+ "fmax v31.4s, v31.4s, v14.4s\n"
+
+ // Apply the clamp_max bound
+ "fmin v16.4s, v16.4s, v15.4s\n"
+ "fmin v17.4s, v17.4s, v15.4s\n"
+ "fmin v18.4s, v18.4s, v15.4s\n"
+ "fmin v19.4s, v19.4s, v15.4s\n"
+ "fmin v20.4s, v20.4s, v15.4s\n"
+ "fmin v21.4s, v21.4s, v15.4s\n"
+ "fmin v22.4s, v22.4s, v15.4s\n"
+ "fmin v23.4s, v23.4s, v15.4s\n"
+ "fmin v24.4s, v24.4s, v15.4s\n"
+ "fmin v25.4s, v25.4s, v15.4s\n"
+ "fmin v26.4s, v26.4s, v15.4s\n"
+ "fmin v27.4s, v27.4s, v15.4s\n"
+ "fmin v28.4s, v28.4s, v15.4s\n"
+ "fmin v29.4s, v29.4s, v15.4s\n"
+ "fmin v30.4s, v30.4s, v15.4s\n"
+ "fmin v31.4s, v31.4s, v15.4s\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #32\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "str q16, [x3, #0]\n"
+ "str q17, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ "str q18, [x3, #0]\n"
+ "str q19, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ "str q20, [x3, #0]\n"
+ "str q21, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ "str q22, [x3, #0]\n"
+ "str q23, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ "str q24, [x3, #0]\n"
+ "str q25, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ "str q26, [x3, #0]\n"
+ "str q27, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ "str q28, [x3, #0]\n"
+ "str q29, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ "str q30, [x3, #0]\n"
+ "str q31, [x3, #16]\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 1.
+ "mov w1, #1\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Variant of KernelFloatNeon tuned for in-order CPUs that do not
+// support dotprod (while dotprod by itself is not relevant to floating-point,
+// this additional bit of information that we have about the target happens to
+// be useful here).
+//
+// So a typical target CPU here would be ARM Cortex-A53 or the original
+// Cortex-A55.
+//
+// This kernel is similar to and inspired by gemmlowp's
+// NEON_64bit_GEMM_Float32_WithScalar_A53.
+// which was contributed by David Mansell with very helpful
+// comments. Specifically, see this comment about tuning for Cortex-A53:
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
+void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
+
+ CheckOffsetsInKernelParamsFloat(params);
+
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are accumulators.
+ // During accumulation, v0 -- v3 are used to load data from LHS and RHS.
+ //
+ // RHS 1x8 block
+ // /-----------------------------------------|
+ // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
+ // \-----------------------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /-----------------------------------------|
+ // | v0.s[0] | |v16.s[0] ... v30.s[0]|
+ // | ... | | ... ... |
+ // | v0.s[3] | |v16.s[3] ... v30.s[3]|
+ // | v1.s[0] | |v17.s[0] ... v31.s[0]|
+ // | ... | | ... ... |
+ // | v1.s[3] | |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // accumulators 8x8 block
+ //
+ // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
+ // we did not observe a benefit of such partial unrolling on in-order CPUs.
+ //
+ // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
+ // for the post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v17)
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v18)
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v19)
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
+ RUY_MAKE_ZERO(v23)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
+ RUY_MAKE_ZERO(v25)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
+ RUY_MAKE_ZERO(v27)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that remain to load
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently depth - 1.
+ "sub w1, w12, #1\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ "cmp w1, #0\n"
+ "fmla v16.4s, v0.4s, v2.s[0]\n"
+ "fmla v18.4s, v0.4s, v2.s[1]\n"
+ "fmla v20.4s, v0.4s, v2.s[2]\n"
+ "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+ // Accumulation loop
+ "beq 79f\n"
+
+ "2:\n"
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "ldr x2, [%[lhs_ptr], #8]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "ldr x3, [%[lhs_ptr], #24]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "ldr x5, [%[rhs_ptr], #24]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "ldr x4, [%[rhs_ptr], #8]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "subs w1, w1, #1\n"
+ "ldr d0, [%[lhs_ptr]], #32\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ins v0.d[1], x2\n"
+ "ldr d3, [%[rhs_ptr], #16]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "ins v3.d[1], x5\n"
+ "ldr d4, [%[rhs_ptr]], #32\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+ "fmla v16.4s, v0.4s, v4.s[0]\n"
+ "ins v4.d[1], x4\n"
+ "ldr d1, [%[lhs_ptr], #-16]\n"
+ "fmla v18.4s, v0.4s, v4.s[1]\n"
+ "fmla v20.4s, v0.4s, v4.s[2]\n"
+ "ins v1.d[1], x3\n"
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
+ "mov v2.16b, v4.16b\n"
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
+ "fmla v22.4s, v0.4s, v4.s[3]\n"
+ "bne 2b\n"
+
+ "79:\n"
+
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last level of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Determine the channel index.
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
+
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+ // Perform the bias-addition.
+ // Jump based on channel dimension.
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
+ "fadd v16.4s, v16.4s, v14.4s\n"
+ "fadd v17.4s, v17.4s, v15.4s\n"
+ "fadd v18.4s, v18.4s, v14.4s\n"
+ "fadd v19.4s, v19.4s, v15.4s\n"
+ "fadd v20.4s, v20.4s, v14.4s\n"
+ "fadd v21.4s, v21.4s, v15.4s\n"
+ "fadd v22.4s, v22.4s, v14.4s\n"
+ "fadd v23.4s, v23.4s, v15.4s\n"
+ "fadd v24.4s, v24.4s, v14.4s\n"
+ "fadd v25.4s, v25.4s, v15.4s\n"
+ "fadd v26.4s, v26.4s, v14.4s\n"
+ "fadd v27.4s, v27.4s, v15.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v15.4s\n"
+ "fadd v30.4s, v30.4s, v14.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v8.4s, v14.s[0]\n"
+ "dup v9.4s, v14.s[1]\n"
+ "fadd v16.4s, v16.4s, v8.4s\n"
+ "dup v10.4s, v14.s[2]\n"
+ "fadd v17.4s, v17.4s, v8.4s\n"
+ "dup v11.4s, v14.s[3]\n"
+ "fadd v18.4s, v18.4s, v9.4s\n"
+ "dup v12.4s, v15.s[0]\n"
+ "fadd v19.4s, v19.4s, v9.4s\n"
+ "dup v13.4s, v15.s[1]\n"
+ "fadd v20.4s, v20.4s, v10.4s\n"
+ "dup v14.4s, v15.s[2]\n"
+ "fadd v21.4s, v21.4s, v10.4s\n"
+ "dup v15.4s, v15.s[3]\n"
+ "fadd v22.4s, v22.4s, v11.4s\n"
+ "fadd v23.4s, v23.4s, v11.4s\n"
+ "fadd v24.4s, v24.4s, v12.4s\n"
+ "fadd v25.4s, v25.4s, v12.4s\n"
+ "fadd v26.4s, v26.4s, v13.4s\n"
+ "fadd v27.4s, v27.4s, v13.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v14.4s\n"
+ "fadd v30.4s, v30.4s, v15.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+ "7:\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.4s, w2\n" // clamp_min
+ "dup v15.4s, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "fmax v16.4s, v16.4s, v14.4s\n"
+ "fmax v17.4s, v17.4s, v14.4s\n"
+ "fmax v18.4s, v18.4s, v14.4s\n"
+ "fmax v19.4s, v19.4s, v14.4s\n"
+ "fmax v20.4s, v20.4s, v14.4s\n"
+ "fmax v21.4s, v21.4s, v14.4s\n"
+ "fmax v22.4s, v22.4s, v14.4s\n"
+ "fmax v23.4s, v23.4s, v14.4s\n"
+ "fmax v24.4s, v24.4s, v14.4s\n"
+ "fmax v25.4s, v25.4s, v14.4s\n"
+ "fmax v26.4s, v26.4s, v14.4s\n"
+ "fmax v27.4s, v27.4s, v14.4s\n"
+ "fmax v28.4s, v28.4s, v14.4s\n"
+ "fmax v29.4s, v29.4s, v14.4s\n"
+ "fmax v30.4s, v30.4s, v14.4s\n"
+ "fmax v31.4s, v31.4s, v14.4s\n"
+
+ // Apply the clamp_max bound
+ "fmin v16.4s, v16.4s, v15.4s\n"
+ "fmin v17.4s, v17.4s, v15.4s\n"
+ "fmin v18.4s, v18.4s, v15.4s\n"
+ "fmin v19.4s, v19.4s, v15.4s\n"
+ "fmin v20.4s, v20.4s, v15.4s\n"
+ "fmin v21.4s, v21.4s, v15.4s\n"
+ "fmin v22.4s, v22.4s, v15.4s\n"
+ "fmin v23.4s, v23.4s, v15.4s\n"
+ "fmin v24.4s, v24.4s, v15.4s\n"
+ "fmin v25.4s, v25.4s, v15.4s\n"
+ "fmin v26.4s, v26.4s, v15.4s\n"
+ "fmin v27.4s, v27.4s, v15.4s\n"
+ "fmin v28.4s, v28.4s, v15.4s\n"
+ "fmin v29.4s, v29.4s, v15.4s\n"
+ "fmin v30.4s, v30.4s, v15.4s\n"
+ "fmin v31.4s, v31.4s, v15.4s\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #32\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "str q16, [x3, #0]\n"
+ "str q17, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ "str q18, [x3, #0]\n"
+ "str q19, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ "str q20, [x3, #0]\n"
+ "str q21, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ "str q22, [x3, #0]\n"
+ "str q23, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ "str q24, [x3, #0]\n"
+ "str q25, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ "str q26, [x3, #0]\n"
+ "str q27, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ "str q28, [x3, #0]\n"
+ "str q29, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ "str q30, [x3, #0]\n"
+ "str q31, [x3, #16]\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that remain to load
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently depth - 1.
+ "sub w1, w12, #1\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Variant of KernelFloatNeonA55ish tuned for in-order CPUs that do
+// support dotprod (while dotprod by itself is not relevant to floating-point,
+// this additional bit of information that we have about the target happens to
+// be useful here).
+//
+// So a typical target CPU here would be ARM Cortex-A55r1.
+//
+// This kernel is similar to and inspired by gemmlowp's
+// NEON_64bit_GEMM_Float32_WithScalar_A55r1.
+// which was contributed by David Mansell with very helpful
+// comments. Specifically, see this comment about tuning for Cortex-A55r1:
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
+void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeonDotprod, optimized for in-order cores)");
+
+ CheckOffsetsInKernelParamsFloat(params);
+
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are accumulators.
+ // During accumulation, v0 -- v3 are used to load data from LHS and RHS.
+ //
+ // RHS 1x8 block
+ // /-----------------------------------------|
+ // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
+ // \-----------------------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /-----------------------------------------|
+ // | v0.s[0] | |v16.s[0] ... v30.s[0]|
+ // | ... | | ... ... |
+ // | v0.s[3] | |v16.s[3] ... v30.s[3]|
+ // | v1.s[0] | |v17.s[0] ... v31.s[0]|
+ // | ... | | ... ... |
+ // | v1.s[3] | |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // accumulators 8x8 block
+ //
+ // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
+ // we did not observe a benefit of such partial unrolling on in-order CPUs.
+ //
+ // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
+ // for the post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v17)
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v18)
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v19)
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
+ RUY_MAKE_ZERO(v23)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
+ RUY_MAKE_ZERO(v25)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
+ RUY_MAKE_ZERO(v27)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that remain to load
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently depth - 1.
+ "sub w1, w12, #1\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ "cmp w1, #0\n"
+ "fmla v16.4s, v0.4s, v2.s[0]\n"
+ "fmla v18.4s, v0.4s, v2.s[1]\n"
+ "fmla v20.4s, v0.4s, v2.s[2]\n"
+ "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+ // Accumulation loop
+ "beq 79f\n"
+
+ "2:\n"
+
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "ldr x2, [%[lhs_ptr], #8]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "ldr x3, [%[lhs_ptr], #24]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "ldr x5, [%[rhs_ptr], #24]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "ldr d0, [%[lhs_ptr]], #32\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "ldr x4, [%[rhs_ptr], #8]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "subs w1, w1, #1\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "ins v0.d[1], x2\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ldr d3, [%[rhs_ptr], #16]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "ins v3.d[1], x5\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "ldr d4, [%[rhs_ptr]], #32\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "ins v4.d[1], x4\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
+ "fmla v16.4s, v0.4s, v4.s[0]\n"
+ "ldr d1, [%[lhs_ptr], #-16]\n"
+ "fmla v18.4s, v0.4s, v4.s[1]\n"
+ "ins v1.d[1], x3\n"
+ "fmla v20.4s, v0.4s, v4.s[2]\n"
+ "mov v2.16b, v4.16b\n"
+ "fmla v22.4s, v0.4s, v4.s[3]\n"
+ "bne 2b\n"
+
+ "79:\n"
+
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last level of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Determine the channel index.
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
+
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+ // Perform the bias-addition.
+ // Jump based on channel dimension.
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
+ "fadd v16.4s, v16.4s, v14.4s\n"
+ "fadd v17.4s, v17.4s, v15.4s\n"
+ "fadd v18.4s, v18.4s, v14.4s\n"
+ "fadd v19.4s, v19.4s, v15.4s\n"
+ "fadd v20.4s, v20.4s, v14.4s\n"
+ "fadd v21.4s, v21.4s, v15.4s\n"
+ "fadd v22.4s, v22.4s, v14.4s\n"
+ "fadd v23.4s, v23.4s, v15.4s\n"
+ "fadd v24.4s, v24.4s, v14.4s\n"
+ "fadd v25.4s, v25.4s, v15.4s\n"
+ "fadd v26.4s, v26.4s, v14.4s\n"
+ "fadd v27.4s, v27.4s, v15.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v15.4s\n"
+ "fadd v30.4s, v30.4s, v14.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v8.4s, v14.s[0]\n"
+ "dup v9.4s, v14.s[1]\n"
+ "fadd v16.4s, v16.4s, v8.4s\n"
+ "dup v10.4s, v14.s[2]\n"
+ "fadd v17.4s, v17.4s, v8.4s\n"
+ "dup v11.4s, v14.s[3]\n"
+ "fadd v18.4s, v18.4s, v9.4s\n"
+ "dup v12.4s, v15.s[0]\n"
+ "fadd v19.4s, v19.4s, v9.4s\n"
+ "dup v13.4s, v15.s[1]\n"
+ "fadd v20.4s, v20.4s, v10.4s\n"
+ "dup v14.4s, v15.s[2]\n"
+ "fadd v21.4s, v21.4s, v10.4s\n"
+ "dup v15.4s, v15.s[3]\n"
+ "fadd v22.4s, v22.4s, v11.4s\n"
+ "fadd v23.4s, v23.4s, v11.4s\n"
+ "fadd v24.4s, v24.4s, v12.4s\n"
+ "fadd v25.4s, v25.4s, v12.4s\n"
+ "fadd v26.4s, v26.4s, v13.4s\n"
+ "fadd v27.4s, v27.4s, v13.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v14.4s\n"
+ "fadd v30.4s, v30.4s, v15.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+ "7:\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.4s, w2\n" // clamp_min
+ "dup v15.4s, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "fmax v16.4s, v16.4s, v14.4s\n"
+ "fmax v17.4s, v17.4s, v14.4s\n"
+ "fmax v18.4s, v18.4s, v14.4s\n"
+ "fmax v19.4s, v19.4s, v14.4s\n"
+ "fmax v20.4s, v20.4s, v14.4s\n"
+ "fmax v21.4s, v21.4s, v14.4s\n"
+ "fmax v22.4s, v22.4s, v14.4s\n"
+ "fmax v23.4s, v23.4s, v14.4s\n"
+ "fmax v24.4s, v24.4s, v14.4s\n"
+ "fmax v25.4s, v25.4s, v14.4s\n"
+ "fmax v26.4s, v26.4s, v14.4s\n"
+ "fmax v27.4s, v27.4s, v14.4s\n"
+ "fmax v28.4s, v28.4s, v14.4s\n"
+ "fmax v29.4s, v29.4s, v14.4s\n"
+ "fmax v30.4s, v30.4s, v14.4s\n"
+ "fmax v31.4s, v31.4s, v14.4s\n"
+
+ // Apply the clamp_max bound
+ "fmin v16.4s, v16.4s, v15.4s\n"
+ "fmin v17.4s, v17.4s, v15.4s\n"
+ "fmin v18.4s, v18.4s, v15.4s\n"
+ "fmin v19.4s, v19.4s, v15.4s\n"
+ "fmin v20.4s, v20.4s, v15.4s\n"
+ "fmin v21.4s, v21.4s, v15.4s\n"
+ "fmin v22.4s, v22.4s, v15.4s\n"
+ "fmin v23.4s, v23.4s, v15.4s\n"
+ "fmin v24.4s, v24.4s, v15.4s\n"
+ "fmin v25.4s, v25.4s, v15.4s\n"
+ "fmin v26.4s, v26.4s, v15.4s\n"
+ "fmin v27.4s, v27.4s, v15.4s\n"
+ "fmin v28.4s, v28.4s, v15.4s\n"
+ "fmin v29.4s, v29.4s, v15.4s\n"
+ "fmin v30.4s, v30.4s, v15.4s\n"
+ "fmin v31.4s, v31.4s, v15.4s\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #32\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "str q16, [x3, #0]\n"
+ "str q17, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ "str q18, [x3, #0]\n"
+ "str q19, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ "str q20, [x3, #0]\n"
+ "str q21, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ "str q22, [x3, #0]\n"
+ "str q23, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ "str q24, [x3, #0]\n"
+ "str q25, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ "str q26, [x3, #0]\n"
+ "str q27, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ "str q28, [x3, #0]\n"
+ "str q29, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ "str q30, [x3, #0]\n"
+ "str q31, [x3, #16]\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that remain to load
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently depth - 1.
+ "sub w1, w12, #1\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#undef RUY_OFFSET_BIAS
+#undef RUY_OFFSET_FLAGS
+#undef RUY_OFFSET_LHS_BASE_PTR
+#undef RUY_OFFSET_CLAMP_MIN
+#undef RUY_OFFSET_CLAMP_MAX
+#undef RUY_OFFSET_START_ROW
+#undef RUY_OFFSET_LAST_ROW
+#undef RUY_OFFSET_LAST_COL
+#undef RUY_OFFSET_LHS_STRIDE
+#undef RUY_OFFSET_RHS_STRIDE
+#undef RUY_OFFSET_DST_STRIDE
+#undef RUY_OFFSET_DEPTH
+#undef RUY_OFFSET_START_COL
+#undef RUY_OFFSET_RHS_BASE_PTR
+#undef RUY_OFFSET_DST_BASE_PTR
+
+#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc
new file mode 100644
index 0000000..2405735
--- /dev/null
+++ b/ruy/kernel_avx.cc
@@ -0,0 +1,1476 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/kernel_common.h"
+#include "ruy/kernel_x86.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM_AVX && RUY_OPT(ASM)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
+
+void Kernel8bitAvx(const KernelParams8bit<8, 8>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx(const KernelParamsFloat<8, 8>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM_AVX && RUY_OPT(ASM)
+
+static constexpr int kAvx8bitBlockSize = 8;
+static constexpr int kAvx8bitInnerSize = 4;
+
+namespace {
+namespace intrin_utils {
+
+template <>
+inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a,
+ const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo);
+ __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi);
+ return _mm256_set_m128i(dst_hi, dst_lo);
+}
+
+template <>
+inline __m128i mm256_extracti128_si256<Path::kAvx>(const __m256i& a,
+ const int imm) {
+ switch (imm) {
+ case 0:
+ return _mm256_extractf128_si256(a, 0);
+ case 1:
+ return _mm256_extractf128_si256(a, 1);
+ default:
+ RUY_DCHECK_LT(imm, 2);
+ return _mm_setzero_si128();
+ }
+}
+
+template <Path path>
+inline __m256i mm256_cvtepi8_epi16(const __m128i& a) {
+ // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
+ __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
+ return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a));
+}
+
+template <Path path>
+inline __m256i mm256_cvtepi32_epi64(const __m128i& a) {
+ // sign extend the 32-bit values in the lower 64 bits of a.
+ __m128i lo = _mm_cvtepi32_epi64(a);
+ __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128()));
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b,
+ const int imm) {
+ __m128i tmp = _mm_setzero_si128();
+ if (!(imm & 8)) {
+ switch (imm & 3) {
+ case 0:
+ return _mm256_extractf128_si256(a, 0);
+ case 1:
+ return _mm256_extractf128_si256(a, 1);
+ case 2:
+ return _mm256_extractf128_si256(b, 0);
+ case 3:
+ return _mm256_extractf128_si256(b, 1);
+ }
+ }
+ return tmp;
+}
+
+template <Path path>
+inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b,
+ const int imm) {
+ const int lo_imm = imm & 15;
+ __m128i lo = mm_permute_helper(a, b, lo_imm);
+ const int hi_imm = (imm >> 4) & 15;
+ __m128i hi = mm_permute_helper(a, b, hi_imm);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_max_epi32(a_lo, b_lo);
+ __m128i hi = _mm_max_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_min_epi32(a_lo, b_lo);
+ __m128i hi = _mm_min_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_add_epi32(a_lo, b_lo);
+ __m128i hi = _mm_add_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_add_epi64(a_lo, b_lo);
+ __m128i hi = _mm_add_epi64(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_slli_epi64(const __m256i& a, int imm) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i lo = _mm_slli_epi64(a_lo, imm);
+ __m128i hi = _mm_slli_epi64(a_hi, imm);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_mullo_epi32(a_lo, b_lo);
+ __m128i hi = _mm_mullo_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+// Defined as a macro since `imm` must be an immediate.
+#define BlendM128_epi32(a, b, imm) \
+ _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm))
+
+// Defined as a macro since `imm` must be an immediate.
+#define BlendM128_epi64(a, b, imm) \
+ _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm))
+
+// Defined as a macro since `imm` must be an immediate.
+#define mm256_blend_epi32(ans, a, b, imm) \
+ __m128i a_lo = _mm256_extractf128_si256(a, 0); \
+ __m128i a_hi = _mm256_extractf128_si256(a, 1); \
+ __m128i b_lo = _mm256_extractf128_si256(b, 0); \
+ __m128i b_hi = _mm256_extractf128_si256(b, 1); \
+ __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \
+ __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4); \
+ ans = _mm256_set_m128i(hi, lo);
+
+#define mm256_shuffle_epi32(ans, a, a_lo, a_hi, imm) \
+ a_lo = _mm256_extractf128_si256(a, 0); \
+ a_hi = _mm256_extractf128_si256(a, 1); \
+ ans = _mm256_set_m128i(_mm_shuffle_epi32(a_hi, imm), \
+ _mm_shuffle_epi32(a_lo, imm));
+
+template <Path path>
+inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_madd_epi16(a_lo, b_lo);
+ __m128i hi = _mm_madd_epi16(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) {
+ // shift both elements of a by lower 64bits of b.
+ __m128i res_lo = _mm_srl_epi64(a, b);
+ // shift both elements of a by upper 64bits of b.
+ __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
+ __m128i res_hi = _mm_srl_epi64(a, hi_count);
+ // Take the lower 64 bits of res_lo and upper 64 bits of res hi
+ // 1. Swap the upper and lower 64 bits of res_hi
+ __m128i tmp_hi =
+ _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
+ // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
+ return _mm_unpacklo_epi64(res_lo, tmp_hi);
+}
+
+template <Path path>
+inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = mm_srlv_epi64(a_lo, b_lo);
+ __m128i hi = mm_srlv_epi64(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) {
+ // shift both elements of a by lower 64bits of b.
+ __m128i res_lo = _mm_sll_epi64(a, b);
+ // shift both elements of a by upper 64bits of b.
+ __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
+ __m128i res_hi = _mm_sll_epi64(a, hi_count);
+ // Take the lower 64 bits of res_lo and upper 64 bits of res hi
+ // 1. Swap the upper and lower 64 bits of res_hi
+ __m128i tmp_hi =
+ _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
+ // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
+ return _mm_unpacklo_epi64(res_lo, tmp_hi);
+}
+
+template <Path path>
+inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo);
+ __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+#define PermuteM128_epi32(a, imm) \
+ _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm));
+
+inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) {
+ // shift all elements of a by first 32bits of b.
+ __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1));
+
+ // put bits 32-63 of b in the first slot.
+ __m128i tmp1 = PermuteM128_epi32(b, 1);
+ // put bits 32-63 of a in the first slot.
+ __m128i a1 = PermuteM128_epi32(a, 1);
+ // shift all elements of a by second 32bits of b.
+ __m128i res1 =
+ _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1));
+
+ // put bits 64-95 of b in the first slot.
+ __m128i tmp2 = PermuteM128_epi32(b, 2);
+ // shift all elements of a by third 32bits of b.
+ __m128i res2 =
+ _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1));
+
+ // put bits 96-127 of b in the first slot.
+ __m128i tmp3 = PermuteM128_epi32(b, 3);
+ // put bits 96-127 of a in the third slot.
+ __m128i a3 = PermuteM128_epi32(a, 48);
+ // shift all elements of a3 by fourth 32bits of b.
+ __m128i res3 =
+ _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1));
+
+ // Take bits 0-31 of res0, bits 0-31 of res1,
+ // bits 64-95 of res2, and bits 64-95 of res3.
+ // res0 _ _ _ 0
+ // res1 _ _ _ 1
+ // res2 _ 2 _ _
+ // res3 _ 3 _ _
+ // f_01 _ _ 1 0
+ // f_23 _ _ 3 2
+ __m128i f_01 = _mm_unpacklo_epi32(res0, res1);
+ __m128i f_23 = _mm_unpackhi_epi32(res2, res3);
+ // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
+ return _mm_unpacklo_epi64(f_01, f_23);
+}
+
+template <Path path>
+inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = mm_sllv_epi32(a_lo, b_lo);
+ __m128i hi = mm_sllv_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_sub_epi32(a_lo, b_lo);
+ __m128i hi = _mm_sub_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_mul_epi32(a_lo, b_lo);
+ __m128i hi = _mm_mul_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+// Perform the equivalent of mm256_permutevar8x32 with
+// a second argument of {7, 5, 3, 1, 6, 4, 2, 0}
+template <Path path>
+inline __m256i PermuteEpi32EvenOdds(const __m256i& a) {
+ // a_lo = 3 2 1 0
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ // a_hi = 7 6 5 4
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ // shuffle a_lo to get 3 1 2 0
+ __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8);
+ // shuffle a_hi to get 7 5 6 4
+ __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8);
+ // unpack lo 64 of res_lo and res hi to get 6 4 2 0
+ __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi);
+ // unpack hi 64 of res_lo and res hi to get 7 5 1 3
+ __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi);
+ return _mm256_set_m128i(res_hi, res_lo);
+}
+
+template <Path path>
+inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) {
+ const __m256i bias0 = _mm256_set1_epi32(*(bias + offset));
+ return mm256_add_epi32<path>(a, bias0);
+}
+
+__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
+ const __m256i& mask) {
+ __m256 result =
+ _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
+ _mm256_castsi256_ps(mask));
+ return _mm256_castps_si256(result);
+}
+
+template <Path path>
+inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo);
+ __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+
+ __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0));
+ __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1));
+ __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2));
+ __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3));
+ __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4));
+ __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5));
+ __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6));
+ __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7));
+
+ // get element 0 from r0, element 1 from r1
+ __m128i r01 = BlendM128_epi32(r0, r1, 2);
+ // get element 2 from r2, element 3 from r3
+ __m128i r23 = BlendM128_epi32(r2, r3, 8);
+ // get element 0 from r4, element 1 from r5
+ __m128i r45 = BlendM128_epi32(r4, r5, 2);
+ // get element 2 from r6, element 3 from r7
+ __m128i r67 = BlendM128_epi32(r6, r7, 8);
+ // get lower 64 bits of r01, upper 64 bits of r23
+ __m128i r0123 = BlendM128_epi64(r01, r23, 2);
+ // get lower 64 bits of r45, upper 64 bits of r67
+ __m128i r4567 = BlendM128_epi64(r45, r67, 2);
+ return _mm256_set_m128i(r4567, r0123);
+}
+
+// AVX doesn't have fused multiply-add so we define an inline function to be
+// used in the common code following.
+template <>
+inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b,
+ const __m256& c) {
+ const __m256 prod = _mm256_mul_ps(a, b);
+ return _mm256_add_ps(prod, c);
+}
+
+} // namespace intrin_utils
+} // namespace
+
+template <Path path>
+void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx 8-bit");
+ const std::int8_t splitter_idx_data[32] = {
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15, //
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15 //
+ };
+
+ std::int32_t dst_stride = 0;
+ if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
+ (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
+ dst_stride = params.dst_stride;
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int16_t);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int32_t);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvx8bitBlockSize) {
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[8];
+ if (has_rhs_sums_offsets) {
+ const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
+ _mm256_set1_epi32(lhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ int channel =
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
+ int multiplier_channel =
+ (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvx8bitBlockSize);
+
+ const __m256i splitter_idx = _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(splitter_idx_data));
+
+ __m256i accum_data_v0;
+ __m256i accum_data_v1;
+ __m256i accum_data_v2;
+ __m256i accum_data_v3;
+ __m256i accum_data_v4;
+ __m256i accum_data_v5;
+ __m256i accum_data_v6;
+ __m256i accum_data_v7;
+
+ // initial_accum_data will be the initialize of each of the
+ // accum_data_* accumulator registers. We compute into it terms that are
+ // identical across columns.
+ __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth);
+ __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth);
+
+ // In the channels-are-rows case, we can load bias here.
+ if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
+ !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
+ initial_accum_data_lo = _mm_add_epi32(
+ initial_accum_data_lo,
+ _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(params.bias + row)));
+ initial_accum_data_hi = _mm_add_epi32(
+ initial_accum_data_hi,
+ _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(params.bias + row + 4)));
+ }
+
+ // Adjustments common across columns.
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point);
+ const __m128i lhs_sums_offset_lo = _mm_mullo_epi32(
+ rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
+ &params.lhs_sums[row])));
+ const __m128i lhs_sums_offset_hi = _mm_mullo_epi32(
+ rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
+ &params.lhs_sums[row + 4])));
+
+ initial_accum_data_lo =
+ _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo);
+ initial_accum_data_hi =
+ _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi);
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ __m256i initial_accum_data =
+ _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
+
+ accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
+ accum_data_v1 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
+ accum_data_v2 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
+ accum_data_v3 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
+ accum_data_v4 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
+ accum_data_v5 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
+ accum_data_v6 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
+ accum_data_v7 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
+ } else {
+ __m256i initial_accum_data =
+ _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
+
+ // Finally, in the channels-are-columns case, load bias data here.
+ if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
+ accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0,
+ params.bias + col, 0);
+ accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1,
+ params.bias + col, 1);
+ accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2,
+ params.bias + col, 2);
+ accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3,
+ params.bias + col, 3);
+ accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4,
+ params.bias + col, 4);
+ accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5,
+ params.bias + col, 5);
+ accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6,
+ params.bias + col, 6);
+ accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7,
+ params.bias + col, 7);
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ const __m256i lhs_data =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
+ const __m256i rhs_data_8bit =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ std::int32_t rhs_data[16];
+ const __m128i rhs_data_bottom_lane =
+ _mm256_castsi256_si128(rhs_data_8bit);
+ const __m128i rhs_data_top_lane =
+ _mm256_extractf128_si256(rhs_data_8bit, 1);
+ const __m256i rhs_16_bit_dup_low =
+ intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane);
+ const __m256i rhs_16_bit_dup_high =
+ intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
+ rhs_16_bit_dup_low);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
+ rhs_16_bit_dup_high);
+
+ // NOTE: There may be opportunities for permuting the data in the
+ // packing code instead of here.
+ const __m256i lhs_data_split =
+ intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
+ const __m256i lhs_data_split_expand_bottom =
+ intrin_utils::mm256_cvtepi8_epi16<path>(
+ _mm256_extractf128_si256(lhs_data_split, 0));
+ const __m256i lhs_data_split_expand_top =
+ intrin_utils::mm256_cvtepi8_epi16<path>(
+ _mm256_extractf128_si256(lhs_data_split, 1));
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_low =
+ intrin_utils::mm256_permute2x128_si256<path>(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
+ // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_high =
+ intrin_utils::mm256_permute2x128_si256<path>(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
+
+ __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
+ rhs_data)); // Load [0 1 2 3 4 5 6 7]
+ __m256i rhs1 = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15]
+ __m256i rhs0_3 =
+ _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3]
+ __m256i rhs4_7 =
+ _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7]
+ __m256i rhs8_11 =
+ _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11]
+ __m256i rhs12_15 =
+ _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15]
+
+ auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
+ __m256i& accum) {
+ // Perform mul-adds on low and high components of accum separately.
+ __m128i accum_lo = _mm256_extractf128_si256(accum, 0);
+ __m128i accum_hi = _mm256_extractf128_si256(accum, 1);
+
+ __m128i lhs_lo_0 = _mm256_extractf128_si256(lhs_16_bit_low, 0);
+ __m128i lhs_lo_1 = _mm256_extractf128_si256(lhs_16_bit_low, 1);
+ __m128i rhs_dup_lo_0 = _mm256_extractf128_si256(rhs_dup_lo, 0);
+ __m128i rhs_dup_lo_1 = _mm256_extractf128_si256(rhs_dup_lo, 1);
+ __m128i lo_0 = _mm_madd_epi16(lhs_lo_0, rhs_dup_lo_0);
+ __m128i lo_1 = _mm_madd_epi16(lhs_lo_1, rhs_dup_lo_1);
+
+ accum_lo = _mm_add_epi32(accum_lo, lo_0);
+ accum_hi = _mm_add_epi32(accum_hi, lo_1);
+
+ __m128i lhs_hi_0 = _mm256_extractf128_si256(lhs_16_bit_high, 0);
+ __m128i lhs_hi_1 = _mm256_extractf128_si256(lhs_16_bit_high, 1);
+ __m128i rhs_dup_hi_0 = _mm256_extractf128_si256(rhs_dup_hi, 0);
+ __m128i rhs_dup_hi_1 = _mm256_extractf128_si256(rhs_dup_hi, 1);
+ __m128i hi_0 = _mm_madd_epi16(lhs_hi_0, rhs_dup_hi_0);
+ __m128i hi_1 = _mm_madd_epi16(lhs_hi_1, rhs_dup_hi_1);
+
+ accum_lo = _mm_add_epi32(accum_lo, hi_0);
+ accum_hi = _mm_add_epi32(accum_hi, hi_1);
+ accum = _mm256_set_m128i(accum_hi, accum_lo);
+ };
+ __m256i tmp0, tmp1, tmp2, tmp3;
+ __m128i lo0, lo1, hi0, hi1;
+ mm256_shuffle_epi32(tmp0, rhs0_3, lo0, hi0, 0);
+ mm256_shuffle_epi32(tmp1, rhs0_3, lo1, hi1, 0x55);
+ process_column(tmp0, tmp1, accum_data_v0);
+ mm256_shuffle_epi32(tmp2, rhs0_3, lo0, hi0, 0xaa);
+ mm256_shuffle_epi32(tmp3, rhs0_3, lo1, hi1, 0xff);
+ process_column(tmp2, tmp3, accum_data_v1);
+
+ mm256_shuffle_epi32(tmp0, rhs4_7, lo0, hi0, 0);
+ mm256_shuffle_epi32(tmp1, rhs4_7, lo1, hi1, 0x55);
+ process_column(tmp0, tmp1, accum_data_v2);
+ mm256_shuffle_epi32(tmp2, rhs4_7, lo0, hi0, 0xaa);
+ mm256_shuffle_epi32(tmp3, rhs4_7, lo1, hi1, 0xff);
+ process_column(tmp2, tmp3, accum_data_v3);
+
+ mm256_shuffle_epi32(tmp0, rhs8_11, lo0, hi0, 0);
+ mm256_shuffle_epi32(tmp1, rhs8_11, lo1, hi1, 0x55);
+ process_column(tmp0, tmp1, accum_data_v4);
+ mm256_shuffle_epi32(tmp2, rhs8_11, lo0, hi0, 0xaa);
+ mm256_shuffle_epi32(tmp3, rhs8_11, lo1, hi1, 0xff);
+ process_column(tmp2, tmp3, accum_data_v5);
+
+ mm256_shuffle_epi32(tmp0, rhs12_15, lo0, hi0, 0);
+ mm256_shuffle_epi32(tmp1, rhs12_15, lo1, hi1, 0x55);
+ process_column(tmp0, tmp1, accum_data_v6);
+ mm256_shuffle_epi32(tmp2, rhs12_15, lo0, hi0, 0xaa);
+ mm256_shuffle_epi32(tmp3, rhs12_15, lo1, hi1, 0xff);
+ process_column(tmp2, tmp3, accum_data_v7);
+
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m256i m_vector;
+ __m256i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ params.multiplier_fixedpoint + multiplier_channel));
+ e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ params.multiplier_exponent + multiplier_channel));
+
+ const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(m_vector, 0));
+ const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(m_vector, 1));
+
+ const __m256i zero_vector = _mm256_setzero_si256();
+ const __m256i left_shift =
+ intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
+ const __m256i neg_e_vector =
+ intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
+ const __m256i right_shift =
+ intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
+ const __m256i final_right_shift = _mm256_set1_epi32(31);
+ const __m256i final_right_shift_low =
+ intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(final_right_shift, 0));
+ const __m256i final_right_shift_high =
+ intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(final_right_shift, 1));
+ const __m256i convert_to_unsigned_64 =
+ _mm256_set1_epi64x(0x8000000000000000);
+
+ __m256i post_scaling_offset = _mm256_setzero_si256();
+
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
+ intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
+ convert_to_unsigned_64);
+
+ if (params.dst_zero_point) {
+ post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
+ }
+
+ // We cannot do
+ //
+ // scaled_v_low =
+ // _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
+ // scaled_v_high =
+ // _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
+ //
+ // since this instruction is not in AVX2. Instead we use
+ // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
+ // offsets before (convert_to_unsigned_64) and after
+ // (convert_to_signed_halved).
+ //
+ // The overall process is, for 64-bit scaled accumulator:
+ // unsigned_accum = signed_accum + 1 << 63;
+ // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
+ // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
+
+ // There are various ways to repack the results, in the absence of
+ // _mm256_cvtepi64_epi32() or anything like it.
+ // A.
+ // accum_data_v[j] =
+ // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
+ // _mm256_extract_epi32(scaled_v_high, 4),
+ // _mm256_extract_epi32(scaled_v_high, 2),
+ // _mm256_extract_epi32(scaled_v_high, 0),
+ // _mm256_extract_epi32(scaled_v_low, 6),
+ // _mm256_extract_epi32(scaled_v_low, 4),
+ // _mm256_extract_epi32(scaled_v_low, 2),
+ // _mm256_extract_epi32(scaled_v_low, 0));
+ // B.
+ // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
+ // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
+ // accum_data_v[j] =
+ // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
+ // _mm256_extract_epi64(scaled_v_high, 0),
+ // _mm256_extract_epi64(scaled_v_low, 2),
+ // _mm256_extract_epi64(scaled_v_low, 0));
+ // C.
+ // scaled_v_low =
+ // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
+ // scaled_v_high =
+ // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
+ // accum_data_v[j] =
+ // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
+ //
+ // However, we choose the following because it uses two lighter
+ // instructions. The permutation does have a longer latency, but this
+ // loop can be unrolled.
+ // D.
+ // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ // __m256i results =
+ // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ // results = _mm256_permutevar8x32_epi32(results, repack_perm);
+ // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results,
+ // post_scaling_offset);
+
+ // This multiplier code is complex and expensive enough on x86, that
+ // we prefer to implement the channels-are-columns case by transposing
+ // around it, rather than duplicate it (which would also require
+ // duplicating the above code computing the multiplier constants).
+ // This is one instance where channels-are-columns has lower performance
+ // than channels-are-rows.
+ const bool transpose_around_multiplier =
+ (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
+ if (transpose_around_multiplier) {
+ // Transpose the 8x8 accumulators block. Will be un-transposed below
+ // after the multplier implementation.
+ intrin_utils::mm256_transpose8x8_epi32<path>(
+ &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
+ &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
+ }
+
+ auto rounding_right_shift = [=](__m256i& results,
+ const __m256i& exponent) {
+ // Construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m256i zeros = _mm256_setzero_si256();
+ const __m256i mask_rightshift_gtz =
+ intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros);
+ const __m256i one_shift_exp_minus1 =
+ intrin_utils::mm256_sllv_epi32<path>(
+ _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
+ exponent, _mm256_set1_epi32(1)));
+ __m256i nudge = intrin_utils::mm256_blendv_epi32(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m256i r_plus_nudge =
+ intrin_utils::mm256_add_epi32<path>(results, nudge);
+ const __m256i shifted_sum =
+ intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent);
+
+ // Identify overflow in each lane and create mask.
+ const __m256i one_shift_31minus_exp =
+ intrin_utils::mm256_sllv_epi32<path>(
+ _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
+ _mm256_set1_epi32(31), exponent));
+ const __m256i mask_num_plus_nudge_overflow =
+ intrin_utils::mm256_cmpgt_epi32<path>(
+ results, intrin_utils::mm256_sub_epi32<path>(
+ _mm256_set1_epi32(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm256_blendv_epi32(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ };
+
+ auto apply_multiplier = [=](__m256i& accum) {
+ __m256i shifted_accum =
+ intrin_utils::mm256_sllv_epi32<path>(accum, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
+ intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
+ intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(shifted_accum, 1)),
+ m_64bit_high);
+ scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
+ offset_vector);
+ scaled_v_high = intrin_utils::mm256_add_epi64<path>(
+ scaled_v_high, offset_vector);
+
+ scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
+ scaled_v_low, final_right_shift_low);
+ scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
+ scaled_v_high, final_right_shift_high);
+ scaled_v_high =
+ intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
+ __m256i results;
+ mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
+ // Permute results to this ordering of int32 elements
+ // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
+ results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
+
+ rounding_right_shift(results, right_shift);
+ accum =
+ intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
+ };
+ apply_multiplier(accum_data_v0);
+ apply_multiplier(accum_data_v1);
+ apply_multiplier(accum_data_v2);
+ apply_multiplier(accum_data_v3);
+ apply_multiplier(accum_data_v4);
+ apply_multiplier(accum_data_v5);
+ apply_multiplier(accum_data_v6);
+ apply_multiplier(accum_data_v7);
+ // See above comment: here we transpose again to undo the transposition
+ // of the 8x8 block of accumulators used to implement the
+ // channels-are-columns case.
+ if (transpose_around_multiplier) {
+ intrin_utils::mm256_transpose8x8_epi32<path>(
+ &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
+ &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
+ }
+ }
+ const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
+ const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
+ const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
+ (residual_cols == kAvx8bitBlockSize);
+
+ __m256i accum_data_v[kAvx8bitBlockSize];
+ if (!store_full_block) {
+ accum_data_v[0] = accum_data_v0;
+ accum_data_v[1] = accum_data_v1;
+ accum_data_v[2] = accum_data_v2;
+ accum_data_v[3] = accum_data_v3;
+ accum_data_v[4] = accum_data_v4;
+ accum_data_v[5] = accum_data_v5;
+ accum_data_v[6] = accum_data_v6;
+ accum_data_v[7] = accum_data_v7;
+ }
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
+ accum_data_v0 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
+ accum_data_v1 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
+ accum_data_v1 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
+ accum_data_v2 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
+ accum_data_v2 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
+ accum_data_v3 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
+ accum_data_v3 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
+ accum_data_v4 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
+ accum_data_v4 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
+ accum_data_v5 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
+ accum_data_v5 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
+ accum_data_v6 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
+ accum_data_v6 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
+ accum_data_v7 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
+ accum_data_v7 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[0 * dst_stride], accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[1 * dst_stride], accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[2 * dst_stride], accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[3 * dst_stride], accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[4 * dst_stride], accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[5 * dst_stride], accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[6 * dst_stride], accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[7 * dst_stride], accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256i result = accum_data_v[j];
+ result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
+ result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
+ tmp_ptr, residual_rows, result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
+ accum_data_v0 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
+ accum_data_v1 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
+ accum_data_v1 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
+ accum_data_v2 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
+ accum_data_v2 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
+ accum_data_v3 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
+ accum_data_v3 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
+ accum_data_v4 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
+ accum_data_v4 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
+ accum_data_v5 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
+ accum_data_v5 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
+ accum_data_v6 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
+ accum_data_v6 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
+ accum_data_v7 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
+ accum_data_v7 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
+ accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[2 * dst_stride], accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[3 * dst_stride], accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[4 * dst_stride], accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[5 * dst_stride], accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[6 * dst_stride], accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[7 * dst_stride], accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256i result = accum_data_v[j];
+ result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
+ result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
+ tmp_ptr, residual_rows, result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
+ accum_data_v0 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
+ accum_data_v1 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
+ accum_data_v1 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
+ accum_data_v2 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
+ accum_data_v2 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
+ accum_data_v3 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
+ accum_data_v3 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
+ accum_data_v4 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
+ accum_data_v4 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
+ accum_data_v5 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
+ accum_data_v5 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
+ accum_data_v6 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
+ accum_data_v6 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
+ accum_data_v7 =
+ intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
+ accum_data_v7 =
+ intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
+ accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[2 * dst_stride], accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[3 * dst_stride], accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[4 * dst_stride], accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[5 * dst_stride], accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[6 * dst_stride], accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[7 * dst_stride], accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256i result = accum_data_v[j];
+ result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
+ result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
+ tmp_ptr, residual_rows, result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ if (store_full_block) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
+ accum_data_v2);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
+ accum_data_v3);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
+ accum_data_v4);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
+ accum_data_v5);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
+ accum_data_v6);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
+ accum_data_v7);
+ } else {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ intrin_utils::mm256_n_storeu_epi32<path>(
+ dst_block_ptr, residual_rows, accum_data_v[j]);
+ dst_block_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ } // End col-block loop.
+} // NOLINT(readability/fn_size)
+
+void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) {
+ Kernel8bitAvxImpl<Path::kAvx>(params);
+}
+
+template <Path path>
+void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ const std::int8_t splitter_idx_data[32] = {
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15, //
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15 //
+ };
+
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[8];
+ if (has_rhs_sums_offsets) {
+ const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
+ _mm256_set1_epi32(lhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+
+ const __m256i splitter_idx =
+ _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
+
+ __m256i accum_data_v0;
+
+ // Initialize with bias.
+ __m256i initial_accum_data =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
+ bias_ptr += bias_ptr_block_increment;
+
+ // Adjustments common across columns.
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>(
+ _mm256_set1_epi32(rhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
+ initial_accum_data = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, lhs_sums_offset);
+ }
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth) {
+ initial_accum_data = intrin_utils::mm256_add_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ const __m256i lhs_data =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
+ const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ // For simplicity we load 4x the data that we need and process twice the
+ // data that we need and store only the data we need.
+ std::int32_t rhs_data[2];
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+
+ // NOTE: There may be opportunities for permuting the data in the packing
+ // code instead of here.
+ const __m256i lhs_data_split =
+ intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
+ const __m256i lhs_data_split_expand_bottom =
+ intrin_utils::mm256_cvtepi8_epi16<path>(
+ _mm256_extractf128_si256(lhs_data_split, 0));
+ const __m256i lhs_data_split_expand_top =
+ intrin_utils::mm256_cvtepi8_epi16<path>(
+ _mm256_extractf128_si256(lhs_data_split, 1));
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_low =
+ intrin_utils::mm256_permute2x128_si256<path>(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
+ // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_high =
+ intrin_utils::mm256_permute2x128_si256<path>(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
+ // Accumulate for column 0.
+ const std::int32_t low_rhs_value = rhs_data[0];
+ const std::int32_t high_rhs_value = rhs_data[1];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
+ accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
+ lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
+ accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
+ lhs_16_bit_high, rhs_16_bit_dup_high));
+
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m256i m_vector;
+ __m256i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
+ m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ params.multiplier_fixedpoint + channel));
+ e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ params.multiplier_exponent + channel));
+
+ const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(m_vector, 0));
+ const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(m_vector, 1));
+
+ const __m256i zero_vector = _mm256_setzero_si256();
+ const __m256i left_shift =
+ intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
+ const __m256i neg_e_vector =
+ intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
+ const __m256i right_shift =
+ intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
+ const __m256i final_right_shift = _mm256_set1_epi32(31);
+ const __m256i final_right_shift_low =
+ intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(final_right_shift, 0));
+ const __m256i final_right_shift_high =
+ intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(final_right_shift, 1));
+ const __m256i convert_to_unsigned_64 =
+ _mm256_set1_epi64x(0x8000000000000000);
+
+ __m256i post_scaling_offset = _mm256_setzero_si256();
+
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
+ intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
+ convert_to_unsigned_64);
+
+ if (params.dst_zero_point) {
+ post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
+ }
+
+ // See GEMM version for details of this process.
+ {
+ __m256i shifted_accum =
+ intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
+ intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
+ intrin_utils::mm256_cvtepi32_epi64<path>(
+ _mm256_extractf128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
+ offset_vector);
+ scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high,
+ offset_vector);
+
+ scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
+ scaled_v_low, final_right_shift_low);
+ scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
+ scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
+ __m256i results;
+ mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
+ // Permute results to this ordering of int32 elements
+ // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
+ results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
+
+ // Now perform the Rounding Right Shift.
+ // First, construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m256i zeros = _mm256_setzero_si256();
+ const __m256i mask_rightshift_gtz =
+ intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros);
+ const __m256i one_shift_exp_minus1 =
+ intrin_utils::mm256_sllv_epi32<path>(
+ _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
+ right_shift, _mm256_set1_epi32(1)));
+ __m256i nudge = intrin_utils::mm256_blendv_epi32(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m256i r_plus_nudge =
+ intrin_utils::mm256_add_epi32<path>(results, nudge);
+ const __m256i shifted_sum =
+ intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift);
+
+ // Identify overflow in each lane and create mask.
+ const __m256i one_shift_31minus_exp =
+ intrin_utils::mm256_sllv_epi32<path>(
+ _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
+ _mm256_set1_epi32(31), right_shift));
+ const __m256i mask_num_plus_nudge_overflow =
+ intrin_utils::mm256_cmpgt_epi32<path>(
+ results, intrin_utils::mm256_sub_epi32<path>(
+ _mm256_set1_epi32(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm256_blendv_epi32(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ accum_data_v0 =
+ intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
+ }
+ }
+ const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
+ const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ __m256i result = accum_data_v0;
+ result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
+ result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ __m256i result = accum_data_v0;
+ result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
+ result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ __m256i result = accum_data_v0;
+ result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
+ result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
+ accum_data_v0);
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+} // NOLINT(readability/fn_size)
+
+void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) {
+ Kernel8bitAvxSingleColImpl<Path::kAvx>(params);
+}
+
+void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx float");
+ KernelFloatAvxCommon<Path::kAvx>(params);
+}
+
+void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx float GEMV");
+ KernelFloatAvxCommonSingleCol<Path::kAvx>(params);
+}
+
+#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
new file mode 100644
index 0000000..eae333c
--- /dev/null
+++ b/ruy/kernel_avx2_fma.cc
@@ -0,0 +1,1011 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/kernel_common.h"
+#include "ruy/kernel_x86.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
+
+void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx2(const KernelParamsFloat<8, 8>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
+
+static constexpr int kAvx8bitBlockSize = 8;
+static constexpr int kAvx8bitInnerSize = 4;
+
+namespace {
+namespace intrin_utils {
+
+template <>
+inline __m256i mm256_shuffle_epi8<Path::kAvx2Fma>(const __m256i& a,
+ const __m256i& b) {
+ return _mm256_shuffle_epi8(a, b);
+}
+
+// Make an inline function for FMA so we can share the float kernels
+// with non-FMA code.
+template <>
+inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b,
+ const __m256& c) {
+ return _mm256_fmadd_ps(a, b, c);
+}
+
+template <>
+inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a,
+ const int imm) {
+ switch (imm) {
+ case 0:
+ return _mm256_extracti128_si256(a, 0);
+ case 1:
+ return _mm256_extracti128_si256(a, 1);
+ default:
+ RUY_DCHECK_LT(imm, 2);
+ return _mm_setzero_si128();
+ }
+}
+
+__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
+ const __m256i& mask) {
+ __m256 result =
+ _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
+ _mm256_castsi256_ps(mask));
+ return _mm256_castps_si256(result);
+}
+
+} // namespace intrin_utils
+} // namespace
+
+template <Path path>
+void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit");
+ const std::int8_t splitter_idx_data[32] = {
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15, //
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15 //
+ };
+
+ std::int32_t dst_stride = 0;
+ if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
+ (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
+ dst_stride = params.dst_stride;
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int16_t);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int32_t);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvx8bitBlockSize) {
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[8];
+ if (has_rhs_sums_offsets) {
+ const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
+ _mm256_set1_epi32(lhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ int channel =
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
+ int multiplier_channel =
+ (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvx8bitBlockSize);
+
+ const __m256i splitter_idx = _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(splitter_idx_data));
+
+ __m256i accum_data_v0;
+ __m256i accum_data_v1;
+ __m256i accum_data_v2;
+ __m256i accum_data_v3;
+ __m256i accum_data_v4;
+ __m256i accum_data_v5;
+ __m256i accum_data_v6;
+ __m256i accum_data_v7;
+
+ // initial_accum_data will be the initialize of each of the
+ // accum_data_* accumulator registers. We compute into it terms that are
+ // identical across columns.
+ __m256i initial_accum_data = _mm256_set1_epi32(params.prod_zp_depth);
+
+ // In the channels-are-rows case, we can load bias here.
+ if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
+ !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
+ initial_accum_data = _mm256_add_epi32(
+ initial_accum_data,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(params.bias + row)));
+ }
+
+ // Adjustments common across columns.
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m256i lhs_sums_offset = _mm256_mullo_epi32(
+ _mm256_set1_epi32(rhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
+ initial_accum_data =
+ _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
+ accum_data_v1 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
+ accum_data_v2 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
+ accum_data_v3 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
+ accum_data_v4 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
+ accum_data_v5 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
+ accum_data_v6 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
+ accum_data_v7 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
+
+ // Finally, in the channels-are-columns case, load bias data here.
+ if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
+ const __m256i bias_data = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(params.bias + col));
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0,
+ _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(0)));
+ accum_data_v1 = _mm256_add_epi32(
+ accum_data_v1,
+ _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(1)));
+ accum_data_v2 = _mm256_add_epi32(
+ accum_data_v2,
+ _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(2)));
+ accum_data_v3 = _mm256_add_epi32(
+ accum_data_v3,
+ _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(3)));
+ accum_data_v4 = _mm256_add_epi32(
+ accum_data_v4,
+ _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(4)));
+ accum_data_v5 = _mm256_add_epi32(
+ accum_data_v5,
+ _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(5)));
+ accum_data_v6 = _mm256_add_epi32(
+ accum_data_v6,
+ _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(6)));
+ accum_data_v7 = _mm256_add_epi32(
+ accum_data_v7,
+ _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(7)));
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ const __m256i lhs_data =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
+ const __m256i rhs_data_8bit =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ std::int32_t rhs_data[16];
+ const __m128i rhs_data_bottom_lane =
+ _mm256_castsi256_si128(rhs_data_8bit);
+ const __m128i rhs_data_top_lane =
+ _mm256_extracti128_si256(rhs_data_8bit, 1);
+ const __m256i rhs_16_bit_dup_low =
+ _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
+ const __m256i rhs_16_bit_dup_high =
+ _mm256_cvtepi8_epi16(rhs_data_top_lane);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
+ rhs_16_bit_dup_low);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
+ rhs_16_bit_dup_high);
+
+ const __m256i lhs_data_split =
+ _mm256_shuffle_epi8(lhs_data, splitter_idx);
+ const __m256i lhs_data_split_expand_bottom =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
+ const __m256i lhs_data_split_expand_top =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
+ // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
+
+ __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
+ rhs_data)); // Load [0 1 2 3 4 5 6 7]
+ __m256i rhs1 = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15]
+ __m256i rhs0_3 =
+ _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3]
+ __m256i rhs4_7 =
+ _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7]
+ __m256i rhs8_11 =
+ _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11]
+ __m256i rhs12_15 =
+ _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15]
+
+ auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
+ __m256i& accum) {
+ accum = _mm256_add_epi32(
+ accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_dup_lo));
+ accum = _mm256_add_epi32(
+ accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_dup_hi));
+ };
+ __m256i tmp0, tmp1, tmp2, tmp3;
+ tmp0 = _mm256_shuffle_epi32(rhs0_3, 0);
+ tmp1 = _mm256_shuffle_epi32(rhs0_3, 0x55);
+ process_column(tmp0, tmp1, accum_data_v0);
+ tmp2 = _mm256_shuffle_epi32(rhs0_3, 0xaa);
+ tmp3 = _mm256_shuffle_epi32(rhs0_3, 0xff);
+ process_column(tmp2, tmp3, accum_data_v1);
+
+ tmp0 = _mm256_shuffle_epi32(rhs4_7, 0);
+ tmp1 = _mm256_shuffle_epi32(rhs4_7, 0x55);
+ process_column(tmp0, tmp1, accum_data_v2);
+ tmp2 = _mm256_shuffle_epi32(rhs4_7, 0xaa);
+ tmp3 = _mm256_shuffle_epi32(rhs4_7, 0xff);
+ process_column(tmp2, tmp3, accum_data_v3);
+
+ tmp0 = _mm256_shuffle_epi32(rhs8_11, 0);
+ tmp1 = _mm256_shuffle_epi32(rhs8_11, 0x55);
+ process_column(tmp0, tmp1, accum_data_v4);
+ tmp2 = _mm256_shuffle_epi32(rhs8_11, 0xaa);
+ tmp3 = _mm256_shuffle_epi32(rhs8_11, 0xff);
+ process_column(tmp2, tmp3, accum_data_v5);
+
+ tmp0 = _mm256_shuffle_epi32(rhs12_15, 0);
+ tmp1 = _mm256_shuffle_epi32(rhs12_15, 0x55);
+ process_column(tmp0, tmp1, accum_data_v6);
+ tmp2 = _mm256_shuffle_epi32(rhs12_15, 0xaa);
+ tmp3 = _mm256_shuffle_epi32(rhs12_15, 0xff);
+ process_column(tmp2, tmp3, accum_data_v7);
+
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m256i m_vector;
+ __m256i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ params.multiplier_fixedpoint + multiplier_channel));
+ e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ params.multiplier_exponent + multiplier_channel));
+
+ const __m256i m_64bit_low =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
+ const __m256i m_64bit_high =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
+
+ const __m256i zero_vector = _mm256_setzero_si256();
+ const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
+ const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
+ const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
+ const __m256i final_right_shift = _mm256_set1_epi32(31);
+ const __m256i final_right_shift_low = _mm256_cvtepi32_epi64(
+ _mm256_extracti128_si256(final_right_shift, 0));
+ const __m256i final_right_shift_high = _mm256_cvtepi32_epi64(
+ _mm256_extracti128_si256(final_right_shift, 1));
+ const __m256i convert_to_unsigned_64 =
+ _mm256_set1_epi64x(0x8000000000000000);
+
+ __m256i post_scaling_offset = _mm256_setzero_si256();
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m256i offset_vector = _mm256_add_epi64(
+ _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
+ convert_to_unsigned_64);
+
+ if (params.dst_zero_point) {
+ post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
+ }
+
+ const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
+
+ // We cannot do
+ //
+ // scaled_v_low =
+ // _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
+ // scaled_v_high =
+ // _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
+ //
+ // since this instruction is not in AVX2. Instead we use
+ // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
+ // offsets before (convert_to_unsigned_64) and after
+ // (convert_to_signed_halved).
+ //
+ // The overall process is, for 64-bit scaled accumulator:
+ // unsigned_accum = signed_accum + 1 << 63;
+ // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
+ // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
+
+ // There are various ways to repack the results, in the absence of
+ // _mm256_cvtepi64_epi32() or anything like it.
+ // A.
+ // accum_data_v[j] =
+ // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
+ // _mm256_extract_epi32(scaled_v_high, 4),
+ // _mm256_extract_epi32(scaled_v_high, 2),
+ // _mm256_extract_epi32(scaled_v_high, 0),
+ // _mm256_extract_epi32(scaled_v_low, 6),
+ // _mm256_extract_epi32(scaled_v_low, 4),
+ // _mm256_extract_epi32(scaled_v_low, 2),
+ // _mm256_extract_epi32(scaled_v_low, 0));
+ // B.
+ // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
+ // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
+ // accum_data_v[j] =
+ // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
+ // _mm256_extract_epi64(scaled_v_high, 0),
+ // _mm256_extract_epi64(scaled_v_low, 2),
+ // _mm256_extract_epi64(scaled_v_low, 0));
+ // C.
+ // scaled_v_low =
+ // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
+ // scaled_v_high =
+ // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
+ // accum_data_v[j] =
+ // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
+ //
+ // However, we choose the following because it uses two lighter
+ // instructions. The permutation does have a longer latency, but this
+ // loop can be unrolled.
+ // D.
+ // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ // __m256i results =
+ // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ // results = _mm256_permutevar8x32_epi32(results, repack_perm);
+ // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset);
+
+ // This multiplier code is complex and expensive enough on x86, that
+ // we prefer to implement the channels-are-columns case by transposing
+ // around it, rather than duplicate it (which would also require
+ // duplicating the above code computing the multiplier constants).
+ // This is one instance where channels-are-columns has lower performance
+ // than channels-are-rows.
+ const bool transpose_around_multiplier =
+ (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
+ if (transpose_around_multiplier) {
+ // Transpose the 8x8 accumulators block. Will be un-transposed below
+ // after the multplier implementation.
+ intrin_utils::mm256_transpose8x8_epi32<path>(
+ &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
+ &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
+ }
+
+ auto rounding_right_shift = [=](__m256i& results,
+ const __m256i& exponent) {
+ // Construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m256i zeros = _mm256_setzero_si256();
+ const __m256i mask_rightshift_gtz =
+ _mm256_cmpgt_epi32(exponent, zeros);
+ const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1),
+ _mm256_sub_epi32(exponent, _mm256_set1_epi32(1)));
+ __m256i nudge = intrin_utils::mm256_blendv_epi32(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
+ const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent);
+
+ // Identify overflow in each lane and create mask.
+ const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1),
+ _mm256_sub_epi32(_mm256_set1_epi32(31), exponent));
+ const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
+ results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm256_blendv_epi32(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ };
+
+ auto apply_multiplier = [=](__m256i& accum) {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+ // Now do a Rounding Right Shift.
+ rounding_right_shift(results, right_shift);
+ accum = _mm256_add_epi32(results, post_scaling_offset);
+ };
+ apply_multiplier(accum_data_v0);
+ apply_multiplier(accum_data_v1);
+ apply_multiplier(accum_data_v2);
+ apply_multiplier(accum_data_v3);
+ apply_multiplier(accum_data_v4);
+ apply_multiplier(accum_data_v5);
+ apply_multiplier(accum_data_v6);
+ apply_multiplier(accum_data_v7);
+ // See above comment: here we transpose again to undo the transposition
+ // of the 8x8 block of accumulators used to implement the
+ // channels-are-columns case.
+ if (transpose_around_multiplier) {
+ intrin_utils::mm256_transpose8x8_epi32<path>(
+ &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
+ &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
+ }
+ }
+ const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
+ const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
+ const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
+ (residual_cols == kAvx8bitBlockSize);
+
+ __m256i accum_data_v[kAvx8bitBlockSize];
+ if (!store_full_block) {
+ accum_data_v[0] = accum_data_v0;
+ accum_data_v[1] = accum_data_v1;
+ accum_data_v[2] = accum_data_v2;
+ accum_data_v[3] = accum_data_v3;
+ accum_data_v[4] = accum_data_v4;
+ accum_data_v[5] = accum_data_v5;
+ accum_data_v[6] = accum_data_v6;
+ accum_data_v[7] = accum_data_v7;
+ }
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
+ accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
+ accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
+ accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
+ accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
+ accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
+ accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
+ accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[0 * dst_stride], accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[1 * dst_stride], accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[2 * dst_stride], accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[3 * dst_stride], accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[4 * dst_stride], accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[5 * dst_stride], accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[6 * dst_stride], accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[7 * dst_stride], accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256i result = accum_data_v[j];
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
+ tmp_ptr, residual_rows, result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
+ accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
+ accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
+ accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
+ accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
+ accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
+ accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
+ accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
+ accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[2 * dst_stride], accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[3 * dst_stride], accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[4 * dst_stride], accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[5 * dst_stride], accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[6 * dst_stride], accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
+ &tmp_ptr[7 * dst_stride], accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256i result = accum_data_v[j];
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
+ tmp_ptr, residual_rows, result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
+ accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
+ accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
+ accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
+ accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
+ accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
+ accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
+ accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
+ accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[2 * dst_stride], accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[3 * dst_stride], accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[4 * dst_stride], accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[5 * dst_stride], accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[6 * dst_stride], accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
+ &tmp_ptr[7 * dst_stride], accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256i result = accum_data_v[j];
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
+ tmp_ptr, residual_rows, result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ if (store_full_block) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
+ accum_data_v2);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
+ accum_data_v3);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
+ accum_data_v4);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
+ accum_data_v5);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
+ accum_data_v6);
+ intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
+ accum_data_v7);
+ } else {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ intrin_utils::mm256_n_storeu_epi32<path>(
+ dst_block_ptr, residual_rows, accum_data_v[j]);
+ dst_block_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ } // End col-block loop.
+} // NOLINT(readability/fn_size)
+
+void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
+ Kernel8bitAvx2Impl<Path::kAvx2Fma>(params);
+}
+
+template <Path path>
+void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ const std::int8_t splitter_idx_data[32] = {
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15, //
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15 //
+ };
+
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[8];
+ if (has_rhs_sums_offsets) {
+ const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
+ _mm256_set1_epi32(lhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+
+ const __m256i splitter_idx =
+ _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
+
+ __m256i accum_data_v0;
+
+ // Initialize with bias.
+ __m256i initial_accum_data =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
+ bias_ptr += bias_ptr_block_increment;
+
+ // Adjustments common across columns.
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m256i lhs_sums_offset = _mm256_mullo_epi32(
+ _mm256_set1_epi32(rhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
+ initial_accum_data =
+ _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth) {
+ initial_accum_data = _mm256_add_epi32(initial_accum_data,
+ _mm256_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm256_sub_epi32(initial_accum_data,
+ _mm256_set1_epi32(rhs_sums_offsets[0]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ const __m256i lhs_data =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
+ const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ // For simplicity we load 4x the data that we need and process twice the
+ // data that we need and store only the data we need.
+ std::int32_t rhs_data[2];
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+
+ // NOTE: There may be opportunities for permuting the data in the packing
+ // code instead of here.
+ const __m256i lhs_data_split =
+ _mm256_shuffle_epi8(lhs_data, splitter_idx);
+ const __m256i lhs_data_split_expand_bottom =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
+ const __m256i lhs_data_split_expand_top =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
+ // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
+ // Accumulate for column 0.
+ const std::int32_t low_rhs_value = rhs_data[0];
+ const std::int32_t high_rhs_value = rhs_data[1];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m256i m_vector;
+ __m256i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
+ m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ params.multiplier_fixedpoint + channel));
+ e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ params.multiplier_exponent + channel));
+
+ const __m256i m_64bit_low =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
+ const __m256i m_64bit_high =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
+
+ const __m256i zero_vector = _mm256_setzero_si256();
+ const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
+ const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
+ const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
+ const __m256i final_right_shift = _mm256_set1_epi32(31);
+ const __m256i final_right_shift_low =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0));
+ const __m256i final_right_shift_high =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1));
+ const __m256i convert_to_unsigned_64 =
+ _mm256_set1_epi64x(0x8000000000000000);
+
+ __m256i post_scaling_offset = _mm256_setzero_si256();
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m256i offset_vector = _mm256_add_epi64(
+ _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
+ convert_to_unsigned_64);
+
+ if (params.dst_zero_point) {
+ post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
+ }
+
+ const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
+
+ // See GEMM version for details of this process.
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ // Now do a Rounding Right Shift.
+ // First, construct the nudge value for each lane.
+ const __m256i zeros = _mm256_setzero_si256();
+ const __m256i mask_rightshift_gtz =
+ _mm256_cmpgt_epi32(right_shift, zeros);
+ const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1),
+ _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1)));
+ __m256i nudge = intrin_utils::mm256_blendv_epi32(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
+ const __m256i shifted_sum =
+ _mm256_srav_epi32(r_plus_nudge, right_shift);
+
+ // Identify overflow in each lane and create mask.
+ const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1),
+ _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift));
+ const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
+ results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm256_blendv_epi32(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+
+ accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset);
+ }
+ }
+ const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
+ const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ __m256i result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ __m256i result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ __m256i result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
+ accum_data_v0);
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+} // NOLINT(readability/fn_size)
+
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
+ Kernel8bitAvx2SingleColImpl<Path::kAvx2Fma>(params);
+}
+
+void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2Fma float");
+ KernelFloatAvxCommon<Path::kAvx2Fma>(params);
+}
+
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV");
+ KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params);
+}
+
+#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
new file mode 100644
index 0000000..fddb482
--- /dev/null
+++ b/ruy/kernel_avx512.cc
@@ -0,0 +1,1550 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+
+#include "ruy/check_macros.h"
+#include "ruy/kernel_x86.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
+
+void Kernel8bitAvx512(const KernelParams8bit<16, 16>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx512(const KernelParamsFloat<16, 16>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
+
+namespace {
+namespace intrin_utils {
+
+__m256i mm256_blendv_epi64(const __m256i& a, const __m256i& b,
+ const __m256i& mask) {
+ __m256d result =
+ _mm256_blendv_pd(_mm256_castsi256_pd(a), _mm256_castsi256_pd(b),
+ _mm256_castsi256_pd(mask));
+ return _mm256_castpd_si256(result);
+}
+
+__m512i mm512_blendv_epi64(const __m512i& a, const __m512i& b,
+ const __m512i& mask) {
+ __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
+ __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
+ __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
+ __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
+ __m256i mask_lo = _mm512_extracti64x4_epi64(mask, 0);
+ __m256i mask_hi = _mm512_extracti64x4_epi64(mask, 1);
+ __m256i lo = mm256_blendv_epi64(a_lo, b_lo, mask_lo);
+ __m256i hi = mm256_blendv_epi64(a_hi, b_hi, mask_hi);
+ __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0);
+ return _mm512_inserti64x4(result, hi, 1);
+}
+
+__m512i mm512_cmpgt_epi64(const __m512i& a, const __m512i& b) {
+ __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
+ __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
+ __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
+ __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
+ __m256i lo = _mm256_cmpgt_epi64(a_lo, b_lo);
+ __m256i hi = _mm256_cmpgt_epi64(a_hi, b_hi);
+ __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0);
+ return _mm512_inserti64x4(result, hi, 1);
+}
+
+} // namespace intrin_utils
+} // namespace
+
+void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvx512 8-bit");
+
+ std::int32_t dst_stride = 0;
+ if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
+ (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
+ dst_stride = params.dst_stride;
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int16_t);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int32_t);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+
+ for (int col = params.start_col; col <= params.last_col; col += 16) {
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[16];
+ if (has_rhs_sums_offsets) {
+ const __m512i rhs_sums_offset_v =
+ _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
+ _mm512_loadu_si512(&params.rhs_sums[col]));
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row; row += 16) {
+ int channel =
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
+ int multiplier_channel =
+ (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
+
+ const int residual_rows = std::min(params.dst_rows - row, 16);
+ const int residual_cols = std::min(params.dst_cols - col, 16);
+
+ __m512i accum_data_v0;
+ __m512i accum_data_v1;
+ __m512i accum_data_v2;
+ __m512i accum_data_v3;
+ __m512i accum_data_v4;
+ __m512i accum_data_v5;
+ __m512i accum_data_v6;
+ __m512i accum_data_v7;
+ __m512i accum_data_v8;
+ __m512i accum_data_v9;
+ __m512i accum_data_va;
+ __m512i accum_data_vb;
+ __m512i accum_data_vc;
+ __m512i accum_data_vd;
+ __m512i accum_data_ve;
+ __m512i accum_data_vf;
+
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+
+ // initial_accum_data will be the initialize of each of the
+ // accum_data_* accumulator registers. We compute into it terms that are
+ // identical across columns.
+ __m512i initial_accum_data = _mm512_set1_epi32(params.prod_zp_depth);
+
+ // In the channels-are-rows case, we can load bias here.
+ if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
+ !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
+ initial_accum_data = _mm512_add_epi32(
+ initial_accum_data,
+ _mm512_loadu_si512(
+ reinterpret_cast<const __m512i*>(params.bias + row)));
+ }
+
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m512i lhs_sums_offset =
+ _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
+ _mm512_loadu_si512(&params.lhs_sums[row]));
+ initial_accum_data =
+ _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0]));
+ accum_data_v1 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1]));
+ accum_data_v2 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2]));
+ accum_data_v3 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3]));
+ accum_data_v4 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4]));
+ accum_data_v5 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5]));
+ accum_data_v6 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6]));
+ accum_data_v7 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7]));
+ accum_data_v8 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8]));
+ accum_data_v9 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9]));
+ accum_data_va = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10]));
+ accum_data_vb = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11]));
+ accum_data_vc = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12]));
+ accum_data_vd = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13]));
+ accum_data_ve = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14]));
+ accum_data_vf = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ accum_data_v8 = initial_accum_data;
+ accum_data_v9 = initial_accum_data;
+ accum_data_va = initial_accum_data;
+ accum_data_vb = initial_accum_data;
+ accum_data_vc = initial_accum_data;
+ accum_data_vd = initial_accum_data;
+ accum_data_ve = initial_accum_data;
+ accum_data_vf = initial_accum_data;
+ }
+
+ // Finally, in the channels-are-columns case, load bias data here.
+ if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
+ const __m512i bias_data = _mm512_loadu_si512(
+ reinterpret_cast<const __m512i*>(params.bias + col));
+ accum_data_v0 = _mm512_add_epi32(
+ accum_data_v0,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(0), bias_data));
+ accum_data_v1 = _mm512_add_epi32(
+ accum_data_v1,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(1), bias_data));
+ accum_data_v2 = _mm512_add_epi32(
+ accum_data_v2,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(2), bias_data));
+ accum_data_v3 = _mm512_add_epi32(
+ accum_data_v3,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(3), bias_data));
+ accum_data_v4 = _mm512_add_epi32(
+ accum_data_v4,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(4), bias_data));
+ accum_data_v5 = _mm512_add_epi32(
+ accum_data_v5,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(5), bias_data));
+ accum_data_v6 = _mm512_add_epi32(
+ accum_data_v6,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(6), bias_data));
+ accum_data_v7 = _mm512_add_epi32(
+ accum_data_v7,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(7), bias_data));
+ accum_data_v8 = _mm512_add_epi32(
+ accum_data_v8,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(8), bias_data));
+ accum_data_v9 = _mm512_add_epi32(
+ accum_data_v9,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(9), bias_data));
+ accum_data_va = _mm512_add_epi32(
+ accum_data_va,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(10), bias_data));
+ accum_data_vb = _mm512_add_epi32(
+ accum_data_vb,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(11), bias_data));
+ accum_data_vc = _mm512_add_epi32(
+ accum_data_vc,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(12), bias_data));
+ accum_data_vd = _mm512_add_epi32(
+ accum_data_vd,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(13), bias_data));
+ accum_data_ve = _mm512_add_epi32(
+ accum_data_ve,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(14), bias_data));
+ accum_data_vf = _mm512_add_epi32(
+ accum_data_vf,
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(15), bias_data));
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += 4) {
+ const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
+ __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr);
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ std::int32_t rhs_data[32];
+ const __m256i rhs_data_bottom_lane =
+ _mm512_castsi512_si256(rhs_data_8bit);
+ const __m256i rhs_data_top_lane =
+ _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
+ const __m512i rhs_16_bit_dup_low =
+ _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_cvtepi8_epi16(rhs_data_top_lane);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data),
+ rhs_16_bit_dup_low);
+ _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16),
+ rhs_16_bit_dup_high);
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_low =
+ _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
+ // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
+ _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
+
+ auto process_column = [=](int col, __m512i& accum) {
+ const __m512i rhs_16_bit_dup_low =
+ _mm512_set1_epi32(rhs_data[2 * col]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[2 * col + 1]);
+
+ accum = _mm512_add_epi32(
+ accum, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum = _mm512_add_epi32(
+ accum, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ };
+ process_column(0, accum_data_v0);
+ process_column(1, accum_data_v1);
+ process_column(2, accum_data_v2);
+ process_column(3, accum_data_v3);
+ process_column(4, accum_data_v4);
+ process_column(5, accum_data_v5);
+ process_column(6, accum_data_v6);
+ process_column(7, accum_data_v7);
+ process_column(8, accum_data_v8);
+ process_column(9, accum_data_v9);
+ process_column(10, accum_data_va);
+ process_column(11, accum_data_vb);
+ process_column(12, accum_data_vc);
+ process_column(13, accum_data_vd);
+ process_column(14, accum_data_ve);
+ process_column(15, accum_data_vf);
+
+ lhs_ptr += 16 * 4;
+ rhs_ptr += 16 * 4;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ // The non-per-channel case could equivalently be handled in the per-row
+ // or per-column code path. The per-row code path is slightly more
+ // efficient so we handle it there.
+ const bool per_column_multiplier =
+ (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
+
+ __m512i m_vector;
+ __m512i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
+ params.multiplier_fixedpoint + multiplier_channel));
+ e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
+ params.multiplier_exponent + multiplier_channel));
+
+ const __m512i m_64bit_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
+ const __m512i m_64bit_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
+
+ const __m512i zero_vector = _mm512_setzero_epi32();
+ const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
+ const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
+ const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
+ const __m512i final_right_shift = _mm512_set1_epi32(31);
+ const __m512i right_shift_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
+ const __m512i right_shift_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
+ const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 0));
+ const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 1));
+
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m512i offset_vector =
+ _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
+
+ auto rounding_right_shift = [=](__m512i& results,
+ const __m512i& exponent) {
+ // Construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m512i zeros = _mm512_setzero_si512();
+ const __m512i mask_rightshift_gtz =
+ intrin_utils::mm512_cmpgt_epi64(exponent, zeros);
+ const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(
+ _mm512_set1_epi64(1),
+ _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
+ __m512i nudge = intrin_utils::mm512_blendv_epi64(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
+ const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
+
+ // Identify overflow in each lane and create mask.
+ const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
+ _mm512_set1_epi64(1),
+ _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
+ const __m512i mask_num_plus_nudge_overflow =
+ intrin_utils::mm512_cmpgt_epi64(
+ results,
+ _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm512_blendv_epi64(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ };
+
+ if (per_column_multiplier) {
+ auto apply_multiplier = [=](__m512i& accum, int col) {
+ __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8);
+ // Apply the fixed-point part of the multiplier.
+ __m512i left_shift_val =
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
+ __m512i m_64bit_val = _mm512_permutexvar_epi64(
+ perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
+ __m512i offset_vector_val = _mm512_permutexvar_epi64(
+ perm_64bit_vals, offset_vector);
+ __m512i final_right_shift_val = _mm512_permutexvar_epi64(
+ perm_64bit_vals,
+ col < 8 ? final_right_shift_low : final_right_shift_high);
+ __m512i right_shift_val = _mm512_permutexvar_epi64(
+ perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high);
+
+ accum = _mm512_sllv_epi32(accum, left_shift_val);
+ __m512i scaled_v_low = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
+ m_64bit_val);
+ __m512i scaled_v_high = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
+ m_64bit_val);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val);
+
+ scaled_v_low =
+ _mm512_srav_epi64(scaled_v_low, final_right_shift_val);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_val);
+
+ rounding_right_shift(scaled_v_low, right_shift_val);
+ rounding_right_shift(scaled_v_high, right_shift_val);
+
+ accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum = _mm512_inserti32x8(accum,
+ _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ };
+ apply_multiplier(accum_data_v0, 0);
+ apply_multiplier(accum_data_v1, 1);
+ apply_multiplier(accum_data_v2, 2);
+ apply_multiplier(accum_data_v3, 3);
+ apply_multiplier(accum_data_v4, 4);
+ apply_multiplier(accum_data_v5, 5);
+ apply_multiplier(accum_data_v6, 6);
+ apply_multiplier(accum_data_v7, 7);
+ apply_multiplier(accum_data_v8, 8);
+ apply_multiplier(accum_data_v9, 9);
+ apply_multiplier(accum_data_va, 10);
+ apply_multiplier(accum_data_vb, 11);
+ apply_multiplier(accum_data_vc, 12);
+ apply_multiplier(accum_data_vd, 13);
+ apply_multiplier(accum_data_ve, 14);
+ apply_multiplier(accum_data_vf, 15);
+ } else { // not per-column, so per-row
+ auto apply_multiplier = [=](__m512i& accum) {
+ accum = _mm512_sllv_epi32(accum, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
+
+ scaled_v_low =
+ _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ rounding_right_shift(scaled_v_low, right_shift_low);
+ rounding_right_shift(scaled_v_high, right_shift_high);
+ accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum = _mm512_inserti32x8(accum,
+ _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ };
+ apply_multiplier(accum_data_v0);
+ apply_multiplier(accum_data_v1);
+ apply_multiplier(accum_data_v2);
+ apply_multiplier(accum_data_v3);
+ apply_multiplier(accum_data_v4);
+ apply_multiplier(accum_data_v5);
+ apply_multiplier(accum_data_v6);
+ apply_multiplier(accum_data_v7);
+ apply_multiplier(accum_data_v8);
+ apply_multiplier(accum_data_v9);
+ apply_multiplier(accum_data_va);
+ apply_multiplier(accum_data_vb);
+ apply_multiplier(accum_data_vc);
+ apply_multiplier(accum_data_vd);
+ apply_multiplier(accum_data_ve);
+ apply_multiplier(accum_data_vf);
+ }
+
+ if (params.dst_zero_point != 0) {
+ __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
+ accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
+ accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point);
+ accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point);
+ accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point);
+ accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point);
+ accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point);
+ accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point);
+ accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point);
+ accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point);
+ accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point);
+ accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point);
+ accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point);
+ accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point);
+ accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point);
+ accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point);
+ accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point);
+ }
+ }
+
+ const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
+ const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
+
+ const bool store_full_block =
+ (residual_rows == 16) && (residual_cols == 16);
+
+ __m512i accum_data_v[16];
+
+ // In most cases we would make this conditional on (!store_full_block) and
+ // unwind the clamp-and-store loop, but the benefit appears small.
+ {
+ accum_data_v[0] = accum_data_v0;
+ accum_data_v[1] = accum_data_v1;
+ accum_data_v[2] = accum_data_v2;
+ accum_data_v[3] = accum_data_v3;
+ accum_data_v[4] = accum_data_v4;
+ accum_data_v[5] = accum_data_v5;
+ accum_data_v[6] = accum_data_v6;
+ accum_data_v[7] = accum_data_v7;
+ accum_data_v[8] = accum_data_v8;
+ accum_data_v[9] = accum_data_v9;
+ accum_data_v[10] = accum_data_va;
+ accum_data_v[11] = accum_data_vb;
+ accum_data_v[12] = accum_data_vc;
+ accum_data_v[13] = accum_data_vd;
+ accum_data_v[14] = accum_data_ve;
+ accum_data_v[15] = accum_data_vf;
+ }
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ const int block_col_offset = dst_stride;
+ if (store_full_block) {
+ for (int j = 0; j < 16; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
+ _mm512_cvtepi32_epi8(result));
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
+ _mm512_cvtepi32_epi8(result));
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ const int block_col_offset = dst_stride;
+ if (store_full_block) {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
+ _mm512_cvtepi32_epi8(result));
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
+ _mm512_cvtepi32_epi8(result));
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ const int block_col_offset = dst_stride;
+ if (store_full_block) {
+ for (int j = 0; j < 16; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(tmp_ptr + j * block_col_offset),
+ _mm512_cvtepi32_epi16(result));
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask,
+ _mm512_cvtepi32_epi16(result));
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ if (store_full_block) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < 16; ++j) {
+ _mm512_storeu_si512(tmp_ptr + j * dst_stride, accum_data_v[j]);
+ }
+ } else {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask,
+ accum_data_v[j]);
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += 16 * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ 16 * params.dst_stride);
+ rhs_col_ptr += 16 * params.rhs_stride;
+ } // End col-block loop.
+} // NOLINT(readability/fn_size)
+
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[16];
+ if (has_rhs_sums_offsets) {
+ const __m512i rhs_sums_offset_v =
+ _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
+ _mm512_loadu_si512(&params.rhs_sums[0]));
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row; row += 16) {
+ const int residual_rows = std::min(params.dst_rows - row, 16);
+
+ __m512i accum_data_v0;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ __m512i initial_accum_data =
+ _mm512_loadu_si512(reinterpret_cast<const __m512i*>(bias_ptr));
+ bias_ptr += bias_ptr_block_increment;
+
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m512i lhs_sums_offset =
+ _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
+ _mm512_loadu_si512(&params.lhs_sums[row]));
+ initial_accum_data =
+ _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth != 0) {
+ initial_accum_data = _mm512_add_epi32(initial_accum_data,
+ _mm512_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm512_sub_epi32(initial_accum_data,
+ _mm512_set1_epi32(rhs_sums_offsets[0]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += 4) {
+ const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
+ const __m128i rhs_data_8bit =
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ // For simplicity we load 4x the data that we need and process twice the
+ // data that we need and store only the data we need.
+ std::int32_t rhs_data[2];
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_low =
+ _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
+ // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
+ _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
+
+ // Process column 0.
+ __m512i accum_v = accum_data_v0;
+ constexpr int index = 0;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v0 = accum_v;
+
+ lhs_ptr += 16 * 4;
+ rhs_ptr += 16 * 4;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m512i m_vector;
+ __m512i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
+ m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
+ params.multiplier_fixedpoint + channel));
+ e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
+ params.multiplier_exponent + channel));
+
+ const __m512i m_64bit_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
+ const __m512i m_64bit_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
+
+ const __m512i zero_vector = _mm512_setzero_epi32();
+ const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
+ const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
+ const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
+ const __m512i final_right_shift = _mm512_set1_epi32(31);
+ const __m512i right_shift_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
+ const __m512i right_shift_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
+ const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 0));
+ const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 1));
+
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
+
+ auto rounding_right_shift = [=](__m512i& results,
+ const __m512i& exponent) {
+ // Construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m512i zeros = _mm512_setzero_si512();
+ const __m512i mask_rightshift_gtz =
+ intrin_utils::mm512_cmpgt_epi64(exponent, zeros);
+ const __m512i one_shift_exp_minus1 =
+ _mm512_sllv_epi64(_mm512_set1_epi64(1),
+ _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
+ __m512i nudge = intrin_utils::mm512_blendv_epi64(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
+ const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
+
+ // Identify overflow in each lane and create mask.
+ const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
+ _mm512_set1_epi64(1),
+ _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
+ const __m512i mask_num_plus_nudge_overflow =
+ intrin_utils::mm512_cmpgt_epi64(
+ results,
+ _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm512_blendv_epi64(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ };
+
+ // Shift and round column 0.
+ accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ rounding_right_shift(scaled_v_low, right_shift_low);
+ rounding_right_shift(scaled_v_high, right_shift_high);
+
+ accum_data_v0 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v0 = _mm512_inserti32x8(
+ accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+
+ if (params.dst_zero_point != 0) {
+ __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
+ accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
+ }
+ }
+
+ const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
+ const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm256_mask_storeu_epi16(tmp_ptr, row_mask,
+ _mm512_cvtepi32_epi16(result));
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0);
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += 16 * params.lhs_stride;
+ } // End row-block loop.
+} // NOLINT(readability/fn_size)
+
+void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvx512 float");
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ const std::int64_t dst_stride = params.dst_stride >> 2;
+ const std::int64_t rhs_stride = params.rhs_stride >> 2;
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ const int end_row = std::min(params.dst_rows, params.last_row + 16);
+ const int end_col = std::min(params.dst_cols, params.last_col + 16);
+
+ const float* adj_rhs_col_ptr =
+ params.rhs_base_ptr - params.start_col * rhs_stride;
+ float* adj_dst_col_ptr =
+ params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_ptr = params.bias;
+
+ const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
+ const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
+ const bool channel_dimension_is_col =
+ params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
+
+ int col = params.start_col;
+ for (; col <= end_col - 16; col += 16) {
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+ int row = params.start_row;
+ for (; row <= end_row - 16; row += 16) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+
+ // Process block in two halves, split by columns.
+ {
+ constexpr int mmm = 0;
+
+ __m512 accum_data_v0;
+ __m512 accum_data_v1;
+ __m512 accum_data_v2;
+ __m512 accum_data_v3;
+ __m512 accum_data_v4;
+ __m512 accum_data_v5;
+ __m512 accum_data_v6;
+ __m512 accum_data_v7;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr =
+ bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
+ accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
+ accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
+ accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
+ accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
+ accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
+ accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
+ accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
+ accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
+ } else {
+ const __m512 initial_accum_data =
+ _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
+
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+ for (int d = 0; d < (params.depth - 1); ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ // In this version RHS values are loaded individually rather than
+ // first loading together and then extract with broadcasting. This is
+ // because AVX flavours and instrinsics and compilers in combination
+ // do not handle this pattern of extraction very well.
+ const float* rhs_data = rhs_ptr;
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+
+ {
+ // Load 8 float32 values.
+ __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
+ __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
+ __m512 rhs4_7 =
+ _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
+
+ const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ }
+ {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ {
+ // Load 8 float32 values.
+ __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
+ __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
+ __m512 rhs4_7 =
+ _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
+ const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ {
+ float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+ accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
+ accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
+ accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
+ accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
+ accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
+ accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
+ accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
+ accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
+ }
+ }
+ } // Inner half-block loop, unrolled, first iteration.
+ {
+ constexpr int mmm = 1;
+
+ __m512 accum_data_v0;
+ __m512 accum_data_v1;
+ __m512 accum_data_v2;
+ __m512 accum_data_v3;
+ __m512 accum_data_v4;
+ __m512 accum_data_v5;
+ __m512 accum_data_v6;
+ __m512 accum_data_v7;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr =
+ bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
+ accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
+ accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
+ accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
+ accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
+ accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
+ accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
+ accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
+ accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
+ } else {
+ const __m512 initial_accum_data =
+ _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
+
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+ for (int d = 0; d < (params.depth - 1); ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ {
+ // Load 8 float32 values.
+ __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
+ __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
+ __m512 rhs4_7 =
+ _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
+
+ const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ }
+ {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ {
+ // Load 8 float32 values.
+ __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
+ __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
+ __m512 rhs4_7 =
+ _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
+ const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ {
+ float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+ accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
+ accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
+ accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
+ accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
+ accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
+ accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
+ accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
+ accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
+ }
+ }
+ } // Inner half-block loop, unrolled, second iteration.
+ } // End row-block loop.
+
+ // The unrolling within this conditional may be somewhat pointless. It
+ // depends on the kinds of models.
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+
+ // Process block in two halves, split by columns.
+ for (int mmm = 0; mmm < 2; ++mmm) {
+ __m512 accum_data_v0;
+ __m512 accum_data_v1;
+ __m512 accum_data_v2;
+ __m512 accum_data_v3;
+ __m512 accum_data_v4;
+ __m512 accum_data_v5;
+ __m512 accum_data_v6;
+ __m512 accum_data_v7;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr =
+ bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
+ accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
+ accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
+ accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
+ accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
+ accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
+ accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
+ accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
+ accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
+ } else {
+ const __m512 initial_accum_data =
+ _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
+
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+ for (int d = 0; d < (params.depth - 1); ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ {
+ // Load 8 float32 values.
+ __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
+ __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
+ __m512 rhs4_7 =
+ _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
+
+ const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ }
+ {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ {
+ // Load 8 float32 values.
+ __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
+ __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
+ __m512 rhs4_7 =
+ _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
+ const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ {
+ float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+ accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
+ accum_data_v0);
+ accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
+ accum_data_v1);
+ accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
+ accum_data_v2);
+ accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
+ accum_data_v3);
+ accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
+ accum_data_v4);
+ accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
+ accum_data_v5);
+ accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
+ accum_data_v6);
+ accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
+ accum_data_v7);
+ }
+ }
+ } // Inner half-block loop.
+ } // Residual rows, main col-block loop.
+ } // End col-block loop.
+
+ if (col < end_col) {
+ RUY_DCHECK_GE(end_col - col, 0);
+ RUY_DCHECK_LT(end_col - col, 16);
+
+ __m512 accum_data_v[8];
+
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+ for (int row = params.start_row; row < end_row; row += 16) {
+ const int residual_rows = std::min(end_row - row, 16);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+
+ // Process block in two halves, split by columns.
+ for (int mmm = 0; mmm < 2; ++mmm) {
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr =
+ bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]);
+ }
+ } else {
+ const __m512 initial_accum_data =
+ _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ for (int j = 0; j < 8; ++j) {
+ const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
+ accum_data_v[j] =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
+ }
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ }
+
+ const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
+
+ if (residual_rows == 16) {
+ if (residual_cols == 8) {
+ for (int j = 0; j < 8; ++j) {
+ float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+ accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+ _mm512_storeu_ps(block_ptr, accum_data_v[j]);
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+ accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+ _mm512_storeu_ps(block_ptr, accum_data_v[j]);
+ }
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+ accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
+ }
+ }
+ } // Inner half-block loop.
+ } // End row-block loop.
+ } // Residual cols.
+}
+
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvx512 float GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ const int end_row = std::min(params.dst_rows, params.last_row + 16);
+
+ float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
+ const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
+
+ __m512 accum_data_v;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = adj_dst_col_ptr;
+
+ int row = params.start_row;
+ for (; row <= end_row - 16; row += 16) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm512_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float rhs_data = *rhs_ptr;
+
+ const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
+ accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ }
+
+ accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
+ _mm512_storeu_ps(dst_ptr, accum_data_v);
+ } // End row-block loop.
+
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+ RUY_CHECK_GE(residual_rows, 1);
+ RUY_CHECK_LT(residual_rows, 16);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ accum_data_v = _mm512_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float rhs_data = *rhs_ptr;
+
+ const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
+ accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ }
+
+ accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
+ _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v);
+ } // End handling of residual rows.
+}
+
+#endif // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h
new file mode 100644
index 0000000..9509b8f
--- /dev/null
+++ b/ruy/kernel_common.h
@@ -0,0 +1,287 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_KERNEL_COMMON_H_
+#define RUY_RUY_KERNEL_COMMON_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <type_traits>
+
+#include "ruy/apply_multiplier.h"
+#include "ruy/check_macros.h"
+#include "ruy/mat.h"
+#include "ruy/matrix.h"
+#include "ruy/mul_params.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/side_pair.h"
+#include "ruy/size_util.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+struct Kernel;
+
+#define RUY_INHERIT_KERNEL(PARENT, CHILD) \
+ template <typename LhsScalar, typename RhsScalar, typename DstScalar, \
+ typename AccumScalar> \
+ struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar> \
+ : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> { \
+ explicit Kernel(Tuning tuning) \
+ : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \
+ tuning) {} \
+ };
+
+// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code.
+//
+// In other cases, we still define (empty) versions, so that dummy kernels
+// can use the classes in function signatures.
+#if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \
+ RUY_PLATFORM_X86
+
+#define RUY_ASM_FLAG_HAS_BIAS 0x1
+#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
+#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
+#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
+#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
+#define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20
+
+#define RUY_ASM_TYPE_ID_UINT8 1
+#define RUY_ASM_TYPE_ID_INT8 2
+#define RUY_ASM_TYPE_ID_INT16 3
+#define RUY_ASM_TYPE_ID_INT32 4
+
+template <typename DstScalar>
+struct DstTypeId {};
+
+template <>
+struct DstTypeId<std::uint8_t> {
+ static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
+};
+
+template <>
+struct DstTypeId<std::int8_t> {
+ static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
+};
+
+template <>
+struct DstTypeId<std::int16_t> {
+ static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
+};
+
+template <>
+struct DstTypeId<std::int32_t> {
+ static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
+};
+
+template <int LhsCols, int RhsCols>
+struct KernelParams8bit {
+ static constexpr int kMaxDstTypeSize = 4;
+
+ const std::int32_t* bias;
+ const std::int32_t* lhs_sums;
+ const std::int32_t* rhs_sums;
+ const std::int8_t* lhs_base_ptr;
+ const std::int32_t* multiplier_fixedpoint;
+ const std::int32_t* multiplier_exponent;
+ const std::int8_t* rhs_base_ptr;
+ void* dst_base_ptr;
+ std::int32_t lhs_zero_point;
+ std::int32_t rhs_zero_point;
+ std::int32_t dst_zero_point;
+ std::int32_t prod_zp_depth;
+ std::int32_t start_row;
+ std::int32_t start_col;
+ std::int32_t last_row;
+ std::int32_t last_col;
+ std::int32_t dst_rows;
+ std::int32_t dst_cols;
+ std::int32_t lhs_stride;
+ std::int32_t rhs_stride;
+ std::int32_t dst_stride;
+ std::int32_t depth;
+ std::int32_t clamp_min;
+ std::int32_t clamp_max;
+ std::uint8_t flags;
+ std::uint8_t dst_type_id;
+ const std::int32_t zero_data[LhsCols] = {0};
+ std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
+ std::int32_t multiplier_fixedpoint_buf[LhsCols];
+ std::int32_t multiplier_exponent_buf[LhsCols];
+};
+
+template <typename DstScalar, int LhsCols, int RhsCols>
+void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
+ const PMat<std::int8_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params,
+ int start_row, int start_col, int end_row,
+ int end_col, Mat<DstScalar>* dst,
+ KernelParams8bit<LhsCols, RhsCols>* params) {
+ using Params = KernelParams8bit<LhsCols, RhsCols>;
+
+ static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
+
+ const int depth = lhs.layout.rows;
+ RUY_DCHECK_EQ(start_row % LhsCols, 0);
+ RUY_DCHECK_EQ(start_col % RhsCols, 0);
+ RUY_DCHECK_EQ(end_row % LhsCols, 0);
+ RUY_DCHECK_EQ(end_col % RhsCols, 0);
+
+ params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
+ params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
+ params->flags = 0;
+ params->bias = params->zero_data;
+ if (mul_params.bias()) {
+ params->bias = mul_params.bias();
+ params->flags |= RUY_ASM_FLAG_HAS_BIAS;
+ }
+ if (lhs.sums) {
+ params->lhs_sums = lhs.sums;
+ params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
+ }
+ if (rhs.sums) {
+ params->rhs_sums = rhs.sums;
+ params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
+ }
+ if (mul_params.channel_dimension() == ChannelDimension::kCol) {
+ params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
+ }
+ params->start_row = start_row;
+ params->start_col = start_col;
+ params->last_row = end_row - LhsCols;
+ params->last_col = end_col - RhsCols;
+ params->lhs_stride = lhs.layout.stride;
+ params->rhs_stride = rhs.layout.stride;
+ params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
+ params->lhs_zero_point = lhs.zero_point;
+ params->rhs_zero_point = rhs.zero_point;
+ params->dst_zero_point = dst->zero_point;
+ params->depth = depth;
+ params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
+ params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
+ if (mul_params.multiplier_fixedpoint_perchannel()) {
+ params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
+ params->multiplier_fixedpoint =
+ mul_params.multiplier_fixedpoint_perchannel();
+ params->multiplier_exponent = mul_params.multiplier_exponent_perchannel();
+ } else {
+ params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
+ params->multiplier_exponent = params->multiplier_exponent_buf;
+ for (int i = 0; i < LhsCols; i++) {
+ params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint();
+ params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent();
+ }
+ }
+ params->clamp_min = mul_params.clamp_min();
+ params->clamp_max = mul_params.clamp_max();
+ params->dst_rows = dst->layout.rows;
+ params->dst_cols = dst->layout.cols;
+
+ RUY_DCHECK_LT(params->last_row, params->dst_rows);
+ RUY_DCHECK_LT(params->last_col, params->dst_cols);
+
+ params->dst_type_id = DstTypeId<DstScalar>::kValue;
+ params->dst_base_ptr =
+ dst->data.get() + start_col * dst->layout.stride + start_row;
+}
+
+template <int LhsCols, int RhsCols>
+struct KernelParamsFloat {
+ const float* lhs_base_ptr;
+ const float* rhs_base_ptr;
+ float* dst_base_ptr;
+ const float* bias;
+ std::int32_t start_row;
+ std::int32_t start_col;
+ std::int32_t last_row;
+ std::int32_t last_col;
+ std::int32_t dst_rows;
+ std::int32_t dst_cols;
+ std::int32_t lhs_stride;
+ std::int32_t rhs_stride;
+ std::int32_t dst_stride;
+ std::int32_t depth;
+ float clamp_min;
+ float clamp_max;
+ std::uint8_t flags;
+ const float zero_data[LhsCols] = {0};
+ float dst_tmp_buf[LhsCols * RhsCols];
+};
+
+template <int LhsCols, int RhsCols>
+inline void MakeKernelParamsFloat(const PMat<float>& lhs,
+ const PMat<float>& rhs,
+ const MulParams<float, float>& mul_params,
+ int start_row, int start_col, int end_row,
+ int end_col, Mat<float>* dst,
+ KernelParamsFloat<LhsCols, RhsCols>* params) {
+ const int depth = lhs.layout.rows;
+ RUY_DCHECK_EQ(start_row % LhsCols, 0);
+ RUY_DCHECK_EQ(start_col % RhsCols, 0);
+ RUY_DCHECK_EQ(end_row % LhsCols, 0);
+ RUY_DCHECK_EQ(end_col % RhsCols, 0);
+
+ params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
+ params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
+ params->dst_base_ptr =
+ dst->data.get() + start_col * dst->layout.stride + start_row;
+
+ std::uint8_t flags = 0;
+ params->bias = params->zero_data;
+ if (mul_params.bias()) {
+ params->bias = mul_params.bias();
+ flags |= RUY_ASM_FLAG_HAS_BIAS;
+ }
+ if (mul_params.channel_dimension() == ChannelDimension::kCol) {
+ flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
+ }
+ params->flags = flags;
+ params->start_row = start_row;
+ params->start_col = start_col;
+ params->last_row = end_row - LhsCols;
+ params->last_col = end_col - RhsCols;
+ params->lhs_stride = sizeof(float) * lhs.layout.stride;
+ params->rhs_stride = sizeof(float) * rhs.layout.stride;
+ params->dst_stride = sizeof(float) * dst->layout.stride;
+ params->depth = depth;
+ params->clamp_min = mul_params.clamp_min();
+ params->clamp_max = mul_params.clamp_max();
+ params->dst_rows = dst->layout.rows;
+ params->dst_cols = dst->layout.cols;
+
+ RUY_DCHECK_LT(params->last_row, params->dst_rows);
+ RUY_DCHECK_LT(params->last_col, params->dst_cols);
+}
+
+#else // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
+ // RUY_OPT(ASM)) || RUY_PLATFORM_X86
+
+template <int LhsCols, int RhsCols>
+struct KernelParams8bit {};
+
+template <int LhsCols, int RhsCols>
+struct KernelParamsFloat {};
+
+#endif // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
+ // RUY_OPT(ASM)) || RUY_PLATFORM_X86
+
+} // namespace ruy
+
+#endif // RUY_RUY_KERNEL_COMMON_H_
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
new file mode 100644
index 0000000..2f8fe19
--- /dev/null
+++ b/ruy/kernel_x86.h
@@ -0,0 +1,874 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_KERNEL_X86_H_
+#define RUY_RUY_KERNEL_X86_H_
+
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/kernel_common.h"
+#include "ruy/mat.h"
+#include "ruy/mul_params.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_X86
+
+RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma)
+RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx)
+RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512)
+
+void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);
+
+template <typename DstScalar>
+struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
+ static constexpr Path kPath = Path::kAvx512;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ Kernel8bitAvx512SingleCol(params);
+ } else {
+ Kernel8bitAvx512(params);
+ }
+ }
+};
+
+void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);
+
+template <>
+struct Kernel<Path::kAvx512, float, float, float, float> {
+ static constexpr Path kPath = Path::kAvx512;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
+ const MulParams<float, float>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ KernelFloatAvx512SingleCol(params);
+ } else {
+ KernelFloatAvx512(params);
+ }
+ }
+};
+
+void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);
+
+template <typename DstScalar>
+struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t,
+ DstScalar> {
+ static constexpr Path kPath = Path::kAvx2Fma;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ Kernel8bitAvx2SingleCol(params);
+ } else {
+ Kernel8bitAvx2(params);
+ }
+ }
+};
+
+void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
+
+template <>
+struct Kernel<Path::kAvx2Fma, float, float, float, float> {
+ static constexpr Path kPath = Path::kAvx2Fma;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
+ const MulParams<float, float>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ KernelFloatAvx2SingleCol(params);
+ } else {
+ KernelFloatAvx2(params);
+ }
+ }
+};
+
+void KernelFloatAvx(const KernelParamsFloat<8, 8>& params);
+void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params);
+
+template <>
+struct Kernel<Path::kAvx, float, float, float, float> {
+ static constexpr Path kPath = Path::kAvx;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
+ const MulParams<float, float>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ KernelFloatAvxSingleCol(params);
+ } else {
+ KernelFloatAvx(params);
+ }
+ }
+};
+
+void Kernel8bitAvx(const KernelParams8bit<8, 8>& params);
+void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params);
+
+template <typename DstScalar>
+struct Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
+ static constexpr Path kPath = Path::kAvx2Fma;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ Kernel8bitAvxSingleCol(params);
+ } else {
+ Kernel8bitAvx(params);
+ }
+ }
+};
+
+#endif // RUY_PLATFORM_X86
+} // namespace ruy
+
+#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
+
+#include <immintrin.h> // IWYU pragma: keep
+
+namespace ruy {
+namespace {
+namespace intrin_utils {
+
+// Defined as a template so clang won't detect it as an uneeded
+// definition.
+template <Path path>
+inline float mm256_get1_ps(const __m256 a, int i) {
+ __m256i ai = _mm256_castps_si256(a);
+ int float_val_as_int;
+ switch (i) {
+ case 0:
+ float_val_as_int = _mm256_extract_epi32(ai, 0);
+ break;
+ case 1:
+ float_val_as_int = _mm256_extract_epi32(ai, 1);
+ break;
+ case 2:
+ float_val_as_int = _mm256_extract_epi32(ai, 2);
+ break;
+ case 3:
+ float_val_as_int = _mm256_extract_epi32(ai, 3);
+ break;
+ case 4:
+ float_val_as_int = _mm256_extract_epi32(ai, 4);
+ break;
+ case 5:
+ float_val_as_int = _mm256_extract_epi32(ai, 5);
+ break;
+ case 6:
+ float_val_as_int = _mm256_extract_epi32(ai, 6);
+ break;
+ case 7:
+ float_val_as_int = _mm256_extract_epi32(ai, 7);
+ break;
+ default:
+ RUY_DCHECK_LT(i, 8);
+ return .0f;
+ }
+ float float_val;
+ std::memcpy(&float_val, &float_val_as_int, sizeof(float_val));
+ return float_val;
+}
+
+// Defined as a template so clang won't detect it as an uneeded
+// definition.
+template <Path path>
+inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst[i] = intrin_utils::mm256_get1_ps<path>(v, i);
+ }
+}
+
+template <Path path>
+inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) {
+ // Specializations added for AVX and AVX2FMA paths in their respective kernel
+ // files.
+ RUY_DCHECK(false);
+ return _mm256_set1_ps(0);
+}
+
+template <Path path>
+inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) {
+ // Specializations added for AVX and AVX2FMA paths in their respective kernel
+ // files.
+ RUY_DCHECK(false);
+ return _mm256_set1_epi32(0);
+}
+
+// Polyfill for _mm_storeu_si16(dst, v).
+template <Path path>
+inline void mm_storeu_si16(void* dst, __m128i v) {
+#if (defined __clang__) || (defined _MSC_VER)
+ _mm_storeu_si16(dst, v);
+#else
+ // GCC 9 lacks support for __mm_storeu_si16.
+ *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0);
+#endif
+}
+
+// Polyfill for _mm_storeu_si32(dst, v).
+template <Path path>
+inline void mm_storeu_si32(void* dst, __m128i v) {
+#if (defined __clang__) || (defined _MSC_VER)
+ _mm_storeu_si32(dst, v);
+#else
+ // GCC 9 lacks support for __mm_storeu_si32.
+ *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0);
+#endif
+}
+
+// Polyfill for _mm_loadu_si32(src).
+template <Path path>
+inline __m128i mm_loadu_si32(const void* src) {
+#if (defined __clang__) || (defined _MSC_VER)
+ return _mm_loadu_si32(src);
+#else
+ // GCC 9 lacks support for _mm_loadu_si32.
+ __m128i res;
+ asm("movss %[src], %[res]"
+ : [res] "=x"(res)
+ : [src] "m"(*static_cast<const int*>(src)));
+ return res;
+#endif
+}
+
+template <Path path>
+inline __m128i mm256_extracti128_si256(const __m256i&, const int) {
+ RUY_DCHECK(false);
+ return _mm_setzero_si128();
+}
+
+template <Path path>
+inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows,
+ const __m256i v) {
+ // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
+ const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
+ __m256i shuffled_v;
+ if (residual_rows > 1) {
+ // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4
+ // in each 128-bit lane.
+ shuffled_v = intrin_utils::mm256_shuffle_epi8<path>(v, repack_perm);
+ }
+ switch (residual_rows) {
+ case 0:
+ break;
+ case 1:
+ dst[0] = _mm256_extract_epi8(v, 0);
+ break;
+ case 2:
+ mm_storeu_si16<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ break;
+ case 3: {
+ __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 0);
+ mm_storeu_si16<path>(dst, trailing_packed);
+ dst[2] = _mm_extract_epi8(trailing_packed, 2);
+ break;
+ }
+ case 4:
+ mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ break;
+ case 5:
+ mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ dst[4] = _mm256_extract_epi8(shuffled_v, 16);
+ break;
+ case 6:
+ mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ mm_storeu_si16<path>(dst + 4,
+ mm256_extracti128_si256<path>(shuffled_v, 1));
+ break;
+ case 7: {
+ mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1);
+ mm_storeu_si16<path>(dst + 4, trailing_packed);
+ dst[6] = _mm_extract_epi8(trailing_packed, 2);
+ break;
+ }
+ case 8:
+ mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ mm_storeu_si32<path>(dst + 4,
+ mm256_extracti128_si256<path>(shuffled_v, 1));
+ break;
+ default:
+ RUY_DCHECK_LE(residual_rows, 8);
+ break;
+ }
+}
+
+template <Path path>
+inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) {
+ // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
+ const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
+ const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
+ mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
+}
+
+template <Path path>
+inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows,
+ const __m256i v) {
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
+ reinterpret_cast<std::uint8_t*>(dst), residual_rows, v);
+}
+
+template <Path path>
+inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) {
+ // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
+ const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
+ const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
+ mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
+}
+
+template <Path path>
+inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows,
+ const __m256i v) {
+ // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
+ // truncating each 16-bit integer.
+ const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
+ __m256i shuffled_v;
+ __m128i shuffled_v_low;
+ if (residual_rows > 1) {
+ shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
+ shuffled_v_low = mm256_extracti128_si256<path>(shuffled_v, 0);
+ } else {
+ shuffled_v_low = mm256_extracti128_si256<path>(v, 0);
+ }
+ switch (residual_rows) {
+ case 0:
+ break;
+ case 1:
+ mm_storeu_si16<path>(dst, shuffled_v_low);
+ break;
+ case 2:
+ mm_storeu_si32<path>(dst, shuffled_v_low);
+ break;
+ case 3: {
+ mm_storeu_si32<path>(dst, shuffled_v_low);
+ dst[2] = _mm_extract_epi16(shuffled_v_low, 2);
+ break;
+ }
+ case 4:
+ _mm_storeu_si64(dst, shuffled_v_low);
+ break;
+ case 5:
+ _mm_storeu_si64(dst, shuffled_v_low);
+ dst[4] = _mm256_extract_epi16(shuffled_v, 8);
+ break;
+ case 6:
+ _mm_storeu_si64(dst, shuffled_v_low);
+ mm_storeu_si32<path>(dst + 4,
+ mm256_extracti128_si256<path>(shuffled_v, 1));
+ break;
+ case 7: {
+ _mm_storeu_si64(dst, shuffled_v_low);
+ __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1);
+ mm_storeu_si32<path>(dst + 4, trailing_packed);
+ dst[6] = _mm_extract_epi16(trailing_packed, 2);
+ break;
+ }
+ case 8:
+ _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
+ break;
+ default:
+ RUY_DCHECK_LE(residual_rows, 8);
+ break;
+ }
+}
+
+template <Path path>
+inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) {
+ // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
+ // truncating each 16-bit integer.
+ const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
+ const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
+ _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
+ _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
+}
+
+template <Path path>
+inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows,
+ const __m256i v) {
+ const __m128i v_low = mm256_extracti128_si256<path>(v, 0);
+ switch (residual_rows) {
+ case 0:
+ break;
+ case 1:
+ mm_storeu_si32<path>(dst, v_low);
+ break;
+ case 2:
+ _mm_storeu_si64(dst, v_low);
+ break;
+ case 3: {
+ __m128i trailing_packed = v_low;
+ _mm_storeu_si64(dst, trailing_packed);
+ dst[2] = _mm_extract_epi32(trailing_packed, 2);
+ break;
+ }
+ case 4:
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
+ break;
+ case 5:
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
+ dst[4] = _mm256_extract_epi32(v, 4);
+ break;
+ case 6:
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
+ _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(v, 1));
+ break;
+ case 7: {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
+ __m128i trailing_packed = mm256_extracti128_si256<path>(v, 1);
+ _mm_storeu_si64(dst + 4, trailing_packed);
+ dst[6] = _mm_extract_epi32(trailing_packed, 2);
+ break;
+ }
+ case 8:
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
+ break;
+ default:
+ RUY_DCHECK_LE(residual_rows, 8);
+ break;
+ }
+}
+
+template <Path path>
+inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
+}
+
+// Transpose a 8x8 matrix of floats.
+template <Path path>
+void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3,
+ __m256* v4, __m256* v5, __m256* v6, __m256* v7) {
+ __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1);
+ __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1);
+ __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3);
+ __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3);
+ __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5);
+ __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5);
+ __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7);
+ __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7);
+ __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0));
+ __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2));
+ __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0));
+ __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2));
+ __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0));
+ __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2));
+ __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0));
+ __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2));
+ *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20);
+ *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20);
+ *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20);
+ *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20);
+ *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31);
+ *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31);
+ *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31);
+ *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31);
+}
+
+// Transpose a 8x8 matrix of int32's.
+template <Path path>
+void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2,
+ __m256i* v3, __m256i* v4, __m256i* v5,
+ __m256i* v6, __m256i* v7) {
+ mm256_transpose8x8_ps<path>(
+ reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1),
+ reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3),
+ reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5),
+ reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7));
+}
+
+} // namespace intrin_utils
+} // namespace
+
+template <Path path>
+inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ const std::int64_t dst_stride = params.dst_stride >> 2;
+ const std::int64_t rhs_stride = params.rhs_stride >> 2;
+ //
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ // AVX2 float block size = 8.
+ const int end_row = std::min(params.dst_rows, params.last_row + 8);
+ const int end_col = std::min(params.dst_cols, params.last_col + 8);
+ //
+ const float* adj_rhs_col_ptr =
+ params.rhs_base_ptr - params.start_col * rhs_stride;
+ float* adj_dst_col_ptr =
+ params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_ptr = params.bias;
+
+ const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
+ const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+ const bool channel_dimension_is_col =
+ params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
+
+ int col = params.start_col;
+ // Loop through cols by float block size, leaving incomplete remainder
+ for (; col <= end_col - 8; col += 8) {
+ __m256 accum_data_v[8];
+
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+ for (int row = params.start_row; row < end_row; row += 8) {
+ const int residual_rows = std::min(end_row - row, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
+ }
+ } else {
+ const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
+ const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
+
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ // Load 8 RHS values, then use permute instructions to
+ // broadcast each value to a register.
+ __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7]
+ __m256 rhs0_3 =
+ _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3]
+ __m256 rhs4_7 =
+ _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7]
+
+ const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
+ accum_data_v[0] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_0, accum_data_v[0]);
+
+ const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
+ accum_data_v[1] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_1, accum_data_v[1]);
+
+ const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
+ accum_data_v[2] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_2, accum_data_v[2]);
+
+ const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
+ accum_data_v[3] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_3, accum_data_v[3]);
+
+ const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
+ accum_data_v[4] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_4, accum_data_v[4]);
+
+ const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
+ accum_data_v[5] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_5, accum_data_v[5]);
+
+ const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
+ accum_data_v[6] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_6, accum_data_v[6]);
+
+ const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
+ accum_data_v[7] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_7, accum_data_v[7]);
+
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ if (residual_rows == 8) {
+ for (int j = 0; j < 8; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ _mm256_storeu_ps(block_ptr, accum_data_v[j]);
+ }
+ } else {
+ for (int j = 0; j < 8; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
+ accum_data_v[j]);
+ }
+ }
+ } // End row-block loop.
+ } // End col-block loop.
+
+ if (col < end_col) {
+ // Remaining cols in [0, float block size).
+ RUY_DCHECK_GE(end_col - col, 0);
+ RUY_DCHECK_LT(end_col - col, 8);
+
+ __m256 accum_data_v[8];
+
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+ const int residual_cols = std::min(end_col - col, 8);
+
+ for (int row = params.start_row; row < end_row; row += 8) {
+ const int residual_rows = std::min(end_row - row, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
+ }
+ } else {
+ const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
+ const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
+
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7]
+ __m256 rhs0_3 =
+ _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3]
+ __m256 rhs4_7 =
+ _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7]
+
+ const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
+ accum_data_v[0] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_0, accum_data_v[0]);
+
+ const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
+ accum_data_v[1] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_1, accum_data_v[1]);
+
+ const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
+ accum_data_v[2] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_2, accum_data_v[2]);
+
+ const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
+ accum_data_v[3] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_3, accum_data_v[3]);
+
+ const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
+ accum_data_v[4] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_4, accum_data_v[4]);
+
+ const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
+ accum_data_v[5] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_5, accum_data_v[5]);
+
+ const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
+ accum_data_v[6] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_6, accum_data_v[6]);
+
+ const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
+ accum_data_v[7] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_7, accum_data_v[7]);
+
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ for (int j = 0; j < residual_cols; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
+ accum_data_v[j]);
+ }
+ } // End row-block loop.
+ } // End col-block terminal conditional.
+}
+
+template <Path path>
+inline void KernelFloatAvxCommonSingleCol(
+ const KernelParamsFloat<8, 8>& params) {
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ //
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ // AVX2 float block size = 8.
+ const int end_row = std::min(params.dst_rows, params.last_row + 8);
+
+ float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
+ const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+
+ __m256 accum_data_v;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = adj_dst_col_ptr;
+
+ int row = params.start_row;
+ for (; row <= end_row - 8; row += 8) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm256_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ int d = 0;
+ for (; d <= params.depth - 4; d += 4) {
+ const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr);
+ const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]);
+ accum_data_v = intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0,
+ accum_data_v);
+ const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]);
+ const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8);
+ accum_data_v = intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1,
+ accum_data_v);
+
+ const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16);
+ const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]);
+ accum_data_v = intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2,
+ accum_data_v);
+ const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]);
+ const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24);
+ accum_data_v = intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3,
+ accum_data_v);
+ lhs_ptr += 32; // Loaded 8 * 4 floats.
+ rhs_ptr += 32;
+ }
+ for (; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
+ accum_data_v =
+ intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
+ _mm256_storeu_ps(dst_ptr, accum_data_v);
+ } // End row-block loop.
+
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+ RUY_CHECK_GE(residual_rows, 1);
+ RUY_CHECK_LT(residual_rows, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm256_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
+ accum_data_v =
+ intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v);
+ } // End handling of residual rows.
+}
+} // namespace ruy
+#endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
+
+#endif // RUY_RUY_KERNEL_X86_H_
diff --git a/ruy/mat.h b/ruy/mat.h
new file mode 100644
index 0000000..587b208
--- /dev/null
+++ b/ruy/mat.h
@@ -0,0 +1,492 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Internal types and helpers for matrices.
+// "Mat" is the name we use to refer to our internal matrix classes; it can be
+// thought of as a shorter version of "InternalMatrix"`
+//
+// Ruy has four internal matrix classes, besides the
+// Matrix<T> class that we expose to the user-facing API.
+//
+// TODO(silvasean): Put parts of this architecture description somewhere more
+// prominent.
+//
+// The 4 internal matrix classes are named Mat, EMat, PMat, PEMat, where:
+// - "E" indicates a type-erased class, storing a void* pointer and a runtime
+// enum value to track the scalar type, as opposed to being templatized
+// on a Scalar type and storing a Scalar* pointer.
+// - "P" indicates a packed matrix class, the output of the packing code and
+// input of the kernel code. See comments in pack.h regarding packing.
+//
+// In other words:
+//
+// Plain matrices Packed matrices
+// +----------------------------------
+// Templated | Mat, Matrix PMat
+// Type-erased | EMat PEMat
+//
+// Note that Matrix<T> is *not* implemented in terms of the internal types. It
+// is an independent, simple, and user-facing type. Matrix<T> is functionally
+// equivalent to Mat, but we keep it separate to insulate internals from
+// interface and to be able to make different compromises in internals
+// vs interface: in internals we prefer Mat to be a C-style struct with
+// raw data member access and to be similar to the other PMat/EMat/PEMat
+// classes for consistency.
+//
+// The use of type-erasure might seem surprising for a library like Ruy with a
+// heavily-templated entry point, but it is motivated by the desire for most of
+// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3
+// main parts:
+// - "entry-point" (ruy.h) - this is the highly templated ruy::Mul entry
+// point.
+// - "front-end" (frontend.*, validate.*, create_trmul_params.*,
+// prepare_packed_matrices.*) - the work to handle the entry-point call down to
+// the point where it can be handed off to the middle/back ends below. That
+// includes routines that select RunKernel and RunPack
+// implementations statically based on those template parameters.
+// - "back-end" (kernel_*.*, pack_*.*)- this consists of the implementations of
+// RunKernel and RunPack, often in assembly code, which are the building blocks
+// that Ruy calls to perform matrix multiplication. These are templated so that
+// only the requested types/Path's are actually emitted by the compiler.
+// - "middle-end" (trmul.*) - this is the part of Ruy that orchestrates the
+// calls to the "back-end" optimized building blocks. This layer has to deal
+// with issues like cache locality and low-overhead multi-threading.
+//
+// There is a desire for the "middle-end" to be non-templated in order to
+// simplify the implementation and reduce code-size. We type-erase when going
+// from the "front-end" to the "middle-end", and un-type-erase going from the
+// "middle-end" to the "back-end". The un-type-erasure is possible because the
+// "front-end" is responsible for instantiating the needed "back-end" templates,
+// and thus the static type information is still present.
+//
+// Each layer of Ruy uses matrix types:
+// - "entry-point": Matrix<T>
+// - "front-end": Mat
+// - "middle-end": EMat, PEMat
+// - "back-end": Mat, PMat
+//
+// The use of separate types for packed matrices is not essential, but makes it
+// obvious at a glance whether a matrix is a packed matrix or not. We would
+// reconsider this decision if there was significant duplication between packed
+// and unpacked matrices, but that doesn't seem to be the case at the moment.
+//
+// Another goal is to keep the user-facing Matrix<T> as simple and
+// understandable as possible. Ideally, a user should be able to read the struct
+// definition for Matrix<T> and see a very simple definition with no internal
+// details like sums and kernel block layout.
+
+#ifndef RUY_RUY_INTERNAL_MATRIX_H_
+#define RUY_RUY_INTERNAL_MATRIX_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <type_traits>
+#include <utility>
+
+#include "ruy/check_macros.h"
+#include "ruy/matrix.h"
+#include "ruy/size_util.h"
+
+namespace ruy {
+
+// Internal counterpart of Layout, used by Mat.
+struct MatLayout final {
+ std::int32_t rows = 0;
+ std::int32_t cols = 0;
+ // Stride is the offset between two adjacent matrix elements
+ // in the non-contiguous direction.
+ std::int32_t stride = 0;
+ Order order = Order::kColMajor;
+};
+
+inline MatLayout ToInternal(const Layout& src) {
+ MatLayout ret;
+ ret.rows = src.rows();
+ ret.cols = src.cols();
+ ret.stride = src.stride();
+ ret.order = src.order();
+ return ret;
+}
+
+// Internal counterpart of Matrix
+template <typename Scalar>
+struct Mat final {
+ detail::ConstCheckingPtr<Scalar> data;
+ MatLayout layout;
+ Scalar zero_point = 0;
+ CachePolicy cache_policy = CachePolicy::kNeverCache;
+};
+
+template <typename Scalar>
+inline Mat<Scalar> ToInternal(const Matrix<Scalar>& src) {
+ Mat<Scalar> ret;
+ ret.data.set(src.data());
+ ret.layout = ToInternal(src.layout());
+ ret.zero_point = src.zero_point();
+ ret.cache_policy = src.cache_policy();
+ return ret;
+}
+
+template <typename Scalar>
+inline Mat<Scalar> ToInternal(Matrix<Scalar>& src) {
+ Mat<Scalar> ret;
+ ret.data.set(src.data());
+ ret.layout = ToInternal(src.layout());
+ ret.zero_point = src.zero_point();
+ ret.cache_policy = src.cache_policy();
+ return ret;
+}
+
+// KernelLayout describes small-scale block structure in a packed matrix layout.
+// It's a runtime (as opposed to compile-time-constant) version of the
+// FixedKernelLayout struct used to declare kernel layouts.
+//
+// This is is sometimes known as "tiling" in other contexts.
+//
+// For example, consider a packed matrix in column-major format with a
+// column-major KernelLayout. The matrix logically has a shape of
+// `[cols, rows]`. However, the matrix is laid out as though it were a 4D array
+// of shape `[cols / kcols, rows / krows, kcols, krows]`.
+//
+// Note that in the case of kcols=1, krows=1, this degenerates to
+// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block
+// structure.
+struct KernelLayout final {
+ Order order = Order::kColMajor;
+ std::uint8_t rows = 1;
+ std::uint8_t cols = 1;
+};
+
+// A packed matrix has a small-scale block structure that is not present in in
+// the input matrices. This block structure is necessary for the kernels to
+// process data efficiently.
+//
+// This struct is very similar to MatLayout, but has the extra KernelLayout
+// field.
+struct PMatLayout final {
+ std::int32_t rows = 0;
+ std::int32_t cols = 0;
+ // Stride is the offset between two adjacent matrix elements
+ // in the non-contiguous direction.
+ std::int32_t stride = 0;
+ Order order = Order::kColMajor;
+ // Small scale layout shuffling, potentially departing from
+ // linear row-major or column-major storage. See KernelLayout.
+ KernelLayout kernel;
+};
+
+inline bool operator==(const PMatLayout& a, const PMatLayout& b) {
+ return a.cols == b.cols && a.rows == b.rows && a.stride == b.stride &&
+ a.order == b.order && a.kernel.rows == b.kernel.rows &&
+ a.kernel.cols == b.kernel.cols && a.kernel.order == b.kernel.order;
+}
+
+// Dynamic representation for a type.
+//
+// The most important field in this struct is the size, which Ruy uses to know
+// how much memory to allocate without having to be templated on a type.
+// Signed-ness and floating-point-ness are mainly present as debugging checks.
+//
+// Note: Ruy does not use this struct to to dynamically dispatch between
+// different typed implementations. As described in the comment at the top of
+// this file, Ruy's "front-end", which is templated, instantiates all the
+// necessary "back-end" routines with complete static knowledge of all the
+// types.
+struct Type final {
+ template <typename T>
+ static Type Create() {
+ Type ret;
+ ret.is_signed = std::is_signed<T>::value;
+ ret.is_floating_point = std::is_floating_point<T>::value;
+ ret.size = sizeof(T);
+ return ret;
+ }
+
+ template <typename T>
+ void AssertIs() const {
+ RUY_DCHECK_EQ(is_signed, Create<T>().is_signed);
+ RUY_DCHECK_EQ(is_floating_point, Create<T>().is_floating_point);
+ RUY_DCHECK_EQ(size, Create<T>().size);
+ }
+
+ bool is_signed = false;
+ bool is_floating_point = false;
+ std::uint8_t size = 0;
+};
+
+inline bool operator==(const Type& type1, const Type& type2) {
+ return type1.is_signed == type2.is_signed &&
+ type1.is_floating_point == type2.is_floating_point &&
+ type1.size == type2.size;
+}
+
+// Type-erased matrix.
+struct EMat final {
+ Type data_type;
+ void* data = nullptr;
+ MatLayout layout;
+ std::int32_t zero_point = 0;
+ CachePolicy cache_policy = CachePolicy::kNeverCache;
+};
+
+// Type-erased packed matrix.
+struct PEMat final {
+ Type data_type;
+ void* data = nullptr;
+ Type sums_type;
+ void* sums = nullptr;
+ PMatLayout layout;
+ std::int32_t zero_point = 0;
+};
+
+// Convenient typed helper for packed matrices.
+template <typename Scalar>
+struct PMat final {
+ // The row/column sums needed for quantized matrix multiplication when
+ // the opposite operand of the multiplication uses a non-symmetric zero
+ // point.
+ // This member is only relevant for packed matrices.
+ // Additionally, Ruy always uses 32-bit signed accumulators for quantized
+ // matrix multiplication.
+ // For floating point types, there is no quantization, so this pointer
+ // will always be null. We still need code referencing it to compile
+ // though, even if it is always branched around. Hence we use Scalar*
+ // itself as the type in that case.
+ using SumsType =
+ typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
+ std::int32_t>::type;
+
+ Scalar* data = nullptr;
+ SumsType* sums = nullptr;
+ PMatLayout layout;
+ std::int32_t zero_point = 0;
+};
+
+template <typename T>
+EMat EraseType(const Mat<T>& matrix) {
+ EMat ret;
+ ret.data_type = Type::Create<T>();
+ ret.data = const_cast<void*>(static_cast<const void*>(matrix.data.get()));
+ ret.layout = matrix.layout;
+ ret.zero_point = matrix.zero_point;
+ ret.cache_policy = matrix.cache_policy;
+ return ret;
+}
+
+template <typename T>
+Mat<T> UneraseType(const EMat& matrix) {
+ matrix.data_type.AssertIs<T>();
+ Mat<T> ret;
+ ret.data.set(static_cast<T*>(matrix.data));
+ ret.layout = matrix.layout;
+ ret.zero_point = matrix.zero_point;
+ ret.cache_policy = matrix.cache_policy;
+ return ret;
+}
+
+template <typename T>
+PMat<T> UneraseType(const PEMat& matrix) {
+ using SumsType = typename PMat<T>::SumsType;
+ matrix.data_type.AssertIs<T>();
+ matrix.sums_type.AssertIs<SumsType>();
+ PMat<T> ret;
+ ret.data = static_cast<T*>(matrix.data);
+ ret.sums = static_cast<SumsType*>(matrix.sums);
+ ret.layout = matrix.layout;
+ ret.zero_point = matrix.zero_point;
+ return ret;
+}
+
+// Helpers for MatLayout / PMatLayout.
+
+inline bool IsUnstrided(const MatLayout& layout) {
+ if (layout.order == Order::kColMajor) {
+ return layout.stride == layout.rows;
+ } else {
+ return layout.stride == layout.cols;
+ }
+}
+
+inline bool IsRowMajor(const MatLayout& layout) {
+ return layout.order == Order::kRowMajor;
+}
+
+inline bool IsColMajor(const MatLayout& layout) {
+ return layout.order == Order::kColMajor;
+}
+
+inline int FlatSize(const MatLayout& layout) {
+ const int outerdim =
+ layout.order == Order::kColMajor ? layout.cols : layout.rows;
+ return layout.stride * outerdim;
+}
+
+inline bool IsUnstrided(const PMatLayout& layout) {
+ if (layout.order == Order::kColMajor) {
+ return layout.stride == layout.rows;
+ } else {
+ return layout.stride == layout.cols;
+ }
+}
+
+inline bool IsRowMajor(const PMatLayout& layout) {
+ return layout.order == Order::kRowMajor;
+}
+
+inline bool IsColMajor(const PMatLayout& layout) {
+ return layout.order == Order::kColMajor;
+}
+
+inline int FlatSize(const PMatLayout& layout) {
+ const int outerdim =
+ layout.order == Order::kColMajor ? layout.cols : layout.rows;
+ return layout.stride * outerdim;
+}
+
+// TODO(b/130417400) add a unit test
+inline int Offset(const MatLayout& layout, int row, int col) {
+ // TODO(benoitjacob) - should check this but this make the _slow tests take
+ // 5x longer. Find a mitigation like in Eigen with an 'internal' variant
+ // bypassing the check?
+ // RUY_DCHECK_GE(row, 0);
+ // RUY_DCHECK_GE(col, 0);
+ // RUY_DCHECK_LT(row, layout.rows);
+ // RUY_DCHECK_LT(col, layout.cols);
+ int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride;
+ int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride;
+ return row * row_stride + col * col_stride;
+}
+
+// TODO(b/130417400) add a unit test
+inline int Offset(const PMatLayout& layout, int row, int col) {
+ RUY_DCHECK(is_pot(layout.kernel.rows));
+ RUY_DCHECK(is_pot(layout.kernel.cols));
+ int row_outer = row & ~(layout.kernel.rows - 1);
+ int col_outer = col & ~(layout.kernel.cols - 1);
+ int row_stride_outer =
+ layout.order == Order::kColMajor ? layout.kernel.cols : layout.stride;
+ int col_stride_outer =
+ layout.order == Order::kRowMajor ? layout.kernel.rows : layout.stride;
+ int offset_outer =
+ row_outer * row_stride_outer + col_outer * col_stride_outer;
+ int row_inner = row - row_outer;
+ int col_inner = col - col_outer;
+ int row_stride_inner =
+ layout.kernel.order == Order::kColMajor ? 1 : layout.kernel.cols;
+ int col_stride_inner =
+ layout.kernel.order == Order::kRowMajor ? 1 : layout.kernel.rows;
+ int offset_inner =
+ row_inner * row_stride_inner + col_inner * col_stride_inner;
+ return offset_outer + offset_inner;
+}
+
+// Helpers for Mat<T>.
+
+template <typename Scalar>
+const Scalar* ElementPtr(const Mat<Scalar>& mat, int row, int col) {
+ return mat.data.get() + Offset(mat.layout, row, col);
+}
+
+template <typename Scalar>
+Scalar* ElementPtr(Mat<Scalar>* mat, int row, int col) {
+ return mat->data.get() + Offset(mat->layout, row, col);
+}
+
+template <typename Scalar>
+Scalar Element(const Mat<Scalar>& mat, int row, int col) {
+ return *ElementPtr(mat, row, col);
+}
+
+// Helpers for PMat<T>.
+// Duplicated from Matrix<T>, but the duplication seems acceptable.
+
+template <typename Scalar>
+const Scalar* ElementPtr(const PMat<Scalar>& mat, int row, int col) {
+ return mat.data + Offset(mat.layout, row, col);
+}
+
+template <typename Scalar>
+Scalar* ElementPtr(PMat<Scalar>* mat, int row, int col) {
+ return mat->data + Offset(mat->layout, row, col);
+}
+
+template <typename Scalar>
+Scalar Element(const PMat<Scalar>& mat, int row, int col) {
+ return *ElementPtr(mat, row, col);
+}
+
+// Helpers for PEMat.
+
+inline int DataBytes(const PEMat& packed) {
+ return FlatSize(packed.layout) * packed.data_type.size;
+}
+
+inline int SumsBytes(const PEMat& packed) {
+ // Packed matrices are only relevant for Ruy's TrMul implementations. For
+ // TrMul, the number of sums is always equal to the number of columns.
+ return packed.layout.cols * packed.sums_type.size;
+}
+
+// Transpose helpers.
+
+inline Order Transpose(Order order) {
+ return order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor;
+}
+
+inline MatLayout Transpose(const MatLayout& layout) {
+ MatLayout result(layout);
+ result.order = Transpose(result.order);
+ std::swap(result.rows, result.cols);
+ return result;
+}
+
+template <typename Scalar>
+Mat<Scalar> Transpose(const Mat<Scalar>& matrix) {
+ Mat<Scalar> result(matrix);
+ result.layout = Transpose(result.layout);
+ return result;
+}
+
+// Compile-time version of KernelLayout, used to declare kernel layouts in a
+// way that can be consumed by compile-time logic.
+template <Order tOrder, int tRows, int tCols>
+struct FixedKernelLayout {
+ static constexpr Order kOrder = tOrder;
+ static constexpr int kRows = tRows;
+ static constexpr int kCols = tCols;
+};
+
+template <typename FixedKernelLayout>
+KernelLayout ToKernelLayout() {
+ KernelLayout ret;
+ ret.order = FixedKernelLayout::kOrder;
+ ret.rows = FixedKernelLayout::kRows;
+ ret.cols = FixedKernelLayout::kCols;
+ return ret;
+}
+
+#if (__cplusplus < 201703L)
+// A static constexpr data member is automatically inline and should not require
+// redeclaration without an initializer. This is actually deprecated from C++17
+// onwards. Clang with -O0 without this can fail to link.
+template <Order tOrder, int tRows, int tCols>
+constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kCols;
+template <Order tOrder, int tRows, int tCols>
+constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kRows;
+#endif
+
+} // namespace ruy
+
+#endif // RUY_RUY_INTERNAL_MATRIX_H_
diff --git a/ruy/matrix.h b/ruy/matrix.h
new file mode 100644
index 0000000..c9353e6
--- /dev/null
+++ b/ruy/matrix.h
@@ -0,0 +1,218 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_MATRIX_H_
+#define RUY_RUY_MATRIX_H_
+
+#include <cstddef>
+#include <cstdint> // IWYU pragma: keep
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+
+namespace ruy {
+
+// Layout storage order. Here and elsewhere, 'col' is short for 'column'.
+// 'column-major' means that each column is contiguous in memory.
+enum class Order : std::uint8_t { kColMajor, kRowMajor };
+
+// Describes the shape and storage layout of a matrix.
+class Layout final {
+ public:
+ int rows() const { return rows_; }
+ void set_rows(int val) { rows_ = val; }
+ int cols() const { return cols_; }
+ void set_cols(int val) { cols_ = val; }
+ int stride() const { return stride_; }
+ void set_stride(int val) { stride_ = val; }
+ Order order() const { return order_; }
+ void set_order(Order val) { order_ = val; }
+
+ private:
+ int rows_ = 0;
+ int cols_ = 0;
+ // Stride is the offset between two adjacent matrix elements
+ // in the non-contiguous direction.
+ int stride_ = 0;
+ Order order_ = Order::kColMajor;
+};
+
+namespace detail {
+
+// Thin wrapper around a pointer with a constness model that works for the
+// purposes of the Matrix class.
+//
+// A typical conundrum of any C++ container class is what type constness should
+// encode at compile time constancy of the contained data?
+// `Matrix<const T>` or `const Matrix<T>`?
+// With either approach it is very difficult to achieve perfect
+// const-correctness that that can only be done with some combination of
+// inconvenient interface and c++ complexity/abstraction.
+//
+// Here we opt for something that's entirely tailored to the needs of the Ruy
+// interface. The only purpose of the Matrix class is to pass matrix data
+// pointers to ruy. There is an asymmetry here: the caller of ruy::Mul only
+// needs to `set` the data; ruy itself only needs to `get` the data. In the
+// caller's code, it's convenient to be able to just deal with `Matrix<T>`
+// without having to sprinkle `const` keywords in the right places, so we want
+// to track whether the data is constant in a way that's decoupled from the
+// constness of `this`, and we never want to have Matrix<const T>. Inside ruy
+// code, as input matrices are passed by const-reference and output matrices are
+// passed by pointer (to non-const), the constness of `this` is telling whether
+// the data is constant. See the `get` and `set` methods below and the comment
+// explaining the core logic that they encapsulate.
+template <typename T>
+class ConstCheckingPtr final {
+ public:
+ using element_type = T;
+
+ // Core accessors. These encapsulate the main logic:
+ // - for `set`, the constness of the argument determines whether internal
+ // pointer should be tracked as const/mutable.
+ // - for `get`, the constness of `this` determines whether the call
+ // counts as a const or mutable use of the internal pointer.
+ void set(T* ptr) {
+ ptr_ = ptr;
+ set_mutable(true);
+ }
+ void set(const T* ptr) {
+ ptr_ = ptr;
+ set_mutable(false);
+ }
+ void set(std::nullptr_t) { ptr_ = nullptr; }
+ T* get() /* NOT const */ {
+ assert_mutable();
+ return const_cast<T*>(ptr_);
+ }
+ const T* get() const { return ptr_; }
+
+ private:
+ // There's never a need for Matrix<const T>.
+ static_assert(!std::is_const<T>::value, "");
+ const T* ptr_ = nullptr;
+#ifndef NDEBUG
+ bool is_mutable_ = true;
+ void set_mutable(bool val) { is_mutable_ = val; }
+ void assert_mutable() { RUY_DCHECK(is_mutable_); }
+#else
+ void set_mutable(bool) {}
+ void assert_mutable() {}
+#endif
+};
+
+} // namespace detail
+
+enum class CachePolicy : std::uint8_t {
+ kNeverCache,
+ kCacheIfLargeSpeedup,
+ kCacheIfSignificantSpeedup,
+ kAlwaysCache,
+};
+
+// A Matrix merely wraps existing data as a matrix. It doesn't own any buffer.
+// The purpose of Matrix is only to be used in ruy's interface -- it's just
+// a structured way for the user to pass to ruy::Mul the matrix data pointers
+// together with other matrix parameters.
+// Scalar may be any floating-point or integral type. When integral, it may be
+// signed or unsigned. It's never const: use Matrix<T> for both input and output
+// matrices, never use Matrix<const T>.
+// See the comments on detail::ConstCheckingPointer.
+template <typename Scalar>
+class Matrix final {
+ public:
+ static_assert(!std::is_const<Scalar>::value,
+ "Never use Matrix<const T>. Just use Matrix<T>. Constness of "
+ "the data is guarded by debug-only runtime assertions. See "
+ "detail::ConstCheckingPtr.");
+
+ Scalar* data() { return data_.get(); }
+ const Scalar* data() const { return data_.get(); }
+ void set_data(Scalar* ptr) { data_.set(ptr); }
+ void set_data(const Scalar* ptr) { data_.set(ptr); }
+ void set_data(std::nullptr_t) { data_.set(nullptr); }
+ const Layout& layout() const { return layout_; }
+ Layout* mutable_layout() { return &layout_; }
+ Scalar zero_point() const { return zero_point_; }
+ void set_zero_point(Scalar value) { zero_point_ = value; }
+ CachePolicy cache_policy() const { return cache_policy_; }
+ void set_cache_policy(CachePolicy value) { cache_policy_ = value; }
+
+ private:
+ // The underlying buffer wrapped by this matrix.
+ detail::ConstCheckingPtr<Scalar> data_;
+ // The shape and data layout of this matrix.
+ Layout layout_;
+ // The zero_point, i.e. which Scalar value is to be interpreted as zero.
+ // When Scalar is floating-point, this must be 0.
+ Scalar zero_point_ = 0;
+ // When the data pointed to by this matrix is constant data, so that it is
+ // valid to assume that equality of pointers implies equality of data,
+ // a CachePolicy may be used instead of the default kNeverCache,
+ // which will enable ruy to take advantage of this constancy of the data to
+ // cache the packing work, which can be a large speedup in matrix*vector
+ // and other narrow shapes.
+ CachePolicy cache_policy_ = CachePolicy::kNeverCache;
+};
+
+inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
+ layout->set_rows(rows);
+ layout->set_cols(cols);
+ layout->set_order(order);
+ layout->set_stride(order == Order::kColMajor ? rows : cols);
+}
+
+template <typename StreamType, typename Scalar>
+StreamType& operator<<(StreamType& stream, const Matrix<Scalar>& mat) {
+ for (int row = 0; row < mat.layout().rows(); row++) {
+ for (int col = 0; col < mat.layout().cols(); col++) {
+ stream << static_cast<double>(Element(mat, row, col)) << " ";
+ }
+ stream << "\n";
+ }
+ return stream;
+}
+
+// TODO(b/130417400) add a unit test
+inline int Offset(const Layout& layout, int row, int col) {
+ // TODO(benoitjacob) - should check this but this make the _slow tests take
+ // 5x longer. Find a mitigation like in Eigen with an 'internal' variant
+ // bypassing the check?
+ // RUY_DCHECK_GE(row, 0);
+ // RUY_DCHECK_GE(col, 0);
+ // RUY_DCHECK_LT(row, layout.rows());
+ // RUY_DCHECK_LT(col, layout.cols());
+ int row_stride = layout.order() == Order::kColMajor ? 1 : layout.stride();
+ int col_stride = layout.order() == Order::kRowMajor ? 1 : layout.stride();
+ return row * row_stride + col * col_stride;
+}
+
+template <typename Scalar>
+const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) {
+ return mat.data() + Offset(mat.layout(), row, col);
+}
+
+template <typename Scalar>
+Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) {
+ return mat->data() + Offset(mat->layout(), row, col);
+}
+
+template <typename Scalar>
+Scalar Element(const Matrix<Scalar>& mat, int row, int col) {
+ return *ElementPtr(mat, row, col);
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_MATRIX_H_
diff --git a/ruy/matrix_test.cc b/ruy/matrix_test.cc
new file mode 100644
index 0000000..0f3fd13
--- /dev/null
+++ b/ruy/matrix_test.cc
@@ -0,0 +1,101 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/matrix.h"
+
+#include "ruy/gtest_wrapper.h"
+
+namespace ruy {
+namespace {
+
+TEST(MatrixTest, LayoutClassSanity) {
+ Layout layout;
+ EXPECT_EQ(layout.rows(), 0);
+ EXPECT_EQ(layout.cols(), 0);
+ EXPECT_EQ(layout.stride(), 0);
+ EXPECT_EQ(layout.order(), Order::kColMajor);
+ layout.set_rows(123);
+ layout.set_cols(456);
+ layout.set_stride(789);
+ layout.set_order(Order::kRowMajor);
+ EXPECT_EQ(layout.rows(), 123);
+ EXPECT_EQ(layout.cols(), 456);
+ EXPECT_EQ(layout.stride(), 789);
+ EXPECT_EQ(layout.order(), Order::kRowMajor);
+}
+
+TEST(MatrixTest, MakeSimpleLayout) {
+ Layout layout;
+ MakeSimpleLayout(123, 456, Order::kColMajor, &layout);
+ EXPECT_EQ(layout.rows(), 123);
+ EXPECT_EQ(layout.cols(), 456);
+ EXPECT_EQ(layout.stride(), 123);
+ EXPECT_EQ(layout.order(), Order::kColMajor);
+ MakeSimpleLayout(321, 654, Order::kRowMajor, &layout);
+ EXPECT_EQ(layout.rows(), 321);
+ EXPECT_EQ(layout.cols(), 654);
+ EXPECT_EQ(layout.stride(), 654);
+ EXPECT_EQ(layout.order(), Order::kRowMajor);
+}
+
+TEST(MatrixTest, ConstCheckingPtrSanity) {
+ using PtrType = detail::ConstCheckingPtr<int>;
+ PtrType ptr;
+ int some_nonconst;
+ const int some_const = 0;
+ EXPECT_EQ(ptr.get(), nullptr);
+ ptr.set(&some_nonconst);
+ EXPECT_EQ(static_cast<const PtrType&>(ptr).get(), &some_nonconst);
+ EXPECT_EQ(ptr.get(), &some_nonconst);
+ ptr.set(&some_const);
+ EXPECT_EQ(static_cast<const PtrType&>(ptr).get(), &some_const);
+#ifndef NDEBUG
+ RUY_ASSERT_DEATH(ptr.get(), "");
+#endif
+}
+
+TEST(MatrixTest, MatrixClassSanity) {
+ Matrix<int> matrix;
+ EXPECT_EQ(matrix.data(), nullptr);
+ EXPECT_EQ(matrix.zero_point(), 0);
+ EXPECT_EQ(matrix.cache_policy(), CachePolicy::kNeverCache);
+ EXPECT_EQ(matrix.layout().rows(), 0);
+ EXPECT_EQ(matrix.layout().cols(), 0);
+ EXPECT_EQ(matrix.layout().stride(), 0);
+ EXPECT_EQ(matrix.layout().order(), Order::kColMajor);
+ const int some_const = 0;
+ matrix.set_data(&some_const);
+ matrix.set_zero_point(123);
+ matrix.set_cache_policy(CachePolicy::kAlwaysCache);
+ MakeSimpleLayout(12, 34, Order::kRowMajor, matrix.mutable_layout());
+ EXPECT_EQ(static_cast<const Matrix<int>&>(matrix).data(), &some_const);
+#ifndef NDEBUG
+ RUY_ASSERT_DEATH(matrix.data(), "");
+#endif
+ EXPECT_EQ(matrix.zero_point(), 123);
+ EXPECT_EQ(matrix.cache_policy(), CachePolicy::kAlwaysCache);
+ EXPECT_EQ(matrix.layout().rows(), 12);
+ EXPECT_EQ(matrix.layout().cols(), 34);
+ EXPECT_EQ(matrix.layout().stride(), 34);
+ EXPECT_EQ(matrix.layout().order(), Order::kRowMajor);
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/mul_params.h b/ruy/mul_params.h
new file mode 100644
index 0000000..d5aa27b
--- /dev/null
+++ b/ruy/mul_params.h
@@ -0,0 +1,299 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_MUL_PARAMS_H_
+#define RUY_RUY_MUL_PARAMS_H_
+
+#include <cstdint>
+#include <limits>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/size_util.h"
+
+namespace ruy {
+
+// Enumeration to designate which dimension is the 'channels', for MulParams
+// features that are 'per-channel', namely the bias-vector and the quantized
+// multiplier.
+enum class ChannelDimension : std::int8_t {
+ // kRow means that 'per-channel' means 'per row of the destination matrix'
+ kRow,
+ // kCol means that 'per-channel' means 'per column of the destination matrix'
+ kCol
+};
+
+namespace detail {
+template <typename tAccumScalar, typename tDstScalar>
+struct MulParamsStorage;
+}
+
+// MulParams describes all about a matrix multiplication that
+// isn't encoded in the LHS, RHS and destination matrices. Some of that
+// information is encoded as compile-time constants and types (for instance, the
+// choice of accumulator type, AccumScalar). Some of that information is encoded
+// as runtime values (for instance, the optional bias vector).
+//
+// Template parameters:
+// AccumScalar: Accumulator type. The type of accumulators used to compute the
+// dot-products before being ultimately casted to the destination type.
+// DstScalar: The destination scalar type.
+//
+// Constraints on these template parameters (see also the ruy::Mul comment):
+// * If DstScalar is floating-point then AccumScalar must also be.
+// * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover
+// in that integral case, there is a mode switch:
+// - If DstScalar is std::int32_t then the multiplier_* fields are all
+// disabled, and ruy::Mul will just return raw (unscaled) accumulators.
+// - If DstScalar is not std::int32_t then the multiplier_* fields are
+// enabled, and ruy::Mul will use them to scale internal std::int32_t
+// accumulators before casting them to the DstScalar type. The default
+// values are such that the effective multiplier is 1 (no scaling).
+//
+// For the latter case (DstScalar integral and narrower than std::int32_t),
+// reference code can be found in the implementation of ruy::ApplyMultiplier.
+// If you look there, you'll find warnings like this:
+//
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Warning: this code is not meant to be bit-exact-normative.
+// Please refer to the class comment of ruy::MulParams, in mul_params.h.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+//
+// The explanation of this warning is that as of early 2021, we still don't know
+// whether it is advisable to let this code as-is have normative value, or
+// whether that would become advisable after some specific final change.
+//
+// Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform
+// bit-exactly to this reference, but we also know that x86 could be faster if
+// it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't
+// know that this particular reference code is inherently better than other
+// forms that could perform better on these architectures --- in fact, the
+// alternative that was proposed in [2] as better performing on ARM Cortex-M
+// is also inherently more accurate thanks to rounding only once, but it would
+// perform worse on both ARM NEON, and x86.
+//
+// In fact, if we look at other hardware architectures beyond current Ruy
+// targets, namely "hardware accelerators", it becomes clear that there is no
+// hope for any form of this to be efficiently implementable simultaneously on
+// all current relevant hardware. Indeed, some accelerators prefer to perform
+// the multiplication in IEEE float32, others in IEEE float16, others in
+// bfloat16, others in 16-bit fixed-point...
+//
+// See:
+// [1] https://github.com/google/ruy/pull/227
+// [2] https://github.com/tensorflow/tensorflow/issues/25087
+template <typename tAccumScalar, typename tDstScalar>
+class MulParams final {
+ public:
+ using AccumScalar = tAccumScalar;
+ using DstScalar = tDstScalar;
+
+ // The bias vector data, if not null.
+ const AccumScalar* bias() const { return storage_.bias; }
+ void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
+ // Only for non-floating-point cases. The fixed-point part of the multiplier
+ // by which accumulators are multiplied before being casted to the destination
+ // type. This is a fixed-point quantity with 0 integer bits. Since
+ // (as explained in the class comment) AccumScalar must be std::int32_t,
+ // that means that the fixed-point format is Q0.31. For example,
+ // a multiplier_fixedpoint value of 2^30 has the effect of multiplying
+ // by one half (1/2). More generally, the effect is to multiply by
+ // (multiplier_fixedpoint / (2^31)).
+ AccumScalar multiplier_fixedpoint() const {
+ return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
+ }
+ void set_multiplier_fixedpoint(const AccumScalar value) {
+ set_perchannel(false);
+ storage_.multiplier_fixedpoint = value;
+ }
+ // Only for non-floating-point cases. The exponent part of the aforementioned
+ // multiplier.
+ int multiplier_exponent() const {
+ return storage_.perchannel ? 0 : storage_.multiplier_exponent;
+ }
+ void set_multiplier_exponent(const int value) {
+ set_perchannel(false);
+ storage_.multiplier_exponent = value;
+ }
+ // Per-channel variant of multiplier_fixedpoint. Setting this switches
+ // to per-channel mode, where `multiplier_fixedpoint` and
+ // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
+ // and `multiplier_exponent_perchannel` are used instead.
+ //
+ // This must point to a buffer of as many values as there are rows or columns
+ // in the destination matrix, whichever is the channels dimension. Each
+ // channel of the destination matrix will use the corresponding buffer element
+ // instead of multiplier_fixedpoint.
+ const AccumScalar* multiplier_fixedpoint_perchannel() const {
+ return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
+ : nullptr;
+ }
+ void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) {
+ set_perchannel(true);
+ storage_.multiplier_fixedpoint_perchannel = ptr;
+ }
+ // Per-channel variant of multiplier_exponent. Same comments as for
+ // multiplier_fixedpoint_perchannel.
+ const int* multiplier_exponent_perchannel() const {
+ return storage_.perchannel ? storage_.multiplier_exponent_perchannel
+ : nullptr;
+ }
+ void set_multiplier_exponent_perchannel(const int* ptr) {
+ set_perchannel(true);
+ storage_.multiplier_exponent_perchannel = ptr;
+ }
+ // min clamp bound of destination values.
+ DstScalar clamp_min() const { return storage_.clamp_min; }
+ void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; }
+ // max clamp bound of destination values.
+ DstScalar clamp_max() const { return storage_.clamp_max; }
+ void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; }
+ // Designates which dimension is the 'channels', for per-channel features
+ // such as bias-addition and per-channel quantization multipliers.
+ ChannelDimension channel_dimension() const {
+ return storage_.channel_dimension;
+ }
+ void set_channel_dimension(ChannelDimension value) {
+ storage_.channel_dimension = value;
+ }
+ // Specifies the upward rounding of the allocated capacity of per-channel
+ // buffers such as bias vectors and per-channel quantization multipliers.
+ // The unit is matrix entries, not bytes.
+ //
+ // This value must be a power of two.
+ //
+ // The default value, 1, means no upward rounding, meaning that the buffers
+ // are not required to have a capacity greater than the size of the
+ // corresponding matrix dimension, i.e. the number of rows (respectively
+ // columns) of the destination matrix if `channel_dimension()` is kRow
+ // (respectively kCol).
+ //
+ // Higher values allow the implementation to assume that it is OK to access
+ // these buffers a little past this boundary, which is useful in SIMD
+ // optimized kernels. In practice, when this value is lower than what the
+ // kernel requires, ruy has to internally reallocate and copy per-channel
+ // buffers. When this value is high enough, this reallocation and copy is
+ // avoided.
+ //
+ // When a value greater than 1 is specified, the tail region of the buffer
+ // (past the end of the values actually corresponding to channels) is required
+ // to be zero-initialized.
+ //
+ // As of 2020, values as high as 16 may be useful on some CPU architectures
+ // (corresponding to the widest kernels used on any CPU architecture).
+ int perchannel_buffers_capacity_rounding() const {
+ return 1 << storage_.perchannel_buffers_capacity_rounding_log2;
+ }
+ void set_perchannel_buffers_capacity_rounding(int value) {
+ // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two.
+ storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value);
+ }
+
+ private:
+ detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
+
+ void set_perchannel(bool perchannel) {
+ storage_.perchannel = perchannel;
+ }
+};
+
+namespace detail {
+
+// Floating-point case.
+template <typename AccumScalar, typename DstScalar>
+struct MulParamsStorage final {
+ static_assert(std::is_floating_point<AccumScalar>::value, "");
+ static_assert(std::is_floating_point<DstScalar>::value, "");
+ static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), "");
+
+ const AccumScalar* bias = nullptr;
+ DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity();
+ DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity();
+ ChannelDimension channel_dimension = ChannelDimension::kRow;
+ std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
+
+ // Data members that are disabled in this case are left as `static constexpr`
+ // so that one can write some generic code.
+ static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
+ nullptr;
+ static constexpr const int* multiplier_exponent_perchannel = nullptr;
+ static constexpr AccumScalar multiplier_fixedpoint = 0;
+ static constexpr int multiplier_exponent = 0;
+ static constexpr bool perchannel = false;
+};
+
+// Specialization for the integer-quantized type, with down-quantization of
+// int32 accumulators to a narrower destination scalar type.
+template <typename DstScalar>
+struct MulParamsStorage<std::int32_t, DstScalar> final {
+ using AccumScalar = std::int32_t;
+ static_assert(std::is_integral<DstScalar>::value, "");
+ static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, "");
+
+ const AccumScalar* bias = nullptr;
+ union {
+ const AccumScalar* multiplier_fixedpoint_perchannel;
+ // Let the default multiplier be effecively a multiplication by 1, so that
+ // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
+ // 1 is not exactly representable in fixedpoint with 0 integer bits, but
+ // using the highest representable value is a sufficiently good
+ // approximation: since this specialization of MulParams is for the case
+ // where DstScalar is at least 2x narrower than MulScalar, the values
+ // for which there would be a difference will get saturated anyway.
+ AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max();
+ };
+ union {
+ const int* multiplier_exponent_perchannel;
+ // See the above comment about the default value of multiplier_fixedpoint.
+ int multiplier_exponent = 0;
+ };
+ DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
+ DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
+ ChannelDimension channel_dimension = ChannelDimension::kRow;
+ bool perchannel = false;
+ std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
+};
+
+// Specialization used in the integer case when outputting raw int32
+// accumulators, without down-quantization to a narrower destination scalar
+// type. In this case, the feature of clamping destination values is not
+// available.
+template <>
+struct MulParamsStorage<std::int32_t, std::int32_t> final {
+ using AccumScalar = std::int32_t;
+ using DstScalar = std::int32_t;
+
+ const AccumScalar* bias = nullptr;
+ ChannelDimension channel_dimension = ChannelDimension::kRow;
+ std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
+
+ // Data members that are disabled in this case are left as `static constexpr`
+ // so that one can write some generic code.
+ static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
+ nullptr;
+ static constexpr const int* multiplier_exponent_perchannel = nullptr;
+ static constexpr AccumScalar multiplier_fixedpoint = 0;
+ static constexpr int multiplier_exponent = 0;
+ static constexpr DstScalar clamp_min =
+ std::numeric_limits<DstScalar>::lowest();
+ static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
+ static constexpr bool perchannel = false;
+};
+
+} // namespace detail
+
+} // namespace ruy
+
+#endif // RUY_RUY_MUL_PARAMS_H_
diff --git a/ruy/mul_params_test.cc b/ruy/mul_params_test.cc
new file mode 100644
index 0000000..4bc9f87
--- /dev/null
+++ b/ruy/mul_params_test.cc
@@ -0,0 +1,79 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/mul_params.h"
+
+#include <cstdint>
+#include <type_traits>
+
+#include "ruy/gtest_wrapper.h"
+
+namespace ruy {
+namespace {
+
+TEST(MulParamsTest, SpecClassSanity) {
+ using MulParamsType = MulParams<std::int32_t, std::int8_t>;
+ static_assert(std::is_same<MulParamsType::AccumScalar, std::int32_t>::value,
+ "");
+ static_assert(std::is_same<MulParamsType::DstScalar, std::int8_t>::value, "");
+
+ MulParamsType mul_params;
+ EXPECT_EQ(mul_params.bias(), nullptr);
+ EXPECT_EQ(mul_params.multiplier_fixedpoint(), std::numeric_limits<std::int32_t>::max());
+ EXPECT_EQ(mul_params.multiplier_exponent(), 0);
+ EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), nullptr);
+ EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), nullptr);
+ EXPECT_EQ(mul_params.clamp_min(), -128);
+ EXPECT_EQ(mul_params.clamp_max(), 127);
+ EXPECT_EQ(mul_params.channel_dimension(), ChannelDimension::kRow);
+ EXPECT_EQ(mul_params.perchannel_buffers_capacity_rounding(), 1);
+ std::int32_t bias_data[1];
+ mul_params.set_bias(bias_data);
+ mul_params.set_multiplier_fixedpoint(123);
+ mul_params.set_multiplier_exponent(4);
+ mul_params.set_channel_dimension(ChannelDimension::kCol);
+ mul_params.set_perchannel_buffers_capacity_rounding(8);
+ EXPECT_EQ(mul_params.bias(), bias_data);
+ EXPECT_EQ(mul_params.multiplier_fixedpoint(), 123);
+ EXPECT_EQ(mul_params.multiplier_exponent(), 4);
+ EXPECT_EQ(mul_params.channel_dimension(), ChannelDimension::kCol);
+ EXPECT_EQ(mul_params.perchannel_buffers_capacity_rounding(), 8);
+ mul_params.set_multiplier_fixedpoint(0);
+ mul_params.set_multiplier_exponent(0);
+ std::int32_t multiplier_fixedpoint_perchannel_data[1];
+ int multiplier_exponent_perchannel_data[1];
+ mul_params.set_multiplier_fixedpoint_perchannel(
+ multiplier_fixedpoint_perchannel_data);
+ mul_params.set_multiplier_exponent_perchannel(
+ multiplier_exponent_perchannel_data);
+ mul_params.set_clamp_min(-10);
+ mul_params.set_clamp_max(10);
+ EXPECT_EQ(mul_params.multiplier_fixedpoint(), 0);
+ EXPECT_EQ(mul_params.multiplier_exponent(), 0);
+ EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(),
+ multiplier_fixedpoint_perchannel_data);
+ EXPECT_EQ(mul_params.multiplier_exponent_perchannel(),
+ multiplier_exponent_perchannel_data);
+ EXPECT_EQ(mul_params.clamp_min(), -10);
+ EXPECT_EQ(mul_params.clamp_max(), 10);
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/opt_set.h b/ruy/opt_set.h
new file mode 100644
index 0000000..244d9da
--- /dev/null
+++ b/ruy/opt_set.h
@@ -0,0 +1,51 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_OPT_SET_H_
+#define RUY_RUY_OPT_SET_H_
+
+// RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling
+// certain optimizations. It should be used by defining that macro on the
+// compiler command line.
+//
+// Each bit in RUY_OPT_SET controls a particular optimization done in Ruy.
+#define RUY_OPT_BIT_INTRINSICS 0x1
+#define RUY_OPT_BIT_ASM 0x2
+#define RUY_OPT_BIT_TUNING 0x4
+#define RUY_OPT_BIT_FAT_KERNEL 0x8
+// 0x10 used to be RUY_OPT_BIT_NATIVE_ROUNDING
+#define RUY_OPT_BIT_AVOID_ALIASING 0x20
+#define RUY_OPT_BIT_MAX_STREAMING 0x40
+#define RUY_OPT_BIT_PACK_AHEAD 0x80
+#define RUY_OPT_BIT_PREFETCH_LOAD 0x100
+#define RUY_OPT_BIT_PREFETCH_STORE 0x200
+#define RUY_OPT_BIT_FRACTAL_Z 0x400
+#define RUY_OPT_BIT_FRACTAL_U 0x800
+#define RUY_OPT_BIT_FRACTAL_HILBERT 0x1000
+
+#if !defined(RUY_OPT_SET)
+#ifdef RUY_OPTIMIZE_FOR_MATMUL_BENCHMARK
+// Load prefetching is detrimental in matrix multiplication benchmarks.
+// Store prefetching is not.
+#define RUY_OPT_SET (~RUY_OPT_BIT_PREFETCH_LOAD)
+#else
+// Default to all optimizations.
+#define RUY_OPT_SET (~0)
+#endif
+#endif
+
+#define RUY_OPT(X) ((RUY_OPT_SET & RUY_OPT_BIT_##X) != 0)
+
+#endif // RUY_RUY_OPT_SET_H_
diff --git a/ruy/pack.h b/ruy/pack.h
new file mode 100644
index 0000000..744f9bc
--- /dev/null
+++ b/ruy/pack.h
@@ -0,0 +1,155 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// # What is "packing"?
+//
+// Before feeding data to the gemm kernels (the parts of Ruy that do lots
+// of multiply-add operations), Ruy first performs a data transformation (which
+// we call "packing") on the input matrices. This transformation has two main
+// goals:
+// - rearrange data into blocks that are a convenient size/layout for the gemm
+// kernels to consume. This helps make the memory access pattern of the gemm
+// kernel simpler and more contiguous, and puts the data in a layout most
+// convenient for specific arithmetic instructions in the gemm kernel.
+// - compute row/column sums needed for handling quantization with non-symmetric
+// zero points.
+//
+// # Simplified algorithmic analysis of packing
+//
+// Packing is a relatively simple transformation which does a small constant
+// amount of work on each element of an input matrix, and hence for an NxM
+// matrix performs O(N*M) work. If N and M are of the same order, then this is
+// O(N^2) work.
+//
+// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
+// Note that if N, K, and M are all the same order, then the number of
+// multiply-accumulate operations is O(N^3).
+//
+// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
+// case of all dimensions being roughly the same order.
+//
+// # Packing cost can be significant
+//
+// When matrix * matrix multiplications begin to look more like matrix * vector
+// multiplications, packing cost can become significant. We sometimes call these
+// cases "gemv-like".
+//
+// Continuing the algorithmic analysis above, if we consider a case where an
+// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
+// situation is different. In this case, the multiply-accumulate work is only
+// quadratic, so the quadratic cost of packing can be come significant.
+//
+// Another way to say this is that the cost of packing an input matrix (either
+// the LHS or RHS) is amortized across the non-depth dimension of the opposite
+// input matrix. Thus, when the LHS has very few rows or the RHS has very few
+// columns, the cost of packing the opposite input matrix can become
+// significant.
+//
+// As a rough rule of thumb, the cost of packing starts to become significant
+// when either N or M is below 32 (and other dimensions are hundreds), with very
+// significant packing costs at 8 or below. This varies by data type, Path, and
+// tuning, so these numbers are only rough guides.
+//
+// One practical use case that is affected by this is inference of
+// fully connected neural network layers with a low batch size. The weight
+// matrix (which is a constant for inference) is the one affected by significant
+// packing cost.
+//
+// Ruy has an optional feature, accessed by Matrix::set_cache_policy(), to
+// cache the packed forms of constant matrices.
+//
+// # Implementation notes
+//
+// Ruy's packing routines always operate on a range of columns and can be
+// applied to either the LHS or RHS. This is possible because Ruy internally
+// implements a TrMul, so the accumulation along depth is done along columns of
+// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
+// for the LHS is along rows). As another example, we are always computing
+// column sums for quantization (and never row sums, since the LHS is
+// transposed).
+
+#ifndef RUY_RUY_PACK_H_
+#define RUY_RUY_PACK_H_
+
+#include "ruy/mat.h"
+#include "ruy/pack_common.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/trace.h"
+
+// IWYU pragma: begin_exports
+#if RUY_PLATFORM_NEON
+#include "ruy/pack_arm.h"
+#elif RUY_PLATFORM_X86
+#include "ruy/pack_x86.h"
+#endif
+// IWYU pragma: end_exports
+
+namespace ruy {
+
+// General implementation of the PackImpl template, overridden by template
+// specializations for specific SIMD code paths. This general implementation
+// is used with Path::kStandardCpp and its internal test-only variants.
+template <Path ThePath, typename FixedKernelLayout, typename Scalar,
+ typename PackedScalar, typename SumsType, Order SrcOrder>
+struct PackImpl {
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<PackedScalar>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (generic)");
+ RUY_DCHECK_EQ(SrcOrder, src_matrix.layout.order);
+ RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0);
+ SumsType* sums = packed_matrix->sums;
+ for (int col = start_col; col < end_col; col++) {
+ SumsType accum = 0;
+ for (int row = 0; row < packed_matrix->layout.rows; row++) {
+ PackedScalar packed_val;
+ if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) {
+ packed_val = Pack<PackedScalar>(Element(src_matrix, row, col));
+ } else {
+ packed_val = packed_matrix->zero_point;
+ }
+ accum += packed_val;
+ *ElementPtr(packed_matrix, row, col) = packed_val;
+ }
+ if (sums) {
+ sums[col] = accum;
+ }
+ }
+ }
+};
+
+// Main entry point for packing.
+template <Path ThePath, typename FixedKernelLayout, typename Scalar,
+ typename PackedScalar>
+void RunPack(Tuning tuning, const EMat& src_matrix, PEMat* packed_matrix,
+ int start_col, int end_col) {
+ RUY_TRACE_SCOPE;
+ using SumsType = typename PMat<PackedScalar>::SumsType;
+ Mat<Scalar> src = UneraseType<Scalar>(src_matrix);
+ PMat<PackedScalar> packed = UneraseType<PackedScalar>(*packed_matrix);
+ RUY_TRACE_INFO(RUN_PACK);
+ if (src.layout.order == Order::kColMajor) {
+ PackImpl<ThePath, FixedKernelLayout, Scalar, PackedScalar, SumsType,
+ Order::kColMajor>::Run(tuning, src, &packed, start_col, end_col);
+ } else {
+ PackImpl<ThePath, FixedKernelLayout, Scalar, PackedScalar, SumsType,
+ Order::kRowMajor>::Run(tuning, src, &packed, start_col, end_col);
+ }
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_PACK_H_
diff --git a/ruy/pack_arm.cc b/ruy/pack_arm.cc
new file mode 100644
index 0000000..c337986
--- /dev/null
+++ b/ruy/pack_arm.cc
@@ -0,0 +1,2480 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/pack_arm.h"
+
+#include <cstdint>
+
+#include "ruy/asm_helpers.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack_common.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM_NEON
+#include <arm_neon.h>
+#endif
+
+namespace ruy {
+
+#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+
+void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+ int input_xor) {
+ profiler::ScopeLabel label("Pack (kNeon)");
+ asm volatile(
+ // clang-format off
+ // v26 will be the vector to XOR input values with to perform
+ // any input data type conversion (e.g. uint8 to int8).
+ "dup v26.16b, %w[input_xor]\n"
+ // w1 will be the number of rows already loaded.
+ "mov w1, #0\n"
+ // v28--v32 will be used to accumulate the sums
+ "movi v28.4s, #0\n"
+ "movi v29.4s, #0\n"
+ "movi v30.4s, #0\n"
+ "movi v31.4s, #0\n"
+ // Let w2 be `rows` rounded down to multiple of 16.
+ "ands w2, %w[rows], #-16\n"
+ // If there are no full blocks of 16 rows to process, jump to the
+ // code handling the last < 16 rows.
+ "beq 3f\n"
+ // Load the first block of 16 rows.
+ "add w1, w1, #16\n"
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ // Check if these were the only full block of 16 rows to load.
+ "cmp w1, w2\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ // In that case, jump to the code handling the last loaded block of
+ // 16 rows.
+ "beq 2f\n"
+ // Main loop processing blocks of 16 rows.
+ "1:\n"
+ // Load the next 16 rows, interleaved with the XOR input type
+ // conversion (e.g. uint8->int8) on the already loaded inputs.
+ "add w1, w1, #16\n"
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ // Compute the sums, interleaved with storing to the packed matrix.
+ "saddlp v16.8h, v4.16b\n"
+ "str q4, [%[packed_ptr], #0]\n"
+ "saddlp v17.8h, v5.16b\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "saddlp v18.8h, v6.16b\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "saddlp v19.8h, v7.16b\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "sadalp v28.4s, v16.8h\n"
+ // Was this the last block of 16 rows to load?
+ "cmp w1, w2\n"
+ "sadalp v29.4s, v17.8h\n"
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+ // End of main loop on blocks of 16 rows.
+ "bne 1b\n"
+
+ // Code handling the last already-loaded block of 16 rows.
+ "2:\n"
+
+ // Process the last loaded full 16x4 block.
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "str q4, [%[packed_ptr], #0]\n"
+ "saddlp v17.8h, v5.16b\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "saddlp v18.8h, v6.16b\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "saddlp v19.8h, v7.16b\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "sadalp v28.4s, v16.8h\n"
+ "sadalp v29.4s, v17.8h\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+
+ // End of code handling full blocks of 16 rows.
+ // Now we handle any remaining rows.
+ "3:\n"
+ // Let w2 be the number of rows left to handle.
+ "ands w2, %w[rows], #15\n"
+ // If w2==0, there are no remaining rows, jump to the end.
+ "beq 4f\n"
+ // Zero out a 16x4 block in registers, which we'll partially overwrite
+ // with any remaining rows.
+ "dup v0.16b, %w[src_zero_point]\n"
+ "dup v1.16b, %w[src_zero_point]\n"
+ "dup v2.16b, %w[src_zero_point]\n"
+ "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+ RUY_LOAD_ONE_ROW(4)
+ RUY_LOAD_ONE_ROW(5)
+ RUY_LOAD_ONE_ROW(6)
+ RUY_LOAD_ONE_ROW(7)
+ RUY_LOAD_ONE_ROW(8)
+ RUY_LOAD_ONE_ROW(9)
+ RUY_LOAD_ONE_ROW(10)
+ RUY_LOAD_ONE_ROW(11)
+ RUY_LOAD_ONE_ROW(12)
+ RUY_LOAD_ONE_ROW(13)
+ RUY_LOAD_ONE_ROW(14)
+ // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ // Process the last zero-padded 16x4 block.
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "saddlp v17.8h, v5.16b\n"
+ "saddlp v18.8h, v6.16b\n"
+ "saddlp v19.8h, v7.16b\n"
+ "sadalp v28.4s, v16.8h\n"
+ "sadalp v29.4s, v17.8h\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "str q4, [%[packed_ptr], #0]\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+
+ "4:\n"
+
+ // Horizontal reduction of the registers used to accumulate sums.
+ "addp v28.4s, v28.4s, v29.4s\n"
+ "addp v30.4s, v30.4s, v31.4s\n"
+ "addp v28.4s, v28.4s, v30.4s\n"
+
+ // Store the sums.
+ "cmp %[sums_ptr], #0\n"
+ "beq 6f\n"
+ "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+ "6:\n"
+ // clang-format on
+
+ : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
+ [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
+ [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
+ : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
+ [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
+ [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
+ [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
+ [rows] "r"(src_rows), [src_zero_point] "r"(src_zero_point),
+ [input_xor] "r"(input_xor)
+ : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+}
+#endif
+
+#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
+
+#define RUY_OFFSET_SRC_PTR0 0
+#define RUY_OFFSET_SRC_PTR1 4
+#define RUY_OFFSET_SRC_PTR2 8
+#define RUY_OFFSET_SRC_PTR3 12
+#define RUY_OFFSET_SUMS_PTR 16
+#define RUY_OFFSET_PACKED_PTR 20
+#define RUY_OFFSET_SRC_INC0 24
+#define RUY_OFFSET_SRC_INC1 28
+#define RUY_OFFSET_SRC_INC2 32
+#define RUY_OFFSET_SRC_INC3 36
+#define RUY_OFFSET_SRC_ROWS 40
+#define RUY_OFFSET_SRC_ZERO_POINT 44
+#define RUY_OFFSET_INPUT_XOR 48
+
+template <typename Params>
+void CheckOffsetsInPackParams8bit(const Params&) {
+ static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, "");
+ static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, "");
+ static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, "");
+ static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, "");
+ static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, "");
+ static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, "");
+ static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, "");
+ static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, "");
+ static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, "");
+ static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, "");
+ static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, "");
+ static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, "");
+}
+
+// No attempt made at making this code efficient on A55-ish cores yet.
+void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params) {
+ CheckOffsetsInPackParams8bit(params);
+ profiler::ScopeLabel label("Pack (kNeon)");
+ const void* src_ptr0 = params.src_ptr0;
+ const void* src_ptr1 = params.src_ptr1;
+ const void* src_ptr2 = params.src_ptr2;
+ const void* src_ptr3 = params.src_ptr3;
+ const int src_inc0 = params.src_inc0;
+ const int src_inc1 = params.src_inc1;
+ const int src_inc2 = params.src_inc2;
+ const int src_inc3 = params.src_inc3;
+ const std::int8_t* packed_ptr = params.packed_ptr;
+
+ asm volatile(
+ // clang-format off
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n"
+ "vdup.8 q11, r2\n"
+ "mov r1, #0\n"
+ // Zero-out the accumulators
+ "vmov.i32 q12, #0\n"
+ "vmov.i32 q13, #0\n"
+ "vmov.i32 q14, #0\n"
+ "vmov.i32 q15, #0\n"
+
+ // Round down src_rows to nearest multiple of 16.
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
+ "and r2, r3, #-16\n"
+ "cmp r1, r2\n"
+ "beq 3f\n"
+
+ "1:\n"
+ "add r1, r1, #16\n"
+ /* Load q0 */
+ "vld1.8 {d0, d1}, [%[src_ptr0]]\n"
+ "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n"
+ RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n")
+
+ /* Load q1 */
+ "vld1.8 {d2, d3}, [%[src_ptr1]]\n"
+ "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n"
+ RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n")
+
+ "veor.8 q4, q0, q11\n"
+ "veor.8 q5, q1, q11\n"
+
+ // Pairwise add in to 16b accumulators.
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ // Pairwise add accumulate into 32b accumulators.
+ // q12 and q13 contain 4x32b accumulators
+ "vpadal.s16 q12, q8\n"
+ "vpadal.s16 q13, q9\n"
+
+ // Now do the same for src_ptr2 and src_ptr3.
+ "vld1.8 {d0, d1}, [%[src_ptr2]]\n"
+ "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n"
+ RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n")
+
+ "vld1.8 {d2, d3}, [%[src_ptr3]]\n"
+ "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n"
+ RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n")
+
+ "veor.8 q4, q0, q11\n"
+ "veor.8 q5, q1, q11\n"
+
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ // Pairwise add accumulate into 32b accumulators.
+ // q14 and q15 contain 4x32b accumulators
+ "vpadal.s16 q14, q8\n"
+ "vpadal.s16 q15, q9\n"
+
+ "cmp r1, r2\n"
+ "bne 1b\n"
+
+ "3:\n"
+
+ // Now pack the last (num_rows % 16) rows.
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
+ "ands r2, r3, #15\n"
+ "beq 4f\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n"
+ "vdup.8 q0, r3\n"
+ "vdup.8 q1, r3\n"
+
+// First, read/accumulate/write for src_ptr0 and src_ptr1.
+#define RUY_LOAD_ONE_ROW1(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \
+ "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \
+
+ RUY_LOAD_ONE_ROW1(0, 0)
+ RUY_LOAD_ONE_ROW1(1, 1)
+ RUY_LOAD_ONE_ROW1(2, 2)
+ RUY_LOAD_ONE_ROW1(3, 3)
+ RUY_LOAD_ONE_ROW1(4, 4)
+ RUY_LOAD_ONE_ROW1(5, 5)
+ RUY_LOAD_ONE_ROW1(6, 6)
+ RUY_LOAD_ONE_ROW1(7, 7)
+#undef RUY_LOAD_ONE_ROW1
+
+#define RUY_LOAD_ONE_ROW2(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \
+ "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \
+
+ RUY_LOAD_ONE_ROW2(8, 0)
+ RUY_LOAD_ONE_ROW2(9, 1)
+ RUY_LOAD_ONE_ROW2(10, 2)
+ RUY_LOAD_ONE_ROW2(11, 3)
+ RUY_LOAD_ONE_ROW2(12, 4)
+ RUY_LOAD_ONE_ROW2(13, 5)
+ RUY_LOAD_ONE_ROW2(14, 6)
+ RUY_LOAD_ONE_ROW2(15, 7)
+#undef RUY_LOAD_ONE_ROW2
+
+ "5:\n"
+
+ "veor.16 q4, q0, q11\n"
+ "veor.16 q5, q1, q11\n"
+
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ // Pairwise add accumulate to 4x32b accumulators.
+ "vpadal.s16 q12, q8\n"
+ "vpadal.s16 q13, q9\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ // Reset to src_zero for src_ptr2 and src_ptr3.
+ "vdup.8 q0, r3\n"
+ "vdup.8 q1, r3\n"
+
+// Next, read/accumulate/write for src_ptr2 and src_ptr3.
+#define RUY_LOAD_ONE_ROW1(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \
+ "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \
+
+ RUY_LOAD_ONE_ROW1(0, 0)
+ RUY_LOAD_ONE_ROW1(1, 1)
+ RUY_LOAD_ONE_ROW1(2, 2)
+ RUY_LOAD_ONE_ROW1(3, 3)
+ RUY_LOAD_ONE_ROW1(4, 4)
+ RUY_LOAD_ONE_ROW1(5, 5)
+ RUY_LOAD_ONE_ROW1(6, 6)
+ RUY_LOAD_ONE_ROW1(7, 7)
+#undef RUY_LOAD_ONE_ROW1
+
+#define RUY_LOAD_ONE_ROW2(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \
+ "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \
+
+ RUY_LOAD_ONE_ROW2(8, 0)
+ RUY_LOAD_ONE_ROW2(9, 1)
+ RUY_LOAD_ONE_ROW2(10, 2)
+ RUY_LOAD_ONE_ROW2(11, 3)
+ RUY_LOAD_ONE_ROW2(12, 4)
+ RUY_LOAD_ONE_ROW2(13, 5)
+ RUY_LOAD_ONE_ROW2(14, 6)
+ RUY_LOAD_ONE_ROW2(15, 7)
+#undef RUY_LOAD_ONE_ROW2
+
+ "5:\n"
+
+ "veor.16 q4, q0, q11\n"
+ "veor.16 q5, q1, q11\n"
+
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ // Pairwise add accumulate to 4x32b accumulators.
+ "vpadal.s16 q14, q8\n"
+ "vpadal.s16 q15, q9\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ "4:\n"
+ // Pairwise add 32-bit accumulators
+ "vpadd.i32 d24, d24, d25\n"
+ "vpadd.i32 d26, d26, d27\n"
+ "vpadd.i32 d28, d28, d29\n"
+ "vpadd.i32 d30, d30, d31\n"
+ // Final 32-bit values per row
+ "vpadd.i32 d25, d24, d26\n"
+ "vpadd.i32 d27, d28, d30\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n"
+ "cmp r3, #0\n"
+ "beq 6f\n"
+ "vst1.32 {d25}, [r3]!\n"
+ "vst1.32 {d27}, [r3]!\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3)
+ : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1),
+ [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3),
+ [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(&params)
+ : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
+ "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13");
+}
+
+// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9.
+// No attempt made at making this code efficient on in-order cores yet.
+// This version differs from the above in that we only handle two columns
+// at a time.
+void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params) {
+ CheckOffsetsInPackParams8bit(params);
+ profiler::ScopeLabel label("Pack (kNeon)");
+ const void* src_ptr0 = params.src_ptr0;
+ const void* src_ptr1 = params.src_ptr1;
+ const int src_inc0 = params.src_inc0;
+ const int src_inc1 = params.src_inc1;
+ const std::int8_t* packed_ptr = params.packed_ptr;
+
+ asm volatile(
+ // clang-format off
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n"
+ "vdup.8 q11, r2\n"
+ "mov r1, #0\n"
+ // Zero-out the accumulators
+ "vmov.i32 q12, #0\n"
+ "vmov.i32 q13, #0\n"
+
+ // Round down src_rows to nearest multiple of 16.
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
+ "and r2, r3, #-16\n"
+ "cmp r1, r2\n"
+ "beq 3f\n"
+
+ "1:\n"
+ "add r1, r1, #16\n"
+ /* Load q0 */
+ "vld1.8 {d0, d1}, [%[src_ptr0]]\n"
+ "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n"
+
+ /* Load q1 */
+ "vld1.8 {d2, d3}, [%[src_ptr1]]\n"
+ "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n"
+
+ "veor.8 q4, q0, q11\n"
+ "veor.8 q5, q1, q11\n"
+
+ // Pairwise add in to 16b accumulators.
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ // Pairwise add accumulate into 32b accumulators.
+ // q12 and q13 contain 4x32b accumulators
+ "vpadal.s16 q12, q8\n"
+ "vpadal.s16 q13, q9\n"
+
+ "cmp r1, r2\n"
+
+ "bne 1b\n"
+
+ "3:\n"
+
+ // Now pack the last (num_rows % 16) rows.
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
+ "ands r2, r3, #15\n"
+ "beq 4f\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n"
+ "vdup.8 q0, r3\n"
+ "vdup.8 q1, r3\n"
+
+// Read/accumulate/write for src_ptr0 and src_ptr1.
+#define RUY_LOAD_ONE_ROW1(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \
+ "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \
+
+ RUY_LOAD_ONE_ROW1(0, 0)
+ RUY_LOAD_ONE_ROW1(1, 1)
+ RUY_LOAD_ONE_ROW1(2, 2)
+ RUY_LOAD_ONE_ROW1(3, 3)
+ RUY_LOAD_ONE_ROW1(4, 4)
+ RUY_LOAD_ONE_ROW1(5, 5)
+ RUY_LOAD_ONE_ROW1(6, 6)
+ RUY_LOAD_ONE_ROW1(7, 7)
+#undef RUY_LOAD_ONE_ROW1
+
+#define RUY_LOAD_ONE_ROW2(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \
+ "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \
+
+ RUY_LOAD_ONE_ROW2(8, 0)
+ RUY_LOAD_ONE_ROW2(9, 1)
+ RUY_LOAD_ONE_ROW2(10, 2)
+ RUY_LOAD_ONE_ROW2(11, 3)
+ RUY_LOAD_ONE_ROW2(12, 4)
+ RUY_LOAD_ONE_ROW2(13, 5)
+ RUY_LOAD_ONE_ROW2(14, 6)
+ RUY_LOAD_ONE_ROW2(15, 7)
+#undef RUY_LOAD_ONE_ROW2
+
+ "5:\n"
+
+ "veor.16 q4, q0, q11\n"
+ "veor.16 q5, q1, q11\n"
+
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+
+ // Pairwise add accumulate to 4x32b accumulators.
+ "vpadal.s16 q12, q8\n"
+ "vpadal.s16 q13, q9\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ "4:\n"
+
+ // Pairwise add 32-bit accumulators
+ "vpadd.i32 d24, d24, d25\n"
+ "vpadd.i32 d26, d26, d27\n"
+ // Final 32-bit values per row
+ "vpadd.i32 d25, d24, d26\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n"
+ "cmp r3, #0\n"
+ "beq 6f\n"
+ "vst1.32 {d25}, [r3]!\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1)
+ : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1),
+ [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(&params)
+ : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
+ "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13");
+}
+
+#undef RUY_OFFSET_SRC_PTR0
+#undef RUY_OFFSET_SRC_PTR1
+#undef RUY_OFFSET_SRC_PTR2
+#undef RUY_OFFSET_SRC_PTR32
+#undef RUY_OFFSET_SUMS_PTR
+#undef RUY_OFFSET_PACKED_PTR0
+#undef RUY_OFFSET_SRC_INC0
+#undef RUY_OFFSET_SRC_INC1
+#undef RUY_OFFSET_SRC_INC2
+#undef RUY_OFFSET_SRC_INC3
+#undef RUY_OFFSET_SRC_ROWS
+#undef RUY_OFFSET_SRC_ZERO_POINT
+#undef RUY_OFFSET_INPUT_XOR
+
+#endif // RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
+
+#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+
+void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows,
+ int src_zero_point, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr, int input_xor) {
+ profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)");
+ asm volatile(
+ // clang-format off
+ // v26 will be the vector to XOR input values with to perform
+ // any input data type conversion (e.g. uint8 to int8).
+ "dup v26.16b, %w[input_xor]\n"
+ // w1 will be the number of rows already loaded.
+ "mov w1, #0\n"
+ // v28--v32 will be used to accumulate the sums
+ "movi v28.4s, #0\n"
+ "movi v29.4s, #0\n"
+ "movi v30.4s, #0\n"
+ "movi v31.4s, #0\n"
+ // Let w2 be `rows` rounded down to multiple of 16.
+ "ands w2, %w[rows], #-16\n"
+ // If there are no full blocks of 16 rows to process, jump to the
+ // code handling the last < 16 rows.
+ "beq 3f\n"
+ // Load the first block of 16 rows.
+ "add w1, w1, #16\n"
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ldr x13, [%[src_ptr3], #8]\n"
+ "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+ // Check if these were the only full block of 16 rows to load.
+ "cmp w1, w2\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
+ // In that case, jump to the code handling the last loaded block of
+ // 16 rows.
+ "beq 2f\n"
+ // Main loop processing blocks of 16 rows.
+ "1:\n"
+ // Load the next 16 rows, interleaved with the XOR input type
+ // conversion (e.g. uint8->int8) on the already loaded inputs.
+ "add w1, w1, #16\n"
+ "ins v0.d[1], x10\n"
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "ins v1.d[1], x11\n"
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "ins v2.d[1], x12\n"
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "ins v3.d[1], x13\n"
+ "ldr x13, [%[src_ptr3], #8]\n"
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+ "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+ // Compute the sums, interleaved with storing to the packed matrix.
+ "saddlp v16.8h, v4.16b\n"
+ "str q4, [%[packed_ptr], #0]\n"
+ "saddlp v17.8h, v5.16b\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "saddlp v18.8h, v6.16b\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "saddlp v19.8h, v7.16b\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "sadalp v28.4s, v16.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
+ // Was this the last block of 16 rows to load?
+ "cmp w1, w2\n"
+ "sadalp v29.4s, v17.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+ "sadalp v30.4s, v18.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
+ "sadalp v31.4s, v19.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
+ // End of main loop on blocks of 16 rows.
+ "bne 1b\n"
+
+ // Code handling the last already-loaded block of 16 rows.
+ "2:\n"
+ // Process the last loaded full 16x4 block.
+ "ins v0.d[1], x10\n"
+ "ins v1.d[1], x11\n"
+ "ins v2.d[1], x12\n"
+ "ins v3.d[1], x13\n"
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "str q4, [%[packed_ptr], #0]\n"
+ "saddlp v17.8h, v5.16b\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "saddlp v18.8h, v6.16b\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "saddlp v19.8h, v7.16b\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "sadalp v28.4s, v16.8h\n"
+ "sadalp v29.4s, v17.8h\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+
+ // End of code handling full blocks of 16 rows.
+ // Now we handle any remaining rows.
+ "3:\n"
+ // Let w2 be the number of rows left to handle.
+ "ands w2, %w[rows], #15\n"
+ // If w2==0, there are no remaining rows, jump to the end.
+ "beq 4f\n"
+ // Zero out a 16x4 block in registers, which we'll partially overwrite
+ // with any remaining rows.
+ "dup v0.16b, %w[src_zero_point]\n"
+ "dup v1.16b, %w[src_zero_point]\n"
+ "dup v2.16b, %w[src_zero_point]\n"
+ "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+ RUY_LOAD_ONE_ROW(4)
+ RUY_LOAD_ONE_ROW(5)
+ RUY_LOAD_ONE_ROW(6)
+ RUY_LOAD_ONE_ROW(7)
+ RUY_LOAD_ONE_ROW(8)
+ RUY_LOAD_ONE_ROW(9)
+ RUY_LOAD_ONE_ROW(10)
+ RUY_LOAD_ONE_ROW(11)
+ RUY_LOAD_ONE_ROW(12)
+ RUY_LOAD_ONE_ROW(13)
+ RUY_LOAD_ONE_ROW(14)
+ // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ // Process the last zero-padded 16x4 block.
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "saddlp v17.8h, v5.16b\n"
+ "saddlp v18.8h, v6.16b\n"
+ "saddlp v19.8h, v7.16b\n"
+ "sadalp v28.4s, v16.8h\n"
+ "sadalp v29.4s, v17.8h\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "str q4, [%[packed_ptr], #0]\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+
+ "4:\n"
+
+ // Horizontal reduction of the registers used to accumulate sums.
+ "addp v28.4s, v28.4s, v29.4s\n"
+ "addp v30.4s, v30.4s, v31.4s\n"
+ "addp v28.4s, v28.4s, v30.4s\n"
+
+ // Store the sums.
+ "cmp %[sums_ptr], #0\n"
+ "beq 6f\n"
+ "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+ [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+ [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+ [ rows ] "r"(src_rows),
+ [ src_zero_point ] "r"(src_zero_point),
+ [input_xor] "r"(input_xor)
+ : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5",
+ "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
+ "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
+ "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+void Pack8bitColMajorForNeonDotprodA55ish(
+ const void* src_ptr0, const void* src_ptr1, const void* src_ptr2,
+ const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr, int input_xor) {
+ profiler::ScopeLabel label(
+ "Pack (kNeonDotprod, optimized for in-order cores)");
+ asm volatile(
+ // clang-format off
+ // v26 will be the vector to XOR input values with to perform
+ // any input data type conversion (e.g. uint8 to int8).
+ "dup v26.16b, %w[input_xor]\n"
+ // v27 will be filled with 1's. It will be used as an operand
+ // to SDOT to compute the sums.
+ "mov w1, #1\n"
+ "dup v27.16b, w1\n"
+ // w1 will be the number of rows already loaded.
+ "mov w1, #0\n"
+ // v28--v32 will be used to accumulate the sums
+ "movi v28.4s, #0\n"
+ "movi v29.4s, #0\n"
+ "movi v30.4s, #0\n"
+ "movi v31.4s, #0\n"
+
+ // Let w2 be `rows` rounded down to multiple of 16.
+ "ands w2, %w[rows], #-16\n"
+ // If there are no full blocks of 16 rows to process, jump to the
+ // code handling the last < 16 rows.
+ "beq 3f\n"
+ // Load the first block of 16 rows.
+ "add w1, w1, #16\n"
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ldr x13, [%[src_ptr3], #8]\n"
+ "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+ // Check if these were the only full block of 16 rows to load.
+ "cmp w1, w2\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
+ // In that case, jump to the code handling the last loaded block of
+ // 16 rows.
+ "beq 2f\n"
+
+ // Main loop processing blocks of 16 rows.
+ "1:\n"
+ "add w1, w1, #16\n"
+ // Prepare the already-loaded 16 rows by inserting the parts
+ // loaded into general purpose registers x10--x13 into the
+ // NEON registers v0--v3 where the other parts had already been
+ // loaded.
+ "ins v0.d[1], x10\n"
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "ins v1.d[1], x11\n"
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "ins v2.d[1], x12\n"
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "ins v3.d[1], x13\n"
+ "ldr x13, [%[src_ptr3], #8]\n"
+
+ // Load the next 16 rows and, interleaved with that,
+ // perform the input type conversion (e.g. uint8->int8) on the
+ // current 16 rows.
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+ "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+
+ // Transposition of 4x4 blocks, part 1
+ "trn1 v16.4s, v4.4s, v5.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
+ "trn2 v17.4s, v4.4s, v5.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
+ "trn1 v18.4s, v6.4s, v7.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
+ "trn2 v19.4s, v6.4s, v7.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
+
+ // Transposition of 4x4 blocks, part 2
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+ "cmp w1, w2\n"
+
+ // Store the block to the packed matrix and, interleaved with
+ // that, compute sums using sdot instructions.
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+ // End of main loop on blocks of 16 rows.
+ "bne 1b\n"
+
+ // Code handling the last already-loaded block of 16 rows.
+ "2:\n"
+ // Process the last loaded full 16x4 block.
+ "ins v0.d[1], x10\n"
+ "ins v1.d[1], x11\n"
+ "ins v2.d[1], x12\n"
+ "ins v3.d[1], x13\n"
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // End of code handling full blocks of 16 rows.
+ // Now we handle any remaining rows.
+ "3:\n"
+ // Let w2 be the number of rows left to handle.
+ "ands w2, %w[rows], #15\n"
+ // If w2==0, there are no remaining rows, jump to the end.
+ "beq 4f\n"
+ // Zero out a 16x4 block in registers, which we'll partially overwrite
+ // with any remaining rows.
+ "dup v0.16b, %w[src_zero_point]\n"
+ "dup v1.16b, %w[src_zero_point]\n"
+ "dup v2.16b, %w[src_zero_point]\n"
+ "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+ RUY_LOAD_ONE_ROW(4)
+ RUY_LOAD_ONE_ROW(5)
+ RUY_LOAD_ONE_ROW(6)
+ RUY_LOAD_ONE_ROW(7)
+ RUY_LOAD_ONE_ROW(8)
+ RUY_LOAD_ONE_ROW(9)
+ RUY_LOAD_ONE_ROW(10)
+ RUY_LOAD_ONE_ROW(11)
+ RUY_LOAD_ONE_ROW(12)
+ RUY_LOAD_ONE_ROW(13)
+ RUY_LOAD_ONE_ROW(14)
+ // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
+#undef RUY_LOAD_ONE_ROW
+
+ "5:\n"
+ // Process the last zero-padded 16x4 block.
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ "cmp w2, #4\n"
+ "ble 4f\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "cmp w2, #8\n"
+ "ble 4f\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "cmp w2, #12\n"
+ "ble 4f\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "4:\n"
+
+ // Reduction of the registers used to accumulate sums.
+ "add v28.4s, v28.4s, v29.4s\n"
+ "add v30.4s, v30.4s, v31.4s\n"
+ "add v28.4s, v28.4s, v30.4s\n"
+
+ // Store the sums.
+ "cmp %[sums_ptr], #0\n"
+ "beq 6f\n"
+ "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
+ [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+ [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+ [rows] "r"(src_rows),
+ [src_zero_point] "r"(static_cast<int>(src_zero_point)),
+ [input_xor] "r"(input_xor)
+ : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+ "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows,
+ int src_zero_point, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr, int input_xor) {
+ profiler::ScopeLabel label("Pack (kNeonDotprod)");
+ asm volatile(
+ // clang-format off
+ // v26 will be the vector to XOR input values with to perform
+ // any input data type conversion (e.g. uint8 to int8).
+ "dup v26.16b, %w[input_xor]\n"
+ // v27 will be filled with 1's. It will be used as an operand
+ // to SDOT to compute the sums.
+ "mov w1, #1\n"
+ "dup v27.16b, w1\n"
+ // w1 will be the number of rows already loaded.
+ "mov w1, #0\n"
+ // v28--v32 will be used to accumulate the sums
+ "movi v28.4s, #0\n"
+ "movi v29.4s, #0\n"
+ "movi v30.4s, #0\n"
+ "movi v31.4s, #0\n"
+
+ // 4x partially unrolled code processing blocks of 64 rows.
+ // Read the original loop below first, it has more comments.
+#if RUY_OPT(MAX_STREAMING)
+ // Let w2 be `rows` rounded down to multiple of 64.
+ // Each iteration of this 4x partially unrolled loop handles
+ // 64 rows.
+ "ands w2, %w[rows], #-64\n"
+ // If there are no full blocks of 64 rows to process, jump to
+ // the main loop below handling 16 rows per iteration.
+ "beq 9f\n"
+ // Load the first block of 64 rows.
+ "add w1, w1, #64\n"
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ // Was that the last full block of 64 rows to load?
+ "cmp w1, w2\n"
+ "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ // Then jump to the end of the 64-rows-at-a-time code.
+ "beq 8f\n"
+
+ // Start of the main 4x partially unrolled loop.
+ "7:\n"
+ // Process rows 0 -- 15 out of 64.
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // Process rows 16 -- 31 out of 64.
+ "eor v4.16b, v4.16b, v26.16b\n"
+ "eor v5.16b, v5.16b, v26.16b\n"
+ "eor v6.16b, v6.16b, v26.16b\n"
+ "eor v7.16b, v7.16b, v26.16b\n"
+
+ "trn1 v16.4s, v4.4s, v5.4s\n"
+ "trn2 v17.4s, v4.4s, v5.4s\n"
+ "trn1 v18.4s, v6.4s, v7.4s\n"
+ "trn2 v19.4s, v6.4s, v7.4s\n"
+
+ "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // Process rows 32 -- 47 out of 64.
+ "eor v8.16b, v8.16b, v26.16b\n"
+ "eor v9.16b, v9.16b, v26.16b\n"
+ "eor v10.16b, v10.16b, v26.16b\n"
+ "eor v11.16b, v11.16b, v26.16b\n"
+
+ "trn1 v16.4s, v8.4s, v9.4s\n"
+ "trn2 v17.4s, v8.4s, v9.4s\n"
+ "trn1 v18.4s, v10.4s, v11.4s\n"
+ "trn2 v19.4s, v10.4s, v11.4s\n"
+
+ "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // Process rows 48 -- 63 out of 64.
+ "eor v12.16b, v12.16b, v26.16b\n"
+ "eor v13.16b, v13.16b, v26.16b\n"
+ "eor v14.16b, v14.16b, v26.16b\n"
+ "eor v15.16b, v15.16b, v26.16b\n"
+
+ "trn1 v16.4s, v12.4s, v13.4s\n"
+ "trn2 v17.4s, v12.4s, v13.4s\n"
+ "trn1 v18.4s, v14.4s, v15.4s\n"
+ "trn2 v19.4s, v14.4s, v15.4s\n"
+
+ "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "cmp w1, w2\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // End of main 4x partially unrolled loop.
+ "bne 7b\n"
+
+ // Last part of the 4x partially unrolled code:
+ // handle the last already-loaded 64 rows.
+ "8:\n"
+
+ // Process rows 0 -- 15 out of 64.
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // Process rows 16 -- 31 out of 64.
+ "eor v4.16b, v4.16b, v26.16b\n"
+ "eor v5.16b, v5.16b, v26.16b\n"
+ "eor v6.16b, v6.16b, v26.16b\n"
+ "eor v7.16b, v7.16b, v26.16b\n"
+
+ "trn1 v16.4s, v4.4s, v5.4s\n"
+ "trn2 v17.4s, v4.4s, v5.4s\n"
+ "trn1 v18.4s, v6.4s, v7.4s\n"
+ "trn2 v19.4s, v6.4s, v7.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // Process rows 32 -- 47 out of 64.
+ "eor v8.16b, v8.16b, v26.16b\n"
+ "eor v9.16b, v9.16b, v26.16b\n"
+ "eor v10.16b, v10.16b, v26.16b\n"
+ "eor v11.16b, v11.16b, v26.16b\n"
+
+ "trn1 v16.4s, v8.4s, v9.4s\n"
+ "trn2 v17.4s, v8.4s, v9.4s\n"
+ "trn1 v18.4s, v10.4s, v11.4s\n"
+ "trn2 v19.4s, v10.4s, v11.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // Process rows 48 -- 63 out of 64.
+ "eor v12.16b, v12.16b, v26.16b\n"
+ "eor v13.16b, v13.16b, v26.16b\n"
+ "eor v14.16b, v14.16b, v26.16b\n"
+ "eor v15.16b, v15.16b, v26.16b\n"
+
+ "trn1 v16.4s, v12.4s, v13.4s\n"
+ "trn2 v17.4s, v12.4s, v13.4s\n"
+ "trn1 v18.4s, v14.4s, v15.4s\n"
+ "trn2 v19.4s, v14.4s, v15.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "9:\n"
+#endif // #if RUY_OPT(MAX_STREAMING)
+ // End of 4x partially unrolled code processing blocks of 64 rows.
+
+ // Main part of the code, processing blocks of 16 rows.
+
+ // Let w2 be `rows` rounded down to multiple of 16.
+ "and w2, %w[rows], #-16\n"
+ // If there are no full blocks of 16 rows to process, jump to the
+ // code handling the last < 16 rows.
+ "cmp w1, w2\n"
+ "beq 3f\n"
+
+ // Load the first block of 16 rows.
+ "add w1, w1, #16\n"
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ // Check if these were the only full block of 16 rows to load.
+ "cmp w1, w2\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ // In that case, jump to the code handling the last loaded block of
+ // 16 rows.
+ "beq 2f\n"
+ // Main loop processing blocks of 16 rows.
+ "1:\n"
+ // Input type conversion (e.g. uint8->int8).
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+ // Transposition of 4x4 blocks, part 1
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+ // Load the next 16 rows
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+ // Transposition of 4x4 blocks, part 2
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+ // Compute sums using sdot instructions.
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ // Store the block to the packed matrix.
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "cmp w1, w2\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+ // End of main loop on blocks of 16 rows.
+ "bne 1b\n"
+
+ // Code handling the last already-loaded block of 16 rows.
+ "2:\n"
+
+ // Process the last loaded full 16x4 block.
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // End of code handling full blocks of 16 rows.
+ // Now we handle any remaining rows.
+ "3:\n"
+ // Let w2 be the number of rows left to handle.
+ "ands w2, %w[rows], #15\n"
+ // If w2==0, there are no remaining rows, jump to the end.
+ "beq 4f\n"
+ // Zero out a 16x4 block in registers, which we'll partially overwrite
+ // with any remaining rows.
+ "dup v0.16b, %w[src_zero_point]\n"
+ "dup v1.16b, %w[src_zero_point]\n"
+ "dup v2.16b, %w[src_zero_point]\n"
+ "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+ RUY_LOAD_ONE_ROW(4)
+ RUY_LOAD_ONE_ROW(5)
+ RUY_LOAD_ONE_ROW(6)
+ RUY_LOAD_ONE_ROW(7)
+ RUY_LOAD_ONE_ROW(8)
+ RUY_LOAD_ONE_ROW(9)
+ RUY_LOAD_ONE_ROW(10)
+ RUY_LOAD_ONE_ROW(11)
+ RUY_LOAD_ONE_ROW(12)
+ RUY_LOAD_ONE_ROW(13)
+ RUY_LOAD_ONE_ROW(14)
+ // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
+#undef RUY_LOAD_ONE_ROW
+
+ "5:\n"
+ // Process the last zero-padded 16x4 block.
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ "cmp w2, #4\n"
+ "ble 4f\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "cmp w2, #8\n"
+ "ble 4f\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "cmp w2, #12\n"
+ "ble 4f\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "4:\n"
+
+ // Reduction of the registers used to accumulate sums.
+ "add v28.4s, v28.4s, v29.4s\n"
+ "add v30.4s, v30.4s, v31.4s\n"
+ "add v28.4s, v28.4s, v30.4s\n"
+
+ // Store the sums.
+ "cmp %[sums_ptr], #0\n"
+ "beq 6f\n"
+ "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+ "6:\n"
+ // clang-format on
+
+ : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
+ [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
+ [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
+ : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
+ [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
+ [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
+ [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
+ [rows] "r"(src_rows),
+ [src_zero_point] "r"(static_cast<int>(src_zero_point)),
+ [input_xor] "r"(input_xor)
+ : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+}
+
+void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_cols,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, std::int32_t* sums_ptr,
+ int input_xor) {
+ profiler::ScopeLabel label("Pack (kNeonDotprod, from row-major)");
+ asm(
+ // clang-format off
+ // Prefetch data. This was tuned on Cortex-A55-rev1 cores.
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0]]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1]]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2]]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3]]\n")
+ // Let w0 = (number of columns to compute) - 8.
+ "subs w0, %w[src_cols], 8\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 64]\n")
+ // Let v26 duplicate the input_xor value in all lanes.
+ "dup v26.16b, %w[input_xor]\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 64]\n")
+ // Let v27 be 1 in all lanes. Used with sdot to compute sums.
+ "movi v27.16b, 1\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 64]\n")
+ // If there isn't a full block of 8 columns to load from, jump to the
+ // code after the loop handling leftovers.
+ "blt 2f\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 64]\n")
+ // Main loop, each iteration handles a full block of 8 cols.
+ "1:\n"
+ // Load the 4x8 block from the source matrix, or zero if we're
+ // past the bottom of the source matrix.
+ "ld1 {v0.8b}, [%[src_ptr0]]\n"
+ "ld1 {v1.8b}, [%[src_ptr1]]\n"
+ "ld1 {v2.8b}, [%[src_ptr2]]\n"
+ "ld1 {v3.8b}, [%[src_ptr3]]\n"
+ // Load values from the sums buffer, and start the reordering
+ // of the loaded 4x8 block by interleaving 8bit values.
+ "zip1 v0.16b, v0.16b, v1.16b\n"
+ "ldr q8, [%[sums_ptr], 0]\n"
+ "zip1 v1.16b, v2.16b, v3.16b\n"
+ "ldr q9, [%[sums_ptr], 16]\n"
+ // Finish the reordering of the 4x8 block, putting it into
+ // column-major order.
+ "zip1 v2.8h, v0.8h, v1.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 128]\n")
+ "zip2 v3.8h, v0.8h, v1.8h\n"
+ // Apply input_xor, i.e. convert source values from uint8 to int8
+ // if needed.
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 128]\n")
+ "eor v2.16b, v2.16b, v26.16b\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 128]\n")
+ "eor v3.16b, v3.16b, v26.16b\n"
+ // Update the sums.
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 128]\n")
+ ".word 0x4e9b9448 // sdot v8.4s, v2.16b, v27.16b\n"
+ ".word 0x4e9b9469 // sdot v9.4s, v3.16b, v27.16b\n"
+ // Store the column-major 4x8 block to the packed matrix, and
+ // increment some source pointers.
+ "str q2, [%[packed_ptr], 0]\n"
+ "add %[src_ptr0], %[src_ptr0], %w[src_inc0], sxtw\n"
+ "str q3, [%[packed_ptr], 16]\n"
+ "add %[src_ptr1], %[src_ptr1], %w[src_inc1], sxtw\n"
+ // Store the updated sums, and increment the remaining pointers
+ // and the block_col loop index.
+ "st1 {v8.4s}, [%[sums_ptr]], 16\n"
+ "add %[packed_ptr], %[packed_ptr], %[packed_stride], lsl 3\n"
+ "st1 {v9.4s}, [%[sums_ptr]], 16\n"
+ // Advance by 8 columns and set the condition code.
+ "subs w0, w0, 8\n"
+ "add %[src_ptr2], %[src_ptr2], %w[src_inc2], sxtw\n"
+ "add %[src_ptr3], %[src_ptr3], %w[src_inc3], sxtw\n"
+ // End of the main loop.
+ "bge 1b\n"
+
+ "2:\n"
+ // We add back 8 to w0 so that w0 is the number of columns remaining
+ // to handle.
+ "adds w0, w0, 8\n"
+ // Nothing left? Then jump to the end.
+ "beq 3f\n"
+ // Here w0 is between 1 and 7. We zero-initialize v0--v3 ...
+ "dup v0.8b, %w[src_zero_point]\n"
+ "dup v1.8b, %w[src_zero_point]\n"
+ "dup v2.8b, %w[src_zero_point]\n"
+ "dup v3.8b, %w[src_zero_point]\n"
+ // ... and now we fill lanes one by one with leftover columns.
+#define RUY_LOAD_ONE_COL(C)\
+ "cmp w0, " #C "\n" \
+ "beq 4f\n" \
+ "ld1 { v0.b }[" #C "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #C "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #C "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #C "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_COL(0)
+ RUY_LOAD_ONE_COL(1)
+ RUY_LOAD_ONE_COL(2)
+ RUY_LOAD_ONE_COL(3)
+ RUY_LOAD_ONE_COL(4)
+ RUY_LOAD_ONE_COL(5)
+ RUY_LOAD_ONE_COL(6)
+ // Here we know that w0==7, so RUY_LOAD_ONE_COL(7) would be a no-op.
+#undef RUY_LOAD_ONE_COL
+
+ "4:\n"
+ // The leftovers source data is loaded, now we can perform the
+ // computation as usual.
+ // Load values from the sums buffer, and start the reordering
+ // of the loaded 4x8 block by interleaving 8bit values.
+ "zip1 v0.16b, v0.16b, v1.16b\n"
+ "ldr q8, [%[sums_ptr], 0]\n"
+ "zip1 v1.16b, v2.16b, v3.16b\n"
+ "ldr q9, [%[sums_ptr], 16]\n"
+ // Finish the reordering of the 4x8 block, putting it into
+ // column-major order.
+ "zip1 v2.8h, v0.8h, v1.8h\n"
+ "zip2 v3.8h, v0.8h, v1.8h\n"
+ // Apply input_xor, i.e. convert source values from uint8 to int8
+ // if needed.
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+ // Update the sums.
+ ".word 0x4e9b9448 // sdot v8.4s, v2.16b, v27.16b\n"
+ ".word 0x4e9b9469 // sdot v9.4s, v3.16b, v27.16b\n"
+ // Store the column-major 4x8 block to the packed matrix, and
+ // increment some source pointers.
+ "str q2, [%[packed_ptr], 0]\n"
+ "str q3, [%[packed_ptr], 16]\n"
+ // Store the updated sums, and increment the remaining pointers
+ // and the block_col loop index.
+ "st1 {v8.4s}, [%[sums_ptr]], 16\n"
+ "st1 {v9.4s}, [%[sums_ptr]], 16\n"
+
+ // End label.
+ "3:\n"
+ // clang-format on
+ : [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr),
+ [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
+ [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3)
+ : [src_inc0] "r"(src_inc0), [src_inc1] "r"(src_inc1),
+ [src_inc2] "r"(src_inc2), [src_inc3] "r"(src_inc3),
+ [input_xor] "r"(input_xor), [src_zero_point] "r"(src_zero_point),
+ [packed_stride] "r"(static_cast<std::int64_t>(packed_stride)),
+ [src_cols] "r"(src_cols)
+ : "cc", "memory", "x0", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5",
+ "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+}
+
+void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, float* packed_ptr) {
+ profiler::ScopeLabel label("Pack (kNeon)");
+ asm volatile(
+ // clang-format off
+ // w1 will be the number of rows already loaded.
+ "mov w1, #0\n"
+ // Let w2 be `rows` rounded down to multiple of 4.
+ "ands w2, %w[rows], #-4\n"
+ // If there are no full blocks of 4 rows to process, jump to the
+ // code handling the last < 4 rows.
+ "beq 3f\n"
+ // Load the first block of 16 rows.
+ "add w1, w1, #4\n"
+ "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+ // Check if these were the only full block of 4 rows to load.
+ "cmp w1, w2\n"
+ "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+ // In that case, jump to the code handling the last loaded block of
+ // 4 rows.
+ "beq 2f\n"
+ // Main loop processing blocks of 4 rows.
+ "1:\n"
+ // Advance by 4 rows.
+ "add w1, w1, #4\n"
+ // Transposition of the already-loaded 4x4 block, part 1.
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+ // Load the next 4x4 block.
+ "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+ // Transposition of the already-loaded 4x4 block, part 2.
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+ // Was this the last full 4x4 block to load?
+ "cmp w1, w2\n"
+ // Store the transposed 4x4 block.
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+ // End of main loop on 4x4 blocks.
+ "bne 1b\n"
+
+ // Code handling the last already-loaded 4x4 block.
+ "2:\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ // End of code handling full 4x4 blocks.
+ // Now we handle any remaining rows.
+ "3:\n"
+ // Let w2 be the number of rows left to handle.
+ "ands w2, %w[rows], #3\n"
+ // If w2==0, there are no remaining rows, jump to the end.
+ "beq 4f\n"
+ // Zero out a 4x4 block in registers, which we'll partially overwrite
+ // with any remaining rows.
+ "movi v0.16b, #0\n"
+ "movi v1.16b, #0\n"
+ "movi v2.16b, #0\n"
+ "movi v3.16b, #0\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
+ "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
+ "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
+ "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ // Here we know that w2==3, so RUY_LOAD_ONE_ROW(3) would be a no-op.
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ // Transpose that last zero-padded 4x4 block.
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ // Store that last zero-padded block to the packed matrix.
+ "mov x1, #32\n"
+#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
+ "cmp w2, #" #ROW "\n" \
+ "beq 4f\n" \
+ "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
+
+ RUY_STORE_ONE_ROW(0, v20)
+ RUY_STORE_ONE_ROW(1, v21)
+ RUY_STORE_ONE_ROW(2, v22)
+ RUY_STORE_ONE_ROW(3, v23)
+
+#undef RUY_STORE_ONE_ROW
+
+ "4:\n"
+
+ // clang-format on
+
+ : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
+ [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
+ [packed_ptr] "+r"(packed_ptr)
+ : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
+ [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
+ [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
+ [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
+ [rows] "r"(src_rows)
+ : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1",
+ "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
+ "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#endif
+
+#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
+void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc, int src_rows, float* packed_ptr,
+ int output_stride) {
+ profiler::ScopeLabel label("Pack (kNeon)");
+ asm volatile(
+ // clang-format off
+ "mov r1, #0\n"
+ "and r2, %[rows], #-4\n"
+ "cmp r1, r2\n"
+ "beq 3f\n"
+#define RUY_LOAD_FOUR_BY_FOUR() \
+ /* Load q0 */ \
+ "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \
+ /* if src_inc0 != 0, add 16 to src_ptr0 */ \
+ "and r3, %[src_inc], #1\n" \
+ "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\
+ /* Load q1 */ \
+ "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \
+ /* if src_inc1 != 0, add 16 to src_ptr0 */ \
+ "and r3, %[src_inc], #2\n" \
+ "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\
+ /* Load q2 */ \
+ "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \
+ /* if src_inc2 != 0, add 16 to src_ptr0 */ \
+ "and r3, %[src_inc], #4\n" \
+ "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\
+ /* Load q3 */ \
+ "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \
+ /* if src_inc3 != 0, add 16 to src_ptr0 */ \
+ "and r3, %[src_inc], #8\n" \
+ "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\
+
+ RUY_LOAD_FOUR_BY_FOUR()
+ "add r1, r1, #4\n"
+ "cmp r1, r2\n"
+
+ "beq 2f\n"
+
+ "1:\n"
+ "add r1, r1, #4\n"
+
+ // Transpose 4x4 matrix.
+ "vzip.32 q0, q1\n"
+ "vzip.32 q2, q3\n"
+
+ "vtrn.32 q0, q2\n"
+ "vtrn.32 q1, q3\n"
+
+ "vzip.32 q0, q2\n"
+ "vzip.32 q1, q3\n"
+
+ "vmov q8, q0\n"
+ "vmov q9, q1\n"
+ "vmov q10, q2\n"
+ "vmov q11, q3\n"
+
+ RUY_LOAD_FOUR_BY_FOUR()
+#undef RUY_LOAD_FOUR_BY_FOUR
+
+#define RUY_STORE_FOUR_BY_FOUR() \
+ /* Store q8, q10, q9, q11 */ \
+ /* q8 = d16, d17 */ \
+ "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \
+ /* q10 = d20, d21 */ \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+ "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \
+ /* q9 = d18, d19 */ \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+ "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \
+ /* q11 = d22, d23 */ \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+ "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+
+ RUY_STORE_FOUR_BY_FOUR()
+ "cmp r1, r2\n"
+
+ "bne 1b\n"
+
+ "2:\n"
+
+ // Transpose 4x4 matrix.
+ "vzip.32 q0, q1\n"
+ "vzip.32 q2, q3\n"
+
+ "vtrn.32 q0, q2\n"
+ "vtrn.32 q1, q3\n"
+
+ "vzip.32 q0, q2\n"
+ "vzip.32 q1, q3\n"
+
+ "vmov q8, q0\n"
+ "vmov q9, q1\n"
+ "vmov q10, q2\n"
+ "vmov q11, q3\n"
+
+ RUY_STORE_FOUR_BY_FOUR()
+#undef RUY_STORE_FOUR_BY_FOUR
+ "3:\n"
+
+ "ands r2, %[rows], #3\n"
+ "beq 4f\n"
+ "mov r0, #0\n"
+ // Zero out q0 - q3
+ "vdup.32 q0, r0\n"
+ "vdup.32 q1, r0\n"
+ "vdup.32 q2, r0\n"
+ "vdup.32 q3, r0\n"
+#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \
+ "cmp r2, #" #R "\n" \
+ "beq 5f\n" \
+ "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \
+ "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \
+ "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \
+ "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n"
+
+#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \
+ "cmp r2, #" #R "\n" \
+ "beq 5f\n" \
+ "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \
+ "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \
+ "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \
+ "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n"
+
+ RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0)
+ RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1)
+ RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0)
+ RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1)
+#undef RUY_LOAD_ONE_ROW_SECOND_HALF
+#undef RUY_LOAD_ONE_ROW_FIRST_HALF
+ "5:\n"
+
+ // Transpose 4x4 matrix.
+ "vzip.32 q0, q1\n"
+ "vzip.32 q2, q3\n"
+
+ "vtrn.32 q0, q2\n"
+ "vtrn.32 q1, q3\n"
+
+ "vzip.32 q0, q2\n"
+ "vzip.32 q1, q3\n"
+
+ "vmov q8, q0\n"
+ "vmov q9, q1\n"
+ "vmov q10, q2\n"
+ "vmov q11, q3\n"
+
+ "mov r1, #32\n"
+
+#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
+ "cmp r2, #" #ROW "\n" \
+ "beq 4f\n" \
+ "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n"
+
+ // Store q8
+ RUY_STORE_ONE_ROW(0, q8)
+ // Store q10
+ RUY_STORE_ONE_ROW(1, q10)
+ // Store q9
+ RUY_STORE_ONE_ROW(2, q9)
+ // Store q11
+ RUY_STORE_ONE_ROW(3, q11)
+
+#undef RUY_STORE_ONE_ROW
+
+ "4:\n"
+
+ // clang-format on
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+ [ packed_ptr ] "+r"(packed_ptr)
+ : [ src_inc ] "r"(static_cast<std::int64_t>(src_inc)),
+ [ rows ] "r"(src_rows), [ stride ] "r"(output_stride)
+ : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
+ "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
+}
+
+#endif // (RUY_PLATFORM_NEON_32
+
+#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+void PackFloatColMajorForNeonA55ish(const float* src_ptr0,
+ const float* src_ptr1,
+ const float* src_ptr2,
+ const float* src_ptr3, int src_inc0,
+ int src_inc1, int src_inc2, int src_inc3,
+ int src_rows, float* packed_ptr) {
+ profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)");
+
+ asm volatile(
+ // clang-format off
+ "mov w1, #0\n"
+
+ "and w2, %w[rows], #-4\n"
+ "cmp w1, w2\n"
+ "beq 3f\n"
+ "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
+ "add w1, w1, #4\n"
+ "cmp w1, w2\n"
+
+ "beq 2f\n"
+
+ "1:\n"
+ "add w1, w1, #4\n"
+
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
+ "ldr x13, [%[src_ptr3], #8]\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
+
+ "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n"
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+ "cmp w1, w2\n"
+
+ "ins v0.d[1], x10\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ "ins v1.d[1], x11\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "ins v2.d[1], x12\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "ins v3.d[1], x13\n"
+ "str q23, [%[packed_ptr], #96]\n"
+
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "bne 1b\n"
+
+ "2:\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "3:\n"
+
+ "ands w2, %w[rows], #3\n"
+ "beq 4f\n"
+ "movi v0.16b, #0\n"
+ "movi v1.16b, #0\n"
+ "movi v2.16b, #0\n"
+ "movi v3.16b, #0\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
+ "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
+ "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
+ "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ "mov x1, #32\n"
+
+#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
+ "cmp w2, #" #ROW "\n" \
+ "beq 4f\n" \
+ "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
+
+ RUY_STORE_ONE_ROW(0, v20)
+ RUY_STORE_ONE_ROW(1, v21)
+ RUY_STORE_ONE_ROW(2, v22)
+ RUY_STORE_ONE_ROW(3, v23)
+
+#undef RUY_STORE_ONE_ROW
+
+ "4:\n"
+
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
+ [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
+ [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), [rows] "r"(src_rows)
+ : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+ "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+
+#if RUY_PLATFORM_NEON
+
+namespace {
+// transpose_*bit_vals are wrappers around ARM TRN1 instructions, allowing
+// to use these instructions like we would in assembly --- this is one instance
+// where assembly is more idiomatic than intrinsics.
+//
+// The way that TRN1 is exposed by vtrn_* intrinsics makes its usage very
+// cumbersome. The issue is that transposing grouped of values has been exposed
+// only as transposing values of a wider type, so this requires many
+// vreinterpret's, and to make it worse, vtrn_* return NEON array types like
+// int8x8x2_t for which vreinterpret's are not defined!
+void transpose_8bit_vals(int8x8_t& a, int8x8_t& b) {
+ int8x8x2_t t = vtrn_s8(a, b);
+ a = t.val[0];
+ b = t.val[1];
+}
+
+void transpose_16bit_vals(int8x8_t& a, int8x8_t& b) {
+ int16x4x2_t t = vtrn_s16(vreinterpret_s16_s8(a), vreinterpret_s16_s8(b));
+ a = vreinterpret_s8_s16(t.val[0]);
+ b = vreinterpret_s8_s16(t.val[1]);
+}
+
+void transpose_32bit_vals(int8x8_t& a, int8x8_t& b) {
+ int32x2x2_t t = vtrn_s32(vreinterpret_s32_s8(a), vreinterpret_s32_s8(b));
+ a = vreinterpret_s8_s32(t.val[0]);
+ b = vreinterpret_s8_s32(t.val[1]);
+}
+} // namespace
+
+void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
+ int src_rows, int src_cols, int block_row,
+ int start_col, int end_col,
+ std::int8_t* packed_ptr, int packed_stride,
+ int packed_zero_point, std::int32_t* sums,
+ int input_xor, int kernel_cols) {
+ profiler::ScopeLabel label("Pack (kNeon, from row-major)");
+
+ int src_end_col = std::min(end_col, src_cols);
+ int col = start_col;
+ for (; col <= src_end_col - 8; col += 8) {
+ // Each iteration of this loop handles 8 columns, and the kernel format
+ // has 16 rows, so each iteration handles a 16x8 block.
+ //
+ // Since the source is row-major, handling 8 columns at a time means
+ // loading only 8 bytes i.e. 64bit from each row. This may seem surprising
+ // on 128bit SIMD like NEON. While we could handle 16 columns at a time,
+ // we prefer to stick with 8 for the following reasons:
+ // 1. The arithmetic (computing sums and transposing data) done on these
+ // values is such that even though we initially start from 64bit vectors,
+ // most of our NEON instructions are full 128bit instructions. For the
+ // sums computation, that is because summing 8bit values requires
+ // expansion to 16bit anyway. For the matrix transposition code, that is
+ // because the ARM ZIP instructions take 64bit of data from two input
+ // registers and zip it into a 128bit output. If we had 128bit of data
+ // in each input registers, we would need 2x more ARM NEON instructions
+ // to zip it.
+ // 2. The main optimization target for this (ARM, 8bit, non-dotprod)
+ // code path is in-order ARM cores such as the Cortex-A53, which prefer
+ // 64bit loads anyway.
+ // 3. Handling only 8 columns at a time limits the size of the final
+ // leftover columns handled with slow scalar code.
+ //
+ // This code is not very optimized anyway, as evidenced from the facts that
+ // (1) it's written in intrinsics, (2) it's not using separate versions
+ // tuned for different types of CPU cores. At the level of optimization that
+ // it's working at, this seems like a fair compromise. If one wanted to
+ // maximize performance at the cost of more code complexity/size, one could
+ // have code handling 16 columns at a time (maybe limited to
+ // Tuning::kGeneric), then 8, then 4 to minimize the amount of slow
+ // leftovers.
+ //
+ // Load 8 sums in sums0, sums1.
+ int32x4_t sums0 = vld1q_s32(sums + col);
+ int32x4_t sums1 = vld1q_s32(sums + col + 4);
+ // Load the 8x16 block from the source matrix.
+ // Each val* here is the data from one row.
+ int8x8_t val0, val1, val2, val3, val4, val5, val6, val7, val8, val9, val10,
+ val11, val12, val13, val14, val15;
+ // Even though this function takes a uint8_t* src_ptr, that's only a
+ // type-erased pointer (using uint8_t* so that pointer arithmetic is
+ // allowed). The actual type may be either uint8_t or int8_t. The only
+ // difference it makes is that if it's uint8_t then we need to flip the
+ // sign bit. This is specified by the input_xor value (which is 0x80 if the
+ // input data is uint8_t, and 0x0 otherwise).
+ auto load_and_convert = [=](const std::uint8_t* from) {
+ return vreinterpret_s8_u8(veor_u8(vdup_n_u8(input_xor), vld1_u8(from)));
+ };
+ if (block_row <= src_rows - 16) {
+ // Load data in the regular case: there are still 16 rows to be read from
+ // the source matrix.
+ val0 = load_and_convert(src_ptr + 0 * src_stride);
+ val1 = load_and_convert(src_ptr + 1 * src_stride);
+ val2 = load_and_convert(src_ptr + 2 * src_stride);
+ val3 = load_and_convert(src_ptr + 3 * src_stride);
+ val4 = load_and_convert(src_ptr + 4 * src_stride);
+ val5 = load_and_convert(src_ptr + 5 * src_stride);
+ val6 = load_and_convert(src_ptr + 6 * src_stride);
+ val7 = load_and_convert(src_ptr + 7 * src_stride);
+ val8 = load_and_convert(src_ptr + 8 * src_stride);
+ val9 = load_and_convert(src_ptr + 9 * src_stride);
+ val10 = load_and_convert(src_ptr + 10 * src_stride);
+ val11 = load_and_convert(src_ptr + 11 * src_stride);
+ val12 = load_and_convert(src_ptr + 12 * src_stride);
+ val13 = load_and_convert(src_ptr + 13 * src_stride);
+ val14 = load_and_convert(src_ptr + 14 * src_stride);
+ val15 = load_and_convert(src_ptr + 15 * src_stride);
+ } else {
+ // Boundary case: there are fewer than 16 rows to be read from the source
+ // matrix. We pad by the zero_point.
+ val0 = vdup_n_s8(packed_zero_point);
+ val1 = val0;
+ val2 = val0;
+ val3 = val0;
+ val4 = val0;
+ val5 = val0;
+ val6 = val0;
+ val7 = val0;
+ val8 = val0;
+ val9 = val0;
+ val10 = val0;
+ val11 = val0;
+ val12 = val0;
+ val13 = val0;
+ val14 = val0;
+ val15 = val0;
+ if (block_row + 0 < src_rows)
+ val0 = load_and_convert(src_ptr + 0 * src_stride);
+ if (block_row + 1 < src_rows)
+ val1 = load_and_convert(src_ptr + 1 * src_stride);
+ if (block_row + 2 < src_rows)
+ val2 = load_and_convert(src_ptr + 2 * src_stride);
+ if (block_row + 3 < src_rows)
+ val3 = load_and_convert(src_ptr + 3 * src_stride);
+ if (block_row + 4 < src_rows)
+ val4 = load_and_convert(src_ptr + 4 * src_stride);
+ if (block_row + 5 < src_rows)
+ val5 = load_and_convert(src_ptr + 5 * src_stride);
+ if (block_row + 6 < src_rows)
+ val6 = load_and_convert(src_ptr + 6 * src_stride);
+ if (block_row + 7 < src_rows)
+ val7 = load_and_convert(src_ptr + 7 * src_stride);
+ if (block_row + 8 < src_rows)
+ val8 = load_and_convert(src_ptr + 8 * src_stride);
+ if (block_row + 9 < src_rows)
+ val9 = load_and_convert(src_ptr + 9 * src_stride);
+ if (block_row + 10 < src_rows)
+ val10 = load_and_convert(src_ptr + 10 * src_stride);
+ if (block_row + 11 < src_rows)
+ val11 = load_and_convert(src_ptr + 11 * src_stride);
+ if (block_row + 12 < src_rows)
+ val12 = load_and_convert(src_ptr + 12 * src_stride);
+ if (block_row + 13 < src_rows)
+ val13 = load_and_convert(src_ptr + 13 * src_stride);
+ if (block_row + 14 < src_rows)
+ val14 = load_and_convert(src_ptr + 14 * src_stride);
+ if (block_row + 15 < src_rows)
+ val15 = load_and_convert(src_ptr + 15 * src_stride);
+ }
+ src_ptr += 8;
+ // Compute sums.
+ int16x8_t sums16_0 = vaddl_s8(val0, val1);
+ int16x8_t sums16_1 = vaddl_s8(val2, val3);
+ sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val4, val5));
+ sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val6, val7));
+ sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val8, val9));
+ sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val10, val11));
+ sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val12, val13));
+ sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val14, val15));
+ int16x8_t sums16 = vaddq_s16(sums16_0, sums16_1);
+ sums0 = vaddw_s16(sums0, vget_low_s16(sums16));
+ sums1 = vaddw_s16(sums1, vget_high_s16(sums16));
+ // Store sums.
+ vst1q_s32(sums + col, sums0);
+ vst1q_s32(sums + col + 4, sums1);
+
+ // Transpose the data, i.e. change the storage order of the
+ // 16x8 block, to convert from the row-major source to the
+ // column-major packed format.
+ //
+ // Before, for i in [0, 15], val<i> is the i-th row.
+ // After, for i in [0, 7], { val<i> val<i+8> } is the i-th column.
+ transpose_8bit_vals(val0, val1);
+ transpose_8bit_vals(val2, val3);
+ transpose_8bit_vals(val4, val5);
+ transpose_8bit_vals(val6, val7);
+ transpose_8bit_vals(val8, val9);
+ transpose_8bit_vals(val10, val11);
+ transpose_8bit_vals(val12, val13);
+ transpose_8bit_vals(val14, val15);
+ transpose_16bit_vals(val0, val2);
+ transpose_16bit_vals(val1, val3);
+ transpose_16bit_vals(val4, val6);
+ transpose_16bit_vals(val5, val7);
+ transpose_16bit_vals(val8, val10);
+ transpose_16bit_vals(val9, val11);
+ transpose_16bit_vals(val12, val14);
+ transpose_16bit_vals(val13, val15);
+ transpose_32bit_vals(val0, val4);
+ transpose_32bit_vals(val1, val5);
+ transpose_32bit_vals(val2, val6);
+ transpose_32bit_vals(val3, val7);
+ transpose_32bit_vals(val8, val12);
+ transpose_32bit_vals(val9, val13);
+ transpose_32bit_vals(val10, val14);
+ transpose_32bit_vals(val11, val15);
+ // Store to the packed_matrix.
+ std::int8_t* dst_ptr = packed_ptr;
+ vst1q_s8(dst_ptr, vcombine_s8(val0, val8));
+ vst1q_s8(dst_ptr + 16, vcombine_s8(val1, val9));
+ dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
+ vst1q_s8(dst_ptr, vcombine_s8(val2, val10));
+ vst1q_s8(dst_ptr + 16, vcombine_s8(val3, val11));
+ packed_ptr += 4 * packed_stride;
+ dst_ptr = packed_ptr;
+ vst1q_s8(dst_ptr, vcombine_s8(val4, val12));
+ vst1q_s8(dst_ptr + 16, vcombine_s8(val5, val13));
+ dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
+ vst1q_s8(dst_ptr, vcombine_s8(val6, val14));
+ vst1q_s8(dst_ptr + 16, vcombine_s8(val7, val15));
+ packed_ptr += 4 * packed_stride;
+ }
+ // Handle remaining columns, not fitting in a full block of 8 columns, but
+ // still true columns frome the source matrix (as opposed to the final columns
+ // below).
+ for (; col < src_end_col; col++) {
+ std::int32_t accum = 0;
+ std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
+ for (int r = 0; r < 16; r++) {
+ std::int8_t packed_val = (block_row + r < src_rows)
+ ? (src_ptr[r * src_stride] ^ input_xor)
+ : packed_zero_point;
+ accum += packed_val;
+ dst_ptr[r] = packed_val;
+ }
+ if (sums) {
+ sums[col] += accum;
+ }
+ src_ptr++;
+ if (((col + 1) & (kernel_cols - 1)) == 0) {
+ packed_ptr += kernel_cols * packed_stride;
+ }
+ }
+ // Handle the final columns of the packed matrix, beyond the last column of
+ // the source matrix. The values here don't matter, we just want to avoid
+ // leaving uninitialized data. Since the sums are already initialized above,
+ // we don't need to do anything about them here.
+ for (; col < end_col; col++) {
+ std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
+ std::memset(dst_ptr, 0, 16);
+ if (((col + 1) & (kernel_cols - 1)) == 0) {
+ packed_ptr += kernel_cols * packed_stride;
+ }
+ }
+}
+
+#endif
+
+} // namespace ruy
diff --git a/ruy/pack_arm.h b/ruy/pack_arm.h
new file mode 100644
index 0000000..ba8964d
--- /dev/null
+++ b/ruy/pack_arm.h
@@ -0,0 +1,613 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PACK_ARM_H_
+#define RUY_RUY_PACK_ARM_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <type_traits>
+
+#include "ruy/asm_helpers.h"
+#include "ruy/check_macros.h"
+#include "ruy/mat.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack_common.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_NEON
+RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
+RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
+
+RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 8)
+#if RUY_PLATFORM_NEON_32
+RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 4)
+#endif
+
+template <>
+struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
+ using Type = std::int8_t;
+};
+template <>
+struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
+ using Type = std::int8_t;
+};
+#endif
+
+#if RUY_PLATFORM_NEON
+void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
+ int src_rows, int src_cols, int block_row,
+ int start_col, int end_col,
+ std::int8_t* packed_ptr, int packed_stride,
+ int packed_zero_point, std::int32_t* sums_ptr,
+ int input_xor, int kernel_cols);
+#endif
+
+#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+
+void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+ int input_xor);
+void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows,
+ int src_zero_point, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr, int input_xor);
+void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows,
+ int src_zero_point, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr, int input_xor);
+void Pack8bitColMajorForNeonDotprodA55ish(
+ const void* src_ptr0, const void* src_ptr1, const void* src_ptr2,
+ const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr, int input_xor);
+void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_cols,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, std::int32_t* sums_ptr,
+ int input_xor);
+#elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
+
+struct PackParams8bit {
+ const void* src_ptr0;
+ const void* src_ptr1;
+ const void* src_ptr2;
+ const void* src_ptr3;
+ const std::int32_t* sums_ptr;
+ const std::int8_t* packed_ptr;
+ int src_inc0;
+ int src_inc1;
+ int src_inc2;
+ int src_inc3;
+ int src_rows;
+ int src_zero_point;
+ int input_xor;
+};
+
+inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ const std::int32_t* sums_ptr,
+ const std::int8_t* packed_ptr, int src_inc0,
+ int src_inc1, int src_inc2, int src_inc3,
+ int src_rows, int src_zero_point, int input_xor,
+ PackParams8bit* params) {
+ params->src_ptr0 = src_ptr0;
+ params->src_ptr1 = src_ptr1;
+ params->src_ptr2 = src_ptr2;
+ params->src_ptr3 = src_ptr3;
+ params->sums_ptr = sums_ptr;
+ params->packed_ptr = packed_ptr;
+ params->src_inc0 = src_inc0;
+ params->src_inc1 = src_inc1;
+ params->src_inc2 = src_inc2;
+ params->src_inc3 = src_inc3;
+ params->src_rows = src_rows;
+ params->src_zero_point = src_zero_point;
+ params->input_xor = input_xor;
+}
+
+void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params);
+void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params);
+
+#endif // (RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
+
+#if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM)
+
+template <typename Scalar>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar,
+ std::int8_t, std::int32_t, Order::kColMajor> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 4, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[16];
+ memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+ for (int block_col = start_col; block_col < end_col; block_col += 4) {
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const Scalar* src_ptr1 = src_ptr0 + src_stride;
+ const Scalar* src_ptr2 = src_ptr1 + src_stride;
+ const Scalar* src_ptr3 = src_ptr2 + src_stride;
+ int src_inc0 = 16;
+ int src_inc1 = 16;
+ int src_inc2 = 16;
+ int src_inc3 = 16;
+ if (block_col >= src_matrix.layout.cols - 3) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ std::int8_t* packed_ptr =
+ packed_matrix->data + packed_matrix->layout.stride * block_col;
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+#if RUY_PLATFORM_NEON_64
+ if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
+ Pack8bitColMajorForNeonA55ish(
+ src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+ src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, sums_ptr, kInputXor);
+ } else {
+ Pack8bitColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
+ src_inc0, src_inc1, src_inc2, src_inc3,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, sums_ptr, kInputXor);
+ }
+#else
+ (void)tuning;
+ // We have a more limited set of general purpose registers in ARMv7, so
+ // we use the "params" struct technique from the kernel code to save
+ // registers.
+ PackParams8bit params;
+ MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr,
+ packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ kInputXor, &params);
+ Pack8bitColMajorForNeon4Cols(params);
+#endif // RUY_PLATFORM_NEON_64
+ }
+ }
+};
+
+#endif // (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) &&
+ // RUY_OPT(ASM)
+
+#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
+// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional
+// partial specialization for the RHS, which has a FixedKernelLayout with 2
+// columns.
+template <typename Scalar>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 2>, Scalar,
+ std::int8_t, std::int32_t, Order::kColMajor> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 2, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[16];
+ memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+ for (int block_col = start_col; block_col < end_col; block_col += 2) {
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const Scalar* src_ptr1 = src_ptr0 + src_stride;
+ int src_inc0 = 16;
+ int src_inc1 = 16;
+ if (block_col >= src_matrix.layout.cols - 2) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ }
+ std::int8_t* packed_ptr =
+ packed_matrix->data + packed_matrix->layout.stride * block_col;
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ PackParams8bit params;
+ MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr,
+ packed_ptr, src_inc0, src_inc1, -1, -1,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ kInputXor, &params);
+ Pack8bitColMajorForNeon2Cols(params);
+ }
+ }
+};
+#endif // (RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)
+
+#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+template <typename Scalar>
+struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 8, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[16];
+ memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+ for (int block_col = start_col; block_col < end_col; block_col += 4) {
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const Scalar* src_ptr1 = src_ptr0 + src_stride;
+ const Scalar* src_ptr2 = src_ptr1 + src_stride;
+ const Scalar* src_ptr3 = src_ptr2 + src_stride;
+ std::int64_t src_inc0 = 16;
+ std::int64_t src_inc1 = 16;
+ std::int64_t src_inc2 = 16;
+ std::int64_t src_inc3 = 16;
+ if (block_col >= src_matrix.layout.cols - 3) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & ~7) +
+ ((block_col & 4) * 4);
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
+ Pack8bitColMajorForNeonDotprodA55ish(
+ src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+ src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, sums_ptr, kInputXor);
+ } else {
+ Pack8bitColMajorForNeonDotprod(
+ src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+ src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, sums_ptr, kInputXor);
+ }
+ }
+ }
+};
+#endif // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM)
+
+#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, float* packed_ptr);
+void PackFloatColMajorForNeonA55ish(const float* src_ptr0,
+ const float* src_ptr1,
+ const float* src_ptr2,
+ const float* src_ptr3, int src_inc0,
+ int src_inc1, int src_inc2, int src_inc3,
+ int src_rows, float* packed_ptr);
+
+#elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
+void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc, int src_rows, float* packed_ptr,
+ int stride);
+#endif // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM)
+
+#if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM)
+
+template <>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float, Order::kColMajor> {
+ static void Run(Tuning tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 8, 0);
+ const float zerobuf[4] = {0};
+ for (int block_col = start_col; block_col < end_col; block_col += 4) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ std::int64_t src_inc0 = 16;
+ std::int64_t src_inc1 = 16;
+ std::int64_t src_inc2 = 16;
+ std::int64_t src_inc3 = 16;
+ if (block_col >= src_matrix.layout.cols - 3) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ float* packed_ptr = packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & ~7) +
+ ((block_col & 4));
+#if RUY_PLATFORM_NEON_64
+ if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
+ PackFloatColMajorForNeonA55ish(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
+ src_inc0, src_inc1, src_inc2, src_inc3,
+ src_matrix.layout.rows, packed_ptr);
+ } else {
+ PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
+ src_inc0, src_inc1, src_inc2, src_inc3,
+ src_matrix.layout.rows, packed_ptr);
+ }
+#else
+ (void)tuning;
+ // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc
+ // to save on registers (we have fewer general purpose registers in
+ // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four
+ // values that are each either 16 or 0 and use them directly. For the
+ // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should
+ // use the value 16 (bit is set) or 0 (bit is not set) for the
+ // respective increment value.
+ std::int64_t src_inc = 0;
+ src_inc += src_inc0 == 16 ? 1 : 0;
+ src_inc += src_inc1 == 16 ? 2 : 0;
+ src_inc += src_inc2 == 16 ? 4 : 0;
+ src_inc += src_inc3 == 16 ? 8 : 0;
+ const int kOutputStride = 32;
+ PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
+ src_matrix.layout.rows, packed_ptr,
+ kOutputStride);
+#endif // RUY_PLATFORM_NEON_64
+ }
+ }
+};
+
+#if RUY_PLATFORM_NEON_32
+// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional
+// specialization for a FixedKernelLayout with 4 columns.
+template <>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
+ float, float, Order::kColMajor> {
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 4, 0);
+ const float zerobuf[4] = {0};
+ for (int block_col = start_col; block_col < end_col; block_col += 4) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ std::int64_t src_inc0 = 16;
+ std::int64_t src_inc1 = 16;
+ std::int64_t src_inc2 = 16;
+ std::int64_t src_inc3 = 16;
+ if (block_col >= src_matrix.layout.cols - 3) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ float* packed_ptr =
+ packed_matrix->data + packed_matrix->layout.stride * (block_col);
+ // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc
+ // to save registers.
+ std::int64_t src_inc = 0;
+ src_inc += src_inc0 == 16 ? 1 : 0;
+ src_inc += src_inc1 == 16 ? 2 : 0;
+ src_inc += src_inc2 == 16 ? 4 : 0;
+ src_inc += src_inc3 == 16 ? 8 : 0;
+ const int kOutputStride = 16;
+ PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
+ src_matrix.layout.rows, packed_ptr,
+ kOutputStride);
+ }
+ }
+};
+#endif // (RUY_PLATFORM_NEON_32)
+#endif // (RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && \
+ // RUY_OPT(ASM)
+
+#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+
+template <typename Scalar>
+struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsRowMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 8, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+ Scalar zerobuf[8];
+ memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+ int src_stride = src_matrix.layout.stride;
+ // As the source matrix is row-major and the destination packed matrix is
+ // column-major, there is no traversal order that will be optimal for both
+ // so we choose to favor the source matrix with a row-major traversal order.
+ // Loop over groups of 4 rows.
+ for (int block_row = 0; block_row < packed_matrix->layout.rows;
+ block_row += 4) {
+ // src_ptr[0-3] shall point to the positions in the 4 rows of the source
+ // matrix that we are loading from, and will be incremented by
+ // src_inc[0-3] after each 4x8 block is loaded.
+ // First we compute these src_ptr and src_inc values for the case where
+ // there are 4 rows left to be loaded from in the source matrix ...
+ const Scalar* src_ptr0 =
+ src_matrix.data.get() + src_stride * block_row + start_col;
+ const Scalar* src_ptr1 = src_ptr0 + src_stride;
+ const Scalar* src_ptr2 = src_ptr1 + src_stride;
+ const Scalar* src_ptr3 = src_ptr2 + src_stride;
+ std::int64_t src_inc0 = 8;
+ std::int64_t src_inc1 = 8;
+ std::int64_t src_inc2 = 8;
+ std::int64_t src_inc3 = 8;
+ // ... and now we adjust these values in case there are fewer than 4 rows
+ // left to load from in the source matrix. In that case, we set the
+ // corresponding src_ptr pointer to load from `zerobuf` and set src_inc
+ // to 0 to avoid overrunning that small buffer.
+ if (block_row >= src_matrix.layout.rows - 3) {
+ if (block_row >= src_matrix.layout.rows - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_row >= src_matrix.layout.rows - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_row >= src_matrix.layout.rows - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_row >= src_matrix.layout.rows - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ // Let src_cols be the number of source matrix columns to handle.
+ int src_cols = std::min(end_col, src_matrix.layout.cols) - start_col;
+ std::int8_t* packed_ptr = packed_matrix->data +
+ packed_matrix->layout.stride * start_col +
+ 8 * block_row;
+ std::int32_t* sums_ptr = sums + start_col;
+ Pack8bitRowMajorForNeonDotprod(
+ src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, src_inc2,
+ src_inc3, src_cols, src_matrix.zero_point, packed_ptr,
+ packed_matrix->layout.stride, sums_ptr, kInputXor);
+ }
+ }
+};
+
+#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+
+#if RUY_PLATFORM_NEON
+
+template <typename Scalar, int KernelCols>
+struct PackImpl<Path::kNeon,
+ FixedKernelLayout<Order::kColMajor, 16, KernelCols>, Scalar,
+ std::int8_t, std::int32_t, Order::kRowMajor> {
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (KNeon, from row-major source)");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+ RUY_DCHECK_EQ((end_col - start_col) % KernelCols, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+ int block_row = 0;
+ for (; block_row < packed_matrix->layout.rows; block_row += 16) {
+ int src_stride = src_matrix.layout.stride;
+ int packed_stride = packed_matrix->layout.stride;
+ const Scalar* src_ptr =
+ src_matrix.data.get() + block_row * src_stride + start_col;
+ std::int8_t* packed_ptr = packed_matrix->data +
+ start_col * packed_stride +
+ block_row * KernelCols;
+
+ Pack8bitRowMajorForNeon(
+ reinterpret_cast<const std::uint8_t*>(src_ptr), src_stride,
+ src_matrix.layout.rows, src_matrix.layout.cols, block_row, start_col,
+ end_col, packed_ptr, packed_stride, packed_matrix->zero_point, sums,
+ kInputXor, KernelCols);
+ }
+ }
+};
+#endif
+
+} // namespace ruy
+
+#endif // RUY_RUY_PACK_ARM_H_
diff --git a/ruy/pack_avx.cc b/ruy/pack_avx.cc
new file mode 100644
index 0000000..2b929e7
--- /dev/null
+++ b/ruy/pack_avx.cc
@@ -0,0 +1,831 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack_x86.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
+
+void Pack8bitColMajorForAvx(const std::int8_t*, std::int8_t, const std::int8_t*,
+ int, int, int, std::int8_t*, std::int32_t*) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void PackFloatColMajorForAvx(const float*, const float*, int, int, int,
+ float*) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void Pack8bitRowMajorForAvx(const std::uint8_t*, int, int, std::int8_t*, int,
+ int, int, int, int, int, int, std::int32_t*) {
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM_AVX && RUY_OPT(ASM)
+
+// The first int8_t template parameter is arbitrary: this routine is common to
+// all 8-bit source matrix types.
+using PackImpl8bitAvx =
+ PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, std::int8_t,
+ std::int8_t, std::int32_t, Order::kColMajor>;
+
+using PackImplFloatAvx =
+ PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float, Order::kColMajor>;
+
+namespace {
+
+// Perform the equivalent of mm256_permutevar8x32 with
+// a second argument of {7, 5, 3, 1, 6, 4, 2, 0}
+inline __m256i PermuteEpi32EvenOdds(const __m256i& a) {
+ // a_lo = 3 2 1 0
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ // a_hi = 7 6 5 4
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ // shuffle a_lo to get 3 1 2 0
+ __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8);
+ // shuffle a_hi to get 7 5 6 4
+ __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8);
+ // unpack lo 64 of res_lo and res hi to get 6 4 2 0
+ __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi);
+ // unpack hi 64 of res_lo and res hi to get 7 5 3 1
+ __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi);
+ return _mm256_set_m128i(res_hi, res_lo);
+}
+
+inline __m128i mm256_extracti128_si256(const __m256i& a, const int imm) {
+ switch (imm) {
+ case 0:
+ return _mm256_extractf128_si256(a, 0);
+ case 1:
+ return _mm256_extractf128_si256(a, 1);
+ default:
+ RUY_DCHECK_LT(imm, 2);
+ return _mm_setzero_si128();
+ }
+}
+
+inline __m256i mm256_cvtepi8_epi16(const __m128i& a) {
+ // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
+ __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
+ return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a));
+}
+
+inline __m256i mm256_cvtepi16_epi32(const __m128i& a) {
+ // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
+ __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
+ return _mm256_set_m128i(_mm_cvtepi16_epi32(hi), _mm_cvtepi16_epi32(a));
+}
+
+inline __m256i mm256_xor_si256(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_xor_si128(a_lo, b_lo);
+ __m128i hi = _mm_xor_si128(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m256i mm256_unpacklo_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_unpacklo_epi32(a_lo, b_lo);
+ __m128i hi = _mm_unpacklo_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m256i mm256_unpacklo_epi64(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_unpacklo_epi64(a_lo, b_lo);
+ __m128i hi = _mm_unpacklo_epi64(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m256i mm256_unpackhi_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_unpackhi_epi32(a_lo, b_lo);
+ __m128i hi = _mm_unpackhi_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m256i mm256_unpackhi_epi64(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_unpackhi_epi64(a_lo, b_lo);
+ __m128i hi = _mm_unpackhi_epi64(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_add_epi32(a_lo, b_lo);
+ __m128i hi = _mm_add_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m256i mm256_add_epi16(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_add_epi16(a_lo, b_lo);
+ __m128i hi = _mm_add_epi16(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_madd_epi16(a_lo, b_lo);
+ __m128i hi = _mm_madd_epi16(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b,
+ const int imm) {
+ __m128i tmp = _mm_setzero_si128();
+ if (!(imm & 8)) {
+ switch (imm & 3) {
+ case 0:
+ return _mm256_extractf128_si256(a, 0);
+ case 1:
+ return _mm256_extractf128_si256(a, 1);
+ case 2:
+ return _mm256_extractf128_si256(b, 0);
+ case 3:
+ return _mm256_extractf128_si256(b, 1);
+ }
+ }
+ return tmp;
+}
+
+inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b,
+ const int imm) {
+ const int lo_imm = imm & 15;
+ __m128i lo = mm_permute_helper(a, b, lo_imm);
+ const int hi_imm = (imm >> 4) & 15;
+ __m128i hi = mm_permute_helper(a, b, hi_imm);
+ return _mm256_set_m128i(hi, lo);
+}
+
+inline void Pack8bitColMajorForAvxPacker(const std::int8_t* src_ptr,
+ std::int8_t input_xor,
+ const std::int8_t* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr,
+ std::int8_t* trailing_buf) {
+ using Layout = PackImpl8bitAvx::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 8;
+ constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+
+ const std::int8_t* src_ptr0 = src_ptr;
+ const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
+ const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
+ const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
+ const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
+ const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
+ const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
+ const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = kNumChunkedSrcRows;
+ std::int64_t src_inc1 = kNumChunkedSrcRows;
+ std::int64_t src_inc2 = kNumChunkedSrcRows;
+ std::int64_t src_inc3 = kNumChunkedSrcRows;
+ std::int64_t src_inc4 = kNumChunkedSrcRows;
+ std::int64_t src_inc5 = kNumChunkedSrcRows;
+ std::int64_t src_inc6 = kNumChunkedSrcRows;
+ std::int64_t src_inc7 = kNumChunkedSrcRows;
+ // Handle cases where source does not have Layout::kCols (8) columns.
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ const std::int8_t zero_point = zerobuf[0];
+
+ if (sums_ptr) {
+ // i: Layout::kCols.
+ for (int i = 0; i < 8; ++i) {
+ sums_ptr[i] = 0;
+ }
+ }
+ std::int32_t sums_adjustment = 0;
+ const __m256i ones_16bit = _mm256_set1_epi16(1);
+ __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0);
+ __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0);
+
+ // The overall packing effectively pads the source rows to
+ // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
+ // only pack for (src_rows + 31) & ~31. When there is an incomplete
+ // destination block, this is stored into trailing_buf instead of packed_ptr.
+ for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
+ // Available source rows.
+ // If this is less than 0 (for m=1), we skip, having filled trailing
+ // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+ // exactly to the end of the column in the packed buffer.
+ const int available_src_rows = src_rows - k;
+ // Effectively,
+ // available rows = std::max(0, std::min(8, src_rows - k));
+ // treat each case separately.
+ if (available_src_rows >= kNumChunkedSrcRows) {
+ if (sums_ptr) {
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
+ t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
+ t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
+ t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
+ t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
+ t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
+ t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
+ t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
+
+ r0 = mm256_unpacklo_epi32(t0, t1);
+ r4 = mm256_unpacklo_epi32(t4, t5);
+ r2 = mm256_unpackhi_epi32(t0, t1);
+ r6 = mm256_unpackhi_epi32(t4, t5);
+ r1 = mm256_unpacklo_epi32(t2, t3);
+ r5 = mm256_unpacklo_epi32(t6, t7);
+ r3 = mm256_unpackhi_epi32(t2, t3);
+ r7 = mm256_unpackhi_epi32(t6, t7);
+
+ t0 = mm256_unpacklo_epi64(r0, r1);
+ t4 = mm256_unpacklo_epi64(r4, r5);
+ t2 = mm256_unpackhi_epi64(r0, r1);
+ t6 = mm256_unpackhi_epi64(r4, r5);
+ t1 = mm256_unpacklo_epi64(r2, r3);
+ t5 = mm256_unpacklo_epi64(r6, r7);
+ t3 = mm256_unpackhi_epi64(r2, r3);
+ t7 = mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = mm256_xor_si256(r0, input_xor_v);
+ r1 = mm256_xor_si256(r1, input_xor_v);
+ r2 = mm256_xor_si256(r2, input_xor_v);
+ r3 = mm256_xor_si256(r3, input_xor_v);
+ r4 = mm256_xor_si256(r4, input_xor_v);
+ r5 = mm256_xor_si256(r5, input_xor_v);
+ r6 = mm256_xor_si256(r6, input_xor_v);
+ r7 = mm256_xor_si256(r7, input_xor_v);
+
+ __m256i sums_4x4_16bit_lo;
+ sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
+
+ // The sums have been performed across columns, and now we have 4x16-bit
+ // sums packed together. We use madd for pairwise 32-bit sums.
+ const __m256i sums_4x2_32bit_lo_new =
+ mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
+ sums_4x2_32bit_lo =
+ mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
+
+ __m256i sums_4x4_16bit_hi;
+ sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1));
+ sums_4x4_16bit_hi = mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1)));
+ sums_4x4_16bit_hi = mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1)));
+ sums_4x4_16bit_hi = mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1)));
+ sums_4x4_16bit_hi = mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1)));
+ sums_4x4_16bit_hi = mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1)));
+ sums_4x4_16bit_hi = mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1)));
+ sums_4x4_16bit_hi = mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1)));
+
+ const __m256i sums_4x2_32bit_hi_new =
+ mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
+ sums_4x2_32bit_hi =
+ mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
+ r7);
+ } else {
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
+ t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
+ t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
+ t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
+ t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
+ t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
+ t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
+ t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
+
+ r0 = mm256_unpacklo_epi32(t0, t1);
+ r4 = mm256_unpacklo_epi32(t4, t5);
+ r2 = mm256_unpackhi_epi32(t0, t1);
+ r6 = mm256_unpackhi_epi32(t4, t5);
+ r1 = mm256_unpacklo_epi32(t2, t3);
+ r5 = mm256_unpacklo_epi32(t6, t7);
+ r3 = mm256_unpackhi_epi32(t2, t3);
+ r7 = mm256_unpackhi_epi32(t6, t7);
+
+ t0 = mm256_unpacklo_epi64(r0, r1);
+ t4 = mm256_unpacklo_epi64(r4, r5);
+ t2 = mm256_unpackhi_epi64(r0, r1);
+ t6 = mm256_unpackhi_epi64(r4, r5);
+ t1 = mm256_unpacklo_epi64(r2, r3);
+ t5 = mm256_unpacklo_epi64(r6, r7);
+ t3 = mm256_unpackhi_epi64(r2, r3);
+ t7 = mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = mm256_xor_si256(r0, input_xor_v);
+ r1 = mm256_xor_si256(r1, input_xor_v);
+ r2 = mm256_xor_si256(r2, input_xor_v);
+ r3 = mm256_xor_si256(r3, input_xor_v);
+ r4 = mm256_xor_si256(r4, input_xor_v);
+ r5 = mm256_xor_si256(r5, input_xor_v);
+ r6 = mm256_xor_si256(r6, input_xor_v);
+ r7 = mm256_xor_si256(r7, input_xor_v);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
+ r7);
+ }
+ } else if (available_src_rows > 0) {
+ RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
+ // We do not care what goes into the trailing buffer, but we want
+ // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
+ //
+ // We compensate for padding-with-zero_point by initializing the
+ // summations with the compensating offset, effectively
+ // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
+ // 4 * (8 - ((available_src_rows + 3) >> 2)).
+ //
+ // Note that (zero_point ^ input_xor) is performed in 8-bits and then
+ // cast.
+ sums_adjustment +=
+ -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2));
+
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr0);
+ t4 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr4);
+ t1 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr1);
+ t5 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr5);
+ t2 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr2);
+ t6 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr6);
+ t3 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr3);
+ t7 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr7);
+
+ r0 = mm256_unpacklo_epi32(t0, t1);
+ r4 = mm256_unpacklo_epi32(t4, t5);
+ r2 = mm256_unpackhi_epi32(t0, t1);
+ r6 = mm256_unpackhi_epi32(t4, t5);
+ r1 = mm256_unpacklo_epi32(t2, t3);
+ r5 = mm256_unpacklo_epi32(t6, t7);
+ r3 = mm256_unpackhi_epi32(t2, t3);
+ r7 = mm256_unpackhi_epi32(t6, t7);
+
+ t0 = mm256_unpacklo_epi64(r0, r1);
+ t4 = mm256_unpacklo_epi64(r4, r5);
+ t2 = mm256_unpackhi_epi64(r0, r1);
+ t6 = mm256_unpackhi_epi64(r4, r5);
+ t1 = mm256_unpacklo_epi64(r2, r3);
+ t5 = mm256_unpacklo_epi64(r6, r7);
+ t3 = mm256_unpackhi_epi64(r2, r3);
+ t7 = mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = mm256_xor_si256(r0, input_xor_v);
+ r1 = mm256_xor_si256(r1, input_xor_v);
+ r2 = mm256_xor_si256(r2, input_xor_v);
+ r3 = mm256_xor_si256(r3, input_xor_v);
+ r4 = mm256_xor_si256(r4, input_xor_v);
+ r5 = mm256_xor_si256(r5, input_xor_v);
+ r6 = mm256_xor_si256(r6, input_xor_v);
+ r7 = mm256_xor_si256(r7, input_xor_v);
+
+ __m256i sums_4x4_16bit_lo;
+ sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
+ sums_4x4_16bit_lo = mm256_add_epi16(
+ sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
+
+ // The sums have been performed across columns, and now we have 4x16-bit
+ // sums packed together. We use madd for pairwise 32-bit sums.
+ const __m256i sums_4x2_32bit_lo_new =
+ mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
+ sums_4x2_32bit_lo =
+ mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
+
+ __m256i sums_4x4_16bit_hi;
+ sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1));
+ sums_4x4_16bit_hi =
+ mm256_add_epi16(sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1)));
+ sums_4x4_16bit_hi =
+ mm256_add_epi16(sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1)));
+ sums_4x4_16bit_hi =
+ mm256_add_epi16(sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1)));
+ sums_4x4_16bit_hi =
+ mm256_add_epi16(sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1)));
+ sums_4x4_16bit_hi =
+ mm256_add_epi16(sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1)));
+ sums_4x4_16bit_hi =
+ mm256_add_epi16(sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1)));
+ sums_4x4_16bit_hi =
+ mm256_add_epi16(sums_4x4_16bit_hi,
+ mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1)));
+
+ const __m256i sums_4x2_32bit_hi_new =
+ mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
+ sums_4x2_32bit_hi =
+ mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4),
+ r7);
+ }
+
+ packed_ptr += 8 * kNumChunkedSrcRows;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+
+ if (sums_ptr) {
+ const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
+
+ __m256i sums =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
+
+ // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
+ // neighbours, finshing up by adding them to the stored accumulated sums.
+ const __m256i sums_2x4_32bit_lo = PermuteEpi32EvenOdds(sums_4x2_32bit_lo);
+ const __m256i sums_2x4_32bit_hi = PermuteEpi32EvenOdds(sums_4x2_32bit_hi);
+ const __m256i sums_2x4_32bit_a =
+ mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20);
+ const __m256i sums_2x4_32bit_b =
+ mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31);
+ sums = mm256_add_epi32(sums, sums_adjustment_v);
+ sums = mm256_add_epi32(sums, sums_2x4_32bit_a);
+ sums = mm256_add_epi32(sums, sums_2x4_32bit_b);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
+ }
+}
+
+// Use a generic AVX intrinsic for greater-than comparison.
+template <>
+inline __m256i CompareGreaterThan<Path::kAvx>(const __m256i& a,
+ const __m256i& b) {
+ constexpr int kGreaterThanSignalling = 14;
+ const __m256 v = _mm256_cmp_ps(_mm256_cvtepi32_ps(a), _mm256_cvtepi32_ps(b),
+ kGreaterThanSignalling);
+ return _mm256_cvtps_epi32(v);
+}
+
+} // namespace.
+
+void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+ profiler::ScopeLabel label("Pack kAvx 8bit");
+
+ using Layout = PackImpl8bitAvx::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ static constexpr int kNumRowChunks = 8; // Short input is padded.
+
+ // Each packed block is 4*8, and there are normally 8. The trailing block is
+ // only slightly shorter.
+ constexpr int kTrailingBufSize =
+ kNumRowChunks * Layout::kCols * Layout::kRows;
+ std::int8_t trailing_buf[kTrailingBufSize];
+ memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+
+ Pack8bitColMajorForAvxPacker(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr,
+ sums_ptr, trailing_buf);
+
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+ const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+ // If the number of source rows is not a multiple of kChunkedRowMask, there
+ // will be data in the trailing buffer,
+ if (trailing_data) {
+ const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+ // Destination "rows" are padded to next highest multiple of Layout::kRows.
+ const int dst_rows = (src_rows + 3) & ~3;
+ const int trailing_rows = dst_rows - non_trailing_rows;
+ memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+ Layout::kCols * trailing_rows * sizeof(std::int8_t));
+ }
+}
+
+void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr) {
+ profiler::ScopeLabel label("Pack kAvx float");
+ static constexpr int kPackCols = 8; // Source cols packed together.
+ static constexpr int kPackRows = 8; // Short input is padded.
+ float trailing_buf[(kPackRows - 1) * kPackCols];
+ if (remaining_src_cols < 8) {
+ memset(trailing_buf, 0, sizeof(trailing_buf));
+ }
+ PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx, Path::kAvx>(
+ src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr,
+ trailing_buf);
+
+ const int trailing_rows = src_rows & (kPackRows - 1);
+ if (trailing_rows > 0) {
+ const int non_trailing_rows = src_rows & ~(kPackRows - 1);
+ memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
+ kPackCols * trailing_rows * sizeof(float));
+ }
+}
+
+void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums) {
+ int col = start_col;
+ int src_end_col = std::min(end_col, src_cols);
+
+ for (; col <= src_end_col - 8; col += 8) {
+ std::int8_t* dst_ptr = packed_ptr;
+ __m128i val0, val1, val2, val3;
+ __m128i input_xor_dup = _mm_set1_epi8(input_xor);
+ // Load a 4x8 block.
+ if (block_row + 4 <= src_rows) {
+ val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
+ val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
+ val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
+ val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
+ } else {
+ val0 = _mm_set1_epi8(src_zero_point);
+ val1 = val0;
+ val2 = val0;
+ val3 = val0;
+ if (block_row + 0 < src_rows)
+ val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
+ if (block_row + 1 < src_rows)
+ val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
+ if (block_row + 2 < src_rows)
+ val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
+ if (block_row + 3 < src_rows)
+ val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
+ }
+ // Maybe xor the sign bit to convert from uint8 to int8.
+ val0 = _mm_xor_si128(val0, input_xor_dup);
+ val1 = _mm_xor_si128(val1, input_xor_dup);
+ val2 = _mm_xor_si128(val2, input_xor_dup);
+ val3 = _mm_xor_si128(val3, input_xor_dup);
+ // Update the sums.
+ __m128i val16_0 = _mm_cvtepi8_epi16(val0);
+ __m128i val16_1 = _mm_cvtepi8_epi16(val1);
+ __m128i val16_2 = _mm_cvtepi8_epi16(val2);
+ __m128i val16_3 = _mm_cvtepi8_epi16(val3);
+ __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1),
+ _mm_add_epi16(val16_2, val16_3));
+ __m256i sum =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col));
+ sum = mm256_add_epi32(sum, mm256_cvtepi16_epi32(new_sum16));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum);
+ // Perform the transposition of 4x4 blocks
+ __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1);
+ __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3);
+ __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1);
+ __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1);
+ src_ptr += 8;
+ packed_ptr += packed_stride * 8;
+ }
+ for (; col < src_end_col; col++) {
+ std::int32_t accum = 0;
+ for (int r = 0; r < 4; r++) {
+ std::int8_t packed_val;
+ if (block_row + r < src_rows) {
+ packed_val = input_xor ^ src_ptr[r * src_stride];
+ } else {
+ packed_val = input_xor ^ src_zero_point;
+ }
+ accum += packed_val;
+ *packed_ptr++ = packed_val;
+ }
+ if (sums) {
+ sums[col] += accum;
+ }
+ src_ptr++;
+ }
+ for (; col < end_col; col++) {
+ std::memset(packed_ptr, 0, 4);
+ packed_ptr += 4;
+ }
+}
+
+#endif // RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS)
+
+} // namespace ruy
diff --git a/ruy/pack_avx2_fma.cc b/ruy/pack_avx2_fma.cc
new file mode 100644
index 0000000..2564b72
--- /dev/null
+++ b/ruy/pack_avx2_fma.cc
@@ -0,0 +1,689 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack_x86.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
+
+void Pack8bitColMajorForAvx2(const std::int8_t*, std::int8_t,
+ const std::int8_t*, int, int, int, std::int8_t*,
+ std::int32_t*) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void PackFloatColMajorForAvx2(const float*, const float*, int, int, int,
+ float*) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void Pack8bitRowMajorForAvx2(const std::uint8_t*, int, int, std::int8_t*, int,
+ int, int, int, int, int, int, std::int32_t*) {
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
+
+// The first int8_t template parameter is arbitrary: this routine is common to
+// all 8-bit source matrix types.
+using PackImpl8bitAvx2 =
+ PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
+
+using PackImplFloatAvx2 =
+ PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float, Order::kColMajor>;
+
+namespace {
+
+inline void Pack8bitColMajorForAvx2Packer(
+ const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride, int remaining_src_cols,
+ int src_rows, std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+ std::int8_t* trailing_buf) {
+ using Layout = PackImpl8bitAvx2::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 8;
+ constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+
+ const std::int8_t* src_ptr0 = src_ptr;
+ const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
+ const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
+ const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
+ const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
+ const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
+ const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
+ const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = kNumChunkedSrcRows;
+ std::int64_t src_inc1 = kNumChunkedSrcRows;
+ std::int64_t src_inc2 = kNumChunkedSrcRows;
+ std::int64_t src_inc3 = kNumChunkedSrcRows;
+ std::int64_t src_inc4 = kNumChunkedSrcRows;
+ std::int64_t src_inc5 = kNumChunkedSrcRows;
+ std::int64_t src_inc6 = kNumChunkedSrcRows;
+ std::int64_t src_inc7 = kNumChunkedSrcRows;
+ // Handle cases where source does not have Layout::kCols (8) columns.
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ const std::int8_t zero_point = zerobuf[0];
+
+ if (sums_ptr) {
+ // i: Layout::kCols.
+ for (int i = 0; i < 8; ++i) {
+ sums_ptr[i] = 0;
+ }
+ }
+ std::int32_t sums_adjustment = 0;
+ const __m256i ones_16bit = _mm256_set1_epi16(1);
+ __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0);
+ __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0);
+
+ // The overall packing effectively pads the source rows to
+ // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
+ // only pack for (src_rows + 31) & ~31. When there is an incomplete
+ // destination block, this is stored into trailing_buf instead of packed_ptr.
+ for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
+ // Available source rows.
+ // If this is less than 0 (for m=1), we skip, having filled trailing
+ // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+ // exactly to the end of the column in the packed buffer.
+ const int available_src_rows = src_rows - k;
+ // Effectively,
+ // available rows = std::max(0, std::min(8, src_rows - k));
+ // treat each case separately.
+ if (available_src_rows >= kNumChunkedSrcRows) {
+ if (sums_ptr) {
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
+ t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
+ t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
+ t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
+ t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
+ t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
+ t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
+ t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
+
+ r0 = _mm256_unpacklo_epi32(t0, t1);
+ r4 = _mm256_unpacklo_epi32(t4, t5);
+ r2 = _mm256_unpackhi_epi32(t0, t1);
+ r6 = _mm256_unpackhi_epi32(t4, t5);
+ r1 = _mm256_unpacklo_epi32(t2, t3);
+ r5 = _mm256_unpacklo_epi32(t6, t7);
+ r3 = _mm256_unpackhi_epi32(t2, t3);
+ r7 = _mm256_unpackhi_epi32(t6, t7);
+
+ t0 = _mm256_unpacklo_epi64(r0, r1);
+ t4 = _mm256_unpacklo_epi64(r4, r5);
+ t2 = _mm256_unpackhi_epi64(r0, r1);
+ t6 = _mm256_unpackhi_epi64(r4, r5);
+ t1 = _mm256_unpacklo_epi64(r2, r3);
+ t5 = _mm256_unpacklo_epi64(r6, r7);
+ t3 = _mm256_unpackhi_epi64(r2, r3);
+ t7 = _mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = _mm256_xor_si256(r0, input_xor_v);
+ r1 = _mm256_xor_si256(r1, input_xor_v);
+ r2 = _mm256_xor_si256(r2, input_xor_v);
+ r3 = _mm256_xor_si256(r3, input_xor_v);
+ r4 = _mm256_xor_si256(r4, input_xor_v);
+ r5 = _mm256_xor_si256(r5, input_xor_v);
+ r6 = _mm256_xor_si256(r6, input_xor_v);
+ r7 = _mm256_xor_si256(r7, input_xor_v);
+
+ __m256i sums_4x4_16bit_lo;
+ sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
+
+ // The sums have been performed across columns, and now we have 4x16-bit
+ // sums packed together. We use madd for pairwise 32-bit sums.
+ const __m256i sums_4x2_32bit_lo_new =
+ _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
+ sums_4x2_32bit_lo =
+ _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
+
+ __m256i sums_4x4_16bit_hi;
+ sums_4x4_16bit_hi =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1)));
+
+ const __m256i sums_4x2_32bit_hi_new =
+ _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
+ sums_4x2_32bit_hi =
+ _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
+ r7);
+ } else {
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
+ t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
+ t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
+ t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
+ t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
+ t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
+ t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
+ t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
+
+ r0 = _mm256_unpacklo_epi32(t0, t1);
+ r4 = _mm256_unpacklo_epi32(t4, t5);
+ r2 = _mm256_unpackhi_epi32(t0, t1);
+ r6 = _mm256_unpackhi_epi32(t4, t5);
+ r1 = _mm256_unpacklo_epi32(t2, t3);
+ r5 = _mm256_unpacklo_epi32(t6, t7);
+ r3 = _mm256_unpackhi_epi32(t2, t3);
+ r7 = _mm256_unpackhi_epi32(t6, t7);
+
+ t0 = _mm256_unpacklo_epi64(r0, r1);
+ t4 = _mm256_unpacklo_epi64(r4, r5);
+ t2 = _mm256_unpackhi_epi64(r0, r1);
+ t6 = _mm256_unpackhi_epi64(r4, r5);
+ t1 = _mm256_unpacklo_epi64(r2, r3);
+ t5 = _mm256_unpacklo_epi64(r6, r7);
+ t3 = _mm256_unpackhi_epi64(r2, r3);
+ t7 = _mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = _mm256_xor_si256(r0, input_xor_v);
+ r1 = _mm256_xor_si256(r1, input_xor_v);
+ r2 = _mm256_xor_si256(r2, input_xor_v);
+ r3 = _mm256_xor_si256(r3, input_xor_v);
+ r4 = _mm256_xor_si256(r4, input_xor_v);
+ r5 = _mm256_xor_si256(r5, input_xor_v);
+ r6 = _mm256_xor_si256(r6, input_xor_v);
+ r7 = _mm256_xor_si256(r7, input_xor_v);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
+ r7);
+ }
+ } else if (available_src_rows > 0) {
+ RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
+ // We do not care what goes into the trailing buffer, but we want
+ // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
+ //
+ // We compensate for padding-with-zero_point by initializing the
+ // summations with the compensating offset, effectively
+ // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
+ // 4 * (8 - ((available_src_rows + 3) >> 2)).
+ //
+ // Note that (zero_point ^ input_xor) is performed in 8-bits and then
+ // cast.
+ sums_adjustment +=
+ -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2));
+
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr0);
+ t4 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr4);
+ t1 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr1);
+ t5 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr5);
+ t2 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr2);
+ t6 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr6);
+ t3 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr3);
+ t7 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr7);
+
+ r0 = _mm256_unpacklo_epi32(t0, t1);
+ r4 = _mm256_unpacklo_epi32(t4, t5);
+ r2 = _mm256_unpackhi_epi32(t0, t1);
+ r6 = _mm256_unpackhi_epi32(t4, t5);
+ r1 = _mm256_unpacklo_epi32(t2, t3);
+ r5 = _mm256_unpacklo_epi32(t6, t7);
+ r3 = _mm256_unpackhi_epi32(t2, t3);
+ r7 = _mm256_unpackhi_epi32(t6, t7);
+
+ t0 = _mm256_unpacklo_epi64(r0, r1);
+ t4 = _mm256_unpacklo_epi64(r4, r5);
+ t2 = _mm256_unpackhi_epi64(r0, r1);
+ t6 = _mm256_unpackhi_epi64(r4, r5);
+ t1 = _mm256_unpacklo_epi64(r2, r3);
+ t5 = _mm256_unpacklo_epi64(r6, r7);
+ t3 = _mm256_unpackhi_epi64(r2, r3);
+ t7 = _mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = _mm256_xor_si256(r0, input_xor_v);
+ r1 = _mm256_xor_si256(r1, input_xor_v);
+ r2 = _mm256_xor_si256(r2, input_xor_v);
+ r3 = _mm256_xor_si256(r3, input_xor_v);
+ r4 = _mm256_xor_si256(r4, input_xor_v);
+ r5 = _mm256_xor_si256(r5, input_xor_v);
+ r6 = _mm256_xor_si256(r6, input_xor_v);
+ r7 = _mm256_xor_si256(r7, input_xor_v);
+
+ __m256i sums_4x4_16bit_lo;
+ sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
+
+ // The sums have been performed across columns, and now we have 4x16-bit
+ // sums packed together. We use madd for pairwise 32-bit sums.
+ const __m256i sums_4x2_32bit_lo_new =
+ _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
+ sums_4x2_32bit_lo =
+ _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
+
+ __m256i sums_4x4_16bit_hi;
+ sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1)));
+
+ const __m256i sums_4x2_32bit_hi_new =
+ _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
+ sums_4x2_32bit_hi =
+ _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4),
+ r7);
+ }
+
+ packed_ptr += 8 * kNumChunkedSrcRows;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+
+ if (sums_ptr) {
+ const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
+
+ __m256i sums =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
+ const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+
+ // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
+ // neighbours, finshing up by adding them to the stored accumulated sums.
+ const __m256i sums_2x4_32bit_lo =
+ _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx);
+ const __m256i sums_2x4_32bit_hi =
+ _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx);
+ const __m256i sums_2x4_32bit_a =
+ _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20);
+ const __m256i sums_2x4_32bit_b =
+ _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31);
+ sums = _mm256_add_epi32(sums, sums_adjustment_v);
+ sums = _mm256_add_epi32(sums, sums_2x4_32bit_a);
+ sums = _mm256_add_epi32(sums, sums_2x4_32bit_b);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
+ }
+}
+
+// Use AVX2 specific intrinsic for greater than comparison.
+template <>
+inline __m256i CompareGreaterThan<Path::kAvx2Fma>(const __m256i& a,
+ const __m256i& b) {
+ return _mm256_cmpgt_epi32(a, b);
+}
+
+} // namespace.
+
+void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+ profiler::ScopeLabel label("Pack kAvx2Fma 8bit");
+
+ using Layout = PackImpl8bitAvx2::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ static constexpr int kNumRowChunks = 8; // Short input is padded.
+
+ // Each packed block is 4*8, and there are normally 8. The trailing block is
+ // only slightly shorter.
+ constexpr int kTrailingBufSize =
+ kNumRowChunks * Layout::kCols * Layout::kRows;
+ std::int8_t trailing_buf[kTrailingBufSize];
+ memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+
+ Pack8bitColMajorForAvx2Packer(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr,
+ sums_ptr, trailing_buf);
+
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+ const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+ // If the number of source rows is not a multiple of kChunkedRowMask, there
+ // will be data in the trailing buffer,
+ if (trailing_data) {
+ const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+ // Destination "rows" are padded to next highest multiple of Layout::kRows.
+ const int dst_rows = (src_rows + 3) & ~3;
+ const int trailing_rows = dst_rows - non_trailing_rows;
+ memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+ Layout::kCols * trailing_rows * sizeof(std::int8_t));
+ }
+}
+
+void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr) {
+ profiler::ScopeLabel label("Pack kAvx2Fma float");
+ static constexpr int kPackCols = 8; // Source cols packed together.
+ static constexpr int kPackRows = 8; // Short input is padded.
+ float trailing_buf[(kPackRows - 1) * kPackCols];
+ if (remaining_src_cols < 8) {
+ memset(trailing_buf, 0, sizeof(trailing_buf));
+ }
+ PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx2, Path::kAvx2Fma>(
+ src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr,
+ trailing_buf);
+
+ const int trailing_rows = src_rows & (kPackRows - 1);
+ if (trailing_rows > 0) {
+ const int non_trailing_rows = src_rows & ~(kPackRows - 1);
+ memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
+ kPackCols * trailing_rows * sizeof(float));
+ }
+}
+
+void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums) {
+ int col = start_col;
+ int src_end_col = std::min(end_col, src_cols);
+
+ for (; col <= src_end_col - 8; col += 8) {
+ std::int8_t* dst_ptr = packed_ptr;
+ __m128i val0, val1, val2, val3;
+ __m128i input_xor_dup = _mm_set1_epi8(input_xor);
+ // Load a 4x8 block.
+ if (block_row + 4 <= src_rows) {
+ val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
+ val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
+ val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
+ val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
+ } else {
+ val0 = _mm_set1_epi8(src_zero_point);
+ val1 = val0;
+ val2 = val0;
+ val3 = val0;
+ if (block_row + 0 < src_rows)
+ val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
+ if (block_row + 1 < src_rows)
+ val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
+ if (block_row + 2 < src_rows)
+ val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
+ if (block_row + 3 < src_rows)
+ val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
+ }
+ // Maybe xor the sign bit to convert from uint8 to int8.
+ val0 = _mm_xor_si128(val0, input_xor_dup);
+ val1 = _mm_xor_si128(val1, input_xor_dup);
+ val2 = _mm_xor_si128(val2, input_xor_dup);
+ val3 = _mm_xor_si128(val3, input_xor_dup);
+ // Update the sums.
+ __m128i val16_0 = _mm_cvtepi8_epi16(val0);
+ __m128i val16_1 = _mm_cvtepi8_epi16(val1);
+ __m128i val16_2 = _mm_cvtepi8_epi16(val2);
+ __m128i val16_3 = _mm_cvtepi8_epi16(val3);
+ __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1),
+ _mm_add_epi16(val16_2, val16_3));
+ __m256i sum =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col));
+ sum = _mm256_add_epi32(sum, _mm256_cvtepi16_epi32(new_sum16));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum);
+ // Perform the transposition of 4x4 blocks
+ __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1);
+ __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3);
+ __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1);
+ __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1);
+ src_ptr += 8;
+ packed_ptr += packed_stride * 8;
+ }
+ for (; col < src_end_col; col++) {
+ std::int32_t accum = 0;
+ for (int r = 0; r < 4; r++) {
+ std::int8_t packed_val;
+ if (block_row + r < src_rows) {
+ packed_val = input_xor ^ src_ptr[r * src_stride];
+ } else {
+ packed_val = input_xor ^ src_zero_point;
+ }
+ accum += packed_val;
+ *packed_ptr++ = packed_val;
+ }
+ if (sums) {
+ sums[col] += accum;
+ }
+ src_ptr++;
+ }
+ for (; col < end_col; col++) {
+ std::memset(packed_ptr, 0, 4);
+ packed_ptr += 4;
+ }
+}
+
+#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
+
+} // namespace ruy
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc
new file mode 100644
index 0000000..5281fa8
--- /dev/null
+++ b/ruy/pack_avx512.cc
@@ -0,0 +1,828 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack_x86.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
+
+void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t,
+ const std::int8_t*, int, int, int, std::int8_t*,
+ std::int32_t*) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void PackFloatColMajorForAvx512(const float*, const float*, int, int, int,
+ float*) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
+ int, int, int, int, int, int, std::int32_t*) {
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
+
+// The first int8_t template parameter is arbitrary: this routine is common to
+// all 8-bit source matrix types.
+using PackImpl8bitAvx512 =
+ PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
+
+namespace {
+
+inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
+ std::int8_t* packed_ptr) {
+ using Layout = PackImpl8bitAvx512::Layout;
+ static constexpr int kHalfLayoutCols =
+ PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
+ // block.
+ RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+ RUY_DCHECK_EQ(Layout::kCols, 16);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ const int non_trailing_blocks = (src_rows & ~31) >> 2;
+ // This routine fills half blocks, and typically fills the second halves.
+ // Thus packed_ptr is already offset by 8 * 4.
+ for (int k = 0; k < non_trailing_blocks; ++k) {
+ for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
+ packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
+ }
+ }
+}
+
+inline __m512i LoaduTwo(const std::int8_t* addr_lo,
+ const std::int8_t* addr_hi) {
+ __m512i lower_filled = _mm512_castsi256_si512(
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo)));
+ return _mm512_inserti32x8(
+ lower_filled,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_hi)), 1);
+}
+
+inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
+ const std::int8_t* addr_lo,
+ const std::int8_t* addr_hi) {
+ const __m512i lower_filled = _mm512_castsi256_si512(
+ _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo));
+ return _mm512_inserti32x8(
+ lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi),
+ 1);
+}
+
+inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
+ std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+ std::int8_t* trailing_buf) {
+ using Layout = PackImpl8bitAvx512::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 16);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 8;
+ constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+
+ const std::int8_t* src_ptr0 = src_ptr;
+ const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
+ const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
+ const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
+ const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
+ const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
+ const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
+ const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = kNumChunkedSrcRows;
+ std::int64_t src_inc1 = kNumChunkedSrcRows;
+ std::int64_t src_inc2 = kNumChunkedSrcRows;
+ std::int64_t src_inc3 = kNumChunkedSrcRows;
+ std::int64_t src_inc4 = kNumChunkedSrcRows;
+ std::int64_t src_inc5 = kNumChunkedSrcRows;
+ std::int64_t src_inc6 = kNumChunkedSrcRows;
+ std::int64_t src_inc7 = kNumChunkedSrcRows;
+ // Handle cases where source does not have kHalfLayoutCols (8) columns.
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ const std::int8_t zero_point = zerobuf[0];
+
+ if (sums_ptr) {
+ // i: kHalfLayoutCols.
+ for (int i = 0; i < 8; ++i) {
+ sums_ptr[i] = 0;
+ }
+ }
+ std::int32_t sums_adjustment = 0;
+ const __m512i ones_16bit = _mm512_set1_epi16(1);
+ __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
+
+ // The overall packing effectively pads the source rows to
+ // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
+ // only pack for (src_rows + 31) & ~31. When there is an incomplete
+ // destination block, this is stored into trailing_buf instead of packed_ptr.
+ for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
+ // m: {0, 1} for 2 chunks of rows.
+ for (int m = 0; m < 2; ++m) {
+ // Available source rows.
+ // If this is less than 0 (for m=1), we skip, having filled trailing
+ // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+ // exactly to the end of the column in the packed buffer.
+ const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
+ // Effectively,
+ // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
+ // treat each case separately.
+ if (available_src_rows >= kNumChunkedSrcRows) {
+ // i: chunks, s: Layout::Rows.
+ if (sums_ptr) {
+ __m512i t0, t1, t2, t3;
+ __m512i r0, r1, r2, r3;
+ const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
+
+ t0 = LoaduTwo(src_ptr0, src_ptr4);
+ t1 = LoaduTwo(src_ptr1, src_ptr5);
+ t2 = LoaduTwo(src_ptr2, src_ptr6);
+ t3 = LoaduTwo(src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_epi32(t0, t1);
+ r2 = _mm512_unpackhi_epi32(t0, t1);
+ r1 = _mm512_unpacklo_epi32(t2, t3);
+ r3 = _mm512_unpackhi_epi32(t2, t3);
+
+ t0 = _mm512_unpacklo_epi64(r0, r1);
+ t2 = _mm512_unpackhi_epi64(r0, r1);
+ t1 = _mm512_unpacklo_epi64(r2, r3);
+ t3 = _mm512_unpackhi_epi64(r2, r3);
+
+ r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
+
+ r0 = _mm512_xor_si512(r0, input_xor_v);
+ r1 = _mm512_xor_si512(r1, input_xor_v);
+ r2 = _mm512_xor_si512(r2, input_xor_v);
+ r3 = _mm512_xor_si512(r3, input_xor_v);
+
+ const __m256i r0_0 = _mm512_castsi512_si256(r0);
+ const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
+ const __m256i r1_0 = _mm512_castsi512_si256(r1);
+ const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
+ const __m256i r2_0 = _mm512_castsi512_si256(r2);
+ const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
+ const __m256i r3_0 = _mm512_castsi512_si256(r3);
+ const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
+
+ __m512i sums_8x4_16bit;
+ sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
+ // The sums have been performed across columns, and now we have
+ // 4x16-bit sums packed together. We use madd for pairwise 32-bit
+ // sums.
+ const __m512i sums_8x2_32bit_new =
+ _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
+ sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
+
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
+ } else {
+ __m512i t0, t1, t2, t3;
+ __m512i r0, r1, r2, r3;
+ const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
+
+ t0 = LoaduTwo(src_ptr0, src_ptr4);
+ t1 = LoaduTwo(src_ptr1, src_ptr5);
+ t2 = LoaduTwo(src_ptr2, src_ptr6);
+ t3 = LoaduTwo(src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_epi32(t0, t1);
+ r2 = _mm512_unpackhi_epi32(t0, t1);
+ r1 = _mm512_unpacklo_epi32(t2, t3);
+ r3 = _mm512_unpackhi_epi32(t2, t3);
+
+ t0 = _mm512_unpacklo_epi64(r0, r1);
+ t2 = _mm512_unpackhi_epi64(r0, r1);
+ t1 = _mm512_unpacklo_epi64(r2, r3);
+ t3 = _mm512_unpackhi_epi64(r2, r3);
+
+ r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
+
+ r0 = _mm512_xor_si512(r0, input_xor_v);
+ r1 = _mm512_xor_si512(r1, input_xor_v);
+ r2 = _mm512_xor_si512(r2, input_xor_v);
+ r3 = _mm512_xor_si512(r3, input_xor_v);
+
+ const __m256i r0_0 = _mm512_castsi512_si256(r0);
+ const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
+ const __m256i r1_0 = _mm512_castsi512_si256(r1);
+ const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
+ const __m256i r2_0 = _mm512_castsi512_si256(r2);
+ const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
+ const __m256i r3_0 = _mm512_castsi512_si256(r3);
+ const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
+ }
+ } else if (available_src_rows > 0) {
+ RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
+ const __mmask32 row_mask =
+ (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
+
+ // We do not care what goes into the trailing buffer, but we want
+ // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
+ //
+ // We compensate for padding-with-zero_point by initializing the
+ // summations with the compensating offset, effectively
+ // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
+ // 4 * (8 - ((available_src_rows + 3) >> 2)).
+ //
+ // Note that (zero_point ^ input_xor) is performed in 8-bits and then
+ // cast.
+ sums_adjustment += -(zero_point ^ input_xor) * 4 *
+ (8 - ((available_src_rows + 3) >> 2));
+
+ __m512i t0, t1, t2, t3;
+ __m512i r0, r1, r2, r3;
+ const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
+ const __m256i zero_point_v = _mm256_set1_epi8(zero_point);
+
+ t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
+ t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
+ t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
+ t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_epi32(t0, t1);
+ r2 = _mm512_unpackhi_epi32(t0, t1);
+ r1 = _mm512_unpacklo_epi32(t2, t3);
+ r3 = _mm512_unpackhi_epi32(t2, t3);
+
+ t0 = _mm512_unpacklo_epi64(r0, r1);
+ t2 = _mm512_unpackhi_epi64(r0, r1);
+ t1 = _mm512_unpacklo_epi64(r2, r3);
+ t3 = _mm512_unpackhi_epi64(r2, r3);
+
+ r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
+
+ r0 = _mm512_xor_si512(r0, input_xor_v);
+ r1 = _mm512_xor_si512(r1, input_xor_v);
+ r2 = _mm512_xor_si512(r2, input_xor_v);
+ r3 = _mm512_xor_si512(r3, input_xor_v);
+
+ const __m256i r0_0 = _mm512_castsi512_si256(r0);
+ const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
+ const __m256i r1_0 = _mm512_castsi512_si256(r1);
+ const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
+ const __m256i r2_0 = _mm512_castsi512_si256(r2);
+ const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
+ const __m256i r3_0 = _mm512_castsi512_si256(r3);
+ const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
+
+ __m512i sums_8x4_16bit;
+ sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
+ // The sums have been performed across columns, and now we have
+ // 4x16-bit sums packed together. We use madd for pairwise 32-bit
+ // sums.
+ const __m512i sums_8x2_32bit_new =
+ _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
+ sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
+
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(trailing_buf + 0 * 16 * 4), r0_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(trailing_buf + 2 * 16 * 4), r0_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(trailing_buf + 4 * 16 * 4), r1_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(trailing_buf + 6 * 16 * 4), r1_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(trailing_buf + 1 * 16 * 4), r2_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(trailing_buf + 3 * 16 * 4), r2_1);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(trailing_buf + 5 * 16 * 4), r3_0);
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(trailing_buf + 7 * 16 * 4), r3_1);
+ }
+
+ packed_ptr += 16 * kNumChunkedSrcRows;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+ }
+
+ if (sums_ptr) {
+ const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
+
+ __m256i sums =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
+ const __m512i idx =
+ _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
+
+ // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
+ // neighbours, finshing up by adding them to the stored accumulated sums.
+ const __m512i sums_2x8_32bit =
+ _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
+ sums = _mm256_add_epi32(sums, sums_adjustment_v);
+ sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
+ sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
+ }
+}
+
+inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
+ const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
+ return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
+}
+
+inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
+ const float* addr_hi) {
+ const __m512 lower_filled =
+ _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
+ return _mm512_insertf32x8(lower_filled,
+ _mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
+}
+
+inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) {
+ return _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
+}
+
+inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) {
+ return _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
+}
+
+inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr,
+ float* trailing_buf) {
+ const float* src_ptr0 = src_ptr;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ const float* src_ptr4 = src_ptr3 + src_stride;
+ const float* src_ptr5 = src_ptr4 + src_stride;
+ const float* src_ptr6 = src_ptr5 + src_stride;
+ const float* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = 8;
+ std::int64_t src_inc1 = 8;
+ std::int64_t src_inc2 = 8;
+ std::int64_t src_inc3 = 8;
+ std::int64_t src_inc4 = 8;
+ std::int64_t src_inc5 = 8;
+ std::int64_t src_inc6 = 8;
+ std::int64_t src_inc7 = 8;
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ for (int k = 0; k < src_rows; k += 16) {
+ for (int m = 0; m < 2; ++m) {
+ const int available_src_rows = src_rows - k - 8 * m;
+ // Effectively,
+ // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
+ // but treat each case separately.
+ if (available_src_rows > 7) {
+ __m512 t0, t1, t2, t3;
+ __m512 r0, r1, r2, r3;
+
+ t0 = LoaduTwo(src_ptr0, src_ptr4);
+ t1 = LoaduTwo(src_ptr1, src_ptr5);
+ t2 = LoaduTwo(src_ptr2, src_ptr6);
+ t3 = LoaduTwo(src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_ps(t0, t1);
+ r2 = _mm512_unpackhi_ps(t0, t1);
+ r1 = _mm512_unpacklo_ps(t2, t3);
+ r3 = _mm512_unpackhi_ps(t2, t3);
+
+ t0 = Mm512UnpackloPsx2(r0, r1);
+ t2 = Mm512UnpackhiPsx2(r0, r1);
+ t1 = Mm512UnpackloPsx2(r2, r3);
+ t3 = Mm512UnpackhiPsx2(r2, r3);
+
+ r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
+
+ _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0));
+ _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
+ _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1));
+ _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
+ _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2));
+ _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
+ _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3));
+ _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1));
+ } else if (available_src_rows > 0) {
+ const __mmask8 row_mask =
+ (static_cast<std::uint32_t>(1) << available_src_rows) - 1;
+
+ __m512 t0, t1, t2, t3;
+ __m512 r0, r1, r2, r3;
+
+ t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
+ t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
+ t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
+ t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_ps(t0, t1);
+ r2 = _mm512_unpackhi_ps(t0, t1);
+ r1 = _mm512_unpacklo_ps(t2, t3);
+ r3 = _mm512_unpackhi_ps(t2, t3);
+
+ t0 = Mm512UnpackloPsx2(r0, r1);
+ t2 = Mm512UnpackhiPsx2(r0, r1);
+ t1 = Mm512UnpackloPsx2(r2, r3);
+ t3 = Mm512UnpackhiPsx2(r2, r3);
+
+ r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
+
+ _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0));
+ _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
+ _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1));
+ _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
+ _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2));
+ _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
+ _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3));
+ // Do not store _mm512_extractf32x8_ps(r3, 1).
+ }
+
+ packed_ptr += 16 * 8;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+ }
+}
+
+inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
+ const int non_trailing_rows = src_rows & ~7;
+ for (int k = 0; k < non_trailing_rows; ++k) {
+ for (int j = 0; j < 8; ++j) {
+ packed_ptr[j] = 0.0f;
+ }
+ packed_ptr += 16;
+ }
+}
+
+} // namespace.
+
+void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
+ std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr) {
+ profiler::ScopeLabel label("Pack kAvx512 8bit");
+
+ using Layout = PackImpl8bitAvx512::Layout;
+ constexpr int kHalfBlockOffset = 32;
+ RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
+ static constexpr int kHalfLayoutCols =
+ PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
+ // block.
+ RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+ RUY_DCHECK_EQ(Layout::kCols, 16);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 8;
+
+ // Each packed block is 4*16, and there are normally 8. The trailing block is
+ // only slightly shorter.
+ constexpr int kTrailingBufSize =
+ kNumRowChunks * Layout::kCols * Layout::kRows;
+ std::int8_t trailing_buf[kTrailingBufSize];
+ memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+
+ std::int32_t* second_sums_ptr =
+ sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
+ if (remaining_src_cols > kHalfLayoutCols) {
+ HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+ trailing_buf);
+ HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
+ zerobuf, src_stride,
+ remaining_src_cols - kHalfLayoutCols, src_rows,
+ packed_ptr + kHalfBlockOffset, second_sums_ptr,
+ trailing_buf + kHalfBlockOffset);
+ } else {
+ HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+ trailing_buf);
+ ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
+ packed_ptr + kHalfBlockOffset);
+ // The kernel may not need the second half-blocks sums to be set.
+ if (second_sums_ptr) {
+ for (int i = 0; i < kHalfLayoutCols; ++i) {
+ second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
+ }
+ }
+ }
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+ const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+ // If the number of source rows is not a multiple of kChunkedRowMask, there
+ // will be data in the trailing buffer,
+ if (trailing_data) {
+ const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+ // Destination "rows" are padded to next highest multiple of Layout::kRows.
+ const int dst_rows = (src_rows + 3) & ~3;
+ const int trailing_rows = dst_rows - non_trailing_rows;
+ memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+ Layout::kCols * trailing_rows * sizeof(std::int8_t));
+ }
+}
+
+void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr) {
+ profiler::ScopeLabel label("Pack kAvx512 float");
+ float trailing_buf[7 * 16];
+ if (remaining_src_cols > 8) {
+ HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, trailing_buf);
+ HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
+ remaining_src_cols - 8, src_rows, packed_ptr + 8,
+ trailing_buf + 8);
+ } else {
+ memset(trailing_buf, 0, sizeof(trailing_buf));
+ HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, trailing_buf);
+ ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
+ }
+ const int trailing_rows = src_rows & 7;
+ if (trailing_rows > 0) {
+ const int non_trailing_rows = src_rows & ~7;
+ memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
+ 16 * trailing_rows * sizeof(float));
+ }
+}
+
+void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums) {
+ int col = start_col;
+ int src_end_col = std::min(end_col, src_cols);
+
+ for (; col <= src_end_col - 16; col += 16) {
+ std::int8_t* dst_ptr = packed_ptr;
+ __m128i val0, val1, val2, val3;
+ __m128i input_xor_dup = _mm_set1_epi8(input_xor);
+ // Load a 4x16 block.
+ if (block_row + 4 <= src_rows) {
+ val0 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
+ val1 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
+ val2 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
+ val3 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
+ } else {
+ val0 = _mm_set1_epi8(src_zero_point);
+ val1 = val0;
+ val2 = val0;
+ val3 = val0;
+ if (block_row + 0 < src_rows)
+ val0 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
+ if (block_row + 1 < src_rows)
+ val1 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
+ if (block_row + 2 < src_rows)
+ val2 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
+ if (block_row + 3 < src_rows)
+ val3 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
+ }
+ // Maybe xor the sign bit to convert from uint8 to int8.
+ val0 = _mm_xor_si128(val0, input_xor_dup);
+ val1 = _mm_xor_si128(val1, input_xor_dup);
+ val2 = _mm_xor_si128(val2, input_xor_dup);
+ val3 = _mm_xor_si128(val3, input_xor_dup);
+ // Update the sums.
+ __m256i val16_0 = _mm256_cvtepi8_epi16(val0);
+ __m256i val16_1 = _mm256_cvtepi8_epi16(val1);
+ __m256i val16_2 = _mm256_cvtepi8_epi16(val2);
+ __m256i val16_3 = _mm256_cvtepi8_epi16(val3);
+ __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1),
+ _mm256_add_epi16(val16_2, val16_3));
+ __m512i sum =
+ _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col));
+ sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16));
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum);
+ auto zip = [](__m128i x, __m128i y) {
+ auto perm_64_0_64_0 = [](__m128i x) {
+ return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3),
+ _mm256_castsi128_si256(x));
+ };
+ return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y));
+ };
+ __m256i t2_val0 = zip(val0, val1);
+ __m256i t2_val1 = zip(val2, val3);
+ __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1);
+ __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr),
+ _mm256_extractf128_si256(t4_val0, 0));
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16),
+ _mm256_extractf128_si256(t4_val1, 0));
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32),
+ _mm256_extractf128_si256(t4_val0, 1));
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48),
+ _mm256_extractf128_si256(t4_val1, 1));
+ src_ptr += 16;
+ packed_ptr += packed_stride * 16;
+ }
+ for (; col < src_end_col; col++) {
+ std::int32_t accum = 0;
+ for (int r = 0; r < 4; r++) {
+ std::int8_t packed_val;
+ if (block_row + r < src_rows) {
+ packed_val = input_xor ^ src_ptr[r * src_stride];
+ } else {
+ packed_val = input_xor ^ src_zero_point;
+ }
+ accum += packed_val;
+ *packed_ptr++ = packed_val;
+ }
+ if (sums) {
+ sums[col] += accum;
+ }
+ src_ptr++;
+ }
+ for (; col < end_col; col++) {
+ std::memset(packed_ptr, 0, 4);
+ packed_ptr += 4;
+ }
+}
+
+#endif // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
+
+} // namespace ruy
diff --git a/ruy/pack_common.h b/ruy/pack_common.h
new file mode 100644
index 0000000..8f07658
--- /dev/null
+++ b/ruy/pack_common.h
@@ -0,0 +1,143 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PACK_COMMON_H_
+#define RUY_RUY_PACK_COMMON_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/mat.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+template <typename Scalar>
+Scalar SymmetricZeroPoint() {
+ if (std::is_floating_point<Scalar>::value) {
+ return 0;
+ }
+ if (std::is_signed<Scalar>::value) {
+ return 0;
+ }
+ return std::numeric_limits<Scalar>::max() / 2 + 1;
+}
+
+template <Path ThePath, typename Scalar>
+struct PackedTypeImpl {
+ using Type = Scalar;
+};
+
+template <Path ThePath, typename Scalar>
+using PackedType = typename PackedTypeImpl<ThePath, Scalar>::Type;
+
+template <typename PackedScalar, typename Scalar>
+PackedScalar Pack(Scalar x) {
+ return x - SymmetricZeroPoint<Scalar>() + SymmetricZeroPoint<PackedScalar>();
+}
+
+template <Path ThePath, typename FixedKernelLayout, typename Scalar,
+ typename PackedScalar, typename SumsType, Order SrcOrder>
+struct PackImpl;
+
+#define RUY_INHERIT_PACK(PARENT, CHILD) \
+ template <typename FixedKernelLayout, typename Scalar, \
+ typename PackedScalar, typename SumsType, Order SrcOrder> \
+ struct PackImpl<CHILD, FixedKernelLayout, Scalar, PackedScalar, SumsType, \
+ SrcOrder> : PackImpl<PARENT, FixedKernelLayout, Scalar, \
+ PackedScalar, SumsType, SrcOrder> {};
+
+// A generic yet fairly fast implementation of
+//
+// PackImpl<ThePath, FixedKernelLayout<Order::kRowMajor, 1, KernelCols>,
+// float, float, float, Order::kRowMajor>
+//
+// that is, a packing code path for the case of floating-point, row-major
+// source matrix, targeting typical float kernel layouts consisting of a
+// single row.
+//
+// The only reason why this isn't a partial specialization of PackImpl is that
+// this leads to ambiguous partial specializations as this conflicts with
+// the ones defined by RUY_INHERIT_PACK.
+//
+// What's special about floating-point kernels is that they tend to use
+// FixedKernelLayout<Order::kRowMajor, 1, KernelCols> for some value of
+// KernelCols, making it easy to implement the packing code as essentially
+// a bunch of memcpy's with compile-time-fixed size
+// (KernelCols * sizeof(float)), typically 16, 32 or 64 bytes. Unlike the
+// quantized case, there are no sums to compute, and the float kernels tend
+// to use this kind of simple layout on multiple architectures, unlike the
+// heavily architecture-specific layouts of quantized kernels.
+//
+// Here are the current instantiations of this template (as of 2020):
+// Path | KernelCols
+// --------------+---------------------------------
+// kNeon (ARM32) | 8 and 4 (for LHS and RHS sides)
+// kNeon (ARM64) | 8
+// kAvxFma | 8
+// kAvx512 | 16
+template <Path ThePath, int KernelCols>
+struct MemcpyRowMajorFloatPackImpl {
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
+ RUY_DCHECK(IsRowMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % KernelCols, 0);
+ int src_stride = src_matrix.layout.stride;
+ // As the source matrix is row-major and the destination packed matrix is
+ // column-major, there is no traversal order that will be optimal for both
+ // so we choose to favor the source matrix with a row-major traversal order.
+ for (int block_row = 0; block_row < src_matrix.layout.rows;
+ block_row += 1) {
+ const float* src_ptr =
+ src_matrix.data.get() + src_stride * block_row + start_col;
+ float* packed_ptr = packed_matrix->data +
+ packed_matrix->layout.stride * start_col +
+ KernelCols * block_row;
+ int src_cols = std::min(end_col, src_matrix.layout.cols) - start_col;
+ int col = 0;
+ for (; col <= src_cols - KernelCols; col += KernelCols) {
+ memcpy(packed_ptr, src_ptr, KernelCols * sizeof(float));
+ packed_ptr += KernelCols * packed_matrix->layout.stride;
+ src_ptr += KernelCols;
+ }
+ int remaining_cols = src_cols - col;
+ if (remaining_cols > 0) {
+ memcpy(packed_ptr, src_ptr, remaining_cols * sizeof(float));
+ memset(packed_ptr + remaining_cols, 0,
+ (KernelCols - remaining_cols) * sizeof(float));
+ }
+ }
+ }
+};
+
+#define RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(ThePath, KernelCols) \
+ template <> \
+ struct PackImpl<ThePath, FixedKernelLayout<Order::kRowMajor, 1, KernelCols>, \
+ float, float, float, Order::kRowMajor> \
+ : MemcpyRowMajorFloatPackImpl<ThePath, KernelCols> {};
+
+} // namespace ruy
+
+#endif // RUY_RUY_PACK_COMMON_H_
diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h
new file mode 100644
index 0000000..f3ea54e
--- /dev/null
+++ b/ruy/pack_x86.h
@@ -0,0 +1,659 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PACK_X86_H_
+#define RUY_RUY_PACK_X86_H_
+
+#include <cstdint>
+#include <cstring>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/mat.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack_common.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM_X86
+
+RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx)
+RUY_INHERIT_PACK(Path::kAvx, Path::kAvx2Fma)
+RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512)
+
+RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx2Fma, 8)
+RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx512, 16)
+
+template <>
+struct PackedTypeImpl<Path::kAvx, std::uint8_t> {
+ using Type = std::int8_t;
+};
+
+template <>
+struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> {
+ using Type = std::int8_t;
+};
+template <>
+struct PackedTypeImpl<Path::kAvx512, std::uint8_t> {
+ using Type = std::int8_t;
+};
+
+// Note that source and zero buffers can be uint8 type, but in the packing
+// function are reinterpreted as int8, and are XOR-ed with input_xor.
+void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ static constexpr std::int8_t kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX2 8-bit)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[Layout::kCols * Layout::kRows];
+ memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
+ Layout::kCols * Layout::kRows * sizeof(Scalar));
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ Pack8bitColMajorForAvx2(
+ reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
+ reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
+ remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
+ }
+ }
+};
+
+void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
+ std::int8_t, std::int32_t, Order::kColMajor> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ static constexpr std::int8_t kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX 8-bit)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[Layout::kCols * Layout::kRows];
+ memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
+ Layout::kCols * Layout::kRows * sizeof(Scalar));
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ Pack8bitColMajorForAvx(
+ reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
+ reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
+ remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
+ }
+ }
+};
+
+void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr);
+
+template <>
+struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float, Order::kColMajor> {
+ using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
+ profiler::ScopeLabel label("Pack (AVX float)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ const float zerobuf[Layout::kCols] = {
+ 0.0f}; // Remainder default inits to 0.0f.
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ float* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ PackFloatColMajorForAvx(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_matrix.layout.rows, packed_ptr);
+ }
+ }
+};
+
+void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr);
+
+template <>
+struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>,
+ float, float, float, Order::kColMajor> {
+ using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
+ profiler::ScopeLabel label("Pack (AVX2 float)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ const float zerobuf[Layout::kCols] = {
+ 0.0f}; // Remainder default inits to 0.0f.
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ float* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ PackFloatColMajorForAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_matrix.layout.rows, packed_ptr);
+ }
+ }
+};
+
+// Note that source and zero buffers can be uint8 type, but in the packing
+// function are reinterpreted as int8, and are XOR-ed with input_xor.
+void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
+ std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ static constexpr int kHalfLayoutCols =
+ 8; // Half the number of cols in a block.
+ static constexpr std::int8_t kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX-512 8-bit)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
+ memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
+ kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ Pack8bitColMajorForAvx512(
+ reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
+ reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
+ remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
+ }
+ }
+};
+
+void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr);
+
+template <>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
+ float, float, float, Order::kColMajor> {
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
+ profiler::ScopeLabel label("Pack (AVX-512 float)");
+ using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ const float zerobuf[Layout::kCols] = {
+ 0.0f}; // Remainder default inits to 0.0f.
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ float* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ PackFloatColMajorForAvx512(src_ptr, zerobuf, src_stride,
+ remaining_src_cols, src_matrix.layout.rows,
+ packed_ptr);
+ }
+ }
+};
+
+void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)");
+ RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ std::int32_t* sums = packed_matrix->sums;
+ std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+ int block_row = 0;
+ for (; block_row < packed_matrix->layout.rows; block_row += 4) {
+ int src_stride = src_matrix.layout.stride;
+ int packed_stride = packed_matrix->layout.stride;
+ const Scalar* src_ptr =
+ src_matrix.data.get() + block_row * src_stride + start_col;
+ std::int8_t* packed_ptr =
+ packed_matrix->data + start_col * packed_stride + block_row * 8;
+ Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr),
+ src_stride, src_matrix.zero_point, packed_ptr,
+ packed_stride, start_col, end_col,
+ src_matrix.layout.cols, block_row,
+ src_matrix.layout.rows, kInputXor, sums);
+ }
+ }
+};
+
+void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
+ std::int8_t, std::int32_t, Order::kRowMajor> {
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX 8bit row-major)");
+ RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ std::int32_t* sums = packed_matrix->sums;
+ std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+ int block_row = 0;
+ for (; block_row < packed_matrix->layout.rows; block_row += 4) {
+ int src_stride = src_matrix.layout.stride;
+ int packed_stride = packed_matrix->layout.stride;
+ const Scalar* src_ptr =
+ src_matrix.data.get() + block_row * src_stride + start_col;
+ std::int8_t* packed_ptr =
+ packed_matrix->data + start_col * packed_stride + block_row * 8;
+ Pack8bitRowMajorForAvx(reinterpret_cast<const std::uint8_t*>(src_ptr),
+ src_stride, src_matrix.zero_point, packed_ptr,
+ packed_stride, start_col, end_col,
+ src_matrix.layout.cols, block_row,
+ src_matrix.layout.rows, kInputXor, sums);
+ }
+ }
+};
+
+void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)");
+ RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ std::int32_t* sums = packed_matrix->sums;
+ std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+ int block_row = 0;
+ for (; block_row < packed_matrix->layout.rows; block_row += 4) {
+ int src_stride = src_matrix.layout.stride;
+ int packed_stride = packed_matrix->layout.stride;
+ const Scalar* src_ptr =
+ src_matrix.data.get() + block_row * src_stride + start_col;
+ std::int8_t* packed_ptr =
+ packed_matrix->data + start_col * packed_stride + block_row * 16;
+ Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr),
+ src_stride, src_matrix.zero_point, packed_ptr,
+ packed_stride, start_col, end_col,
+ src_matrix.layout.cols, block_row,
+ src_matrix.layout.rows, kInputXor, sums);
+ }
+ }
+};
+#endif // RUY_PLATFORM_X86
+
+} // namespace ruy
+
+#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
+
+#include <immintrin.h> // IWYU pragma: keep
+
+namespace ruy {
+namespace {
+
+template <Path path>
+inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) {
+ return _mm256_castpd_ps(
+ _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
+}
+
+template <Path path>
+inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) {
+ return _mm256_castpd_ps(
+ _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
+}
+
+template <Path path>
+inline __m256i CompareGreaterThan(const __m256i&, const __m256i&) {
+ RUY_DCHECK(false);
+ return _mm256_set1_epi32(0);
+}
+
+// Shared between AVX and AVX2+FMA.
+template <Path path>
+inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point,
+ const std::int8_t* addr) {
+ RUY_DCHECK_LT(available_src_rows, 32);
+ __m256i padded_data;
+
+ if (available_src_rows >= 16) {
+ __m128i load_hi = _mm_set1_epi8(zero_point);
+ __m128i load_lo = _mm_loadu_si128(reinterpret_cast<const __m128i*>(addr));
+ memcpy(&load_hi, addr + 16, available_src_rows - 16);
+ padded_data = _mm256_set_m128i(load_hi, load_lo);
+ } else {
+ __m128i load_hi = _mm_set1_epi8(zero_point);
+ __m128i load_lo = load_hi;
+ memcpy(&load_lo, addr, available_src_rows);
+ padded_data = _mm256_set_m128i(load_hi, load_lo);
+ }
+ return padded_data;
+}
+
+} // namespace.
+
+template <typename PackImpl, Path path>
+inline void PackFloatColMajorForAvxCommonPacker(const float* src_ptr,
+ const float* zerobuf,
+ int src_stride,
+ int remaining_src_cols,
+ int src_rows, float* packed_ptr,
+ float* trailing_buf) {
+ RUY_DCHECK_EQ(PackImpl::Layout::kCols, 8);
+ RUY_DCHECK_EQ(PackImpl::Layout::kRows, 1);
+
+ // This packing amounts to transposition of 8x8 blocks.
+ static constexpr int kPackCols = 8; // Source cols packed together.
+ static constexpr int kPackRows = 8; // Short input is padded.
+
+ const float* src_ptr0 = src_ptr;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ const float* src_ptr4 = src_ptr3 + src_stride;
+ const float* src_ptr5 = src_ptr4 + src_stride;
+ const float* src_ptr6 = src_ptr5 + src_stride;
+ const float* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = 8;
+ std::int64_t src_inc1 = 8;
+ std::int64_t src_inc2 = 8;
+ std::int64_t src_inc3 = 8;
+ std::int64_t src_inc4 = 8;
+ std::int64_t src_inc5 = 8;
+ std::int64_t src_inc6 = 8;
+ std::int64_t src_inc7 = 8;
+ // Handle cases where source does not have kPackDim (8) columns.
+ if (remaining_src_cols < kPackCols) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ for (int k = 0; k < src_rows; k += kPackRows) {
+ const int available_src_rows = src_rows - k;
+ // Effectively,
+ // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k));
+ // but treat each case separately.
+ if (available_src_rows >= kPackRows) {
+ __m256 t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256 r0, r1, r2, r3, r4, r5, r6, r7;
+
+ t0 = _mm256_loadu_ps(src_ptr0);
+ t4 = _mm256_loadu_ps(src_ptr4);
+ t1 = _mm256_loadu_ps(src_ptr1);
+ t5 = _mm256_loadu_ps(src_ptr5);
+ t2 = _mm256_loadu_ps(src_ptr2);
+ t6 = _mm256_loadu_ps(src_ptr6);
+ t3 = _mm256_loadu_ps(src_ptr3);
+ t7 = _mm256_loadu_ps(src_ptr7);
+
+ r0 = _mm256_unpacklo_ps(t0, t1);
+ r4 = _mm256_unpacklo_ps(t4, t5);
+ r2 = _mm256_unpackhi_ps(t0, t1);
+ r6 = _mm256_unpackhi_ps(t4, t5);
+ r1 = _mm256_unpacklo_ps(t2, t3);
+ r5 = _mm256_unpacklo_ps(t6, t7);
+ r3 = _mm256_unpackhi_ps(t2, t3);
+ r7 = _mm256_unpackhi_ps(t6, t7);
+
+ t0 = Mm256UnpackloPsx2<path>(r0, r1);
+ t4 = Mm256UnpackloPsx2<path>(r4, r5);
+ t2 = Mm256UnpackhiPsx2<path>(r0, r1);
+ t6 = Mm256UnpackhiPsx2<path>(r4, r5);
+ t1 = Mm256UnpackloPsx2<path>(r2, r3);
+ t5 = Mm256UnpackloPsx2<path>(r6, r7);
+ t3 = Mm256UnpackhiPsx2<path>(r2, r3);
+ t7 = Mm256UnpackhiPsx2<path>(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by 16
+ // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
+ // are interleaved to create (r0, r1). This complexity follows from the
+ // way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
+ r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
+ r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
+ r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
+ r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
+ r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
+ r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
+ r7 = _mm256_permute2f128_ps(t3, t7, 0x31);
+
+ _mm256_storeu_ps(packed_ptr + 0 * 8, r0);
+ _mm256_storeu_ps(packed_ptr + 2 * 8, r4);
+ _mm256_storeu_ps(packed_ptr + 4 * 8, r1);
+ _mm256_storeu_ps(packed_ptr + 6 * 8, r5);
+ _mm256_storeu_ps(packed_ptr + 1 * 8, r2);
+ _mm256_storeu_ps(packed_ptr + 3 * 8, r6);
+ _mm256_storeu_ps(packed_ptr + 5 * 8, r3);
+ _mm256_storeu_ps(packed_ptr + 7 * 8, r7);
+ } else if (available_src_rows > 0) {
+ const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
+ const __m256i row_mask_v = CompareGreaterThan<path>(
+ _mm256_set1_epi32(available_src_rows), series);
+
+ __m256 t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256 r0, r1, r2, r3, r4, r5, r6, r7;
+
+ t0 = _mm256_maskload_ps(src_ptr0, row_mask_v);
+ t4 = _mm256_maskload_ps(src_ptr4, row_mask_v);
+ t1 = _mm256_maskload_ps(src_ptr1, row_mask_v);
+ t5 = _mm256_maskload_ps(src_ptr5, row_mask_v);
+ t2 = _mm256_maskload_ps(src_ptr2, row_mask_v);
+ t6 = _mm256_maskload_ps(src_ptr6, row_mask_v);
+ t3 = _mm256_maskload_ps(src_ptr3, row_mask_v);
+ t7 = _mm256_maskload_ps(src_ptr7, row_mask_v);
+
+ r0 = _mm256_unpacklo_ps(t0, t1);
+ r4 = _mm256_unpacklo_ps(t4, t5);
+ r2 = _mm256_unpackhi_ps(t0, t1);
+ r6 = _mm256_unpackhi_ps(t4, t5);
+ r1 = _mm256_unpacklo_ps(t2, t3);
+ r5 = _mm256_unpacklo_ps(t6, t7);
+ r3 = _mm256_unpackhi_ps(t2, t3);
+ r7 = _mm256_unpackhi_ps(t6, t7);
+
+ t0 = Mm256UnpackloPsx2<path>(r0, r1);
+ t4 = Mm256UnpackloPsx2<path>(r4, r5);
+ t2 = Mm256UnpackhiPsx2<path>(r0, r1);
+ t6 = Mm256UnpackhiPsx2<path>(r4, r5);
+ t1 = Mm256UnpackloPsx2<path>(r2, r3);
+ t5 = Mm256UnpackloPsx2<path>(r6, r7);
+ t3 = Mm256UnpackhiPsx2<path>(r2, r3);
+ t7 = Mm256UnpackhiPsx2<path>(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by 16
+ // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
+ // are interleaved to create (r0, r1). This complexity follows from the
+ // way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
+ r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
+ r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
+ r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
+ r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
+ r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
+ r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
+ // r7 no longer needed.
+
+ _mm256_storeu_ps(trailing_buf + 0 * 8, r0);
+ _mm256_storeu_ps(trailing_buf + 2 * 8, r4);
+ _mm256_storeu_ps(trailing_buf + 4 * 8, r1);
+ _mm256_storeu_ps(trailing_buf + 6 * 8, r5);
+ _mm256_storeu_ps(trailing_buf + 1 * 8, r2);
+ _mm256_storeu_ps(trailing_buf + 3 * 8, r6);
+ _mm256_storeu_ps(trailing_buf + 5 * 8, r3);
+ // No store to (trailing_buf + 7 * 8), space not allocated.
+ }
+
+ packed_ptr += kPackRows * kPackCols;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+}
+} // namespace ruy
+#endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
+
+#endif // RUY_RUY_PACK_X86_H_
diff --git a/ruy/path.h b/ruy/path.h
new file mode 100644
index 0000000..d3c5b06
--- /dev/null
+++ b/ruy/path.h
@@ -0,0 +1,203 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PATH_H_
+#define RUY_RUY_PATH_H_
+
+#include <cstdint>
+
+#include "ruy/platform.h"
+#include "ruy/size_util.h"
+
+namespace ruy {
+
+// A Path is an implementation path, typically corresponding to a SIMD
+// instruction set being targetted. For example, on the ARM architecture,
+// Path::kNeon means using NEON instructions, and Path::kNeonDotprod means
+// also using the newer NEON dot-product instructions.
+//
+// Different Path enum values are defined on different CPU architectures,
+// corresponding to different SIMD ISA extensions available there.
+//
+// Path::kStandardCpp is the one Path that is always available.
+//
+// Path enum values are bits and may be OR-ed to form "sets of Paths".
+// Ruy entry points such as ruy::Mul either implicitly use such a set of Paths,
+// or allow passing an explicit one as a template parameter. The meaning of such
+// an OR-ed Path combination is "compile all of
+// these paths; which path is used will be determined at runtime". This is why
+// for most users, it is enough to call ruy::Mul(...), which will compile a
+// reasonable selection of paths for the target CPU architecture's various
+// SIMD ISA extensions, and let ruy determine at runtime which one to use.
+// Internally, after the actual path has been resolved, ruy's internal functions
+// templatized on a Path tend to require that to be a single bit.
+//
+// An element of ruy's internal design was to allow for code compiled for
+// multiple such paths to coexist without violating the C++ One Definition Rule
+// (ODR). This is achieved by having all ruy internal functions, whose
+// definition depends on a choice of Path, be templatized on a Path, so that
+// each path-specific specialization is a separate symbol. There is never
+// a need to compile ruy code with different compilation flags to enable
+// different SIMD extensions and dispatch at runtime between them, as this is
+// taken care of internally by ruy in an ODR-correct way.
+enum class Path : std::uint8_t {
+ // This is a special null value, representing the absence of any path.
+ kNone = 0,
+ // Standard C++ implementation of Ruy's architecture-specific parts.
+ //
+ // This is intended for testing/development, and as a fallback for when
+ // the SIMD ISA extensions required by other paths are unavailable at runtime.
+ kStandardCpp = 0x1,
+ // Internal, test-only variants of StandardCpp used to exercise more corners
+ // of internal ruy logic.
+ // They are intentionally omitted from ruy::kAllPaths and ruy::kNonArchPaths,
+ // and are only ever used in dedicated ruy tests explicitly referencing them.
+ kInternalStandardCppVariant1 = 0x2,
+ kInternalStandardCppVariant2 = 0x4,
+ kInternalStandardCppVariant3 = 0x8,
+
+#if RUY_PLATFORM_ARM
+ // Optimized path using a widely available subset of ARM NEON instructions.
+ kNeon = 0x10,
+ // Optimized path making use of ARM NEON dot product instructions that are
+ // available on newer ARM cores.
+ kNeonDotprod = 0x20,
+#endif // RUY_PLATFORM_ARM
+
+#if RUY_PLATFORM_X86
+ // Optimized for AVX
+ // Compiled with -mavx
+ kAvx = 0x10,
+ // Optimized for AVX2+FMA.
+ // Compiled with -mavx2 -mfma.
+ kAvx2Fma = 0x20,
+ // Optimized for AVX-512.
+ // Compiled with -mavx512f -mavx512vl -mavx512cd -mavx512bw -mavx512dq.
+ kAvx512 = 0x40,
+#endif // RUY_PLATFORM_X86
+};
+
+inline constexpr Path operator|(Path p, Path q) {
+ return static_cast<Path>(static_cast<std::uint32_t>(p) |
+ static_cast<std::uint32_t>(q));
+}
+
+inline constexpr Path operator&(Path p, Path q) {
+ return static_cast<Path>(static_cast<std::uint32_t>(p) &
+ static_cast<std::uint32_t>(q));
+}
+
+inline constexpr Path operator^(Path p, Path q) {
+ return static_cast<Path>(static_cast<std::uint32_t>(p) ^
+ static_cast<std::uint32_t>(q));
+}
+
+inline constexpr Path operator~(Path p) {
+ return static_cast<Path>(~static_cast<std::uint32_t>(p));
+}
+
+inline constexpr bool Disjoint(Path p, Path q) {
+ return (p & q) == Path::kNone;
+}
+
+inline Path GetMostSignificantPath(Path path_mask) {
+ return static_cast<Path>(round_down_pot(static_cast<int>(path_mask)));
+}
+
+// We define three disjoint sets of paths.
+//
+// kNonArchPaths is the set of paths that are defined regardless of
+// the CPU architecture (excluding some internal test-only paths).
+// These paths are slow, but portable. At the moment,
+// that is only kStandardCpp. In the past, that used to also include a
+// kReference path providing an even more basic implementation, but that has
+// been split out into a separate library, see the ReferenceMul function.
+constexpr Path kNonArchPaths = Path::kStandardCpp;
+
+// The other two are specific to each CPU architecture. Note that these sets
+// do NOT include a fallback for when none of these architecture paths are
+// supported at runtime by the CPU. For that, see the other constants defined
+// further below.
+//
+// kDefaultArchPaths is the set of architecture-specific paths that
+// we recommend for most users. It is part of kDefaultPaths defined
+// below.
+//
+// kExtraArchPaths is the set of all other architecture-specific paths
+// that for whatever reason we're not recommending to most users at the moment.
+// Typically that would include work-in-progress paths, or paths targeting
+// minority hardware that isn't the best compromise of code size to performance
+// for most users.
+
+#if RUY_PLATFORM_NEON_64
+constexpr Path kDefaultArchPaths = Path::kNeon | Path::kNeonDotprod;
+constexpr Path kExtraArchPaths = Path::kNone;
+#elif RUY_PLATFORM_NEON_32
+constexpr Path kDefaultArchPaths = Path::kNeon;
+constexpr Path kExtraArchPaths = Path::kNone;
+#elif RUY_PLATFORM_X86
+constexpr Path kDefaultArchPaths = Path::kAvx | Path::kAvx2Fma | Path::kAvx512;
+constexpr Path kExtraArchPaths = Path::kNone;
+#else
+constexpr Path kDefaultArchPaths = Path::kNone;
+constexpr Path kExtraArchPaths = Path::kNone;
+#endif
+
+// kNonArchPathsIncludingInternalVariants is the set of all
+// non-architecture-specific paths without exception. This includes some paths
+// that are internal-only and test-only and not useful to any user.
+static constexpr Path kNonArchPathsIncludingInternalVariants =
+ kNonArchPaths | Path::kInternalStandardCppVariant1 |
+ Path::kInternalStandardCppVariant2 | Path::kInternalStandardCppVariant3;
+
+// Enforce that kDefaultArchPaths, kExtraArchPaths and
+// kNonArchPathsIncludingInternalVariants are mutually disjoint,
+// and that kNonArchPaths is a subset of kNonArchPathsIncludingInternalVariants.
+static_assert(Disjoint(kDefaultArchPaths, kExtraArchPaths), "");
+static_assert(Disjoint(kDefaultArchPaths,
+ kNonArchPathsIncludingInternalVariants),
+ "");
+static_assert(Disjoint(kExtraArchPaths, kNonArchPathsIncludingInternalVariants),
+ "");
+static_assert(Disjoint(kNonArchPaths, ~kNonArchPathsIncludingInternalVariants),
+ "");
+
+// We now define two aggregate sets of paths for convenience, including
+// both architecture-specific paths and some portable fallbacks.
+//
+// kDefaultPaths is the set of paths that we recommend most users to use.
+// It is what ruy::Mul(...), the entry point not taking an explicit Path value,
+// uses.
+constexpr Path kDefaultPaths = Path::kStandardCpp | kDefaultArchPaths;
+
+// kAllPaths is the set of all paths that are available to compile, except
+// some some internal test-only paths that no user would ever want to use.
+// In addition to the Default paths, it also includes the extra
+// architecture paths, as well as any other non-arch path besides kStandardCpp
+// (there is none at the moment).
+constexpr Path kAllPaths = kNonArchPaths | kDefaultArchPaths | kExtraArchPaths;
+
+// kAllPathsIncludingInternalVariants is the set of all paths without exception.
+// This includes some paths that are internal-only and test-only and not useful
+// to any user.
+static constexpr Path kAllPathsIncludingInternalVariants =
+ kAllPaths | kNonArchPathsIncludingInternalVariants;
+
+static_assert(Disjoint(kDefaultPaths, ~kAllPaths), "");
+static_assert(Disjoint(kAllPaths, ~kAllPathsIncludingInternalVariants), "");
+
+} // namespace ruy
+
+#endif // RUY_RUY_PATH_H_
diff --git a/ruy/perchannel_buffers_reallocation_test.cc b/ruy/perchannel_buffers_reallocation_test.cc
new file mode 100644
index 0000000..0754aef
--- /dev/null
+++ b/ruy/perchannel_buffers_reallocation_test.cc
@@ -0,0 +1,120 @@
+#include "ruy/context.h"
+#include "ruy/gtest_wrapper.h"
+#include "ruy/kernel.h"
+#include "ruy/matrix.h"
+#include "ruy/path.h"
+#include "ruy/performance_advisory.h"
+#include "ruy/ruy.h"
+
+namespace ruy {
+namespace {
+
+constexpr Path kPath = Path::kInternalStandardCppVariant3;
+constexpr int kBufferSize = 64;
+
+template <typename AccumScalar, typename DstScalar,
+ bool HaveQuantizedMultipliers =
+ std::is_same<AccumScalar, std::int32_t>::value &&
+ !std::is_same<DstScalar, std::int32_t>::value>
+struct PopulatePerChannelBuffersImpl {
+ static void Run(MulParams<AccumScalar, DstScalar>* mul_params) {
+ static const AccumScalar bias_buf[kBufferSize] = {0};
+ static const AccumScalar multiplier_fixedpoint_buf[kBufferSize] = {0};
+ static const int multiplier_exponent_buf[kBufferSize] = {0};
+ mul_params->set_bias(bias_buf);
+ mul_params->set_multiplier_fixedpoint_perchannel(multiplier_fixedpoint_buf);
+ mul_params->set_multiplier_exponent_perchannel(multiplier_exponent_buf);
+ }
+};
+
+template <typename AccumScalar, typename DstScalar>
+struct PopulatePerChannelBuffersImpl<AccumScalar, DstScalar, false> {
+ static void Run(MulParams<AccumScalar, DstScalar>* mul_params) {
+ static const AccumScalar bias_buf[kBufferSize] = {0};
+ mul_params->set_bias(bias_buf);
+ }
+};
+
+template <typename AccumScalar, typename DstScalar>
+void PopulatePerChannelBuffers(MulParams<AccumScalar, DstScalar>* mul_params) {
+ PopulatePerChannelBuffersImpl<AccumScalar, DstScalar>::Run(mul_params);
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestPerChannelBuffersReallocation() {
+ using KernelType = Kernel<kPath, float, float, float, float>;
+
+ MulParams<AccumScalar, DstScalar> mul_params;
+ PopulatePerChannelBuffers(&mul_params);
+
+ const int kMatrixSize = 3;
+ ruy::Matrix<LhsScalar> lhs;
+ ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kRowMajor,
+ lhs.mutable_layout());
+ const LhsScalar lhs_data[kMatrixSize * kMatrixSize] = {0};
+ lhs.set_data(lhs_data);
+ ruy::Matrix<RhsScalar> rhs;
+ ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kColMajor,
+ rhs.mutable_layout());
+ const RhsScalar rhs_data[kMatrixSize * kMatrixSize] = {0};
+ rhs.set_data(rhs_data);
+ DstScalar dst_data[kMatrixSize * kMatrixSize] = {0};
+ ruy::Matrix<DstScalar> dst;
+ ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kColMajor,
+ dst.mutable_layout());
+ dst.set_data(dst_data);
+
+ ruy::Context context;
+
+ auto test_advisory = [&](bool expect_advisory,
+ ChannelDimension channel_dimension,
+ int capacity_rounding) {
+ mul_params.set_channel_dimension(channel_dimension);
+ mul_params.set_perchannel_buffers_capacity_rounding(capacity_rounding);
+ ruy::Mul<kPath>(lhs, rhs, mul_params, &context, &dst);
+ EXPECT_EQ(context.performance_advisory(
+ PerformanceAdvisory::kReallocatedPerChannelBuffer),
+ expect_advisory);
+ };
+
+ static_assert(KernelType::LhsLayout::kCols == 16, "");
+ test_advisory(true, ChannelDimension::kRow, 1);
+ test_advisory(true, ChannelDimension::kRow, 2);
+ test_advisory(true, ChannelDimension::kRow, 4);
+ test_advisory(true, ChannelDimension::kRow, 8);
+ test_advisory(false, ChannelDimension::kRow, 16);
+ test_advisory(false, ChannelDimension::kRow, 32);
+ test_advisory(false, ChannelDimension::kRow, 64);
+
+ static_assert(KernelType::RhsLayout::kCols == 8, "");
+ test_advisory(true, ChannelDimension::kCol, 1);
+ test_advisory(true, ChannelDimension::kCol, 2);
+ test_advisory(true, ChannelDimension::kCol, 4);
+ test_advisory(false, ChannelDimension::kCol, 8);
+ test_advisory(false, ChannelDimension::kCol, 16);
+ test_advisory(false, ChannelDimension::kCol, 32);
+ test_advisory(false, ChannelDimension::kCol, 64);
+}
+
+TEST(PerChannelBuffersReallocationTest, Float) {
+ TestPerChannelBuffersReallocation<float, float, float, float>();
+}
+
+TEST(PerChannelBuffersReallocationTest, Quantized) {
+ TestPerChannelBuffersReallocation<std::int8_t, std::int8_t, std::int32_t,
+ std::int8_t>();
+}
+
+TEST(PerChannelBuffersReallocationTest, RawInt32) {
+ TestPerChannelBuffersReallocation<std::int8_t, std::int8_t, std::int32_t,
+ std::int32_t>();
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/performance_advisory.h b/ruy/performance_advisory.h
new file mode 100644
index 0000000..02dd5d3
--- /dev/null
+++ b/ruy/performance_advisory.h
@@ -0,0 +1,40 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PERFORMANCE_ADVISORY_H_
+#define RUY_RUY_PERFORMANCE_ADVISORY_H_
+
+namespace ruy {
+
+enum class PerformanceAdvisory {
+ kNone = 0,
+ kReallocatedPerChannelBuffer = 0x1
+};
+
+inline constexpr PerformanceAdvisory operator|(PerformanceAdvisory p,
+ PerformanceAdvisory q) {
+ return static_cast<PerformanceAdvisory>(static_cast<int>(p) |
+ static_cast<int>(q));
+}
+
+inline constexpr PerformanceAdvisory operator&(PerformanceAdvisory p,
+ PerformanceAdvisory q) {
+ return static_cast<PerformanceAdvisory>(static_cast<int>(p) &
+ static_cast<int>(q));
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_PERFORMANCE_ADVISORY_H_
diff --git a/ruy/platform.h b/ruy/platform.h
new file mode 100644
index 0000000..eb51931
--- /dev/null
+++ b/ruy/platform.h
@@ -0,0 +1,159 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Preprocessor platform check macros.
+// Note that ruy_copts contains '-Wundef', which ensures that we get a compile
+// error if these macros are mistyped or if they are used without having
+// #included this header.
+
+#ifndef RUY_RUY_PLATFORM_H_
+#define RUY_RUY_PLATFORM_H_
+
+#ifdef __ANDROID_NDK__
+#include <android/ndk-version.h>
+#endif
+
+// Detect APPLE.
+#ifdef __APPLE__
+#define RUY_PLATFORM_APPLE 1
+#else
+#define RUY_PLATFORM_APPLE 0
+#endif
+
+// Detect APPLE.
+#if defined(__ppc__) || defined(__powerpc__)
+#define RUY_PLATFORM_PPC 1
+#else
+#define RUY_PLATFORM_PPC 0
+#endif
+
+// Detect Fuchsia
+#ifdef __Fuchsia__
+#define RUY_PLATFORM_FUCHSIA 1
+#else
+#define RUY_PLATFORM_FUCHSIA 0
+#endif
+
+// Architecture-level platform detection.
+//
+// Ruy requires these to be mutually exclusive.
+
+// Detect x86.
+#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \
+ defined(__x86__) || defined(__X86__) || defined(_X86_) || \
+ defined(_M_IX86) || defined(_M_X64)
+#define RUY_PLATFORM_X86 1
+#else
+#define RUY_PLATFORM_X86 0
+#endif
+
+// Detect ARM 32-bit.
+#ifdef __arm__
+#define RUY_PLATFORM_ARM_32 1
+#else
+#define RUY_PLATFORM_ARM_32 0
+#endif
+
+// Detect ARM 64-bit.
+#ifdef __aarch64__
+#define RUY_PLATFORM_ARM_64 1
+#else
+#define RUY_PLATFORM_ARM_64 0
+#endif
+
+// Combined ARM.
+#define RUY_PLATFORM_ARM (RUY_PLATFORM_ARM_64 || RUY_PLATFORM_ARM_32)
+
+// Feature and capability platform detection.
+//
+// These are mostly sub-selections of architectures.
+
+// Detect NEON. Explicitly avoid emulation, or anything like it, on x86.
+#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM_X86
+#define RUY_PLATFORM_NEON 1
+#else
+#define RUY_PLATFORM_NEON 0
+#endif
+
+// Define ARM 32-bit NEON.
+#define RUY_PLATFORM_NEON_32 (RUY_PLATFORM_NEON && RUY_PLATFORM_ARM_32)
+
+// Define ARM 64-bit NEON.
+// Note: NEON is implied by ARM64, so this define is redundant.
+// It still allows some conveyance of intent.
+#define RUY_PLATFORM_NEON_64 (RUY_PLATFORM_NEON && RUY_PLATFORM_ARM_64)
+
+// Determine whether to enable X86 non-portable performance improvements,
+// typically x86 SIMD paths (AVX, etc).
+#if defined(RUY_FORCE_ENABLE_X86_ENHANCEMENTS)
+#define RUY_PLATFORM_X86_ENHANCEMENTS 1
+#elif defined(__EMSCRIPTEN__)
+// We use some x86 asm e.g. in runtime CPU detection and to implement missing
+// intrinsics. This can't build to Emscripten.
+#define RUY_PLATFORM_X86_ENHANCEMENTS 0
+#elif defined(__ANDROID_NDK__) && defined(__NDK_MAJOR__) && \
+ (__NDK_MAJOR__ >= 20)
+// Enable on sufficiently recent Android NDK. Earlier versions had broken
+// intrinsics headers.
+#define RUY_PLATFORM_X86_ENHANCEMENTS 1
+#elif defined(__linux__) && defined(__clang__) && (__clang_major__ >= 8)
+// Enable on recent versions of Clang on Linux. Might be possible
+// to relax this version requirement.
+// Not enabling on Apple at the moment because b/138922878, see comment #8, we
+// may only need to disable this on XCode <= 10.2.
+#define RUY_PLATFORM_X86_ENHANCEMENTS 1
+#elif defined(__GNUC__) && (__GNUC__ >= 9)
+// Enable on recent versions of GCC. Might be possible
+// to relax this version requirement.
+#define RUY_PLATFORM_X86_ENHANCEMENTS 1
+// Things are working on MSVC 2019. This should also enable on sufficiently
+// recent Clang-CL.
+#elif defined(_MSC_VER) && (_MSC_VER >= 1920)
+#define RUY_PLATFORM_X86_ENHANCEMENTS 1
+#else
+#define RUY_PLATFORM_X86_ENHANCEMENTS 0
+#endif
+
+// These CPU capabilities will all be true when Skylake, etc, are enabled during
+// compilation.
+#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && \
+ defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \
+ defined(__AVX512BW__) && defined(__AVX512VL__)
+#define RUY_PLATFORM_AVX512 1
+#else
+#define RUY_PLATFORM_AVX512 0
+#endif
+
+#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX2__) && \
+ (defined(__FMA__) || defined(_MSC_VER))
+#define RUY_PLATFORM_AVX2_FMA 1
+#else
+#define RUY_PLATFORM_AVX2_FMA 0
+#endif
+
+#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX__)
+#define RUY_PLATFORM_AVX 1
+#else
+#define RUY_PLATFORM_AVX 0
+#endif
+
+// Detect Emscripten, typically Wasm.
+#ifdef __EMSCRIPTEN__
+#define RUY_PLATFORM_EMSCRIPTEN 1
+#else
+#define RUY_PLATFORM_EMSCRIPTEN 0
+#endif
+
+#endif // RUY_RUY_PLATFORM_H_
diff --git a/ruy/pmu.cc b/ruy/pmu.cc
new file mode 100644
index 0000000..d1a60be
--- /dev/null
+++ b/ruy/pmu.cc
@@ -0,0 +1,297 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/pmu.h"
+
+#include "ruy/check_macros.h"
+
+#ifdef __linux__
+#include <asm/unistd.h>
+#include <linux/perf_event.h>
+#include <sys/ioctl.h>
+#include <syscall.h>
+#include <unistd.h>
+
+#include <cstdio>
+#endif
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+
+namespace ruy {
+
+// Linux-specific. Not ARM-specific.
+#ifdef __linux__
+class PerfEvent {
+ public:
+ PerfEvent(std::uint32_t type, std::uint64_t config) {
+ perf_event_attr pe;
+ memset(&pe, 0, sizeof(pe));
+ pe.size = sizeof(pe);
+ pe.type = type;
+ pe.config = config;
+ pe.disabled = 1;
+ pe.exclude_kernel = 1;
+ pe.exclude_hv = 1;
+ pe.inherit = 1;
+ fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0);
+ if (fd_ == -1) {
+ fprintf(stderr, "perf_event_open failed for config 0x%lx\n",
+ static_cast<unsigned long>(config));
+ // abort();
+ }
+ }
+
+ ~PerfEvent() {
+ RUY_CHECK(!started_);
+ close(fd_);
+ }
+
+ void Start() {
+ RUY_CHECK(!started_);
+ started_ = true;
+ ioctl(fd_, PERF_EVENT_IOC_RESET, 0);
+ ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0);
+ count_at_start_ = Read();
+ }
+
+ void Stop() {
+ RUY_CHECK(started_);
+ started_ = false;
+ ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0);
+ count_at_stop_ = Read();
+ }
+
+ std::int64_t Count() const {
+ RUY_CHECK(!started_);
+ return count_at_stop_ - count_at_start_;
+ }
+
+ private:
+ std::int64_t Read() const {
+ std::int64_t count;
+ RUY_CHECK_NE(read(fd_, &count, sizeof(count)), -1);
+ return count;
+ }
+ std::int64_t count_at_start_ = -1;
+ std::int64_t count_at_stop_ = -1;
+ bool started_ = false;
+ int fd_ = -1;
+};
+#else
+// Placeholder implementation to at least compile outside of linux.
+#define PERF_TYPE_RAW 0
+class PerfEvent {
+ public:
+ PerfEvent(std::uint32_t, std::uint64_t) {}
+ ~PerfEvent() {}
+ void Start() {}
+ void Stop() {}
+ std::int64_t Count() const { return 0; }
+};
+#endif
+
+// ARM-specific. Query ARM PMU counters as Linux perf events using
+// PERF_TYPE_RAW.
+namespace arm_pmuv3 {
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-const-variable"
+
+// These event numbers are listed in the ARMv8 architecture reference manual.
+constexpr std::uint16_t L1I_CACHE_REFILL = 0x01;
+constexpr std::uint16_t L1I_TLB_REFILL = 0x02;
+constexpr std::uint16_t L1D_CACHE_REFILL = 0x03;
+constexpr std::uint16_t L1D_CACHE = 0x04;
+constexpr std::uint16_t L1D_TLB_REFILL = 0x05;
+constexpr std::uint16_t LD_RETIRED = 0x06;
+constexpr std::uint16_t ST_RETIRED = 0x07;
+constexpr std::uint16_t INST_RETIRED = 0x08;
+constexpr std::uint16_t EXC_TAKEN = 0x09;
+constexpr std::uint16_t EXC_RETURN = 0x0A;
+constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B;
+constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C;
+constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D;
+constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E;
+constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F;
+constexpr std::uint16_t BR_MIS_PRED = 0x10;
+constexpr std::uint16_t CPU_CYCLES = 0x11;
+constexpr std::uint16_t BR_PRED = 0x12;
+constexpr std::uint16_t MEM_ACCESS = 0x13;
+constexpr std::uint16_t L1I_CACHE = 0x14;
+constexpr std::uint16_t L1D_CACHE_WB = 0x15;
+constexpr std::uint16_t L2D_CACHE = 0x16;
+constexpr std::uint16_t L2D_CACHE_REFILL = 0x17;
+constexpr std::uint16_t L2D_CACHE_WB = 0x18;
+constexpr std::uint16_t BUS_ACCESS = 0x19;
+constexpr std::uint16_t MEMORY_ERROR = 0x1A;
+constexpr std::uint16_t INST_SPEC = 0x1B;
+constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C;
+constexpr std::uint16_t BUS_CYCLES = 0x1D;
+constexpr std::uint16_t CHAIN = 0x1E;
+constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F;
+constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20;
+constexpr std::uint16_t BR_RETIRED = 0x21;
+constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22;
+constexpr std::uint16_t STALL_FRONTEND = 0x23;
+constexpr std::uint16_t STALL_BACKEND = 0x24;
+constexpr std::uint16_t L1D_TLB = 0x25;
+constexpr std::uint16_t L1I_TLB = 0x26;
+constexpr std::uint16_t L2I_CACHE = 0x27;
+constexpr std::uint16_t L2I_CACHE_REFILL = 0x28;
+constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29;
+constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A;
+constexpr std::uint16_t L3D_CACHE = 0x2B;
+constexpr std::uint16_t L3D_CACHE_WB = 0x2C;
+constexpr std::uint16_t L2D_TLB_REFILL = 0x2D;
+constexpr std::uint16_t L2I_TLB_REFILL = 0x2E;
+constexpr std::uint16_t L2D_TLB = 0x2F;
+constexpr std::uint16_t L2I_TLB = 0x30;
+constexpr std::uint16_t LL_CACHE = 0x32;
+constexpr std::uint16_t LL_CACHE_MISS = 0x33;
+constexpr std::uint16_t DTLB_WALK = 0x34;
+constexpr std::uint16_t LL_CACHE_RD = 0x36;
+constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37;
+
+// Additional implementation-defined events found by googling around.
+constexpr std::uint16_t L1D_CACHE_RD = 0x40;
+constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42;
+constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C;
+constexpr std::uint16_t L1D_TLB_RD = 0x4E;
+constexpr std::uint16_t L2D_CACHE_RD = 0x50;
+constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52;
+constexpr std::uint16_t BUS_ACCESS_RD = 0x60;
+constexpr std::uint16_t MEM_ACCESS_RD = 0x66;
+constexpr std::uint16_t L3D_CACHE_RD = 0xA0;
+constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2;
+
+#pragma GCC diagnostic pop
+
+} // namespace arm_pmuv3
+
+class PmuEventsPrivate {
+ public:
+ PmuEventsPrivate()
+ : l1d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_REFILL),
+ l2d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_REFILL),
+ l3d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L3D_CACHE_REFILL),
+ ll_cache_miss(PERF_TYPE_RAW, arm_pmuv3::LL_CACHE_MISS),
+ l1d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_TLB_REFILL),
+ l2d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_TLB_REFILL),
+ stall_frontend(PERF_TYPE_RAW, arm_pmuv3::STALL_FRONTEND),
+ stall_backend(PERF_TYPE_RAW, arm_pmuv3::STALL_BACKEND),
+ br_mis_pred(PERF_TYPE_RAW, arm_pmuv3::BR_MIS_PRED),
+ l1d_cache_writeback(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_WB),
+ l2d_cache_writeback(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_WB) {}
+
+ private:
+ friend class PmuEvents;
+ PerfEvent l1d_cache_refill;
+ PerfEvent l2d_cache_refill;
+ PerfEvent l3d_cache_refill;
+ PerfEvent ll_cache_miss;
+ PerfEvent l1d_tlb_refill;
+ PerfEvent l2d_tlb_refill;
+ PerfEvent stall_frontend;
+ PerfEvent stall_backend;
+ PerfEvent br_mis_pred;
+ PerfEvent l1d_cache_writeback;
+ PerfEvent l2d_cache_writeback;
+};
+
+PmuEvents::PmuEvents() : priv(new PmuEventsPrivate) {}
+PmuEvents::~PmuEvents() { delete priv; }
+
+void PmuEvents::StartRecording() {
+ priv->l1d_cache_refill.Start();
+ priv->l2d_cache_refill.Start();
+ priv->l3d_cache_refill.Start();
+ priv->ll_cache_miss.Start();
+ priv->l1d_tlb_refill.Start();
+ priv->l2d_tlb_refill.Start();
+ priv->stall_frontend.Start();
+ priv->stall_backend.Start();
+ priv->br_mis_pred.Start();
+ priv->l1d_cache_writeback.Start();
+ priv->l2d_cache_writeback.Start();
+}
+
+void PmuEvents::StopRecording() {
+ priv->l1d_cache_refill.Stop();
+ priv->l2d_cache_refill.Stop();
+ priv->l3d_cache_refill.Stop();
+ priv->ll_cache_miss.Stop();
+ priv->l1d_tlb_refill.Stop();
+ priv->l2d_tlb_refill.Stop();
+ priv->stall_frontend.Stop();
+ priv->stall_backend.Stop();
+ priv->br_mis_pred.Stop();
+ priv->l1d_cache_writeback.Stop();
+ priv->l2d_cache_writeback.Stop();
+}
+
+float PmuEvents::BranchMispredictionCount() const {
+ return static_cast<float>(priv->br_mis_pred.Count());
+}
+
+float PmuEvents::FrontendStallCount() const {
+ return static_cast<float>(priv->stall_frontend.Count());
+}
+
+float PmuEvents::BackendStallCount() const {
+ return static_cast<float>(priv->stall_backend.Count());
+}
+
+float PmuEvents::L1RefillCount() const {
+ return static_cast<float>(priv->l1d_cache_refill.Count());
+}
+
+float PmuEvents::L2RefillCount() const {
+ return static_cast<float>(priv->l2d_cache_refill.Count());
+}
+
+float PmuEvents::L3RefillCount() const {
+ // Important: this was discovered in the context of the above experiments,
+ // which also tested the _RD variants of these counters. So it's possible that
+ // it's just not needed here with the default (non _RD) counters.
+ //
+ // Some CPUs implement LL_CACHE_MISS[_RD], some implement
+ // L3D_CACHE_REFILL[_RD]. It seems that either one of these two counters is
+ // zero, or they roughly both agree with each other. Therefore, taking the max
+ // of them is a reasonable way to get something more portable across various
+ // CPUs.
+ return static_cast<float>(
+ std::max(priv->l3d_cache_refill.Count(), priv->ll_cache_miss.Count()));
+}
+
+float PmuEvents::L1TLBRefillCount() const {
+ return static_cast<float>(priv->l1d_tlb_refill.Count());
+}
+
+float PmuEvents::L2TLBRefillCount() const {
+ return static_cast<float>(priv->l2d_tlb_refill.Count());
+}
+
+float PmuEvents::L1WritebackCount() const {
+ return static_cast<float>(priv->l1d_cache_writeback.Count());
+}
+
+float PmuEvents::L2WritebackCount() const {
+ return static_cast<float>(priv->l2d_cache_writeback.Count());
+}
+
+} // namespace ruy
diff --git a/ruy/pmu.h b/ruy/pmu.h
new file mode 100644
index 0000000..2d769e2
--- /dev/null
+++ b/ruy/pmu.h
@@ -0,0 +1,46 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PMU_H_
+#define RUY_RUY_PMU_H_
+
+namespace ruy {
+
+class PmuEventsPrivate;
+
+class PmuEvents {
+ public:
+ PmuEvents();
+ ~PmuEvents();
+ void StartRecording();
+ void StopRecording();
+ float L1RefillCount() const;
+ float L2RefillCount() const;
+ float L3RefillCount() const;
+ float BranchMispredictionCount() const;
+ float FrontendStallCount() const;
+ float BackendStallCount() const;
+ float L1TLBRefillCount() const;
+ float L2TLBRefillCount() const;
+ float L1WritebackCount() const;
+ float L2WritebackCount() const;
+
+ private:
+ PmuEventsPrivate* priv = nullptr;
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_PMU_H_
diff --git a/ruy/prepacked_cache.cc b/ruy/prepacked_cache.cc
new file mode 100644
index 0000000..ee891cb
--- /dev/null
+++ b/ruy/prepacked_cache.cc
@@ -0,0 +1,129 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/prepacked_cache.h"
+
+#include "ruy/mat.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/system_aligned_alloc.h"
+
+namespace ruy {
+
+namespace {
+
+// Allocates the `data` and `sums` buffers, and sets the corresponding
+// pointer fields, in a PEMat whose other fields, particularly `layout`
+// and the runtime data types, are already populated.
+int AllocateBuffers(PEMat* packed_matrix) {
+ const int data_bytes = DataBytes(*packed_matrix);
+ packed_matrix->data = detail::SystemAlignedAlloc(data_bytes);
+ int sums_bytes = 0;
+ if (!packed_matrix->sums_type.is_floating_point) {
+ // Integer quantized matrices also need the `sums` buffer.
+ sums_bytes = SumsBytes(*packed_matrix);
+ packed_matrix->sums = detail::SystemAlignedAlloc(sums_bytes);
+ }
+ return data_bytes + sums_bytes;
+}
+
+// Frees the `data` and `sums` buffers held by a PEMat.
+void FreeBuffers(const PEMat& packed_matrix) {
+ detail::SystemAlignedFree(packed_matrix.data);
+ detail::SystemAlignedFree(packed_matrix.sums);
+}
+
+} // end anonymous namespace
+
+std::size_t PrepackedCache::KeyHash::operator()(
+ const PrepackedCache::Key& key) const {
+ std::size_t src_data_hash = reinterpret_cast<std::size_t>(key.src_data);
+ // Naive hash of the layout. Based on some heuristic reasoning, not any
+ // benchmarking.
+ // A choice of hash function here is just an optimization matter
+ // anyway, since a hash collision only results in some Key::operator== calls
+ // to disambiguate, and even just returning src_data_hash, ignoring the layout
+ // altogether, would probably be good enough, as the case of multiple entries
+ // with the same data pointer will be uncommon.
+ // Here we multiply-add the layout fields using some small constant prime
+ // numbers as multipliers. The conventional approach of xor-ing bit-rotations
+ // would result in some hash collisions because these values are typically
+ // small positive integers, so bit-rotations are essentially bit-shifts,
+ // and powers of two are common.
+ std::size_t packed_layout_hash =
+ static_cast<int>(key.packed_layout.order) +
+ static_cast<int>(key.packed_layout.kernel.order) * 2 +
+ key.packed_layout.stride * 3 + key.packed_layout.kernel.rows * 5 +
+ key.packed_layout.kernel.cols * 7 + key.packed_layout.rows * 11 +
+ key.packed_layout.cols * 13;
+ return src_data_hash ^ packed_layout_hash;
+}
+
+PrepackedCache::~PrepackedCache() {
+ for (const auto& pair : cache_) {
+ FreeBuffers(pair.second.packed_matrix);
+ }
+}
+
+PrepackedCache::Action PrepackedCache::Get(const void* src_data,
+ PEMat* packed_matrix) {
+ // Construct a Key and look up the cache.
+ Key key;
+ key.src_data = src_data;
+ key.packed_layout = packed_matrix->layout;
+ key.zero_point = packed_matrix->zero_point;
+ const auto& itr = cache_.find(key);
+
+ if (itr != cache_.end()) {
+ // Found existing entry. Update its timestamp and return it.
+ itr->second.timestamp = timestamp_++;
+ *packed_matrix = itr->second.packed_matrix;
+ return Action::kGotExistingEntry;
+ }
+
+ // No existing entry found. Allocate new buffers now and insert in the cache.
+ const int new_bytes = AllocateBuffers(packed_matrix);
+ EjectUntilRoomFor(new_bytes);
+ Entry entry{*packed_matrix, timestamp_++};
+ cache_.emplace(key, entry);
+ buffers_bytes_ += new_bytes;
+ return Action::kInsertedNewEntry;
+}
+
+void PrepackedCache::EjectUntilRoomFor(int new_bytes) {
+ profiler::ScopeLabel label("PrepackedCacheEjection");
+ // While we are above the threshold of ejection, eject the LRU entry.
+ while (!cache_.empty() && buffers_bytes_ + new_bytes > max_buffers_bytes_) {
+ EjectOne();
+ }
+}
+
+void PrepackedCache::EjectOne() {
+ auto oldest = cache_.begin();
+ Timestamp oldest_timestamp = oldest->second.timestamp;
+ {
+ for (auto itr = cache_.begin(); itr != cache_.end(); ++itr) {
+ if (itr->second.timestamp < oldest_timestamp) {
+ oldest = itr;
+ oldest_timestamp = itr->second.timestamp;
+ }
+ }
+ }
+ const PEMat& packed_matrix = oldest->second.packed_matrix;
+ buffers_bytes_ -= DataBytes(packed_matrix) + SumsBytes(packed_matrix);
+ FreeBuffers(packed_matrix);
+ cache_.erase(oldest);
+}
+
+} // namespace ruy
diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h
new file mode 100644
index 0000000..cb3a113
--- /dev/null
+++ b/ruy/prepacked_cache.h
@@ -0,0 +1,141 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PREPACKED_CACHE_H_
+#define RUY_RUY_PREPACKED_CACHE_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <unordered_map>
+
+#include "ruy/mat.h"
+
+namespace ruy {
+
+// "Low effort" Least Recently Used Cache for Prepacked Matrices
+// A cache mechanism for prepacked matrices that ejects oldest entries.
+// The implementation is "low effort" in the following ways:
+// - we just linearly search for the oldest entry when doing an ejection
+// - the ejection policy is very simple: if the new size would be above the
+// . threshold, we will eject entries until the size is below the threshold.
+// Current use cases (RNNs with GEMV operations) indicate that ejection is rare
+// and memory constraints are tight, so we devote no additional storage to the
+// LRU mechanism and accept O(n) search to eject oldest entry. In practice,
+// the number of total entries has not been shown to be large.
+//
+// An instance of PrepackedCache is always owned by a Context. Just like
+// Context, no "thread safety" consideration is applicable to this class:
+// no two threads may simulaneously be accessing the same instance.
+class PrepackedCache final {
+ public:
+ enum class Action { kGotExistingEntry, kInsertedNewEntry };
+
+ static constexpr int kDefaultMaxBuffersBytes = 1 << 28;
+
+ // A key in this key-value cache. Equality of keys implies interchangeable
+ // packed matrices, so we must be careful to make this Key type specific
+ // enough, and its equality comparison operator strict enough.
+ //
+ // These keys need to be used before the packed matrix buffers are allocated
+ // (since they are used to determine whether to allocate a new buffer).
+ // So they instead use the source matrix's data pointer. On the other hand,
+ // the packed matrix's layout structure is already available by the time we
+ // need to handle Keys, and that's fortunate because it is more specific
+ // than the source matrix's layout: it also encodes details about the kernel's
+ // small-scale block layout. In the past, we made the kernel function pointer
+ // part of the cache key, but all that is relevant here is the packed layout.
+ //
+ // The only other field that needs to be involved is the zero_point, for
+ // quantized matrices, although it seems far-fetched that the same matrix
+ // data would be reused with different zero_point values.
+ //
+ // The data types (PEMat::data_type and PEMat::sums_type) are omitted based on
+ // the "strict aliasing" model: each memory location should contain data of
+ // only one type. This could be relaxed in the future, by adding data types
+ // to this Key type, if a use case arises.
+ struct Key {
+ // The source matrix's data pointer.
+ const void* src_data;
+ // The packed matrix's layout, see PEMat::layout.
+ PMatLayout packed_layout;
+ // The packed matrix's zero point (for integer-quantized matrices only).
+ std::int32_t zero_point;
+ };
+
+ friend bool operator==(const Key& a, const Key& b) {
+ return a.src_data == b.src_data && a.packed_layout == b.packed_layout &&
+ a.zero_point == b.zero_point;
+ }
+
+ struct KeyHash {
+ std::size_t operator()(const Key&) const;
+ };
+
+ // A dummy timestamp to associate to each entry (see struct Entry) to
+ // determine which entry is "least recently used" when ejecting entries.
+ // This is just an integer counter, not related to physical time.
+ // It does not need to be atomic because only one thread uses an instance
+ // of PrepackedCache at a time (see class comment).
+ using Timestamp = std::uint64_t;
+
+ struct Entry {
+ PEMat packed_matrix;
+ Timestamp timestamp;
+ };
+
+ explicit PrepackedCache(int max_buffers_bytes = kDefaultMaxBuffersBytes)
+ : max_buffers_bytes_(max_buffers_bytes) {}
+
+ ~PrepackedCache();
+
+ // Returns the total size in bytes of buffers held in this cache.
+ int BuffersBytes() const { return buffers_bytes_; }
+
+ // Returns the number of packed matrices held in this cache.
+ int MatrixCount() const { return cache_.size(); }
+
+ // This is the method by which new matrices are cached, and existing cache
+ // entries are queried.
+ // `src_data` is the source matrix data pointer.
+ // `packed_matrix` is a packed matrix structure where all fields have already
+ // been populated, except for the `data` and `sums` pointers which have not
+ // yet been allocated.
+ //
+ // This method:
+ // 1. Queries the cache for an entry matching the given `src_data` pointer and
+ // the relevant fields of `packed_matrix`, particularly its `layout`.
+ // 2. If a matching cache entry does not exist, it is created and inserted
+ // into the cache, and its `data` and `sums` buffers are allocated.
+ // 3. The `packed_matrix` has its `data` and `sums` pointers set to point
+ // to the allocated buffers.
+ // 4. The cache entry's timestamp is updated so it's the most recently used
+ // entry.
+ // 5. The return value is Action::kInsertedNewEntry if at step 2 a new
+ // entry was created. Otherwise it is Action::kGotExistingEntry.
+ Action Get(const void* src_data, PEMat* packed_matrix);
+
+ private:
+ void EjectOne();
+ void EjectUntilRoomFor(int new_bytes);
+
+ std::unordered_map<Key, Entry, KeyHash> cache_;
+ const int max_buffers_bytes_;
+ int buffers_bytes_ = 0;
+ Timestamp timestamp_ = 0;
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_PREPACKED_CACHE_H_
diff --git a/ruy/prepacked_cache_test.cc b/ruy/prepacked_cache_test.cc
new file mode 100644
index 0000000..9625931
--- /dev/null
+++ b/ruy/prepacked_cache_test.cc
@@ -0,0 +1,309 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/prepacked_cache.h"
+
+#include <thread> // NOLINT(build/c++11)
+
+#include "ruy/context.h"
+#include "ruy/context_get_ctx.h"
+#include "ruy/gtest_wrapper.h"
+#include "ruy/mat.h"
+#include "ruy/matrix.h"
+#include "ruy/ruy.h"
+#include "ruy/time.h"
+
+namespace ruy {
+namespace {
+
+PEMat MakeDummyPEMat(Type data_type, int rows, int cols) {
+ PEMat ret;
+ ret.data_type = data_type;
+ if (!data_type.is_floating_point) {
+ ret.sums_type = Type::Create<std::int32_t>();
+ }
+ ret.layout.rows = rows;
+ ret.layout.cols = cols;
+ ret.layout.stride = rows;
+ ret.layout.order = Order::kColMajor;
+ // The kernel block layout is not relevant to this test, so we leave it
+ // trivial 1x1.
+ ret.layout.kernel.rows = 1;
+ ret.layout.kernel.cols = 1;
+ return ret;
+}
+
+template <typename T>
+void DummyPack(const std::vector<T>& data, PEMat* packed_matrix) {
+ EXPECT_EQ(data.size(), FlatSize(packed_matrix->layout));
+ memcpy(packed_matrix->data, data.data(), data.size() * sizeof(T));
+}
+
+TEST(PrepackedCacheTest, TestCacheBasic) {
+ PrepackedCache prepacked_cache(307);
+ // Allocate the prepacked matrix.
+ // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes
+ std::vector<std::uint8_t> data1(10 * 20);
+ PEMat mat1 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20);
+ EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) ==
+ PrepackedCache::Action::kInsertedNewEntry);
+ DummyPack(data1, &mat1);
+
+ // DataBytes=15, SumsBytes=3*4=12, Total: 27 bytes
+ std::vector<std::uint8_t> data2(5 * 3);
+ PEMat mat2 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 5, 3);
+ EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) ==
+ PrepackedCache::Action::kInsertedNewEntry);
+ DummyPack(data2, &mat2);
+
+ // Both should now be in cache.
+ EXPECT_EQ(prepacked_cache.MatrixCount(), 2);
+ EXPECT_EQ(prepacked_cache.BuffersBytes(), 307);
+ EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) ==
+ PrepackedCache::Action::kGotExistingEntry);
+ EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) ==
+ PrepackedCache::Action::kGotExistingEntry);
+}
+
+TEST(PrepackedCacheTest, TestCacheBasicFloat) {
+ PrepackedCache prepacked_cache(860);
+ // Allocate the prepacked matrix.
+ // DataBytes=200*4, SumsBytes=0 because float, Total: 800 bytes
+ std::vector<float> data1(10 * 20);
+ PEMat mat1 = MakeDummyPEMat(Type::Create<float>(), 10, 20);
+ EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) ==
+ PrepackedCache::Action::kInsertedNewEntry);
+ DummyPack(data1, &mat1);
+
+ // DataBytes=15*4, SumsBytes=0 because float, Total: 60 bytes
+ std::vector<float> data2(5 * 3);
+ PEMat mat2 = MakeDummyPEMat(Type::Create<float>(), 5, 3);
+ EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) ==
+ PrepackedCache::Action::kInsertedNewEntry);
+ DummyPack(data2, &mat2);
+
+ // Both should now be in cache.
+ EXPECT_EQ(prepacked_cache.MatrixCount(), 2);
+ EXPECT_EQ(prepacked_cache.BuffersBytes(), 860);
+ EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) ==
+ PrepackedCache::Action::kGotExistingEntry);
+ EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) ==
+ PrepackedCache::Action::kGotExistingEntry);
+}
+
+TEST(PrepackedCacheTest, TestCacheEjection) {
+ PrepackedCache prepacked_cache(306);
+ // Allocate the prepacked matrix.
+ // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes
+ std::vector<std::uint8_t> data1(10 * 20);
+ PEMat mat1 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20);
+ prepacked_cache.Get(data1.data(), &mat1);
+ DummyPack(data1, &mat1);
+
+ // DataBytes=15, SumsBytes=3*4=12, Total: 27 bytes
+ std::vector<std::uint8_t> data2(5 * 3);
+ PEMat mat2 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 5, 3);
+ prepacked_cache.Get(data2.data(), &mat2);
+ DummyPack(data2, &mat2);
+
+ // The first matrix should have been ejected from the cache.
+ // Only the second matrix should now be in cache.
+ EXPECT_EQ(prepacked_cache.MatrixCount(), 1);
+ EXPECT_EQ(prepacked_cache.BuffersBytes(), 27);
+ EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) ==
+ PrepackedCache::Action::kGotExistingEntry);
+ EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) ==
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ // The second matrix should have been ejected from the cache.
+ // Only the first matrix should now be in cache.
+ EXPECT_EQ(prepacked_cache.MatrixCount(), 1);
+ EXPECT_EQ(prepacked_cache.BuffersBytes(), 280);
+ EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) ==
+ PrepackedCache::Action::kGotExistingEntry);
+ EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) ==
+ PrepackedCache::Action::kInsertedNewEntry);
+}
+
+TEST(PrepackedCacheTest, TestCacheEjection2) {
+ PrepackedCache prepacked_cache(1000);
+ // Allocate the prepacked matrix 1.
+ // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes
+ std::vector<std::uint8_t> data1(10 * 20);
+ PEMat mat1 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20);
+ prepacked_cache.Get(data1.data(), &mat1);
+ DummyPack(data1, &mat1);
+
+ // Allocate the prepacked matrix 2.
+ // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes
+ std::vector<std::uint8_t> data2(10 * 20);
+ PEMat mat2 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20);
+ prepacked_cache.Get(data2.data(), &mat2);
+ DummyPack(data2, &mat2);
+
+ // Allocate the prepacked matrix 3.
+ // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes
+ std::vector<std::uint8_t> data3(10 * 20);
+ PEMat mat3 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20);
+ prepacked_cache.Get(data3.data(), &mat3);
+ DummyPack(data3, &mat3);
+
+ // The next insertion will cause the cache size to go over the ejection
+ // threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest
+ EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) ==
+ PrepackedCache::Action::kGotExistingEntry);
+ EXPECT_TRUE(prepacked_cache.Get(data3.data(), &mat3) ==
+ PrepackedCache::Action::kGotExistingEntry);
+
+ // Allocate the prepacked matrix 4.
+ // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes
+ std::vector<std::uint8_t> data4(10 * 20);
+ PEMat mat4 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20);
+ prepacked_cache.Get(data4.data(), &mat4);
+ DummyPack(data4, &mat4);
+
+ // Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not.
+ EXPECT_EQ(prepacked_cache.MatrixCount(), 3);
+ EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) ==
+ PrepackedCache::Action::kGotExistingEntry);
+ EXPECT_TRUE(prepacked_cache.Get(data3.data(), &mat3) ==
+ PrepackedCache::Action::kGotExistingEntry);
+ EXPECT_TRUE(prepacked_cache.Get(data4.data(), &mat4) ==
+ PrepackedCache::Action::kGotExistingEntry);
+ EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) ==
+ PrepackedCache::Action::kInsertedNewEntry);
+}
+
+TEST(PrepackedCacheTest, TestDistinguishSubtlyDifferentMatrices) {
+ PrepackedCache prepacked_cache;
+
+ std::vector<std::uint8_t> data(10 * 20);
+ PEMat mat = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20);
+ EXPECT_EQ(prepacked_cache.Get(data.data(), &mat),
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ // Same layout, different source data pointer
+ EXPECT_EQ(prepacked_cache.Get(data.data() + 1, &mat),
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ // Layout tweaks
+ mat.layout.rows = 9;
+ EXPECT_EQ(prepacked_cache.Get(data.data(), &mat),
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ mat.layout.cols = 19;
+ EXPECT_EQ(prepacked_cache.Get(data.data(), &mat),
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ mat.layout.order = Order::kRowMajor;
+ EXPECT_EQ(prepacked_cache.Get(data.data(), &mat),
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ mat.layout.kernel.rows = 2;
+ EXPECT_EQ(prepacked_cache.Get(data.data(), &mat),
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ mat.layout.kernel.cols = 2;
+ EXPECT_EQ(prepacked_cache.Get(data.data(), &mat),
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ mat.layout.kernel.order = Order::kRowMajor;
+ EXPECT_EQ(prepacked_cache.Get(data.data(), &mat),
+ PrepackedCache::Action::kInsertedNewEntry);
+
+ EXPECT_EQ(prepacked_cache.MatrixCount(), 8);
+}
+
+void TestCachePolicies(CachePolicy cache_policy, bool expected_cached) {
+ ruy::Context context;
+ ruy::Ctx* ctx = get_ctx(&context);
+ PrepackedCache* cache = ctx->GetPrepackedCache();
+ EXPECT_EQ(cache->MatrixCount(), 0);
+
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2};
+ float dst_data[4];
+
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
+ lhs.set_data(lhs_data);
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, rhs.mutable_layout());
+ rhs.set_data(rhs_data);
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, dst.mutable_layout());
+ dst.set_data(dst_data);
+
+ ruy::MulParams<float, float> mul_params;
+ // Perform the multiplication and confirm no caching occurred.
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
+ EXPECT_EQ(cache->MatrixCount(), 0);
+
+ // Set cache policy for the LHS, repeat the multiplication, and see
+ // that caching did occur.
+ lhs.set_cache_policy(cache_policy);
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
+ const bool actual_cached = cache->MatrixCount() == 1;
+ EXPECT_EQ(actual_cached, expected_cached);
+}
+
+TEST(PrepackedCacheTest, TestCachePolicies) {
+ for (CachePolicy cache_policy :
+ {CachePolicy::kNeverCache, CachePolicy::kCacheIfLargeSpeedup,
+ CachePolicy::kCacheIfSignificantSpeedup, CachePolicy::kAlwaysCache}) {
+ TestCachePolicies(cache_policy,
+ cache_policy != CachePolicy::kNeverCache);
+ }
+}
+
+TEST(PrepackedCacheTest, TestClearCache) {
+ ruy::Context context;
+ PrepackedCache* cache = get_ctx(&context)->GetPrepackedCache();
+ EXPECT_EQ(cache->MatrixCount(), 0);
+
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2};
+ float dst_data[4];
+
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
+ lhs.set_data(lhs_data);
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, rhs.mutable_layout());
+ rhs.set_data(rhs_data);
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, dst.mutable_layout());
+ dst.set_data(dst_data);
+
+ ruy::MulParams<float, float> mul_params;
+ // Set cache policy for the LHS and see that caching occurs.
+ lhs.set_cache_policy(CachePolicy::kAlwaysCache);
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
+ EXPECT_NE(cache->MatrixCount(), 0);
+
+ // Clear the cache via the Context.
+ context.ClearPrepackedCache();
+ // Verify that the cache is now empty.
+ cache = get_ctx(&context)->GetPrepackedCache();
+ EXPECT_EQ(cache->MatrixCount(), 0);
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/prepare_packed_matrices.cc b/ruy/prepare_packed_matrices.cc
new file mode 100644
index 0000000..5a01af7
--- /dev/null
+++ b/ruy/prepare_packed_matrices.cc
@@ -0,0 +1,94 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/prepare_packed_matrices.h"
+
+#include "ruy/allocator.h"
+#include "ruy/ctx.h"
+#include "ruy/matrix.h"
+#include "ruy/prepacked_cache.h"
+#include "ruy/side_pair.h"
+#include "ruy/trace.h"
+#include "ruy/trmul_params.h"
+
+namespace ruy {
+namespace {
+
+// Returns true if the operand on the given side should use caching of the
+// packed form. This may either be explicitly dictated by its cache_policy
+// (if it is kNeverCache, the default, or kAlwaysCache), or it may depend
+// on a heuristic decision based on the other operand's width. For example,
+// in a matrix*vector product, for the LHS matrix operand, the other side is
+// the RHS vector, with a width of 1, causing the packing of the LHS to be
+// a large fraction of the overall work, so a heuristic would typically
+// decide in favor of caching, if permitted at all by the cache_policy.
+bool ShouldCache(const TrMulParams& params, Side side) {
+ const CachePolicy cache_policy = params.src[side].cache_policy;
+ // The width that matters is that of the other side, it is what determines
+ // the amortization of the packing work done on the present side.
+ const Side other_side = OtherSide(side);
+ const int other_width = params.src[other_side].layout.cols;
+ const int other_kernel_width =
+ params.packed_matrix[other_side].layout.kernel.cols;
+ switch (cache_policy) {
+ case CachePolicy::kNeverCache:
+ return false;
+ case CachePolicy::kAlwaysCache:
+ return true;
+ case CachePolicy::kCacheIfLargeSpeedup:
+ // The condition (other_width <= other_kernel_width) means that the kernel
+ // will traverse each value of the present side only once, meaning that
+ // the overhead of the packing work will be maximal, hence maximally
+ // worth caching.
+ return (other_width <= other_kernel_width);
+ case CachePolicy::kCacheIfSignificantSpeedup:
+ // Variant of the heuristic used in the kCacheIfLargeSpeedup case. The
+ // kernel will run on each value of the present side only a few times,
+ // so packing overhead will be significant.
+ return (other_width <= 4 * other_kernel_width);
+ default:
+ RUY_DCHECK(false);
+ return false;
+ }
+}
+
+} // namespace
+
+void PreparePackedMatrices(Ctx* ctx, TrMulParams* params) {
+ RUY_TRACE_SCOPE;
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ PEMat& packed_matrix = params->packed_matrix[side];
+ if (ShouldCache(*params, side)) {
+ // Use a cached packed matrix (possibly packing and caching now).
+ auto* cache = ctx->GetPrepackedCache();
+ auto action = cache->Get(params->src[side].data, &packed_matrix);
+ RUY_TRACE_INFO(PREPARE_PACKED_MATRICES_SHOULD_CACHE);
+ if (action == PrepackedCache::Action::kInsertedNewEntry) {
+ params->RunPack(side, ctx->GetMainThreadTuning(), 0,
+ packed_matrix.layout.cols);
+ }
+ params->is_prepacked[side] = true;
+ } else {
+ RUY_TRACE_INFO(PREPARE_PACKED_MATRICES_NO_CACHE);
+ // Do not use a cached packed matrix. Only need to allocate buffers now.
+ Allocator* allocator = ctx->GetMainAllocator();
+ packed_matrix.data = allocator->AllocateBytesAvoidingAliasingWith(
+ DataBytes(packed_matrix), params->src[side].data);
+ packed_matrix.sums = allocator->AllocateBytes(SumsBytes(packed_matrix));
+ }
+ }
+}
+
+} // namespace ruy
diff --git a/ruy/prepare_packed_matrices.h b/ruy/prepare_packed_matrices.h
new file mode 100644
index 0000000..1092dc9
--- /dev/null
+++ b/ruy/prepare_packed_matrices.h
@@ -0,0 +1,42 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PREPARE_PACKED_MATRICES_H_
+#define RUY_RUY_PREPARE_PACKED_MATRICES_H_
+
+#include "ruy/ctx.h"
+#include "ruy/trmul_params.h"
+
+namespace ruy {
+
+// Ensures that the packed matrices are ready for TrMul's work. In the generic
+// case, this is merely allocating their buffers.
+//
+// In the non-default case where
+// a matrix has a cache_policy allowing caching, this is where we implement
+// this caching feature: determining whether to cache each matrix, performing
+// the cache lookup, and possibly performing the packing and cache update if
+// not already cached.
+//
+// Assumes that the packed matrices have previously been created, with their
+// fields already set except for the buffer allocations, as part of
+// CreateTrMulParams. The reason for separating this preparation from the
+// creation is that the creation needs to be templatized and this preparation
+// does not.
+void PreparePackedMatrices(Ctx* ctx, TrMulParams* params);
+
+} // namespace ruy
+
+#endif // RUY_RUY_PREPARE_PACKED_MATRICES_H_
diff --git a/ruy/profiler/BUILD b/ruy/profiler/BUILD
new file mode 100644
index 0000000..64754bf
--- /dev/null
+++ b/ruy/profiler/BUILD
@@ -0,0 +1,66 @@
+# A minimalistic profiler sampling pseudo-stacks
+
+load("//ruy:build_defs.oss.bzl", "ruy_linkopts_thread_standard_library")
+
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+config_setting(
+ name = "ruy_profiler",
+ define_values = {"ruy_profiler": "true"},
+)
+
+# Used to build TFLite Micro RUY dependency for embedded targets outside of the
+# RUY source tree.
+filegroup(
+ name = "ruy_instrumentation_header",
+ srcs = ["instrumentation.h"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "instrumentation",
+ srcs = ["instrumentation.cc"],
+ hdrs = ["instrumentation.h"],
+ defines = select({
+ ":ruy_profiler": ["RUY_PROFILER"],
+ "//conditions:default": [],
+ }),
+ linkopts = ruy_linkopts_thread_standard_library(),
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "profiler",
+ srcs = [
+ "profiler.cc",
+ "treeview.cc",
+ ],
+ hdrs = [
+ "profiler.h",
+ "treeview.h",
+ ],
+ linkopts = ruy_linkopts_thread_standard_library(),
+ visibility = ["//visibility:public"],
+ deps = [":instrumentation"],
+)
+
+cc_library(
+ name = "test_instrumented_library",
+ testonly = True,
+ srcs = ["test_instrumented_library.cc"],
+ hdrs = ["test_instrumented_library.h"],
+ deps = [":instrumentation"],
+)
+
+cc_test(
+ name = "test",
+ srcs = ["test.cc"],
+ linkopts = ruy_linkopts_thread_standard_library(),
+ deps = [
+ ":profiler",
+ ":test_instrumented_library",
+ "//ruy:gtest_wrapper",
+ ],
+)
diff --git a/ruy/profiler/CMakeLists.txt b/ruy/profiler/CMakeLists.txt
new file mode 100644
index 0000000..df4b30a
--- /dev/null
+++ b/ruy/profiler/CMakeLists.txt
@@ -0,0 +1,72 @@
+# This file is generated (whence no license header). Do not edit!
+# To regenerate, run:
+# cmake/bazel_to_cmake.sh
+
+if(${RUY_PROFILER})
+ set(ruy_profiler_0_RUY_PROFILER "RUY_PROFILER")
+else()
+ set(ruy_profiler_0_RUY_PROFILER "")
+endif()
+
+if(CMAKE_SYSTEM_NAME STREQUAL Windows)
+ set(ruy_profiler_1_pthread "")
+else()
+ set(ruy_profiler_1_pthread "-pthread")
+endif()
+
+ruy_cc_library(
+ NAME
+ ruy_profiler_instrumentation
+ SRCS
+ instrumentation.cc
+ HDRS
+ instrumentation.h
+ DEFINES
+ ${ruy_profiler_0_RUY_PROFILER}
+ LINKOPTS
+ ${ruy_profiler_1_pthread}
+ PUBLIC
+)
+
+ruy_cc_library(
+ NAME
+ ruy_profiler_profiler
+ SRCS
+ profiler.cc
+ treeview.cc
+ HDRS
+ profiler.h
+ treeview.h
+ LINKOPTS
+ ${ruy_profiler_1_pthread}
+ PUBLIC
+ DEPS
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_library(
+ NAME
+ ruy_profiler_test_instrumented_library
+ TESTONLY
+ SRCS
+ test_instrumented_library.cc
+ HDRS
+ test_instrumented_library.h
+ DEPS
+ ruy_profiler_instrumentation
+)
+
+ruy_cc_test(
+ NAME
+ ruy_profiler_test
+ SRCS
+ test.cc
+ LINKOPTS
+ ${ruy_profiler_1_pthread}
+ DEPS
+ ruy_profiler_profiler
+ ruy_profiler_test_instrumented_library
+ ruy_gtest_wrapper
+)
+
+ruy_add_all_subdirs()
diff --git a/ruy/profiler/README.md b/ruy/profiler/README.md
new file mode 100644
index 0000000..4ff344d
--- /dev/null
+++ b/ruy/profiler/README.md
@@ -0,0 +1,149 @@
+# A minimalistic profiler sampling pseudo-stacks
+
+## Overview
+
+The present directory is the "ruy profiler". As a time profiler, it allows to
+measure where code is spending time.
+
+Contrary to most typical profilers, what it samples is not real call stacks, but
+"pseudo-stacks" which are just simple data structures constructed from within
+the program being profiled. Using this profiler requires manually instrumenting
+code to construct such pseudo-stack information.
+
+Another unusual characteristic of this profiler is that it uses only the C++11
+standard library. It does not use any non-portable feature, in particular it
+does not rely on signal handlers. The sampling is performed by a thread, the
+"profiler thread".
+
+A discussion of pros/cons of this approach is appended below.
+
+## How to use this profiler
+
+### How to instrument code
+
+An example of instrumented code is given in `test_instrumented_library.cc`.
+
+Code is instrumented by constructing `ScopeLabel` objects. These are RAII
+helpers, ensuring that the thread pseudo-stack contains the label during their
+lifetime. In the most common use case, one would construct such an object at the
+start of a function, so that its scope is the function scope and it allows to
+measure how much time is spent in this function.
+
+```c++
+#include "ruy/profiler/instrumentation.h"
+
+...
+
+void SomeFunction() {
+ ruy::profiler::ScopeLabel function_label("SomeFunction");
+ ... do something ...
+}
+```
+
+A `ScopeLabel` may however have any scope, for instance:
+
+```c++
+if (some_case) {
+ ruy::profiler::ScopeLabel extra_work_label("Some more work");
+ ... do some more work ...
+}
+```
+
+The string passed to the `ScopeLabel` constructor must be just a pointer to a
+literal string (a `char*` pointer). The profiler will assume that these pointers
+stay valid until the profile is finalized.
+
+However, that literal string may be a `printf` format string, and labels may
+have up to 4 parameters, of type `int`. For example:
+
+```c++
+void SomeFunction(int size) {
+ ruy::profiler::ScopeLabel function_label("SomeFunction (size=%d)", size);
+
+```
+
+### How to run the profiler
+
+Profiling instrumentation is a no-op unless the preprocessor token
+`RUY_PROFILER` is defined, so defining it is the first step when actually
+profiling. When building with Bazel, the preferred way to enable that is to pass
+this flag on the Bazel command line:
+
+```
+--define=ruy_profiler=true
+```
+
+To actually profile a code scope, it is enough to construct a `ScopeProfile`
+object, also a RAII helper. It will start the profiler on construction, and on
+destruction it will terminate the profiler and report the profile treeview on
+standard output by default. Example:
+
+```c++
+void SomeProfiledBenchmark() {
+ ruy::profiler::ScopeProfile profile;
+
+ CallSomeInstrumentedCode();
+}
+```
+
+An example is provided by the `:test` target in the present directory. Run it
+with `--define=ruy_profiler=true` as explained above:
+
+```
+bazel run -c opt \
+ --define=ruy_profiler=true \
+ //tensorflow/lite/experimental/ruy/profiler:test
+```
+
+The default behavior dumping the treeview on standard output may be overridden
+by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor.
+This causes the tree-view to be stored in that `TreeView` object, where it may
+be accessed an manipulated using the functions declared in `treeview.h`. The
+aforementioned `:test` provides examples for doing so.
+
+## Advantages and inconvenients
+
+Compared to a traditional profiler, e.g. Linux's "perf", the present kind of
+profiler has the following inconvenients:
+
+* Requires manual instrumentation of code being profiled.
+* Substantial overhead, modifying the performance characteristics of the code
+ being measured.
+* Questionable accuracy.
+
+But also the following advantages:
+
+* Profiling can be driven from within a benchmark program, allowing the entire
+ profiling procedure to be a single command line.
+* Not relying on symbol information removes removes exposure to toolchain
+ details and means less hassle in some build environments, especially
+ embedded/mobile (single command line to run and profile, no symbols files
+ required).
+* Fully portable (all of this is standard C++11).
+* Fully testable (see `:test`). Profiling becomes just another feature of the
+ code like any other.
+* Customized instrumentation can result in easier to read treeviews (only
+ relevant functions, and custom labels may be more readable than function
+ names).
+* Parametrized/formatted labels allow to do things that aren't possible with
+ call-stack-sampling profilers. For example, break down a profile where much
+ time is being spent in matrix multiplications, by the various matrix
+ multiplication shapes involved.
+
+The philosophy underlying this profiler is that software performance depends on
+software engineers profiling often, and a key factor limiting that in practice
+is the difficulty or cumbersome aspects of profiling with more serious profilers
+such as Linux's "perf", especially in embedded/mobile development: multiple
+command lines are involved to copy symbol files to devices, retrieve profile
+data from the device, etc. In that context, it is useful to make profiling as
+easy as benchmarking, even on embedded targets, even if the price to pay for
+that is lower accuracy, higher overhead, and some intrusive instrumentation
+requirement.
+
+Another key aspect determining what profiling approach is suitable for a given
+context, is whether one already has a-priori knowledge of where much of the time
+is likely being spent. When one has such a-priori knowledge, it is feasible to
+instrument the known possibly-critical code as per the present approach. On the
+other hand, in situations where one doesn't have such a-priori knowledge, a real
+profiler such as Linux's "perf" allows to right away get a profile of real
+stacks, from just symbol information generated by the toolchain.
diff --git a/ruy/profiler/instrumentation.cc b/ruy/profiler/instrumentation.cc
new file mode 100644
index 0000000..cc8122b
--- /dev/null
+++ b/ruy/profiler/instrumentation.cc
@@ -0,0 +1,132 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/profiler/instrumentation.h"
+
+#ifdef RUY_PROFILER
+
+#include <cstring>
+
+namespace ruy {
+namespace profiler {
+
+void Label::operator=(const Label& other) {
+ format_ = other.format_;
+ args_count_ = other.args_count_;
+ for (int i = 0; i < args_count_; i++) {
+ args_[i] = other.args_[i];
+ }
+}
+
+bool Label::operator==(const Label& other) const {
+ if (std::string(format_) != std::string(other.format_)) {
+ return false;
+ }
+ if (args_count_ != other.args_count_) {
+ return false;
+ }
+ for (int i = 0; i < args_count_; i++) {
+ if (args_[i] != other.args_[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+std::string Label::Formatted() const {
+ static constexpr int kBufSize = 256;
+ char buf[kBufSize];
+ if (args_count_ == 0) {
+ return format_;
+ }
+ if (args_count_ == 1) {
+ snprintf(buf, kBufSize, format_, args_[0]);
+ } else if (args_count_ == 2) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1]);
+ } else if (args_count_ == 3) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]);
+ } else if (args_count_ == 4) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]);
+ } else {
+ abort();
+ }
+ return buf;
+}
+
+namespace detail {
+
+std::mutex* GlobalsMutex() {
+ static std::mutex mutex;
+ return &mutex;
+}
+
+bool& GlobalIsProfilerRunning() {
+ static bool b;
+ return b;
+}
+
+std::vector<ThreadStack*>* GlobalAllThreadStacks() {
+ static std::vector<ThreadStack*> all_stacks;
+ return &all_stacks;
+}
+
+ThreadStack* ThreadLocalThreadStack() {
+ thread_local static ThreadStack thread_stack;
+ return &thread_stack;
+}
+
+ThreadStack::ThreadStack() {
+ std::lock_guard<std::mutex> lock(*GlobalsMutex());
+ static std::uint32_t global_next_thread_stack_id = 0;
+ stack_.id = global_next_thread_stack_id++;
+ GlobalAllThreadStacks()->push_back(this);
+}
+
+ThreadStack::~ThreadStack() {
+ std::lock_guard<std::mutex> lock(*GlobalsMutex());
+ std::vector<ThreadStack*>* all_stacks = GlobalAllThreadStacks();
+ for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) {
+ if (*it == this) {
+ all_stacks->erase(it);
+ return;
+ }
+ }
+}
+int GetBufferSize(const Stack& stack) {
+ return sizeof(stack.id) + sizeof(stack.size) +
+ stack.size * sizeof(stack.labels[0]);
+}
+
+void CopyToBuffer(const Stack& stack, char* dst) {
+ memcpy(dst, &stack.id, sizeof(stack.id));
+ dst += sizeof(stack.id);
+ memcpy(dst, &stack.size, sizeof(stack.size));
+ dst += sizeof(stack.size);
+ memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0]));
+}
+
+void ReadFromBuffer(const char* src, Stack* stack) {
+ memcpy(&stack->id, src, sizeof(stack->id));
+ src += sizeof(stack->id);
+ memcpy(&stack->size, src, sizeof(stack->size));
+ src += sizeof(stack->size);
+ memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0]));
+}
+
+} // namespace detail
+} // namespace profiler
+} // namespace ruy
+
+#endif
diff --git a/ruy/profiler/instrumentation.h b/ruy/profiler/instrumentation.h
new file mode 100644
index 0000000..c4df1e6
--- /dev/null
+++ b/ruy/profiler/instrumentation.h
@@ -0,0 +1,203 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PROFILER_INSTRUMENTATION_H_
+#define RUY_RUY_PROFILER_INSTRUMENTATION_H_
+
+#ifdef RUY_PROFILER
+#include <cstdio>
+#include <mutex>
+#include <vector>
+#endif
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+// A label is how a code scope is annotated to appear in profiles.
+// The stacks that are sampled by the profiler are stacks of such labels.
+// A label consists of a literal string, plus optional integer arguments.
+class Label {
+ public:
+ Label() {}
+ template <typename... Args>
+ explicit Label(Args... args) {
+ Set(args...);
+ }
+ void Set(const char* format) {
+ format_ = format;
+ args_count_ = 0;
+ }
+ template <typename... Args>
+ void Set(const char* format, Args... args) {
+ format_ = format;
+ args_count_ = sizeof...(args);
+ SetArgs(0, args...);
+ }
+
+ void operator=(const Label& other);
+
+ bool operator==(const Label& other) const;
+
+ std::string Formatted() const;
+ const char* format() const { return format_; }
+
+ private:
+ void SetArgs(int position, int arg0) { args_[position] = arg0; }
+
+ template <typename... Args>
+ void SetArgs(int position, int arg0, Args... args) {
+ SetArgs(position, arg0);
+ SetArgs(position + 1, args...);
+ }
+
+ static constexpr int kMaxArgs = 4;
+ const char* format_ = nullptr;
+ int args_count_ = 0;
+ int args_[kMaxArgs];
+};
+
+namespace detail {
+
+// Forward-declaration, see class ThreadStack below.
+class ThreadStack;
+
+bool& GlobalIsProfilerRunning();
+
+// Returns the global vector of pointers to all stacks, there being one stack
+// per thread executing instrumented code.
+std::vector<ThreadStack*>* GlobalAllThreadStacks();
+
+// Returns the mutex to be locked around any access to GlobalAllThreadStacks().
+std::mutex* GlobalsMutex();
+
+// Returns the thread-local stack, specific to the current thread.
+ThreadStack* ThreadLocalThreadStack();
+
+// This 'stack' is what may be more appropriately called a 'pseudostack':
+// It contains Label entries that are 'manually' entered by instrumentation
+// code. It's unrelated to real call stacks.
+struct Stack {
+ std::uint32_t id = 0;
+ static constexpr int kMaxSize = 64;
+ int size = 0;
+ Label labels[kMaxSize];
+};
+
+// Returns the buffer byte size required by CopyToSample.
+int GetBufferSize(const Stack& stack);
+
+// Copies this Stack into a byte buffer, called a 'sample'.
+void CopyToBuffer(const Stack& stack, char* dst);
+
+// Populates this Stack from an existing sample buffer, typically
+// produced by CopyToSample.
+void ReadFromBuffer(const char* src, Stack* stack);
+
+// ThreadStack is meant to be used as a thread-local singleton, assigning to
+// each thread a Stack object holding its pseudo-stack of profile labels,
+// plus a mutex allowing to synchronize accesses to this pseudo-stack between
+// this thread and a possible profiler thread sampling it.
+class ThreadStack {
+ public:
+ ThreadStack();
+ ~ThreadStack();
+
+ const Stack& stack() const { return stack_; }
+
+ // Returns the mutex to lock around any access to this stack. Each stack is
+ // accessed by potentially two threads: the thread that it belongs to
+ // (which calls Push and Pop) and the profiler thread during profiling
+ // (which calls CopyToSample).
+ std::mutex& Mutex() const { return mutex_; }
+
+ // Pushes a new label on the top of this Stack.
+ template <typename... Args>
+ void Push(Args... args) {
+ // This mutex locking is needed to guard against race conditions as both
+ // the current thread and the profiler thread may be concurrently accessing
+ // this stack. In addition to that, this mutex locking also serves the other
+ // purpose of acting as a barrier (of compiler code reordering, of runtime
+ // CPU instruction reordering, and of memory access reordering), which
+ // gives a measure of correctness to this profiler. The downside is some
+ // latency. As this lock will be uncontended most of the times, the cost
+ // should be roughly that of an sequentially-consistent atomic access,
+ // comparable to an access to the level of CPU data cache that is shared
+ // among all cores, typically 60 cycles on current ARM CPUs, plus side
+ // effects from barrier instructions.
+ std::lock_guard<std::mutex> lock(mutex_);
+ // Avoid overrunning the stack, even in 'release' builds. This profiling
+ // instrumentation code should not ship in release builds anyway, the
+ // overhead of this check is negligible, and overrunning a stack array would
+ // be bad.
+ if (stack_.size >= Stack::kMaxSize) {
+ abort();
+ }
+ stack_.labels[stack_.size++].Set(args...);
+ }
+
+ // Pops the top-most label from this Stack.
+ void Pop() {
+ // See the comment in Push about this lock. While it would be tempting to
+ // try to remove this lock and just atomically decrement size_ with a
+ // store-release, that would not necessarily be a substitute for all of the
+ // purposes that this lock serves, or if it was done carefully to serve all
+ // of the same purposes, then that wouldn't be faster than this (mostly
+ // uncontended) lock.
+ std::lock_guard<std::mutex> lock(mutex_);
+ stack_.size--;
+ }
+
+ private:
+ mutable std::mutex mutex_;
+ Stack stack_;
+};
+
+} // namespace detail
+
+// RAII user-facing way to construct Labels associated with their life scope
+// and get them pushed to / popped from the current thread stack.
+class ScopeLabel {
+ public:
+ template <typename... Args>
+ ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) {
+ thread_stack_->Push(args...);
+ }
+
+ ~ScopeLabel() { thread_stack_->Pop(); }
+
+ private:
+ detail::ThreadStack* thread_stack_;
+};
+
+#else // no RUY_PROFILER
+
+class ScopeLabel {
+ public:
+ template <typename... Args>
+ explicit ScopeLabel(Args...) {}
+
+ // This destructor is needed to consistently silence clang's -Wunused-variable
+ // which seems to trigger semi-randomly.
+ ~ScopeLabel() {}
+};
+
+#endif
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // RUY_RUY_PROFILER_INSTRUMENTATION_H_
diff --git a/ruy/profiler/profiler.cc b/ruy/profiler/profiler.cc
new file mode 100644
index 0000000..ae3a2e2
--- /dev/null
+++ b/ruy/profiler/profiler.cc
@@ -0,0 +1,109 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/profiler/profiler.h"
+
+#ifdef RUY_PROFILER
+#include <atomic>
+#include <chrono> // NOLINT
+#include <cstdio>
+#include <cstdlib>
+#include <thread> // NOLINT
+#include <vector>
+#endif
+
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+ScopeProfile::ScopeProfile() { Start(); }
+ScopeProfile::ScopeProfile(bool enable) {
+ if (enable) {
+ Start();
+ }
+}
+ScopeProfile::~ScopeProfile() {
+ if (!thread_) {
+ return;
+ }
+ finishing_.store(true);
+ thread_->join();
+ Finish();
+}
+
+void ScopeProfile::Start() {
+ {
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ if (detail::GlobalIsProfilerRunning()) {
+ fprintf(stderr, "FATAL: profiler already running!\n");
+ abort();
+ }
+ detail::GlobalIsProfilerRunning() = true;
+ }
+ finishing_ = false;
+ thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this));
+}
+
+void ScopeProfile::ThreadFunc() {
+ while (!finishing_.load()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ auto* thread_stacks = detail::GlobalAllThreadStacks();
+ for (detail::ThreadStack* thread_stack : *thread_stacks) {
+ Sample(*thread_stack);
+ }
+ }
+}
+
+void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) {
+ std::lock_guard<std::mutex> lock(thread_stack.Mutex());
+ // Drop empty stacks.
+ // This ensures that profiles aren't polluted by uninteresting threads.
+ if (thread_stack.stack().size == 0) {
+ return;
+ }
+ int sample_size = detail::GetBufferSize(thread_stack.stack());
+ int old_buf_size = samples_buf_.size();
+ samples_buf_.resize(old_buf_size + sample_size);
+ detail::CopyToBuffer(thread_stack.stack(),
+ samples_buf_.data() + old_buf_size);
+}
+
+void ScopeProfile::Finish() {
+ {
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ if (!detail::GlobalIsProfilerRunning()) {
+ fprintf(stderr, "FATAL: profiler is not running!\n");
+ abort();
+ }
+ detail::GlobalIsProfilerRunning() = false;
+ }
+ if (user_treeview_) {
+ user_treeview_->Populate(samples_buf_);
+ } else {
+ TreeView treeview;
+ treeview.Populate(samples_buf_);
+ Print(treeview);
+ }
+}
+
+#endif // RUY_PROFILER
+
+} // namespace profiler
+} // namespace ruy
diff --git a/ruy/profiler/profiler.h b/ruy/profiler/profiler.h
new file mode 100644
index 0000000..10db8df
--- /dev/null
+++ b/ruy/profiler/profiler.h
@@ -0,0 +1,106 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PROFILER_PROFILER_H_
+#define RUY_RUY_PROFILER_PROFILER_H_
+
+#include <cstdio>
+
+#ifdef RUY_PROFILER
+#include <atomic>
+#include <chrono>
+#include <thread>
+#include <vector>
+#endif
+
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+// RAII user-facing way to create a profiler and let it profile a code scope,
+// and print out an ASCII/MarkDown treeview upon leaving the scope.
+class ScopeProfile {
+ public:
+ // Default constructor, unconditionally profiling.
+ ScopeProfile();
+
+ // Constructor allowing to choose at runtime whether to profile.
+ explicit ScopeProfile(bool enable);
+
+ // Destructor. It's where the profile is reported.
+ ~ScopeProfile();
+
+ // See treeview_ member.
+ void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; }
+
+ private:
+ void Start();
+
+ // Thread entry point function for the profiler thread. This thread is
+ // created on construction.
+ void ThreadFunc();
+
+ // Record a stack as a sample.
+ void Sample(const detail::ThreadStack& stack);
+
+ // Finalize the profile. Called on destruction.
+ // If user_treeview_ is non-null, it will receive the treeview.
+ // Otherwise the treeview will just be printed.
+ void Finish();
+
+ // Buffer where samples are recorded during profiling.
+ std::vector<char> samples_buf_;
+
+ // Used to synchronize thread termination.
+ std::atomic<bool> finishing_;
+
+ // Underlying profiler thread, which will perform the sampling.
+ // This profiler approach relies on a thread rather than on signals.
+ std::unique_ptr<std::thread> thread_;
+
+ // TreeView to populate upon destruction. If left null (the default),
+ // a temporary treeview will be used and dumped on stdout. The user
+ // may override that by passing their own TreeView object for other
+ // output options or to directly inspect the TreeView.
+ TreeView* user_treeview_ = nullptr;
+};
+
+#else // no RUY_PROFILER
+
+struct ScopeProfile {
+ ScopeProfile() {
+#ifdef GEMMLOWP_PROFILING
+ fprintf(
+ stderr,
+ "\n\n\n**********\n\nWARNING:\n\nLooks like you defined "
+ "GEMMLOWP_PROFILING, but this code has been ported to the new ruy "
+ "profiler replacing the old gemmlowp profiler. You should now be "
+ "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using "
+ "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n");
+#endif
+ }
+ explicit ScopeProfile(bool) {}
+};
+
+#endif
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // RUY_RUY_PROFILER_PROFILER_H_
diff --git a/ruy/profiler/test.cc b/ruy/profiler/test.cc
new file mode 100644
index 0000000..0405ac7
--- /dev/null
+++ b/ruy/profiler/test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <chrono>
+#include <random>
+#include <thread>
+
+#include "ruy/gtest_wrapper.h"
+#include "ruy/profiler/profiler.h"
+#include "ruy/profiler/test_instrumented_library.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+namespace {
+
+void DoSomeMergeSort(int size) {
+ std::vector<int> data(size);
+
+ std::default_random_engine engine;
+ for (auto& val : data) {
+ val = engine();
+ }
+
+ MergeSort(size, data.data());
+}
+
+// The purpose of this basic test is to cover the basic path that will be taken
+// by a majority of users, not inspecting treeviews but just implicitly printing
+// them on stdout, and to have this test enabled even when RUY_PROFILER is not
+// defined, so that we have coverage for the non-RUY_PROFILER case.
+TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) {
+ {
+ ScopeProfile profile;
+ DoSomeMergeSort(1 << 20);
+ }
+}
+
+#ifdef RUY_PROFILER
+
+TEST(ProfilerTest, MergeSortSingleThread) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ DoSomeMergeSort(1 << 20);
+ }
+ Print(treeview);
+ EXPECT_EQ(treeview.thread_roots().size(), 1);
+ const auto& thread_root = *treeview.thread_roots().begin()->second;
+ EXPECT_EQ(DepthOfTreeBelow(thread_root), 22);
+ EXPECT_GE(
+ WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"),
+ 0.1 * thread_root.weight);
+ EXPECT_GE(WeightBelowNodeMatchingFormatted(
+ thread_root, "MergeSortRecurse (level=20, size=1)"),
+ 0.01 * thread_root.weight);
+
+ TreeView treeview_collapsed;
+ CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)",
+ &treeview_collapsed);
+ Print(treeview_collapsed);
+ const auto& collapsed_thread_root =
+ *treeview_collapsed.thread_roots().begin()->second;
+ EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6);
+ EXPECT_EQ(
+ WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"),
+ WeightBelowNodeMatchingUnformatted(collapsed_thread_root,
+ "MergeSort (size=%d)"));
+}
+
+TEST(ProfilerTest, MemcpyFourThreads) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ std::vector<std::unique_ptr<std::thread>> threads;
+ for (int i = 0; i < 4; i++) {
+ threads.emplace_back(new std::thread([i]() {
+ ScopeLabel thread_label("worker thread #%d", i);
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ ScopeLabel some_more_work_label("some more work");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ }));
+ }
+ for (int i = 0; i < 4; i++) {
+ threads[i]->join();
+ }
+ }
+ Print(treeview);
+ // Since we cleared GlobalAllThreadStacks and the current thread hasn't
+ // created any ScopeLabel, only the 4 worker threads should be recorded.
+ EXPECT_EQ(treeview.thread_roots().size(), 4);
+ for (const auto& thread_root : treeview.thread_roots()) {
+ const TreeView::Node& root_node = *thread_root.second;
+ // The root node may have 1 or 2 children depending on whether there is
+ // an "[other]" child.
+ EXPECT_GE(root_node.children.size(), 1);
+ EXPECT_LE(root_node.children.size(), 2);
+ const TreeView::Node& child_node = *root_node.children[0];
+ EXPECT_EQ(child_node.label.format(), "worker thread #%d");
+ // There must be 2 children, since roughly half the time will be in
+ // "some more work" leaving the other half in "[other]".
+ EXPECT_EQ(child_node.children.size(), 2);
+ const TreeView::Node& child_child_node = *child_node.children[0];
+ // Since we sample every millisecond and the threads run for >= 2000
+ // milliseconds, the "thread func" label should get roughly 2000 samples.
+ // Not very rigorous, as we're depending on the profiler thread getting
+ // scheduled, so to avoid this test being flaky, we use a much more
+ // conservative value of 500, one quarter of that normal value 2000.
+ EXPECT_GE(child_node.weight, 500);
+ // Likewise, allow up to four times more than the normal value 2000.
+ EXPECT_LE(child_node.weight, 8000);
+ // Roughly half of time should be spent under the "some more work" label.
+ float some_more_work_percentage =
+ 100.f * child_child_node.weight / child_node.weight;
+ EXPECT_GE(some_more_work_percentage, 40.0f);
+ EXPECT_LE(some_more_work_percentage, 60.0f);
+ }
+}
+
+TEST(ProfilerTest, OneThreadAfterAnother) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ {
+ std::thread thread([]() {
+ ScopeLabel thread_label("thread 0");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ });
+ thread.join();
+ }
+ {
+ std::thread thread([]() {
+ ScopeLabel thread_label("thread 1");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ });
+ thread.join();
+ }
+ }
+ Print(treeview);
+ EXPECT_EQ(treeview.thread_roots().size(), 2);
+}
+
+#endif // RUY_PROFILER
+
+} // namespace
+} // namespace profiler
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/profiler/test_instrumented_library.cc b/ruy/profiler/test_instrumented_library.cc
new file mode 100644
index 0000000..b017ea9
--- /dev/null
+++ b/ruy/profiler/test_instrumented_library.cc
@@ -0,0 +1,59 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "ruy/profiler/instrumentation.h"
+
+namespace {
+
+void MergeSortRecurse(int level, int size, int* data, int* workspace) {
+ ruy::profiler::ScopeLabel function_label(
+ "MergeSortRecurse (level=%d, size=%d)", level, size);
+ if (size <= 1) {
+ return;
+ }
+ int half_size = size / 2;
+ MergeSortRecurse(level + 1, half_size, data, workspace);
+ MergeSortRecurse(level + 1, size - half_size, data + half_size,
+ workspace + half_size);
+
+ ruy::profiler::ScopeLabel merging_sorted_halves_label(
+ "Merging sorted halves");
+ int dst_index = 0;
+ int left_index = 0;
+ int right_index = half_size;
+ while (dst_index < size) {
+ int val;
+ if (left_index < half_size &&
+ ((right_index >= size) || data[left_index] < data[right_index])) {
+ val = data[left_index++];
+ } else {
+ val = data[right_index++];
+ }
+ workspace[dst_index++] = val;
+ }
+ for (int i = 0; i < size; i++) {
+ data[i] = workspace[i];
+ }
+}
+
+} // namespace
+
+void MergeSort(int size, int* data) {
+ ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size);
+ std::vector<int> workspace(size);
+ MergeSortRecurse(0, size, data, workspace.data());
+}
diff --git a/ruy/profiler/test_instrumented_library.h b/ruy/profiler/test_instrumented_library.h
new file mode 100644
index 0000000..4882962
--- /dev/null
+++ b/ruy/profiler/test_instrumented_library.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
+#define RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
+
+#include "ruy/profiler/instrumentation.h"
+
+void MergeSort(int size, int* data);
+
+#endif // RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
diff --git a/ruy/profiler/treeview.cc b/ruy/profiler/treeview.cc
new file mode 100644
index 0000000..be0a944
--- /dev/null
+++ b/ruy/profiler/treeview.cc
@@ -0,0 +1,252 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef RUY_PROFILER
+
+#include "ruy/profiler/treeview.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cstdio>
+#include <functional>
+#include <memory>
+#include <vector>
+
+namespace ruy {
+namespace profiler {
+
+namespace {
+
+void SortNode(TreeView::Node* node) {
+ using NodePtr = std::unique_ptr<TreeView::Node>;
+ std::sort(node->children.begin(), node->children.end(),
+ [](const NodePtr& n1, const NodePtr& n2) {
+ return n1->weight > n2->weight;
+ });
+ for (const auto& child : node->children) {
+ SortNode(child.get());
+ }
+}
+
+// Records a stack i.e. a sample in a treeview, by incrementing the weights
+// of matching existing nodes and/or by creating new nodes as needed,
+// recursively, below the given node.
+void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) {
+ node->weight++;
+ if (stack.size == level) {
+ return;
+ }
+ TreeView::Node* child_to_add_to = nullptr;
+ for (const auto& child : node->children) {
+ if (child->label == stack.labels[level]) {
+ child_to_add_to = child.get();
+ break;
+ }
+ }
+ if (!child_to_add_to) {
+ child_to_add_to = new TreeView::Node;
+ child_to_add_to->label = stack.labels[level];
+ node->children.emplace_back(child_to_add_to);
+ }
+ AddStack(stack, child_to_add_to, level + 1);
+}
+
+// Recursively populates the treeview below the given node with 'other'
+// entries documenting for each node the difference between its weight and the
+// sum of its children's weight.
+void AddOther(TreeView::Node* node) {
+ int top_level_children_weight = 0;
+ for (const auto& child : node->children) {
+ AddOther(child.get());
+ top_level_children_weight += child->weight;
+ }
+ if (top_level_children_weight != 0 &&
+ top_level_children_weight != node->weight) {
+ auto* new_child = new TreeView::Node;
+ new_child->label = Label("[other]");
+ new_child->weight = node->weight - top_level_children_weight;
+ node->children.emplace_back(new_child);
+ }
+}
+
+} // namespace
+
+void TreeView::Populate(const std::vector<char>& samples_buf_) {
+ thread_roots_.clear();
+ // Populate the treeview with regular nodes coming from samples.
+ const char* buf_ptr = samples_buf_.data();
+ const char* const buf_ptr_end = buf_ptr + samples_buf_.size();
+ while (buf_ptr < buf_ptr_end) {
+ detail::Stack stack;
+ detail::ReadFromBuffer(buf_ptr, &stack);
+ // Empty stacks should have been dropped during sampling.
+ assert(stack.size > 0);
+ buf_ptr += GetBufferSize(stack);
+ const int id = stack.id;
+ if (!thread_roots_[id]) {
+ thread_roots_[id].reset(new Node);
+ }
+ AddStack(stack, thread_roots_[id].get(), 0);
+ }
+ // Populate the treeview with additional 'other' nodes, sort, and set
+ // root labels.
+ for (const auto& thread_root : thread_roots_) {
+ std::uint32_t id = thread_root.first;
+ Node* root = thread_root.second.get();
+ AddOther(root);
+ SortNode(root);
+ root->label.Set("Thread %x (%d samples)", id, root->weight);
+ }
+}
+
+// Recursively prints the treeview below the given node. The 'root' node
+// argument is only needed to compute weights ratios, with the root ratio
+// as denominator.
+void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root,
+ int level) {
+ if (&node == &root) {
+ printf("%s\n\n", node.label.Formatted().c_str());
+ } else {
+ for (int i = 1; i < level; i++) {
+ printf(" ");
+ }
+ printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight,
+ node.label.Formatted().c_str());
+ }
+ for (const auto& child : node.children) {
+ PrintTreeBelow(*child, root, level + 1);
+ }
+}
+
+void Print(const TreeView& treeview) {
+ printf("\n");
+ printf("Profile (%d threads):\n\n",
+ static_cast<int>(treeview.thread_roots().size()));
+ for (const auto& thread_root : treeview.thread_roots()) {
+ const TreeView::Node& root = *thread_root.second;
+ PrintTreeBelow(root, root, 0);
+ printf("\n");
+ }
+}
+
+int DepthOfTreeBelow(const TreeView::Node& node) {
+ if (node.children.empty()) {
+ return 0;
+ } else {
+ int max_child_depth = 0;
+ for (const auto& child : node.children) {
+ max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child));
+ }
+ return 1 + max_child_depth;
+ }
+}
+
+int WeightBelowNodeMatchingFunction(
+ const TreeView::Node& node,
+ const std::function<bool(const Label&)>& match) {
+ int weight = 0;
+ if (match(node.label)) {
+ weight += node.weight;
+ }
+ for (const auto& child : node.children) {
+ weight += WeightBelowNodeMatchingFunction(*child, match);
+ }
+ return weight;
+}
+
+int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node,
+ const std::string& format) {
+ return WeightBelowNodeMatchingFunction(
+ node, [&format](const Label& label) { return label.format() == format; });
+}
+
+int WeightBelowNodeMatchingFormatted(const TreeView::Node& node,
+ const std::string& formatted) {
+ return WeightBelowNodeMatchingFunction(
+ node, [&formatted](const Label& label) {
+ return label.Formatted() == formatted;
+ });
+}
+
+void CollapseNode(const TreeView::Node& node_in, int depth,
+ TreeView::Node* node_out) {
+ node_out->label = node_in.label;
+ node_out->weight = node_in.weight;
+ node_out->children.clear();
+ if (depth > 0) {
+ for (const auto& child_in : node_in.children) {
+ auto* child_out = new TreeView::Node;
+ node_out->children.emplace_back(child_out);
+ CollapseNode(*child_in, depth - 1, child_out);
+ }
+ }
+}
+
+void CollapseSubnodesMatchingFunction(
+ const TreeView::Node& node_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView::Node* node_out) {
+ if (match(node_in.label)) {
+ CollapseNode(node_in, depth, node_out);
+ } else {
+ node_out->label = node_in.label;
+ node_out->weight = node_in.weight;
+ node_out->children.clear();
+
+ for (const auto& child_in : node_in.children) {
+ auto* child_out = new TreeView::Node;
+ node_out->children.emplace_back(child_out);
+ CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out);
+ }
+ }
+}
+
+void CollapseNodesMatchingFunction(
+ const TreeView& treeview_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView* treeview_out) {
+ treeview_out->mutable_thread_roots()->clear();
+ for (const auto& thread_root_in : treeview_in.thread_roots()) {
+ std::uint32_t id = thread_root_in.first;
+ const auto& root_in = *thread_root_in.second;
+ auto* root_out = new TreeView::Node;
+ treeview_out->mutable_thread_roots()->emplace(
+ id, std::unique_ptr<TreeView::Node>(root_out));
+ CollapseSubnodesMatchingFunction(root_in, depth, match, root_out);
+ }
+}
+
+void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth,
+ const std::string& format,
+ TreeView* treeview_out) {
+ CollapseNodesMatchingFunction(
+ treeview_in, depth,
+ [&format](const Label& label) { return label.format() == format; },
+ treeview_out);
+}
+
+void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth,
+ const std::string& formatted,
+ TreeView* treeview_out) {
+ CollapseNodesMatchingFunction(
+ treeview_in, depth,
+ [&formatted](const Label& label) {
+ return label.Formatted() == formatted;
+ },
+ treeview_out);
+}
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // RUY_PROFILER
diff --git a/ruy/profiler/treeview.h b/ruy/profiler/treeview.h
new file mode 100644
index 0000000..de3313a
--- /dev/null
+++ b/ruy/profiler/treeview.h
@@ -0,0 +1,130 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_PROFILER_TREEVIEW_H_
+#define RUY_RUY_PROFILER_TREEVIEW_H_
+
+#ifdef RUY_PROFILER
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+namespace profiler {
+
+// A tree view of a profile.
+class TreeView {
+ public:
+ struct Node {
+ std::vector<std::unique_ptr<Node>> children;
+ Label label;
+ int weight = 0;
+ };
+
+ void Populate(const std::vector<char>& samples_buf_);
+
+ // Intentionally an *ordered* map so that threads are enumerated
+ // in an order that's consistent and typically putting the 'main thread'
+ // first.
+ using ThreadRootsMap = std::map<std::uint32_t, std::unique_ptr<Node>>;
+
+ const ThreadRootsMap& thread_roots() const { return thread_roots_; }
+ ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; }
+
+ private:
+ ThreadRootsMap thread_roots_;
+};
+
+/* Below are API functions for manipulating and printing treeviews. */
+
+// Prints the treeview to stdout.
+void Print(const TreeView& treeview);
+
+// Prints the treeview below the given node on stdout.
+void PrintTreeBelow(const TreeView::Node& node);
+
+// Returns the tree depth below the given node.
+int DepthOfTreeBelow(const TreeView::Node& node);
+
+// Returns the sum of weights of nodes below the given node and filtered by
+// the `match` predicate.
+int WeightBelowNodeMatchingFunction(
+ const TreeView::Node& node, const std::function<bool(const Label&)>& match);
+
+// Returns the sum of weights of nodes below the given node and whose
+// unformatted label (i.e. raw format string) matches the given `format` string.
+//
+// This allows to aggregate nodes whose labels differ only by parameter values.
+int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node,
+ const std::string& format);
+
+// Returns the sum of weights of nodes below the given node and whose formatted
+// label matches the `formatted` string.
+//
+// In the case of nodes with parametrized labels, this allows to count only
+// nodes with specific parameter values. For that purpose, one may also instead
+// use WeightBelowNodeMatchingFunction directly, with a `match` predicate
+// comparing raw integer parameter values directly, instead of going through
+// formatted strings.
+int WeightBelowNodeMatchingFormatted(const TreeView::Node& node,
+ const std::string& formatted);
+
+// Produces a `node_out` that is a copy of `node_in` but with tree depth below
+// it clamped at `depth`, with further subtrees aggregated into single leaf
+// nodes.
+void CollapseNode(const TreeView::Node& node_in, int depth,
+ TreeView::Node* node_out);
+
+// Calls CollapseNode with the given `depth` on every subnode filtered by the
+// `match` predicate. Note that this does NOT limit the tree depth below
+// `node_out` to `depth`, since each collapsed node below `node_out` may be
+// arbitrarily far below it and `depth` is only used as the collapsing depth
+// at that point.
+void CollapseSubnodesMatchingFunction(
+ const TreeView::Node& node_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView::Node* node_out);
+
+// Calls CollapseNode with the given `depth` on every node filtered by the
+// `match` predicate. Note that this does NOT limit the tree depth below
+// `node_out` to `depth`, since each collapsed node below `node_out` may be
+// arbitrarily far below it and `depth` is only used as the collapsing depth
+// at that point.
+void CollapseNodesMatchingFunction(
+ const TreeView& treeview_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView* treeview_out);
+
+// Special case of CollapseNodesMatchingFunction matching unformatted labels,
+// i.e. raw format strings.
+// See the comment on WeightBelowNodeMatchingUnformatted.
+void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth,
+ const std::string& format,
+ TreeView* treeview_out);
+
+// Special case of CollapseNodesMatchingFunction matching formatted labels.
+// See the comment on WeightBelowNodeMatchingFormatted.
+void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth,
+ const std::string& formatted,
+ TreeView* treeview_out);
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // RUY_PROFILER
+
+#endif // RUY_RUY_PROFILER_TREEVIEW_H_
diff --git a/ruy/reference_mul.h b/ruy/reference_mul.h
new file mode 100644
index 0000000..6c00dc7
--- /dev/null
+++ b/ruy/reference_mul.h
@@ -0,0 +1,56 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_REFERENCE_MUL_H_
+#define RUY_RUY_REFERENCE_MUL_H_
+
+#include <algorithm>
+
+#include "ruy/apply_multiplier.h"
+#include "ruy/matrix.h"
+#include "ruy/mul_params.h"
+
+namespace ruy {
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void ReferenceMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const MulParams<AccumScalar, DstScalar>& mul_params,
+ Matrix<DstScalar>* dst) {
+ for (int i = 0; i < lhs.layout().rows(); i++) {
+ for (int j = 0; j < rhs.layout().cols(); j++) {
+ AccumScalar accum = 0;
+ for (int k = 0; k < lhs.layout().cols(); k++) {
+ AccumScalar lhs_val = Element(lhs, i, k);
+ AccumScalar rhs_val = Element(rhs, k, j);
+ accum += (lhs_val - lhs.zero_point()) * (rhs_val - rhs.zero_point());
+ }
+ int channel =
+ mul_params.channel_dimension() == ChannelDimension::kRow ? i : j;
+ if (mul_params.bias()) {
+ accum += mul_params.bias()[channel];
+ }
+ ApplyMultiplier(mul_params, channel, &accum);
+ accum += dst->zero_point();
+ accum = std::min<AccumScalar>(accum, mul_params.clamp_max());
+ accum = std::max<AccumScalar>(accum, mul_params.clamp_min());
+ *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
+ }
+ }
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_REFERENCE_MUL_H_
diff --git a/ruy/ruy.h b/ruy/ruy.h
new file mode 100644
index 0000000..ddbe192
--- /dev/null
+++ b/ruy/ruy.h
@@ -0,0 +1,114 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is the main Ruy public header.
+
+#ifndef RUY_RUY_RUY_H_
+#define RUY_RUY_RUY_H_
+
+#include "ruy/context.h"
+#include "ruy/context_get_ctx.h"
+#include "ruy/frontend.h"
+#include "ruy/mat.h"
+#include "ruy/matrix.h"
+#include "ruy/mul_params.h"
+#include "ruy/path.h"
+#include "ruy/trace.h"
+
+namespace ruy {
+
+// Entry point allowing to specify a custom OR-ed set of Path's to
+// compile. See the comments in path.h for more details about that.
+// Most users should use the other ruy::Mul overload not taking a Path template
+// parameter, and the main documentation comment is on that overload.
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename AccumScalar, typename DstScalar>
+void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const MulParams<AccumScalar, DstScalar>& mul_params, Context* context,
+ Matrix<DstScalar>* dst) {
+ RUY_TRACE_SCOPE;
+ RUY_TRACE_INFO(MUL);
+ Mat<LhsScalar> internal_lhs = ToInternal(lhs);
+ Mat<RhsScalar> internal_rhs = ToInternal(rhs);
+ Mat<DstScalar> internal_dst = ToInternal(*dst);
+ MulFrontEnd<CompiledPaths>(internal_lhs, internal_rhs, mul_params,
+ get_ctx(context), &internal_dst);
+}
+
+// Performs a multiplication of matrices, with some extra features for
+// neural network applications. The basic operation is:
+//
+// dst = lhs * rhs // matrix multiplication
+//
+// The `mul_params` argument conveys additional parameters that are not
+// naturally associated with lhs, rhs, dst. That includes typical neural network
+// application domain specific features such as a bias-vector and clamp bounds,
+// as well as integer quantization parameters.
+//
+// A simple reference implementation of the operation performed by ruy::Mul
+// is provided by the ruy::ReferenceMul function in reference_mul.h.
+//
+// The `context` argument can be any ruy::Context object as long as no other
+// thread is going to concurrently access that ruy::Context. The simplest
+// correct (but not efficient) calling pattern is
+//
+// ruy::Context context;
+// ruy::Mul(lhs, rhs, mul_params, &context, dst);
+//
+// However, creating and destroying a new context everytime is inefficient
+// because it doesn't allow for resources to persist across ruy calls. Such
+// resources may include heap allocations, a thread pool, and hardware detection
+// results, and can be expensive to obtain. So the recommended usage pattern is
+// more like this:
+//
+// // Once during initialization:
+// ruy::Context* context = new ruy::Context;
+//
+// // Many times
+// ruy::Mul(lhs, rhs, mul_params, context, dst);
+//
+// If multiple threads may concurrently be calling ruy::Mul, they must either
+// use separate Contexts, or use a lock to ensure that no two threads are
+// concurrently accessing the Context object. There is no lock inside Context,
+// nothing is done to ensure reentrancy with shared Context objects.
+//
+// Ruy defaults to using only 1 thread. Multi-threading is always opted in to,
+// by calling Context::set_max_num_threads() with an explicit thread count.
+// If multiple threads may concurrently be calling ruy::Mul, it is advisable
+// to set up their respective Context objects with set_max_num_threads so that
+// the overall number of threads doesn't exceed the overall number of threads
+// that the system can usefully execute concurrently
+// (e.g. the number of CPU cores in typical scenarios). At least ruy forces
+// each invocation to make an explicit decision here, there is no automatic
+// detection of the best number of threads to use in ruy.
+//
+// Constraints on the template parameters:
+// * If DstScalar is floating-point then AccumScalar must also be.
+// * If DstScalar is integral then AccumScalar must be std::int32_t.
+// Please refer to MulParams' class comment for more information. When
+// DstScalar is integral and is narrower than AccumScalar, additional
+// MulParams fields must be set to control the scaling of internal accumulators
+// before the final saturating cast to the DstScalar type.
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const MulParams<AccumScalar, DstScalar>& mul_params, Context* context,
+ Matrix<DstScalar>* dst) {
+ Mul<kDefaultPaths>(lhs, rhs, mul_params, context, dst);
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_RUY_H_
diff --git a/ruy/ruy_test.bzl b/ruy/ruy_test.bzl
new file mode 100644
index 0000000..ef7e8b1
--- /dev/null
+++ b/ruy/ruy_test.bzl
@@ -0,0 +1,34 @@
+# Provides the ruy_test macro for type-parametrized tests.
+"""ruy_test is a macro for building a test with multiple paths corresponding to tuples of types for LHS, RHS, accumulator and destination."""
+
+def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = [], deps = None):
+ for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst:
+ native.cc_test(
+ name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst),
+ srcs = srcs,
+ copts = copts + [
+ "-DRUY_TEST_LHSSCALAR=%s" % lhs,
+ "-DRUY_TEST_RHSSCALAR=%s" % rhs,
+ "-DRUY_TEST_ACCUMSCALAR=%s" % accum,
+ "-DRUY_TEST_DSTSCALAR=%s" % dst,
+ ],
+ deps = deps,
+ tags = tags,
+ )
+
+def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts, deps = None):
+ tags = ["req_dep=//third_party/gemmlowp:profiler"]
+ for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst:
+ native.cc_binary(
+ name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst),
+ testonly = True,
+ srcs = srcs,
+ copts = copts + [
+ "-DRUY_TEST_LHSSCALAR=%s" % lhs,
+ "-DRUY_TEST_RHSSCALAR=%s" % rhs,
+ "-DRUY_TEST_ACCUMSCALAR=%s" % accum,
+ "-DRUY_TEST_DSTSCALAR=%s" % dst,
+ ],
+ deps = deps,
+ tags = tags,
+ )
diff --git a/ruy/ruy_test_ext.oss.bzl b/ruy/ruy_test_ext.oss.bzl
new file mode 100644
index 0000000..5701fff
--- /dev/null
+++ b/ruy/ruy_test_ext.oss.bzl
@@ -0,0 +1,7 @@
+"""Allows to specialize the ruy BUILD to availability of external libraries"""
+
+def ruy_test_ext_defines():
+ return []
+
+def ruy_test_ext_deps():
+ return []
diff --git a/ruy/side_pair.h b/ruy/side_pair.h
new file mode 100644
index 0000000..2f40277
--- /dev/null
+++ b/ruy/side_pair.h
@@ -0,0 +1,68 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_SIDE_PAIR_H_
+#define RUY_RUY_SIDE_PAIR_H_
+
+#include "ruy/check_macros.h"
+
+namespace ruy {
+
+// Enumeration of the sides, i.e. the operands 'slots', in a matrix
+// multiplication. The numerical values of these enumeration constants matter
+// because these will be used as indices into the array underlying a SidePair.
+enum class Side {
+ // Left-hand side
+ kLhs = 0,
+ // Right-hand side
+ kRhs = 1
+};
+
+inline Side OtherSide(Side side) {
+ return side == Side::kLhs ? Side::kRhs : Side::kLhs;
+}
+
+// SidePair is a pair container where the two elements are indexed by a Side
+// enum.
+template <typename T>
+class SidePair final {
+ public:
+ SidePair() {}
+ SidePair(const T& a, const T& b) : elem_{a, b} {}
+ const T& operator[](Side side) const {
+ const int index = static_cast<int>(side);
+ // Technically this check is vacuous, since other values would be
+ // out-of-range for enum Side.
+ RUY_DCHECK(index == 0 || index == 1);
+ return elem_[index];
+ }
+
+ T& operator[](Side side) {
+ const int index = static_cast<int>(side);
+ // Technically this check is vacuous, since other values would be
+ // out-of-range for enum Side.
+ RUY_DCHECK(index == 0 || index == 1);
+ return elem_[index];
+ }
+
+ private:
+ static_assert(static_cast<int>(Side::kLhs) == 0, "");
+ static_assert(static_cast<int>(Side::kRhs) == 1, "");
+ T elem_[2];
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_SIDE_PAIR_H_
diff --git a/ruy/size_util.h b/ruy/size_util.h
new file mode 100644
index 0000000..a144522
--- /dev/null
+++ b/ruy/size_util.h
@@ -0,0 +1,105 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_SIZE_UTIL_H_
+#define RUY_RUY_SIZE_UTIL_H_
+
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+
+#ifdef _WIN32
+#include <intrin.h>
+#endif
+
+namespace ruy {
+
+template <typename Integer>
+inline Integer floor_log2(Integer n) {
+ static_assert(std::is_integral<Integer>::value, "");
+ static_assert(std::is_signed<Integer>::value, "");
+ static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
+
+ RUY_DCHECK_GE(n, 1);
+#ifdef _MSC_VER
+ unsigned long result; // NOLINT[runtime/int]
+ if (sizeof(Integer) == 4) {
+ _BitScanReverse(&result, n);
+ } else {
+#if defined(_M_X64) || defined(_M_ARM64)
+ // _BitScanReverse64 is supported only on 64-bit MSVC platforms
+ _BitScanReverse64(&result, static_cast<unsigned __int64>(n));
+#else
+ // Emulate using 32-bit _BitScanReverse
+ const uint32_t n_hi = uint64_t(n) >> 32;
+ if (n_hi == 0) {
+ _BitScanReverse(&result, static_cast<unsigned long>(n));
+ } else {
+ _BitScanReverse(&result, static_cast<unsigned long>(n_hi));
+ result += 32;
+ }
+#endif // defined(_M_X64) || defined(_M_ARM64)
+ }
+ return result;
+#else
+ if (sizeof(Integer) == 4) {
+ return 31 - __builtin_clz(n);
+ } else {
+ return 63 - __builtin_clzll(n);
+ }
+#endif
+}
+
+template <typename Integer>
+Integer ceil_log2(Integer n) {
+ RUY_DCHECK_GE(n, 1);
+ return n == 1 ? 0 : floor_log2(n - 1) + 1;
+}
+
+template <typename Integer>
+constexpr bool is_pot(Integer value) {
+ return (value > 0) && ((value & (value - 1)) == 0);
+}
+
+template <typename Integer>
+Integer pot_log2(Integer n) {
+ RUY_DCHECK(is_pot(n));
+ return floor_log2(n);
+}
+
+template <typename Integer>
+Integer round_down_pot(Integer value) {
+ return static_cast<Integer>(1) << floor_log2(value);
+}
+
+template <typename Integer>
+Integer round_up_pot(Integer value) {
+ return static_cast<Integer>(1) << ceil_log2(value);
+}
+
+template <typename Integer, typename Modulo>
+Integer round_down_pot(Integer value, Modulo modulo) {
+ RUY_DCHECK_EQ(modulo & (modulo - 1), 0);
+ return value & ~(modulo - 1);
+}
+
+template <typename Integer, typename Modulo>
+Integer round_up_pot(Integer value, Modulo modulo) {
+ return round_down_pot(value + modulo - 1, modulo);
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_SIZE_UTIL_H_
diff --git a/ruy/size_util_test.cc b/ruy/size_util_test.cc
new file mode 100644
index 0000000..eb4c9ca
--- /dev/null
+++ b/ruy/size_util_test.cc
@@ -0,0 +1,101 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/size_util.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+
+#include "ruy/gtest_wrapper.h"
+
+namespace ruy {
+namespace {
+
+template <typename Integer>
+void SizeUtilTestValue(Integer value) {
+ if (value == 0) {
+ return;
+ }
+
+ EXPECT_LE(0, floor_log2(value));
+ EXPECT_LE(floor_log2(value), ceil_log2(value));
+ EXPECT_LE(ceil_log2(value), 8 * sizeof(Integer));
+
+ if (is_pot(value)) {
+ EXPECT_EQ(floor_log2(value), ceil_log2(value));
+ EXPECT_EQ(floor_log2(value), pot_log2(value));
+ } else {
+ EXPECT_EQ(floor_log2(value) + 1, ceil_log2(value));
+ }
+ EXPECT_EQ(value >> floor_log2(value), 1);
+ EXPECT_EQ(round_down_pot(value), static_cast<Integer>(1)
+ << floor_log2(value));
+ EXPECT_LE(round_down_pot(value), value);
+ EXPECT_GE(round_down_pot(value), value >> 1);
+ EXPECT_TRUE(is_pot(round_down_pot(value)));
+
+ if (ceil_log2(value) < static_cast<int>(8 * sizeof(Integer) - 1)) {
+ EXPECT_EQ(value >> ceil_log2(value), is_pot(value) ? 1 : 0);
+ EXPECT_EQ(round_up_pot(value), static_cast<Integer>(1) << ceil_log2(value));
+ EXPECT_GE(round_up_pot(value), value);
+ EXPECT_LE(round_up_pot(value) >> 1, value);
+ EXPECT_TRUE(is_pot(round_up_pot(value)));
+ }
+
+ for (std::uint8_t modulo : {1, 2, 8, 32, 128}) {
+ EXPECT_GE(value, round_down_pot(value, modulo));
+ EXPECT_EQ(round_down_pot(value, modulo) % modulo, 0);
+
+ if (value <= std::numeric_limits<Integer>::max() - modulo) {
+ EXPECT_LE(value, round_up_pot(value, modulo));
+ EXPECT_EQ(round_up_pot(value, modulo) % modulo, 0);
+ }
+ }
+}
+
+template <typename Integer>
+void SizeUtilTest() {
+ for (unsigned exponent = 0; exponent < 8 * sizeof(Integer) - 1; exponent++) {
+ const Integer pot = static_cast<Integer>(1) << exponent;
+ SizeUtilTestValue(pot - 1);
+ SizeUtilTestValue(pot);
+ SizeUtilTestValue(pot + 1);
+ SizeUtilTestValue(pot + 12);
+ SizeUtilTestValue(pot + 123);
+ }
+ SizeUtilTestValue(std::numeric_limits<Integer>::max() - 1);
+ SizeUtilTestValue(std::numeric_limits<Integer>::max());
+}
+
+TEST(SizeUtilTest, Int) { SizeUtilTest<int>(); }
+
+TEST(SizeUtilTest, Long) { SizeUtilTest<long int>(); } // NOLINT
+
+TEST(SizeUtilTest, LongLong) { SizeUtilTest<long long int>(); } // NOLINT
+
+TEST(SizeUtilTest, Int32) { SizeUtilTest<std::int32_t>(); }
+
+TEST(SizeUtilTest, Int64) { SizeUtilTest<std::int64_t>(); }
+
+TEST(SizeUtilTest, Ptrdiff) { SizeUtilTest<std::ptrdiff_t>(); }
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/system_aligned_alloc.cc b/ruy/system_aligned_alloc.cc
new file mode 100644
index 0000000..7c86691
--- /dev/null
+++ b/ruy/system_aligned_alloc.cc
@@ -0,0 +1,51 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/system_aligned_alloc.h"
+
+#include <cstddef>
+#include <cstdlib>
+
+#ifdef _WIN32
+#include <malloc.h>
+#endif
+
+namespace ruy {
+
+namespace detail {
+
+void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
+#ifdef _WIN32
+ return _aligned_malloc(num_bytes, kMinimumBlockAlignment);
+#else
+ void *ptr;
+ if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) {
+ return nullptr;
+ }
+ return ptr;
+#endif
+}
+
+void SystemAlignedFree(void *ptr) {
+#ifdef _WIN32
+ _aligned_free(ptr);
+#else
+ free(ptr);
+#endif
+}
+
+} // namespace detail
+
+} // namespace ruy
diff --git a/ruy/system_aligned_alloc.h b/ruy/system_aligned_alloc.h
new file mode 100644
index 0000000..cdf73af
--- /dev/null
+++ b/ruy/system_aligned_alloc.h
@@ -0,0 +1,53 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_SYSTEM_ALIGNED_ALLOC_H_
+#define RUY_RUY_SYSTEM_ALIGNED_ALLOC_H_
+
+#include <cstddef>
+
+namespace ruy {
+
+namespace detail {
+
+// Minimum alignment for blocks.
+//
+// Considerations:
+// - This needs to be at least the alignment of any usual data type.
+// - It's useful that this is at least the size of a cache line to limit
+// possible cache side effects (if only on performance behavior).
+// - It's useful that this is at least the size of SIMD registers, as
+// some SIMD instruction sets have at least performance behavior
+// differences (e.g. NEON) or even different requirements (e.g. SSE)
+// based on that.
+// - It's useful that this is at least the size of an "exclusive reservation
+// granule" on ARM, meaning that if we use this Allocator to allocate
+// an atomic variable, there will be no side effects from other things
+// contending for exclusive/atomic memory accesses to it. While the
+// ARM reference manual mentions that this granule size may be as large
+// as 2048 bytes, in practice we observe it to be 64 bytes. It can
+// be queried cheaply, at runtime, from userspace, if needed.
+constexpr std::ptrdiff_t kMinimumBlockAlignment = 64;
+
+// Primitive allocation functions obtaining aligned memory from the
+// operating system.
+void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
+void SystemAlignedFree(void* ptr);
+
+} // namespace detail
+
+} // namespace ruy
+
+#endif // RUY_RUY_SYSTEM_ALIGNED_ALLOC_H_
diff --git a/ruy/test.h b/ruy/test.h
new file mode 100644
index 0000000..5aa4c41
--- /dev/null
+++ b/ruy/test.h
@@ -0,0 +1,2308 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_TEST_H_
+#define RUY_RUY_TEST_H_
+
+#include <math.h>
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <ctime>
+#include <iostream>
+#include <iterator>
+#include <limits>
+#include <memory>
+#include <numeric>
+#include <random>
+#include <set>
+#include <sstream>
+#include <string>
+#include <tuple>
+#include <type_traits>
+#include <vector>
+
+#include "ruy/allocator.h"
+#include "ruy/context.h"
+#include "ruy/context_get_ctx.h"
+#include "ruy/ctx.h"
+#include "ruy/gtest_wrapper.h" // IWYU pragma: export
+#include "ruy/matrix.h" // IWYU pragma: export
+#include "ruy/mul_params.h" // IWYU pragma: export
+#include "ruy/pack_common.h"
+#include "ruy/platform.h"
+#include "ruy/pmu.h"
+#include "ruy/reference_mul.h"
+#include "ruy/ruy.h"
+#include "ruy/size_util.h"
+#include "ruy/time.h"
+
+#ifdef RUY_TEST_EXTERNAL_PATHS
+#define EIGEN_USE_THREADS
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#pragma GCC diagnostic pop
+#include "third_party/gemmlowp/public/gemmlowp.h"
+#include "third_party/lapack/blas.h"
+#endif
+
+#ifdef RUY_PROFILER
+#include "ruy/profiler/profiler.h"
+#endif
+
+#ifdef __linux__
+#include <sys/mman.h>
+#include <unistd.h>
+#endif
+
+namespace ruy {
+
+const float kClampRatio = 0.1f;
+
+enum class ExternalPath {
+ kNone,
+ kReference,
+ kGemmlowp,
+ kEigen,
+ kEigenTensor,
+ kOpenBlas
+};
+
+inline std::vector<std::string>* CoveredPaths() {
+ static std::vector<std::string> covered_paths;
+ return &covered_paths;
+}
+
+inline const char* PathName(Path path) {
+#define RUY_PATHNAME_CASE(NAME) \
+ case Path::NAME: \
+ return #NAME;
+ switch (path) {
+ RUY_PATHNAME_CASE(kStandardCpp)
+ RUY_PATHNAME_CASE(kInternalStandardCppVariant1)
+ RUY_PATHNAME_CASE(kInternalStandardCppVariant2)
+ RUY_PATHNAME_CASE(kInternalStandardCppVariant3)
+
+#if RUY_PLATFORM_NEON
+ RUY_PATHNAME_CASE(kNeon)
+ RUY_PATHNAME_CASE(kNeonDotprod)
+#elif RUY_PLATFORM_X86
+ RUY_PATHNAME_CASE(kAvx2Fma)
+ RUY_PATHNAME_CASE(kAvx512)
+ RUY_PATHNAME_CASE(kAvx)
+#endif
+ default:
+ RUY_CHECK(false);
+ return nullptr;
+ }
+#undef RUY_PATHNAME_CASE
+}
+
+inline const char* TuningName(Tuning tuning) {
+#define RUY_SUBPATHNAME_CASE(NAME) \
+ case Tuning::NAME: \
+ return #NAME;
+ switch (tuning) {
+ RUY_SUBPATHNAME_CASE(kA55ish)
+ RUY_SUBPATHNAME_CASE(kGeneric)
+ default:
+ RUY_CHECK(false);
+ return nullptr;
+ }
+#undef RUY_SUBPATHNAME_CASE
+}
+
+inline const char* PathName(ExternalPath path) {
+#define RUY_PATHNAME_CASE(NAME) \
+ case ExternalPath::NAME: \
+ return #NAME;
+ switch (path) {
+ RUY_PATHNAME_CASE(kReference)
+ RUY_PATHNAME_CASE(kGemmlowp)
+ RUY_PATHNAME_CASE(kEigen)
+ RUY_PATHNAME_CASE(kEigenTensor)
+ RUY_PATHNAME_CASE(kOpenBlas)
+ default:
+ RUY_CHECK(false);
+ return nullptr;
+ }
+#undef RUY_PATHNAME_CASE
+}
+
+inline std::ostream& operator<<(std::ostream& stream, Path path) {
+ return stream << PathName(path);
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ ExternalPath external_path) {
+ return stream << PathName(external_path);
+}
+
+template <typename ContainerType>
+std::string Join(const ContainerType& container) {
+ if (container.empty()) {
+ return "<empty>";
+ }
+ std::ostringstream stream;
+ auto it = container.begin();
+ stream << *it++;
+ for (; it != container.end(); ++it) {
+ stream << ", ";
+ stream << *it;
+ }
+ return stream.str();
+}
+
+struct LogCoveredPathsOnDestruction final {
+ ~LogCoveredPathsOnDestruction() {
+ std::cerr << "Covered paths: " << Join(*CoveredPaths()) << std::endl;
+
+ // When we know that it would be abnormal for some path not to be covered,
+ // we check it here. Accidentally disabling SIMD paths has occurred in the
+ // past and is one of the biggest performance regressions imaginable.
+ //
+ // TODO: we should be able to require some x86 paths as well, at least
+ // SSE4.2.
+
+#if RUY_PLATFORM_ARM
+ // When testing on ARM32 or ARM64, make sure that we covered the NEON path.
+ // NEON is always available on ARM64, and we treat it as always available
+ // also on ARM32.
+ bool found_neon = false;
+ for (const std::string& covered_path : *CoveredPaths()) {
+ if (covered_path == "kNeon") {
+ found_neon = true;
+ }
+ }
+ if (!found_neon) {
+ std::cerr
+ << "Error: we haven't tested the kNeon path as we should have.\n"
+ << std::endl;
+ abort();
+ }
+#endif
+
+ // When testing on ARM64 ChromiumOS emulator, make sure that we covered
+ // the dotprod path. We're getting such coverage at the moment thanks to
+ // using a sufficiently recent emulator, and we don't want to regress that.
+#if RUY_PLATFORM_ARM_64 && defined RUY_TESTING_ON_CHROMIUMOS
+ bool found_dotprod = false;
+ for (const std::string& covered_path : *CoveredPaths()) {
+ if (covered_path == "kNeonDotprod") {
+ found_dotprod = true;
+ }
+ }
+ if (!found_dotprod) {
+ std::cerr
+ << "Error: we haven't tested the kNeonDotprod path as we should "
+ "have. At the moment, this is required on ChromiumOS as this is "
+ "what we run emulator tests in, that currently supports "
+ "dot-product "
+ "instructions, and we care very much about not regressing that. "
+ "If this test was run in an emulator, please upgrade to a newer "
+ "emulator version. If this test was run on an actual device, and "
+ "you need to be able to run ruy tests on devices not supporting "
+ "dot-product instructions, get in touch with us.\n"
+ << std::endl;
+ abort();
+ }
+#endif
+ }
+ static void Singleton() { static LogCoveredPathsOnDestruction singleton; }
+};
+
+enum class RandomRange {
+ kGeneral,
+ kAvoidMinValue,
+ kOffCenterAvoidMinValue,
+ kReasonableSrcZeroPoint,
+ kReasonableDstZeroPoint,
+ kBias
+};
+
+template <typename Scalar,
+ bool IsFloatingPoint = std::is_floating_point<Scalar>::value>
+struct RandomRangeBounds {};
+
+template <typename Scalar>
+struct RandomRangeBounds<Scalar, true> {
+ static Scalar GetMinBound(RandomRange range) {
+ switch (range) {
+ case RandomRange::kGeneral:
+ return -1;
+ case RandomRange::kAvoidMinValue:
+ return -1;
+ case RandomRange::kOffCenterAvoidMinValue:
+ return -1;
+ case RandomRange::kReasonableSrcZeroPoint:
+ return 0;
+ case RandomRange::kReasonableDstZeroPoint:
+ return 0;
+ case RandomRange::kBias:
+ return -1;
+ default:
+ RUY_CHECK(false);
+ return 0;
+ }
+ }
+ static Scalar GetMaxBound(RandomRange range) {
+ switch (range) {
+ case RandomRange::kGeneral:
+ return 1;
+ case RandomRange::kAvoidMinValue:
+ return 1;
+ case RandomRange::kOffCenterAvoidMinValue:
+ return 1;
+ case RandomRange::kReasonableSrcZeroPoint:
+ return 0;
+ case RandomRange::kReasonableDstZeroPoint:
+ return 0;
+ case RandomRange::kBias:
+ return 1;
+ default:
+ RUY_CHECK(false);
+ return 0;
+ }
+ }
+};
+
+template <typename Scalar>
+Scalar WeightedSum(Scalar s1, float weight1, Scalar s2, float weight2) {
+ float sum = s1 * weight1 + s2 * weight2;
+ float clamped = std::min<float>(
+ std::numeric_limits<Scalar>::max(),
+ std::max<float>(std::numeric_limits<Scalar>::lowest(), sum));
+ return static_cast<Scalar>(clamped);
+}
+
+template <typename Scalar>
+Scalar Parametrized(float param) {
+ return WeightedSum(std::numeric_limits<Scalar>::max(), param,
+ std::numeric_limits<Scalar>::lowest(), 1 - param);
+}
+
+template <typename Scalar>
+struct RandomRangeBounds<Scalar, false> {
+ static Scalar GetMinBound(RandomRange range) {
+ static constexpr double offcenteredness =
+ 0.02; // Shift lower limit by about 5 for range of 255.
+ switch (range) {
+ case RandomRange::kGeneral:
+ return std::numeric_limits<Scalar>::lowest();
+ case RandomRange::kAvoidMinValue:
+ return 1 + std::numeric_limits<Scalar>::lowest();
+ case RandomRange::kOffCenterAvoidMinValue:
+ return 1 + std::numeric_limits<Scalar>::lowest() +
+ static_cast<Scalar>(
+ offcenteredness * std::numeric_limits<Scalar>::max() -
+ offcenteredness *
+ (std::numeric_limits<Scalar>::lowest() + 1));
+ case RandomRange::kReasonableSrcZeroPoint:
+ return std::numeric_limits<Scalar>::lowest();
+ case RandomRange::kReasonableDstZeroPoint:
+ return Parametrized<Scalar>(0.4);
+ case RandomRange::kBias:
+ return std::is_same<Scalar, std::int32_t>::value
+ ? static_cast<Scalar>(-10000)
+ : 0;
+ default:
+ RUY_CHECK(false);
+ return 0;
+ }
+ }
+ static Scalar GetMaxBound(RandomRange range) {
+ switch (range) {
+ case RandomRange::kGeneral:
+ return std::numeric_limits<Scalar>::max();
+ case RandomRange::kAvoidMinValue:
+ return std::numeric_limits<Scalar>::max();
+ case RandomRange::kOffCenterAvoidMinValue:
+ return std::numeric_limits<Scalar>::max();
+ case RandomRange::kReasonableSrcZeroPoint:
+ return std::numeric_limits<Scalar>::max();
+ case RandomRange::kReasonableDstZeroPoint:
+ return Parametrized<Scalar>(0.6);
+ case RandomRange::kBias:
+ return std::is_same<Scalar, std::int32_t>::value
+ ? static_cast<Scalar>(10000)
+ : 0;
+ default:
+ RUY_CHECK(false);
+ return 0;
+ }
+ }
+};
+
+inline std::default_random_engine& global_random_engine() {
+ static std::default_random_engine engine;
+ return engine;
+}
+
+template <typename Scalar>
+struct UniformRandomDistribution {
+ UniformRandomDistribution(RandomRange range)
+ : dist(RandomRangeBounds<Scalar>::GetMinBound(range),
+ RandomRangeBounds<Scalar>::GetMaxBound(range)) {}
+ Scalar Get() { return dist(global_random_engine()); }
+ // std::uniform_int_distribution is specified not to support char types,
+ // only short and wider types. MSVC actually generates an error on
+ // std::uniform_int_distribution<std::int8_t>.
+ using StdDistType = typename std::conditional<
+ std::is_floating_point<Scalar>::value,
+ std::uniform_real_distribution<Scalar>,
+ std::uniform_int_distribution<std::int32_t>>::type;
+ StdDistType dist;
+};
+
+template <typename Scalar>
+void MakeRandomScalar(UniformRandomDistribution<Scalar>* uniform_dist,
+ Scalar* dst) {
+ *dst = uniform_dist->Get();
+}
+
+#if defined(__has_feature)
+#if __has_feature(address_sanitizer)
+#define RUY_TEST_BUILT_WITH_ASAN
+#endif
+#endif
+
+// Don't use separate mappings when building with Address Sanitizer, as the
+// manual mappings could hide actual address errors from ASan (ASan can't see
+// the actual buffer valid address range inside the manual mapping).
+#if defined __linux__ && !defined RUY_TEST_BUILT_WITH_ASAN
+#define RUY_TEST_USE_SEPARATE_MAPPINGS
+#endif
+
+template <typename T>
+struct SeparateMappingAllocator {
+ using value_type = T;
+
+ T* allocate(std::size_t n) {
+#ifdef RUY_TEST_USE_SEPARATE_MAPPINGS
+ const std::size_t page_size = getpagesize();
+ std::size_t buffer_size = n * sizeof(T);
+ std::size_t rounded_buffer_size = round_up_pot(buffer_size, page_size);
+ // We are going to map an additional page at the end of our buffer, then
+ // unmap it, to ensure that our buffer's end is the last mapped byte, so as
+ // to catch any overrun.
+ void* mapping =
+ mmap(nullptr, rounded_buffer_size + page_size, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ RUY_CHECK_NE(mapping, MAP_FAILED);
+ int unmap_result =
+ munmap(static_cast<char*>(mapping) + rounded_buffer_size, page_size);
+ RUY_CHECK_EQ(unmap_result, 0);
+ // Clearing bytes should be redundant since mmap has do do it already, but
+ // it does not hurt and acts as an assertion checking that we got the above
+ // mapping and unmapping right.
+ std::memset(mapping, 0, rounded_buffer_size);
+ // Compute the offset to make our buffer end at the last mapped byte.
+ std::size_t buffer_offset = rounded_buffer_size - buffer_size;
+ void* buffer =
+ static_cast<void*>(static_cast<char*>(mapping) + buffer_offset);
+ return static_cast<T*>(buffer);
+#else
+ T* ret = new T[n];
+ std::memset(ret, 0, n * sizeof(T));
+ return ret;
+#endif
+ }
+ void deallocate(T* p, std::size_t n) {
+#ifdef RUY_TEST_USE_SEPARATE_MAPPINGS
+ // The mapped bytes are the buffer address range, rounded on both ends
+ // to page boundary.
+ const std::size_t page_size = getpagesize();
+ std::size_t buffer_size = n * sizeof(T);
+ std::size_t rounded_buffer_size = round_up_pot(buffer_size, page_size);
+ std::uintptr_t p_addr = reinterpret_cast<std::uintptr_t>(p);
+ void* mapping = reinterpret_cast<void*>(p_addr & ~(page_size - 1));
+ int ret = munmap(mapping, rounded_buffer_size);
+ RUY_CHECK_EQ(ret, 0);
+#else
+ (void)n;
+ delete[] p;
+#endif
+ }
+};
+
+template <typename T>
+using SeparateMappingVector = std::vector<T, SeparateMappingAllocator<T>>;
+
+template <typename Scalar, typename Allocator>
+void MakeRandomVector(UniformRandomDistribution<Scalar>* uniform_dist, int size,
+ std::vector<Scalar, Allocator>* dst) {
+ dst->resize(size);
+ for (auto& x : *dst) {
+ MakeRandomScalar(uniform_dist, &x);
+ }
+}
+
+template <typename Scalar>
+void MakeRandomScalar(RandomRange range, Scalar* dst) {
+ UniformRandomDistribution<Scalar> dist(range);
+ *dst = dist.Get();
+ if (range == RandomRange::kReasonableDstZeroPoint ||
+ range == RandomRange::kReasonableSrcZeroPoint) {
+ if (global_random_engine()() & 1) {
+ *dst = SymmetricZeroPoint<Scalar>();
+ }
+ }
+}
+
+template <typename Scalar, typename Allocator>
+void MakeRandomVector(RandomRange range, int size,
+ std::vector<Scalar, Allocator>* dst) {
+ UniformRandomDistribution<Scalar> dist(range);
+ dst->resize(size);
+ for (auto& x : *dst) {
+ MakeRandomScalar(&dist, &x);
+ }
+}
+
+enum class LayoutStyle { kUnstridedLinear, kLinear };
+
+inline void MakeLayout(int rows, int cols, Order order,
+ LayoutStyle layout_style, Layout* layout) {
+ layout->set_rows(rows);
+ layout->set_cols(cols);
+ layout->set_order(order);
+
+ const int min_stride = order == Order::kColMajor ? rows : cols;
+
+ RUY_CHECK(layout_style == LayoutStyle::kUnstridedLinear ||
+ layout_style == LayoutStyle::kLinear);
+ if (layout_style == LayoutStyle::kUnstridedLinear) {
+ layout->set_stride(min_stride);
+ } else {
+ layout->set_stride(min_stride + 1);
+ }
+}
+
+template <typename Scalar>
+struct StorageMatrix {
+ StorageMatrix() = default;
+ StorageMatrix(const StorageMatrix&) = delete;
+ SeparateMappingVector<Scalar> data;
+ Matrix<Scalar> matrix;
+};
+
+inline bool IsUnstrided(const Layout& layout) {
+ if (layout.order() == Order::kColMajor) {
+ return layout.stride() == layout.rows();
+ } else {
+ return layout.stride() == layout.cols();
+ }
+}
+
+inline bool IsRowMajor(const Layout& layout) {
+ return layout.order() == Order::kRowMajor;
+}
+
+inline bool IsColMajor(const Layout& layout) {
+ return layout.order() == Order::kColMajor;
+}
+
+inline int FlatSize(const Layout& layout) {
+ const int outerdim =
+ layout.order() == Order::kColMajor ? layout.cols() : layout.rows();
+ return layout.stride() * outerdim;
+}
+
+template <typename Scalar>
+void VerifyConsistentFields(const StorageMatrix<Scalar>& storage_matrix) {
+ if (storage_matrix.data.empty()) {
+ RUY_CHECK_EQ(storage_matrix.matrix.data(), nullptr);
+ RUY_CHECK_EQ(storage_matrix.matrix.layout().rows(), 0);
+ RUY_CHECK_EQ(storage_matrix.matrix.layout().cols(), 0);
+ } else {
+ RUY_CHECK_EQ(storage_matrix.matrix.data(), storage_matrix.data.data());
+ RUY_CHECK_EQ(FlatSize(storage_matrix.matrix.layout()),
+ static_cast<int>(storage_matrix.data.size()));
+ }
+}
+
+template <typename Scalar>
+void MakeRandom(int rows, int cols, Order order, Scalar zero_point,
+ LayoutStyle layout_style, RandomRange range,
+ StorageMatrix<Scalar>* storage_matrix) {
+ MakeLayout(rows, cols, order, layout_style,
+ storage_matrix->matrix.mutable_layout());
+ storage_matrix->matrix.set_zero_point(zero_point);
+ UniformRandomDistribution<Scalar> data_dist(range);
+ MakeRandomVector(&data_dist, FlatSize(storage_matrix->matrix.layout()),
+ &storage_matrix->data);
+ storage_matrix->matrix.set_data(storage_matrix->data.data());
+ VerifyConsistentFields(*storage_matrix);
+}
+
+template <typename Scalar>
+struct TestResult {
+ void operator=(const TestResult&) = delete;
+ void operator=(const TestResult&&) = delete;
+ StorageMatrix<Scalar> storage_matrix;
+ Path path = Path::kNone;
+ Tuning tuning = Tuning::kAuto;
+ ExternalPath external_path = ExternalPath::kNone;
+ float latency;
+ float l1_refill_rate;
+ float l2_refill_rate;
+ float l3_refill_rate;
+ float l1tlb_refill_rate;
+ float l2tlb_refill_rate;
+ float mispred_rate;
+ float frontend_stall_rate;
+ float backend_stall_rate;
+};
+
+template <typename Scalar>
+std::string PathName(const TestResult<Scalar>& result) {
+ std::string pathname;
+ if (result.path != Path::kNone) {
+ pathname.assign(PathName(result.path));
+ } else if (result.external_path != ExternalPath::kNone) {
+ pathname.assign(PathName(result.external_path));
+ } else {
+ RUY_CHECK(false);
+ }
+ if (result.tuning != Tuning::kAuto) {
+ pathname.append("/");
+ pathname.append(TuningName(result.tuning));
+ }
+ return pathname;
+}
+
+template <typename tLhsScalar, typename tRhsScalar, typename tAccumScalar,
+ typename tDstScalar>
+struct TestSet final {
+ using LhsScalar = tLhsScalar;
+ using RhsScalar = tRhsScalar;
+ using AccumScalar = tAccumScalar;
+ using DstScalar = tDstScalar;
+ using MulParamsType = MulParams<AccumScalar, DstScalar>;
+ using TestResultType = TestResult<DstScalar>;
+
+ void Run() {
+ MakeZeroPoints();
+ MakeLhsRhs();
+ MakeMulParams();
+ MakeOtherParams();
+ MakeResultPaths();
+ Eval();
+ Verify();
+ }
+
+ private:
+ void MakeZeroPoints();
+ void MakeLhsRhs();
+ void MakeMulParams();
+ void MakeResultPaths();
+ void MakeOtherParams();
+ void EvalAndVerify();
+ void Eval();
+ void Verify();
+
+ void EvalResult(TestResultType* result);
+ void EvalRuy(TestResultType* result);
+ void DoMul(TestResultType* result);
+ void Benchmark(TestResultType* result);
+ void VerifyTestResults() const;
+
+ public:
+ enum class LifeStage {
+ kInitial,
+ kHasZeroPoints,
+ kHasLhsRhs,
+ kHasMulParams,
+ kHasOtherParams,
+ kHasResultPaths,
+ kEvaluated,
+ kFinal
+ };
+
+ ~TestSet();
+
+ LifeStage life_stage = LifeStage::kInitial;
+
+ int rows = 0;
+ int cols = 0;
+ int depth = 0;
+ Order lhs_order = Order::kRowMajor;
+ Order rhs_order = Order::kColMajor;
+ Order dst_order = Order::kColMajor;
+ LayoutStyle layout_style = LayoutStyle::kUnstridedLinear;
+
+ bool use_specified_zero_points = false;
+ LhsScalar lhs_zero_point = 0;
+ RhsScalar rhs_zero_point = 0;
+ DstScalar dst_zero_point = 0;
+
+ SeparateMappingVector<AccumScalar> per_channel_multiplier_fixedpoint;
+ SeparateMappingVector<int> per_channel_multiplier_exponent;
+
+ StorageMatrix<LhsScalar> lhs;
+ StorageMatrix<RhsScalar> rhs;
+ MulParamsType mul_params;
+ SeparateMappingVector<AccumScalar> bias_data;
+ std::vector<std::unique_ptr<TestResultType>> results;
+
+ std::vector<Path> paths;
+ std::vector<ExternalPath> external_paths;
+
+ bool benchmark = false;
+ bool perchannel = false;
+ int max_num_threads = 0;
+
+ bool cache_lhs = false;
+ bool cache_rhs = false;
+};
+
+inline PmuEvents& GlobalPmuEvents() {
+ static PmuEvents pmu;
+ return pmu;
+}
+
+inline Context& GlobalContext() {
+ // Ensure that GlobalPmuEvents is constructed before we create any context.
+ // This ensures that pmu counters are opened before we create any worker
+ // thread, which is necessary to count events from worker threads.
+ GlobalPmuEvents();
+
+ static Context context;
+ return context;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::~TestSet() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kFinal);
+ LogCoveredPathsOnDestruction::Singleton();
+ GlobalContext().ClearPrepackedCache();
+}
+
+#if defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define RUY_TSAN
+#endif
+#if __has_feature(address_sanitizer)
+#define RUY_ASAN
+#endif
+#endif // defined(__has_feature)
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::DoMul(
+ TestResultType* result) {
+ Mul<kAllPathsIncludingInternalVariants>(lhs.matrix, rhs.matrix, mul_params,
+ &GlobalContext(),
+ &result->storage_matrix.matrix);
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::EvalRuy(
+ TestResultType* result) {
+ GlobalContext().set_explicit_tuning(result->tuning);
+ if (max_num_threads) {
+ GlobalContext().set_max_num_threads(max_num_threads);
+ } else if (benchmark) {
+ GlobalContext().set_max_num_threads(1);
+ } else {
+ GlobalContext().set_max_num_threads(1 + global_random_engine()() % 8);
+ }
+ get_ctx(&GlobalContext())->SetRuntimeEnabledPaths(result->path);
+ DoMul(result);
+ // If enabling caching, Mul is stateful, so we run it a second time to get
+ // coverage of these aspects.
+ if (!benchmark && (cache_lhs || cache_rhs)) {
+ DoMul(result);
+ }
+ RUY_CHECK_EQ(GlobalContext().last_used_path(), result->path);
+ GlobalContext().set_explicit_tuning(Tuning::kAuto);
+ GlobalContext().set_max_num_threads(1);
+}
+
+#ifdef RUY_TEST_EXTERNAL_PATHS
+
+template <typename Scalar, gemmlowp::MapOrder tOrder>
+void WrapGemmlowp(const Matrix<Scalar>& src,
+ gemmlowp::MatrixMap<const Scalar, tOrder>* dst) {
+ RUY_CHECK(src.layout().order() == (tOrder == gemmlowp::MapOrder::ColMajor
+ ? Order::kColMajor
+ : Order::kRowMajor));
+ *dst = gemmlowp::MatrixMap<const Scalar, tOrder>(
+ src.data(), src.layout().rows(), src.layout().cols(),
+ src.layout().stride());
+}
+
+template <typename Scalar, gemmlowp::MapOrder tOrder>
+void WrapGemmlowpMutable(Matrix<Scalar>* src,
+ gemmlowp::MatrixMap<Scalar, tOrder>* dst) {
+ RUY_CHECK(src->layout().order() == (tOrder == gemmlowp::MapOrder::ColMajor
+ ? Order::kColMajor
+ : Order::kRowMajor));
+ *dst = gemmlowp::MatrixMap<Scalar, tOrder>(src->data(), src->layout().rows(),
+ src->layout().cols(),
+ src->layout().stride());
+}
+
+template <Order tOrder>
+struct GemmlowpOrder {};
+
+template <>
+struct GemmlowpOrder<Order::kColMajor> {
+ static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::ColMajor;
+};
+
+template <>
+struct GemmlowpOrder<Order::kRowMajor> {
+ static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::RowMajor;
+};
+
+inline gemmlowp::GemmContext& GlobalGemmlowpContext() {
+ static gemmlowp::GemmContext context;
+ return context;
+}
+
+template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename LhsScalar,
+ typename RhsScalar, typename DstScalar, typename MulParamsType>
+void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const MulParamsType& mul_params, int max_num_threads,
+ Matrix<DstScalar>* dst) {
+ static constexpr gemmlowp::MapOrder kGemmlowpLhsOrder =
+ GemmlowpOrder<LhsOrder>::kValue;
+ static constexpr gemmlowp::MapOrder kGemmlowpRhsOrder =
+ GemmlowpOrder<RhsOrder>::kValue;
+ static constexpr gemmlowp::MapOrder kGemmlowpDstOrder =
+ GemmlowpOrder<DstOrder>::kValue;
+ gemmlowp::MatrixMap<const LhsScalar, kGemmlowpLhsOrder> gemmlowp_lhs;
+ gemmlowp::MatrixMap<const RhsScalar, kGemmlowpRhsOrder> gemmlowp_rhs;
+ gemmlowp::MatrixMap<DstScalar, kGemmlowpDstOrder> gemmlowp_dst;
+ WrapGemmlowp(lhs, &gemmlowp_lhs);
+ WrapGemmlowp(rhs, &gemmlowp_rhs);
+ WrapGemmlowpMutable(dst, &gemmlowp_dst);
+
+ gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
+ quantize_down_stage.result_offset_after_shift = dst->zero_point();
+ quantize_down_stage.result_fixedpoint_multiplier =
+ mul_params.multiplier_fixedpoint();
+ quantize_down_stage.result_exponent = mul_params.multiplier_exponent();
+ gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
+ gemmlowp::VectorShape::Col>
+ quantize_down_stage_pc;
+ quantize_down_stage_pc.result_offset_after_shift = dst->zero_point();
+ using ColVectorMap =
+ gemmlowp::VectorMap<const std::int32_t, gemmlowp::VectorShape::Col>;
+ quantize_down_stage_pc.result_fixedpoint_multiplier = ColVectorMap(
+ mul_params.multiplier_fixedpoint_perchannel(), lhs.layout().rows());
+ quantize_down_stage_pc.result_exponent = ColVectorMap(
+ mul_params.multiplier_exponent_perchannel(), lhs.layout().rows());
+
+ gemmlowp::OutputStageClamp clamp_stage;
+ clamp_stage.min = mul_params.clamp_min();
+ clamp_stage.max = mul_params.clamp_max();
+ using OutputStageSaturatingCast = typename std::conditional<
+ std::is_same<DstScalar, std::uint8_t>::value,
+ gemmlowp::OutputStageSaturatingCastToUint8,
+ gemmlowp::OutputStageSaturatingCastToInt16>::type;
+ OutputStageSaturatingCast saturating_cast_stage;
+
+ GlobalGemmlowpContext().set_max_num_threads(max_num_threads ? max_num_threads
+ : 1);
+ if (mul_params.bias()) {
+ using ColVectorMap =
+ gemmlowp::VectorMap<const std::int32_t, gemmlowp::VectorShape::Col>;
+ gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_add_stage;
+ bias_add_stage.bias_vector =
+ ColVectorMap(mul_params.bias(), dst->layout().rows());
+#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE
+ if (mul_params.multiplier_exponent_perchannel()) {
+ const auto& output_pipeline =
+ std::make_tuple(bias_add_stage, quantize_down_stage_pc, clamp_stage,
+ saturating_cast_stage);
+ gemmlowp::GemmWithOutputPipeline<
+ LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
+ -lhs.zero_point(), -rhs.zero_point(), output_pipeline);
+ } else // NOLINT[readability/braces]
+#endif
+ {
+ const auto& output_pipeline =
+ std::make_tuple(bias_add_stage, quantize_down_stage, clamp_stage,
+ saturating_cast_stage);
+ gemmlowp::GemmWithOutputPipeline<
+ LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
+ -lhs.zero_point(), -rhs.zero_point(), output_pipeline);
+ }
+ } else {
+#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE
+ if (mul_params.multiplier_exponent_perchannel()) {
+ const auto& output_pipeline = std::make_tuple(
+ quantize_down_stage_pc, clamp_stage, saturating_cast_stage);
+ gemmlowp::GemmWithOutputPipeline<
+ LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
+ -lhs.zero_point(), -rhs.zero_point(), output_pipeline);
+ } else // NOLINT[readability/braces]
+#endif
+ {
+ const auto& output_pipeline = std::make_tuple(
+ quantize_down_stage, clamp_stage, saturating_cast_stage);
+ gemmlowp::GemmWithOutputPipeline<
+ LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
+ -lhs.zero_point(), -rhs.zero_point(), output_pipeline);
+ }
+ }
+}
+
+inline constexpr int Mash(Order LhsOrder, Order RhsOrder, Order DstOrder) {
+ return (LhsOrder == Order::kRowMajor ? 4 : 0) +
+ (RhsOrder == Order::kRowMajor ? 2 : 0) +
+ (DstOrder == Order::kRowMajor ? 1 : 0);
+}
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename MulParamsType>
+void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const MulParamsType& mul_params, int max_num_threads,
+ Matrix<DstScalar>* dst) {
+ int index =
+ Mash(lhs.layout().order(), rhs.layout().order(), dst->layout().order());
+ switch (index) {
+#define EVALGEMMLOWP_CASE3(LHS, RHS, DST) \
+ case Mash(LHS, RHS, DST): \
+ return EvalGemmlowp<LHS, RHS, DST>(lhs, rhs, mul_params, max_num_threads, \
+ dst);
+#define EVALGEMMLOWP_CASE2(LHS, RHS) \
+ EVALGEMMLOWP_CASE3(LHS, RHS, Order::kColMajor) \
+ EVALGEMMLOWP_CASE3(LHS, RHS, Order::kRowMajor)
+#define EVALGEMMLOWP_CASE1(LHS) \
+ EVALGEMMLOWP_CASE2(LHS, Order::kColMajor) \
+ EVALGEMMLOWP_CASE2(LHS, Order::kRowMajor)
+
+ EVALGEMMLOWP_CASE1(Order::kColMajor)
+ EVALGEMMLOWP_CASE1(Order::kRowMajor)
+
+#undef EVALGEMMLOWP_CASE1
+#undef EVALGEMMLOWP_CASE2
+#undef EVALGEMMLOWP_CASE3
+
+ default:
+ RUY_CHECK(false);
+ }
+}
+
+template <Order tOrder>
+struct EigenOrder {};
+
+template <>
+struct EigenOrder<Order::kColMajor> {
+ static constexpr int kValue = Eigen::ColMajor;
+};
+
+template <>
+struct EigenOrder<Order::kRowMajor> {
+ static constexpr int kValue = Eigen::RowMajor;
+};
+
+template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename LhsScalar,
+ typename RhsScalar, typename DstScalar, typename MulParamsType>
+void EvalEigen(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const MulParamsType& mul_params, int max_num_threads,
+ Matrix<DstScalar>* dst) {
+ RUY_CHECK_EQ(lhs.zero_point(), 0);
+ RUY_CHECK_EQ(rhs.zero_point(), 0);
+ RUY_CHECK_EQ(dst->zero_point(), 0);
+ RUY_CHECK_EQ(mul_params.multiplier_fixedpoint(), 0);
+ RUY_CHECK_EQ(mul_params.multiplier_exponent(), 0);
+
+ static constexpr int kEigenLhsOrder = EigenOrder<LhsOrder>::kValue;
+ static constexpr int kEigenRhsOrder = EigenOrder<RhsOrder>::kValue;
+ static constexpr int kEigenDstOrder = EigenOrder<DstOrder>::kValue;
+
+ using EigenLhsType = typename Eigen::Matrix<LhsScalar, Eigen::Dynamic,
+ Eigen::Dynamic, kEigenLhsOrder>::
+ template StridedConstMapType<Eigen::OuterStride<Eigen::Dynamic>>::type;
+ using EigenRhsType = typename Eigen::Matrix<RhsScalar, Eigen::Dynamic,
+ Eigen::Dynamic, kEigenRhsOrder>::
+ template StridedConstMapType<Eigen::OuterStride<Eigen::Dynamic>>::type;
+ using EigenDstType = typename Eigen::Matrix<DstScalar, Eigen::Dynamic,
+ Eigen::Dynamic, kEigenDstOrder>::
+ template StridedMapType<Eigen::OuterStride<Eigen::Dynamic>>::type;
+ using EigenBiasType =
+ typename Eigen::Matrix<DstScalar, Eigen::Dynamic, 1>::ConstMapType;
+
+ EigenLhsType eigen_lhs(
+ lhs.data(), lhs.layout().rows(), lhs.layout().cols(),
+ Eigen::OuterStride<Eigen::Dynamic>(lhs.layout().stride()));
+ EigenRhsType eigen_rhs(
+ rhs.data(), rhs.layout().rows(), rhs.layout().cols(),
+ Eigen::OuterStride<Eigen::Dynamic>(rhs.layout().stride()));
+ EigenDstType eigen_dst(
+ dst->data(), dst->layout().rows(), dst->layout().cols(),
+ Eigen::OuterStride<Eigen::Dynamic>(dst->layout().stride()));
+ Eigen::setNbThreads(max_num_threads ? max_num_threads : 1);
+
+ if (mul_params.bias()) {
+ EigenBiasType eigen_bias(mul_params.bias(), dst->layout().rows());
+ if (mul_params.clamp_max() == std::numeric_limits<DstScalar>::infinity() &&
+ mul_params.clamp_min() == -std::numeric_limits<DstScalar>::infinity()) {
+ eigen_dst.noalias() = (eigen_lhs * eigen_rhs).colwise() + eigen_bias;
+ } else {
+ eigen_dst.noalias() = ((eigen_lhs * eigen_rhs).colwise() + eigen_bias)
+ .cwiseMin(mul_params.clamp_max())
+ .cwiseMax(mul_params.clamp_min());
+ }
+ } else {
+ if (mul_params.clamp_max() == std::numeric_limits<DstScalar>::infinity() &&
+ mul_params.clamp_min() == -std::numeric_limits<DstScalar>::infinity()) {
+ eigen_dst.noalias() = eigen_lhs * eigen_rhs;
+ } else {
+ eigen_dst.noalias() = (eigen_lhs * eigen_rhs)
+ .cwiseMin(mul_params.clamp_max())
+ .cwiseMax(mul_params.clamp_min());
+ }
+ }
+}
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename MulParamsType>
+void EvalEigen(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const MulParamsType& mul_params, int max_num_threads,
+ Matrix<DstScalar>* dst) {
+ int index =
+ Mash(lhs.layout().order(), rhs.layout().order(), dst->layout().order());
+ switch (index) {
+#define EVALEIGEN_CASE3(LHS, RHS, DST) \
+ case Mash(LHS, RHS, DST): \
+ return EvalEigen<LHS, RHS, DST>(lhs, rhs, mul_params, max_num_threads, dst);
+#define EVALEIGEN_CASE2(LHS, RHS) \
+ EVALEIGEN_CASE3(LHS, RHS, Order::kColMajor) \
+ EVALEIGEN_CASE3(LHS, RHS, Order::kRowMajor)
+#define EVALEIGEN_CASE1(LHS) \
+ EVALEIGEN_CASE2(LHS, Order::kColMajor) \
+ EVALEIGEN_CASE2(LHS, Order::kRowMajor)
+
+ EVALEIGEN_CASE1(Order::kColMajor)
+ EVALEIGEN_CASE1(Order::kRowMajor)
+
+#undef EVALEIGEN_CASE1
+#undef EVALEIGEN_CASE2
+#undef EVALEIGEN_CASE3
+
+ default:
+ RUY_CHECK(false);
+ }
+}
+
+template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename Scalar,
+ typename MulParamsType>
+void EvalEigenTensor(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs,
+ const MulParamsType& mul_params, int max_num_threads,
+ Matrix<Scalar>* dst) {
+ RUY_CHECK_EQ(lhs.zero_point(), 0);
+ RUY_CHECK_EQ(rhs.zero_point(), 0);
+ RUY_CHECK_EQ(dst->zero_point(), 0);
+ RUY_CHECK_EQ(mul_params.multiplier_fixedpoint(), 0);
+ RUY_CHECK_EQ(mul_params.multiplier_exponent(), 0);
+
+ // Eigen::TensorMap only supports unstrided layouts
+ RUY_CHECK(IsUnstrided(lhs.layout()));
+ RUY_CHECK(IsUnstrided(rhs.layout()));
+ RUY_CHECK(IsUnstrided(dst->layout()));
+
+ using TensorLhsType =
+ Eigen::TensorMap<Eigen::Tensor<const Scalar, 2, Eigen::ColMajor>>;
+ using TensorRhsType =
+ Eigen::TensorMap<Eigen::Tensor<const Scalar, 2, Eigen::ColMajor>>;
+ using TensorDstType =
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 2, Eigen::ColMajor>>;
+ using TensorBiasType =
+ Eigen::TensorMap<Eigen::Tensor<const Scalar, 1, Eigen::ColMajor>>;
+
+ const bool tr = DstOrder == Order::kRowMajor;
+ const auto& contract_lhs = tr ? rhs : lhs;
+ const auto& contract_rhs = tr ? lhs : rhs;
+
+ TensorLhsType tensor_lhs(
+ contract_lhs.data(),
+ LhsOrder == Order::kColMajor ? contract_lhs.layout().rows()
+ : contract_lhs.layout().cols(),
+ LhsOrder == Order::kColMajor ? contract_lhs.layout().cols()
+ : contract_lhs.layout().rows());
+ TensorRhsType tensor_rhs(
+ contract_rhs.data(),
+ RhsOrder == Order::kColMajor ? contract_rhs.layout().rows()
+ : contract_rhs.layout().cols(),
+ RhsOrder == Order::kColMajor ? contract_rhs.layout().cols()
+ : contract_rhs.layout().rows());
+ TensorDstType tensor_dst(dst->data(),
+ DstOrder == Order::kColMajor ? dst->layout().rows()
+ : dst->layout().cols(),
+ DstOrder == Order::kColMajor ? dst->layout().cols()
+ : dst->layout().rows());
+ using DimPair =
+ typename Eigen::Tensor<Scalar, 1, 0, Eigen::Index>::DimensionPair;
+ Eigen::array<DimPair, 1> contract_dims(
+ {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0,
+ (RhsOrder == Order::kColMajor) ? 0 : 1)});
+ Eigen::array<int, 2> shuffle(DstOrder == Order::kColMajor ? 0 : 1,
+ DstOrder == Order::kColMajor ? 1 : 0);
+ static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1);
+ static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
+ if (mul_params.bias()) {
+ TensorBiasType tensor_bias(mul_params.bias(), dst->layout().rows());
+ Eigen::array<int, 2> bias_2d_shape(tr ? 1 : dst->layout().rows(),
+ tr ? dst->layout().rows() : 1);
+ Eigen::array<int, 2> bcast(tr ? dst->layout().cols() : 1,
+ tr ? 1 : dst->layout().cols());
+ if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() &&
+ mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) {
+ tensor_dst.device(device) =
+ tensor_lhs.contract(tensor_rhs, contract_dims);
+ } else {
+ tensor_dst.device(device) =
+ (tensor_lhs.contract(tensor_rhs, contract_dims) +
+ tensor_bias.reshape(bias_2d_shape).broadcast(bcast))
+ .cwiseMin(mul_params.clamp_max())
+ .cwiseMax(mul_params.clamp_min());
+ }
+ } else {
+ if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() &&
+ mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) {
+ tensor_dst.device(device) =
+ tensor_lhs.contract(tensor_rhs, contract_dims);
+ } else {
+ tensor_dst.device(device) = tensor_lhs.contract(tensor_rhs, contract_dims)
+ .cwiseMin(mul_params.clamp_max())
+ .cwiseMax(mul_params.clamp_min());
+ }
+ }
+}
+
+template <typename Scalar, typename MulParamsType>
+void EvalEigenTensor(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs,
+ const MulParamsType& mul_params, int max_num_threads,
+ Matrix<Scalar>* dst) {
+ int index =
+ Mash(lhs.layout().order(), rhs.layout().order(), dst->layout().order());
+ switch (index) {
+#define EVALEIGENTENSOR_CASE3(LHS, RHS, DST) \
+ case Mash(LHS, RHS, DST): \
+ return EvalEigenTensor<LHS, RHS, DST>(lhs, rhs, mul_params, \
+ max_num_threads, dst);
+#define EVALEIGENTENSOR_CASE2(LHS, RHS) \
+ EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kColMajor) \
+ EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kRowMajor)
+#define EVALEIGENTENSOR_CASE1(LHS) \
+ EVALEIGENTENSOR_CASE2(LHS, Order::kColMajor) \
+ EVALEIGENTENSOR_CASE2(LHS, Order::kRowMajor)
+
+ EVALEIGENTENSOR_CASE1(Order::kColMajor)
+ EVALEIGENTENSOR_CASE1(Order::kRowMajor)
+
+#undef EVALEIGENTENSOR_CASE1
+#undef EVALEIGENTENSOR_CASE2
+#undef EVALEIGENTENSOR_CASE3
+
+ default:
+ RUY_CHECK(false);
+ }
+}
+
+template <typename Scalar>
+struct GenericBlasGemm {};
+
+template <>
+struct GenericBlasGemm<lapack::doublereal> {
+ static void Run(char* transa, char* transb, lapack::integer* m,
+ lapack::integer* n, lapack::integer* k,
+ lapack::doublereal* alpha, lapack::doublereal* a,
+ lapack::integer* lda, lapack::doublereal* b,
+ lapack::integer* ldb, lapack::doublereal* beta,
+ lapack::doublereal* c, lapack::integer* ldc) {
+ dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ }
+};
+
+template <>
+struct GenericBlasGemm<lapack::real> {
+ static void Run(char* transa, char* transb, lapack::integer* m,
+ lapack::integer* n, lapack::integer* k, lapack::real* alpha,
+ lapack::real* a, lapack::integer* lda, lapack::real* b,
+ lapack::integer* ldb, lapack::real* beta, lapack::real* c,
+ lapack::integer* ldc) {
+ sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ }
+};
+
+inline void TransposeLayout(Layout* layout) {
+ layout->set_order((layout->order() == Order::kRowMajor) ? Order::kColMajor
+ : Order::kRowMajor);
+ int tmp_rows = layout->rows();
+ layout->set_rows(layout->cols());
+ layout->set_cols(tmp_rows);
+}
+
+template <typename Scalar>
+void Transpose(Matrix<Scalar>* matrix) {
+ TransposeLayout(matrix->mutable_layout());
+}
+
+template <typename Scalar, typename MulParamsType>
+void EvalOpenBlas(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs,
+ const MulParamsType& mul_params, int max_num_threads,
+ Matrix<Scalar>* dst) {
+ RUY_CHECK_EQ(lhs.zero_point(), 0);
+ RUY_CHECK_EQ(rhs.zero_point(), 0);
+ RUY_CHECK_EQ(dst->zero_point(), 0);
+ RUY_CHECK_EQ(mul_params.multiplier_fixedpoint(), 0);
+ RUY_CHECK_EQ(mul_params.multiplier_exponent(), 0);
+
+ Matrix<Scalar> gemm_lhs;
+ Matrix<Scalar> gemm_rhs;
+ Matrix<Scalar> gemm_dst;
+ gemm_dst = *dst;
+
+ // Use Transpose to reduce to the all-column-major case.
+ // Notice that ruy::Matrix merely holds a pointer, does not own data,
+ // so Transpose is cheap -- no actual matrix data is being transposed here.
+ if (dst->layout().order() == Order::kColMajor) {
+ gemm_lhs = lhs;
+ gemm_rhs = rhs;
+ } else {
+ gemm_lhs = rhs;
+ gemm_rhs = lhs;
+ Transpose(&gemm_lhs);
+ Transpose(&gemm_rhs);
+ Transpose(&gemm_dst);
+ }
+ bool transposed_lhs = false;
+ bool transposed_rhs = false;
+
+ if (gemm_lhs.layout().order() == Order::kRowMajor) {
+ Transpose(&gemm_lhs);
+ transposed_lhs = true;
+ }
+ if (gemm_rhs.layout().order() == Order::kRowMajor) {
+ Transpose(&gemm_rhs);
+ transposed_rhs = true;
+ }
+
+ RUY_CHECK_EQ(gemm_lhs.layout().order(), Order::kColMajor);
+ RUY_CHECK_EQ(gemm_rhs.layout().order(), Order::kColMajor);
+ RUY_CHECK_EQ(gemm_dst.layout().order(), Order::kColMajor);
+
+ char transa = transposed_lhs ? 'T' : 'N';
+ char transb = transposed_rhs ? 'T' : 'N';
+ int m = gemm_lhs.layout().rows();
+ int n = gemm_rhs.layout().cols();
+ int k = gemm_lhs.layout().cols();
+ float alpha = 1;
+ Scalar* a = gemm_lhs.data();
+ int lda = gemm_lhs.layout().stride();
+ Scalar* b = gemm_rhs.data();
+ int ldb = gemm_rhs.layout().stride();
+ float beta = 0;
+ Scalar* c = gemm_dst.data();
+ int ldc = gemm_dst.layout().stride();
+ GenericBlasGemm<Scalar>::Run(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b,
+ &ldb, &beta, c, &ldc);
+
+ // BLAS does not allow us to express the bias-addition and clamping, so
+ // we use Eigen for that.
+
+ using EigenDstType =
+ typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>::
+ template StridedMapType<Eigen::OuterStride<Eigen::Dynamic>>::type;
+ using EigenBiasType =
+ typename Eigen::Matrix<Scalar, Eigen::Dynamic, 1>::ConstMapType;
+
+ EigenDstType eigen_dst(
+ gemm_dst.data(), gemm_dst.layout().rows(), gemm_dst.layout().cols(),
+ Eigen::OuterStride<Eigen::Dynamic>(gemm_dst.layout().stride()));
+ Eigen::setNbThreads(max_num_threads ? max_num_threads : 1);
+
+ if (mul_params.bias()) {
+ EigenBiasType eigen_bias(mul_params.bias(), dst->layout().rows());
+ if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() &&
+ mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) {
+ eigen_dst.noalias() = eigen_dst.colwise() + eigen_bias;
+ } else {
+ eigen_dst.noalias() = (eigen_dst.colwise() + eigen_bias)
+ .cwiseMin(mul_params.clamp_max())
+ .cwiseMax(mul_params.clamp_min());
+ }
+ } else {
+ if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() &&
+ mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) {
+ } else {
+ eigen_dst.noalias() = eigen_dst.cwiseMin(mul_params.clamp_max())
+ .cwiseMax(mul_params.clamp_min());
+ }
+ }
+}
+
+template <typename TestSetType>
+struct SupportsGemmlowp {
+ static constexpr bool kValue =
+ std::is_same<typename TestSetType::LhsScalar, std::uint8_t>::value &&
+ std::is_same<typename TestSetType::RhsScalar, std::uint8_t>::value;
+};
+
+template <typename TestSetType>
+struct UsesSingleScalarType {
+ static constexpr bool kValue =
+ std::is_same<typename TestSetType::DstScalar,
+ typename TestSetType::LhsScalar>::value &&
+ std::is_same<typename TestSetType::DstScalar,
+ typename TestSetType::RhsScalar>::value &&
+ std::is_same<typename TestSetType::DstScalar,
+ typename TestSetType::AccumScalar>::value;
+};
+
+template <typename TestSetType,
+ bool IsFloatingPoint =
+ std::is_floating_point<typename TestSetType::AccumScalar>::value,
+ bool EnableGemmlowp = SupportsGemmlowp<TestSetType>::kValue,
+ bool SingleScalarType = UsesSingleScalarType<TestSetType>::kValue>
+struct EvalExternalPathImpl {
+ using DstScalar = typename TestSetType::DstScalar;
+ static void Run(TestSetType*, TestResult<DstScalar>*) { RUY_CHECK(false); }
+};
+
+template <typename TestSetType>
+struct EvalExternalPathImpl<TestSetType, true, false, true> {
+ using DstScalar = typename TestSetType::DstScalar;
+ static void Run(TestSetType* test_set, TestResult<DstScalar>* test_result) {
+ if (test_result->external_path == ExternalPath::kEigen) {
+ EvalEigen(test_set->lhs.matrix, test_set->rhs.matrix,
+ test_set->mul_params, test_set->max_num_threads,
+ &test_result->storage_matrix.matrix);
+ } else if (test_result->external_path == ExternalPath::kEigenTensor) {
+ EvalEigenTensor(test_set->lhs.matrix, test_set->rhs.matrix,
+ test_set->mul_params, test_set->max_num_threads,
+ &test_result->storage_matrix.matrix);
+ } else if (test_result->external_path == ExternalPath::kOpenBlas) {
+ EvalOpenBlas(test_set->lhs.matrix, test_set->rhs.matrix,
+ test_set->mul_params, test_set->max_num_threads,
+ &test_result->storage_matrix.matrix);
+ } else {
+ RUY_CHECK(false);
+ }
+ }
+};
+
+template <typename TestSetType, bool SingleScalarType>
+struct EvalExternalPathImpl<TestSetType, false, true, SingleScalarType> {
+ using DstScalar = typename TestSetType::DstScalar;
+ static void Run(TestSetType* test_set, TestResult<DstScalar>* test_result) {
+ if (test_result->external_path == ExternalPath::kGemmlowp) {
+ EvalGemmlowp(test_set->lhs.matrix, test_set->rhs.matrix,
+ test_set->mul_params, test_set->max_num_threads,
+ &test_result->storage_matrix.matrix);
+ } else {
+ RUY_CHECK(false);
+ }
+ }
+};
+
+#endif // RUY_TEST_EXTERNAL_PATHS
+
+template <typename TestSetType>
+void EvalExternalPath(
+ TestSetType* test_set,
+ TestResult<typename TestSetType::DstScalar>* test_result) {
+ if (test_result->external_path == ExternalPath::kReference) {
+ // kReference is special because it's always available (the implementation
+ // is provided by ruy) and supports all cases (quantized and float).
+ ruy::ReferenceMul(test_set->lhs.matrix, test_set->rhs.matrix,
+ test_set->mul_params,
+ &test_result->storage_matrix.matrix);
+ } else {
+#ifdef RUY_TEST_EXTERNAL_PATHS
+ EvalExternalPathImpl<TestSetType>::Run(test_set, test_result);
+#endif // RUY_TEST_EXTERNAL_PATHS
+ }
+}
+
+template <typename Scalar>
+bool Agree(ExternalPath external_path1, const Matrix<Scalar>& matrix1,
+ ExternalPath external_path2, const Matrix<Scalar>& matrix2,
+ int depth) {
+ RUY_CHECK_EQ(matrix1.layout().rows(), matrix2.layout().rows());
+ RUY_CHECK_EQ(matrix1.layout().cols(), matrix2.layout().cols());
+ RUY_CHECK_EQ(matrix1.zero_point(), matrix2.zero_point());
+ const int size = matrix1.layout().rows() * matrix1.layout().cols();
+ double tolerated_max_diff = 0;
+ double tolerated_mean_diff = 0;
+ const float kSmallestAllowedDifference =
+ 4. * std::numeric_limits<Scalar>::epsilon();
+ if (std::is_floating_point<Scalar>::value) {
+ double max_abs_val = 0;
+ for (int row = 0; row < matrix1.layout().rows(); row++) {
+ for (int col = 0; col < matrix1.layout().cols(); col++) {
+ max_abs_val =
+ std::max(max_abs_val,
+ std::abs(static_cast<double>(Element(matrix1, row, col))));
+ max_abs_val =
+ std::max(max_abs_val,
+ std::abs(static_cast<double>(Element(matrix2, row, col))));
+ }
+ }
+ tolerated_max_diff = max_abs_val * std::numeric_limits<Scalar>::epsilon() *
+ 64 * std::sqrt(static_cast<float>(depth));
+ if (tolerated_max_diff < kSmallestAllowedDifference) {
+ // Clamp the tolerated max diff to be a bit above machine epsilon if the
+ // calculated value is too small.
+ tolerated_max_diff = kSmallestAllowedDifference;
+ if (external_path1 == ExternalPath::kEigen ||
+ external_path2 == ExternalPath::kEigen ||
+ external_path1 == ExternalPath::kEigenTensor ||
+ external_path2 == ExternalPath::kEigenTensor) {
+ // Make additional allowance for Eigen differences.
+ tolerated_max_diff *= 10.0f;
+ }
+ }
+ tolerated_mean_diff = tolerated_max_diff / std::sqrt(size);
+ } else if (std::is_same<Scalar, std::int32_t>::value) {
+ // raw integer case, no rounding, so we can require exactness
+ tolerated_max_diff = 0;
+ tolerated_mean_diff = 0;
+ } else {
+ // quantized case, with rounding errors in downscaling int32 accumulators
+ // to final 8bit or 16bit values.
+
+ if (external_path1 != ExternalPath::kNone ||
+ external_path2 != ExternalPath::kNone) {
+ // In this case, we are comparing against some other library than ruy.
+ //
+ // We may have to tolerate an error of +/- 1 from using different
+ // rounding in fixed-point multiplication, and then again an error of +/-
+ // 1 from using different rounding in right shifts, so the tolerance on
+ // the difference may have to be as large as 2.
+ tolerated_max_diff = 2;
+ } else if (RUY_PLATFORM_ARM) {
+ // All our code paths on ARM (32 and 64-bit) are bit-exact
+ // with the reference code (by design of the reference code).
+ tolerated_max_diff = 0;
+ } else if (RUY_PLATFORM_X86) {
+ // Our reference and ARM paths have diverged from x86 paths in PR #227.
+ // TODO: update the x86 path to adapt to that and reset that tolerance
+ // to 0.
+ tolerated_max_diff = 1;
+ } else {
+ // Other architectures, which we don't have dedicated code paths for
+ // at the moment.
+ // TODO: try resetting that tolerance to 0, since by
+ // definition we're only using non-optimized code paths here.
+ tolerated_max_diff = 1;
+ }
+
+ // totally empirical
+ tolerated_mean_diff = std::min(1.0, 2.0 * std::pow(size, -0.18));
+ }
+ double sum_diff = 0;
+ for (int row = 0; row < matrix1.layout().rows(); row++) {
+ for (int col = 0; col < matrix1.layout().cols(); col++) {
+ double elem1 = Element(matrix1, row, col);
+ double elem2 = Element(matrix2, row, col);
+ double diff = elem1 - elem2;
+
+ sum_diff += diff;
+ // Test (std::abs(diff) > tolerated_max_diff), but also true if diff is
+ // NaN.
+ if (!(std::abs(diff) <= tolerated_max_diff)) {
+ return false;
+ }
+ }
+ }
+ double mean_diff = sum_diff / size;
+ if (std::abs(mean_diff) > tolerated_mean_diff) {
+ return false;
+ }
+ return true;
+}
+
+template <typename Scalar>
+bool Agree(ExternalPath external_path1,
+ const StorageMatrix<Scalar>& storage_matrix1,
+ ExternalPath external_path2,
+ const StorageMatrix<Scalar>& storage_matrix2, int depth) {
+ VerifyConsistentFields(storage_matrix1);
+ VerifyConsistentFields(storage_matrix2);
+ return Agree(external_path1, storage_matrix1.matrix, external_path2,
+ storage_matrix2.matrix, depth);
+}
+
+template <typename Scalar>
+bool Agree(const TestResult<Scalar>& result1, const TestResult<Scalar>& result2,
+ int depth) {
+ return Agree(result1.external_path, result1.storage_matrix,
+ result2.external_path, result2.storage_matrix, depth);
+}
+
+template <typename Scalar>
+void AddTestResultToCluster(
+ TestResult<Scalar>** result,
+ std::vector<std::vector<TestResult<Scalar>*>>& clusters, int depth) {
+ bool inserted = false;
+ for (auto& cluster : clusters) {
+ bool agreement = true;
+ // Test for agreement with every result in the cluster.
+ for (auto& other_result : cluster) {
+ agreement &= Agree(**result, *other_result, depth);
+ }
+ if (agreement) {
+ cluster.push_back(*result);
+ inserted = true;
+ }
+ }
+ if (!inserted) {
+ std::vector<TestResult<Scalar>*> new_results;
+ new_results.push_back(*result);
+ clusters.push_back(new_results);
+ }
+}
+
+template <typename Scalar>
+void PrintPathsInAgreement(
+ const std::vector<std::unique_ptr<TestResult<Scalar>>>& results,
+ int depth) {
+ // A container holding vectors of TestResults, where membership indicates
+ // that all TestResults agree with each other.
+ std::vector<std::vector<TestResult<Scalar>*>> clusters;
+ for (const auto& result : results) {
+ TestResult<Scalar>* test_result = result.get();
+ AddTestResultToCluster(&test_result, clusters, depth);
+ }
+
+ std::cerr << "Error: Not all paths agree. \n";
+ for (auto& cluster : clusters) {
+ std::cerr << "These paths all agree with each other: ";
+ for (auto& result : cluster) {
+ std::cerr << PathName(*result) << ", ";
+ }
+ std::cerr << "but disagree with the rest.\n";
+ }
+}
+
+struct Stats {
+ double median;
+ double mean;
+ double min;
+ double max;
+};
+
+inline std::string StatsAsString(const Stats& stats) {
+ char buf[256];
+ snprintf(buf, sizeof(buf), "(median = %g, mean = %g, min = %g, max = %g)",
+ stats.median, stats.mean, stats.min, stats.max);
+ return std::string(buf);
+}
+
+template <typename Scalar>
+void GetMatrixStats(const Matrix<Scalar>& matrix, Stats* stats) {
+ double min = std::numeric_limits<double>::infinity();
+ double max = -std::numeric_limits<double>::infinity();
+ double sum = 0;
+ std::vector<double> allvals;
+ for (int row = 0; row < matrix.layout().rows(); row++) {
+ for (int col = 0; col < matrix.layout().cols(); col++) {
+ double val = Element(matrix, row, col);
+ min = std::min(min, val);
+ max = std::max(max, val);
+ sum += val;
+ allvals.push_back(val);
+ }
+ }
+ std::sort(allvals.begin(), allvals.end());
+ stats->min = min;
+ stats->max = max;
+ stats->mean = sum / allvals.size();
+ stats->median = allvals[allvals.size() / 2];
+}
+
+struct ErrorAnalysis {
+ Stats stats_good;
+ Stats stats_bad;
+ // The below is to help document departure from bit exactness. It's probably
+ // not going to be relevant to floating-point.
+ std::set<int> error_rows;
+ std::set<int> error_cols;
+ int row_of_first_error = 0;
+ int col_of_first_error = 0;
+ double first_error_good_value = 0;
+ double first_error_bad_value = 0;
+};
+
+template <typename TestSetType>
+void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index,
+ ErrorAnalysis* error_analysis) {
+ const auto& good_matrix = test_set.results[0]->storage_matrix.matrix;
+ const auto& bad_matrix =
+ test_set.results[first_bad_result_index]->storage_matrix.matrix;
+ GetMatrixStats(good_matrix, &error_analysis->stats_good);
+ GetMatrixStats(bad_matrix, &error_analysis->stats_bad);
+ bool found_first_error = false;
+ for (int row = 0; row < good_matrix.layout().rows(); row++) {
+ for (int col = 0; col < good_matrix.layout().cols(); col++) {
+ if (Element(good_matrix, row, col) != Element(bad_matrix, row, col)) {
+ if (!found_first_error) {
+ found_first_error = true;
+ error_analysis->row_of_first_error = row;
+ error_analysis->col_of_first_error = col;
+ error_analysis->first_error_good_value =
+ Element(good_matrix, row, col);
+ error_analysis->first_error_bad_value = Element(bad_matrix, row, col);
+ }
+ error_analysis->error_rows.insert(row);
+ error_analysis->error_cols.insert(col);
+ }
+ }
+ }
+}
+
+template <typename TestSetType>
+void ComputeReasonableMultiplier(
+ const Matrix<typename TestSetType::LhsScalar>& lhs,
+ const Matrix<typename TestSetType::RhsScalar>&, double* multiplier) {
+ using LhsScalar = typename TestSetType::LhsScalar;
+ using RhsScalar = typename TestSetType::RhsScalar;
+ using DstScalar = typename TestSetType::DstScalar;
+ if (std::is_floating_point<DstScalar>::value ||
+ std::is_same<DstScalar, std::int32_t>::value) {
+ *multiplier = 0;
+ return;
+ }
+ *multiplier = static_cast<double>(std::numeric_limits<DstScalar>::max()) /
+ (static_cast<double>(lhs.layout().cols()) *
+ std::numeric_limits<LhsScalar>::max() *
+ std::numeric_limits<RhsScalar>::max());
+}
+
+inline void QuantizeMultiplier(double multiplier_double,
+ std::int32_t* multiplier_fixedpoint,
+ int* multiplier_exponent) {
+ RUY_CHECK_GT(multiplier_double, 0);
+ if (multiplier_double == 0.) {
+ *multiplier_fixedpoint = 0;
+ *multiplier_exponent = 0;
+ return;
+ }
+ const double q = std::frexp(multiplier_double, multiplier_exponent);
+ auto q_fixed = static_cast<std::int64_t>(std::round(q * (1ll << 31)));
+ RUY_CHECK_LE(q_fixed, (1ll << 31));
+ if (q_fixed == (1ll << 31)) {
+ q_fixed /= 2;
+ ++*multiplier_exponent;
+ }
+ RUY_CHECK_LE(q_fixed, std::numeric_limits<std::int32_t>::max());
+ *multiplier_fixedpoint = static_cast<std::int32_t>(q_fixed);
+}
+
+template <typename TestSetType>
+void SwitchMultiplierToPerChannel(TestSetType* test_set) {
+ ChannelDimension channel_dimension =
+ (test_set->benchmark || global_random_engine()() % 2)
+ ? ChannelDimension::kRow
+ : ChannelDimension::kCol;
+ test_set->mul_params.set_channel_dimension(channel_dimension);
+ const int num_channels = channel_dimension == ChannelDimension::kRow
+ ? test_set->rows
+ : test_set->cols;
+ test_set->per_channel_multiplier_fixedpoint.resize(num_channels);
+ test_set->per_channel_multiplier_exponent.resize(num_channels);
+ for (int i = 0; i < num_channels; i++) {
+ // multipliers typically range in [2^30 ; 2^31 - 1].
+ // Values in [0, 2^30 - 1] are normally unused, but harmless.
+ // Thus a good way to randomize multipliers is to subtract from them
+ // a random value smaller than 2^30 but still significant compared to it.
+ std::int32_t nudged_multiplier =
+ test_set->mul_params.multiplier_fixedpoint() -
+ (global_random_engine()() % (1 << 26));
+ int nudged_exponent = test_set->mul_params.multiplier_exponent() - 1 +
+ (global_random_engine()() % 4);
+ test_set->per_channel_multiplier_fixedpoint[i] = nudged_multiplier;
+ test_set->per_channel_multiplier_exponent[i] = nudged_exponent;
+ }
+ test_set->mul_params.set_multiplier_fixedpoint(0);
+ test_set->mul_params.set_multiplier_exponent(0);
+ test_set->mul_params.set_multiplier_fixedpoint_perchannel(
+ test_set->per_channel_multiplier_fixedpoint.data());
+ test_set->mul_params.set_multiplier_exponent_perchannel(
+ test_set->per_channel_multiplier_exponent.data());
+}
+
+template <
+ typename TestSetType,
+ bool IsApplicable =
+ std::is_same<typename TestSetType::AccumScalar, std::int32_t>::value &&
+ !std::is_same<typename TestSetType::DstScalar, std::int32_t>::value>
+struct MakeSpecMultiplierFieldsImpl {};
+
+template <typename TestSetType>
+struct MakeSpecMultiplierFieldsImpl<TestSetType, true> {
+ static void Run(TestSetType* test_set) {
+ double multiplier;
+ ComputeReasonableMultiplier<TestSetType>(test_set->lhs.matrix,
+ test_set->rhs.matrix, &multiplier);
+ typename TestSetType::AccumScalar multiplier_fixedpoint;
+ int multiplier_exponent;
+ QuantizeMultiplier(multiplier, &multiplier_fixedpoint,
+ &multiplier_exponent);
+ test_set->mul_params.set_multiplier_fixedpoint(multiplier_fixedpoint);
+ test_set->mul_params.set_multiplier_exponent(multiplier_exponent);
+
+ if (!test_set->benchmark) {
+ test_set->perchannel = global_random_engine()() & 1;
+ }
+ if (test_set->perchannel) {
+ SwitchMultiplierToPerChannel(test_set);
+ }
+ }
+};
+
+template <typename TestSetType>
+struct MakeSpecMultiplierFieldsImpl<TestSetType, false> {
+ static void Run(TestSetType*) {}
+};
+
+template <typename MulParamsType>
+void MakeSpecClampFields(MulParamsType* mul_params) {
+ using DstScalar = typename MulParamsType::DstScalar;
+
+ if (getenv("BENCHMARK_ONLY_MATMUL")) {
+ if (std::is_floating_point<DstScalar>::value) {
+ mul_params->set_clamp_min(-std::numeric_limits<DstScalar>::infinity());
+ mul_params->set_clamp_max(std::numeric_limits<DstScalar>::infinity());
+ } else {
+ mul_params->set_clamp_min(std::numeric_limits<DstScalar>::lowest());
+ mul_params->set_clamp_max(std::numeric_limits<DstScalar>::max());
+ }
+ return;
+ }
+
+ mul_params->set_clamp_min(std::numeric_limits<DstScalar>::lowest() + 1);
+ mul_params->set_clamp_max(std::numeric_limits<DstScalar>::max() - 1);
+}
+
+void MakeSpecClampFields(MulParams<std::int32_t, std::int32_t>*) {
+ // Returning raw accumulators, clamping is not supported.
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeZeroPoints() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kInitial);
+ if (!benchmark && !use_specified_zero_points) {
+ MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point);
+ MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point);
+ // If destination is std::int32_t, no dst_zero_point is necessary.
+ if (std::is_same<DstScalar, std::int32_t>::value) {
+ dst_zero_point = 0;
+ } else {
+ MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point);
+ }
+ }
+ life_stage = LifeStage::kHasZeroPoints;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeLhsRhs() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasZeroPoints);
+ MakeRandom(rows, depth, lhs_order, lhs_zero_point, layout_style,
+ RandomRange::kOffCenterAvoidMinValue, &lhs);
+ MakeRandom(depth, cols, rhs_order, rhs_zero_point, layout_style,
+ RandomRange::kGeneral, &rhs);
+ if (!benchmark) {
+ cache_lhs = (global_random_engine()() & 0xf) == 0;
+ cache_rhs = (global_random_engine()() & 0xf) == 0;
+ }
+ if (cache_lhs) {
+ lhs.matrix.set_cache_policy(CachePolicy::kAlwaysCache);
+ }
+ if (cache_rhs) {
+ rhs.matrix.set_cache_policy(CachePolicy::kAlwaysCache);
+ }
+ life_stage = LifeStage::kHasLhsRhs;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeMulParams() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs);
+
+ if (!getenv("BENCHMARK_ONLY_MATMUL") &&
+ (benchmark || (global_random_engine()() & 1))) {
+ MakeRandomVector(RandomRange::kBias, std::max(rows, cols), &bias_data);
+ mul_params.set_bias(bias_data.data());
+ }
+ if (lhs.matrix.zero_point() == std::numeric_limits<LhsScalar>::lowest() &&
+ rhs.matrix.zero_point() == std::numeric_limits<RhsScalar>::lowest()) {
+ lhs.matrix.set_zero_point(lhs.matrix.zero_point() + 1);
+ }
+ MakeSpecMultiplierFieldsImpl<TestSet>::Run(this);
+ MakeSpecClampFields(&mul_params);
+ life_stage = LifeStage::kHasMulParams;
+}
+
+inline int GetIntEnvVarOrZero(const char* name) {
+ const char* val = getenv(name);
+ if (!val) {
+ return 0;
+ }
+ return std::stoi(val);
+}
+
+inline float GetFloatEnvVarOrZero(const char* name) {
+ const char* val = getenv(name);
+ if (!val) {
+ return 0;
+ }
+ return std::stof(val);
+}
+
+inline int GetHexIntEnvVarOrZero(const char* name) {
+ const char* val = getenv(name);
+ if (!val) {
+ return 0;
+ }
+ return std::stoi(val, nullptr, 16);
+}
+
+inline bool GetBoolEnvVarOrFalse(const char* name) {
+ return static_cast<bool>(GetIntEnvVarOrZero(name));
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeOtherParams() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasMulParams);
+ if (max_num_threads == 0) {
+ max_num_threads = GetIntEnvVarOrZero("THREADS");
+ }
+ life_stage = LifeStage::kHasOtherParams;
+}
+
+inline std::vector<Path> PathsBitfieldAsVector(Path paths_bitfield) {
+ std::vector<Path> result;
+ std::uint32_t remaining_paths = static_cast<std::uint32_t>(paths_bitfield);
+ std::uint32_t test_bit = 1;
+ while (remaining_paths) {
+ if (remaining_paths & test_bit) {
+ result.push_back(static_cast<Path>(test_bit));
+ }
+ remaining_paths &= ~test_bit;
+ test_bit <<= 1;
+ }
+ return result;
+}
+
+inline std::vector<Tuning> EnumerateTuningsForPath(Path path, bool benchmark) {
+ if (benchmark) {
+ return {Tuning::kAuto};
+ }
+#if RUY_PLATFORM_ARM
+ if (path == Path::kNeon || path == Path::kNeonDotprod) {
+ return {Tuning::kA55ish, Tuning::kGeneric, Tuning::kAuto};
+ }
+#endif
+ (void)path;
+ return {Tuning::kAuto};
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeResultPaths() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasOtherParams);
+
+ Path paths_bitfield = static_cast<Path>(GetHexIntEnvVarOrZero("PATHS"));
+
+ if (paths_bitfield == Path::kNone) {
+ // Use a dummy Context just to perform the resolution of specific runtime
+ // enabled paths.
+ Context context;
+ paths_bitfield = get_ctx(&context)->GetRuntimeEnabledPaths();
+ }
+
+ // Disable the internal test-only variants of the StandardCpp path on large
+ // tests.
+ // This constant be large enough to exercise some interesting BlockMap logic,
+ // small enough to avoid large test latency increases from running these
+ // slow code paths on large matrix multiplications.
+ const int kMaxSizeToTestInternalStandardCppVariants = 300;
+ if (rows > kMaxSizeToTestInternalStandardCppVariants ||
+ cols > kMaxSizeToTestInternalStandardCppVariants ||
+ depth > kMaxSizeToTestInternalStandardCppVariants) {
+ paths_bitfield = paths_bitfield & kAllPaths;
+ }
+
+ // Trim bits that don't correspond to a compiled path,
+ // to allow specifying e.g. ffff to mean 'all paths' regardless of whether all
+ // those bits exist as actual paths.
+ paths_bitfield = paths_bitfield & kAllPathsIncludingInternalVariants;
+ RUY_CHECK_NE(paths_bitfield, Path::kNone);
+ paths = PathsBitfieldAsVector(paths_bitfield);
+
+ // kReference is a special 'external path' that's always available.
+ // It can still be disabled by NOEXT.
+ if (!GetBoolEnvVarOrFalse("NOEXT")) {
+ external_paths.push_back(ExternalPath::kReference);
+ }
+
+#ifdef RUY_TEST_EXTERNAL_PATHS
+
+ if (!GetBoolEnvVarOrFalse("NOEXT")) {
+ if (SupportsGemmlowp<TestSet>::kValue) {
+#ifdef GEMMLOWP_SSE4
+ const bool gemmlowp_supported =
+ !mul_params.multiplier_fixedpoint_perchannel() &&
+ mul_params.channel_dimension() == ChannelDimension::kRow;
+#else
+ const bool gemmlowp_supported =
+ mul_params.channel_dimension() == ChannelDimension::kRow;
+#endif
+ if (gemmlowp_supported) {
+ external_paths.push_back(ExternalPath::kGemmlowp);
+ }
+ }
+ if (UsesSingleScalarType<TestSet>::kValue &&
+ std::is_floating_point<AccumScalar>::value) {
+ external_paths.push_back(ExternalPath::kEigen);
+ if (layout_style == LayoutStyle::kUnstridedLinear) {
+ external_paths.push_back(ExternalPath::kEigenTensor);
+ }
+// We link against a generic BLAS target that only maps to OpenBLAS on specific
+// architectures.
+#if RUY_PLATFORM_ARM_32 || RUY_PLATFORM_ARM_64
+ // OpenBLAS multi-threading is disabled, so avoid mixing single-threaded
+ // and multi-threaded benchmark results.
+ if (max_num_threads == 1 && !getenv("NO_OPENBLAS")) {
+ external_paths.push_back(ExternalPath::kOpenBlas);
+ }
+#endif
+ }
+ }
+
+#endif // RUY_TEST_EXTERNAL_PATHS
+
+ for (Path path : paths) {
+ for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) {
+ results.emplace_back(new TestResultType);
+ TestResultType& result = *results.back();
+ result.path = path;
+ result.tuning = tuning;
+ MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style,
+ RandomRange::kGeneral, &result.storage_matrix);
+ }
+ }
+
+ for (ExternalPath external_path : external_paths) {
+ results.emplace_back(new TestResultType);
+ TestResultType& result = *results.back();
+ result.external_path = external_path;
+ MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style,
+ RandomRange::kGeneral, &result.storage_matrix);
+ }
+
+ life_stage = LifeStage::kHasResultPaths;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::EvalResult(
+ TestResult<DstScalar>* result) {
+ RUY_CHECK(result->path != Path::kNone ||
+ result->external_path != ExternalPath::kNone);
+ if (result->path != Path::kNone) {
+ EvalRuy(result);
+ } else {
+ EvalExternalPath(this, result);
+ }
+ const std::string& pathname = PathName(*result);
+ if (std::find(CoveredPaths()->begin(), CoveredPaths()->end(), pathname) ==
+ CoveredPaths()->end()) {
+ CoveredPaths()->push_back(pathname);
+ }
+}
+
+using f32 = float;
+using f64 = double;
+using u8 = std::uint8_t;
+using i8 = std::int8_t;
+using u16 = std::uint16_t;
+using i16 = std::int16_t;
+using u32 = std::uint32_t;
+using i32 = std::int32_t;
+using u64 = std::uint64_t;
+using i64 = std::int64_t;
+
+template <typename Scalar>
+const char* TypeName() {
+ return nullptr;
+}
+
+#define RUY_TYPENAME(TYPE) \
+ template <> \
+ inline const char* TypeName<TYPE>() { \
+ return #TYPE; \
+ }
+
+RUY_TYPENAME(f32)
+RUY_TYPENAME(f64)
+RUY_TYPENAME(u8)
+RUY_TYPENAME(i8)
+RUY_TYPENAME(u16)
+RUY_TYPENAME(i16)
+RUY_TYPENAME(u32)
+RUY_TYPENAME(i32)
+RUY_TYPENAME(u64)
+RUY_TYPENAME(i64)
+
+#undef RUY_TYPENAME
+
+template <typename Scalar>
+const char* SymmetryName(const Matrix<Scalar>& matrix) {
+ if (matrix.zero_point() == SymmetricZeroPoint<Scalar>()) {
+ return "symm";
+ } else {
+ return "asymm";
+ }
+}
+
+template <typename Scalar>
+int StorageSize(const Matrix<Scalar>& matrix) {
+ return sizeof(Scalar) * FlatSize(matrix.layout());
+}
+
+// Helper that replicates a buffer and gives out pointers to the replicas.
+// This is useful when one wants to traverse data so that it is cold in cache.
+// By having a sufficiently large value of num_repeats, one can ensure that the
+// working set covered by the replicas is greater than the cache size.
+template <typename T>
+class RepeatedBuffer {
+ public:
+ RepeatedBuffer() = default;
+ void Init(const T* elems, int num_elems, int num_repeats) {
+ buffers_.clear();
+ allocator_.FreeAll();
+ for (int i = 0; i < num_repeats; i++) {
+ T* p;
+ allocator_.Allocate(num_elems, &p);
+ memcpy(p, elems, num_elems * sizeof(T));
+ buffers_.push_back(p);
+ }
+ }
+ T* Next() {
+ T* ret = buffers_[current_];
+ current_ = (current_ + 1) % buffers_.size();
+ return ret;
+ }
+
+ private:
+ Allocator allocator_;
+ std::vector<T*> buffers_;
+ int current_ = 0;
+};
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Benchmark(
+ TestResult<DstScalar>* result) {
+ const bool cold = getenv("RUY_BENCHMARK_COLD");
+ LhsScalar* orig_lhs_data = lhs.matrix.data();
+ RhsScalar* orig_rhs_data = rhs.matrix.data();
+ DstScalar* orig_dst_data = result->storage_matrix.matrix.data();
+
+ int num_matmul_sets = 0;
+
+ RepeatedBuffer<LhsScalar> cold_lhs;
+ RepeatedBuffer<RhsScalar> cold_rhs;
+ RepeatedBuffer<DstScalar> cold_dst;
+
+ if (cold) {
+ const int kWorkingSetSize = 100 << 20;
+ const int each_matmul_set_size = StorageSize(lhs.matrix) +
+ StorageSize(rhs.matrix) +
+ StorageSize(result->storage_matrix.matrix);
+ num_matmul_sets =
+ (kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size;
+
+ cold_lhs.Init(lhs.matrix.data(), FlatSize(lhs.matrix.layout()),
+ num_matmul_sets);
+ cold_rhs.Init(rhs.matrix.data(), FlatSize(rhs.matrix.layout()),
+ num_matmul_sets);
+ cold_dst.Init(result->storage_matrix.matrix.data(),
+ FlatSize(result->storage_matrix.matrix.layout()),
+ num_matmul_sets);
+ }
+ const bool record_pmu = GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU");
+ int repeats = GetIntEnvVarOrZero("RUY_BENCHMARK_REPEATS");
+ if (!repeats) {
+ repeats = 4;
+ }
+ float benchmark_min_secs = GetFloatEnvVarOrZero("RUY_BENCHMARK_MIN_SECS");
+ if (!benchmark_min_secs) {
+ benchmark_min_secs = 0.5;
+ }
+#ifdef RUY_PROFILER
+ {
+ const char* lhstype = TypeName<LhsScalar>();
+ const char* lhssymm = SymmetryName(lhs.matrix);
+ const char* rhstype = TypeName<RhsScalar>();
+ const char* rhssymm = SymmetryName(rhs.matrix);
+
+ printf("Profiling path=%s shape=(%dx%dx%d) lhs=(%s,%s) rhs=(%s,%s)\n",
+ PathName(*result).c_str(), rows, depth, cols, lhstype, lhssymm,
+ rhstype, rhssymm);
+ ruy::profiler::ScopeProfile profile;
+#endif
+
+ float latency = std::numeric_limits<float>::infinity();
+ float l1_refill_rate = std::numeric_limits<float>::infinity();
+ float l2_refill_rate = std::numeric_limits<float>::infinity();
+ float l3_refill_rate = std::numeric_limits<float>::infinity();
+ float l1tlb_refill_rate = std::numeric_limits<float>::infinity();
+ float l2tlb_refill_rate = std::numeric_limits<float>::infinity();
+ float mispred_rate = std::numeric_limits<float>::infinity();
+ float frontend_stall_rate = std::numeric_limits<float>::infinity();
+ float backend_stall_rate = std::numeric_limits<float>::infinity();
+
+ for (int repeat = 0; repeat < repeats; repeat++) {
+ auto& pmu_events = GlobalPmuEvents();
+ if (record_pmu) {
+ pmu_events.StartRecording();
+ }
+ TimePoint time_start = Now();
+ TimePoint t = time_start;
+ int iters = 0;
+ int iters_at_a_time = 1;
+ while (ToFloatSeconds(t - time_start) < benchmark_min_secs) {
+ for (int i = 0; i < iters_at_a_time; i++) {
+ if (cold) {
+ lhs.matrix.set_data(cold_lhs.Next());
+ rhs.matrix.set_data(cold_rhs.Next());
+ result->storage_matrix.matrix.set_data(cold_dst.Next());
+ }
+ EvalResult(result);
+ iters++;
+ }
+ iters_at_a_time *= 2;
+ t = Now();
+ }
+ latency = std::min(
+ latency, static_cast<float>(ToFloatSeconds(t - time_start) / iters));
+ if (record_pmu) {
+ pmu_events.StopRecording();
+ const float normalization_factor =
+ 1.0f / (static_cast<float>(iters) * rows * cols * depth);
+ l1_refill_rate = std::min(
+ l1_refill_rate, pmu_events.L1RefillCount() * normalization_factor);
+ l2_refill_rate = std::min(
+ l2_refill_rate, pmu_events.L2RefillCount() * normalization_factor);
+ l3_refill_rate = std::min(
+ l3_refill_rate, pmu_events.L3RefillCount() * normalization_factor);
+ l1tlb_refill_rate =
+ std::min(l1tlb_refill_rate,
+ pmu_events.L1TLBRefillCount() * normalization_factor);
+ l2tlb_refill_rate =
+ std::min(l2tlb_refill_rate,
+ pmu_events.L2TLBRefillCount() * normalization_factor);
+ mispred_rate =
+ std::min(mispred_rate, pmu_events.BranchMispredictionCount() *
+ normalization_factor);
+ frontend_stall_rate =
+ std::min(frontend_stall_rate,
+ pmu_events.FrontendStallCount() * normalization_factor);
+ backend_stall_rate =
+ std::min(backend_stall_rate,
+ pmu_events.BackendStallCount() * normalization_factor);
+ }
+ }
+ result->latency = latency;
+ if (record_pmu) {
+ result->l1_refill_rate = l1_refill_rate;
+ result->l2_refill_rate = l2_refill_rate;
+ result->l3_refill_rate = l3_refill_rate;
+ result->l1tlb_refill_rate = l1tlb_refill_rate;
+ result->l2tlb_refill_rate = l2tlb_refill_rate;
+ result->mispred_rate = mispred_rate;
+ result->frontend_stall_rate = frontend_stall_rate;
+ result->backend_stall_rate = backend_stall_rate;
+ }
+
+#ifdef RUY_PROFILER
+ }
+ fflush(stdout);
+#endif
+
+ if (cold) {
+ lhs.matrix.set_data(orig_lhs_data);
+ rhs.matrix.set_data(orig_rhs_data);
+ memcpy(orig_dst_data, result->storage_matrix.matrix.data(),
+ StorageSize(result->storage_matrix.matrix));
+ result->storage_matrix.matrix.set_data(orig_dst_data);
+ }
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Eval() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasResultPaths);
+ for (auto& result : results) {
+ if (benchmark) {
+ Benchmark(result.get());
+ } else {
+ EvalResult(result.get());
+ }
+ }
+ life_stage = LifeStage::kEvaluated;
+}
+
+template <typename Scalar>
+std::string DumpRegion(const Matrix<Scalar>& matrix, int center_row,
+ int center_col) {
+ static constexpr int kRadius = 20;
+ int first_row = std::max(0, center_row - kRadius);
+ int last_row = std::min(matrix.layout().rows() - 1, center_row + kRadius);
+ int first_col = std::max(0, center_col - kRadius);
+ int last_col = std::min(matrix.layout().cols() - 1, center_col + kRadius);
+ std::ostringstream stream;
+ for (int row = first_row; row <= last_row; row++) {
+ for (int col = first_col; col <= last_col; col++) {
+ stream << static_cast<double>(Element(matrix, row, col)) << " ";
+ }
+ stream << "\n";
+ }
+ return stream.str();
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::VerifyTestResults()
+ const {
+ const int depth = lhs.matrix.layout().cols();
+ for (int i = 0; i < static_cast<int>(results.size()) - 1; i++) {
+ if (!Agree(*results[i], *results[i + 1], depth)) {
+ PrintPathsInAgreement(results, depth);
+ ErrorAnalysis error_analysis;
+ AnalyzeTestError(*this, i + 1, &error_analysis);
+ std::cerr << "Shape: rows = " << rows << ", cols = " << cols
+ << ", depth = " << depth << std::endl;
+ std::cerr << "Stats of the good result matrix: "
+ << StatsAsString(error_analysis.stats_good) << std::endl;
+ std::cerr << "Stats of the bad result matrix: "
+ << StatsAsString(error_analysis.stats_bad) << std::endl;
+ if (static_cast<int>(error_analysis.error_rows.size()) < rows) {
+ std::cerr << "Rows containing errors: "
+ << Join(error_analysis.error_rows) << std::endl;
+ } else {
+ std::cerr << "Errors found in ALL rows." << std::endl;
+ }
+ if (static_cast<int>(error_analysis.error_cols.size()) < cols) {
+ std::cerr << "Cols containing errors: "
+ << Join(error_analysis.error_cols) << std::endl;
+ } else {
+ std::cerr << "Errors found in ALL cols." << std::endl;
+ }
+ std::cerr << "The first error occurs at row "
+ << error_analysis.row_of_first_error << ", col "
+ << error_analysis.col_of_first_error << std::endl;
+ std::cerr << "Good value: " << error_analysis.first_error_good_value
+ << std::endl;
+ std::cerr << "Bad value : " << error_analysis.first_error_bad_value
+ << std::endl;
+ std::cerr << "Region of Good result matrix around first error:\n\n"
+ << DumpRegion(results[0]->storage_matrix.matrix,
+ error_analysis.row_of_first_error,
+ error_analysis.col_of_first_error)
+ << std::endl;
+ std::cerr << "Region of Bad result matrix around first error:\n\n"
+ << DumpRegion(results[i + 1]->storage_matrix.matrix,
+ error_analysis.row_of_first_error,
+ error_analysis.col_of_first_error)
+ << std::endl;
+ RUY_CHECK(false);
+ }
+ }
+}
+
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
+ typename DstScalar>
+void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Verify() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kEvaluated);
+ VerifyTestResults();
+ life_stage = LifeStage::kFinal;
+}
+
+template <typename TestSetType>
+void TestRCC(int rows, int depth, int cols) {
+ TestSetType test_set;
+ test_set.rows = rows;
+ test_set.depth = depth;
+ test_set.cols = cols;
+ test_set.lhs_order = Order::kRowMajor;
+ test_set.rhs_order = Order::kColMajor;
+ test_set.dst_order = Order::kColMajor;
+ test_set.layout_style = LayoutStyle::kUnstridedLinear;
+ test_set.Run();
+}
+
+template <typename TestSetType>
+void TestNonRCC(int rows, int depth, int cols) {
+ TestSetType test_set;
+ test_set.rows = rows;
+ test_set.depth = depth;
+ test_set.cols = cols;
+ test_set.lhs_order = Order::kColMajor;
+ test_set.rhs_order = Order::kColMajor;
+ test_set.dst_order = Order::kColMajor;
+ test_set.layout_style = LayoutStyle::kUnstridedLinear;
+ test_set.Run();
+}
+
+template <typename TestSetType>
+void TestLinearAllOrders(int rows, int depth, int cols) {
+ const std::vector<Order> orders{Order::kColMajor, Order::kRowMajor};
+
+ for (Order lhs_order : orders) {
+ for (Order rhs_order : orders) {
+ for (Order dst_order : orders) {
+ TestSetType test_set;
+ test_set.rows = rows;
+ test_set.depth = depth;
+ test_set.cols = cols;
+ test_set.lhs_order = lhs_order;
+ test_set.rhs_order = rhs_order;
+ test_set.dst_order = dst_order;
+ test_set.layout_style = LayoutStyle::kLinear;
+ test_set.Run();
+ }
+ }
+ }
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_TEST_H_
diff --git a/ruy/test_fast.cc b/ruy/test_fast.cc
new file mode 100644
index 0000000..5598afb
--- /dev/null
+++ b/ruy/test_fast.cc
@@ -0,0 +1,109 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This test contains cheap test cases, completes in a few seconds.
+
+#include <vector>
+
+#include "ruy/test.h"
+
+namespace ruy {
+
+using LhsScalar = RUY_TEST_LHSSCALAR;
+using RhsScalar = RUY_TEST_RHSSCALAR;
+using AccumScalar = RUY_TEST_ACCUMSCALAR;
+using DstScalar = RUY_TEST_DSTSCALAR;
+
+using TestSetType = TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>;
+
+TEST(RuyTest, TestSquareMuls) {
+ const std::vector<int> sizes{
+ // small sizes
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ // multiplies of 16
+ 16,
+ 32,
+ 48,
+ 64,
+ // pot-minus-1 sizes
+ 15,
+ 31,
+ 63,
+ // pot-plus-1 sizes
+ 17,
+ 33,
+ 65,
+ };
+
+ for (int size : sizes) {
+ TestRCC<TestSetType>(size, size, size);
+ TestLinearAllOrders<TestSetType>(size, size, size);
+ }
+}
+
+TEST(RuyTest, TestMiscMuls) {
+ const int shapes[][3] = {
+ {2, 3, 4}, {7, 6, 5}, {12, 23, 6}, {19, 3, 11}, {3, 10, 17},
+ {30, 21, 43}, {7, 57, 9}, {49, 69, 71}, {38, 111, 29}, {87, 98, 76},
+ {16, 96, 16}, {16, 88, 16}, {16, 84, 16}, {16, 92, 16}, {16, 82, 16},
+ {16, 81, 16}, {16, 95, 16}, {3, 128, 5}};
+ for (const auto& shape : shapes) {
+ TestLinearAllOrders<TestSetType>(shape[0], shape[1], shape[2]);
+ }
+}
+
+TEST(RuyTest, TestDeepMuls) {
+ // TODO(b/137649322): clarify what's the max allowed matrix size.
+ TestRCC<TestSetType>(1, 32767, 1);
+ TestLinearAllOrders<TestSetType>(5, 5001, 4);
+ TestLinearAllOrders<TestSetType>(9, 1025, 10);
+}
+
+TEST(RuyTest, TestShallowMuls) {
+ TestLinearAllOrders<TestSetType>(101, 1, 103);
+ TestLinearAllOrders<TestSetType>(71, 2, 53);
+ TestLinearAllOrders<TestSetType>(51, 3, 73);
+ TestLinearAllOrders<TestSetType>(51, 4, 43);
+}
+
+TEST(RuyTest, TestNarrowMuls) {
+ for (int width : {1, 2, 3, 4, 5, 8}) {
+ TestLinearAllOrders<TestSetType>(width, 12, 13);
+ TestLinearAllOrders<TestSetType>(15, 19, width);
+ TestLinearAllOrders<TestSetType>(width, 123, 137);
+ TestLinearAllOrders<TestSetType>(158, 119, width);
+ }
+}
+
+TEST(RuyTest, TestGEMV) {
+ for (int size = 1; size < 1024; size *= 2) {
+ for (int depth = 1; depth < 500; depth += 47) {
+ TestLinearAllOrders<TestSetType>(size, depth, 1);
+ }
+ }
+ TestLinearAllOrders<TestSetType>(5, 5001, 1);
+ TestLinearAllOrders<TestSetType>(8193, 17, 1);
+}
+
+} // namespace ruy
diff --git a/ruy/test_slow.cc b/ruy/test_slow.cc
new file mode 100644
index 0000000..0ddc40b
--- /dev/null
+++ b/ruy/test_slow.cc
@@ -0,0 +1,70 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This test contains more expensive test cases.
+
+#include "ruy/test.h"
+
+namespace ruy {
+
+using LhsScalar = RUY_TEST_LHSSCALAR;
+using RhsScalar = RUY_TEST_RHSSCALAR;
+using AccumScalar = RUY_TEST_ACCUMSCALAR;
+using DstScalar = RUY_TEST_DSTSCALAR;
+
+using TestSetType = TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>;
+
+TEST(RuyTest, TestBigNarrowMuls) {
+ for (int width : {1, 2, 3, 4, 5, 8}) {
+ TestRCC<TestSetType>(width, 401, 601);
+ TestRCC<TestSetType>(587, 443, width);
+ }
+ TestRCC<TestSetType>(7, 45984,
+ 5); // Large enough to trigger row-sum overflows.
+ TestRCC<TestSetType>(512, 256, 16);
+}
+
+TEST(RuyTest, TestBigShallowMuls) {
+ TestLinearAllOrders<TestSetType>(501, 1, 321);
+ TestLinearAllOrders<TestSetType>(301, 5, 403);
+ TestLinearAllOrders<TestSetType>(256, 32, 512);
+}
+
+TEST(RuyTest, TestBigMuls) {
+ TestRCC<TestSetType>(225, 303, 199);
+ TestLinearAllOrders<TestSetType>(256, 192, 128);
+}
+
+TEST(RuyTest, TestBigPowerOfTwoDepthWithAvoidAliasing) {
+ // Important to test some power-of-two depths: that's when the
+ // RUY_AVOID_ALIASING optimization kicks in and makes unstrided matrices
+ // strided, exposing bugs in kernels mixing up size and stride.
+ // Moreover, it's important that the test matrices be sufficiently wide
+ // that they will result in multiple blocks, exposing bugs in the
+ // computation of the base address of each block.
+ TestLinearAllOrders<TestSetType>(70, 1024, 80);
+ TestLinearAllOrders<TestSetType>(60, 2048, 70);
+ TestLinearAllOrders<TestSetType>(40, 4096, 50);
+}
+
+TEST(RuyTest, TestGEMV) {
+ for (int size = 1025; size <= 1409; size += 384) {
+ for (int depth = 350; depth < 500; depth += 47) {
+ TestLinearAllOrders<TestSetType>(size, depth, 1);
+ }
+ }
+}
+
+} // namespace ruy
diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc
new file mode 100644
index 0000000..100cfe3
--- /dev/null
+++ b/ruy/thread_pool.cc
@@ -0,0 +1,218 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/thread_pool.h"
+
+#include <atomic>
+#include <chrono> // NOLINT(build/c++11)
+#include <condition_variable> // NOLINT(build/c++11)
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <mutex> // NOLINT(build/c++11)
+#include <thread> // NOLINT(build/c++11)
+
+#include "ruy/check_macros.h"
+#include "ruy/trace.h"
+#include "ruy/wait.h"
+
+namespace ruy {
+
+// A worker thread.
+class Thread {
+ public:
+ enum class State {
+ Startup, // The initial state before the thread main loop runs.
+ Ready, // Is not working, has not yet received new work to do.
+ HasWork, // Has work to do.
+ ExitAsSoonAsPossible // Should exit at earliest convenience.
+ };
+
+ explicit Thread(BlockingCounter* counter_to_decrement_when_ready,
+ Duration spin_duration)
+ : task_(nullptr),
+ state_(State::Startup),
+ counter_to_decrement_when_ready_(counter_to_decrement_when_ready),
+ spin_duration_(spin_duration) {
+ thread_.reset(new std::thread(ThreadFunc, this));
+ }
+
+ ~Thread() {
+ ChangeState(State::ExitAsSoonAsPossible);
+ thread_->join();
+ }
+
+ // Changes State; may be called from either the worker thread
+ // or the master thread; however, not all state transitions are legal,
+ // which is guarded by assertions.
+ //
+ // The Task argument is to be used only with new_state==HasWork.
+ // It specifies the Task being handed to this Thread.
+ void ChangeState(State new_state, Task* task = nullptr) {
+ state_mutex_.lock();
+ State old_state = state_.load(std::memory_order_relaxed);
+ RUY_DCHECK_NE(old_state, new_state);
+ switch (old_state) {
+ case State::Startup:
+ RUY_DCHECK_EQ(new_state, State::Ready);
+ break;
+ case State::Ready:
+ RUY_DCHECK(new_state == State::HasWork ||
+ new_state == State::ExitAsSoonAsPossible);
+ break;
+ case State::HasWork:
+ RUY_DCHECK(new_state == State::Ready ||
+ new_state == State::ExitAsSoonAsPossible);
+ break;
+ default:
+ abort();
+ }
+ switch (new_state) {
+ case State::Ready:
+ if (task_) {
+ // Doing work is part of reverting to 'ready' state.
+ task_->Run();
+ task_ = nullptr;
+ }
+ break;
+ case State::HasWork:
+ RUY_DCHECK(!task_);
+ task_ = task;
+ break;
+ default:
+ break;
+ }
+ state_.store(new_state, std::memory_order_relaxed);
+ state_cond_.notify_all();
+ state_mutex_.unlock();
+ if (new_state == State::Ready) {
+ counter_to_decrement_when_ready_->DecrementCount();
+ }
+ }
+
+ static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
+
+ // Called by the master thead to give this thread work to do.
+ void StartWork(Task* task) { ChangeState(State::HasWork, task); }
+
+ private:
+ // Thread entry point.
+ void ThreadFuncImpl() {
+ RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
+ ChangeState(State::Ready);
+
+ // Thread main loop
+ while (true) {
+ RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
+ // In the 'Ready' state, we have nothing to do but to wait until
+ // we switch to another state.
+ const auto& condition = [this]() {
+ return state_.load(std::memory_order_acquire) != State::Ready;
+ };
+ RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
+ Wait(condition, spin_duration_, &state_cond_, &state_mutex_);
+
+ // Act on new state.
+ switch (state_.load(std::memory_order_acquire)) {
+ case State::HasWork: {
+ RUY_TRACE_SCOPE_NAME("Worker thread task");
+ // Got work to do! So do it, and then revert to 'Ready' state.
+ ChangeState(State::Ready);
+ break;
+ }
+ case State::ExitAsSoonAsPossible:
+ return;
+ default:
+ abort();
+ }
+ }
+ }
+
+ // The underlying thread.
+ std::unique_ptr<std::thread> thread_;
+
+ // The task to be worked on.
+ Task* task_;
+
+ // The condition variable and mutex guarding state changes.
+ std::condition_variable state_cond_;
+ std::mutex state_mutex_;
+
+ // The state enum tells if we're currently working, waiting for work, etc.
+ // Its concurrent accesses by the thread and main threads are guarded by
+ // state_mutex_, and can thus use memory_order_relaxed. This still needs
+ // to be a std::atomic because we use WaitForVariableChange.
+ std::atomic<State> state_;
+
+ // pointer to the master's thread BlockingCounter object, to notify the
+ // master thread of when this thread switches to the 'Ready' state.
+ BlockingCounter* const counter_to_decrement_when_ready_;
+
+ // See ThreadPool::spin_duration_.
+ const Duration spin_duration_;
+};
+
+void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
+ RUY_TRACE_SCOPE_NAME("ThreadPool::Execute");
+ RUY_DCHECK_GE(task_count, 1);
+
+ // Case of 1 thread: just run the single task on the current thread.
+ if (task_count == 1) {
+ (tasks + 0)->Run();
+ return;
+ }
+
+ // Task #0 will be run on the current thread.
+ CreateThreads(task_count - 1);
+ counter_to_decrement_when_ready_.Reset(task_count - 1);
+ for (int i = 1; i < task_count; i++) {
+ RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
+ auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
+ threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
+ }
+
+ RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD);
+ // Execute task #0 immediately on the current thread.
+ (tasks + 0)->Run();
+
+ RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
+ // Wait for the threads submitted above to finish.
+ counter_to_decrement_when_ready_.Wait(spin_duration_);
+}
+
+// Ensures that the pool has at least the given count of threads.
+// If any new thread has to be created, this function waits for it to
+// be ready.
+void ThreadPool::CreateThreads(int threads_count) {
+ RUY_DCHECK_GE(threads_count, 0);
+ unsigned int unsigned_threads_count = threads_count;
+ if (threads_.size() >= unsigned_threads_count) {
+ return;
+ }
+ counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
+ while (threads_.size() < unsigned_threads_count) {
+ threads_.push_back(
+ new Thread(&counter_to_decrement_when_ready_, spin_duration_));
+ }
+ counter_to_decrement_when_ready_.Wait(spin_duration_);
+}
+
+ThreadPool::~ThreadPool() {
+ for (auto w : threads_) {
+ delete w;
+ }
+}
+
+} // end namespace ruy
diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h
new file mode 100644
index 0000000..e3b6803
--- /dev/null
+++ b/ruy/thread_pool.h
@@ -0,0 +1,127 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0
+// license.
+
+#ifndef RUY_RUY_THREAD_POOL_H_
+#define RUY_RUY_THREAD_POOL_H_
+
+#include <vector>
+
+#include "ruy/blocking_counter.h"
+#include "ruy/time.h"
+
+namespace ruy {
+
+// A workload for a thread.
+struct Task {
+ virtual ~Task() {}
+ virtual void Run() = 0;
+};
+
+class Thread;
+
+// A simple pool of threads, that only allows the very
+// specific parallelization pattern that we use here:
+// One thread, which we call the 'main thread', calls Execute, distributing
+// a Task each to N threads, being N-1 'worker threads' and the main thread
+// itself. After the main thread has completed its own Task, it waits for
+// the worker threads to have all completed. That is the only synchronization
+// performed by this ThreadPool.
+//
+// In particular, there is a naive 1:1 mapping of Tasks to threads.
+// This ThreadPool considers it outside of its own scope to try to work
+// with fewer threads than there are Tasks. The idea is that such N:M mappings
+// of tasks to threads can be implemented as a higher-level feature on top of
+// the present low-level 1:1 threadpool. For example, a user might have a
+// Task subclass referencing a shared atomic counter indexing into a vector of
+// finer-granularity subtasks. Different threads would then concurrently
+// increment this atomic counter, getting each their own subtasks to work on.
+// That approach is the one used in ruy's multi-thread matrix multiplication
+// implementation --- see ruy's TrMulTask.
+class ThreadPool {
+ public:
+ ThreadPool() {}
+
+ ~ThreadPool();
+
+ // Executes task_count tasks on task_count threads.
+ // Grows the threadpool as needed to have at least (task_count-1) threads.
+ // The 0-th task is run on the thread on which Execute is called: that
+ // is by definition what we call the "main thread". Synchronization of all
+ // threads is performed before this function returns.
+ //
+ // As explained in the class comment, there is a 1:1 mapping of tasks to
+ // threads. If you need something smarter than that, for instance if you
+ // want to run an unbounded number of tasks on a bounded number of threads,
+ // then you need something higher-level than this ThreadPool, that can
+ // be layered on top of it by appropriately subclassing Tasks.
+ //
+ // TaskType must be a subclass of ruy::Task. That is implicitly guarded by
+ // the static_cast in this inline implementation.
+ template <typename TaskType>
+ void Execute(int task_count, TaskType* tasks) {
+ ExecuteImpl(task_count, sizeof(TaskType), static_cast<Task*>(tasks));
+ }
+
+ void set_spin_milliseconds(float milliseconds) {
+ spin_duration_ = DurationFromMilliseconds(milliseconds);
+ }
+
+ float spin_milliseconds() const {
+ return ToFloatMilliseconds(spin_duration_);
+ }
+
+ private:
+ // Ensures that the pool has at least the given count of threads.
+ // If any new thread has to be created, this function waits for it to
+ // be ready.
+ void CreateThreads(int threads_count);
+
+ // Non-templatized implementation of the public Execute method.
+ // See the inline implementation of Execute for how this is used.
+ void ExecuteImpl(int task_count, int stride, Task* tasks);
+
+ // copy construction disallowed
+ ThreadPool(const ThreadPool&) = delete;
+
+ // The threads in this pool. They are owned by the pool:
+ // the pool creates threads and destroys them in its destructor.
+ std::vector<Thread*> threads_;
+
+ // The BlockingCounter used to wait for the threads.
+ BlockingCounter counter_to_decrement_when_ready_;
+
+ // This value was empirically derived with some microbenchmark, we don't have
+ // high confidence in it.
+ //
+ // That this value means that we may be sleeping substantially longer
+ // than a scheduler timeslice's duration is not necessarily surprising. The
+ // idea is to pick up quickly new work after having finished the previous
+ // workload. When it's new work within the same GEMM as the previous work, the
+ // time interval that we might be busy-waiting is very small, so for that
+ // purpose it would be more than enough to sleep for 1 ms.
+ // That is all what we would observe on a GEMM benchmark. However, in a real
+ // application, after having finished a GEMM, we might do unrelated work for
+ // a little while, then start on a new GEMM. In that case the wait interval
+ // may be a little longer. There may also not be another GEMM for a long time,
+ // in which case we'll end up passively waiting below.
+ Duration spin_duration_ = DurationFromMilliseconds(2);
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_THREAD_POOL_H_
diff --git a/ruy/time.h b/ruy/time.h
new file mode 100644
index 0000000..aa02245
--- /dev/null
+++ b/ruy/time.h
@@ -0,0 +1,87 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_TIME_H_
+#define RUY_RUY_TIME_H_
+
+#include <chrono> // NOLINT(build/c++11)
+#include <cstdint> // IWYU pragma: keep
+#include <ratio> // NOLINT(build/c++11)
+
+#ifdef __linux__
+#include <sys/time.h>
+// IWYU pragma: no_include <type_traits>
+
+#include <ctime>
+#endif
+
+namespace ruy {
+
+using InternalDefaultClock = std::chrono::steady_clock;
+
+using TimePoint = InternalDefaultClock::time_point;
+using Duration = InternalDefaultClock::duration;
+
+template <typename RepresentationType>
+Duration DurationFromSeconds(RepresentationType representation) {
+ return std::chrono::duration_cast<Duration>(
+ std::chrono::duration<RepresentationType>(representation));
+}
+
+template <typename RepresentationType>
+Duration DurationFromMilliseconds(RepresentationType representation) {
+ return std::chrono::duration_cast<Duration>(
+ std::chrono::duration<RepresentationType, std::milli>(representation));
+}
+
+template <typename RepresentationType>
+Duration DurationFromNanoseconds(RepresentationType representation) {
+ return std::chrono::duration_cast<Duration>(
+ std::chrono::duration<RepresentationType, std::nano>(representation));
+}
+
+inline float ToFloatSeconds(const Duration& duration) {
+ return std::chrono::duration_cast<std::chrono::duration<float>>(duration)
+ .count();
+}
+
+inline float ToFloatMilliseconds(const Duration& duration) {
+ return std::chrono::duration_cast<std::chrono::duration<float, std::milli>>(
+ duration)
+ .count();
+}
+
+inline std::int64_t ToInt64Nanoseconds(const Duration& duration) {
+ return std::chrono::duration_cast<
+ std::chrono::duration<std::int64_t, std::nano>>(duration)
+ .count();
+}
+
+inline TimePoint Now() { return InternalDefaultClock::now(); }
+
+inline TimePoint CoarseNow() {
+#ifdef __linux__
+ timespec t;
+ clock_gettime(CLOCK_MONOTONIC_COARSE, &t);
+ return TimePoint(
+ DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec));
+#else
+ return Now();
+#endif
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_TIME_H_
diff --git a/ruy/trace.h b/ruy/trace.h
new file mode 100644
index 0000000..4f5059e
--- /dev/null
+++ b/ruy/trace.h
@@ -0,0 +1,836 @@
+/* Copyright 2021 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_TRACE_H_
+#define RUY_RUY_TRACE_H_
+
+#ifdef RUY_TRACE
+
+#include <algorithm>
+#include <cstdio>
+#include <cstdlib>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "ruy/mat.h"
+#include "ruy/matrix.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/side_pair.h"
+
+namespace ruy {
+
+// Helper for `formatted` so we don't have to put .c_str() on strings.
+template <typename T>
+T value_for_snprintf(T value) {
+ return value;
+}
+
+inline const char* value_for_snprintf(const std::string& s) {
+ return s.c_str();
+}
+
+// A sprintf-like function returning a std::string.
+// Remove this once we can rely on std::format (c++20).
+template <typename... Args>
+std::string formatted(const char* format, Args... args) {
+ char buf[1024];
+#pragma GCC diagnostic push
+#pragma GCC diagnostic warning "-Wformat-security"
+ int size = snprintf(buf, sizeof buf, format, value_for_snprintf(args)...);
+#pragma GCC diagnostic pop
+ if (size <= 0) {
+ abort();
+ }
+ return std::string(buf);
+}
+
+// An entry in the trace.
+struct ThreadTraceEntry final {
+ std::string text;
+ int indent = 0;
+ const char* source_file = nullptr;
+ int source_line = 0;
+};
+
+// Trace for one thread.
+class ThreadTrace final {
+ public:
+ ~ThreadTrace() {}
+
+ void set_thread_id(int thread_id) { thread_id_ = thread_id; }
+ int thread_id() const { return thread_id_; }
+
+ bool is_in_run_ahead_packing_loop() const {
+ return is_in_run_ahead_packing_loop_;
+ }
+ void set_is_in_run_ahead_packing_loop(bool value) {
+ is_in_run_ahead_packing_loop_ = value;
+ }
+
+ void set_current_source_file(const char* source_file) {
+ current_source_file_ = source_file;
+ }
+
+ void set_current_source_line(int source_line) {
+ current_source_line_ = source_line;
+ }
+
+ const std::vector<ThreadTraceEntry>& entries() const { return entries_; }
+
+ template <typename... Args>
+ void Write(const char* format, Args... args) {
+ ThreadTraceEntry entry;
+ entry.text = formatted(format, args...);
+ entry.indent = indent_;
+ entry.source_file = current_source_file_;
+ entry.source_line = current_source_line_;
+ entries_.emplace_back(std::move(entry));
+ }
+
+ template <typename... Args>
+ void EnterScope(const char* scope_name) {
+ Write("%s {", scope_name);
+ indent_++;
+ }
+ void LeaveScope(const char* scope_name) {
+ indent_--;
+ Write("} // end of %s", scope_name);
+ }
+
+ private:
+ // The trace contents
+ std::vector<ThreadTraceEntry> entries_;
+
+ // Current indentation level.
+ int indent_ = 0;
+ // Thread's ID as set by Ruy, e.g. [0,N-1]. Not OS TID.
+ int thread_id_ = -1;
+ // The run-ahead loop in `EnsurePacked` may run many iterations when the
+ // thread is waiting for a block to be packed by another thread --- it's
+ // a busy wait. We track whether we are in that mode to avoid generating
+ // many uninteresting trace entries.
+ bool is_in_run_ahead_packing_loop_ = false;
+ // Last recorded value of __FILE__ and __LINE__, as a convenience so we don't
+ // have to pass these in every call to `Write`.
+ const char* current_source_file_ = nullptr;
+ int current_source_line_ = 0;
+};
+
+// Main components of ruy. Used for trace colorization.
+enum class Component { kNone, kFrontEnd, kMiddleEnd, kBackEnd, kThreadPool };
+
+// Output format for the trace.
+enum class TraceOutputFormat { kNone, kTerminal, kHtml };
+
+inline std::string IndentString(int indent) {
+ std::string s;
+ for (int i = 0; i < indent; i++) {
+ s += " ";
+ }
+ return s;
+}
+
+// Returns the text to write to the trace to open a colored section.
+inline const char* ColorSectionStart(TraceOutputFormat output_format,
+ Component component) {
+ switch (output_format) {
+ case TraceOutputFormat::kTerminal:
+ switch (component) {
+ case Component::kFrontEnd:
+ return "\x1b[36m";
+ case Component::kMiddleEnd:
+ return "\x1b[32m";
+ case Component::kBackEnd:
+ return "\x1b[31m";
+ case Component::kThreadPool:
+ return "\x1b[33m";
+ default:
+ abort();
+ return nullptr;
+ }
+ case TraceOutputFormat::kHtml:
+ switch (component) {
+ case Component::kFrontEnd:
+ return "<span style=\"background-color:#B2EBF2\">";
+ case Component::kMiddleEnd:
+ return "<span style=\"background-color:#C8E6C9\">";
+ case Component::kBackEnd:
+ return "<span style=\"background-color:#FFCDD2\">";
+ case Component::kThreadPool:
+ return "<span style=\"background-color:#FFF9C4\">";
+ default:
+ abort();
+ return nullptr;
+ }
+ default:
+ abort();
+ return nullptr;
+ }
+}
+
+// Returns the text to write to the trace to close a colored section.
+inline const char* ColorSectionEnd(TraceOutputFormat output_format) {
+ switch (output_format) {
+ case TraceOutputFormat::kTerminal:
+ return "\x1b[0m";
+ case TraceOutputFormat::kHtml:
+ return "</span>";
+ default:
+ abort();
+ return nullptr;
+ }
+}
+
+// Returns the output format to use for the trace.
+inline TraceOutputFormat GetOutputFormat() {
+ const char* html_env = getenv("RUY_TRACE_HTML");
+ if (html_env && strtol(html_env, nullptr, 10) != 0) {
+ return TraceOutputFormat::kHtml;
+ } else {
+ return TraceOutputFormat::kTerminal;
+ }
+}
+
+// A `basename` function that's good enough for ruy __FILE__'s.
+// Note: `basename` is POSIX-only and annoying (takes a char*, may mutate).
+inline const char* GetBaseName(const char* path) {
+ std::size_t len = strlen(path);
+ if (len == 0) {
+ return path;
+ }
+ const char* ptr = path + len - 1;
+ while (ptr != path) {
+ if (*ptr == '/' || *ptr == '\\') {
+ return ptr + 1;
+ }
+ --ptr;
+ }
+ // Path did not contain any path separator.
+ return path;
+}
+
+// Determines a Component (used for colorization) by source file.
+inline Component GetComponentBySourceFile(const char* base_name) {
+ if (!strcmp(base_name, "pack.h") || !strcmp(base_name, "kernel.h")) {
+ return Component::kBackEnd;
+ } else if (!strcmp(base_name, "trmul.cc") ||
+ !strcmp(base_name, "block_map.cc")) {
+ return Component::kMiddleEnd;
+ } else if (!strcmp(base_name, "thread_pool.cc")) {
+ return Component::kThreadPool;
+ } else {
+ return Component::kFrontEnd;
+ }
+}
+
+inline std::string EscapeText(TraceOutputFormat output_format,
+ const std::string& text) {
+ if (output_format == TraceOutputFormat::kHtml) {
+ std::string escaped_text;
+ for (char c : text) {
+ if (c == '<') {
+ escaped_text += "&lt;";
+ } else if (c == '>') {
+ escaped_text += "&gt;";
+ } else {
+ escaped_text += c;
+ }
+ }
+ return escaped_text;
+ } else {
+ return text;
+ }
+}
+
+// Prints an entry from the trace to the destination trace file.
+inline void Print(const ThreadTraceEntry& entry,
+ TraceOutputFormat output_format, FILE* file) {
+ const char* base_name = GetBaseName(entry.source_file);
+ Component component = GetComponentBySourceFile(base_name);
+ const std::string& source_location =
+ formatted("%s:%d", base_name, entry.source_line);
+ const std::string& escaped_text = EscapeText(output_format, entry.text);
+ fprintf(file, "%s%-32s%s%s%s\n", ColorSectionStart(output_format, component),
+ source_location.c_str(), IndentString(entry.indent).c_str(),
+ escaped_text.c_str(), ColorSectionEnd(output_format));
+}
+
+// Prints a thread's entire trace to the destination trace file.
+inline void Print(const ThreadTrace& trace, TraceOutputFormat output_format,
+ FILE* file) {
+ if (output_format == TraceOutputFormat::kHtml) {
+ fprintf(file, "<html><body><pre>\n<span style=\"font-weight:bold\">\n");
+ }
+ fprintf(file, "Ruy trace for thread %d:\n", trace.thread_id());
+ if (output_format == TraceOutputFormat::kHtml) {
+ fprintf(file, "</span>\n");
+ }
+ for (const ThreadTraceEntry& entry : trace.entries()) {
+ Print(entry, output_format, file);
+ }
+ fprintf(file, "\n");
+ if (output_format == TraceOutputFormat::kHtml) {
+ fprintf(file, "</pre></body></html>\n");
+ }
+}
+
+// Holds all the threads' traces. This is a global singleton class.
+// On exit, when the singleton is destroyed, the destructor prints out the
+// traces.
+class AllThreadTraces final {
+ public:
+ // Add a new ThreadTrace for the current thread. Should be called only once
+ // on each thread.
+ ThreadTrace* AddCurrentThread() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ ThreadTrace* thread_trace = new ThreadTrace;
+ thread_traces_.emplace_back(thread_trace);
+ return thread_trace;
+ }
+ ~AllThreadTraces() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ // Open the destination file.
+ const char* file_env = getenv("RUY_TRACE_FILE");
+ FILE* file = stdout;
+ if (file_env) {
+ file = fopen(file_env, "w");
+ if (!file) {
+ fprintf(stderr, "Failed to open %s for write\n", file_env);
+ exit(1);
+ }
+ }
+ // Sort the threads by Ruy Thread ID (not OS TID).
+ auto output_format = GetOutputFormat();
+ std::sort(std::begin(thread_traces_), std::end(thread_traces_),
+ [](const auto& a, const auto& b) {
+ return a->thread_id() < b->thread_id();
+ });
+ // Print all the threads' traces.
+ for (const auto& trace : thread_traces_) {
+ Print(*trace, output_format, file);
+ }
+ if (file_env) {
+ fclose(file);
+ }
+ }
+ static AllThreadTraces* Singleton() {
+ static AllThreadTraces all_thread_traces;
+ return &all_thread_traces;
+ }
+
+ private:
+ std::vector<std::unique_ptr<ThreadTrace>> thread_traces_;
+ std::mutex mutex_;
+};
+
+// Returns the thread-local ThreadTrace singleton, constructing it as needed.
+inline ThreadTrace* ThreadLocalTrace() {
+ static thread_local ThreadTrace* thread_local_trace =
+ AllThreadTraces::Singleton()->AddCurrentThread();
+ return thread_local_trace;
+}
+
+// RAII helper to trace a scope, e.g. a function scope.
+class RuyTraceScope {
+ const char* source_file_;
+ int source_line_;
+ const char* scope_name_;
+
+ public:
+ RuyTraceScope(const char* source_file, int source_line,
+ const char* scope_name)
+ : source_file_(source_file),
+ source_line_(source_line),
+ scope_name_(scope_name) {
+ ThreadLocalTrace()->set_current_source_file(source_file_);
+ ThreadLocalTrace()->set_current_source_line(source_line_);
+ ThreadLocalTrace()->EnterScope(scope_name_);
+ }
+ ~RuyTraceScope() {
+ ThreadLocalTrace()->set_current_source_file(source_file_);
+ ThreadLocalTrace()->set_current_source_line(source_line_);
+ ThreadLocalTrace()->LeaveScope(scope_name_);
+ }
+};
+
+#define RUY_TRACE_SCOPE_NAME_IMPL(file, line, name) \
+ RuyTraceScope ruy_trace_scope##line(file, line, name)
+#define RUY_TRACE_SCOPE_NAME(name) \
+ RUY_TRACE_SCOPE_NAME_IMPL(__FILE__, __LINE__, name)
+#define RUY_TRACE_SCOPE \
+ RUY_TRACE_SCOPE_NAME_IMPL(__FILE__, __LINE__, __FUNCTION__)
+
+// Helpers to trace Ruy objects.
+
+inline std::string str(Order o) {
+ return o == Order::kRowMajor ? "row-major" : "column-major";
+}
+
+inline std::string str(Side s) { return s == Side::kLhs ? "LHS" : "RHS"; }
+
+inline std::string str(const Layout& layout) {
+ std::string s =
+ formatted("%dx%d, %s", layout.rows(), layout.cols(), str(layout.order()));
+ int inner_size =
+ layout.order() == Order::kRowMajor ? layout.cols() : layout.rows();
+ if (inner_size != layout.stride()) {
+ s += formatted(", stride=%d", layout.stride());
+ } else {
+ s += formatted(", unstrided");
+ }
+ return s;
+}
+
+inline std::string str(const MatLayout& layout) {
+ std::string s =
+ formatted("%dx%d, %s", layout.rows, layout.cols, str(layout.order));
+ int inner_size = layout.order == Order::kRowMajor ? layout.cols : layout.rows;
+ if (inner_size != layout.stride) {
+ s += formatted(", stride=%d", layout.stride);
+ } else {
+ s += formatted(", unstrided");
+ }
+ return s;
+}
+
+inline std::string str(const PMatLayout& layout) {
+ std::string s =
+ formatted("%dx%d, %s", layout.rows, layout.cols, str(layout.order));
+ int inner_size = layout.order == Order::kRowMajor ? layout.cols : layout.rows;
+ if (inner_size != layout.stride) {
+ s += formatted(", stride=%d", layout.stride);
+ } else {
+ s += formatted(", unstrided");
+ }
+ s += formatted(", kernel blocks: %dx%d %s", layout.kernel.rows,
+ layout.kernel.cols, str(layout.kernel.order));
+ return s;
+}
+
+template <typename T>
+std::string str() {
+ return "<unknown type>";
+}
+#define RUY_IMPL_STR_TYPE_STD(T) \
+ template <> \
+ inline std::string str<std::T>() { \
+ return #T; \
+ }
+#define RUY_IMPL_STR_TYPE(T) \
+ template <> \
+ inline std::string str<T>() { \
+ return #T; \
+ }
+
+RUY_IMPL_STR_TYPE(float)
+RUY_IMPL_STR_TYPE(double)
+RUY_IMPL_STR_TYPE_STD(int8_t)
+RUY_IMPL_STR_TYPE_STD(uint8_t)
+RUY_IMPL_STR_TYPE_STD(int16_t)
+RUY_IMPL_STR_TYPE_STD(uint16_t)
+RUY_IMPL_STR_TYPE_STD(int32_t)
+RUY_IMPL_STR_TYPE_STD(uint32_t)
+RUY_IMPL_STR_TYPE_STD(int64_t)
+RUY_IMPL_STR_TYPE_STD(uint64_t)
+
+template <typename T>
+std::string str(const Matrix<T>& matrix) {
+ std::string s = formatted("Matrix<%s>, %s", str<T>(), str(matrix.layout()));
+ if (matrix.zero_point()) {
+ s += formatted(", zero_point=%d", static_cast<int>(matrix.zero_point()));
+ }
+ if (matrix.cache_policy() != CachePolicy::kNeverCache) {
+ s +=
+ formatted(", cache_policy=%d", static_cast<int>(matrix.cache_policy()));
+ }
+ return s;
+}
+
+inline std::string str(const Type& type) {
+ char c;
+ if (type.is_floating_point) {
+ c = 'f';
+ } else if (type.is_signed) {
+ c = 'i';
+ } else {
+ c = 'u';
+ }
+ return formatted("%c%d", c, type.size * 8);
+}
+
+inline std::string str(const EMat& mat) {
+ std::string s =
+ formatted("EMat, data_type=%s, %s", str(mat.data_type), str(mat.layout));
+ if (mat.zero_point) {
+ s += formatted(", zero_point=%d", static_cast<int>(mat.zero_point));
+ }
+ if (mat.cache_policy != CachePolicy::kNeverCache) {
+ s += formatted(", cache_policy=%d", static_cast<int>(mat.cache_policy));
+ }
+ return s;
+}
+
+inline std::string str(const PEMat& mat) {
+ std::string s =
+ formatted("PEMat, data_type=%s, %s", str(mat.data_type), str(mat.layout));
+ if (mat.zero_point) {
+ s += formatted(", zero_point=%d", static_cast<int>(mat.zero_point));
+ }
+ return s;
+}
+
+inline std::string str(Path paths) {
+ bool first = true;
+ std::string s;
+ for (int bit = 0; bit < 16; bit++) {
+ Path cur_path = static_cast<Path>(1 << bit);
+ if ((paths & cur_path) != Path::kNone) {
+ if (!first) {
+ s += " | ";
+ }
+ first = false;
+ switch (cur_path) {
+ case Path::kNone:
+ continue;
+#define RUY_HANDLE_PATH(p) \
+ case Path::p: \
+ s += #p; \
+ break;
+ RUY_HANDLE_PATH(kStandardCpp)
+ RUY_HANDLE_PATH(kInternalStandardCppVariant1)
+ RUY_HANDLE_PATH(kInternalStandardCppVariant2)
+ RUY_HANDLE_PATH(kInternalStandardCppVariant3)
+#if RUY_PLATFORM_ARM
+ RUY_HANDLE_PATH(kNeon)
+ RUY_HANDLE_PATH(kNeonDotprod)
+#endif // RUY_PLATFORM_ARM
+#if RUY_PLATFORM_X86
+ RUY_HANDLE_PATH(kAvx)
+ RUY_HANDLE_PATH(kAvx2Fma)
+ RUY_HANDLE_PATH(kAvx512)
+#endif // RUY_PLATFORM_X86
+#undef RUY_HANDLE_PATH
+ default:
+ fprintf(stderr, "Unhandled Path value 0x%x\n",
+ static_cast<int>(cur_path));
+ abort();
+ }
+ }
+ }
+ return s;
+}
+
+// Implementation of RUY_TRACE_INFO(X) macros.
+
+#define RUY_TRACE_INFO_MUL \
+ ThreadLocalTrace()->Write("CompiledPaths: %s", str(CompiledPaths)); \
+ ThreadLocalTrace()->Write("LHS: %s", str(lhs)); \
+ ThreadLocalTrace()->Write("RHS: %s", str(rhs)); \
+ ThreadLocalTrace()->Write("Destination: %s", str(*dst));
+
+#define RUY_TRACE_INFO_CREATE_TRMUL_PARAMS_TRANSPOSING \
+ ThreadLocalTrace()->Write("Canonicalizing to column-major destination:"); \
+ ThreadLocalTrace()->Write( \
+ "Swapping LHS<->RHS and flipping all storage orders.");
+
+#define RUY_TRACE_INFO_CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST \
+ ThreadLocalTrace()->Write("Runtime-selected path: %s", str(the_path)); \
+ ThreadLocalTrace()->Write("LHS: %s", str(params->src[Side::kLhs])); \
+ ThreadLocalTrace()->Write("RHS: %s", str(params->src[Side::kRhs])); \
+ ThreadLocalTrace()->Write("Destination: %s", str(params->dst));
+
+#define RUY_TRACE_INFO_POPULATE_TRMUL_PARAMS \
+ ThreadLocalTrace()->Write( \
+ "Here we have this Path as a template parameter: %s", str(ThePath)); \
+ ThreadLocalTrace()->Write("PackedLhsScalar: %s", str<PackedLhsScalar>()); \
+ ThreadLocalTrace()->Write("PackedRhsScalar: %s", str<PackedRhsScalar>()); \
+ ThreadLocalTrace()->Write("Kernel function pointer: %p", \
+ params->run_kernel); \
+ ThreadLocalTrace()->Write("Kernel LHS block layout: %dx%d %s", \
+ LhsKernelLayout::kRows, LhsKernelLayout::kCols, \
+ str(LhsKernelLayout::kOrder)); \
+ ThreadLocalTrace()->Write("Kernel RHS block layout: %dx%d %s", \
+ RhsKernelLayout::kRows, RhsKernelLayout::kCols, \
+ str(RhsKernelLayout::kOrder)); \
+ ThreadLocalTrace()->Write("Created packed matrices:"); \
+ ThreadLocalTrace()->Write("Packed LHS matrix: %s", \
+ str(params->packed_matrix[Side::kLhs])); \
+ ThreadLocalTrace()->Write("Packed RHS matrix: %s", \
+ str(params->packed_matrix[Side::kRhs])); \
+ ThreadLocalTrace()->Write("LHS packing function pointer: %p", \
+ params->run_pack[Side::kLhs]); \
+ ThreadLocalTrace()->Write("RHS packing function pointer: %p", \
+ params->run_pack[Side::kRhs]);
+
+#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_SET_VALUE \
+ ThreadLocalTrace()->Write("SetRuntimeEnabledPaths forcing paths: %s", \
+ str(*paths));
+
+#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_ENV_VAR \
+ ThreadLocalTrace()->Write("Environment variable forcing paths: %s", \
+ str(*paths));
+
+#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_DETECTION \
+ ThreadLocalTrace()->Write( \
+ "Runtime-detected paths: %s", \
+ str(*paths & ~kNonArchPathsIncludingInternalVariants));
+
+#define RUY_TRACE_INFO_PREPARE_PACKED_MATRICES_SHOULD_CACHE \
+ ThreadLocalTrace()->Write( \
+ "Caching the packed %s matrix. Already in cache: %s", str(side), \
+ action == PrepackedCache::Action::kInsertedNewEntry ? "no" : "yes");
+
+#define RUY_TRACE_INFO_PREPARE_PACKED_MATRICES_NO_CACHE \
+ ThreadLocalTrace()->Write("Not caching the packed %s matrix.", str(side));
+
+#define RUY_TRACE_INFO_GET_TENTATIVE_THREAD_COUNT \
+ ThreadLocalTrace()->Write( \
+ "tentative_thread_count=%d (determined based on shape %dx%dx%d and " \
+ "max_num_threads=%d)", \
+ tentative_thread_count, rows, depth, cols, ctx->max_num_threads());
+
+#define RUY_TRACE_INFO_GET_USE_SIMPLE_LOOP_RETURNS_TRUE \
+ ThreadLocalTrace()->Write( \
+ "Choosing to use the simple loop code path in TrMul because of the " \
+ "linear traversal and single thread.");
+
+#define RUY_TRACE_INFO_GET_USE_SIMPLE_LOOP_RETURNS_FALSE \
+ ThreadLocalTrace()->Write( \
+ "Choosing to use the general case code path in TrMul because of: %s", \
+ tentative_thread_count > 1 ? "multi-threading" \
+ : "non-linear traversal order");
+
+#define RUY_TRACE_INFO_TRMUL_SIMPLE_LOOP \
+ ThreadLocalTrace()->Write("Entering the simple loop code path of TrMul");
+
+#define RUY_TRACE_INFO_TRMUL_GENERAL_CASE \
+ ThreadLocalTrace()->Write("Entering the general case code path of TrMul");
+
+#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_START \
+ ThreadLocalTrace()->Write("Kernel block: %dx%d", kernel_rows, kernel_cols); \
+ ThreadLocalTrace()->Write( \
+ "BlockMap shape: %dx%d (destination matrix shape rounded to next " \
+ "kernel blocks)", \
+ rows, cols); \
+ ThreadLocalTrace()->Write( \
+ "Rectangularness log2: %dx%d (powers of two factors bringing the shape " \
+ "closest to square)", \
+ rows_rectangularness_log2, cols_rectangularness_log2); \
+ ThreadLocalTrace()->Write("Accumulation depth: %d", depth); \
+ ThreadLocalTrace()->Write("LHS scalar type size: %d", lhs_scalar_size); \
+ ThreadLocalTrace()->Write("RHS scalar type size: %d", rhs_scalar_size); \
+ ThreadLocalTrace()->Write("Tentative thread count: %d", \
+ tentative_thread_count); \
+ ThreadLocalTrace()->Write( \
+ "CPU cache params: local_cache_size=%d, last_level_cache_size=%d", \
+ cpu_cache_params.local_cache_size, \
+ cpu_cache_params.last_level_cache_size); \
+ ThreadLocalTrace()->Write( \
+ "For the sizes below, when rows!=cols, we always retain the min of the " \
+ "two."); \
+ ThreadLocalTrace()->Write("Kernel block size_log2: %d", kernel_size_log2); \
+ ThreadLocalTrace()->Write( \
+ "BlockMap size_log2: %d (destination matrix shape rounded to next " \
+ "kernel blocks)", \
+ size_log2); \
+ ThreadLocalTrace()->Write( \
+ "Now we will pick the optimal log2 of BlockMap block size");
+
+#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE \
+ ThreadLocalTrace()->Write( \
+ "For BlockMap block size_log2 %d, score=%d (" \
+ "multithreading_score=%d + cache_locality_score=%d + " \
+ "kernel_amortization_score=%d)", \
+ block_size_log2, score, multithreading_score, cache_locality_score, \
+ kernel_amortization_score);
+
+#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_END \
+ ThreadLocalTrace()->Write("Selecting BlockMap block size_log2: %d", \
+ best_score_block_size_log2); \
+ ThreadLocalTrace()->Write( \
+ "BlockMap has %dx%d blocks, each of size between %dx%d and %dx%d.", \
+ 1 << num_blocks_of_rows_log2, 1 << num_blocks_of_cols_log2, \
+ block_map->small_block_dims[Side::kLhs], \
+ block_map->small_block_dims[Side::kRhs], \
+ block_map->small_block_dims[Side::kLhs] + \
+ block_map->kernel_dims[Side::kLhs], \
+ block_map->small_block_dims[Side::kRhs] + \
+ block_map->kernel_dims[Side::kRhs]); \
+ ThreadLocalTrace()->Write( \
+ "The first %d rows of blocks have %d rows, the remaining ones have %d " \
+ "rows ", \
+ block_map->large_blocks[Side::kLhs], \
+ block_map->small_block_dims[Side::kLhs] + \
+ block_map->kernel_dims[Side::kLhs], \
+ block_map->small_block_dims[Side::kLhs]); \
+ ThreadLocalTrace()->Write( \
+ "The first %d columns of blocks have %d columns, the remaining ones " \
+ "have %d columns ", \
+ block_map->large_blocks[Side::kRhs], \
+ block_map->small_block_dims[Side::kRhs] + \
+ block_map->kernel_dims[Side::kLhs], \
+ block_map->small_block_dims[Side::kRhs]); \
+ ThreadLocalTrace()->Write( \
+ "Traversal order: %s", \
+ block_map->traversal_order == BlockMapTraversalOrder::kLinear ? "linear" \
+ : block_map->traversal_order == BlockMapTraversalOrder::kFractalZ \
+ ? "fractal Z-curve" \
+ : block_map->traversal_order == BlockMapTraversalOrder::kFractalU \
+ ? "fractal U-curve" \
+ : block_map->traversal_order == BlockMapTraversalOrder::kFractalHilbert \
+ ? "fractal Hilbert curve" \
+ : nullptr); \
+ ThreadLocalTrace()->Write("Finalized thread count: %d", \
+ block_map->thread_count);
+
+#define RUY_TRACE_SET_THEAD_ID(thread_id) \
+ ThreadLocalTrace()->set_thread_id(thread_id);
+
+#define RUY_TRACE_INFO_TRMUL_TASK_MAIN_LOOP_GOT_BLOCK_COORDS \
+ ThreadLocalTrace()->Write( \
+ "Block #%d is at position (%d, %d) in the BlockMap.", block_id, \
+ block[Side::kLhs], block[Side::kRhs]); \
+ ThreadLocalTrace()->Write( \
+ "Block #%d has shape %dx%d and starts at position (%d, %d) in the " \
+ "destination matrix.", \
+ block_id, end[Side::kLhs] - start[Side::kLhs], \
+ end[Side::kRhs] - start[Side::kRhs], start[Side::kLhs], \
+ start[Side::kRhs]); \
+ ThreadLocalTrace()->Write( \
+ "Block #%d depends on LHS panel #%d and RHS panel #%d.", block_id, \
+ block[Side::kLhs], block[Side::kRhs]);
+
+#define RUY_TRACE_INFO_TRYPACK_PACKING \
+ ThreadLocalTrace()->Write( \
+ "%s panel #%d is not already packed. Packing it now.", str(side), \
+ block);
+
+#define RUY_TRACE_INFO_TRYPACK_ANOTHER_THREAD_PACKING \
+ if (!ThreadLocalTrace()->is_in_run_ahead_packing_loop()) { \
+ ThreadLocalTrace()->Write( \
+ "%s panel #%d is currently being packed by another thread.", \
+ str(side), block); \
+ }
+
+#define RUY_TRACE_INFO_TRYPACK_PREVIOUSLY_PACKED \
+ if (!ThreadLocalTrace()->is_in_run_ahead_packing_loop()) { \
+ ThreadLocalTrace()->Write("%s panel #%d had previously been packed.", \
+ str(side), block); \
+ }
+
+#define RUY_TRACE_INFO_TRYPACK_PACKED_BY_ANOTHER_THREAD \
+ ThreadLocalTrace()->Write( \
+ "%s panel #%d has just been packed by another thread.", str(side), \
+ block);
+
+#define RUY_TRACE_INFO_ENSURE_PACKED_ENTER_RUN_AHEAD \
+ if (!ThreadLocalTrace()->is_in_run_ahead_packing_loop()) { \
+ ThreadLocalTrace()->set_is_in_run_ahead_packing_loop(true); \
+ ThreadLocalTrace()->Write( \
+ "We're blocked on other threads packing the panels that we need. " \
+ "Packing some other panels while we wait..."); \
+ }
+
+#define RUY_TRACE_INFO_ENSURE_PACKED_END \
+ if (ThreadLocalTrace()->is_in_run_ahead_packing_loop()) { \
+ ThreadLocalTrace()->set_is_in_run_ahead_packing_loop(false); \
+ ThreadLocalTrace()->Write( \
+ "Other threads have finished packing what we were waiting for."); \
+ }
+
+#define RUY_TRACE_INFO_RUN_PACK \
+ ThreadLocalTrace()->Write("Path: %s", str(ThePath)); \
+ ThreadLocalTrace()->Write("Packing panel consisting of columns [%d, %d)", \
+ start_col, end_col); \
+ ThreadLocalTrace()->Write("Source: columns [%d, %d) of %s", start_col, \
+ end_col, str(src_matrix)); \
+ ThreadLocalTrace()->Write("Destination: columns [%d, %d) of %s", start_col, \
+ end_col, str(*packed_matrix)); \
+ if (end_col > src_matrix.layout.cols) { \
+ ThreadLocalTrace()->Write( \
+ "This runs past the last column of the source matrix. Padding as " \
+ "needed."); \
+ } \
+ if (packed_matrix->layout.rows > src_matrix.layout.rows) { \
+ ThreadLocalTrace()->Write( \
+ "The packed matrix has more rows than the source matrix due to " \
+ "rounding up to the kernel block size. Padding as needed."); \
+ }
+
+#define RUY_TRACE_INFO_RUN_KERNEL \
+ { \
+ ThreadLocalTrace()->Write("Path: %s", str(KernelArgs<KernelType>::kPath)); \
+ int lhs_cols = end[Side::kLhs] - start[Side::kLhs]; \
+ int rhs_cols = end[Side::kRhs] - start[Side::kRhs]; \
+ int kernel_lhs_cols = src[Side::kLhs].layout.kernel.cols; \
+ int kernel_rhs_cols = src[Side::kRhs].layout.kernel.cols; \
+ ThreadLocalTrace()->Write("LHS: columns [%d, %d) of %s", \
+ start[Side::kLhs], end[Side::kLhs], \
+ str(src[Side::kLhs])); \
+ ThreadLocalTrace()->Write("RHS: columns [%d, %d) of %s", \
+ start[Side::kRhs], end[Side::kRhs], \
+ str(src[Side::kRhs])); \
+ ThreadLocalTrace()->Write("Destination: block [%d, %d)x[%d, %d) of %s", \
+ start[Side::kLhs], end[Side::kLhs], \
+ start[Side::kRhs], end[Side::kRhs], str(*dst)); \
+ if (end[Side::kLhs] > dst->layout.rows || \
+ end[Side::kRhs] > dst->layout.cols) { \
+ ThreadLocalTrace()->Write( \
+ "This runs over the destination matrix boundaries. The kernel will " \
+ "internally clamp stores to avoid overruns."); \
+ } \
+ ThreadLocalTrace()->Write( \
+ "The kernel's inner loop only produces a %dx%d block, so the " \
+ "kernel's outer loops will run %dx%d times.", \
+ kernel_lhs_cols, kernel_rhs_cols, lhs_cols / kernel_lhs_cols, \
+ rhs_cols / kernel_rhs_cols); \
+ }
+
+#define RUY_TRACE_INFO_THREAD_FUNC_IMPL_WAITING \
+ ThreadLocalTrace()->Write("Waiting for a task...");
+
+#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_STARTING_TASK \
+ ThreadLocalTrace()->Write("Sending task #%d to a worker thread...", i);
+
+#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD \
+ ThreadLocalTrace()->Write("Running task #0 on the current thread...");
+
+#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_WAITING_FOR_THREADS \
+ ThreadLocalTrace()->Write("Waiting for worker threads to finish..");
+
+#define RUY_TRACE_INFO(id) \
+ [=]() { \
+ ThreadLocalTrace()->set_current_source_file(__FILE__); \
+ ThreadLocalTrace()->set_current_source_line(__LINE__); \
+ RUY_TRACE_INFO_##id \
+ }()
+
+} // namespace ruy
+
+#else
+
+// Vacuous implementation when RUY_TRACE is not defined.
+#define RUY_TRACE_SCOPE_NAME(name)
+#define RUY_TRACE_SCOPE
+#define RUY_TRACE_SET_THEAD_ID(thread_id)
+#define RUY_TRACE_INFO(id)
+
+#endif
+
+#endif // RUY_RUY_TRACE_H_
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
new file mode 100644
index 0000000..9345f0c
--- /dev/null
+++ b/ruy/trmul.cc
@@ -0,0 +1,397 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The 'middle-end' in ruy. See TrMul function comment.
+
+#include "ruy/trmul.h"
+
+#include <algorithm>
+#include <atomic>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <vector>
+
+#include "ruy/allocator.h"
+#include "ruy/block_map.h"
+#include "ruy/check_macros.h"
+#include "ruy/cpu_cache_params.h"
+#include "ruy/cpuinfo.h"
+#include "ruy/ctx.h"
+#include "ruy/mat.h"
+#include "ruy/matrix.h"
+#include "ruy/mul_params.h"
+#include "ruy/opt_set.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/side_pair.h"
+#include "ruy/size_util.h"
+#include "ruy/thread_pool.h"
+#include "ruy/trace.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+namespace {
+
+// Enum to track the packingstatus of a block of the LHS or RHS matrix.
+enum class PackingStatus : std::uint8_t {
+ kNotStarted, // No thread has started packing this block yet.
+ kInProgress, // Some thread is currently packing this block.
+ kFinished // This block has already been packed.
+};
+
+// TrMulTask is the task that a ruy thread runs to perform the TrMul operation.
+class TrMulTask final : public Task {
+ public:
+ TrMulTask(TrMulParams* params, const BlockMap& block_map,
+ std::atomic<int>* atomic_block_id, int thread_id, bool need_atomics,
+ SidePair<std::atomic<PackingStatus>*> packing_status,
+ TuningResolver* tuning_resolver, Allocator* local_allocator,
+ CpuInfo* cpuinfo)
+ : params_(params),
+ block_map_(block_map),
+ atomic_block_id_(atomic_block_id),
+ thread_id_(thread_id),
+ need_atomics_(need_atomics),
+ packing_status_(packing_status),
+ tuning_resolver_(tuning_resolver),
+ local_allocator_(local_allocator),
+ local_already_packed_{nullptr, nullptr},
+ cpuinfo_(cpuinfo) {}
+
+ // Thread main function. This is one thread's share of the TrMul work.
+ void Run() override {
+ RUY_TRACE_SCOPE_NAME("TrMulTask::Run");
+ RUY_TRACE_SET_THEAD_ID(thread_id_);
+ // Allocate and initialize `local_packed`.
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (!params_->is_prepacked[side]) {
+ const int size = NumBlocksPerSide(side, block_map_);
+ local_allocator_->Allocate(size, &local_already_packed_[side]);
+ memset(local_already_packed_[side], 0, size * sizeof(bool));
+ }
+ }
+
+ const Tuning tuning = tuning_resolver_->Resolve(cpuinfo_);
+ const int num_blocks = NumBlocks(block_map_);
+
+ // Each thread starts by initially reserving the block whose id
+ // is the thread id.
+ int block_id = thread_id_;
+ // Loop until all blocks have been computed.
+ while (block_id < num_blocks) {
+ RUY_TRACE_SCOPE_NAME("Main loop iteration");
+ // Reserve the next block to handle, hiding the latency of this atomic op.
+ const int next_block_id =
+ atomic_block_id_->fetch_add(1, std::memory_order_relaxed);
+ // Get coordinates of the current block to handle, in "block space".
+ SidePair<int> block;
+ GetBlockByIndex(block_map_, block_id, &block);
+ // Get coordinates of the current block to handle, in matrix space.
+ SidePair<int> start, end;
+ GetBlockMatrixCoords(block_map_, block, &start, &end);
+ RUY_TRACE_INFO(TRMUL_TASK_MAIN_LOOP_GOT_BLOCK_COORDS);
+ // Maybe pack the current LHS/RHS block, if not already packed.
+ EnsurePacked(block, start, end, tuning);
+ // Actually do matrix multiplication work
+ params_->RunKernel(tuning, start, end);
+ // Move on to the next block as obtained by the atomic increment
+ // at the start of this while loop iteration.
+ block_id = next_block_id;
+ }
+
+ local_allocator_->FreeAll();
+ }
+
+ private:
+ // Tries to pack a block, without blocking.
+ // If the block was already packed, returns true.
+ // If the block was not started packing, packs it and returns true.
+ // If the block was being packed by another thread, returns false.
+ bool TryPack(Side side, int block, int start, int end, Tuning tuning) {
+ if (params_->is_prepacked[side]) {
+ return true;
+ }
+ if (!local_already_packed_[side][block]) {
+ if (need_atomics_) {
+ // Explanation of this compare_exchange_strong operation:
+ // This atomically performs all of the following:
+ // 1. Read `status` with "acquire" memory order.
+ // * That this read uses "acquire" is because both memory orders
+ // specified have "acquire" as their read-component.
+ // 2. Compare (bitwise) with `exchanged_status`.
+ // 3. If equal, stores the value kInProgress to `status` with "release"
+ // memory order, and returns true, so we take this 'if' branch.
+ // * That this store uses "release" is because of the _rel part in
+ // memory_order_acq_rel passed as the first memory order argument.
+ // 4. If not equal, stores the loaded value of `status` to
+ // `exchanged_status` with "relaxed" semantics, and returns false,
+ // so we take the 'else' branch.
+ // * That this store uses "relaxed" is because the second memory
+ // order argument, memory_order_acquire, implies no particular
+ // store semantics. "relaxed" is acceptable here because this
+ // stores to a local stack variable.
+ //
+ // Rationale for compare_exchange_strong as opposed to
+ // compare_exchange_weak:
+ // The spurious-failure case with compare_exchange_weak will actually
+ // happen a lot here, because the atomic 'status' bytes are stored
+ // contiguously in arrays and neighboring values will be accessed
+ // by multiple threads concurrently. On a typical ARM CPU, an exclusives
+ // reservation granule is 64 bytes, so a lot of false-sharing may
+ // happen. Using compare_exchange_weak would thus result in often having
+ // TryPack return 'false' when it could instead have done the packing
+ // work and returned 'true'. Heuristically, that is not a good thing.
+ // Moreover, this changes the TryPack contract, loosening it and making
+ // it harder for the caller to reason about. Finally, the overhead of
+ // atomic operations is mitigated by the enclosing check on
+ // local_already_packed, so maybe the overhead of
+ // compare_exchange_strong isn't such a problem. But we don't really
+ // know for sure, that would be interesting to experiment more with.
+ PackingStatus exchanged_status = PackingStatus::kNotStarted;
+ std::atomic<PackingStatus>& status = packing_status_[side][block];
+ if (status.compare_exchange_strong(
+ exchanged_status, PackingStatus::kInProgress,
+ std::memory_order_acq_rel, std::memory_order_acquire)) {
+ // In this branch, the status was kNotStarted and we just atomically
+ // changed it to kInProgress as we are about to handle the packing
+ // ourselves.
+ RUY_TRACE_INFO(TRYPACK_PACKING);
+ params_->RunPack(side, tuning, start, end);
+ status.store(PackingStatus::kFinished, std::memory_order_release);
+ } else if (exchanged_status == PackingStatus::kInProgress) {
+ // Another thread is currently packing this block.
+ RUY_TRACE_INFO(TRYPACK_ANOTHER_THREAD_PACKING);
+ return false;
+ } else {
+ RUY_TRACE_INFO(TRYPACK_PACKED_BY_ANOTHER_THREAD);
+ }
+ RUY_DCHECK(status.load(std::memory_order_acquire) ==
+ PackingStatus::kFinished);
+ } else {
+ // Single-threaded case: no need for expensive atomics,
+ // local_already_packed is the truth already.
+ params_->RunPack(side, tuning, start, end);
+ }
+ local_already_packed_[side][block] = true;
+ } else {
+ RUY_TRACE_INFO(TRYPACK_PREVIOUSLY_PACKED);
+ }
+ return true;
+ }
+
+ // Ensures that both the LHS and RHS blocks required by the specified block
+ // are packed. In the event that they are already being packed on another
+ // threads, this function may perform the packing of some other block while
+ // waiting for that other thread to finish packing the requested block.
+ void EnsurePacked(const SidePair<int>& block, const SidePair<int>& start,
+ const SidePair<int>& end, Tuning tuning) {
+#if RUY_OPT(PACK_AHEAD)
+ SidePair<int> next_runahead_block{block[Side::kLhs] + 1,
+ block[Side::kRhs] + 1};
+ Side next_runahead_side = Side::kLhs;
+#endif
+ while (true) {
+ bool both_sides_packed = true;
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ both_sides_packed &=
+ TryPack(side, block[side], start[side], end[side], tuning);
+ }
+ if (both_sides_packed) {
+ break;
+ }
+#if RUY_OPT(PACK_AHEAD)
+ RUY_TRACE_INFO(ENSURE_PACKED_ENTER_RUN_AHEAD);
+ const Side runahead_side = next_runahead_side;
+ const int runahead_block = next_runahead_block[runahead_side];
+ next_runahead_side = OtherSide(next_runahead_side);
+ if (runahead_block >= NumBlocksPerSide(runahead_side, block_map_)) {
+ continue;
+ }
+ int runahead_block_start, runahead_block_end;
+ GetBlockMatrixCoords(runahead_side, block_map_, runahead_block,
+ &runahead_block_start, &runahead_block_end);
+ TryPack(runahead_side, runahead_block, runahead_block_start,
+ runahead_block_end, tuning);
+ next_runahead_block[runahead_side] = runahead_block + 1;
+#endif
+ }
+ RUY_TRACE_INFO(ENSURE_PACKED_END);
+ }
+
+ TrMulParams* params_;
+ const BlockMap& block_map_;
+ std::atomic<int>* atomic_block_id_;
+ int thread_id_;
+ bool need_atomics_;
+ SidePair<std::atomic<PackingStatus>*> packing_status_;
+ TuningResolver* tuning_resolver_;
+ Allocator* local_allocator_;
+
+ // Local indicators of packedness to avoid the overhead of atomic ops.
+ SidePair<bool*> local_already_packed_;
+
+ CpuInfo* cpuinfo_;
+};
+
+int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) {
+#if RUY_PLATFORM_EMSCRIPTEN
+ // b/139927184, std::thread constructor raises exception
+ return 1;
+#endif
+ RUY_TRACE_SCOPE;
+ // Empirically determined rule for reasonable number of
+ // threads to use. This is proportional to the number of arithmetic ops
+ // in this Mul (product of the 3 sizes).
+ static constexpr int kDivisorLog2 = 15;
+ const int guess_log2 = std::max(
+ 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2);
+ int tentative_thread_count =
+ std::min(1 << guess_log2, ctx->max_num_threads());
+ RUY_TRACE_INFO(GET_TENTATIVE_THREAD_COUNT);
+ return tentative_thread_count;
+}
+
+bool GetUseSimpleLoop(int tentative_thread_count, int rows, int cols, int depth,
+ int lhs_scalar_size, int rhs_scalar_size,
+ const CpuCacheParams& cpu_cache_params) {
+ RUY_TRACE_SCOPE;
+ if (tentative_thread_count == 1) {
+ if (IsObviouslyLinearTraversal(rows, cols, depth, lhs_scalar_size,
+ rhs_scalar_size, cpu_cache_params)) {
+ RUY_TRACE_INFO(GET_USE_SIMPLE_LOOP_RETURNS_TRUE);
+ return true;
+ }
+ }
+ RUY_TRACE_INFO(GET_USE_SIMPLE_LOOP_RETURNS_FALSE);
+ return false;
+}
+
+} // namespace
+
+// TrMul is the ruy middle-end. It contains the high-level logic to perform
+// a ruy::Mul's work, down to calls to back-end Kernel and Pack functions.
+// This includes determining how many threads to use, computing the BlockMap,
+// executing tasks on a thread-pool. The TrMul function itself runs on the main
+// thread, the code that is potentially running on worker threads is in
+// TrMulTask::Run().
+void TrMul(Ctx* ctx, TrMulParams* params) {
+ RUY_TRACE_SCOPE;
+ profiler::ScopeLabel label(
+ "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))",
+ static_cast<int>(params->path), ctx->max_num_threads(),
+ params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]);
+
+ PEMat& packed_lhs = params->packed_matrix[Side::kLhs];
+ PEMat& packed_rhs = params->packed_matrix[Side::kRhs];
+ EMat& lhs = params->src[Side::kLhs];
+ EMat& rhs = params->src[Side::kRhs];
+
+ const int rows = lhs.layout.cols;
+ const int cols = rhs.layout.cols;
+ const int depth = lhs.layout.rows;
+
+ const int tentative_thread_count =
+ GetTentativeThreadCount(ctx, rows, cols, depth);
+ const auto& cpu_cache_params = ctx->mutable_cpuinfo()->CacheParams();
+
+ // Case of running this TrMul as a simple loop.
+ // This is a good place to start reading this function: all the rest
+ // of this function is just an optimized, but functionally equivalent,
+ // version of that.
+ if (GetUseSimpleLoop(tentative_thread_count, rows, cols, depth,
+ lhs.data_type.size, rhs.data_type.size,
+ cpu_cache_params)) {
+ profiler::ScopeLabel label_simple("TrMulImpl, simple loop");
+ Tuning tuning = ctx->GetMainThreadTuning();
+ RUY_TRACE_INFO(TRMUL_SIMPLE_LOOP);
+
+ const SidePair<int> origin{0, 0};
+ const SidePair<int> rounded_dims{packed_lhs.layout.cols,
+ packed_rhs.layout.cols};
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (!params->is_prepacked[side]) {
+ params->RunPack(side, tuning, origin[side], rounded_dims[side]);
+ }
+ }
+ params->RunKernel(tuning, origin, rounded_dims);
+ return;
+ }
+
+ profiler::ScopeLabel label_general("TrMulImpl, general case");
+ RUY_TRACE_INFO(TRMUL_GENERAL_CASE);
+ Allocator* main_allocator = ctx->GetMainAllocator();
+
+ // Initialize block map.
+ BlockMap block_map;
+ MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth,
+ packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols,
+ packed_lhs.data_type.size, packed_rhs.data_type.size,
+ tentative_thread_count, cpu_cache_params, &block_map);
+
+ // Initialize per-thread state.
+ const int thread_count = block_map.thread_count;
+ const bool need_atomics = thread_count > 1;
+ ctx->EnsureThreadSpecificResources(thread_count);
+ for (int i = 0; i < thread_count; i++) {
+ ctx->GetThreadSpecificTuningResolver(i)->SetTuning(ctx->explicit_tuning());
+ }
+
+ // In the need_atomics case, allocate and initialize atomic values tracking
+ // the packing status of blocks.
+ SidePair<std::atomic<PackingStatus>*> packing_status{nullptr, nullptr};
+ if (need_atomics) {
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (!params->is_prepacked[side]) {
+ const int size = NumBlocksPerSide(side, block_map);
+ main_allocator->Allocate(size, &packing_status[side]);
+ for (int i = 0; i < size; i++) {
+ packing_status[side][i].store(PackingStatus::kNotStarted,
+ std::memory_order_relaxed);
+ }
+ }
+ }
+ }
+
+ // Create the atomic block id, allocate it using Allocator so that
+ // we get the alignment ensuring that it sits alone in its exclusives
+ // reservation granule.
+ std::atomic<int>* atomic_block_id;
+ main_allocator->Allocate(1, &atomic_block_id);
+
+ // Create task objects.
+ TrMulTask* tasks;
+ main_allocator->Allocate(thread_count, &tasks);
+
+ atomic_block_id->store(thread_count);
+
+ for (int i = 0; i < thread_count; i++) {
+ auto* allocator = ctx->GetThreadSpecificAllocator(i);
+ auto* tuning_resolver = ctx->GetThreadSpecificTuningResolver(i);
+ new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i,
+ need_atomics, packing_status, tuning_resolver,
+ allocator, ctx->mutable_cpuinfo());
+ }
+
+ // Do the computation.
+ ctx->mutable_thread_pool()->Execute(thread_count, tasks);
+
+ // Finish up.
+ for (int i = 0; i < thread_count; i++) {
+ tasks[i].~TrMulTask();
+ }
+}
+
+} // namespace ruy
diff --git a/ruy/trmul.h b/ruy/trmul.h
new file mode 100644
index 0000000..b4af287
--- /dev/null
+++ b/ruy/trmul.h
@@ -0,0 +1,39 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// As a matrix multiplication library, Ruy offers a Mul entry point, performing
+// matrix multiplication. For implementation purposes, it is much nicer to
+// be dealing with the transpose-and-multiply operation, doing
+// Destination = Transpose(LHS) * RHS
+// Indeed, the latter is performing dot-products between the *columns* of LHS
+// and the columns of RHS, whereas a plain matrix multiplication is performing
+// dot-products between the *rows* of LHS and the columns of RHS.
+// That is why TrMul is nicer to implement, allowing for a more symmetric
+// treatment of LHS and RHS.
+
+#ifndef RUY_RUY_TRMUL_H_
+#define RUY_RUY_TRMUL_H_
+
+#include "ruy/ctx.h"
+#include "ruy/trmul_params.h"
+
+namespace ruy {
+
+struct ContextInternal;
+void TrMul(Ctx* ctx, TrMulParams* params);
+
+} // namespace ruy
+
+#endif // RUY_RUY_TRMUL_H_
diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h
new file mode 100644
index 0000000..e68d909
--- /dev/null
+++ b/ruy/trmul_params.h
@@ -0,0 +1,92 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_TRMUL_PARAMS_H_
+#define RUY_RUY_TRMUL_PARAMS_H_
+
+#include <cstdint>
+
+#include "ruy/mat.h"
+#include "ruy/mul_params.h"
+#include "ruy/path.h"
+#include "ruy/side_pair.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+using RunKernelFn = void(Tuning, const SidePair<PEMat>&, const void*,
+ const SidePair<int>&, const SidePair<int>&, EMat*);
+
+using RunPackFn = void(Tuning, const EMat&, PEMat*, int, int);
+
+// This should not be needed since we require c++14, where std::max is already
+// constexpr, but TensorFlow continuous integration uses Ubuntu 16 with a
+// libstdc++ that does not support that.
+constexpr int constexpr_max(int a, int b) { return a > b ? a : b; }
+
+// Under-estimating these values would be caught by a static_assert in
+// StoreMulParams. Over-estimating these values cannot easily be caught, and
+// would cause unnecessary inflation of the TrMulParams data structure.
+constexpr int kMaxMulParamsAlignment =
+ constexpr_max(alignof(void*), alignof(double));
+constexpr int kMaxMulParamsSizeFloatingPointCase =
+ sizeof(MulParams<double, double>);
+constexpr int kMaxMulParamsSizeRawIntegerCase =
+ sizeof(MulParams<std::int32_t, std::int32_t>);
+constexpr int kMaxMulParamsSizeQuantizedIntegerCase =
+ sizeof(MulParams<std::int32_t, std::int16_t>);
+constexpr int kMaxMulParamsSize =
+ constexpr_max(kMaxMulParamsSizeFloatingPointCase,
+ constexpr_max(kMaxMulParamsSizeRawIntegerCase,
+ kMaxMulParamsSizeQuantizedIntegerCase));
+
+// OK to adjust as needed, but we want to avoid unnecessarily inflating that.
+static_assert(kMaxMulParamsSize <= 32, "");
+
+// Type-erased data needed for implementing TrMul.
+struct TrMulParams {
+ TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {}
+ // Helper functions for invoking the function pointers.
+ void RunPack(Side side, Tuning tuning, int start, int end) {
+ run_pack[side](tuning, src[side], &packed_matrix[side], start, end);
+ }
+ void RunKernel(Tuning tuning, const SidePair<int>& start,
+ const SidePair<int>& end) {
+ run_kernel(tuning, packed_matrix, mul_params_bytes, start, end, &dst);
+ }
+
+ // path id, can be useful info for some fine-tuning, e.g. to guess reasonable
+ // cache sizes when not runtime-detectable.
+ Path path;
+
+ // Function pointers to type-erased entry points for kernels and packers.
+ SidePair<RunPackFn*> run_pack;
+ RunKernelFn* run_kernel = nullptr;
+
+ // Matrices and packed matrices.
+ SidePair<EMat> src;
+ EMat dst;
+ SidePair<PEMat> packed_matrix;
+ SidePair<bool> is_prepacked;
+
+ // Bytes underlying the MulParams, used as type-erased storage for MulParams
+ // data as it isn't used until we reach the kernel code, where it is casted
+ // back to the original MulParams type.
+ alignas(kMaxMulParamsAlignment) char mul_params_bytes[kMaxMulParamsSize];
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_TRMUL_PARAMS_H_
diff --git a/ruy/tune.cc b/ruy/tune.cc
new file mode 100644
index 0000000..1f615bf
--- /dev/null
+++ b/ruy/tune.cc
@@ -0,0 +1,46 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/tune.h"
+
+#include <algorithm>
+#include <cstdint>
+
+#include "ruy/cpuinfo.h"
+
+namespace ruy {
+
+Tuning TuningResolver::ResolveNow(CpuInfo* cpuinfo) {
+ return cpuinfo->CurrentCpuIsA55ish() ? Tuning::kA55ish : Tuning::kGeneric;
+}
+
+TuningResolver::TuningResolver()
+ : expiry_duration_(DurationFromMilliseconds(250)) {}
+
+Tuning TuningResolver::Resolve(CpuInfo* cpuinfo) {
+ if (unresolved_tuning_ != Tuning::kAuto) {
+ return unresolved_tuning_;
+ }
+ TimePoint new_timepoint = CoarseNow();
+ if (last_resolved_tuning_ != Tuning::kAuto &&
+ (new_timepoint - last_resolved_timepoint_) < expiry_duration_) {
+ return last_resolved_tuning_;
+ }
+ last_resolved_timepoint_ = new_timepoint;
+ last_resolved_tuning_ = ResolveNow(cpuinfo);
+ return last_resolved_tuning_;
+}
+
+} // namespace ruy
diff --git a/ruy/tune.h b/ruy/tune.h
new file mode 100644
index 0000000..c9beed9
--- /dev/null
+++ b/ruy/tune.h
@@ -0,0 +1,117 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Library doing minimal CPU detection to decide what to tune asm code for.
+//
+// # Tuning vs Path
+//
+// Tunings are merely local variations of optimized code paths, that are
+// drop-in replacements for each other --- the input and output data layouts
+// are identical. By contrast, what ruy calls a Path dictates its own
+// data layouts. For example, Path::kNeonDotprod will use different
+// layouts compared to Path::kNeon; but within each, different tunings
+// will share that same layout.
+//
+// # Tuning is for now only based on 1 bit: Generic / A55ish
+//
+// In practice, each of our asm code paths only needs one bit information to
+// decide on tuning: whether the CPU is out-of-order or in-order.
+// That is because out-of-order CPUs are by definition relatively insensitive
+// to small-scale asm details (which is what "tuning" is about); and for each
+// asm code path, there tends to be one main in-order CPU architecture that
+// we focus our tuning effort on. Examples:
+// * For Path::kNeon, the main in-order CPU is Cortex-A53/A55 (pre-dotprod)
+// * For Path::kNeonDotprod, the main in-order CPU is Cortex-A55r1 (dotprod)
+//
+// Because having tuned code paths is a compromise of efficiency gains
+// versus implementation effort and code size, we are happy to stop at just this
+// single bit of information, Generic / A55ish, at least in the current CPU
+// landscape. This could change in the future.
+#ifndef RUY_RUY_TUNE_H_
+#define RUY_RUY_TUNE_H_
+
+#include "ruy/cpuinfo.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/time.h"
+
+namespace ruy {
+
+enum class Tuning {
+ // kAuto means please use auto-detection. It's the default in the
+ // user-visible parts (see Context). It's meant to be resolved to an
+ // actual tuning at some point by means of TuningResolver.
+ kAuto,
+ // Use code not tuned for any particular CPU, typically performing well
+ // on out-of-order cores that don't require as much tuning.
+ kGeneric,
+ // Use code tuned for "Cortex-A55-ish" CPUs, by which we mean mostly:
+ // A53, A55r0 (pre-dotprod), A55r1 (with dotprod). These CPUs have in common
+ // that they are in-order CPU cores with largely similar requirements of code
+ // tuning. The most important such requirement is to use only 64-bit loads
+ // to maximize dual-issuing.
+ //
+ // A55r1 differs from A55r0 and A53 in that it dual-issues 64-bit NEON loads
+ // whereas A55r0 and A53 require using non-NEON ARM 64-bit loads together with
+ // INS instructions to insert 64bit lanes into NEON registers. However, since
+ // A55r1 supports dotprod unlike A55r0 and A53, they are not using the same
+ // kernels in practice anyway, so there was no need to distinguish them with
+ // separate Tuning values.
+ kA55ish
+};
+
+// Why a TuningResolver class?
+//
+// Ideally, this Library would offer a single function,
+// Tuning GetCurrentCPUTuning();
+//
+// However, determining information about the current CPU is not necessarily
+// cheap, so we currently cache that and only invalidate/reevaluate after
+// a fixed amount of time. This need to store state is why this library
+// has to expose a class, TuningResolver, not just a function.
+class TuningResolver {
+ public:
+ TuningResolver();
+
+ // Allows the user to specify an explicit Tuning value, bypassing auto
+ // detection; or to specify Tuning::kAuto, reverting to auto detection.
+ void SetTuning(Tuning tuning) { unresolved_tuning_ = tuning; }
+
+ // Get an actual tuning --- that is the function that this class wanted to be.
+ Tuning Resolve(CpuInfo* cpuinfo);
+
+ private:
+ TuningResolver(const TuningResolver&) = delete;
+
+ // Perform the tuning resolution now. That may typically use EvalRatio and
+ // ThresholdRatio, but an implementation may use a different approach instead.
+ Tuning ResolveNow(CpuInfo* cpuinfo);
+
+ // The tuning as specified by the user, before actual resolution happens
+ // i.e. before querying any specifics of the current CPU.
+ // The default value kAuto means try to auto-detect. Other values mean
+ // bypass auto-detect, use explicit value instead. See SetTuning().
+ Tuning unresolved_tuning_ = Tuning::kAuto;
+ // Cached last resolved tuning.
+ Tuning last_resolved_tuning_ = Tuning::kAuto;
+ // Timepoint of cached last resolved tuning, for invalidation purposes.
+ TimePoint last_resolved_timepoint_;
+ // Cached last resolved tunings that are older than this age are invalid.
+ const Duration expiry_duration_;
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_TUNE_H_
diff --git a/ruy/tune_test.cc b/ruy/tune_test.cc
new file mode 100644
index 0000000..c5f2342
--- /dev/null
+++ b/ruy/tune_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/tune.h"
+
+#include <chrono> // NOLINT(build/c++11)
+#include <thread> // NOLINT(build/c++11)
+
+#include "ruy/cpuinfo.h"
+#include "ruy/gtest_wrapper.h"
+
+namespace ruy {
+namespace {
+
+TEST(TuneTest, TuneTest) {
+ TuningResolver tuning_resolver;
+ CpuInfo cpuinfo;
+ ASSERT_FALSE(tuning_resolver.Resolve(&cpuinfo) == Tuning::kAuto);
+ // 1 second is likely higher than TuningResolver's internal cache expiry,
+ // exercising the logic invalidating earlier tuning resolutions.
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ ASSERT_FALSE(tuning_resolver.Resolve(&cpuinfo) == Tuning::kAuto);
+
+ tuning_resolver.SetTuning(Tuning::kAuto);
+
+#ifdef RUY_IMPLEMENT_TUNING
+ for (auto tuning : {Tuning::kGeneric, Tuning::kA55ish}) {
+ tuning_resolver.SetTuning(tuning);
+ ASSERT_TRUE(tuning_resolver.Resolve(&cpuinfo) == tuning);
+ // See above comment about 1 second.
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ ASSERT_TRUE(tuning_resolver.Resolve(&cpuinfo) == tuning);
+ }
+#endif
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/validate.h b/ruy/validate.h
new file mode 100644
index 0000000..b164530
--- /dev/null
+++ b/ruy/validate.h
@@ -0,0 +1,75 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Front-end validation code, see the Validate function.
+
+#ifndef RUY_RUY_VALIDATE_H_
+#define RUY_RUY_VALIDATE_H_
+
+#include <cstdint>
+#include <limits>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/mat.h"
+#include "ruy/mul_params.h"
+#include "ruy/side_pair.h"
+
+namespace ruy {
+namespace detail {
+
+template <typename Scalar>
+void CheckZeroPoint(Scalar zero_point) {
+ if (std::is_floating_point<Scalar>::value) {
+ RUY_DCHECK(!zero_point);
+ }
+}
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar>
+void ValidateZeroPoints(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point,
+ DstScalar dst_zero_point) {
+ CheckZeroPoint(lhs_zero_point);
+ CheckZeroPoint(rhs_zero_point);
+ CheckZeroPoint(dst_zero_point);
+
+ // Guard against the case when both LHS and RHS zero_point's are equal to
+ // the minimum representable value. In that case, padding with zero_point
+ // values will generate the bad case for fast int8 kernels on NEON
+ // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8
+ // into a int16: this is safe except in the bad case -128*-128 + -128*-128.
+ // See b/131609283. This only affects the kNeon path but we ban this for all
+ // paths in order for ruy to have the same supported parameter space
+ // on all paths.
+ // We disable this check for now for the case of LhsScalar==RhsScalar==uint8
+ // for backwards compatability with gemmlowp. The issue is still relevant
+ // because we convert from uint8 to int8 for the backend kernels.
+ if (!std::is_same<LhsScalar, uint8_t>::value ||
+ !std::is_same<RhsScalar, uint8_t>::value) {
+ RUY_DCHECK(lhs_zero_point != std::numeric_limits<LhsScalar>::lowest() ||
+ rhs_zero_point != std::numeric_limits<RhsScalar>::lowest());
+ }
+}
+
+} // namespace detail
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar>
+void Validate(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
+ const Mat<DstScalar>& dst) {
+ detail::ValidateZeroPoints(lhs.zero_point, rhs.zero_point, dst.zero_point);
+}
+
+} // namespace ruy
+
+#endif // RUY_RUY_VALIDATE_H_
diff --git a/ruy/wait.cc b/ruy/wait.cc
new file mode 100644
index 0000000..fc33832
--- /dev/null
+++ b/ruy/wait.cc
@@ -0,0 +1,44 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/wait.h"
+
+#include <chrono> // NOLINT(build/c++11)
+
+namespace ruy {
+
+void Wait(const std::function<bool()>& condition, const Duration& spin_duration,
+ std::condition_variable* condvar, std::mutex* mutex) {
+ // First, trivial case where the `condition` is already true;
+ if (condition()) {
+ return;
+ }
+
+ // Then, if spin_duration is nonzero, try busy-waiting.
+ if (spin_duration.count() > 0) {
+ const TimePoint wait_start = Now();
+ while (Now() - wait_start < spin_duration) {
+ if (condition()) {
+ return;
+ }
+ }
+ }
+
+ // Finally, do real passive waiting.
+ std::unique_lock<std::mutex> lock(*mutex);
+ condvar->wait(lock, condition);
+}
+
+} // namespace ruy
diff --git a/ruy/wait.h b/ruy/wait.h
new file mode 100644
index 0000000..24f9912
--- /dev/null
+++ b/ruy/wait.h
@@ -0,0 +1,68 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_WAIT_H_
+#define RUY_RUY_WAIT_H_
+
+#include <condition_variable> // NOLINT(build/c++11)
+#include <functional>
+#include <mutex> // NOLINT(build/c++11)
+
+#include "ruy/time.h"
+
+namespace ruy {
+
+// Waits until some evaluation of `condition` has returned true.
+//
+// There is no guarantee that calling `condition` again after this function
+// has returned would still return true. The only
+// contract is that at some point during the execution of that function,
+// `condition` has returned true.
+//
+// First does some spin-waiting for the specified `spin_duration`,
+// then falls back to passive waiting for the given condvar, guarded
+// by the given mutex. At this point it will try to acquire the mutex lock,
+// around the waiting on the condition variable.
+// Therefore, this function expects that the calling thread hasn't already
+// locked the mutex before calling it.
+// This function will always release the mutex lock before returning.
+//
+// The idea of doing some initial spin-waiting is to help get
+// better and more consistent multithreading benefits for small GEMM sizes.
+// Spin-waiting help ensuring that if we need to wake up soon after having
+// started waiting, then we can wake up quickly (as opposed to, say,
+// having to wait to be scheduled again by the OS). On the other hand,
+// we must still eventually revert to passive waiting for longer waits
+// (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
+// so as to avoid permanently spinning.
+//
+// In situations where other threads might have more useful things to do with
+// these CPU cores than our spin-waiting, it may be best to reduce the value
+// of `spin_duration`. Setting it to zero disables the spin-waiting entirely.
+//
+// There is a risk that the std::function used here might use a heap allocation
+// to store its context. The expected usage pattern is that these functions'
+// contexts will consist of a single pointer value (typically capturing only
+// [this]), and that in this case the std::function implementation will use
+// inline storage, avoiding a heap allocation. However, we can't effectively
+// guard that assumption, and that's not a big concern anyway because the
+// latency of a small heap allocation is probably low compared to the intrinsic
+// latency of what this Wait function does.
+void Wait(const std::function<bool()>& condition, const Duration& spin_duration,
+ std::condition_variable* condvar, std::mutex* mutex);
+
+} // namespace ruy
+
+#endif // RUY_RUY_WAIT_H_
diff --git a/ruy/wait_test.cc b/ruy/wait_test.cc
new file mode 100644
index 0000000..fa02b0d
--- /dev/null
+++ b/ruy/wait_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/wait.h"
+
+#include <atomic>
+#include <condition_variable> // NOLINT(build/c++11)
+#include <mutex> // NOLINT(build/c++11)
+#include <thread> // NOLINT(build/c++11)
+
+#include "ruy/gtest_wrapper.h"
+#include "ruy/platform.h"
+
+namespace ruy {
+namespace {
+
+// Thread taking a `value` atomic counter and incrementing it until it equals
+// `end_value`, then notifying the condition variable as long as
+// `value == end_value`. If `end_value` is increased, it will then resume
+// incrementing `value`, etc. Terminates if `end_value == -1`.
+class ThreadCountingUpToValue {
+ public:
+ ThreadCountingUpToValue(const std::atomic<int>& end_value,
+ std::atomic<int>* value,
+ std::condition_variable* condvar, std::mutex* mutex)
+ : end_value_(end_value),
+ value_(value),
+ condvar_(condvar),
+ mutex_(mutex) {}
+ void operator()() {
+ // end_value_==-1 is how the master thread will tell us it's OK to terminate
+ while (end_value_.load() != -1) {
+ // wait until end_value is set to a higher value
+ while (value_->load() == end_value_.load()) {
+ }
+ // increment value as long as it's lower than end_value
+ while (value_->fetch_add(1) < end_value_.load() - 1) {
+ }
+ // when value has reached end_value, notify the master thread.
+ while (value_->load() == end_value_.load()) {
+ std::lock_guard<std::mutex> lock(*mutex_);
+ condvar_->notify_all();
+ }
+ }
+ }
+
+ private:
+ const std::atomic<int>& end_value_;
+ std::atomic<int>* value_;
+ std::condition_variable* condvar_;
+ std::mutex* mutex_;
+};
+
+void WaitTest(const Duration& spin_duration, const Duration& delay) {
+#if RUY_PLATFORM_EMSCRIPTEN
+ // b/139927184, std::thread constructor raises exception
+ return;
+#endif
+ std::condition_variable condvar;
+ std::mutex mutex;
+ std::atomic<int> value(0);
+ std::atomic<int> end_value(0);
+ ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex);
+ std::thread thread(thread_callable);
+ std::this_thread::sleep_for(delay);
+ for (int i = 1; i < 10; i++) {
+ end_value.store(1000 * i);
+ const auto& condition = [&value, &end_value]() {
+ return value.load() == end_value.load();
+ };
+ ruy::Wait(condition, spin_duration, &condvar, &mutex);
+ EXPECT_EQ(value.load(), end_value.load());
+ }
+ end_value.store(-1);
+ thread.join();
+}
+
+TEST(WaitTest, WaitTestNoSpin) {
+ WaitTest(DurationFromSeconds(0), DurationFromSeconds(0));
+}
+
+TEST(WaitTest, WaitTestSpinOneMicrosecond) {
+ WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0));
+}
+
+TEST(WaitTest, WaitTestSpinOneMillisecond) {
+ WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0));
+}
+
+TEST(WaitTest, WaitTestSpinOneSecond) {
+ WaitTest(DurationFromSeconds(1), DurationFromSeconds(0));
+}
+
+// Testcase to consistently reproduce the hang in b/139062384.
+TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) {
+ WaitTest(DurationFromSeconds(0), DurationFromSeconds(1));
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/third_party/BUILD b/third_party/BUILD
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/third_party/BUILD
diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt
new file mode 100644
index 0000000..4c82d83
--- /dev/null
+++ b/third_party/CMakeLists.txt
@@ -0,0 +1,5 @@
+# This file is generated (whence no license header). Do not edit!
+# To regenerate, run:
+# cmake/bazel_to_cmake.sh
+
+ruy_add_all_subdirs()
diff --git a/third_party/cpuinfo b/third_party/cpuinfo
new file mode 160000
+Subproject 5916273f79a21551890fd3d56fc5375a78d1598
diff --git a/third_party/cpuinfo.BUILD b/third_party/cpuinfo.BUILD
new file mode 100644
index 0000000..6d68cd5
--- /dev/null
+++ b/third_party/cpuinfo.BUILD
@@ -0,0 +1,386 @@
+# cpuinfo, a library to detect information about the host CPU
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+C99OPTS = [
+ "-std=gnu99", # gnu99, not c99, because dprintf is used
+ "-Wno-vla",
+ "-D_GNU_SOURCE=1", # to use CPU_SETSIZE
+ "-DCPUINFO_INTERNAL=",
+ "-DCPUINFO_PRIVATE=",
+]
+
+# Source code common to all platforms.
+COMMON_SRCS = [
+ "src/api.c",
+ "src/init.c",
+ "src/cache.c",
+]
+
+# Architecture-specific sources and headers.
+X86_SRCS = [
+ "src/x86/cache/descriptor.c",
+ "src/x86/cache/deterministic.c",
+ "src/x86/cache/init.c",
+ "src/x86/info.c",
+ "src/x86/init.c",
+ "src/x86/isa.c",
+ "src/x86/name.c",
+ "src/x86/topology.c",
+ "src/x86/uarch.c",
+ "src/x86/vendor.c",
+]
+
+ARM_SRCS = [
+ "src/arm/cache.c",
+ "src/arm/uarch.c",
+]
+
+# Platform-specific sources and headers
+LINUX_SRCS = [
+ "src/linux/cpulist.c",
+ "src/linux/multiline.c",
+ "src/linux/processors.c",
+ "src/linux/smallfile.c",
+]
+
+MOCK_LINUX_SRCS = [
+ "src/linux/mockfile.c",
+]
+
+MACH_SRCS = [
+ "src/mach/topology.c",
+]
+
+EMSCRIPTEN_SRCS = [
+ "src/emscripten/init.c",
+]
+
+LINUX_X86_SRCS = [
+ "src/x86/linux/cpuinfo.c",
+ "src/x86/linux/init.c",
+]
+
+LINUX_ARM_SRCS = [
+ "src/arm/linux/chipset.c",
+ "src/arm/linux/clusters.c",
+ "src/arm/linux/cpuinfo.c",
+ "src/arm/linux/hwcap.c",
+ "src/arm/linux/init.c",
+ "src/arm/linux/midr.c",
+]
+
+LINUX_ARM32_SRCS = LINUX_ARM_SRCS + ["src/arm/linux/aarch32-isa.c"]
+
+LINUX_ARM64_SRCS = LINUX_ARM_SRCS + ["src/arm/linux/aarch64-isa.c"]
+
+ANDROID_ARM_SRCS = [
+ "src/arm/android/properties.c",
+]
+
+WINDOWS_X86_SRCS = [
+ "src/x86/windows/init.c",
+]
+
+MACH_X86_SRCS = [
+ "src/x86/mach/init.c",
+]
+
+MACH_ARM_SRCS = [
+ "src/arm/mach/init.c",
+]
+
+cc_library(
+ name = "clog",
+ srcs = [
+ "deps/clog/src/clog.c",
+ ],
+ hdrs = [
+ "deps/clog/include/clog.h",
+ ],
+ copts = select({
+ ":windows_x86_64": [],
+ "//conditions:default": ["-Wno-unused-result"],
+ }),
+ linkopts = select({
+ ":android": ["-llog"],
+ "//conditions:default": [],
+ }),
+ linkstatic = select({
+ # https://github.com/bazelbuild/bazel/issues/11552
+ ":macos_x86_64": False,
+ "//conditions:default": True,
+ }),
+ defines = select({
+ # When linkstatic=False, we need default visibility
+ ":macos_x86_64": ["CLOG_VISIBILITY="],
+ "//conditions:default": [],
+ }),
+ strip_include_prefix = "deps/clog/include",
+)
+
+cc_library(
+ name = "cpuinfo_impl",
+ srcs = select({
+ ":linux_x86_64": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS,
+ ":linux_arm": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS,
+ ":linux_armhf": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS,
+ ":linux_armv7a": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS,
+ ":linux_armeabi": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS,
+ ":linux_aarch64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS,
+ ":macos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":windows_x86_64": COMMON_SRCS + X86_SRCS + WINDOWS_X86_SRCS,
+ ":android_armv7": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS + ANDROID_ARM_SRCS,
+ ":android_arm64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS + ANDROID_ARM_SRCS,
+ ":android_x86": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS,
+ ":android_x86_64": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS,
+ ":ios_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":ios_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":ios_armv7": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":ios_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":ios_arm64e": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":watchos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":watchos_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":watchos_armv7k": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":watchos_arm64_32": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":tvos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":tvos_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":emscripten": COMMON_SRCS + EMSCRIPTEN_SRCS,
+ }),
+ copts = select({
+ ":windows_x86_64": [],
+ "//conditions:default": C99OPTS,
+ }) + [
+ "-Iexternal/cpuinfo/include",
+ "-Iexternal/cpuinfo/src",
+ ],
+ linkstatic = select({
+ # https://github.com/bazelbuild/bazel/issues/11552
+ ":macos_x86_64": False,
+ "//conditions:default": True,
+ }),
+ # Headers must be in textual_hdrs to allow us to set the standard to C99
+ textual_hdrs = [
+ "include/cpuinfo.h",
+ "src/linux/api.h",
+ "src/mach/api.h",
+ "src/cpuinfo/common.h",
+ "src/cpuinfo/internal-api.h",
+ "src/cpuinfo/log.h",
+ "src/cpuinfo/utils.h",
+ "src/x86/api.h",
+ "src/x86/cpuid.h",
+ "src/x86/linux/api.h",
+ "src/arm/android/api.h",
+ "src/arm/linux/api.h",
+ "src/arm/linux/cp.h",
+ "src/arm/api.h",
+ "src/arm/midr.h",
+ ],
+ deps = [
+ ":clog",
+ ],
+)
+
+cc_library(
+ name = "cpuinfo",
+ hdrs = [
+ "include/cpuinfo.h",
+ ],
+ strip_include_prefix = "include",
+ deps = [
+ ":cpuinfo_impl",
+ ],
+)
+
+cc_library(
+ name = "cpuinfo_with_unstripped_include_path",
+ hdrs = [
+ "include/cpuinfo.h",
+ ],
+ deps = [
+ ":cpuinfo_impl",
+ ],
+)
+
+############################# Build configurations #############################
+
+config_setting(
+ name = "linux_x86_64",
+ values = {"cpu": "k8"},
+)
+
+config_setting(
+ name = "linux_arm",
+ values = {"cpu": "arm"},
+)
+
+config_setting(
+ name = "linux_armhf",
+ values = {"cpu": "armhf"},
+)
+
+config_setting(
+ name = "linux_armv7a",
+ values = {"cpu": "armv7a"},
+)
+
+config_setting(
+ name = "linux_armeabi",
+ values = {"cpu": "armeabi"},
+)
+
+config_setting(
+ name = "linux_aarch64",
+ values = {"cpu": "aarch64"},
+)
+
+config_setting(
+ name = "macos_x86_64",
+ values = {
+ "apple_platform_type": "macos",
+ "cpu": "darwin",
+ },
+)
+
+config_setting(
+ name = "android",
+ values = {"crosstool_top": "//external:android/crosstool"},
+)
+
+config_setting(
+ name = "windows_x86_64",
+ values = {"cpu": "x64_windows"},
+)
+
+config_setting(
+ name = "android_armv7",
+ values = {
+ "crosstool_top": "//external:android/crosstool",
+ "cpu": "armeabi-v7a",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "android_arm64",
+ values = {
+ "crosstool_top": "//external:android/crosstool",
+ "cpu": "arm64-v8a",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "android_x86",
+ values = {
+ "crosstool_top": "//external:android/crosstool",
+ "cpu": "x86",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "android_x86_64",
+ values = {
+ "crosstool_top": "//external:android/crosstool",
+ "cpu": "x86_64",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "ios_armv7",
+ values = {
+ "apple_platform_type": "ios",
+ "cpu": "ios_armv7",
+ },
+)
+
+config_setting(
+ name = "ios_arm64",
+ values = {
+ "apple_platform_type": "ios",
+ "cpu": "ios_arm64",
+ },
+)
+
+config_setting(
+ name = "ios_arm64e",
+ values = {
+ "apple_platform_type": "ios",
+ "cpu": "ios_arm64e",
+ },
+)
+
+config_setting(
+ name = "ios_x86",
+ values = {
+ "apple_platform_type": "ios",
+ "cpu": "ios_i386",
+ },
+)
+
+config_setting(
+ name = "ios_x86_64",
+ values = {
+ "apple_platform_type": "ios",
+ "cpu": "ios_x86_64",
+ },
+)
+
+config_setting(
+ name = "watchos_armv7k",
+ values = {
+ "apple_platform_type": "watchos",
+ "cpu": "watchos_armv7k",
+ },
+)
+
+config_setting(
+ name = "watchos_arm64_32",
+ values = {
+ "apple_platform_type": "watchos",
+ "cpu": "watchos_arm64_32",
+ },
+)
+
+config_setting(
+ name = "watchos_x86",
+ values = {
+ "apple_platform_type": "watchos",
+ "cpu": "watchos_i386",
+ },
+)
+
+config_setting(
+ name = "watchos_x86_64",
+ values = {
+ "apple_platform_type": "watchos",
+ "cpu": "watchos_x86_64",
+ },
+)
+
+config_setting(
+ name = "tvos_arm64",
+ values = {
+ "apple_platform_type": "tvos",
+ "cpu": "tvos_arm64",
+ },
+)
+
+config_setting(
+ name = "tvos_x86_64",
+ values = {
+ "apple_platform_type": "tvos",
+ "cpu": "tvos_x86_64",
+ },
+)
+
+config_setting(
+ name = "emscripten",
+ values = {"crosstool_top": "//toolchain:emscripten"},
+)
diff --git a/third_party/googletest b/third_party/googletest
new file mode 160000
+Subproject 6c58c11d5497b6ee1df3cb400ce30deb72fc28c