diff options
author | Lev Proleev <levp@google.com> | 2021-03-12 18:55:26 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-03-12 18:55:26 +0000 |
commit | 2c005aca73d9a32151a040aa476fed0ec89a14ae (patch) | |
tree | 45b033783360ed59b9efbfd8bd2dfdb4fc8bdb72 | |
parent | d23d5384ee2dad29223e8c57248ea83ec23da4bf (diff) | |
parent | 998a6df5933c918fe486c3cbd3cb1e699e0211b5 (diff) | |
download | ruy-2c005aca73d9a32151a040aa476fed0ec89a14ae.tar.gz |
Merge remote-tracking branch 'aosp/upstream-master' into tflite-rebase-feb-2021 am: 713d254ecf am: dd1f1778c2 am: 039d972c6e am: 998a6df593
Original change: https://android-review.googlesource.com/c/platform/external/ruy/+/1610773
MUST ONLY BE SUBMITTED BY AUTOMERGER
Change-Id: Ie195bc9d9ec4d8755daeb0c6b0c52f52ada4df51
140 files changed, 41802 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..950a145 --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +# Visual Studio files +.vs/ +.vscode/ +*.sdf +*.opensdf +*.VC.opendb +*.suo +*.user + +# macOS files +.DS_Store + +# CMake artifacts +build/ +build-*/ + +# Bazel artifacts +**/bazel-* + +# Emacs autosaves +*~ +\#*\# + +# Vim swap files +[._]*.sw[a-p] + +# Source indexing files +compile_commands.json +.cache/clangd +.clangd/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..0e03f4e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "googletest"] + path = third_party/googletest + url = https://github.com/google/googletest +[submodule "cpuinfo"] + path = third_party/cpuinfo + url = https://github.com/pytorch/cpuinfo @@ -0,0 +1,7 @@ +# Ruy is not BLAS + +package( + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..98d480d --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,90 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_policy(SET CMP0012 NEW) +cmake_policy(SET CMP0048 NEW) +project(ruy CXX) +cmake_minimum_required(VERSION 3.13) # Copied from IREE +set(CMAKE_CXX_STANDARD 14) + + + +if (PROJECT_NAME STREQUAL CMAKE_PROJECT_NAME) + set(RUY_IS_TOPLEVEL TRUE) + set(RUY_MINIMAL_BUILD_DEFAULT_VALUE OFF) +else() + set(RUY_IS_TOPLEVEL FALSE) + set(RUY_MINIMAL_BUILD_DEFAULT_VALUE ON) +endif() + +option(RUY_MINIMAL_BUILD "Disable ruy's tests, examples, etc. Build only ruy public libraries." ${RUY_MINIMAL_BUILD_DEFAULT_VALUE}) +if (NOT RUY_MINIMAL_BUILD) + enable_testing() +endif() + +option(RUY_PROFILER "Enable ruy's built-in profiler (harms performance)" OFF) + +include(cmake/ruy_add_all_subdirs.cmake) +include(cmake/ruy_cc_library.cmake) +include(cmake/ruy_cc_binary.cmake) +include(cmake/ruy_cc_test.cmake) + +# Skip cpuinfo if it was already generated, which can happen when ruy is +# a subdirectory in a wider project that already uses cpuinfo. +if (NOT TARGET cpuinfo) + # Test if the third_party/cpuinfo submodule was checked out before + # adding that subdirectory, so we can do more helpful things below in the + # else() block when it's not. + set(RUY_CPUINFO_CMAKELISTS_FILE "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cpuinfo/CMakeLists.txt") + if (EXISTS "${RUY_CPUINFO_CMAKELISTS_FILE}") + # Disabling cpuinfo's tests and benchmarks to prevent a copy of its + # googletest dependency getting downloaded into a 'deps' directory in the + # source tree! + set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) + set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "" FORCE) + set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "" FORCE) + add_subdirectory("third_party/cpuinfo" EXCLUDE_FROM_ALL) + else() + # third_party/cpuinfo is not checked out. That could be intentional when + # ruy is a subdirectory in a wider project that is already providing + # the cpuinfo target. Maybe that wider project's CMakeLists is ordered + # in such a way that cpuinfo gets generated after ruy. In that case, + # it's helpful that we continue silently. In the worst case if the cpuinfo + # target never gets defined, ruy will fail to compile. + # On the other hand, if ruy is the top-level project here (not part of a + # wider project) then nothing will define the cpuinfo target for us, + # so we will definitely fail to compile, so we may as well fail right here. + if (RUY_IS_TOPLEVEL) + message(FATAL_ERROR "This file does not exist:\n${RUY_CPUINFO_CMAKELISTS_FILE}\n" + "That typically means that the git submodules of the ruy " + "repository haven't been checked out. Try this in the ruy " + "git directory:\n git submodule update --init") + endif() + endif() +endif() + +# googletest is only needed for tests. Projects embedding ruy as a subdirectory +# and not needing to build ruy tests may proceed without a local checkout of +# third_party/googletest. +if (NOT RUY_MINIMAL_BUILD + AND NOT TARGET gtest + AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/third_party/googletest/CMakeLists.txt") + add_subdirectory("third_party/googletest" EXCLUDE_FROM_ALL) +endif() + +add_subdirectory("ruy") + +if (NOT RUY_MINIMAL_BUILD) + add_subdirectory("example") +endif() diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..654a071 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to <https://cla.developers.google.com/> to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows [Google's Open Source Community +Guidelines](https://opensource.google/conduct/). @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f341563 --- /dev/null +++ b/README.md @@ -0,0 +1,25 @@ +# The ruy matrix multiplication library + +This is not an officially supported Google product. + +ruy is a matrix multiplication library. Its focus is to cover the matrix +multiplication needs of neural network inference engines. Its initial user has +been TensorFlow Lite, where it is used by default on the ARM CPU architecture. + +ruy supports both floating-point and 8bit-integer-quantized matrices. + +## Efficiency + +ruy is designed to achieve high performance not just on very large sizes, as +is the focus of many established libraries, but on whatever are the actual sizes +and shapes of matrices most critical in current TensorFlow Lite applications. +This often means quite small sizes, e.g. 100x100 or even 50x50, and all sorts of +rectangular shapes. It's not as fast as completely specialized code for each +shape, but it aims to offer a good compromise of speed across all shapes and a +small binary size. + +## Documentation + +Some documentation will eventually be available in the doc/ directory, see +[doc/README.md](doc/README.md). + diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..da4fe9f --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,44 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Workspace file for the Ruy project. + +workspace(name = "com_google_ruy") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") + +maybe( + local_repository, + name = "com_google_googletest", + path = "third_party/googletest", +) + +maybe( + new_local_repository, + name = "cpuinfo", + path = "third_party/cpuinfo", + build_file = "@//third_party:cpuinfo.BUILD", +) + +# skylib utility for additional bazel functionality. +skylib_version = "0.9.0" +http_archive( + name = "bazel_skylib", + type = "tar.gz", + url = "https://github.com/bazelbuild/bazel-skylib/releases/download/{}/bazel_skylib-{}.tar.gz".format (skylib_version, skylib_version), + sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0", +) +load("@bazel_skylib//lib:versions.bzl", "versions") +versions.check(minimum_bazel_version = "2.0.0") diff --git a/cmake/bazel_to_cmake.py b/cmake/bazel_to_cmake.py new file mode 100755 index 0000000..ba1a38b --- /dev/null +++ b/cmake/bazel_to_cmake.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This is yet another bazel-to-cmake converter. It's independently written from +scratch but relies on the same basic idea as others (including IREE's), namely: +to let the python interpreter do the bulk of the work, exploiting the fact that +both Bazel's BUILD syntax and Starlark (".bzl") languages are more or less +subsets of Python. + +The main features that this converter supports and that others don't, justifying +its existence as of early 2021, are: + 1. Ad-hoc support for select(), generating CMake if()...elseif()... chains + parsing the condition keys (e.g. anything ending in ":windows" is + interpreted as the condition "the target platform is Windows"). This allows + to just ignore config_setting, as we only care about the config_setting + names, not their actual implementation, as well as all the variants from + the Bazel 'selects' library. + 2. Support for load(), loading macros from Starlark files. +""" + +import re +import os +import os.path +import pickle +import sys +import datetime +import itertools + +# Ruy's dependencies. +external_targets = ['gtest', 'gtest_main', 'cpuinfo'] + +# Text replacements [oldstring, newstring] pairs, applied on all BUILD and +# Starlark files that we load. Only used by preprocess_input_text. +replacements = [ + ['$(STACK_FRAME_UNLIMITED)', ''], + ['native.cc_', 'cc_'], + ['selects.config_setting_group', 'config_setting_group'], + ['@com_google_googletest//:gtest', 'gtest'], + ['@com_google_googletest//:gtest_main', 'gtest_main'], + ['@cpuinfo//:cpuinfo_with_unstripped_include_path', 'cpuinfo'], +] + + +def preprocess_input_text(text): + result = text + for replacement in replacements: + result = result.replace(replacement[0], replacement[1]) + return result + + +def set_cmake_list(list_name, values, indent): + semicolon_separated = ";".join(values) + print(f'{indent}set({list_name} "{semicolon_separated}")') + + +def generate_cmake_select(select_name, dict): + new_if_branch_keyword = 'if' + default_value = [] + for key in dict: + condition = '' + if key == '//conditions:default': + default_value = dict[key] + continue + elif re.search(r':windows$', key): + condition = 'CMAKE_SYSTEM_NAME STREQUAL Windows' + elif re.search(r':ppc$', key): + condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le' + elif re.search(r':s390x$', key): + condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x' + elif re.search(r':fuchsia$', key): + condition = 'CMAKE_SYSTEM_NAME STREQUAL Fuchsia' + elif re.search(r':arm32_assuming_neon$', key): + condition = 'CMAKE_SYSTEM_PROCESSOR STREQUAL arm' + elif re.search(r':do_not_want_O3$', key): + # Ruy is a specialist library: we always want code to be compiled + # with -O3 unless the build type is Debug or the compiler does not + # support that flag syntax. + condition = '(CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC' + elif re.search(r':x86_64_and_not_msvc$', key): + condition = '(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC' + elif re.search(r':windows_msvc$', key): + condition = 'MSVC' + elif re.search(r':ruy_profiler$', key): + condition = '${RUY_PROFILER}' + else: + raise ValueError(f'Unhandled key in select: {key}') + + print(f'{new_if_branch_keyword}({condition})') + set_cmake_list(select_name, dict[key], ' ') + new_if_branch_keyword = 'elseif' + + print('else()') + set_cmake_list(select_name, default_value, ' ') + + print('endif()\n') + + +def trim_multiple_ruy_prefixes(name): + return re.sub(r'(ruy_)+ruy', 'ruy', name) + +def get_cmake_local_target_name(name): + global package_prefix + return trim_multiple_ruy_prefixes(f'ruy_{package_prefix}_{name}') + + +def get_cmake_dep_target_name(name): + if name in external_targets: + return name + if name.startswith('$'): + # Happens for deps that are the result of expanding a select() that we + # have compiled to expanding a variable. + return name + if name.startswith('//'): + after_last_slash = name.split('/')[-1] + if not ':' in after_last_slash: + name = f'{name}:{after_last_slash}' + raw=name[2:].replace('/', '_').replace(':', '_') + return trim_multiple_ruy_prefixes(raw) + if name.startswith(':'): + name = name[1:] + return get_cmake_local_target_name(name) + + +# +# Functions implementing BUILD functions +# + + +def package(**kwargs): + pass + + +def exports_files(*args): + pass + + +def load(filename, *args): + if filename.startswith('@'): + return + elif filename.startswith(':'): + filename = os.path.join(bazel_package_dir, filename[1:]) + elif filename.startswith('//'): + split = filename[2:].split(':') + filename = os.path.join(bazel_workspace_dir, split[0], split[1]) + + src_file_content = open(filename).read() + processed_file_content = preprocess_input_text(src_file_content) + exec(processed_file_content, globals(), globals()) + + +def config_setting(**kwargs): + # Nothing to do since our implementation of select() is based on parsing + # the names of config_settings, not looking deep into their actual + # implementation. + pass + + +def filegroup(**kwargs): + pass + + +def config_setting_group(**kwargs): + # See config_setting. + pass + + +def bzl_library(**kwargs): + pass + + +select_index = 0 +select_cache = {} + + +def select(select_dict): + global select_index + global select_cache + global package_prefix + key = pickle.dumps(sorted(select_dict.items())) + if key in select_cache: + select_name = select_cache[key] + else: + unique_values = sorted(set(itertools.chain.from_iterable(select_dict.values()))) # sorting ensures determinism, no spurious diffs + description = '_'.join(unique_values) + select_name = f'{package_prefix}_{select_index}_{description}' + select_name = select_name.replace('c++', 'cxx') + select_name = re.sub(r'[^a-zA-Z0-9]+', '_', select_name) + select_index = select_index + 1 + select_cache[key] = select_name + generate_cmake_select(select_name, select_dict) + + return [f'${{{select_name}}}'] + + +def generic_rule(rule_name, **kwargs): + print(f'{rule_name}(') + for key in kwargs.keys(): + values = kwargs[key] + if type(values) is bool: + if values: + print(f' {key.upper()}') + continue + else: + raise ValueError( + 'Cannot specify FALSE boolean args in CMake') + if key == 'visibility': + if values == ['//visibility:public']: + print(f' PUBLIC') + continue + if key == 'tags': + values = list(filter(lambda x : not x.startswith('req_dep'), values)) + if not values: + continue + print(f' {key.upper()}') + if type(values) is list: + for value in values: + if key == 'deps': + target_name = get_cmake_dep_target_name(value) + print(f' {target_name}') + else: + print(f' {value}') + else: + if key == 'name': + target_name = get_cmake_local_target_name(values) + print(f' {target_name}') + else: + print(f' {values}') + print(')\n') + + +def cc_library(**kwargs): + generic_rule('ruy_cc_library', **kwargs) + + +def cc_test(**kwargs): + generic_rule('ruy_cc_test', **kwargs) + + +def cc_binary(**kwargs): + generic_rule('ruy_cc_binary', **kwargs) + + +# +# Program entry point. +# +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: bazel_to_cmake.py bazel_workspace_dir bazel_package_dir") + sys.exit(1) + + bazel_workspace_dir = sys.argv[1] + bazel_package_dir = sys.argv[2] + bazel_package_relative_dir = os.path.relpath( + bazel_package_dir, bazel_workspace_dir) + package_prefix = bazel_package_relative_dir.replace(os.path.sep, '_') + + print("""# This file is generated (whence no license header). Do not edit! +# To regenerate, run: +# cmake/bazel_to_cmake.sh +""") + + src_build_file = os.path.join(bazel_package_dir, "BUILD") + src_build_content = open(src_build_file).read() + processed_build_content = preprocess_input_text(src_build_content) + exec(processed_build_content) + + print("ruy_add_all_subdirs()") diff --git a/cmake/bazel_to_cmake.sh b/cmake/bazel_to_cmake.sh new file mode 100755 index 0000000..296219e --- /dev/null +++ b/cmake/bazel_to_cmake.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +this_script_dir="$(dirname "$0")" + +root_dir="$(git -C "${this_script_dir}" rev-parse --show-toplevel)" + +build_files="$(find "${root_dir}" -type f -name BUILD)" + +if ! command -v python3 &> /dev/null; then + python_command=python +else + python_command=python3 +fi + +for build_file in ${build_files}; do + package_dir="$(dirname "${build_file}")" + if [[ "${package_dir}" == "${root_dir}" ]]; then + # The root CMakeLists.txt is not generated. + continue + fi + "${python_command}" "${this_script_dir}/bazel_to_cmake.py" "${root_dir}" "${package_dir}" > "${package_dir}/CMakeLists.txt" +done diff --git a/cmake/run_android_test.sh b/cmake/run_android_test.sh new file mode 100755 index 0000000..d643232 --- /dev/null +++ b/cmake/run_android_test.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Minimal script pushing and running a file on device! +# Contemporary versions of ADB properly propagate exit codes so nothing more +# is needed to let CTest report test success/failure. + +# TODO: consider clearing temporary files after testing, although that will +# get in the way of debugging and will make code more complex... also, +# Ruy's test files aren't huge and people running these probably have +# bigger clutter issues in their /data/local/tmp anyway. Anyway, if we want +# to do this, we could copy IREE's code. + +device_tmpdir=/data/local/tmp + +adb push "$1" "${device_tmpdir}" +adb shell "${device_tmpdir}/$(basename "$1")" diff --git a/cmake/ruy_add_all_subdirs.cmake b/cmake/ruy_add_all_subdirs.cmake new file mode 100644 index 0000000..1a7d126 --- /dev/null +++ b/cmake/ruy_add_all_subdirs.cmake @@ -0,0 +1,37 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Forked from IREE's iree_add_all_subdirs.cmake. + +# add_all_subidrs +# +# CMake function to add all subdirectories of the current directory that contain +# a CMakeLists.txt file +# +# Takes no arguments. +function(ruy_add_all_subdirs) + FILE(GLOB _CHILDREN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/*) + SET(_DIRLIST "") + foreach(_CHILD ${_CHILDREN}) + if((NOT(subdir MATCHES third_party)) AND + (IS_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/${_CHILD}) AND + (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${_CHILD}/CMakeLists.txt)) + LIST(APPEND _DIRLIST ${_CHILD}) + endif() + endforeach() + + foreach(subdir ${_DIRLIST}) + add_subdirectory(${subdir}) + endforeach() +endfunction() diff --git a/cmake/ruy_cc_binary.cmake b/cmake/ruy_cc_binary.cmake new file mode 100644 index 0000000..93d1adf --- /dev/null +++ b/cmake/ruy_cc_binary.cmake @@ -0,0 +1,57 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Forked from IREE's iree_cc_binary.cmake. + +include(CMakeParseArguments) +include(cmake/ruy_include_directories.cmake) + +# ruy_cc_binary() +# +# CMake function to imitate Bazel's cc_binary rule. +function(ruy_cc_binary) + cmake_parse_arguments( + _RULE + "TESTONLY" + "NAME" + "SRCS;COPTS;LINKOPTS;DEPS;TAGS" + ${ARGN} + ) + + if(_RULE_TESTONLY AND RUY_MINIMAL_BUILD) + return() + endif() + + set(_NAME "${_RULE_NAME}") + + add_executable(${_NAME} "") + target_sources(${_NAME} + PRIVATE + ${_RULE_SRCS} + ) + set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}") + ruy_include_directories(${_NAME} "${_RULE_DEPS}") + target_compile_options(${_NAME} + PRIVATE + ${_RULE_COPTS} + ) + target_link_options(${_NAME} + PRIVATE + ${_RULE_LINKOPTS} + ) + target_link_libraries(${_NAME} + PUBLIC + ${_RULE_DEPS} + ) +endfunction() diff --git a/cmake/ruy_cc_library.cmake b/cmake/ruy_cc_library.cmake new file mode 100644 index 0000000..38accc5 --- /dev/null +++ b/cmake/ruy_cc_library.cmake @@ -0,0 +1,85 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Forked from IREE's iree_cc_library.cmake. + +include(CMakeParseArguments) +include(cmake/ruy_include_directories.cmake) + +# ruy_cc_library() +# +# CMake function to imitate Bazel's cc_library rule. +function(ruy_cc_library) + cmake_parse_arguments( + _RULE + "PUBLIC;TESTONLY" + "NAME" + "HDRS;SRCS;COPTS;DEFINES;LINKOPTS;DEPS" + ${ARGN} + ) + + if(_RULE_TESTONLY AND RUY_MINIMAL_BUILD) + return() + endif() + + set(_NAME "${_RULE_NAME}") + + # Check if this is a header-only library. + if("${_RULE_SRCS}" STREQUAL "") + set(_RULE_IS_INTERFACE 1) + else() + set(_RULE_IS_INTERFACE 0) + endif() + + if(_RULE_IS_INTERFACE) + # Generating a header-only library. + add_library(${_NAME} INTERFACE) + target_include_directories(${_NAME} + INTERFACE + "${PROJECT_SOURCE_DIR}" + ) + target_link_libraries(${_NAME} + INTERFACE + ${_RULE_DEPS} + ${_RULE_LINKOPTS} + ) + target_compile_definitions(${_NAME} + INTERFACE + ${_RULE_DEFINES} + ) + else() + # Generating a static binary library. + add_library(${_NAME} STATIC "") + target_sources(${_NAME} + PRIVATE + ${_RULE_SRCS} + ${_RULE_HDRS} + ) + ruy_include_directories(${_NAME} "${_RULE_DEPS}") + target_compile_options(${_NAME} + PRIVATE + ${_RULE_COPTS} + ) + target_link_libraries(${_NAME} + PUBLIC + ${_RULE_DEPS} + PRIVATE + ${_RULE_LINKOPTS} + ) + target_compile_definitions(${_NAME} + PUBLIC + ${_RULE_DEFINES} + ) + endif() +endfunction() diff --git a/cmake/ruy_cc_test.cmake b/cmake/ruy_cc_test.cmake new file mode 100644 index 0000000..6ad247e --- /dev/null +++ b/cmake/ruy_cc_test.cmake @@ -0,0 +1,76 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Forked from IREE's iree_cc_test.cmake. + +include(CMakeParseArguments) +include(cmake/ruy_include_directories.cmake) + +# ruy_cc_test() +# +# CMake function to imitate Bazel's cc_test rule. +function(ruy_cc_test) + cmake_parse_arguments( + _RULE + "" + "NAME" + "SRCS;COPTS;LINKOPTS;DEPS;TAGS" + ${ARGN} + ) + + if(RUY_MINIMAL_BUILD) + return() + endif() + + set(_NAME "${_RULE_NAME}") + + add_executable(${_NAME} "") + target_sources(${_NAME} + PRIVATE + ${_RULE_SRCS} + ) + set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}") + ruy_include_directories(${_NAME} "${_RULE_DEPS}") + target_compile_options(${_NAME} + PRIVATE + ${_RULE_COPTS} + ) + target_link_options(${_NAME} + PRIVATE + ${_RULE_LINKOPTS} + ) + target_link_libraries(${_NAME} + PUBLIC + ${_RULE_DEPS} + ) + if(ANDROID) + add_test( + NAME + ${_NAME} + COMMAND + "${CMAKE_SOURCE_DIR}/cmake/run_android_test.sh" + "$<TARGET_FILE:${_NAME}>" + ) + else() + add_test( + NAME + ${_NAME} + COMMAND + "$<TARGET_FILE:${_NAME}>" + ) + endif() + if (_RULE_TAGS) + set_property(TEST ${_NAME} PROPERTY LABELS ${_RULE_TAGS}) + endif() +endfunction() diff --git a/cmake/ruy_include_directories.cmake b/cmake/ruy_include_directories.cmake new file mode 100644 index 0000000..e9b50a9 --- /dev/null +++ b/cmake/ruy_include_directories.cmake @@ -0,0 +1,33 @@ +# Copyright 2019-2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +function(ruy_include_directories NAME DEPS) + target_include_directories(${NAME} + PUBLIC + "${PROJECT_SOURCE_DIR}" + ) + if (cpuinfo IN_LIST DEPS) + target_include_directories(${NAME} + PRIVATE + "${PROJECT_SOURCE_DIR}/third_party/cpuinfo/include" + ) + endif() + if ((gtest IN_LIST DEPS) OR + (gtest_main IN_LIST DEPS)) + target_include_directories(${NAME} + PRIVATE + "${PROJECT_SOURCE_DIR}/third_party/googletest/googletest" + ) + endif() +endfunction()
\ No newline at end of file diff --git a/doc/README.md b/doc/README.md new file mode 100644 index 0000000..2a03104 --- /dev/null +++ b/doc/README.md @@ -0,0 +1,3 @@ +# Ruy documentation + +This directory will eventually contain ruy documentation. diff --git a/doc/depgraph.sh b/doc/depgraph.sh new file mode 100755 index 0000000..d66d44f --- /dev/null +++ b/doc/depgraph.sh @@ -0,0 +1,153 @@ +#!/bin/bash + +# Generates a graphviz dependency graph for :ruy, with details trimmed. +# Suggested rendering: pipe to `neato` (part of graphviz standard distribution) +# doc/depgraph.sh | dot -Tsvg > depgraph.svg + +drop=( + ':platform' + ':check_macros' + ':asm_helpers' + ':size_util' + ':system_aligned_alloc' + ':side_pair' + ':opt_set' + ':blocking_counter' + ':wait' + ':time' + ':path' + ':performance_advisory' + ':tune' + ':matrix' + ':mat' + ':mul_params' + ':context_get_ctx' + ':have_built_path_for' + ':pack_common' + ':kernel_common' + ':trace' + ':validate' + 'profiler:instrumentation' + '\bclog\b' + '\bcpuinfo_impl\b' + ':apply_multiplier' + '\blabel=' +) + +graph="$(bazel query 'kind("cc_library", deps(//ruy))' --output graph --noimplicit_deps 2>/dev/null)" + +graph="$(echo "${graph}" | sed 's|//ruy/\?||g')" + +for t in "${drop[@]}"; do + graph="$(echo "${graph}" | grep -v "${t}")" +done + +graph="$(echo "${graph}" | sed 's|//:cpuinfo_with_unstripped_include_path||g')" +graph="$(echo "${graph}" | sed 's|//third_party/cpuinfo:[a-z0-9_]*|@cpuinfo|g')" + +frontend=( + ':ruy' + ':context' + ':frontend' + ':prepare_packed_matrices' + ':create_trmul_params' +) + +middleend=( + ':ctx' + ':trmul_params' + ':trmul' + ':block_map' + ':cpuinfo' + ':cpu_cache_params' + ':allocator' + ':prepacked_cache' +) + +backend=( + ':kernel.*' + ':pack.*' +) + +threadpool=( + ':thread_pool' +) + +frontend_lines=() +middleend_lines=() +backend_lines=() +threadpool_lines=() +misc_lines=() +arrow_lines=() + +while IFS= read -r line; do + if [[ "${line}" =~ '->' ]]; then + arrow_lines+=("${line}") + else + handled=false + if [ $handled = false ]; then + for f in "${frontend[@]}"; do + if [[ "${line}" =~ ${f} ]]; then + frontend_lines+=("${line}") + handled=true + break + fi + done + fi + if [ $handled = false ]; then + for f in "${middleend[@]}"; do + if [[ "${line}" =~ ${f} ]]; then + middleend_lines+=("${line}") + handled=true + break + fi + done + fi + if [ $handled = false ]; then + for f in "${backend[@]}"; do + if [[ "${line}" =~ ${f} ]]; then + backend_lines+=("${line}") + handled=true + break + fi + done + fi + if [ $handled = false ]; then + for f in "${threadpool[@]}"; do + if [[ "${line}" =~ ${f} ]]; then + threadpool_lines+=("${line}") + handled=true + break + fi + done + fi + if [ $handled = false ]; then + if [[ "${line}" =~ ^[[:space:]]+\" ]]; then + misc_lines+=("${line}") + fi + fi + fi +done <<< "${graph}" + +echo "digraph ruy {" +echo " splines = true" +echo " node [shape=box]" +for f in "${frontend_lines[@]}"; do + echo " $f [style=filled, color=\"#B2EBF2\"];" +done +for m in "${middleend_lines[@]}"; do + echo " $m [style=filled, color=\"#C8E6C9\"];" +done +for b in "${backend_lines[@]}"; do + echo " $b [style=filled, color=\"#FFCDD2\"];" +done +for b in "${threadpool_lines[@]}"; do + echo " $b [style=filled, color=\"#FFF9C4\"];" +done +for m in "${misc_lines[@]}"; do + echo "$m" +done +for a in "${arrow_lines[@]}"; do + echo "$a" +done +echo "}" diff --git a/doc/depgraph.svg b/doc/depgraph.svg new file mode 100644 index 0000000..79fb27c --- /dev/null +++ b/doc/depgraph.svg @@ -0,0 +1,377 @@ +<?xml version="1.0" encoding="UTF-8" standalone="no"?> +<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" + "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> +<!-- Generated by graphviz version 2.43.0 (0) + --> +<!-- Title: ruy Pages: 1 --> +<svg width="1007pt" height="421pt" + viewBox="0.00 0.00 1007.00 421.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"> +<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 417)"> +<title>ruy</title> +<polygon fill="white" stroke="transparent" points="-4,4 -4,-417 1003,-417 1003,4 -4,4"/> +<!-- :ruy --> +<g id="node1" class="node"> +<title>:ruy</title> +<polygon fill="#b2ebf2" stroke="#b2ebf2" points="233.5,-413 179.5,-413 179.5,-377 233.5,-377 233.5,-413"/> +<text text-anchor="middle" x="206.5" y="-391.3" font-family="Times,serif" font-size="14.00">:ruy</text> +</g> +<!-- :frontend --> +<g id="node2" class="node"> +<title>:frontend</title> +<polygon fill="#b2ebf2" stroke="#b2ebf2" points="375.5,-341 309.5,-341 309.5,-305 375.5,-305 375.5,-341"/> +<text text-anchor="middle" x="342.5" y="-319.3" font-family="Times,serif" font-size="14.00">:frontend</text> +</g> +<!-- :ruy->:frontend --> +<g id="edge2" class="edge"> +<title>:ruy->:frontend</title> +<path fill="none" stroke="black" d="M233.69,-380C252.69,-370.23 278.42,-356.98 300.09,-345.83"/> +<polygon fill="black" stroke="black" points="301.81,-348.88 309.09,-341.19 298.6,-342.66 301.81,-348.88"/> +</g> +<!-- :context --> +<g id="node5" class="node"> +<title>:context</title> +<polygon fill="#b2ebf2" stroke="#b2ebf2" points="100.5,-269 40.5,-269 40.5,-233 100.5,-233 100.5,-269"/> +<text text-anchor="middle" x="70.5" y="-247.3" font-family="Times,serif" font-size="14.00">:context</text> +</g> +<!-- :ruy->:context --> +<g id="edge1" class="edge"> +<title>:ruy->:context</title> +<path fill="none" stroke="black" d="M190.1,-376.87C166.2,-351.92 121.69,-305.45 94.23,-276.77"/> +<polygon fill="black" stroke="black" points="96.41,-273.99 86.96,-269.19 91.35,-278.83 96.41,-273.99"/> +</g> +<!-- :prepare_packed_matrices --> +<g id="node3" class="node"> +<title>:prepare_packed_matrices</title> +<polygon fill="#b2ebf2" stroke="#b2ebf2" points="315,-269 156,-269 156,-233 315,-233 315,-269"/> +<text text-anchor="middle" x="235.5" y="-247.3" font-family="Times,serif" font-size="14.00">:prepare_packed_matrices</text> +</g> +<!-- :frontend->:prepare_packed_matrices --> +<g id="edge6" class="edge"> +<title>:frontend->:prepare_packed_matrices</title> +<path fill="none" stroke="black" d="M316.32,-304.88C302.46,-295.81 285.26,-284.55 270.29,-274.76"/> +<polygon fill="black" stroke="black" points="272.06,-271.74 261.78,-269.19 268.23,-277.59 272.06,-271.74"/> +</g> +<!-- :create_trmul_params --> +<g id="node4" class="node"> +<title>:create_trmul_params</title> +<polygon fill="#b2ebf2" stroke="#b2ebf2" points="700.5,-269 564.5,-269 564.5,-233 700.5,-233 700.5,-269"/> +<text text-anchor="middle" x="632.5" y="-247.3" font-family="Times,serif" font-size="14.00">:create_trmul_params</text> +</g> +<!-- :frontend->:create_trmul_params --> +<g id="edge4" class="edge"> +<title>:frontend->:create_trmul_params</title> +<path fill="none" stroke="black" d="M375.77,-313.97C419.01,-303.53 495.91,-284.97 554.3,-270.88"/> +<polygon fill="black" stroke="black" points="555.34,-274.22 564.24,-268.48 553.7,-267.42 555.34,-274.22"/> +</g> +<!-- :trmul --> +<g id="node6" class="node"> +<title>:trmul</title> +<polygon fill="#c8e6c9" stroke="#c8e6c9" points="424.5,-269 370.5,-269 370.5,-233 424.5,-233 424.5,-269"/> +<text text-anchor="middle" x="397.5" y="-247.3" font-family="Times,serif" font-size="14.00">:trmul</text> +</g> +<!-- :frontend->:trmul --> +<g id="edge7" class="edge"> +<title>:frontend->:trmul</title> +<path fill="none" stroke="black" d="M356.1,-304.7C362.62,-296.39 370.57,-286.28 377.75,-277.14"/> +<polygon fill="black" stroke="black" points="380.63,-279.13 384.06,-269.1 375.13,-274.81 380.63,-279.13"/> +</g> +<!-- :trmul_params --> +<g id="node8" class="node"> +<title>:trmul_params</title> +<polygon fill="#c8e6c9" stroke="#c8e6c9" points="643,-197 546,-197 546,-161 643,-161 643,-197"/> +<text text-anchor="middle" x="594.5" y="-175.3" font-family="Times,serif" font-size="14.00">:trmul_params</text> +</g> +<!-- :frontend->:trmul_params --> +<g id="edge8" class="edge"> +<title>:frontend->:trmul_params</title> +<path fill="none" stroke="black" d="M372.9,-304.87C418.49,-279.18 504.61,-230.65 555.02,-202.25"/> +<polygon fill="black" stroke="black" points="557,-205.15 564,-197.19 553.56,-199.05 557,-205.15"/> +</g> +<!-- :ctx --> +<g id="node9" class="node"> +<title>:ctx</title> +<polygon fill="#c8e6c9" stroke="#c8e6c9" points="200.5,-197 146.5,-197 146.5,-161 200.5,-161 200.5,-197"/> +<text text-anchor="middle" x="173.5" y="-175.3" font-family="Times,serif" font-size="14.00">:ctx</text> +</g> +<!-- :frontend->:ctx --> +<g id="edge5" class="edge"> +<title>:frontend->:ctx</title> +<path fill="none" stroke="black" d="M309.39,-317.89C258.84,-310.86 166.21,-294.75 146.5,-269 132.35,-250.52 142.7,-224.62 154.44,-205.53"/> +<polygon fill="black" stroke="black" points="157.47,-207.31 160.04,-197.03 151.62,-203.46 157.47,-207.31"/> +</g> +<!-- :allocator --> +<g id="node13" class="node"> +<title>:allocator</title> +<polygon fill="#c8e6c9" stroke="#c8e6c9" points="382.5,-116.5 314.5,-116.5 314.5,-80.5 382.5,-80.5 382.5,-116.5"/> +<text text-anchor="middle" x="348.5" y="-94.8" font-family="Times,serif" font-size="14.00">:allocator</text> +</g> +<!-- :frontend->:allocator --> +<g id="edge3" class="edge"> +<title>:frontend->:allocator</title> +<path fill="none" stroke="black" d="M342.96,-304.91C343.99,-266.57 346.51,-173.41 347.76,-126.89"/> +<polygon fill="black" stroke="black" points="351.27,-126.74 348.04,-116.65 344.27,-126.55 351.27,-126.74"/> +</g> +<!-- :prepare_packed_matrices->:trmul_params --> +<g id="edge20" class="edge"> +<title>:prepare_packed_matrices->:trmul_params</title> +<path fill="none" stroke="black" d="M315.21,-238.43C373.78,-229.36 455.42,-215.53 536.04,-196.97"/> +<polygon fill="black" stroke="black" points="536.92,-200.35 545.87,-194.68 535.34,-193.54 536.92,-200.35"/> +</g> +<!-- :prepare_packed_matrices->:ctx --> +<g id="edge18" class="edge"> +<title>:prepare_packed_matrices->:ctx</title> +<path fill="none" stroke="black" d="M220.17,-232.7C212.74,-224.3 203.68,-214.07 195.52,-204.86"/> +<polygon fill="black" stroke="black" points="197.9,-202.27 188.65,-197.1 192.66,-206.91 197.9,-202.27"/> +</g> +<!-- :prepacked_cache --> +<g id="node10" class="node"> +<title>:prepacked_cache</title> +<polygon fill="#c8e6c9" stroke="#c8e6c9" points="113,-116.5 0,-116.5 0,-80.5 113,-80.5 113,-116.5"/> +<text text-anchor="middle" x="56.5" y="-94.8" font-family="Times,serif" font-size="14.00">:prepacked_cache</text> +</g> +<!-- :prepare_packed_matrices->:prepacked_cache --> +<g id="edge19" class="edge"> +<title>:prepare_packed_matrices->:prepacked_cache</title> +<path fill="none" stroke="black" d="M195.68,-232.78C176.98,-223.65 155,-211.35 137.5,-197 112.04,-176.13 88.67,-146.35 73.64,-125.2"/> +<polygon fill="black" stroke="black" points="76.32,-122.92 67.73,-116.72 70.58,-126.92 76.32,-122.92"/> +</g> +<!-- :prepare_packed_matrices->:allocator --> +<g id="edge17" class="edge"> +<title>:prepare_packed_matrices->:allocator</title> +<path fill="none" stroke="black" d="M248.46,-232.74C268.39,-206.2 306.64,-155.25 329.63,-124.63"/> +<polygon fill="black" stroke="black" points="332.48,-126.66 335.69,-116.56 326.89,-122.46 332.48,-126.66"/> +</g> +<!-- :create_trmul_params->:trmul_params --> +<g id="edge25" class="edge"> +<title>:create_trmul_params->:trmul_params</title> +<path fill="none" stroke="black" d="M623.11,-232.7C618.74,-224.64 613.44,-214.89 608.6,-205.98"/> +<polygon fill="black" stroke="black" points="611.63,-204.22 603.79,-197.1 605.48,-207.56 611.63,-204.22"/> +</g> +<!-- :create_trmul_params->:ctx --> +<g id="edge22" class="edge"> +<title>:create_trmul_params->:ctx</title> +<path fill="none" stroke="black" d="M564.28,-239.6C466.01,-224.61 287.97,-197.46 210.82,-185.69"/> +<polygon fill="black" stroke="black" points="211.04,-182.18 200.62,-184.14 209.98,-189.1 211.04,-182.18"/> +</g> +<!-- :create_trmul_params->:allocator --> +<g id="edge21" class="edge"> +<title>:create_trmul_params->:allocator</title> +<path fill="none" stroke="black" d="M643.65,-232.97C654.87,-213.46 668.3,-181.55 651.5,-161 635.2,-141.06 473.1,-116.46 392.93,-105.39"/> +<polygon fill="black" stroke="black" points="393.03,-101.87 382.65,-103.98 392.08,-108.81 393.03,-101.87"/> +</g> +<!-- :pack --> +<g id="node14" class="node"> +<title>:pack</title> +<polygon fill="#ffcdd2" stroke="#ffcdd2" points="753.5,-197 699.5,-197 699.5,-161 753.5,-161 753.5,-197"/> +<text text-anchor="middle" x="726.5" y="-175.3" font-family="Times,serif" font-size="14.00">:pack</text> +</g> +<!-- :create_trmul_params->:pack --> +<g id="edge24" class="edge"> +<title>:create_trmul_params->:pack</title> +<path fill="none" stroke="black" d="M655.74,-232.7C667.69,-223.8 682.42,-212.82 695.35,-203.2"/> +<polygon fill="black" stroke="black" points="697.6,-205.88 703.53,-197.1 693.42,-200.27 697.6,-205.88"/> +</g> +<!-- :kernel --> +<g id="node17" class="node"> +<title>:kernel</title> +<polygon fill="#ffcdd2" stroke="#ffcdd2" points="866.5,-197 812.5,-197 812.5,-161 866.5,-161 866.5,-197"/> +<text text-anchor="middle" x="839.5" y="-175.3" font-family="Times,serif" font-size="14.00">:kernel</text> +</g> +<!-- :create_trmul_params->:kernel --> +<g id="edge23" class="edge"> +<title>:create_trmul_params->:kernel</title> +<path fill="none" stroke="black" d="M682.87,-232.97C719.86,-220.46 769.25,-203.75 802.61,-192.48"/> +<polygon fill="black" stroke="black" points="804.06,-195.68 812.41,-189.16 801.82,-189.05 804.06,-195.68"/> +</g> +<!-- :context->:ctx --> +<g id="edge31" class="edge"> +<title>:context->:ctx</title> +<path fill="none" stroke="black" d="M95.7,-232.88C108.91,-223.89 125.29,-212.76 139.61,-203.03"/> +<polygon fill="black" stroke="black" points="141.9,-205.71 148.21,-197.19 137.97,-199.92 141.9,-205.71"/> +</g> +<!-- :context->:prepacked_cache --> +<g id="edge32" class="edge"> +<title>:context->:prepacked_cache</title> +<path fill="none" stroke="black" d="M51.9,-232.76C42.97,-223.2 33.21,-210.53 28.5,-197 20.19,-173.13 30,-145.38 40.22,-125.6"/> +<polygon fill="black" stroke="black" points="43.43,-127.03 45.19,-116.59 37.3,-123.65 43.43,-127.03"/> +</g> +<!-- :context->:allocator --> +<g id="edge30" class="edge"> +<title>:context->:allocator</title> +<path fill="none" stroke="black" d="M78.54,-232.97C89.1,-212.53 109.63,-178.8 137.5,-161 201.86,-119.9 234.39,-152.84 305.5,-125 308.23,-123.93 310.98,-122.7 313.71,-121.37"/> +<polygon fill="black" stroke="black" points="315.61,-124.33 322.81,-116.55 312.33,-118.14 315.61,-124.33"/> +</g> +<!-- :thread_pool --> +<g id="node20" class="node"> +<title>:thread_pool</title> +<polygon fill="#fff9c4" stroke="#fff9c4" points="216,-116.5 131,-116.5 131,-80.5 216,-80.5 216,-116.5"/> +<text text-anchor="middle" x="173.5" y="-94.8" font-family="Times,serif" font-size="14.00">:thread_pool</text> +</g> +<!-- :context->:thread_pool --> +<g id="edge33" class="edge"> +<title>:context->:thread_pool</title> +<path fill="none" stroke="black" d="M64.6,-232.97C59.14,-214.05 53.64,-183.22 66.5,-161 70.31,-154.41 102.87,-136.18 131.08,-121.27"/> +<polygon fill="black" stroke="black" points="132.82,-124.31 140.04,-116.56 129.56,-118.11 132.82,-124.31"/> +</g> +<!-- :block_map --> +<g id="node7" class="node"> +<title>:block_map</title> +<polygon fill="#c8e6c9" stroke="#c8e6c9" points="528,-197 447,-197 447,-161 528,-161 528,-197"/> +<text text-anchor="middle" x="487.5" y="-175.3" font-family="Times,serif" font-size="14.00">:block_map</text> +</g> +<!-- :trmul->:block_map --> +<g id="edge10" class="edge"> +<title>:trmul->:block_map</title> +<path fill="none" stroke="black" d="M419.75,-232.7C431.08,-223.88 445.03,-213.03 457.32,-203.47"/> +<polygon fill="black" stroke="black" points="459.76,-206.01 465.51,-197.1 455.47,-200.48 459.76,-206.01"/> +</g> +<!-- :trmul->:trmul_params --> +<g id="edge15" class="edge"> +<title>:trmul->:trmul_params</title> +<path fill="none" stroke="black" d="M424.89,-240.27C453.72,-230.02 499.94,-213.6 536.85,-200.48"/> +<polygon fill="black" stroke="black" points="538.25,-203.7 546.51,-197.05 535.91,-197.1 538.25,-203.7"/> +</g> +<!-- :trmul->:ctx --> +<g id="edge13" class="edge"> +<title>:trmul->:ctx</title> +<path fill="none" stroke="black" d="M370.23,-243.12C334.21,-233.79 268.96,-216.11 214.5,-197 213.09,-196.51 211.66,-195.99 210.22,-195.46"/> +<polygon fill="black" stroke="black" points="211.27,-192.11 200.68,-191.77 208.75,-198.64 211.27,-192.11"/> +</g> +<!-- :cpuinfo --> +<g id="node11" class="node"> +<title>:cpuinfo</title> +<polygon fill="#c8e6c9" stroke="#c8e6c9" points="296.5,-116.5 234.5,-116.5 234.5,-80.5 296.5,-80.5 296.5,-116.5"/> +<text text-anchor="middle" x="265.5" y="-94.8" font-family="Times,serif" font-size="14.00">:cpuinfo</text> +</g> +<!-- :trmul->:cpuinfo --> +<g id="edge12" class="edge"> +<title>:trmul->:cpuinfo</title> +<path fill="none" stroke="black" d="M382.36,-232.74C358.98,-206.09 314.02,-154.82 287.19,-124.23"/> +<polygon fill="black" stroke="black" points="289.69,-121.77 280.47,-116.56 284.43,-126.39 289.69,-121.77"/> +</g> +<!-- :cpu_cache_params --> +<g id="node12" class="node"> +<title>:cpu_cache_params</title> +<polygon fill="#c8e6c9" stroke="#c8e6c9" points="480.5,-36 356.5,-36 356.5,0 480.5,0 480.5,-36"/> +<text text-anchor="middle" x="418.5" y="-14.3" font-family="Times,serif" font-size="14.00">:cpu_cache_params</text> +</g> +<!-- :trmul->:cpu_cache_params --> +<g id="edge11" class="edge"> +<title>:trmul->:cpu_cache_params</title> +<path fill="none" stroke="black" d="M399.08,-232.64C402.71,-192.74 411.66,-94.27 416.02,-46.24"/> +<polygon fill="black" stroke="black" points="419.51,-46.52 416.93,-36.25 412.54,-45.89 419.51,-46.52"/> +</g> +<!-- :trmul->:allocator --> +<g id="edge9" class="edge"> +<title>:trmul->:allocator</title> +<path fill="none" stroke="black" d="M391.88,-232.74C383.39,-206.65 367.22,-156.99 357.2,-126.21"/> +<polygon fill="black" stroke="black" points="360.48,-124.99 354.06,-116.56 353.82,-127.16 360.48,-124.99"/> +</g> +<!-- :trmul->:thread_pool --> +<g id="edge14" class="edge"> +<title>:trmul->:thread_pool</title> +<path fill="none" stroke="black" d="M371.15,-232.94C355.27,-222.61 334.67,-209.14 316.5,-197 278.68,-171.74 235.6,-142.27 206.7,-122.4"/> +<polygon fill="black" stroke="black" points="208.43,-119.34 198.2,-116.55 204.46,-125.1 208.43,-119.34"/> +</g> +<!-- :block_map->:cpu_cache_params --> +<g id="edge16" class="edge"> +<title>:block_map->:cpu_cache_params</title> +<path fill="none" stroke="black" d="M480.12,-160.98C468.14,-133.4 444.4,-78.69 430.14,-45.82"/> +<polygon fill="black" stroke="black" points="433.23,-44.16 426.04,-36.38 426.81,-46.94 433.23,-44.16"/> +</g> +<!-- :ctx->:prepacked_cache --> +<g id="edge36" class="edge"> +<title>:ctx->:prepacked_cache</title> +<path fill="none" stroke="black" d="M148.11,-160.97C131.17,-149.6 108.7,-134.53 90.26,-122.15"/> +<polygon fill="black" stroke="black" points="92.09,-119.17 81.84,-116.5 88.19,-124.98 92.09,-119.17"/> +</g> +<!-- :ctx->:cpuinfo --> +<g id="edge35" class="edge"> +<title>:ctx->:cpuinfo</title> +<path fill="none" stroke="black" d="M193.46,-160.97C206.48,-149.86 223.65,-135.21 237.96,-123"/> +<polygon fill="black" stroke="black" points="240.24,-125.65 245.58,-116.5 235.7,-120.33 240.24,-125.65"/> +</g> +<!-- :ctx->:allocator --> +<g id="edge34" class="edge"> +<title>:ctx->:allocator</title> +<path fill="none" stroke="black" d="M200.57,-168.52C227.63,-158.72 270.18,-142.43 305.5,-125 307.75,-123.89 310.03,-122.71 312.32,-121.48"/> +<polygon fill="black" stroke="black" points="314.06,-124.52 321.09,-116.6 310.65,-118.41 314.06,-124.52"/> +</g> +<!-- :ctx->:thread_pool --> +<g id="edge37" class="edge"> +<title>:ctx->:thread_pool</title> +<path fill="none" stroke="black" d="M173.5,-160.97C173.5,-150.99 173.5,-138.15 173.5,-126.8"/> +<polygon fill="black" stroke="black" points="177,-126.5 173.5,-116.5 170,-126.5 177,-126.5"/> +</g> +<!-- :cpuinfo->:cpu_cache_params --> +<g id="edge38" class="edge"> +<title>:cpuinfo->:cpu_cache_params</title> +<path fill="none" stroke="black" d="M291.46,-80.27C296.07,-77.41 300.88,-74.54 305.5,-72 326.07,-60.68 349.47,-49.45 369.61,-40.26"/> +<polygon fill="black" stroke="black" points="371.25,-43.36 378.92,-36.06 368.36,-36.98 371.25,-43.36"/> +</g> +<!-- @cpuinfo --> +<g id="node21" class="node"> +<title>@cpuinfo</title> +<polygon fill="none" stroke="black" points="301,-36 230,-36 230,0 301,0 301,-36"/> +<text text-anchor="middle" x="265.5" y="-14.3" font-family="Times,serif" font-size="14.00">@cpuinfo</text> +</g> +<!-- :cpuinfo->@cpuinfo --> +<g id="edge39" class="edge"> +<title>:cpuinfo->@cpuinfo</title> +<path fill="none" stroke="black" d="M265.5,-80.47C265.5,-70.49 265.5,-57.65 265.5,-46.3"/> +<polygon fill="black" stroke="black" points="269,-46 265.5,-36 262,-46 269,-46"/> +</g> +<!-- :pack_avx2_fma\n:pack_avx512\n:pack_avx --> +<g id="node15" class="node"> +<title>:pack_avx2_fma\n:pack_avx512\n:pack_avx</title> +<polygon fill="#ffcdd2" stroke="#ffcdd2" points="671,-125 564,-125 564,-72 671,-72 671,-125"/> +<text text-anchor="middle" x="617.5" y="-109.8" font-family="Times,serif" font-size="14.00">:pack_avx2_fma</text> +<text text-anchor="middle" x="617.5" y="-94.8" font-family="Times,serif" font-size="14.00">:pack_avx512</text> +<text text-anchor="middle" x="617.5" y="-79.8" font-family="Times,serif" font-size="14.00">:pack_avx</text> +</g> +<!-- :pack->:pack_avx2_fma\n:pack_avx512\n:pack_avx --> +<g id="edge27" class="edge"> +<title>:pack->:pack_avx2_fma\n:pack_avx512\n:pack_avx</title> +<path fill="none" stroke="black" d="M702.85,-160.97C690.71,-152.22 675.52,-141.28 661.33,-131.07"/> +<polygon fill="black" stroke="black" points="663.28,-128.16 653.12,-125.16 659.19,-133.84 663.28,-128.16"/> +</g> +<!-- :pack_arm --> +<g id="node16" class="node"> +<title>:pack_arm</title> +<polygon fill="#ffcdd2" stroke="#ffcdd2" points="763.5,-116.5 689.5,-116.5 689.5,-80.5 763.5,-80.5 763.5,-116.5"/> +<text text-anchor="middle" x="726.5" y="-94.8" font-family="Times,serif" font-size="14.00">:pack_arm</text> +</g> +<!-- :pack->:pack_arm --> +<g id="edge26" class="edge"> +<title>:pack->:pack_arm</title> +<path fill="none" stroke="black" d="M726.5,-160.97C726.5,-150.99 726.5,-138.15 726.5,-126.8"/> +<polygon fill="black" stroke="black" points="730,-126.5 726.5,-116.5 723,-126.5 730,-126.5"/> +</g> +<!-- :kernel_avx\n:kernel_avx512\n:kernel_avx2_fma --> +<g id="node18" class="node"> +<title>:kernel_avx\n:kernel_avx512\n:kernel_avx2_fma</title> +<polygon fill="#ffcdd2" stroke="#ffcdd2" points="897.5,-125 781.5,-125 781.5,-72 897.5,-72 897.5,-125"/> +<text text-anchor="middle" x="839.5" y="-109.8" font-family="Times,serif" font-size="14.00">:kernel_avx</text> +<text text-anchor="middle" x="839.5" y="-94.8" font-family="Times,serif" font-size="14.00">:kernel_avx512</text> +<text text-anchor="middle" x="839.5" y="-79.8" font-family="Times,serif" font-size="14.00">:kernel_avx2_fma</text> +</g> +<!-- :kernel->:kernel_avx\n:kernel_avx512\n:kernel_avx2_fma --> +<g id="edge29" class="edge"> +<title>:kernel->:kernel_avx\n:kernel_avx512\n:kernel_avx2_fma</title> +<path fill="none" stroke="black" d="M839.5,-160.97C839.5,-153.45 839.5,-144.31 839.5,-135.4"/> +<polygon fill="black" stroke="black" points="843,-135.16 839.5,-125.16 836,-135.16 843,-135.16"/> +</g> +<!-- :kernel_arm --> +<g id="node19" class="node"> +<title>:kernel_arm</title> +<polygon fill="#ffcdd2" stroke="#ffcdd2" points="999,-116.5 916,-116.5 916,-80.5 999,-80.5 999,-116.5"/> +<text text-anchor="middle" x="957.5" y="-94.8" font-family="Times,serif" font-size="14.00">:kernel_arm</text> +</g> +<!-- :kernel->:kernel_arm --> +<g id="edge28" class="edge"> +<title>:kernel->:kernel_arm</title> +<path fill="none" stroke="black" d="M865.1,-160.97C882.19,-149.6 904.85,-134.53 923.45,-122.15"/> +<polygon fill="black" stroke="black" points="925.56,-124.95 931.95,-116.5 921.68,-119.12 925.56,-124.95"/> +</g> +</g> +</svg> diff --git a/example/BUILD b/example/BUILD new file mode 100644 index 0000000..738c33e --- /dev/null +++ b/example/BUILD @@ -0,0 +1,16 @@ +package( + licenses = ["notice"], # Apache 2.0 +) + +# Usage examples. +cc_binary( + name = "example", + srcs = ["example.cc"], + deps = ["//ruy"], +) + +cc_binary( + name = "parametrized_example", + srcs = ["parametrized_example.cc"], + deps = ["//ruy"], +) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt new file mode 100644 index 0000000..63fbd66 --- /dev/null +++ b/example/CMakeLists.txt @@ -0,0 +1,23 @@ +# This file is generated (whence no license header). Do not edit! +# To regenerate, run: +# cmake/bazel_to_cmake.sh + +ruy_cc_binary( + NAME + ruy_example_example + SRCS + example.cc + DEPS + ruy +) + +ruy_cc_binary( + NAME + ruy_example_parametrized_example + SRCS + parametrized_example.cc + DEPS + ruy +) + +ruy_add_all_subdirs() diff --git a/example/README.md b/example/README.md new file mode 100644 index 0000000..e29ee12 --- /dev/null +++ b/example/README.md @@ -0,0 +1,14 @@ +## Introduction + +These are some examples about how to use RUY. + +## BUILD + +Build the example with bazel commands: +``` +bazel build //ruy/example:example +``` +You can find the generated target under directory: +``` +./bazel-bin/ruy/example +``` diff --git a/example/example.cc b/example/example.cc new file mode 100644 index 0000000..3bb95f4 --- /dev/null +++ b/example/example.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <iostream> + +#include "ruy/ruy.h" + +void ExampleMulFloat(ruy::Context *context) { + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2, 3, 4}; + float dst_data[4]; + + ruy::Matrix<float> lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout()); + lhs.set_data(lhs_data); + ruy::Matrix<float> rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout()); + rhs.set_data(rhs_data); + ruy::Matrix<float> dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout()); + dst.set_data(dst_data); + + ruy::MulParams<float, float> mul_params; + ruy::Mul(lhs, rhs, mul_params, context, &dst); + + std::cout << "Example Mul, float:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) { + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2, 3, 4}; + const float bias_data[] = {1, 0}; + float dst_data[4]; + + ruy::Matrix<float> lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout()); + lhs.set_data(lhs_data); + ruy::Matrix<float> rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout()); + rhs.set_data(rhs_data); + ruy::Matrix<float> dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout()); + dst.set_data(dst_data); + + ruy::MulParams<float, float> mul_params; + mul_params.set_bias(bias_data); + mul_params.set_clamp_min(0); + mul_params.set_clamp_max(15); + ruy::Mul(lhs, rhs, mul_params, context, &dst); + + std::cout << "Example Mul, float with bias addition and clamp:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) { + const std::uint8_t lhs_data[] = {124, 125, 126, 127}; + const std::uint8_t rhs_data[] = {129, 130, 131, 132}; + std::uint8_t dst_data[4]; + + ruy::Matrix<std::uint8_t> lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout()); + lhs.set_data(lhs_data); + lhs.set_zero_point(125); + ruy::Matrix<std::uint8_t> rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout()); + rhs.set_data(rhs_data); + rhs.set_zero_point(132); + ruy::Matrix<std::uint8_t> dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout()); + dst.set_data(dst_data); + dst.set_zero_point(129); + + ruy::MulParams<std::int32_t, std::uint8_t> mul_params; + mul_params.set_multiplier_fixedpoint(1 << 30); + + mul_params.set_multiplier_exponent(0); + ruy::Mul(lhs, rhs, mul_params, context, &dst); + + std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} +void ExampleMulInt8PerChannelQuantized(ruy::Context *context) { + const std::int8_t lhs_data[] = {1, 2, 3, 4}; + const std::int8_t rhs_data[] = {1, 2, 3, 4}; + const std::int32_t multiplier_data[] = {3 << 28, 5 << 28}; + const int exponent_data[] = {1, -2}; + std::int8_t dst_data[4]; + + ruy::Matrix<std::int8_t> lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout()); + lhs.set_data(lhs_data); + ruy::Matrix<std::int8_t> rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout()); + rhs.set_data(rhs_data); + ruy::Matrix<std::int8_t> dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout()); + dst.set_data(dst_data); + + ruy::MulParams<std::int32_t, std::int8_t> mul_params; + mul_params.set_multiplier_fixedpoint_perchannel(multiplier_data); + mul_params.set_multiplier_exponent_perchannel(exponent_data); + ruy::Mul(lhs, rhs, mul_params, context, &dst); + + std::cout << "Example Mul, int8 quantized with per-channel multipliers\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} +void ExampleMulInt8GetRawAccumulators(ruy::Context *context) { + const std::int8_t lhs_data[] = {1, 2, 3, 4}; + const std::int8_t rhs_data[] = {1, 2, 3, 4}; + std::int32_t dst_data[4]; + + ruy::Matrix<std::int8_t> lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout()); + lhs.set_data(lhs_data); + ruy::Matrix<std::int8_t> rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout()); + rhs.set_data(rhs_data); + ruy::Matrix<std::int32_t> dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout()); + dst.set_data(dst_data); + + // When Dst is int32, mul_params is unused. + ruy::MulParams<std::int32_t, std::int32_t> mul_params; + ruy::Mul(lhs, rhs, mul_params, context, &dst); + + std::cout << "Example Mul, returning raw int32 accumulators:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +int main() { + ruy::Context context; + ExampleMulFloat(&context); + ExampleMulFloatWithBiasAddAndClamp(&context); + ExampleMulUint8AsymmetricQuantized(&context); + ExampleMulInt8PerChannelQuantized(&context); + ExampleMulInt8GetRawAccumulators(&context); +} diff --git a/example/parametrized_example.cc b/example/parametrized_example.cc new file mode 100644 index 0000000..ef6ad23 --- /dev/null +++ b/example/parametrized_example.cc @@ -0,0 +1,198 @@ +#include <cstdint> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <type_traits> + +#include "ruy/context.h" +#include "ruy/matrix.h" +#include "ruy/mul_params.h" +#include "ruy/ruy.h" + +template <typename... Dst> +void read_cmdline_args(bool help, int argc, char* argv[], const char* name, + const char* format, const char* default_value, + const char* allowed_values, Dst... dst) { + if (help) { + fprintf(stderr, "%-20s %-12s %-16s %s\n", name, format, default_value, + allowed_values ? allowed_values : ""); + return; + } + const char* value = default_value; + for (int i = 1; i < argc; i++) { + if (std::strstr(argv[i], name) == argv[i]) { + const char* equal_sign = std::strchr(argv[i], '='); + if (equal_sign == argv[i] + std::strlen(name)) { + value = equal_sign + 1; + } + break; + } + } + if (allowed_values) { + if (!std::strstr(allowed_values, value)) { + fprintf(stderr, "Illegal value %s. The legal values are %s.\n", value, + allowed_values); + exit(1); + } + } + if (sizeof...(Dst) != sscanf(value, format, dst...)) { + fprintf(stderr, "Failed to parse %s\n", value); + exit(1); + } +} + +struct Params { + char types[100]; + int m, k, n; // matmul shape m*k*n + int paths; + int num_threads; + int repeat; + int lhs_cache_policy; + int rhs_cache_policy; + int lhs_stride; + int rhs_stride; + int dst_stride; + int lhs_zero_point; + int rhs_zero_point; + int dst_zero_point; + char lhs_order[100]; + char rhs_order[100]; + char dst_order[100]; +}; + +template <typename LhsType, typename RhsType, typename DstType> +void run(const Params& params) { + using AccumType = + typename std::conditional<std::is_floating_point<DstType>::value, DstType, + std::int32_t>::type; + + ruy::Matrix<LhsType> lhs; + ruy::Matrix<RhsType> rhs; + ruy::Matrix<DstType> dst; + + auto parse_order = [](const char* name) { + if (!std::strcmp(name, "row-major")) { + return ruy::Order::kRowMajor; + } else if (!std::strcmp(name, "column-major")) { + return ruy::Order::kColMajor; + } else { + fprintf(stderr, "Failed to parse %s\n", name); + exit(1); + } + }; + + auto make_layout = [](int rows, int cols, int stride, ruy::Order order, + ruy::Layout* layout) { + layout->set_rows(rows); + layout->set_cols(cols); + layout->set_order(order); + int base_stride = order == ruy::Order::kRowMajor ? cols : rows; + layout->set_stride(stride ? stride : base_stride); + }; + + make_layout(params.m, params.k, params.lhs_stride, + parse_order(params.lhs_order), lhs.mutable_layout()); + make_layout(params.k, params.n, params.rhs_stride, + parse_order(params.rhs_order), rhs.mutable_layout()); + make_layout(params.m, params.n, params.dst_stride, + parse_order(params.dst_order), dst.mutable_layout()); + + lhs.set_zero_point(params.lhs_zero_point); + rhs.set_zero_point(params.rhs_zero_point); + dst.set_zero_point(params.dst_zero_point); + + lhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.lhs_cache_policy)); + rhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.rhs_cache_policy)); + + auto flat_size = [](const ruy::Layout& layout) { + int outer_size = + layout.order() == ruy::Order::kRowMajor ? layout.rows() : layout.cols(); + return outer_size * layout.stride(); + }; + + std::vector<LhsType> lhs_buf(flat_size(lhs.layout())); + std::vector<RhsType> rhs_buf(flat_size(rhs.layout())); + std::vector<DstType> dst_buf(flat_size(dst.layout())); + + lhs.set_data(lhs_buf.data()); + rhs.set_data(rhs_buf.data()); + dst.set_data(dst_buf.data()); + + ruy::Context context; + context.set_max_num_threads(params.num_threads); + context.set_runtime_enabled_paths(static_cast<ruy::Path>(params.paths)); + + ruy::MulParams<AccumType, DstType> mul_params; + // Here an actual application might set some mul_params fields. + // Quantization multipliers, bias-vector, clamp bounds, etc. + + for (int r = 0; r < params.repeat; r++) { + ruy::Mul(lhs, rhs, mul_params, &context, &dst); + } +} + +int main(int argc, char* argv[]) { + bool help = argc == 1 || (argc == 2 && !strcmp(argv[1], "--help")); + if (help) { + fprintf(stderr, "Command-line flags (all in the form --flag=value):\n"); + fprintf(stderr, "%-20s %-12s %-16s %s\n", "flag", "format", "default", + "allowed"); + } + Params params; + const char* allowed_types = + "f32xf32->f32, i8xi8->i8, i8xi8->i16, i8xi8->i32, u8xu8->i16, u8xi8->u8"; + const char* allowed_orders = "row-major, column-major"; + read_cmdline_args(help, argc, argv, "--types", "%s", "f32xf32->f32", + allowed_types, ¶ms.types); + read_cmdline_args(help, argc, argv, "--shape", "%dx%dx%d", "100x100x100", + nullptr, ¶ms.m, ¶ms.k, ¶ms.n); + read_cmdline_args(help, argc, argv, "--paths", "%x", "0", nullptr, + ¶ms.paths); + read_cmdline_args(help, argc, argv, "--num_threads", "%d", "1", nullptr, + ¶ms.num_threads); + read_cmdline_args(help, argc, argv, "--repeat", "%d", "1", nullptr, + ¶ms.repeat); + read_cmdline_args(help, argc, argv, "--lhs_cache_policy", "%d", "0", + "0, 1, 2, 3", ¶ms.lhs_cache_policy); + read_cmdline_args(help, argc, argv, "--rhs_cache_policy", "%d", "0", + "0, 1, 2, 3", ¶ms.rhs_cache_policy); + read_cmdline_args(help, argc, argv, "--lhs_stride", "%d", "0", nullptr, + ¶ms.lhs_stride); + read_cmdline_args(help, argc, argv, "--rhs_stride", "%d", "0", nullptr, + ¶ms.rhs_stride); + read_cmdline_args(help, argc, argv, "--dst_stride", "%d", "0", nullptr, + ¶ms.dst_stride); + read_cmdline_args(help, argc, argv, "--lhs_zero_point", "%d", "0", nullptr, + ¶ms.lhs_zero_point); + read_cmdline_args(help, argc, argv, "--rhs_zero_point", "%d", "0", nullptr, + ¶ms.rhs_zero_point); + read_cmdline_args(help, argc, argv, "--dst_zero_point", "%d", "0", nullptr, + ¶ms.dst_zero_point); + read_cmdline_args(help, argc, argv, "--lhs_order", "%s", "row-major", + allowed_orders, ¶ms.lhs_order); + read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major", + allowed_orders, ¶ms.rhs_order); + read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major", + allowed_orders, ¶ms.dst_order); + + if (help) { + exit(1); + } + + if (!strcmp(params.types, "f32xf32->f32")) { + run<float, float, float>(params); + } else if (!strcmp(params.types, "i8xi8->i8")) { + run<std::int8_t, std::int8_t, std::int8_t>(params); + } else if (!strcmp(params.types, "i8xi8->i16")) { + run<std::int8_t, std::int8_t, std::int16_t>(params); + } else if (!strcmp(params.types, "i8xi8->i32")) { + run<std::int8_t, std::int8_t, std::int32_t>(params); + } else if (!strcmp(params.types, "u8xu8->i16")) { + run<std::uint8_t, std::uint8_t, std::int16_t>(params); + } else if (!strcmp(params.types, "u8xi8->u8")) { + run<std::uint8_t, std::int8_t, std::uint8_t>(params); + } else { + fprintf(stderr, "Unknown types: %s\n", params.types); + exit(1); + } +} diff --git a/ruy/BUILD b/ruy/BUILD new file mode 100644 index 0000000..37e89ab --- /dev/null +++ b/ruy/BUILD @@ -0,0 +1,1220 @@ +# Ruy is not BLAS + +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") +load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx", "ruy_copts_avx2_fma", "ruy_copts_avx512") +load(":build_defs.oss.bzl", "ruy_linkopts_thread_standard_library") +load(":ruy_test_ext.oss.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps") +load(":ruy_test.bzl", "ruy_benchmark", "ruy_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +config_setting( + name = "armeabi-v7a", + values = {"cpu": "armeabi-v7a"}, +) + +config_setting( + name = "armv7a", + values = {"cpu": "armv7a"}, +) + +# Detect ARM 32-bit targets where we are going to just assume NEON support. +selects.config_setting_group( + name = "arm32_assuming_neon", + match_any = [ + ":armeabi-v7a", + ":armv7a", + ], +) + +config_setting( + name = "x86_64_k8", + values = {"cpu": "k8"}, +) + +config_setting( + name = "x86_64_haswell", + values = {"cpu": "haswell"}, +) + +# MSVC toolchains define a different "cpu" value, which helps us as we need +# to pass different flags on MSVC vs GCC-compatible toolchains to enable +# x86 SIMD extensions. +selects.config_setting_group( + name = "x86_64_and_not_msvc", + match_any = [ + ":x86_64_k8", + ":x86_64_haswell", + ], +) + +config_setting( + name = "ppc", + values = { + "cpu": "ppc", + }, +) + +config_setting( + name = "s390x", + values = { + "cpu": "s390x", + }, +) + +config_setting( + name = "fuchsia", + values = {"cpu": "fuchsia"}, +) + +config_setting( + name = "dbg_build", + values = { + "compilation_mode": "dbg", + }, +) + +config_setting( + name = "fastbuild_build", + values = { + "compilation_mode": "fastbuild", + }, +) + +selects.config_setting_group( + name = "do_not_want_O3", + match_any = [ + "@bazel_tools//src/conditions:windows_msvc", + ":dbg_build", + ":fastbuild_build", + ], +) + +cc_library( + name = "trace", + hdrs = ["trace.h"], + copts = ruy_copts(), + deps = [ + ":mat", + ":matrix", + ":path", + ":platform", + ":side_pair", + ], +) + +cc_library( + name = "platform", + hdrs = ["platform.h"], + copts = ruy_copts(), +) + +cc_library( + name = "gtest_wrapper", + testonly = True, + hdrs = ["gtest_wrapper.h"], + visibility = [":__subpackages__"], + deps = ["@com_google_googletest//:gtest"], +) + +cc_library( + name = "check_macros", + hdrs = ["check_macros.h"], + copts = ruy_copts(), +) + +cc_test( + name = "check_macros_test", + srcs = ["check_macros_test.cc"], + copts = ruy_copts(), + deps = [ + ":check_macros", + ":gtest_wrapper", + ], +) + +cc_library( + name = "opt_set", + hdrs = ["opt_set.h"], + copts = ruy_copts(), +) + +cc_library( + name = "time", + hdrs = ["time.h"], + copts = ruy_copts(), +) + +cc_library( + name = "wait", + srcs = ["wait.cc"], + hdrs = ["wait.h"], + copts = ruy_copts(), + linkopts = ruy_linkopts_thread_standard_library(), + deps = [":time"], +) + +cc_test( + name = "wait_test", + srcs = ["wait_test.cc"], + copts = ruy_copts(), + linkopts = ruy_linkopts_thread_standard_library(), + deps = [ + ":gtest_wrapper", + ":platform", + ":wait", + ], +) + +cc_library( + name = "size_util", + hdrs = ["size_util.h"], + copts = ruy_copts(), + deps = [":check_macros"], +) + +cc_test( + name = "size_util_test", + srcs = ["size_util_test.cc"], + copts = ruy_copts(), + deps = [ + ":gtest_wrapper", + ":size_util", + ], +) + +cc_library( + name = "tune", + srcs = [ + "tune.cc", + ], + hdrs = [ + "tune.h", + ], + copts = ruy_copts(), + deps = [ + ":cpu_cache_params", + ":cpuinfo", + ":opt_set", + ":platform", + ":time", + ], +) + +cc_library( + name = "system_aligned_alloc", + srcs = [ + "system_aligned_alloc.cc", + ], + hdrs = [ + "system_aligned_alloc.h", + ], + copts = ruy_copts(), +) + +cc_library( + name = "prepacked_cache", + srcs = [ + "prepacked_cache.cc", + ], + hdrs = [ + "prepacked_cache.h", + ], + copts = ruy_copts(), + deps = [ + ":mat", + ":system_aligned_alloc", + "//ruy/profiler:instrumentation", + ], +) + +cc_test( + name = "tune_test", + srcs = ["tune_test.cc"], + copts = ruy_copts(), + deps = [ + ":cpuinfo", + ":gtest_wrapper", + ":tune", + ], +) + +cc_test( + name = "prepacked_cache_test", + srcs = ["prepacked_cache_test.cc"], + copts = ruy_copts(), + deps = [ + ":context", + ":context_get_ctx", + ":ctx", + ":gtest_wrapper", + ":mat", + ":matrix", + ":prepacked_cache", + ":ruy", + ":time", + ], +) + +cc_library( + name = "allocator", + srcs = [ + "allocator.cc", + ], + hdrs = [ + "allocator.h", + ], + copts = ruy_copts(), + deps = [ + ":opt_set", + ":size_util", + ":system_aligned_alloc", + ], +) + +cc_test( + name = "allocator_test", + srcs = ["allocator_test.cc"], + copts = ruy_copts(), + deps = [ + ":allocator", + ":gtest_wrapper", + ], +) + +cc_library( + name = "side_pair", + hdrs = ["side_pair.h"], + copts = ruy_copts(), + deps = [":check_macros"], +) + +cc_library( + name = "block_map", + srcs = [ + "block_map.cc", + ], + hdrs = [ + "block_map.h", + ], + copts = ruy_copts(), + deps = [ + ":check_macros", + ":cpu_cache_params", + ":opt_set", + ":side_pair", + ":size_util", + ":trace", + "//ruy/profiler:instrumentation", + ], +) + +cc_test( + name = "block_map_test", + srcs = ["block_map_test.cc"], + copts = ruy_copts(), + deps = [ + ":block_map", + ":cpu_cache_params", + ":gtest_wrapper", + ":path", + ":platform", + ":side_pair", + ], +) + +cc_library( + name = "blocking_counter", + srcs = [ + "blocking_counter.cc", + ], + hdrs = [ + "blocking_counter.h", + ], + copts = ruy_copts(), + linkopts = ruy_linkopts_thread_standard_library(), + deps = [ + ":check_macros", + ":time", + ":wait", + ], +) + +cc_library( + name = "thread_pool", + srcs = [ + "thread_pool.cc", + ], + hdrs = [ + "thread_pool.h", + ], + copts = ruy_copts(), + linkopts = ruy_linkopts_thread_standard_library(), + visibility = ["//visibility:public"], + deps = [ + ":blocking_counter", + ":check_macros", + ":time", + ":trace", + ":wait", + ], +) + +cc_library( + name = "cpu_cache_params", + hdrs = ["cpu_cache_params.h"], + copts = ruy_copts(), +) + +cc_library( + name = "cpuinfo", + srcs = [ + "cpuinfo.cc", + ], + hdrs = [ + "cpuinfo.h", + ], + copts = ruy_copts() + + select({ + "@bazel_tools//src/conditions:windows": [], + "//conditions:default": [ + # ruy_copts contains -Wundef, but cpuinfo's header warns with that. + "-Wno-undef", + ], + }) + select({ + # This select must match the similar select in `deps`. + # We intentionally define this token in the BUILD + # file so that ports to other build-systems do not + # use cpuinfo by default - they need to port the + # cpuinfo BUILD first, then can define this token. + ":ppc": [], + ":s390x": [], + ":fuchsia": [], + "//conditions:default": ["-DRUY_HAVE_CPUINFO"], + }), + deps = [ + ":platform", + ":check_macros", + ":cpu_cache_params", + ] + select({ + # This select must match the similar select in `copts` + ":ppc": [], + ":s390x": [], + ":fuchsia": [], + "//conditions:default": ["@cpuinfo"], + }), +) + +cc_library( + name = "path", + hdrs = ["path.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], + deps = [ + ":platform", + ":size_util", + ], +) + +cc_library( + name = "performance_advisory", + hdrs = ["performance_advisory.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], +) + +cc_library( + name = "matrix", + hdrs = ["matrix.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], + deps = [":check_macros"], +) + +cc_test( + name = "matrix_test", + srcs = ["matrix_test.cc"], + copts = ruy_copts(), + deps = [ + ":gtest_wrapper", + ":matrix", + ], +) + +cc_library( + name = "mul_params", + hdrs = ["mul_params.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], + deps = [ + ":check_macros", + ":size_util", + ], +) + +cc_test( + name = "mul_params_test", + srcs = ["mul_params_test.cc"], + copts = ruy_copts(), + deps = [ + ":gtest_wrapper", + ":mul_params", + ], +) + +cc_library( + name = "mat", + hdrs = ["mat.h"], + copts = ruy_copts(), + deps = [ + ":check_macros", + ":matrix", + ":size_util", + ], +) + +cc_library( + name = "asm_helpers", + hdrs = [ + "asm_helpers.h", + ], + copts = ruy_copts(), + deps = [ + ":opt_set", + ], +) + +cc_library( + name = "apply_multiplier", + srcs = ["apply_multiplier.cc"], + hdrs = ["apply_multiplier.h"], + copts = ruy_copts(), + deps = [ + ":check_macros", + ":mul_params", + ], +) + +cc_test( + name = "apply_multiplier_test", + srcs = ["apply_multiplier_test.cc"], + copts = ruy_copts(), + deps = [ + ":apply_multiplier", + ":gtest_wrapper", + ":mul_params", + ], +) + +cc_library( + name = "kernel_common", + hdrs = [ + "kernel_common.h", + ], + copts = ruy_copts(), + deps = [ + ":apply_multiplier", + ":check_macros", + ":mat", + ":matrix", + ":mul_params", + ":opt_set", + ":path", + ":platform", + ":side_pair", + ":size_util", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_common", + hdrs = [ + "pack_common.h", + ], + copts = ruy_copts(), + deps = [ + ":check_macros", + ":mat", + ":matrix", + ":opt_set", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "kernel_arm", + srcs = [ + "kernel_arm32.cc", + "kernel_arm64.cc", + ], + hdrs = ["kernel_arm.h"], + copts = ruy_copts(), + deps = [ + ":asm_helpers", + ":check_macros", + ":kernel_common", + ":mat", + ":mul_params", + ":opt_set", + ":path", + ":platform", + ":side_pair", + ":size_util", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_arm", + srcs = [ + "pack_arm.cc", + ], + hdrs = [ + "pack_arm.h", + ], + copts = ruy_copts(), + deps = [ + ":asm_helpers", + ":check_macros", + ":mat", + ":opt_set", + ":pack_common", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "kernel_avx512", + srcs = [ + "kernel_avx512.cc", + ], + hdrs = [ + "kernel_x86.h", + ], + copts = ruy_copts() + ruy_copts_avx512(), + deps = [ + ":check_macros", + ":kernel_common", + ":mat", + ":mul_params", + ":opt_set", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avx512", + srcs = [ + "pack_avx512.cc", + ], + hdrs = [ + "pack_x86.h", + ], + copts = ruy_copts() + ruy_copts_avx512(), + deps = [ + ":check_macros", + ":mat", + ":opt_set", + ":pack_common", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avx512", + srcs = [ + "have_built_path_for_avx512.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = ruy_copts() + ruy_copts_avx512(), + deps = [ + ":opt_set", + ":platform", + ], +) + +cc_library( + name = "kernel_avx2_fma", + srcs = [ + "kernel_avx2_fma.cc", + ], + hdrs = [ + "kernel_x86.h", + ], + copts = ruy_copts() + ruy_copts_avx2_fma(), + deps = [ + ":check_macros", + ":kernel_common", + ":mat", + ":mul_params", + ":opt_set", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avx2_fma", + srcs = [ + "pack_avx2_fma.cc", + ], + hdrs = [ + "pack_x86.h", + ], + copts = ruy_copts() + ruy_copts_avx2_fma(), + deps = [ + ":check_macros", + ":mat", + ":opt_set", + ":pack_common", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avx2_fma", + srcs = [ + "have_built_path_for_avx2_fma.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = ruy_copts() + ruy_copts_avx2_fma(), + deps = [ + ":opt_set", + ":platform", + ], +) + +cc_library( + name = "kernel_avx", + srcs = [ + "kernel_avx.cc", + ], + hdrs = [ + "kernel_x86.h", + ], + copts = ruy_copts() + ruy_copts_avx(), + deps = [ + ":check_macros", + ":kernel_common", + ":mat", + ":mul_params", + ":opt_set", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avx", + srcs = [ + "pack_avx.cc", + ], + hdrs = [ + "pack_x86.h", + ], + copts = ruy_copts() + ruy_copts_avx(), + deps = [ + ":check_macros", + ":mat", + ":opt_set", + ":pack_common", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avx", + srcs = [ + "have_built_path_for_avx.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = ruy_copts() + ruy_copts_avx(), + deps = [ + ":opt_set", + ":platform", + ], +) + +cc_library( + name = "kernel", + hdrs = [ + "kernel.h", + ], + copts = ruy_copts(), + deps = [ + ":apply_multiplier", + ":check_macros", + ":kernel_arm", # fixdeps: keep + ":kernel_avx", + ":kernel_avx2_fma", # fixdeps: keep + ":kernel_avx512", # fixdeps: keep + ":kernel_common", + ":mat", + ":matrix", + ":mul_params", + ":opt_set", + ":path", + ":platform", + ":side_pair", + ":size_util", + ":trace", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack", + hdrs = [ + "pack.h", + ], + copts = ruy_copts(), + deps = [ + ":check_macros", + ":mat", + ":matrix", + ":opt_set", + ":pack_arm", # fixdeps: keep + ":pack_avx", # fixdeps: keep + ":pack_avx2_fma", # fixdeps: keep + ":pack_avx512", # fixdeps: keep + ":pack_common", + ":path", + ":platform", + ":trace", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for", + hdrs = [ + "have_built_path_for.h", + ], + deps = [ + ":have_built_path_for_avx", + ":have_built_path_for_avx2_fma", + ":have_built_path_for_avx512", + ":platform", + ], +) + +cc_library( + name = "context", + srcs = ["context.cc"], + hdrs = [ + "context.h", + ], + copts = ruy_copts(), + visibility = ["//visibility:public"], + deps = [ + ":allocator", + ":check_macros", + ":ctx", + ":path", + ":performance_advisory", + ":platform", + ":prepacked_cache", + ":thread_pool", + ":tune", + ], +) + +cc_test( + name = "context_test", + srcs = ["context_test.cc"], + copts = ruy_copts(), + deps = [ + ":context", + ":gtest_wrapper", + ":path", + ":platform", + ":prepacked_cache", + ":tune", + ], +) + +cc_library( + name = "ctx_header_only_should_not_include_other_ruy_headers", + testonly = True, + hdrs = [ + "ctx.h", + ], + # Intentionally no deps. This will cause the stand-alone build of ctx.h to + # fail if ctx.h #includes any other ruy header. +) + +cc_library( + name = "ctx", + srcs = [ + "ctx.cc", + ], + hdrs = [ + "ctx.h", + "ctx_impl.h", + ], + copts = ruy_copts(), + deps = [ + ":allocator", + ":check_macros", + ":cpuinfo", + ":have_built_path_for", + ":path", + ":performance_advisory", + ":platform", + ":prepacked_cache", + ":thread_pool", + ":trace", + ":tune", + ], +) + +cc_library( + name = "context_get_ctx", + srcs = [ + "context_get_ctx.cc", + ], + hdrs = [ + "context_get_ctx.h", + ], + copts = ruy_copts(), + deps = [ + ":context", + ":ctx", + ], +) + +cc_test( + name = "ctx_test", + srcs = ["ctx_test.cc"], + copts = ruy_copts(), + deps = [ + ":ctx", + ":gtest_wrapper", + ":path", + ":platform", + ], +) + +cc_library( + name = "trmul_params", + hdrs = ["trmul_params.h"], + copts = ruy_copts(), + deps = [ + ":mat", + ":mul_params", + ":path", + ":side_pair", + ":tune", + ], +) + +cc_library( + name = "trmul", + srcs = ["trmul.cc"], + hdrs = ["trmul.h"], + copts = ruy_copts(), + deps = [ + ":allocator", + ":block_map", + ":check_macros", + ":cpu_cache_params", + ":cpuinfo", + ":ctx", + ":mat", + ":matrix", + ":mul_params", + ":opt_set", + ":side_pair", + ":size_util", + ":thread_pool", + ":trace", + ":trmul_params", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "prepare_packed_matrices", + srcs = ["prepare_packed_matrices.cc"], + hdrs = ["prepare_packed_matrices.h"], + copts = ruy_copts(), + deps = [ + ":allocator", + ":ctx", + ":matrix", + ":prepacked_cache", + ":side_pair", + ":trace", + ":trmul_params", + ], +) + +cc_library( + name = "create_trmul_params", + hdrs = ["create_trmul_params.h"], + copts = ruy_copts(), + deps = [ + ":allocator", + ":check_macros", + ":ctx", + ":kernel", + ":mat", + ":mul_params", + ":pack", + ":path", + ":performance_advisory", + ":platform", + ":side_pair", + ":trace", + ":trmul_params", + ], +) + +cc_library( + name = "validate", + hdrs = ["validate.h"], + copts = ruy_copts(), + deps = [ + ":check_macros", + ":mat", + ":mul_params", + ":side_pair", + ], +) + +cc_library( + name = "frontend", + srcs = [ + "frontend.cc", + ], + hdrs = [ + "frontend.h", + ], + copts = ruy_copts(), + deps = [ + ":allocator", + ":create_trmul_params", + ":ctx", + ":mat", + ":mul_params", + ":prepare_packed_matrices", + ":trace", + ":trmul", + ":trmul_params", + ":validate", + "//ruy/profiler:instrumentation", + ], +) + +# The main library. +cc_library( + name = "ruy", + hdrs = [ + "context.h", + "matrix.h", + "mul_params.h", + "path.h", + "ruy.h", + ], + copts = ruy_copts(), + visibility = ["//visibility:public"], + deps = [ + ":check_macros", + ":context", + ":context_get_ctx", + ":frontend", + ":mat", + ":matrix", + ":mul_params", + ":path", + ":platform", + ":size_util", + ":trace", + ], +) + +cc_test( + name = "perchannel_buffers_reallocation_test", + srcs = ["perchannel_buffers_reallocation_test.cc"], + copts = ruy_copts(), + deps = [ + ":context", + ":gtest_wrapper", + ":kernel", + ":matrix", + ":path", + ":performance_advisory", + ":ruy", + ], +) + +# Small library to query PMU counters, for benchmark only +cc_library( + name = "pmu", + testonly = True, + srcs = ["pmu.cc"], + hdrs = ["pmu.h"], + copts = ruy_copts(), + deps = [":check_macros"], +) + +cc_library( + name = "reference_mul", + hdrs = ["reference_mul.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], + deps = [ + ":apply_multiplier", + ":matrix", + ":mul_params", + ], +) + +# Testing framework. +cc_library( + name = "test_lib", + testonly = True, + hdrs = ["test.h"], + copts = ruy_copts(), + # need defines, not copts, because it's controlling a header, test.h + defines = ruy_test_ext_defines(), + linkopts = select({ + "@bazel_tools//src/conditions:windows": [], + "//conditions:default": ["-lm"], + }), + deps = [ + ":allocator", + ":size_util", + ":reference_mul", + ":matrix", + ":pmu", + ":ruy", + ":mul_params", + ":time", + ":gtest_wrapper", + ":platform", + ":context", + ":ctx", + ":context_get_ctx", + ":pack_common", + "//ruy/profiler", + ] + ruy_test_ext_deps(), +) + +ruy_benchmark( + name = "benchmark", + srcs = ["benchmark.cc"], + copts = ruy_copts(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ], + deps = [ + ":test_lib", + "//ruy/profiler:instrumentation", + ], +) + +ruy_test( + name = "test_fast", + srcs = ["test_fast.cc"], + copts = ruy_copts(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("f64", "f32", "f64", "f32"), + ("f32", "f64", "f64", "f64"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("i8", "u8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ("i8", "u8", "i32", "i32"), + ], + deps = [ + ":test_lib", + "@com_google_googletest//:gtest_main", + ], +) + +ruy_test( + name = "test_slow", + srcs = ["test_slow.cc"], + copts = ruy_copts(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ], + tags = ["slow"], + deps = [ + ":test_lib", + "@com_google_googletest//:gtest_main", + ], +) + +bzl_library( + name = "ruy_test_ext.oss_bzl", + srcs = ["ruy_test_ext.oss.bzl"], + visibility = ["//visibility:private"], +) + +bzl_library( + name = "ruy_test_bzl", + srcs = ["ruy_test.bzl"], + visibility = ["//visibility:private"], +) + +bzl_library( + name = "build_defs.oss_bzl", + srcs = ["build_defs.oss.bzl"], + visibility = ["//visibility:private"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/ruy/CMakeLists.txt b/ruy/CMakeLists.txt new file mode 100644 index 0000000..4c3e394 --- /dev/null +++ b/ruy/CMakeLists.txt @@ -0,0 +1,1696 @@ +# This file is generated (whence no license header). Do not edit! +# To regenerate, run: +# cmake/bazel_to_cmake.sh + +if(CMAKE_SYSTEM_NAME STREQUAL Windows) + set(ruy_0_Wall_Wcxx14_compat_Wextra_Wundef "") +else() + set(ruy_0_Wall_Wcxx14_compat_Wextra_Wundef "-Wall;-Wextra;-Wc++14-compat;-Wundef") +endif() + +if(CMAKE_SYSTEM_PROCESSOR STREQUAL arm) + set(ruy_1_mfpu_neon "-mfpu=neon") +else() + set(ruy_1_mfpu_neon "") +endif() + +if((CMAKE_BUILD_TYPE STREQUAL Debug) OR MSVC) + set(ruy_2_O3 "") +else() + set(ruy_2_O3 "-O3") +endif() + +ruy_cc_library( + NAME + ruy_trace + HDRS + trace.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_mat + ruy_matrix + ruy_path + ruy_platform + ruy_side_pair +) + +ruy_cc_library( + NAME + ruy_platform + HDRS + platform.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} +) + +ruy_cc_library( + NAME + ruy_gtest_wrapper + TESTONLY + HDRS + gtest_wrapper.h + DEPS + gtest +) + +ruy_cc_library( + NAME + ruy_check_macros + HDRS + check_macros.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} +) + +ruy_cc_test( + NAME + ruy_check_macros_test + SRCS + check_macros_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros + ruy_gtest_wrapper +) + +ruy_cc_library( + NAME + ruy_opt_set + HDRS + opt_set.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} +) + +ruy_cc_library( + NAME + ruy_time + HDRS + time.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} +) + +if(CMAKE_SYSTEM_NAME STREQUAL Windows) + set(ruy_3_pthread "") +else() + set(ruy_3_pthread "-pthread") +endif() + +ruy_cc_library( + NAME + ruy_wait + SRCS + wait.cc + HDRS + wait.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + LINKOPTS + ${ruy_3_pthread} + DEPS + ruy_time +) + +ruy_cc_test( + NAME + ruy_wait_test + SRCS + wait_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + LINKOPTS + ${ruy_3_pthread} + DEPS + ruy_gtest_wrapper + ruy_platform + ruy_wait +) + +ruy_cc_library( + NAME + ruy_size_util + HDRS + size_util.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros +) + +ruy_cc_test( + NAME + ruy_size_util_test + SRCS + size_util_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_gtest_wrapper + ruy_size_util +) + +ruy_cc_library( + NAME + ruy_tune + SRCS + tune.cc + HDRS + tune.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_cpu_cache_params + ruy_cpuinfo + ruy_opt_set + ruy_platform + ruy_time +) + +ruy_cc_library( + NAME + ruy_system_aligned_alloc + SRCS + system_aligned_alloc.cc + HDRS + system_aligned_alloc.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} +) + +ruy_cc_library( + NAME + ruy_prepacked_cache + SRCS + prepacked_cache.cc + HDRS + prepacked_cache.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_mat + ruy_system_aligned_alloc + ruy_profiler_instrumentation +) + +ruy_cc_test( + NAME + ruy_tune_test + SRCS + tune_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_cpuinfo + ruy_gtest_wrapper + ruy_tune +) + +ruy_cc_test( + NAME + ruy_prepacked_cache_test + SRCS + prepacked_cache_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_context + ruy_context_get_ctx + ruy_ctx + ruy_gtest_wrapper + ruy_mat + ruy_matrix + ruy_prepacked_cache + ruy + ruy_time +) + +ruy_cc_library( + NAME + ruy_allocator + SRCS + allocator.cc + HDRS + allocator.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_opt_set + ruy_size_util + ruy_system_aligned_alloc +) + +ruy_cc_test( + NAME + ruy_allocator_test + SRCS + allocator_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_allocator + ruy_gtest_wrapper +) + +ruy_cc_library( + NAME + ruy_side_pair + HDRS + side_pair.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros +) + +ruy_cc_library( + NAME + ruy_block_map + SRCS + block_map.cc + HDRS + block_map.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros + ruy_cpu_cache_params + ruy_opt_set + ruy_side_pair + ruy_size_util + ruy_trace + ruy_profiler_instrumentation +) + +ruy_cc_test( + NAME + ruy_block_map_test + SRCS + block_map_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_block_map + ruy_cpu_cache_params + ruy_gtest_wrapper + ruy_path + ruy_platform + ruy_side_pair +) + +ruy_cc_library( + NAME + ruy_blocking_counter + SRCS + blocking_counter.cc + HDRS + blocking_counter.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + LINKOPTS + ${ruy_3_pthread} + DEPS + ruy_check_macros + ruy_time + ruy_wait +) + +ruy_cc_library( + NAME + ruy_thread_pool + SRCS + thread_pool.cc + HDRS + thread_pool.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + LINKOPTS + ${ruy_3_pthread} + PUBLIC + DEPS + ruy_blocking_counter + ruy_check_macros + ruy_time + ruy_trace + ruy_wait +) + +ruy_cc_library( + NAME + ruy_cpu_cache_params + HDRS + cpu_cache_params.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} +) + +if(CMAKE_SYSTEM_NAME STREQUAL Windows) + set(ruy_4_Wno_undef "") +else() + set(ruy_4_Wno_undef "-Wno-undef") +endif() + +if(CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le) + set(ruy_5_DRUY_HAVE_CPUINFO "") +elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x) + set(ruy_5_DRUY_HAVE_CPUINFO "") +elseif(CMAKE_SYSTEM_NAME STREQUAL Fuchsia) + set(ruy_5_DRUY_HAVE_CPUINFO "") +else() + set(ruy_5_DRUY_HAVE_CPUINFO "-DRUY_HAVE_CPUINFO") +endif() + +if(CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL ppc64le) + set(ruy_6_cpuinfo "") +elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL s390 OR CMAKE_SYSTEM_PROCESSOR STREQUAL s390x) + set(ruy_6_cpuinfo "") +elseif(CMAKE_SYSTEM_NAME STREQUAL Fuchsia) + set(ruy_6_cpuinfo "") +else() + set(ruy_6_cpuinfo "cpuinfo") +endif() + +ruy_cc_library( + NAME + ruy_cpuinfo + SRCS + cpuinfo.cc + HDRS + cpuinfo.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_4_Wno_undef} + ${ruy_5_DRUY_HAVE_CPUINFO} + DEPS + ruy_platform + ruy_check_macros + ruy_cpu_cache_params + ${ruy_6_cpuinfo} +) + +ruy_cc_library( + NAME + ruy_path + HDRS + path.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC + DEPS + ruy_platform + ruy_size_util +) + +ruy_cc_library( + NAME + ruy_performance_advisory + HDRS + performance_advisory.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC +) + +ruy_cc_library( + NAME + ruy_matrix + HDRS + matrix.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC + DEPS + ruy_check_macros +) + +ruy_cc_test( + NAME + ruy_matrix_test + SRCS + matrix_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_gtest_wrapper + ruy_matrix +) + +ruy_cc_library( + NAME + ruy_mul_params + HDRS + mul_params.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC + DEPS + ruy_check_macros + ruy_size_util +) + +ruy_cc_test( + NAME + ruy_mul_params_test + SRCS + mul_params_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_gtest_wrapper + ruy_mul_params +) + +ruy_cc_library( + NAME + ruy_mat + HDRS + mat.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros + ruy_matrix + ruy_size_util +) + +ruy_cc_library( + NAME + ruy_asm_helpers + HDRS + asm_helpers.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_opt_set +) + +ruy_cc_library( + NAME + ruy_apply_multiplier + SRCS + apply_multiplier.cc + HDRS + apply_multiplier.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros + ruy_mul_params +) + +ruy_cc_test( + NAME + ruy_apply_multiplier_test + SRCS + apply_multiplier_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_apply_multiplier + ruy_gtest_wrapper + ruy_mul_params +) + +ruy_cc_library( + NAME + ruy_kernel_common + HDRS + kernel_common.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_apply_multiplier + ruy_check_macros + ruy_mat + ruy_matrix + ruy_mul_params + ruy_opt_set + ruy_path + ruy_platform + ruy_side_pair + ruy_size_util + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_pack_common + HDRS + pack_common.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros + ruy_mat + ruy_matrix + ruy_opt_set + ruy_path + ruy_platform + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_kernel_arm + SRCS + kernel_arm32.cc + kernel_arm64.cc + HDRS + kernel_arm.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_asm_helpers + ruy_check_macros + ruy_kernel_common + ruy_mat + ruy_mul_params + ruy_opt_set + ruy_path + ruy_platform + ruy_side_pair + ruy_size_util + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_pack_arm + SRCS + pack_arm.cc + HDRS + pack_arm.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_asm_helpers + ruy_check_macros + ruy_mat + ruy_opt_set + ruy_pack_common + ruy_path + ruy_platform + ruy_tune + ruy_profiler_instrumentation +) + +if((CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC) + set(ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512 ";-mavx512f;-mavx512vl;-mavx512cd;-mavx512bw;-mavx512dq") +elseif(MSVC) + set(ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512 "/arch:AVX512") +else() + set(ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512 "") +endif() + +ruy_cc_library( + NAME + ruy_kernel_avx512 + SRCS + kernel_avx512.cc + HDRS + kernel_x86.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512} + DEPS + ruy_check_macros + ruy_kernel_common + ruy_mat + ruy_mul_params + ruy_opt_set + ruy_path + ruy_platform + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_pack_avx512 + SRCS + pack_avx512.cc + HDRS + pack_x86.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512} + DEPS + ruy_check_macros + ruy_mat + ruy_opt_set + ruy_pack_common + ruy_path + ruy_platform + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_have_built_path_for_avx512 + SRCS + have_built_path_for_avx512.cc + HDRS + have_built_path_for.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_7_mavx512bw_mavx512cd_mavx512dq_mavx512f_mavx512vl_arch_AVX512} + DEPS + ruy_opt_set + ruy_platform +) + +if((CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC) + set(ruy_8_mavx2_mfma_arch_AVX2 "-mavx2;-mfma") +elseif(MSVC) + set(ruy_8_mavx2_mfma_arch_AVX2 "/arch:AVX2") +else() + set(ruy_8_mavx2_mfma_arch_AVX2 "") +endif() + +ruy_cc_library( + NAME + ruy_kernel_avx2_fma + SRCS + kernel_avx2_fma.cc + HDRS + kernel_x86.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_8_mavx2_mfma_arch_AVX2} + DEPS + ruy_check_macros + ruy_kernel_common + ruy_mat + ruy_mul_params + ruy_opt_set + ruy_path + ruy_platform + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_pack_avx2_fma + SRCS + pack_avx2_fma.cc + HDRS + pack_x86.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_8_mavx2_mfma_arch_AVX2} + DEPS + ruy_check_macros + ruy_mat + ruy_opt_set + ruy_pack_common + ruy_path + ruy_platform + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_have_built_path_for_avx2_fma + SRCS + have_built_path_for_avx2_fma.cc + HDRS + have_built_path_for.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_8_mavx2_mfma_arch_AVX2} + DEPS + ruy_opt_set + ruy_platform +) + +if((CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL amd64) AND NOT MSVC) + set(ruy_9_mavx_arch_AVX "-mavx") +elseif(MSVC) + set(ruy_9_mavx_arch_AVX "/arch:AVX") +else() + set(ruy_9_mavx_arch_AVX "") +endif() + +ruy_cc_library( + NAME + ruy_kernel_avx + SRCS + kernel_avx.cc + HDRS + kernel_x86.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_9_mavx_arch_AVX} + DEPS + ruy_check_macros + ruy_kernel_common + ruy_mat + ruy_mul_params + ruy_opt_set + ruy_path + ruy_platform + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_pack_avx + SRCS + pack_avx.cc + HDRS + pack_x86.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_9_mavx_arch_AVX} + DEPS + ruy_check_macros + ruy_mat + ruy_opt_set + ruy_pack_common + ruy_path + ruy_platform + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_have_built_path_for_avx + SRCS + have_built_path_for_avx.cc + HDRS + have_built_path_for.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + ${ruy_9_mavx_arch_AVX} + DEPS + ruy_opt_set + ruy_platform +) + +ruy_cc_library( + NAME + ruy_kernel + HDRS + kernel.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_apply_multiplier + ruy_check_macros + ruy_kernel_arm + ruy_kernel_avx + ruy_kernel_avx2_fma + ruy_kernel_avx512 + ruy_kernel_common + ruy_mat + ruy_matrix + ruy_mul_params + ruy_opt_set + ruy_path + ruy_platform + ruy_side_pair + ruy_size_util + ruy_trace + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_pack + HDRS + pack.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros + ruy_mat + ruy_matrix + ruy_opt_set + ruy_pack_arm + ruy_pack_avx + ruy_pack_avx2_fma + ruy_pack_avx512 + ruy_pack_common + ruy_path + ruy_platform + ruy_trace + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_have_built_path_for + HDRS + have_built_path_for.h + DEPS + ruy_have_built_path_for_avx + ruy_have_built_path_for_avx2_fma + ruy_have_built_path_for_avx512 + ruy_platform +) + +ruy_cc_library( + NAME + ruy_context + SRCS + context.cc + HDRS + context.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC + DEPS + ruy_allocator + ruy_check_macros + ruy_ctx + ruy_path + ruy_performance_advisory + ruy_platform + ruy_prepacked_cache + ruy_thread_pool + ruy_tune +) + +ruy_cc_test( + NAME + ruy_context_test + SRCS + context_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_context + ruy_gtest_wrapper + ruy_path + ruy_platform + ruy_prepacked_cache + ruy_tune +) + +ruy_cc_library( + NAME + ruy_ctx_header_only_should_not_include_other_ruy_headers + TESTONLY + HDRS + ctx.h +) + +ruy_cc_library( + NAME + ruy_ctx + SRCS + ctx.cc + HDRS + ctx.h + ctx_impl.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_allocator + ruy_check_macros + ruy_cpuinfo + ruy_have_built_path_for + ruy_path + ruy_performance_advisory + ruy_platform + ruy_prepacked_cache + ruy_thread_pool + ruy_trace + ruy_tune +) + +ruy_cc_library( + NAME + ruy_context_get_ctx + SRCS + context_get_ctx.cc + HDRS + context_get_ctx.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_context + ruy_ctx +) + +ruy_cc_test( + NAME + ruy_ctx_test + SRCS + ctx_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_ctx + ruy_gtest_wrapper + ruy_path + ruy_platform +) + +ruy_cc_library( + NAME + ruy_trmul_params + HDRS + trmul_params.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_mat + ruy_mul_params + ruy_path + ruy_side_pair + ruy_tune +) + +ruy_cc_library( + NAME + ruy_trmul + SRCS + trmul.cc + HDRS + trmul.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_allocator + ruy_block_map + ruy_check_macros + ruy_cpu_cache_params + ruy_cpuinfo + ruy_ctx + ruy_mat + ruy_matrix + ruy_mul_params + ruy_opt_set + ruy_side_pair + ruy_size_util + ruy_thread_pool + ruy_trace + ruy_trmul_params + ruy_tune + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_prepare_packed_matrices + SRCS + prepare_packed_matrices.cc + HDRS + prepare_packed_matrices.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_allocator + ruy_ctx + ruy_matrix + ruy_prepacked_cache + ruy_side_pair + ruy_trace + ruy_trmul_params +) + +ruy_cc_library( + NAME + ruy_create_trmul_params + HDRS + create_trmul_params.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_allocator + ruy_check_macros + ruy_ctx + ruy_kernel + ruy_mat + ruy_mul_params + ruy_pack + ruy_path + ruy_performance_advisory + ruy_platform + ruy_side_pair + ruy_trace + ruy_trmul_params +) + +ruy_cc_library( + NAME + ruy_validate + HDRS + validate.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros + ruy_mat + ruy_mul_params + ruy_side_pair +) + +ruy_cc_library( + NAME + ruy_frontend + SRCS + frontend.cc + HDRS + frontend.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_allocator + ruy_create_trmul_params + ruy_ctx + ruy_mat + ruy_mul_params + ruy_prepare_packed_matrices + ruy_trace + ruy_trmul + ruy_trmul_params + ruy_validate + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy + HDRS + context.h + matrix.h + mul_params.h + path.h + ruy.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC + DEPS + ruy_check_macros + ruy_context + ruy_context_get_ctx + ruy_frontend + ruy_mat + ruy_matrix + ruy_mul_params + ruy_path + ruy_platform + ruy_size_util + ruy_trace +) + +ruy_cc_test( + NAME + ruy_perchannel_buffers_reallocation_test + SRCS + perchannel_buffers_reallocation_test.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_context + ruy_gtest_wrapper + ruy_kernel + ruy_matrix + ruy_path + ruy_performance_advisory + ruy +) + +ruy_cc_library( + NAME + ruy_pmu + TESTONLY + SRCS + pmu.cc + HDRS + pmu.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + DEPS + ruy_check_macros +) + +ruy_cc_library( + NAME + ruy_reference_mul + HDRS + reference_mul.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC + DEPS + ruy_apply_multiplier + ruy_matrix + ruy_mul_params +) + +if(CMAKE_SYSTEM_NAME STREQUAL Windows) + set(ruy_10_lm "") +else() + set(ruy_10_lm "-lm") +endif() + +ruy_cc_library( + NAME + ruy_test_lib + TESTONLY + HDRS + test.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + LINKOPTS + ${ruy_10_lm} + DEPS + ruy_allocator + ruy_size_util + ruy_reference_mul + ruy_matrix + ruy_pmu + ruy + ruy_mul_params + ruy_time + ruy_gtest_wrapper + ruy_platform + ruy_context + ruy_ctx + ruy_context_get_ctx + ruy_pack_common + ruy_profiler_profiler +) + +ruy_cc_binary( + NAME + ruy_benchmark_f32_f32_f32_f32 + TESTONLY + SRCS + benchmark.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=f32 + -DRUY_TEST_RHSSCALAR=f32 + -DRUY_TEST_ACCUMSCALAR=f32 + -DRUY_TEST_DSTSCALAR=f32 + DEPS + ruy_test_lib + ruy_profiler_instrumentation +) + +ruy_cc_binary( + NAME + ruy_benchmark_u8_u8_i32_u8 + TESTONLY + SRCS + benchmark.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=u8 + -DRUY_TEST_RHSSCALAR=u8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=u8 + DEPS + ruy_test_lib + ruy_profiler_instrumentation +) + +ruy_cc_binary( + NAME + ruy_benchmark_i8_i8_i32_u8 + TESTONLY + SRCS + benchmark.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=i8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=u8 + DEPS + ruy_test_lib + ruy_profiler_instrumentation +) + +ruy_cc_binary( + NAME + ruy_benchmark_i8_i8_i32_i8 + TESTONLY + SRCS + benchmark.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=i8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i8 + DEPS + ruy_test_lib + ruy_profiler_instrumentation +) + +ruy_cc_binary( + NAME + ruy_benchmark_u8_u8_i32_i16 + TESTONLY + SRCS + benchmark.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=u8 + -DRUY_TEST_RHSSCALAR=u8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i16 + DEPS + ruy_test_lib + ruy_profiler_instrumentation +) + +ruy_cc_binary( + NAME + ruy_benchmark_i8_i8_i32_i32 + TESTONLY + SRCS + benchmark.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=i8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i32 + DEPS + ruy_test_lib + ruy_profiler_instrumentation +) + +ruy_cc_test( + NAME + ruy_test_fast_f32_f32_f32_f32 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=f32 + -DRUY_TEST_RHSSCALAR=f32 + -DRUY_TEST_ACCUMSCALAR=f32 + -DRUY_TEST_DSTSCALAR=f32 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_fast_f64_f32_f64_f32 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=f64 + -DRUY_TEST_RHSSCALAR=f32 + -DRUY_TEST_ACCUMSCALAR=f64 + -DRUY_TEST_DSTSCALAR=f32 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_fast_f32_f64_f64_f64 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=f32 + -DRUY_TEST_RHSSCALAR=f64 + -DRUY_TEST_ACCUMSCALAR=f64 + -DRUY_TEST_DSTSCALAR=f64 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_fast_u8_u8_i32_u8 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=u8 + -DRUY_TEST_RHSSCALAR=u8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=u8 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_fast_i8_i8_i32_i8 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=i8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i8 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_fast_i8_u8_i32_i8 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=u8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i8 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_fast_u8_u8_i32_i16 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=u8 + -DRUY_TEST_RHSSCALAR=u8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i16 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_fast_i8_i8_i32_i32 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=i8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i32 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_fast_i8_u8_i32_i32 + SRCS + test_fast.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=u8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i32 + DEPS + ruy_test_lib + gtest_main +) + +ruy_cc_test( + NAME + ruy_test_slow_f32_f32_f32_f32 + SRCS + test_slow.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=f32 + -DRUY_TEST_RHSSCALAR=f32 + -DRUY_TEST_ACCUMSCALAR=f32 + -DRUY_TEST_DSTSCALAR=f32 + DEPS + ruy_test_lib + gtest_main + TAGS + slow +) + +ruy_cc_test( + NAME + ruy_test_slow_u8_u8_i32_u8 + SRCS + test_slow.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=u8 + -DRUY_TEST_RHSSCALAR=u8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=u8 + DEPS + ruy_test_lib + gtest_main + TAGS + slow +) + +ruy_cc_test( + NAME + ruy_test_slow_i8_i8_i32_i8 + SRCS + test_slow.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=i8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i8 + DEPS + ruy_test_lib + gtest_main + TAGS + slow +) + +ruy_cc_test( + NAME + ruy_test_slow_u8_u8_i32_i16 + SRCS + test_slow.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=u8 + -DRUY_TEST_RHSSCALAR=u8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i16 + DEPS + ruy_test_lib + gtest_main + TAGS + slow +) + +ruy_cc_test( + NAME + ruy_test_slow_i8_i8_i32_i32 + SRCS + test_slow.cc + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + -DRUY_TEST_LHSSCALAR=i8 + -DRUY_TEST_RHSSCALAR=i8 + -DRUY_TEST_ACCUMSCALAR=i32 + -DRUY_TEST_DSTSCALAR=i32 + DEPS + ruy_test_lib + gtest_main + TAGS + slow +) + +ruy_add_all_subdirs() diff --git a/ruy/allocator.cc b/ruy/allocator.cc new file mode 100644 index 0000000..64da664 --- /dev/null +++ b/ruy/allocator.cc @@ -0,0 +1,124 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/allocator.h" + +#include "ruy/opt_set.h" +#include "ruy/size_util.h" +#include "ruy/system_aligned_alloc.h" + +namespace ruy { + +Allocator::~Allocator() { + FreeAll(); + detail::SystemAlignedFree(ptr_); +} + +void* Allocator::AllocateFast(std::ptrdiff_t num_bytes) { + if (current_ + num_bytes > size_) { + return nullptr; + } + void* ret = static_cast<char*>(ptr_) + current_; + current_ += num_bytes; + return ret; +} + +void* Allocator::AllocateSlow(std::ptrdiff_t num_bytes) { + void* p = detail::SystemAlignedAlloc(num_bytes); + fallback_blocks_total_size_ += num_bytes; + fallback_blocks_.push_back(p); + return p; +} + +void* Allocator::AllocateBytes(std::ptrdiff_t num_bytes) { + if (num_bytes == 0) { + return nullptr; + } + const std::ptrdiff_t rounded_num_bytes = + round_up_pot(num_bytes, detail::kMinimumBlockAlignment); + if (void* p = AllocateFast(rounded_num_bytes)) { + return p; + } + return AllocateSlow(rounded_num_bytes); +} + +void* Allocator::AllocateBytesAvoidingAliasingWith(std::ptrdiff_t num_bytes, + const void* to_avoid) { +#if RUY_OPT(AVOID_ALIASING) + if (num_bytes == 0) { + return nullptr; + } + // The minimum L1D cache aliasing periodicity in bytes that we expect to + // encounter on any device. This does not seem to be documented, but + // empirically we observe the following: + // Cortex-A53: 1024 + // Cortex-A55r1: 2048 + // Cortex-A76: not as easily observable. + // Over-estimating this makes the AVOID_ALIASING optimization useless on + // devices with lower periodicity. + // Under-estimating this by 2x should be harmless. + // Under-estimating this by a larger factor should gradually degrade + // performance due to cache aliasing causing mutual eviction between + // the packed matrix data, and the source matrix data being prefetched by the + // CPU ahead of the packing code execution. + static constexpr std::uint32_t kMinPeriod = 1024; + static_assert(is_pot(kMinPeriod), ""); + void* p = AllocateBytes(num_bytes + kMinPeriod); + auto unsigned_low_bits = [](const void* p) { + return static_cast<std::uint32_t>(reinterpret_cast<std::uintptr_t>(p)); + }; + // This relies on unsigned integer overflow wrapping around. + std::uint32_t diff_modulus = + (unsigned_low_bits(p) - unsigned_low_bits(to_avoid)) % kMinPeriod; + // diff_modulus is in [0, kMinPeriod). + // We want it as close as possible to the middle of that interval, + // kMinPeriod/2. The bad 'aliasing' case, that we are working to avoid, + // is when diff_modulus is close to the ends of that interval, 0 or + // kMinPeriod. So we want to add an offset of kMinPeriod/2 if it is in the + // first or the last quarter of that interval. + bool need_offset = + diff_modulus < kMinPeriod / 4 || diff_modulus > 3 * kMinPeriod / 4; + return static_cast<char*>(p) + (need_offset ? (kMinPeriod / 2) : 0); +#else + (void)to_avoid; + return AllocateBytes(num_bytes); +#endif +} + +void Allocator::FreeAll() { + current_ = 0; + if (fallback_blocks_.empty()) { + return; + } + + // No rounding-up of the size means linear instead of logarithmic + // bound on the number of allocation in some worst-case calling patterns. + // This is considered worth it because minimizing memory usage is important + // and actual calling patterns in applications that we care about still + // reach the no-further-allocations steady state in a small finite number + // of iterations. + std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_; + detail::SystemAlignedFree(ptr_); + ptr_ = detail::SystemAlignedAlloc(new_size); + size_ = new_size; + + for (void* p : fallback_blocks_) { + detail::SystemAlignedFree(p); + } + fallback_blocks_.clear(); + fallback_blocks_total_size_ = 0; +} + +} // namespace ruy diff --git a/ruy/allocator.h b/ruy/allocator.h new file mode 100644 index 0000000..8aee01c --- /dev/null +++ b/ruy/allocator.h @@ -0,0 +1,100 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_ALLOCATOR_H_ +#define RUY_RUY_ALLOCATOR_H_ + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <vector> + +namespace ruy { + +// Specialized allocator designed to converge to a steady-state where all +// allocations are bump-ptr allocations from an already-allocated buffer. +// +// To support these constraints, this allocator only supports two +// operations. +// - AllocateBytes/Allocate<Pointer>: allocates a pointer to storage of a +// specified size, which will be aligned to kMinimumBlockAlignment. +// - FreeAll: frees all previous allocations (but retains the internal +// buffer to minimize future calls into the system allocator). +// +// This class is specialized for supporting just those two operations +// under this specific steady-state usage pattern. Extending this class +// with new allocation interfaces that don't fit that pattern is probably not +// the right choice. Instead, build a new class on top of +// SystemAlignedAlloc/SystemAlignedFree. +// +// All operations happen on aligned blocks for simplicity. +// +// Theory of operation: +// +// - ptr_, current_, and size_ implement a basic bump-ptr allocator. +// +// - in AllocateBytes, the fast path is just a bump-ptr +// allocation. If our bump-ptr allocator doesn't have enough space for an +// allocation, then we allocate a block from the system allocator to +// service the allocation request. We save that block in fallback_blocks_ +// and track the total size of the fallback blocks in +// fallback_blocks_total_size_. +// +// - in FreeAll, the fast path just resets the bump-ptr allocator. If +// there are any fallback blocks, we free them and reallocate the +// bump-ptr allocator's buffer so that the next sequence of allocations +// will hopefully not need any fallback blocks. +class Allocator final { + public: + ~Allocator(); + + // Allocate a buffer. + void* AllocateBytes(std::ptrdiff_t num_bytes); + // Allocate a buffer, trying to avoid having its address close to aliasing + // the specified `to_avoid` in the L1D cache. + void* AllocateBytesAvoidingAliasingWith(std::ptrdiff_t num_bytes, + const void* to_avoid); + // Allocate an array of `count` elements of type T. + template <typename T> + T* Allocate(std::ptrdiff_t count) { + return static_cast<T*>(AllocateBytes(count * sizeof(T))); + } + // Allocate an array of `count` elements of the given `Pointer` type's + // element_type. + template <typename Pointer> + void Allocate(std::ptrdiff_t count, Pointer* out) { + using T = typename std::pointer_traits<Pointer>::element_type; + *out = Allocate<T>(count); + } + + // Free all allocated blocks. Internally consolidate allocated buffers as + // explained in the class comment. + void FreeAll(); + + private: + void operator=(const Allocator&) = delete; + void* AllocateFast(std::ptrdiff_t num_bytes); + void* AllocateSlow(std::ptrdiff_t num_bytes); + + void* ptr_ = nullptr; + std::ptrdiff_t current_ = 0; + std::ptrdiff_t size_ = 0; + std::vector<void*> fallback_blocks_; + std::ptrdiff_t fallback_blocks_total_size_ = 0; +}; + +} // namespace ruy + +#endif // RUY_RUY_ALLOCATOR_H_ diff --git a/ruy/allocator_test.cc b/ruy/allocator_test.cc new file mode 100644 index 0000000..ee31669 --- /dev/null +++ b/ruy/allocator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/allocator.h" + +#include "ruy/gtest_wrapper.h" + +namespace ruy { +namespace { + +TEST(AllocatorTest, ReturnsValidMemory) { + Allocator allocator; + int *p; + allocator.Allocate(1, &p); + ASSERT_NE(p, nullptr); + + // If this is bogus memory, ASan will cause this test to fail. + *p = 42; + + allocator.FreeAll(); +} + +TEST(AllocatorTest, NoLeak) { + Allocator allocator; + // Allocate and free some ridiculously large total amount of memory, so + // that a leak will hopefully cause some sort of resource exhaustion. + // + // Despite the large number of allocations, this test is actually quite + // fast, since our fast-path allocation logic is very fast. + constexpr int kNumAllocations = 100 * 1024; + constexpr int kAllocationSize = 1024 * 1024; + for (int i = 0; i < kNumAllocations; i++) { + char *p; + allocator.Allocate(kAllocationSize, &p); + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, IncreasingSizes) { + Allocator allocator; + // Allocate sizes that increase by small amounts across FreeAll calls. + for (int i = 1; i < 100 * 1024; i++) { + char *p; + allocator.Allocate(i, &p); + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, ManySmallAllocations) { + Allocator allocator; + // Allocate many small allocations between FreeAll calls. + for (int i = 0; i < 10 * 1024; i += 100) { + for (int j = 0; j < i; j++) { + char *p; + allocator.Allocate(1, &p); + } + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, DestructorHandlesMainBumpPtr) { + // This is a white-box test. + Allocator allocator; + allocator.AllocateBytes(1); + allocator.FreeAll(); + // After the call to FreeAll, the allocator will consolidate all of the memory + // into the main bump-ptr allocator's block, which we then expect to be freed + // in the destructor. + // + // We have no test assertions -- we primarily expect that this trigger a leak + // checker and cause the test to fail. +} + +TEST(AllocatorTest, DestructorHandlesFallbackBlocks) { + // This is a white-box test. + Allocator allocator; + // Since we just created the allocator, this will allocate a fallback block, + // which we then expect to be freed in the destructor. + // + // We have no test assertions -- we primarily expect that this trigger a leak + // checker and cause the test to fail. + allocator.AllocateBytes(1); +} + +TEST(AllocatorTest, AvoidAliasing) { + Allocator allocator; + // Run twice with a FreeAll in between, just in case some future + // change of internal logic makes that bug-prone. + for (int repeat = 0; repeat < 2; repeat++) { + for (int i = 1; i < 100; i++) { + const void *to_avoid = + reinterpret_cast<const void *>(0x1234567890123ull + 123 * i); + void *ptr = allocator.AllocateBytesAvoidingAliasingWith(i * 10, to_avoid); + auto unsigned_low_bits = [](const void *p) { + return static_cast<std::uint32_t>(reinterpret_cast<std::uintptr_t>(p)); + }; + static constexpr int kMinPeriod = 1024; + std::uint32_t unsigned_diff = + (unsigned_low_bits(ptr) - unsigned_low_bits(to_avoid)) % kMinPeriod; + std::uint32_t unsigned_diff_mod = unsigned_diff % kMinPeriod; + ASSERT_TRUE(unsigned_diff_mod >= (kMinPeriod / 4) && + unsigned_diff_mod <= 3 * (kMinPeriod / 4)); + } + allocator.FreeAll(); + } +} + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/apply_multiplier.cc b/ruy/apply_multiplier.cc new file mode 100644 index 0000000..19bfd88 --- /dev/null +++ b/ruy/apply_multiplier.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/apply_multiplier.h" + +#include <cmath> +#include <cstdint> +#include <cstdlib> +#include <limits> + +namespace ruy { +namespace detail { + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Warning: this code is not meant to be bit-exact-normative. +// Please refer to the class comment of ruy::MulParams, in mul_params.h. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Simplified multiplier application function +// +// Double rounding and symmetric rounding are removed compared to reference. +// Double rounding seems unnecessary and can complicate implementations. +// Symmetric rounding also adds implementation complexity. +// +// Composed of a single rounding shift right and can lead to more HW +// friendly implementations. +// +// On NEON this can be translated to a SQDMULH + rounding shift right sequence. +// The use of SQDMULH rather than SQRDMULH gives a result that is +// equivalent to a single rounded shift since the truncating shift of SQDMULH +// can be combined with the rounding right shift via the formula (for k>=1): +// ((x>>31)+(1<<(k-1)))>>k = (x + (1<<(30+k))>>(31+k) +// +// Preconditions: +// - quantized_multiplier >= 0 +// - shift is -31 to +7 (negative for right shift) +std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x, + std::int32_t quantized_multiplier, + int shift) { + RUY_CHECK_GE(shift, -31); + RUY_CHECK_LE(shift, 7); + + int total_shift = 31 - shift; + + std::int64_t x_64(x); + std::int64_t quantized_multiplier_64(quantized_multiplier); + std::int64_t round = (int64_t)1 << (total_shift - 1); + int64_t result = x_64 * quantized_multiplier_64 + round; + result = result >> total_shift; + + RUY_DCHECK_GE(result, std::numeric_limits<std::int32_t>::lowest()); + RUY_DCHECK_LE(result, std::numeric_limits<std::int32_t>::max()); + + return static_cast<std::int32_t>(result); +} + +} // namespace detail + +} // namespace ruy diff --git a/ruy/apply_multiplier.h b/ruy/apply_multiplier.h new file mode 100644 index 0000000..120b990 --- /dev/null +++ b/ruy/apply_multiplier.h @@ -0,0 +1,92 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides a reference (portable, non-optimized) ApplyMultiplier function. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Warning: this code is not meant to be bit-exact-normative. +// Please refer to the class comment of ruy::MulParams, in mul_params.h. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +#ifndef RUY_RUY_APPLY_MULTIPLIER_H_ +#define RUY_RUY_APPLY_MULTIPLIER_H_ + +#include <cstdint> +#include <type_traits> + +#include "ruy/check_macros.h" +#include "ruy/mul_params.h" + +namespace ruy { + +// Applies the quantized multiplier to the `*accum` accumulator value, if +// applicable, that is, if AccumScalar==int32 and DstScalar!=int32. Otherwise, +// does nothing. +// +// This is slow, portable, 'reference' code. It should only be used in +// ReferenceMul and in Path::kStandardCpp. There isn't a point in optimizing it, +// either. Fast paths have that multiplier work done as part of the kernel, +// typically written in assembly anyway. +template <typename AccumScalar, typename DstScalar> +void ApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params, + int channel, AccumScalar* accum); + +namespace detail { + +// Copied from TF Lite code. +std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x, + std::int32_t quantized_multiplier, + int shift); + +// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar +// is int32 (i.e. in all cases except floating-point) and if the destination is +// not int32 (i.e. unless the user wants to get raw accumulators). +template <typename AccumScalar, typename DstScalar, + bool IsApplicable = std::is_same<AccumScalar, std::int32_t>::value && + !std::is_same<DstScalar, std::int32_t>::value> +struct ApplyMultiplierImpl {}; + +// Specialization in non-applicable case: do nothing. +template <typename AccumScalar, typename DstScalar> +struct ApplyMultiplierImpl<AccumScalar, DstScalar, false> { + static void Run(const MulParams<AccumScalar, DstScalar>&, int, AccumScalar*) { + } +}; + +template <typename AccumScalar, typename DstScalar> +struct ApplyMultiplierImpl<AccumScalar, DstScalar, true> { + static void Run(const MulParams<AccumScalar, DstScalar>& mul_params, + int channel, AccumScalar* accum) { + AccumScalar m = mul_params.multiplier_fixedpoint_perchannel() + ? mul_params.multiplier_fixedpoint_perchannel()[channel] + : mul_params.multiplier_fixedpoint(); + int e = mul_params.multiplier_exponent_perchannel() + ? mul_params.multiplier_exponent_perchannel()[channel] + : mul_params.multiplier_exponent(); + *accum = MultiplyByQuantizedMultiplier(*accum, m, e); + } +}; + +} // namespace detail + +template <typename AccumScalar, typename DstScalar> +void ApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params, + int channel, AccumScalar* accum) { + detail::ApplyMultiplierImpl<AccumScalar, DstScalar>::Run(mul_params, channel, + accum); +} + +} // namespace ruy + +#endif // RUY_RUY_APPLY_MULTIPLIER_H_ diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc new file mode 100644 index 0000000..2df80d7 --- /dev/null +++ b/ruy/apply_multiplier_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/apply_multiplier.h" + +#include <cstdint> +#include <limits> + +#include "ruy/gtest_wrapper.h" +#include "ruy/mul_params.h" + +namespace ruy { +namespace { + +void TestMultiplyByQuantizedMultiplier(std::int32_t input, + std::int32_t multiplier_fixedpoint, + int multiplier_exponent, + std::int32_t expected_output) { + EXPECT_EQ(expected_output, + detail::MultiplyByQuantizedMultiplier(input, multiplier_fixedpoint, + multiplier_exponent)); +} + +// These testcases exercise various multiplier_fixedpoint values, leaving +// multiplier_exponent = 0. They exercise the logic in +// SaturatingRoundingDoublingHighMul. +TEST(ApplyMultiplierTest, SaturatingRoundingDoublingHighMul) { + const std::int32_t max_int32 = std::numeric_limits<std::int32_t>::max(); + TestMultiplyByQuantizedMultiplier(1000, max_int32, 0, 1000); + TestMultiplyByQuantizedMultiplier(1000, 1 << 30, 0, 500); + TestMultiplyByQuantizedMultiplier(1000, 1 << 29, 0, 250); + TestMultiplyByQuantizedMultiplier(1000, (1 << 30) + (1 << 29), 0, 750); + TestMultiplyByQuantizedMultiplier(1000, (1 << 30) + (1 << 28), 0, 625); + // This 563 expected value comes from the breaking of the tie on 562.5. + // As a positive value, it does not distinguish between 'upward' and + // 'away from zero' tie-breaking behavior. + TestMultiplyByQuantizedMultiplier(1000, (1 << 30) + (1 << 27), 0, 563); + TestMultiplyByQuantizedMultiplier(1000, (1 << 30) + (1 << 26), 0, 531); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, 0, -1000); + TestMultiplyByQuantizedMultiplier(-1000, 1 << 30, 0, -500); + TestMultiplyByQuantizedMultiplier(-1000, 1 << 29, 0, -250); + TestMultiplyByQuantizedMultiplier(-1000, (1 << 30) + (1 << 29), 0, -750); + TestMultiplyByQuantizedMultiplier(-1000, (1 << 30) + (1 << 28), 0, -625); + // This -562 expected value is because the SaturatingRoundingDoublingHighMul + // breaks ties upwards, not away-from-zero. The value before rounding is + // -562.5. No other test case here tests a tie on a negative value. + TestMultiplyByQuantizedMultiplier(-1000, (1 << 30) + (1 << 27), 0, -562); + TestMultiplyByQuantizedMultiplier(-1000, (1 << 30) + (1 << 26), 0, -531); +} + +// These testcases exercise various negative multiplier_exponent values while +// keeping multiplier_fixedpoint trivial. +TEST(ApplyMultiplierTest, RoundingRightShift) { + const std::int32_t max_int32 = std::numeric_limits<std::int32_t>::max(); + TestMultiplyByQuantizedMultiplier(1000, max_int32, -1, 500); + TestMultiplyByQuantizedMultiplier(1000, max_int32, -2, 250); + TestMultiplyByQuantizedMultiplier(1000, max_int32, -3, 125); + TestMultiplyByQuantizedMultiplier(1000, max_int32, -4, 62); + TestMultiplyByQuantizedMultiplier(1000, max_int32, -5, 31); + TestMultiplyByQuantizedMultiplier(1000, max_int32, -6, 16); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, -1, -500); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, -2, -250); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, -3, -125); + // This -62 value comes from rounding -62.5, which is a tie. + // Is the only test case here that exercises a tie-break on a negative value, + // distinguishing between 'upward' and 'away from zero'. + TestMultiplyByQuantizedMultiplier(-1000, max_int32, -4, -62); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, -5, -31); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, -6, -16); +} + +// These testcases exercise various positive multiplier_exponent values while +// keeping multiplier_fixedpoint trivial. +TEST(ApplyMultiplierTest, LeftShift) { + const std::int32_t max_int32 = std::numeric_limits<std::int32_t>::max(); + TestMultiplyByQuantizedMultiplier(1000, max_int32, 1, 2000); + TestMultiplyByQuantizedMultiplier(1000, max_int32, 2, 4000); + TestMultiplyByQuantizedMultiplier(1000, max_int32, 3, 8000); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, 1, -2000); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, 2, -4000); + TestMultiplyByQuantizedMultiplier(-1000, max_int32, 3, -8000); +} + +template <typename AccumScalar, typename DstScalar> +void TestApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params, + int channel, AccumScalar input, + AccumScalar expected_output) { + AccumScalar actual_output = input; + ApplyMultiplier(mul_params, channel, &actual_output); + EXPECT_EQ(expected_output, actual_output); +} + +TEST(ApplyMultiplierTest, ApplyMultiplierUniform) { + MulParams<std::int32_t, std::int8_t> mul_params; + // Test that default values give a multiplication by 1. + TestApplyMultiplier(mul_params, 0, 1000, 1000); + mul_params.set_multiplier_fixedpoint(1 << 30); + mul_params.set_multiplier_exponent(-1); + TestApplyMultiplier(mul_params, 0, 1000, 250); + mul_params.set_multiplier_fixedpoint(1 << 25); + mul_params.set_multiplier_exponent(3); + TestApplyMultiplier(mul_params, 0, 1000, 125); +} + +TEST(ApplyMultiplierTest, ApplyMultiplierPerChannel) { + const std::int32_t max_int32 = std::numeric_limits<std::int32_t>::max(); + const std::int32_t multiplier_fixedpoint[4] = {max_int32, 1 << 30, max_int32, + 1 << 30}; + const int multiplier_exponent[4] = {0, 0, -1, -1}; + MulParams<std::int32_t, std::int8_t> mul_params; + mul_params.set_multiplier_fixedpoint_perchannel(multiplier_fixedpoint); + mul_params.set_multiplier_exponent_perchannel(multiplier_exponent); + TestApplyMultiplier(mul_params, 0, 1000, 1000); + TestApplyMultiplier(mul_params, 1, 1000, 500); + TestApplyMultiplier(mul_params, 2, 1000, 500); + TestApplyMultiplier(mul_params, 3, 1000, 250); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/asm_helpers.h b/ruy/asm_helpers.h new file mode 100644 index 0000000..f0bc303 --- /dev/null +++ b/ruy/asm_helpers.h @@ -0,0 +1,43 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Some helpers to write inline asm. + +#ifndef RUY_RUY_ASM_HELPERS_H_ +#define RUY_RUY_ASM_HELPERS_H_ + +#include "ruy/opt_set.h" + +// Enclose load-prefetch instructions in RUY_PREFETCH_LOAD() so we can +// conditionally enable them based on the RUY_OPT_SET. +#if RUY_OPT(PREFETCH_LOAD) +#define RUY_PREFETCH_LOAD(X) X +#else +#define RUY_PREFETCH_LOAD(X) +#endif + +// Enclose store-prefetch instructions in RUY_PREFETCH_STORE() so we can +// conditionally enable them based on the RUY_OPT_SET. +#if RUY_OPT(PREFETCH_STORE) +#define RUY_PREFETCH_STORE(X) X +#else +#define RUY_PREFETCH_STORE(X) +#endif + +// The usual stringification macro. +#define RUY_STR(s) RUY_STR_UNEXPANDED(s) +#define RUY_STR_UNEXPANDED(s) #s + +#endif // RUY_RUY_ASM_HELPERS_H_ diff --git a/ruy/benchmark.cc b/ruy/benchmark.cc new file mode 100644 index 0000000..3c63249 --- /dev/null +++ b/ruy/benchmark.cc @@ -0,0 +1,223 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdio> +#include <cstdlib> +#include <string> + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; +using TestSetType = TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>; + +struct BenchmarkShape { + int rows; + int depth; + int cols; + int symm_lhs; + int symm_rhs; +}; + +template <typename TestSetType> +std::vector<std::unique_ptr<TestResult<DstScalar>>> Benchmark( + const BenchmarkShape& shape) { + TestSetType test_set; + test_set.rows = shape.rows; + test_set.depth = shape.depth; + test_set.cols = shape.cols; + const char* orders = "RCC"; + const char* orders_env = getenv("ORDERS"); + if (orders_env) { + bool error = false; + if (strlen(orders_env) != 3) { + error = true; + } else { + for (int i = 0; i < 3; i++) { + if (orders_env[i] != 'R' && orders_env[i] != 'C') { + error = true; + } + } + } + if (error) { + fprintf(stderr, + "ORDERS must contain 3 letters, each either R or C, indicating " + "whether to use Row-major or Column-major storage order for the " + "LHS, RHS and Destination matrix.\n"); + exit(EXIT_FAILURE); + } + orders = orders_env; + } + test_set.lhs_order = orders[0] == 'R' ? Order::kRowMajor : Order::kColMajor; + test_set.rhs_order = orders[1] == 'R' ? Order::kRowMajor : Order::kColMajor; + test_set.dst_order = orders[2] == 'R' ? Order::kRowMajor : Order::kColMajor; + test_set.layout_style = LayoutStyle::kUnstridedLinear; + test_set.benchmark = true; + const int asymmetry_lhs = shape.symm_lhs ? 0 : 1; + const int asymmetry_rhs = shape.symm_rhs ? 0 : 1; + test_set.lhs_zero_point = SymmetricZeroPoint<LhsScalar>() + asymmetry_lhs; + test_set.rhs_zero_point = SymmetricZeroPoint<RhsScalar>() + asymmetry_rhs; + test_set.use_specified_zero_points = true; + test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL"); + if (getenv("PREPACK_LHS") || getenv("PREPACK_RHS")) { + fprintf(stderr, + "PREPACK_LHS and PREPACK_RHS are deprecated. Use CACHE_LHS and " + "CACHE_RHS instead.\n"); + exit(EXIT_FAILURE); + } + test_set.cache_lhs = GetBoolEnvVarOrFalse("CACHE_LHS"); + test_set.cache_rhs = GetBoolEnvVarOrFalse("CACHE_RHS"); + test_set.Run(); + return std::move(test_set.results); +} + +std::vector<int> ParseCommaSeparatedInts( + const std::string& comma_separated_ints) { + std::vector<int> result; + for (std::size_t pos = 0; pos < comma_separated_ints.size();) { + std::size_t delim_pos = comma_separated_ints.find(',', pos); + if (delim_pos == std::string::npos) { + delim_pos = comma_separated_ints.size(); + } + result.push_back( + std::stoi(comma_separated_ints.substr(pos, delim_pos - pos))); + pos = delim_pos + 1; + } + return result; +} + +void Benchmark() { + const bool symm_lhs = std::is_floating_point<LhsScalar>::value || + GetBoolEnvVarOrFalse("SYMM_LHS"); + const bool symm_rhs = std::is_floating_point<RhsScalar>::value || + GetBoolEnvVarOrFalse("SYMM_RHS"); + const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") || + GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST"); + const int explicit_rows = GetIntEnvVarOrZero("ROWS"); + const int explicit_cols = GetIntEnvVarOrZero("COLS"); + const int explicit_depth = GetIntEnvVarOrZero("DEPTH"); + + std::vector<BenchmarkShape> shapes; + + if (benchmark_cubic) { + std::vector<int> sizes; + const char* benchmark_cubic_list_env = getenv("RUY_BENCHMARK_CUBIC_LIST"); + if (benchmark_cubic_list_env) { + sizes = ParseCommaSeparatedInts(benchmark_cubic_list_env); + } else { + // Often 8 is used for this multiplier, but to check teeny sizes one can + // use 1. + static constexpr int cubic_size_multiplier = 8; + for (int i = 2 * cubic_size_multiplier; + i <= (512 * cubic_size_multiplier); i *= 2) { + sizes.push_back(i); + if (i < (512 * cubic_size_multiplier)) { + sizes.push_back(i * 3 / 2); + } + } + } + for (int i : sizes) { + BenchmarkShape shape; + // Even in cubic mode, one may still override an individual dimension + // to allow testing a batch of rectangular sizes. + shape.rows = explicit_rows ? explicit_rows : i; + shape.cols = explicit_cols ? explicit_cols : i; + shape.depth = explicit_depth ? explicit_depth : i; + shape.symm_lhs = symm_lhs; + shape.symm_rhs = symm_rhs; + shapes.push_back(shape); + } + } else { + BenchmarkShape shape; + shape.rows = explicit_rows; + shape.cols = explicit_cols; + shape.depth = explicit_depth; + if (!shape.rows || !shape.depth || !shape.cols) { + fprintf(stderr, + "Please specify positive sizes with these env vars: ROWS, DEPTH, " + "COLS.\n"); + exit(1); + } + shape.symm_lhs = symm_lhs; + shape.symm_rhs = symm_rhs; + shapes.push_back(shape); + } + + for (int i = 0; i < static_cast<int>(shapes.size()); i++) { + const auto& shape = shapes[i]; + const auto& results = Benchmark<TestSetType>(shape); + if (i == 0) { + if (benchmark_cubic) { + printf("size"); + for (const auto& result : results) { + if (results.size() > 1) { + printf(",%s:Gop/s", PathName(*result).c_str()); + } else { + printf(",Gop/s"); + } + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf( + ",l1_refill,l2_refill,l3_refill,l1tlb_refill,l2tlb_refill," + "mispred,frontend_stall,backend_stall"); + } + } + printf("\n"); + } else { + printf("path,shape,Gop/s\n"); + } + fflush(stdout); + } + if (benchmark_cubic) { + printf("%d", shape.rows); + for (const auto& result : results) { + printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth / + result->latency); + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", + result->l1_refill_rate, result->l2_refill_rate, + result->l3_refill_rate, result->l1tlb_refill_rate, + result->l2tlb_refill_rate, result->mispred_rate, + result->frontend_stall_rate, result->backend_stall_rate); + } + } + printf("\n"); + fflush(stdout); + } else { + for (const auto& result : results) { + printf( + "%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows, + shape.depth, shape.cols, + 2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency); + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", + result->l1_refill_rate, result->l2_refill_rate, + result->l3_refill_rate, result->l1tlb_refill_rate, + result->l2tlb_refill_rate, result->mispred_rate, + result->frontend_stall_rate, result->backend_stall_rate); + } + printf("\n"); + } + fflush(stdout); + } + } +} + +} // namespace ruy + +int main() { ruy::Benchmark(); } diff --git a/ruy/block_map.cc b/ruy/block_map.cc new file mode 100644 index 0000000..8240de2 --- /dev/null +++ b/ruy/block_map.cc @@ -0,0 +1,497 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/block_map.h" + +#include <algorithm> +#include <cstdint> + +#ifdef RUY_MAKEBLOCKMAP_DEBUG +#include <cstdio> +#include <cstdlib> +#include <string> +#endif + +#include "ruy/check_macros.h" +#include "ruy/opt_set.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/size_util.h" +#include "ruy/trace.h" + +namespace ruy { + +namespace { + +void DecodeTraversalLinear(int size_log2, std::uint32_t square_index, + SidePair<int>* local_pos) { + (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1); + (*local_pos)[Side::kRhs] = square_index >> size_log2; +} + +void DecodeTraversalFractalZ(std::uint32_t square_index, + SidePair<int>* local_pos) { + const std::uint32_t n1 = square_index; + const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) | + ((n1 & 0x22222222u) << 1); + const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) | + ((n2 & 0x0c0c0c0cu) << 2); + const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) | + ((n4 & 0x00f000f0u) << 4); + const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) | + ((n8 & 0x0000ff00u) << 8); + (*local_pos)[Side::kLhs] = n16 & 0xffff; + (*local_pos)[Side::kRhs] = n16 >> 16; +} + +void DecodeTraversalFractalU(std::uint32_t square_index, + SidePair<int>* local_pos) { + DecodeTraversalFractalZ(square_index, local_pos); + // Change fractal z-order to u-order + (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs]; +} + +// Code inspired by the sample code in +// https://en.wikipedia.org/wiki/Hilbert_curve +// The main optimization is to avoid hard-to-predict conditional branches +// based on the bits of the square_index parameter. +void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index, + SidePair<int>* local_pos) { + std::uint32_t t = square_index; + std::uint32_t x = 0; + std::uint32_t y = 0; + // Easy-to-predict for loop, the number of iterations is the same for + // an entire GEMM. + for (int sb = 0; sb < size_log2; sb++) { + std::uint32_t s = 1 << sb; + bool rx = t & 2; + bool ry = (t & 1) ^ rx; + std::uint32_t tmp = rx ? (s - 1 - x) : x; + x = ry ? x : rx ? (s - 1 - y) : y; + y = ry ? (y + s) : tmp; + x = rx ? (x + s) : x; + t >>= 2; + } + (*local_pos)[Side::kLhs] = y; + (*local_pos)[Side::kRhs] = x; +} + +} // end anonymous namespace + +void GetBlockByIndex(const BlockMap& block_map, int index, + SidePair<int>* block) { + profiler::ScopeLabel label("GetBlockByIndex"); + const std::uint32_t index_u32 = index; + + const std::uint32_t num_blocks_per_local_curve = + 1u << (2 * block_map.num_blocks_base_log2); + const std::uint32_t square_index = + index_u32 & (num_blocks_per_local_curve - 1); + + const int size_log2 = block_map.num_blocks_base_log2; + SidePair<int> local_pos; + switch (block_map.traversal_order) { + case BlockMapTraversalOrder::kFractalZ: + DecodeTraversalFractalZ(square_index, &local_pos); + break; + case BlockMapTraversalOrder::kFractalU: + DecodeTraversalFractalU(square_index, &local_pos); + break; + case BlockMapTraversalOrder::kFractalHilbert: + DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos); + break; + default: + RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear); + DecodeTraversalLinear(size_log2, square_index, &local_pos); + break; + } + + const std::uint32_t rectangular_index = + index_u32 >> 2 * block_map.num_blocks_base_log2; + for (Side side : {Side::kLhs, Side::kRhs}) { + const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1; + const int rectangular_offset = (rectangular_index & mask) + << block_map.num_blocks_base_log2; + (*block)[side] = local_pos[side] + rectangular_offset; + } +} + +namespace { + +BlockMapTraversalOrder GetTraversalOrder( + int rows_after_rectangularness_division, + int cols_after_rectangularness_division, int depth, int lhs_scalar_size, + int rhs_scalar_size, const CpuCacheParams& cpu_cache_params) { + static constexpr bool kAnyFractal = + RUY_OPT(FRACTAL_Z) | RUY_OPT(FRACTAL_U) | RUY_OPT(FRACTAL_HILBERT); + const int working_set_size = + (lhs_scalar_size * rows_after_rectangularness_division + + rhs_scalar_size * cols_after_rectangularness_division) * + depth; + if (kAnyFractal && (working_set_size > cpu_cache_params.local_cache_size)) { + if (RUY_OPT(FRACTAL_HILBERT) && + (working_set_size > cpu_cache_params.last_level_cache_size)) { + return BlockMapTraversalOrder::kFractalHilbert; + } else if (RUY_OPT(FRACTAL_U)) { + return BlockMapTraversalOrder::kFractalU; + } else { + return BlockMapTraversalOrder::kFractalZ; + } + } else { + return BlockMapTraversalOrder::kLinear; + } +} + +int floor_log2_quotient(int num, int denom) { + if (num <= denom) { + return 0; + } + int log2_quotient = floor_log2(num) - ceil_log2(denom); + if ((denom << (log2_quotient + 1)) <= num) { + log2_quotient++; + } + return log2_quotient; +} + +// Computes the rectangularness of the matrix shape (rows, cols). This is +// essentially just the log2 of the quotient (rows / cols). The kernel_rows and +// kernel_cols only get into the picture for clamping bounds but don't affect +// the generic computation. +void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols, + int* rows_rectangularness_log2, + int* cols_rectangularness_log2) { + *rows_rectangularness_log2 = 0; + *cols_rectangularness_log2 = 0; + + // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel + // itself, we risk having too small kernel blocks for good kernel + // amortization. We avoid that by limiting recangularness so that kernel + // blocks are not too tiny at least in that dimension. Specifically, we try to + // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each + // kernel block along the large dimension. + const int min_kernel_inner_loop_runs_log2 = 3; + if (rows > cols) { + int cols_of_kernel_inner_loop_runs_log2 = + ceil_log2(cols) - pot_log2(kernel_cols); + int min_rows_of_kernel_inner_loop_runs_log2 = + std::max(0, min_kernel_inner_loop_runs_log2 - + cols_of_kernel_inner_loop_runs_log2); + *rows_rectangularness_log2 = + std::min(floor_log2_quotient(rows, cols), + std::max(0, floor_log2(rows) - pot_log2(kernel_rows) - + min_rows_of_kernel_inner_loop_runs_log2)); + // Sanity check that we did not over-estimate rows_rectangularness_log2. + RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols); + } else if (cols > rows) { + int rows_of_kernel_inner_loop_runs_log2 = + ceil_log2(rows) - pot_log2(kernel_rows); + int min_cols_of_kernel_inner_loop_runs_log2 = + std::max(0, min_kernel_inner_loop_runs_log2 - + rows_of_kernel_inner_loop_runs_log2); + *cols_rectangularness_log2 = + std::min(floor_log2_quotient(cols, rows), + std::max(0, floor_log2(cols) - pot_log2(kernel_cols) - + min_cols_of_kernel_inner_loop_runs_log2)); + // Sanity check that we did not over-estimate cols_rectangularness_log2. + RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows); + } + RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2); +} + +// Computes a 'multithreading score'. When multithreading, we need there to +// be at least as many tiles as there are threads, and hopefully +// substantially more than that, so we benefit from ruy's ability to +// dispatch fine-grained workloads to threads. +int GetMultithreadingScore(int block_size_log2, int rows, int cols, + int tentative_thread_count) { + const int num_full_blocks_of_rows = rows >> block_size_log2; + const int num_full_blocks_of_cols = cols >> block_size_log2; + const int candidate_num_full_blocks_log2 = floor_log2( + std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols)); + + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (tentative_thread_count == 1) { + return 0; + } else { + const int blocks_per_thread_log2 = + candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count); + if (blocks_per_thread_log2 < 0) { + return -64; + } else if (blocks_per_thread_log2 == 0) { + return -16; + } else if (blocks_per_thread_log2 == 1) { + return -8; + } else if (blocks_per_thread_log2 == 2) { + return 0; + } else if (blocks_per_thread_log2 == 3) { + return 8; + } else { + return 16; + } + } +} + +// Computes a 'cache locality score'. +int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth, + int kernel_rows_log2, int kernel_cols_log2, + int lhs_scalar_size, int rhs_scalar_size, + const CpuCacheParams& cpu_cache_params) { + // In the narrow case (e.g. matrix*vector), each byte of the big operand + // matrix (either LHS or RHS) is traversed only once, so any notion of data + // locality is irrelevant. Ignore the 'cache locality score' by forcing it to + // be 0 in that case. + if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) { + return 0; + } + const int block_rows = std::min(1 << block_size_log2, rows); + const int block_cols = std::min(1 << block_size_log2, cols); + const int total_read_bytes = + (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth; + const int total_read_bytes_log2 = ceil_log2(total_read_bytes); + const int nonlocality_log2 = + total_read_bytes_log2 - floor_log2(cpu_cache_params.local_cache_size); + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (nonlocality_log2 < -1) { + return 64; + } else if (nonlocality_log2 == -1) { + return 56; + } else if (nonlocality_log2 == 0) { + return 48; + } else if (nonlocality_log2 == 1) { + return 32; + } else if (nonlocality_log2 == 2) { + return 16; + } else if (nonlocality_log2 == 3) { + return 0; + } else { + return -64; + } +} + +// Compute a 'kernel amortization score'. This is the notion that very small +// tiles result in more overhead outside of kernels, more complex memory +// access patterns and less benefits from ruy's fat kernels, so we reward +// larger blocks more than smaller ones. +int GetKernelAmortizationScore(int block_size_log2, int rows, int cols, + int kernel_rows_log2, int kernel_cols_log2) { + const int block_rows = std::min(1 << block_size_log2, rows); + const int block_cols = std::min(1 << block_size_log2, cols); + const int kernels_per_block_log2 = + floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2; + RUY_DCHECK_GE(kernels_per_block_log2, 0); + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (kernels_per_block_log2 == 0) { + return 0; + } else if (kernels_per_block_log2 == 1) { + return 8; + } else if (kernels_per_block_log2 == 2) { + return 16; + } else if (kernels_per_block_log2 == 3) { + return 24; + } else if (kernels_per_block_log2 == 4) { + return 32; + } else if (kernels_per_block_log2 == 5) { + return 40; + } else if (kernels_per_block_log2 == 6) { + return 48; + } else if (kernels_per_block_log2 == 7) { + return 56; + } else { + return 64; + } +} + +} // namespace + +bool IsObviouslyLinearTraversal(int rows, int cols, int depth, + int lhs_scalar_size, int rhs_scalar_size, + const CpuCacheParams& cpu_cache_params) { + if (rows == 1 || cols == 1) { + return true; + } + // Normally, GetTraversalOrder wants the dimensions (rows x cols) divided + // by the rectangularness factors, since any non-linear traversal order will + // be local to each subdivision. In the present function, we don't know the + // rectangularness factors yet, and we can't just call GetRectangularness + // as that requires knowing the kernel block layout. Since we just want + // a coarse estimate with only the guarantee that if we return `true` then + // linear traversal will be used, it is OK here to over-estimate `rows` and + // `cols`, by omitting to divide them by the rectangularness factors.ß + return GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, + cpu_cache_params) == BlockMapTraversalOrder::kLinear; +} + +void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, + int tentative_thread_count, + const CpuCacheParams& cpu_cache_params, BlockMap* block_map) { + RUY_TRACE_SCOPE; + profiler::ScopeLabel label("MakeBlockMap"); + + RUY_DCHECK_GE(rows, kernel_rows); + RUY_DCHECK_GE(cols, kernel_cols); + RUY_DCHECK_EQ(rows % kernel_rows, 0); + RUY_DCHECK_EQ(cols % kernel_cols, 0); + + // Estimate the 'rectangularness', the first level of subdivision bringing + // the shape to within 2x of a square shape. + int rows_rectangularness_log2 = 0; + int cols_rectangularness_log2 = 0; + GetRectangularness(rows, cols, kernel_rows, kernel_cols, + &rows_rectangularness_log2, &cols_rectangularness_log2); + + const int kernel_rows_log2 = pot_log2(kernel_rows); + const int kernel_cols_log2 = pot_log2(kernel_cols); + const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2); + + const int size = std::min(rows, cols); + const int size_log2 = std::max(kernel_size_log2, floor_log2(size)); + + RUY_DCHECK_GE(size_log2, kernel_size_log2); + + // Heuristic selecting the power-of-two grid subdivision insider of each + // square-ish region (past the above subdivision by 'rectangularness'). + // Note that it is the number of subdivisions, not the resulting block size, + // that will be a power of two. But inside of that heuristic, it simplifies + // code to talk in terms of 'block_size_log2', as if it were the block size + // that were a power of two. This 'block_size_log2' is to be interpreted as + // "log2 rounded below", e.g. when block_size_log2=8 we might have a block + // size in [256, 511]. When the shape is non-square, rows!=cols, this + // refers to the smaller of the two, so the other might be as large as + // 1021 (can't be 1022 because following the above 'rectangularness' + // subdivision, the aspect ratio is already < 2). + + // We are going to try candidate values for block_size_log2 ranging from + // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2). + // For each of them we will compute a 'score' by adding individual scores + // for a few different considerations, all of which is entirely empirical. + // The values (and possibly the logic) around here are all subject to tuning + // based on benchmarks on different hardware. The current values are based + // on benchmarking on Qualcomm S855 (big and little cores), arm64, + // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead + // and tune this as needed to achieve good performance elsewhere. Use + // the unit test, block_map_test, to encode values that should be preserved + // on specific architectures. Use RUY_TRACE to debug the current heuristics + // and RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2 to test the impact of a + // different block_size_log2 choice, to empirically find the optimal value + // before getting to updating the heuristic so that it produces that value. + static constexpr int kMaxKernelsPerBlockLog2 = 6; + const int max_block_size_log2 = + std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2); + int best_score = std::numeric_limits<int>::min(); + int best_score_block_size_log2 = -1; + RUY_TRACE_INFO(MAKE_BLOCK_MAP_START); + for (int block_size_log2 = kernel_size_log2; + block_size_log2 <= max_block_size_log2; block_size_log2++) { + const int multithreading_score = GetMultithreadingScore( + block_size_log2, rows, cols, tentative_thread_count); + const int cache_locality_score = GetCacheLocalityScore( + block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2, + lhs_scalar_size, rhs_scalar_size, cpu_cache_params); + const int kernel_amortization_score = GetKernelAmortizationScore( + block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2); + const int score = + multithreading_score + cache_locality_score + kernel_amortization_score; + if (score >= best_score) { + best_score = score; + best_score_block_size_log2 = block_size_log2; + } + RUY_TRACE_INFO(MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE); + } + +#ifdef RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2 + // Useful for tuning. + best_score_block_size_log2 = RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2; +#endif + + // As explained in the above comment, phrasing the above code in terms of + // block_size_log2 was only convenience inside of that heuristic. Now we + // revert to talking in terms of grid subdivision. That is what will actually + // be powers of two. + int num_blocks_base_log2 = size_log2 - best_score_block_size_log2; + RUY_DCHECK_GE(num_blocks_base_log2, 0); + const int num_blocks_of_rows_log2 = + num_blocks_base_log2 + rows_rectangularness_log2; + const int num_blocks_of_cols_log2 = + num_blocks_base_log2 + cols_rectangularness_log2; + + // Now that we know the grid subdivision, we can pinpoint the exact block + // sizes. They can't be powers of two in general; they can't even be all + // equal in general; so the following few parameters will govern how blocks + // of slightly different shapes are put together in the block map. + const int small_block_rows = + round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows); + const int small_block_cols = + round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols); + const int rows_of_large_blocks = + round_up_pot(rows - (small_block_rows << num_blocks_of_rows_log2), + kernel_rows) >> + pot_log2(kernel_rows); + const int cols_of_large_blocks = + round_up_pot(cols - (small_block_cols << num_blocks_of_cols_log2), + kernel_cols) >> + pot_log2(kernel_cols); + + // We have everything! Write out to the destination block_map. + block_map->dims[Side::kLhs] = rows; + block_map->dims[Side::kRhs] = cols; + block_map->kernel_dims[Side::kLhs] = kernel_rows; + block_map->kernel_dims[Side::kRhs] = kernel_cols; + block_map->num_blocks_base_log2 = num_blocks_base_log2; + block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2; + block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2; + block_map->small_block_dims[Side::kLhs] = small_block_rows; + block_map->small_block_dims[Side::kRhs] = small_block_cols; + block_map->large_blocks[Side::kLhs] = rows_of_large_blocks; + block_map->large_blocks[Side::kRhs] = cols_of_large_blocks; + // See the comment on GetTraversalOrder for why we are dividing `rows` and + // `cols` by the rectangularness subdivision parameters here. + block_map->traversal_order = GetTraversalOrder( + rows >> rows_rectangularness_log2, cols >> cols_rectangularness_log2, + depth, lhs_scalar_size, rhs_scalar_size, cpu_cache_params); + // Done last: NumBlocks needs some of the block_map fields to be already set. + block_map->thread_count = + std::min(tentative_thread_count, NumBlocks(*block_map)); + RUY_TRACE_INFO(MAKE_BLOCK_MAP_END); +} + +void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, + int* start, int* end) { + profiler::ScopeLabel label("GetBlockMatrixCoords"); + *start = block * block_map.small_block_dims[side] + + std::min(block, block_map.large_blocks[side]) * + block_map.kernel_dims[side]; + *end = + *start + block_map.small_block_dims[side] + + (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0); + + RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]); + RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]); + RUY_DCHECK_LE(*end, block_map.dims[side]); + RUY_DCHECK_LT(*start, *end); + RUY_DCHECK_GE(*start, 0); +} + +void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block, + SidePair<int>* start, SidePair<int>* end) { + for (Side side : {Side::kLhs, Side::kRhs}) { + GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side], + &(*end)[side]); + } +} + +} // namespace ruy diff --git a/ruy/block_map.h b/ruy/block_map.h new file mode 100644 index 0000000..0057509 --- /dev/null +++ b/ruy/block_map.h @@ -0,0 +1,162 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_BLOCK_MAP_H_ +#define RUY_RUY_BLOCK_MAP_H_ + +#include "ruy/cpu_cache_params.h" +#include "ruy/side_pair.h" + +namespace ruy { + +enum class BlockMapTraversalOrder { + // Plain old row-by-row or column-by-column traversal. + kLinear, + // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve + kFractalZ, + // Variant of Z-order doing a U instead of a Z. + kFractalU, + // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve + kFractalHilbert +}; + +// A BlockMap describes a tiling of a matrix, typically the destination matrix +// of a matrix multiplication computation. As is standard in matrix +// multiplication, a tile is called a "block". +// +// Ruy subdivides work by blocks of the destination matrix: each thread fully +// computes a block at once, then moves on to another block; each block is +// produced by a single thread. +// +// This ensures that the workloads for each block are mutually independent, +// which reduces synchronization requirements. +// +// Typically, a matrix multiplication will early on create a BlockMap by +// calling MakeBlockMap. It will then query the number of blocks in that +// BlockMap by calling NumBlocks. It will then create a single atomic integer +// counter indexing these blocks, called the 'index', and will distribute +// work to its N threads by ensuring that each thread works on disjoint sets +// of index values. For a given index value, the thread will call +// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords +// to find the actual row and column numbers of this block. +// +// There are two nested levels of subdivision. On a local level, the matrix is +// tiled into a square NxN grid where N is a power of two, specifically: +// N = 2^num_blocks_base_log2. +// +// At a larger scale, around these blocks, there may be one further +// level of subdivision, in only one dimension: either along rows or along +// columns. That is used to handle arbitrarily rectangular matrices. The +// aforementioned high-level block grid is square, so it does not readily fit +// well very rectangular matrices. +// +// Taking together these two nested levels of subdivision, the effective +// tiling is by +// 2^(num_blocks_base_log2 + rows_rectangularness_log2) +// blocks in the row dimension, and by +// 2^(num_blocks_base_log2 + cols_rectangularness_log2) +// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols. +// +// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero. +// +// Finally, this BlockMap is designed to operate under alignment constraints: +// two fields, kernel_rows and kernel_cols, describe the requested alignment +// of the effective grid in both dimensions. The idea is to feed matrix +// multiplication kernels with tiles that fit their width as much as possible. +// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp. +// kernel_cols) then some tile will have to have unaligned size. BlockMap +// will only allow that to happen in the last position along each axis, so +// as to minimize the overhead incurred onto the matrix multiplication kernels. +struct BlockMap { + // The number of threads to use (to distribute the blocks to). + int thread_count; + // The order in which to traverse the matrix of which this BlockMap represents + // a tiling (hereafter "the matrix"). + BlockMapTraversalOrder traversal_order; + // The dimensions of the block_map, that is, of the destination + // matrix rounded up to next multiples of kernel_dims. + SidePair<int> dims; + // Log2 of the minimum number of subdivisions of the grid along either axis. + int num_blocks_base_log2; + // Log2 of the additional subdivision of the rows/columns axis. + SidePair<int> rectangularness_log2; + // Requested alignment of the subdivisions of the grid along the rows/columns + // axis. + SidePair<int> kernel_dims; + // Internal helper. Minimum number of rows/columns in each block. + SidePair<int> small_block_dims; + // Internal helper. Number of blocks along each dimension that need to have + // their size in that dimension be given by (small_block_dims + kernel_dims) + // instead of just small_block_dims. + SidePair<int> large_blocks; +}; + +// This function produces a coarse estimate of whether linear traversal will +// be used for this matmul. It offers a one-way guarantee: if this function +// returns true then linear traversal will be used. +// +// The purpose of this function is to allow TrMul to make a cheap, early +// decision to enter a "simple loop" code path for simple cases. +bool IsObviouslyLinearTraversal(int rows, int cols, int depth, + int lhs_scalar_size, int rhs_scalar_size, + const CpuCacheParams& cpu_cache_params); + +// Create a BlockMap suitable for tiling the destination matrix in a +// matrix multiplication with the given parameters. +void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, + int tentative_thread_count, + const CpuCacheParams& cpu_cache_params, BlockMap* block_map); + +// Maps an integer index to a block position in the grid. +void GetBlockByIndex(const BlockMap& block_map, int index, + SidePair<int>* block); + +// Given a block position in the grid, returns its actual +// position in the matrix that the BlockMap refers to in the dimension +// referred to by `side`: along rows if side==kLhs, along columns if +// side==kRhs. +void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, + int* start, int* end); + +// Given a block position in the grid, returns its actual +// position in the matrix that the BlockMap refers to in terms of +// actual row/column indices. +void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block, + SidePair<int>* start, SidePair<int>* end); + +// Returns the number of grid subdivisions along the rows dimension (if +// side == kLhs) or columns dimension (if side == kRhs). +inline int NumBlocksPerSide(Side side, const BlockMap& block_map) { + return 1 << (block_map.num_blocks_base_log2 + + block_map.rectangularness_log2[side]); +} + +// Returns the overall number of blocks in +// the BlockMap. The valid index values to pass to GetBlockByIndex are the +// integers from 0 to N-1 where N is the value returned here. +// +// Note that it is always true that +// NumBlocks == NumBlocksOfRows * NumBlocksOfCols +// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0. +inline int NumBlocks(const BlockMap& block_map) { + return 1 << (2 * block_map.num_blocks_base_log2 + + block_map.rectangularness_log2[Side::kLhs] + + block_map.rectangularness_log2[Side::kRhs]); +} + +} // namespace ruy + +#endif // RUY_RUY_BLOCK_MAP_H_ diff --git a/ruy/block_map_test.cc b/ruy/block_map_test.cc new file mode 100644 index 0000000..8245a5c --- /dev/null +++ b/ruy/block_map_test.cc @@ -0,0 +1,259 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/block_map.h" + +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <limits> +#include <vector> + +#include "ruy/cpu_cache_params.h" +#include "ruy/gtest_wrapper.h" +#include "ruy/platform.h" +#include "ruy/side_pair.h" + +namespace ruy { +namespace { + +#if RUY_PLATFORM_NEON_64 + +// Unless otherwise specified, these tests have been tuned on ARM Cortex-A55. +void MakeBlockMapTuningTest(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, + int rhs_scalar_size, int tentative_thread_count, + int expected_num_blocks_base_log2, + int expected_rectangularness_log2) { + // Plausible Cortex-A55 cache sizes. + CpuCacheParams cpu_cache_params; + cpu_cache_params.local_cache_size = 128 * 1024; + cpu_cache_params.last_level_cache_size = 1024 * 1024; + BlockMap block_map; + MakeBlockMap(rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, + rhs_scalar_size, tentative_thread_count, cpu_cache_params, + &block_map); + EXPECT_EQ(block_map.num_blocks_base_log2, expected_num_blocks_base_log2); + EXPECT_EQ(std::min(block_map.rectangularness_log2[Side::kLhs], + block_map.rectangularness_log2[Side::kRhs]), + 0); + EXPECT_EQ(std::max(block_map.rectangularness_log2[Side::kLhs], + block_map.rectangularness_log2[Side::kRhs]), + expected_rectangularness_log2); +} + +TEST(BlockMapTest, MakeBlockMapTuningTest8bitCubicShapesOneThreadNeonDotprod) { + MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, + /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, + /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, + /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, + MakeBlockMapTuningTest8bitCubicShapesFourThreadsNeonDotprod) { + MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, + /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, + /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 2, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, + /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 2, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, MakeBlockMapTuningTest32bit) { + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 4, 4, + /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 3, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(4096, 4096, 4096, 8, 8, 4, 4, + /* tentative_thread_count */ 4, + /* expected_num_blocks_base_log2 */ 7, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) { + MakeBlockMapTuningTest(256, 16, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 3); + MakeBlockMapTuningTest(24, 2400, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 6); +} + +#endif + +int L1Distance(const SidePair<int>& a, const SidePair<int>& b) { + return std::abs(a[Side::kLhs] - b[Side::kLhs]) + + std::abs(a[Side::kRhs] - b[Side::kRhs]); +} + +void GetBlockByIndexSquareTest(int num_blocks_base_log2, + BlockMapTraversalOrder traversal_order) { + // Arbitrary, does not affect this test. 3 is just a typical value. + constexpr int kKernelSizeLog2 = 3; + + const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2; + BlockMap block_map; + block_map.thread_count = 1; + block_map.traversal_order = traversal_order; + block_map.num_blocks_base_log2 = num_blocks_base_log2; + for (Side side : {Side::kLhs, Side::kRhs}) { + block_map.dims[side] = 1 << size_log2; + block_map.rectangularness_log2[side] = 0; + block_map.kernel_dims[side] = 1 << kKernelSizeLog2; + block_map.small_block_dims[side] = block_map.kernel_dims[side]; + block_map.large_blocks[side] = 0; + } + + const int num_blocks_per_side = 1 << num_blocks_base_log2; + const int num_blocks = num_blocks_per_side * num_blocks_per_side; + EXPECT_EQ(num_blocks, NumBlocks(block_map)); + + // Perform a full traversal of all blocks, as if computing a whole matrix + // multiplication. + // + // Used to record how many times each block was hit by the traversal. + std::vector<int> block_hit_counts(num_blocks); + // Here we guard an assumption that all traversal orders start at (0, 0). + SidePair<int> previous_block_coords(0, 0); + // Sum of L1 norm of the coordinate change at every step of the traversal. + std::int64_t total_l1_distance = 0; + // Number of jumps i.e. traversal steps with a L1 norm greater than 1. + int discontinuity_count = 0; + for (int block_index = 0; block_index < num_blocks; block_index++) { + SidePair<int> block_coords; + GetBlockByIndex(block_map, block_index, &block_coords); + ++block_hit_counts[block_coords[Side::kLhs] + + num_blocks_per_side * block_coords[Side::kRhs]]; + int distance = L1Distance(block_coords, previous_block_coords); + total_l1_distance += distance; + discontinuity_count += (distance > 1); + previous_block_coords = block_coords; + } + + // Verify that each block was traversed exactly once. + for (int l = 0; l < num_blocks_per_side; l++) { + for (int r = 0; r < num_blocks_per_side; r++) { + EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1); + } + } + + // Verify that the discontinuity_count and total_l1_distance are as expected + // for the given traversal_order. + switch (traversal_order) { + case BlockMapTraversalOrder::kFractalHilbert: + // No discontinuity at all with this space-filling continuous curve! + EXPECT_EQ(discontinuity_count, 0); + // Therefore, total_l1_distance has to be the number of blocks minus one. + EXPECT_EQ(total_l1_distance, num_blocks - 1); + break; + case BlockMapTraversalOrder::kLinear: + EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1); + EXPECT_EQ(total_l1_distance, + 2 * num_blocks_per_side * (num_blocks_per_side - 1)); + break; + case BlockMapTraversalOrder::kFractalZ: + EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0); + EXPECT_EQ(total_l1_distance, + 2 * num_blocks_per_side * (num_blocks_per_side - 1)); + break; + case BlockMapTraversalOrder::kFractalU: { + if (num_blocks_base_log2 == 0) { + EXPECT_EQ(discontinuity_count, 0); + EXPECT_EQ(total_l1_distance, 0); + } else { + int expected_discontinuity_count = 0; + int expected_total_l1_distance = 3; + for (int i = 2; i <= num_blocks_base_log2; i++) { + expected_discontinuity_count = 4 * expected_discontinuity_count + 2; + expected_total_l1_distance = + 4 * expected_total_l1_distance + (1 << (i + 1)) - 1; + } + EXPECT_EQ(discontinuity_count, expected_discontinuity_count); + EXPECT_EQ(total_l1_distance, expected_total_l1_distance); + } + break; + } + default: + abort(); + } +} + +TEST(BlockMapTest, GetBlockByIndexSquare) { + for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10; + num_blocks_base_log2++) { + for (BlockMapTraversalOrder traversal_order : + {BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ, + BlockMapTraversalOrder::kFractalU, + BlockMapTraversalOrder::kFractalHilbert}) { + GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order); + } + } +} + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/blocking_counter.cc b/ruy/blocking_counter.cc new file mode 100644 index 0000000..b62eca9 --- /dev/null +++ b/ruy/blocking_counter.cc @@ -0,0 +1,49 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/blocking_counter.h" + +#include "ruy/check_macros.h" +#include "ruy/wait.h" + +namespace ruy { + +void BlockingCounter::Reset(int initial_count) { + int old_count_value = count_.load(std::memory_order_relaxed); + RUY_DCHECK_EQ(old_count_value, 0); + (void)old_count_value; + count_.store(initial_count, std::memory_order_release); +} + +bool BlockingCounter::DecrementCount() { + int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel); + RUY_DCHECK_GT(old_count_value, 0); + int count_value = old_count_value - 1; + bool hit_zero = (count_value == 0); + if (hit_zero) { + std::lock_guard<std::mutex> lock(count_mutex_); + count_cond_.notify_all(); + } + return hit_zero; +} + +void BlockingCounter::Wait(const Duration spin_duration) { + const auto& condition = [this]() { + return count_.load(std::memory_order_acquire) == 0; + }; + ruy::Wait(condition, spin_duration, &count_cond_, &count_mutex_); +} + +} // namespace ruy diff --git a/ruy/blocking_counter.h b/ruy/blocking_counter.h new file mode 100644 index 0000000..806909b --- /dev/null +++ b/ruy/blocking_counter.h @@ -0,0 +1,66 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_BLOCKING_COUNTER_H_ +#define RUY_RUY_BLOCKING_COUNTER_H_ + +#include <atomic> +#include <condition_variable> // NOLINT(build/c++11) // IWYU pragma: keep +#include <mutex> // NOLINT(build/c++11) // IWYU pragma: keep + +#include "ruy/time.h" + +namespace ruy { + +// A BlockingCounter lets one thread to wait for N events to occur. +// This is how the master thread waits for all the worker threads +// to have finished working. +// The waiting is done using a naive spinlock waiting for the atomic +// count_ to hit the value 0. This is acceptable because in our usage +// pattern, BlockingCounter is used only to synchronize threads after +// short-lived tasks (performing parts of the same GEMM). It is not used +// for synchronizing longer waits (resuming work on the next GEMM). +class BlockingCounter { + public: + BlockingCounter() : count_(0) {} + + // Sets/resets the counter; initial_count is the number of + // decrementing events that the Wait() call will be waiting for. + void Reset(int initial_count); + + // Decrements the counter; if the counter hits zero, signals + // the threads that were waiting for that, and returns true. + // Otherwise (if the decremented count is still nonzero), + // returns false. + bool DecrementCount(); + + // Waits for the N other threads (N having been set by Reset()) + // to hit the BlockingCounter. + // + // Will first spin-wait for `spin_duration` before reverting to passive wait. + void Wait(const Duration spin_duration); + + private: + std::atomic<int> count_; + + // The condition variable and mutex allowing to passively wait for count_ + // to reach the value zero, in the case of longer waits. + std::condition_variable count_cond_; + std::mutex count_mutex_; +}; + +} // namespace ruy + +#endif // RUY_RUY_BLOCKING_COUNTER_H_ diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl new file mode 100644 index 0000000..836f47a --- /dev/null +++ b/ruy/build_defs.bzl @@ -0,0 +1,79 @@ +"""Build definitions for Ruy.""" + +# Helper for ruy_copts(). +# Returns warnings flags to use for all ruy code. +def ruy_copts_warnings(): + return select({ + "@bazel_tools//src/conditions:windows": [ + # We run into trouble on Windows toolchains with warning flags, + # as mentioned in the comments below on each flag. + # We could be more aggressive in enabling supported warnings on each + # Windows toolchain, but we compromise with keeping BUILD files simple + # by limiting the number of config_setting's. + ], + "//conditions:default": [ + "-Wall", + # Some clang-based Windows toolchains have more warnings in -Wextra. + "-Wextra", + # TensorFlow is C++14 at the moment. This flag ensures that we warn + # on any code that isn't C++14, but MSVC does not support it. + "-Wc++14-compat", + # Warn on preprocessor expansion of an undefined token, e.g. catching + # typos such as `#ifdef __linus__` instead of `#ifdef __linux__`. + # Not supported by MSVC. + "-Wundef", + ], + }) + +# Helper for ruy_copts(). +# Returns flags to use to enable NEON if applicable, for all ruy code. +def ruy_copts_neon(): + return select({ + # OK to crash old devices that lack full NEON support. + # No need to pass -mfloat-abi=softfp, that is already on. + "//ruy:arm32_assuming_neon": [ + "-mfpu=neon", + ], + "//conditions:default": [], + }) + +# Helper for ruy_copts(). +# Returns optimization flags to use for all ruy code. +def ruy_copts_optimize(): + return select({ + # On some toolchains, typically mobile, "-c opt" is interpreted by + # default as "optimize for size, not for speed". For Ruy code, + # optimizing for speed is the better compromise, so we override that. + # Careful to keep debug builds debuggable, whence the select based + # on the compilation mode. + "//ruy:do_not_want_O3": [], + "//conditions:default": ["-O3"], + }) + +# Returns compiler flags to use for all ruy code. +def ruy_copts(): + return ruy_copts_warnings() + ruy_copts_neon() + ruy_copts_optimize() + +def ruy_copts_avx(): + return select({ + "//ruy:x86_64_and_not_msvc": ["-mavx"], + "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX"], + "//conditions:default": [], + }) + +def ruy_copts_avx2_fma(): + return select({ + "//ruy:x86_64_and_not_msvc": ["-mavx2", "-mfma"], + "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX2"], + "//conditions:default": [], + }) + +def ruy_copts_avx512(): + # In some clang-based toolchains, in the default compilation mode (not -c opt), + # heavy spillage in the AVX512 kernels results in stack frames > 50k. This issue does not exist + # in optimized builds (-c opt). + return select({ + "//ruy:x86_64_and_not_msvc": ["$(STACK_FRAME_UNLIMITED)", "-mavx512f", "-mavx512vl", "-mavx512cd", "-mavx512bw", "-mavx512dq"], + "@bazel_tools//src/conditions:windows_msvc": ["/arch:AVX512"], + "//conditions:default": [], + }) diff --git a/ruy/build_defs.oss.bzl b/ruy/build_defs.oss.bzl new file mode 100644 index 0000000..e405b41 --- /dev/null +++ b/ruy/build_defs.oss.bzl @@ -0,0 +1,15 @@ +"""Build definitions for Ruy that are specific to the open-source build.""" + +# Used for targets that #include <thread> +def ruy_linkopts_thread_standard_library(): + # In open source builds, GCC is a common occurence. It requires "-pthread" + # to use the C++11 <thread> standard library header. This breaks the + # opensource build on Windows and probably some other platforms, so that + # will need to be fixed as needed. Ideally we would like to do this based + # on GCC being the compiler, but that does not seem to be easy to achieve + # with Bazel. Instead we do the following, which is copied from + # https://github.com/abseil/abseil-cpp/blob/1112609635037a32435de7aa70a9188dcb591458/absl/base/BUILD.bazel#L155 + return select({ + "@bazel_tools//src/conditions:windows": [], + "//conditions:default": ["-pthread"], + }) diff --git a/ruy/check_macros.h b/ruy/check_macros.h new file mode 100644 index 0000000..13d27e2 --- /dev/null +++ b/ruy/check_macros.h @@ -0,0 +1,147 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// self-contained, minimal, CHECK/DCHECK macros similar to glog. + +#ifndef RUY_RUY_CHECK_MACROS_H_ +#define RUY_RUY_CHECK_MACROS_H_ + +#include <cstdio> +#include <cstdlib> +#include <functional> +#include <type_traits> + +namespace ruy { +namespace check_macros { + +constexpr int kValueBufSize = 32; + +template <typename T, typename Enable = void> +struct ToString { + static void Run(const T&, char* buf) { snprintf(buf, kValueBufSize, "(?)"); } +}; + +template <> +struct ToString<float, void> { + static void Run(float value, char* buf) { + snprintf(buf, kValueBufSize, "%.9g", static_cast<double>(value)); + } +}; + +template <> +struct ToString<double, void> { + static void Run(double value, char* buf) { + snprintf(buf, kValueBufSize, "%.16g", value); + } +}; + +template <typename T> +struct ToString<T, typename std::enable_if<std::is_integral<T>::value>::type> { + static void Run(const T& value, char* buf) { + snprintf(buf, kValueBufSize, "%lld", static_cast<long long>(value)); + } +}; + +template <typename T> +struct ToString<T*, void> { + static void Run(T* value, char* buf) { + snprintf(buf, kValueBufSize, "%p", value); + } +}; + +template <typename T> +struct ToString<T, typename std::enable_if<std::is_enum<T>::value>::type> { + static void Run(const T& value, char* buf) { + snprintf(buf, kValueBufSize, "(enum value %d)", static_cast<int>(value)); + } +}; + +inline void CheckImpl(bool condition, const char* file, int line, + const char* macro, const char* condition_str) { + if (!condition) { + fprintf(stderr, "%s:%d: %s condition not satisfied: %s\n", file, line, + macro, condition_str); + abort(); + } +} + +template <template <typename T> class Comparison, typename LhsType, + typename RhsType> +inline void CheckImpl(const char* file, int line, const char* macro, + const char* lhs, const LhsType& lhs_value, + const char* op_symbol, const char* rhs, + const RhsType& rhs_value) { + using CommonType = typename std::common_type<LhsType, RhsType>::type; + if (!Comparison<CommonType>()(lhs_value, rhs_value)) { + char lhs_value_buf[kValueBufSize]; + ToString<LhsType>::Run(lhs_value, lhs_value_buf); + char rhs_value_buf[kValueBufSize]; + ToString<RhsType>::Run(rhs_value, rhs_value_buf); + fprintf( + stderr, + "%s:%d: %s condition not satisfied: [ %s %s %s ] with values [ " + "%s %s %s ].\n", + file, line, macro, lhs, op_symbol, rhs, lhs_value_buf, op_symbol, + rhs_value_buf); + abort(); + } +} + +#define RUY_CHECK_IMPL(macro, condition) \ + ruy::check_macros::CheckImpl(condition, __FILE__, __LINE__, #macro, \ + #condition) + +#define RUY_CHECK_OP_IMPL(macro, lhs, op_symbol, op_comparison, rhs) \ + ruy::check_macros::CheckImpl<op_comparison>( \ + __FILE__, __LINE__, #macro, #lhs, lhs, #op_symbol, #rhs, rhs) + +#define RUY_CHECK(condition) RUY_CHECK_IMPL(RUY_CHECK, condition) +#define RUY_CHECK_EQ(x, y) \ + RUY_CHECK_OP_IMPL(RUY_CHECK_EQ, x, ==, std::equal_to, y) +#define RUY_CHECK_NE(x, y) \ + RUY_CHECK_OP_IMPL(RUY_CHECK_NE, x, !=, std::not_equal_to, y) +#define RUY_CHECK_GE(x, y) \ + RUY_CHECK_OP_IMPL(RUY_CHECK_GE, x, >=, std::greater_equal, y) +#define RUY_CHECK_GT(x, y) \ + RUY_CHECK_OP_IMPL(RUY_CHECK_GT, x, >, std::greater, y) +#define RUY_CHECK_LE(x, y) \ + RUY_CHECK_OP_IMPL(RUY_CHECK_LE, x, <=, std::less_equal, y) +#define RUY_CHECK_LT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LT, x, <, std::less, y) + +#ifdef NDEBUG +#define RUY_DCHECK_IS_ENABLED false +#else +#define RUY_DCHECK_IS_ENABLED true +#endif + +#define RUY_DCHECK(condition) \ + if (RUY_DCHECK_IS_ENABLED) RUY_CHECK(condition) +#define RUY_DCHECK_EQ(x, y) \ + if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_EQ(x, y) +#define RUY_DCHECK_NE(x, y) \ + if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_NE(x, y) +#define RUY_DCHECK_GE(x, y) \ + if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_GE(x, y) +#define RUY_DCHECK_GT(x, y) \ + if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_GT(x, y) +#define RUY_DCHECK_LE(x, y) \ + if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_LE(x, y) +#define RUY_DCHECK_LT(x, y) \ + if (RUY_DCHECK_IS_ENABLED) RUY_CHECK_LT(x, y) + +} // end namespace check_macros +} // end namespace ruy + +#endif // RUY_RUY_CHECK_MACROS_H_ diff --git a/ruy/check_macros_test.cc b/ruy/check_macros_test.cc new file mode 100644 index 0000000..4e7508f --- /dev/null +++ b/ruy/check_macros_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/check_macros.h" + +#include "ruy/gtest_wrapper.h" + +namespace { + +#define TEST_CONDITION_FOR_FAMILY(family, vacuously_succeeds, condition) \ + do { \ + if (vacuously_succeeds || (condition)) { \ + RUY_##family(condition); \ + } \ + } while (false) + +#define TEST_COMPARISON_FOR_FAMILY(family, vacuously_succeeds, op_name, x, op, \ + y) \ + do { \ + if (vacuously_succeeds || ((x)op(y))) { \ + RUY_##family##_##op_name(x, y); \ + } \ + } while (false) + +#ifdef NDEBUG +#define TEST_CONDITION(condition) \ + do { \ + TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ + } while (false) +#define TEST_COMPARISON(op_name, x, op, y) \ + do { \ + TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ + } while (false) +#else +#define TEST_CONDITION(condition) \ + do { \ + TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ + TEST_CONDITION_FOR_FAMILY(DCHECK, false, condition); \ + } while (false) +#define TEST_COMPARISON(op_name, x, op, y) \ + do { \ + TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ + TEST_COMPARISON_FOR_FAMILY(DCHECK, false, op_name, x, op, y); \ + } while (false) + +#endif + +template <typename LhsType, typename RhsType> +void TestEqualityComparisons(const LhsType& lhs, const RhsType& rhs) { + RUY_CHECK_EQ(lhs, lhs); + TEST_COMPARISON(EQ, lhs, ==, lhs); + RUY_CHECK_EQ(lhs, lhs); + RUY_CHECK_EQ(lhs, lhs); + if (lhs == rhs) { + RUY_CHECK_EQ(lhs, rhs); + } + if (lhs != rhs) { + RUY_CHECK_NE(lhs, rhs); + } +} + +template <typename LhsType, typename RhsType> +void TestComparisons(const LhsType& lhs, const RhsType& rhs) { + TestEqualityComparisons(lhs, rhs); + if (lhs > rhs) { + RUY_CHECK_GT(lhs, rhs); + } + if (lhs >= rhs) { + RUY_CHECK_GE(lhs, rhs); + } + if (lhs < rhs) { + RUY_CHECK_LT(lhs, rhs); + } + if (lhs <= rhs) { + RUY_CHECK_LE(lhs, rhs); + } +} + +TEST(CheckMacrosTest, IntInt) { + TestComparisons(0, 0); + TestComparisons(0, 1); + TestComparisons(1, -1); + TestComparisons(-1, 0); + TestComparisons(123, -456); + TestComparisons(std::numeric_limits<int>::min(), + std::numeric_limits<int>::max()); + TestComparisons(123, std::numeric_limits<int>::max()); + TestComparisons(123, std::numeric_limits<int>::min()); +} + +TEST(CheckMacrosTest, Uint8Uint8) { + TestComparisons<std::uint8_t, std::uint8_t>(0, 0); + TestComparisons<std::uint8_t, std::uint8_t>(255, 0); + TestComparisons<std::uint8_t, std::uint8_t>(0, 255); + TestComparisons<std::uint8_t, std::uint8_t>(12, 34); +} + +TEST(CheckMacrosTest, Uint8Int) { + TestComparisons<std::uint8_t, int>(0, std::numeric_limits<int>::min()); + TestComparisons<std::uint8_t, int>(255, std::numeric_limits<int>::min()); + TestComparisons<std::uint8_t, int>(0, std::numeric_limits<int>::max()); + TestComparisons<std::uint8_t, int>(255, std::numeric_limits<int>::max()); +} + +TEST(CheckMacrosTest, FloatFloat) { + TestComparisons(0.f, 0.f); + TestComparisons(0.f, 1.f); + TestComparisons(1.f, -1.f); + TestComparisons(-1.f, 0.f); + TestComparisons(123.f, -456.f); + TestComparisons(std::numeric_limits<float>::lowest(), + std::numeric_limits<float>::max()); + TestComparisons(123.f, std::numeric_limits<float>::max()); + TestComparisons(123.f, std::numeric_limits<float>::lowest()); +} + +TEST(CheckMacrosTest, IntFloat) { + TestComparisons(0, 0.f); + TestComparisons(0, 1.f); + TestComparisons(1, -1.f); + TestComparisons(-1, 0.f); + TestComparisons(123, -456.f); + TestComparisons(std::numeric_limits<int>::lowest(), + std::numeric_limits<float>::max()); + TestComparisons(123, std::numeric_limits<float>::max()); + TestComparisons(123, std::numeric_limits<float>::lowest()); +} + +TEST(CheckMacrosTest, EnumClass) { + enum class SomeEnumClass { kA, kB, kC }; + TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kA); + TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kB); + TestEqualityComparisons(SomeEnumClass::kC, SomeEnumClass::kB); +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/context.cc b/ruy/context.cc new file mode 100644 index 0000000..4661738 --- /dev/null +++ b/ruy/context.cc @@ -0,0 +1,58 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/context.h" + +#include "ruy/ctx.h" +#include "ruy/ctx_impl.h" +#include "ruy/path.h" +#include "ruy/performance_advisory.h" +#include "ruy/prepacked_cache.h" +#include "ruy/thread_pool.h" +#include "ruy/tune.h" + +namespace ruy { + +Context::Context() : impl_(new CtxImpl) {} +Context::~Context() { delete impl_; } + +const Ctx& Context::ctx() const { return static_cast<const Ctx&>(*impl_); } +Ctx* Context::mutable_ctx() { return static_cast<Ctx*>(impl_); } + +Path Context::last_used_path() const { return ctx().last_used_path(); } +Tuning Context::explicit_tuning() const { return ctx().explicit_tuning(); } +void Context::set_explicit_tuning(Tuning value) { + mutable_ctx()->set_explicit_tuning(value); +} +const ThreadPool& Context::thread_pool() const { return ctx().thread_pool(); } +ThreadPool* Context::mutable_thread_pool() { + return mutable_ctx()->mutable_thread_pool(); +} +int Context::max_num_threads() const { return ctx().max_num_threads(); } +void Context::set_max_num_threads(int value) { + mutable_ctx()->set_max_num_threads(value); +} + +void Context::ClearPrepackedCache() { mutable_ctx()->ClearPrepackedCache(); } + +bool Context::performance_advisory(PerformanceAdvisory advisory) const { + return ctx().performance_advisory(advisory); +} + +void Context::set_runtime_enabled_paths(Path paths) { + mutable_ctx()->SetRuntimeEnabledPaths(paths); +} + +} // namespace ruy diff --git a/ruy/context.h b/ruy/context.h new file mode 100644 index 0000000..79a4b5c --- /dev/null +++ b/ruy/context.h @@ -0,0 +1,108 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Context is the user-facing context class. + +#ifndef RUY_RUY_CONTEXT_H_ +#define RUY_RUY_CONTEXT_H_ + +#include <cstdint> + +namespace ruy { + +class Ctx; +class CtxImpl; +class ThreadPool; +enum class Path : std::uint8_t; +enum class Tuning; +enum class PerformanceAdvisory; + +// A Context holds runtime information used by Ruy. It holds runtime resources +// such as the workers thread pool and the allocator (which holds buffers for +// temporary data), as well as runtime options controlling which Paths are +// enabled (typically based on which instruction sets are detected) and how +// many threads to use. +class Context final { + public: + Context(); + ~Context(); + + // Returns the Path enum value that corresponds to the code path used by + // the last ruy::Mul with this Context. + Path last_used_path() const; + + // Control of whether to use kernels tuned for in-order or out-of-order CPU + // cores. The default is auto-detection, so these methods should only be used + // to override that auto-detection if it's not working as intended or for + // testing. + Tuning explicit_tuning() const; + void set_explicit_tuning(Tuning value); + + // The thread pool held by this context to dispatch a ruy::Mul to worker + // threads. + // + // By default, threads may spin-wait for a few milliseconds before reverting + // to passive wait. This can be controlled by + // `mutable_thread_pool()->set_spin_milliseconds(value)`. + const ThreadPool& thread_pool() const; + ThreadPool* mutable_thread_pool(); + + // Controls the maximum number of threads to be used by ruy::Mul with this + // Context. The number of threads in the pool will be that value minus one, + // as the remaining portion of the work is done directly on the calling + // thread. + // + // This defaults to 1. Multi-threading in ruy is always opt-in. There is + // no auto-detection of hardware concurrency. That is on purpose, ruy focuses + // on mobile applications where such concepts are difficult to define + // (e.g. ARM big.LITTLE). + int max_num_threads() const; + void set_max_num_threads(int value); + + // Returns true of the last ruy::Mul using this Context flagged the specified + // `advisory`. This is reset by each ruy::Mul call. + bool performance_advisory(PerformanceAdvisory advisory) const; + + // When using Matrix::set_cache_policy(), this Context will keep a cache of + // pre-packed matrix data. This function clears that cache. + void ClearPrepackedCache(); + + // Override auto-detection of supported code paths. + // + // Passing `paths == Path::kNone` means reverting to the default behavior. + // This will trigger auto-detection on the next use. + // + // Other values will override auto-detection with the explicitly provided set + // of paths. + // + // Paths in kNonArchPaths are always implicitly supported. + void set_runtime_enabled_paths(Path paths); + + private: + CtxImpl* const impl_; + + const Ctx& ctx() const; + Ctx* mutable_ctx(); + + friend const Ctx* get_ctx(const Context*); + friend Ctx* get_ctx(Context*); + + // Disallow copy + Context(const Context&) = delete; +}; + +} // end namespace ruy + +#endif // RUY_RUY_CONTEXT_H_ diff --git a/ruy/context_get_ctx.cc b/ruy/context_get_ctx.cc new file mode 100644 index 0000000..dce5858 --- /dev/null +++ b/ruy/context_get_ctx.cc @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/context_get_ctx.h" + +#include "ruy/ctx_impl.h" + +namespace ruy { + +const Ctx* get_ctx(const Context* context) { + return static_cast<const Ctx*>(context->impl_); +} +Ctx* get_ctx(Context* context) { return static_cast<Ctx*>(context->impl_); } + +} // namespace ruy diff --git a/ruy/context_get_ctx.h b/ruy/context_get_ctx.h new file mode 100644 index 0000000..cc599fe --- /dev/null +++ b/ruy/context_get_ctx.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Gateway to access the Ctx (internal context interface for ruy code) from +// a Context (public-facing class). Befriended by Context. + +#ifndef THIRD_PARTY_RUY_RUY_CONTEXT_GET_CTX_H_ +#define THIRD_PARTY_RUY_RUY_CONTEXT_GET_CTX_H_ + +#include "ruy/context.h" +#include "ruy/ctx.h" + +namespace ruy { + +const Ctx* get_ctx(const Context* context); +Ctx* get_ctx(Context*); + +} // namespace ruy + +#endif // THIRD_PARTY_RUY_RUY_CONTEXT_GET_CTX_H_ diff --git a/ruy/context_test.cc b/ruy/context_test.cc new file mode 100644 index 0000000..4e69e65 --- /dev/null +++ b/ruy/context_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/context.h" + +#include "ruy/gtest_wrapper.h" +#include "ruy/path.h" +#include "ruy/prepacked_cache.h" +#include "ruy/tune.h" + +namespace ruy { +namespace { + +TEST(ContextTest, ContextClassSanity) { + Context context; + EXPECT_EQ(context.last_used_path(), Path::kNone); + EXPECT_EQ(context.explicit_tuning(), Tuning::kAuto); + EXPECT_EQ(&context.thread_pool(), context.mutable_thread_pool()); + EXPECT_NE(context.mutable_thread_pool(), nullptr); + EXPECT_EQ(context.max_num_threads(), 1); + context.set_explicit_tuning(Tuning::kGeneric); + context.set_max_num_threads(2); + EXPECT_EQ(context.explicit_tuning(), Tuning::kGeneric); + EXPECT_EQ(context.max_num_threads(), 2); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/cpu_cache_params.h b/ruy/cpu_cache_params.h new file mode 100644 index 0000000..202cd54 --- /dev/null +++ b/ruy/cpu_cache_params.h @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_CPU_CACHE_PARAMS_H_ +#define RUY_RUY_CPU_CACHE_PARAMS_H_ + +namespace ruy { + +// Holds some information about a CPU's data caches. +// +// Meaning of 'local': a 'local' cache means a cache that is used by only one +// CPU core, not shared with other cores. It might still be used by multiple +// 'processors' in case of SMT as in Intel HyperThreading. CPUs often have +// multiple levels of local cache, e.g. L1 and L2. We typically return the +// larger one, the assumption being that even the larger one has substantially +// lower latency than any higher (non-local) cache, however as noted below (*) +// the implementation may choose to ignore a cache level. +// +// Meaning of 'last level': this refers to some higher cache level, typically +// shared among multiple CPU cores, so we considered using the terminology +// 'shared' instead of 'last_level'. However that created some confusion of its +// own, as the meaning of 'shared' varies between CPUs, with some CPUs not +// having any level of cache shared among all cores. That is why we stick with +// the 'last_level' terminology, however with the following caveats: +// 1. As noted below (*) the implementation may choose to ignore a cache +// level, which could cause the 'last level' cache according to ruy not to be +// the actual last level. +// 2. On some systems-on-chip there is a 'last level' cache outside of the +// last level cache in the CPU complex. Ruy is not currently doing anything +// specific regarding such caches. +// 3. We haven't figured out how to amend our terminology to be meaningful +// on NUMA architectures. NUMA hasn't been part of ruy's scope so far. +// +// (*) Note on ignoring certain cache levels: +// The implementation may choose to ignore a cache if it's suspected not to +// have compelling performance. This is true about all cache levels, but more +// likely regarding the 'last level' cache. For example, a L4 cache may be +// ignored if we believe that it's not the right latency/size compromise for us, +// so on such a CPU, the L3 cache may be used as the 'last level' cache instead. +// +// (**) Note on CPUs with heterogeneous cores: +// Some CPUs have multiple cores with different local caches. For example, some +// ARM big.LITTLE CPUs have some CPU cores with L1=32k and L2=128k, and some +// other CPU cores with L1=64k and L2=256k or even 512k. On such CPUs, the +// fields in this struct refer to the minimum value over all cores. In other +// words, we use conservative values that do not risk over-estimating local +// cache sizes in case of a migration of our threads to smaller cores. +// +// Example: +// On a Qualcomm S855 SoC, there are 8 CPU cores. Each core has L1 and L2 data +// caches local to it: +// - 4 cores have L1=32k, L2=128k. +// - 3 cores have L1=64k, L2=256k. +// - 1 core has L1=64k, L2=512k. +// All 8 cores share a L3 cache of size 2M, and there is beyond that a SoC-level +// cache of size 3M. +// On such a system, we should have: +// - local_level_cache_size=128k, the smallest L2 size. +// - last_level_cache_size=2M, the L3 cache size, ignoring the SoC-level cache. +struct CpuCacheParams final { + // Minimum value (see (**)), over all cores, of the size in bytes of its local + // cache (see "Meaning of 'local'"). + int local_cache_size = 0; + // Minimum value (see (**)), over all cores, of the size in bytes of its last + // level cache (see "Meaning of 'last level'"). + int last_level_cache_size = 0; +}; + +} // namespace ruy + +#endif // RUY_RUY_CPU_CACHE_PARAMS_H_ diff --git a/ruy/cpuinfo.cc b/ruy/cpuinfo.cc new file mode 100644 index 0000000..b1f54bc --- /dev/null +++ b/ruy/cpuinfo.cc @@ -0,0 +1,163 @@ +#include "ruy/cpuinfo.h" + +#include <algorithm> +#include <cstdint> +#include <limits> + +#include "ruy/check_macros.h" +#include "ruy/cpu_cache_params.h" +#include "ruy/platform.h" + +#ifdef RUY_HAVE_CPUINFO +#include <cpuinfo.h> +#endif + +namespace ruy { + +namespace { +void MakeDummyCacheParams(CpuCacheParams* result) { + // Reasonable dummy values + result->local_cache_size = 32 * 1024; + result->last_level_cache_size = 512 * 1024; +} +} // end namespace + +#ifdef RUY_HAVE_CPUINFO + +CpuInfo::~CpuInfo() { + if (init_status_ == InitStatus::kInitialized) { + cpuinfo_deinitialize(); + } +} + +bool CpuInfo::EnsureInitialized() { + if (init_status_ == InitStatus::kNotYetAttempted) { + init_status_ = Initialize(); + RUY_DCHECK_NE(init_status_, InitStatus::kNotYetAttempted); + } + return init_status_ == InitStatus::kInitialized; +} + +namespace { +void QueryCacheParams(CpuCacheParams* cache_params) { + const int processors_count = cpuinfo_get_processors_count(); + RUY_DCHECK_GT(processors_count, 0); + int overall_local_cache_size = std::numeric_limits<int>::max(); + int overall_last_level_cache_size = std::numeric_limits<int>::max(); + for (int i = 0; i < processors_count; i++) { + int local_cache_size = 0; + int last_level_cache_size = 0; + const cpuinfo_processor* processor = cpuinfo_get_processor(i); + // Loop over cache levels. Ignoring L4 for now: it seems that in CPUs that + // have L4, we would still prefer to stay in lower-latency L3. + for (const cpuinfo_cache* cache : + {processor->cache.l1d, processor->cache.l2, processor->cache.l3}) { + if (!cache) { + continue; // continue, not break, it is possible to have L1+L3 but no + // L2. + } + const bool is_local = + cpuinfo_get_processor(cache->processor_start)->core == + cpuinfo_get_processor(cache->processor_start + + cache->processor_count - 1) + ->core; + if (is_local) { + local_cache_size = cache->size; + } + last_level_cache_size = cache->size; + } + // If no local cache was found, use the last-level cache. + if (!local_cache_size) { + local_cache_size = last_level_cache_size; + } + RUY_DCHECK_GT(local_cache_size, 0); + RUY_DCHECK_GT(last_level_cache_size, 0); + RUY_DCHECK_GE(last_level_cache_size, local_cache_size); + overall_local_cache_size = + std::min(overall_local_cache_size, local_cache_size); + overall_last_level_cache_size = + std::min(overall_last_level_cache_size, last_level_cache_size); + } + cache_params->local_cache_size = overall_local_cache_size; + cache_params->last_level_cache_size = overall_last_level_cache_size; +} +} // end namespace + +CpuInfo::InitStatus CpuInfo::Initialize() { + RUY_DCHECK_EQ(init_status_, InitStatus::kNotYetAttempted); + if (!cpuinfo_initialize()) { + MakeDummyCacheParams(&cache_params_); + return InitStatus::kFailed; + } + QueryCacheParams(&cache_params_); + return InitStatus::kInitialized; +} + +bool CpuInfo::NeonDotprod() { + return EnsureInitialized() && cpuinfo_has_arm_neon_dot(); +} + +bool CpuInfo::Sse42() { + return EnsureInitialized() && cpuinfo_has_x86_sse4_2(); +} + +bool CpuInfo::Avx2Fma() { + return EnsureInitialized() && cpuinfo_has_x86_avx2() && + cpuinfo_has_x86_fma3(); +} + +bool CpuInfo::Avx() { return EnsureInitialized() && cpuinfo_has_x86_avx(); } + +bool CpuInfo::Avx512() { + return EnsureInitialized() && cpuinfo_has_x86_avx512f() && + cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512cd() && + cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512vl(); +} + +bool CpuInfo::AvxVnni() { + return EnsureInitialized() && cpuinfo_has_x86_avx512vnni(); +} + +bool CpuInfo::CurrentCpuIsA55ish() { + if (!EnsureInitialized()) { + return false; + } + + switch (cpuinfo_get_uarch(cpuinfo_get_current_uarch_index())->uarch) { + case cpuinfo_uarch_cortex_a53: + case cpuinfo_uarch_cortex_a55r0: + case cpuinfo_uarch_cortex_a55: + return true; + default: + return false; + } +} + +#else // not defined RUY_HAVE_CPUINFO + +CpuInfo::~CpuInfo() {} +bool CpuInfo::EnsureInitialized() { + if (init_status_ == InitStatus::kNotYetAttempted) { + MakeDummyCacheParams(&cache_params_); + init_status_ = InitStatus::kInitialized; + } + RUY_DCHECK_EQ(init_status_, InitStatus::kInitialized); + return true; +} +bool CpuInfo::NeonDotprod() { return false; } +bool CpuInfo::Sse42() { return false; } +bool CpuInfo::Avx() { return false; } +bool CpuInfo::Avx2Fma() { return false; } +bool CpuInfo::Avx512() { return false; } +bool CpuInfo::AvxVnni() { return false; } +bool CpuInfo::CurrentCpuIsA55ish() { return false; } + +#endif + +const CpuCacheParams& CpuInfo::CacheParams() { + EnsureInitialized(); + // On failure, EnsureInitialized leaves dummy values in cache_params_. + return cache_params_; +} + +} // namespace ruy diff --git a/ruy/cpuinfo.h b/ruy/cpuinfo.h new file mode 100644 index 0000000..e45fa51 --- /dev/null +++ b/ruy/cpuinfo.h @@ -0,0 +1,61 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_CPUINFO_H_ +#define RUY_RUY_CPUINFO_H_ + +#include "ruy/cpu_cache_params.h" + +namespace ruy { + +// Wraps the functionality that ruy needs from the cpuinfo library. +class CpuInfo final { + public: + CpuInfo() {} + ~CpuInfo(); + + // ARM features + bool NeonDotprod(); + + // X86 features + bool Sse42(); + bool Avx(); + bool Avx2Fma(); + bool Avx512(); + bool AvxVnni(); + + // Common features + const CpuCacheParams& CacheParams(); + bool CurrentCpuIsA55ish(); + + private: + enum class InitStatus { + kNotYetAttempted, + kInitialized, + kFailed, + }; + + InitStatus init_status_ = InitStatus::kNotYetAttempted; + CpuCacheParams cache_params_; + + bool EnsureInitialized(); + InitStatus Initialize(); + + CpuInfo(const CpuInfo&) = delete; +}; + +} // namespace ruy + +#endif // RUY_RUY_CPUINFO_H_ diff --git a/ruy/create_trmul_params.h b/ruy/create_trmul_params.h new file mode 100644 index 0000000..531e066 --- /dev/null +++ b/ruy/create_trmul_params.h @@ -0,0 +1,484 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required_capacity by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of CreateTrMulParams, see function comment. + +#ifndef RUY_RUY_CREATE_TRMUL_PARAMS_H_ +#define RUY_RUY_CREATE_TRMUL_PARAMS_H_ + +#include <cstdint> +#include <cstring> +#include <type_traits> + +#include "ruy/allocator.h" +#include "ruy/ctx.h" +#include "ruy/kernel.h" +#include "ruy/mat.h" +#include "ruy/mul_params.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/performance_advisory.h" +#include "ruy/trace.h" +#include "ruy/trmul_params.h" + +namespace ruy { +// While the only entry point to this file is CreateTrMulParams, its templatized +// nature requires putting more code in this header than we would like. This +// internal implementation code is enclosed in namespace 'detail'. +namespace detail { + +inline void CreatePackedLayout(const MatLayout& src, + const KernelLayout& kernel_layout, + PMatLayout* packed_layout) { + // Packed matrices are always column-major, because in TrMul that is always + // the dimension of traversal of the kernel's inner loop. + packed_layout->order = Order::kColMajor; + packed_layout->rows = round_up_pot(src.rows, kernel_layout.rows); + packed_layout->cols = round_up_pot(src.cols, kernel_layout.cols); + packed_layout->stride = packed_layout->rows; + packed_layout->kernel = kernel_layout; +} + +template <typename Scalar, typename PackedScalar> +void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout, + TrMulParams* params) { + // Ruy always uses 32-bit signed accumulators for quantized + // matrix multiplication, so we would like to always use std::int32_t + // unconditionally for SumsType. + // However, for floating point types, we still need a reasonable type here to + // avoid tripping assertions elsewhere in the code. + using SumsType = + typename std::conditional<std::is_floating_point<Scalar>::value, Scalar, + std::int32_t>::type; + + const EMat& src = params->src[side]; + PEMat* packed_matrix = ¶ms->packed_matrix[side]; + packed_matrix->data_type = Type::Create<PackedScalar>(); + packed_matrix->sums_type = Type::Create<SumsType>(); + CreatePackedLayout(src.layout, kernel_layout, &packed_matrix->layout); + packed_matrix->zero_point = Pack<PackedScalar, Scalar>(src.zero_point); +} + +template <typename KernelType> +struct CheckKernelPathImpl { + static void Run(Path) { + // Do nothing. + // Path fallbacks are normal in general (see RUY_INHERIT_KERNEL). + // That is to say that one may instantiate ruy::Mul with a weird combination + // of types, such as LhsScalar==float and RhsScalar==double, and have it + // work by silently falling back to Path::kStandardCpp. Only in specific + // cases do we have dedicated kernels overriding that fallback, and that is + // what partial specializations of this template will check. + } +}; + +#if RUY_DCHECK_IS_ENABLED +template <Path ThePath, typename SrcScalar, typename AccumScalar, + typename DstScalar> +struct CheckKernelPathImpl<Kernel<ThePath, SrcScalar, SrcScalar, DstScalar, + MulParams<AccumScalar, DstScalar>>> + final { + using KernelType = Kernel<ThePath, SrcScalar, SrcScalar, DstScalar, + MulParams<AccumScalar, DstScalar>>; + static void Run(Path expected_path) { + // We want to assert that we are using a dedicated Kernel specialization and + // not a fallback when we know we are in a case where such a kernel + // specialization exists. At the moment in the current state of ruy's + // architecture support for ARM and x86, that is when LhsScalar==RhsScalar + // (already implied in this partial specialization) and when that type is + // either float, int8, or uint8. Indeed, we have kernels supporting float + // and int8, and we have the packing code converting uint8 to int8 (see + // PackedTypeImpl). + static constexpr bool kSrcScalarTypeSupportsFastKernels = + std::is_same<SrcScalar, float>::value || + std::is_same<SrcScalar, std::int8_t>::value || + std::is_same<SrcScalar, std::uint8_t>::value; + if (kSrcScalarTypeSupportsFastKernels) { + RUY_DCHECK_EQ(expected_path, KernelType::kPath); + } + } +}; +#endif + +template <typename KernelType> +void CheckKernelPath(Path expected_path) { + CheckKernelPathImpl<KernelType>::Run(expected_path); +} + +template <Path ThePath, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +void PopulateTrMulParams(TrMulParams* params) { + RUY_TRACE_SCOPE; + using PackedLhsScalar = PackedType<ThePath, LhsScalar>; + using PackedRhsScalar = PackedType<ThePath, RhsScalar>; + using Kernel = + Kernel<ThePath, PackedLhsScalar, PackedRhsScalar, AccumScalar, DstScalar>; + using LhsKernelLayout = typename Kernel::LhsLayout; + using RhsKernelLayout = typename Kernel::RhsLayout; + + params->path = ThePath; + + CreatePackedMatrix<LhsScalar, PackedLhsScalar>( + Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params); + CreatePackedMatrix<RhsScalar, PackedRhsScalar>( + Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params); + params->run_pack[Side::kLhs] = + &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>; + params->run_pack[Side::kRhs] = + &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>; + params->run_kernel = &RunKernel<Kernel>::Run; + CheckKernelPath<Kernel>(ThePath); + RUY_TRACE_INFO(POPULATE_TRMUL_PARAMS); +} + +// PopulateTrMulParamsAllCompiledPaths calls into one of multiple +// instantiations of PopulateTrMulParams. For each bit that is set in +// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path +// corresponding to that single bit. The call to PopulateTrMulParams is +// guarded by a runtime check that it is in fact the dynamically selected path. +// +// PopulateTrMulParamsAllCompiledPaths is implemented with template +// metaprogramming by mutual recursion between PathSearchCountdown and +// PathSearchCompiledPaths. +// +// PopulateTrMulParamsAllCompiledPaths is logically implementing the following +// computation: +// +// template <Path CompiledPaths> +// void PopulateTrMulParamsAllCompiledPaths(Path the_path, +// TrMulParams* params) { +// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1] +// Path current_path = static_cast<Path>(1 << bit); +// if ((CompiledPaths & current_path) != Path::kNone) { // [2] +// if (current_path == the_path) { // [3] +// PopulateTrMulParams<current_path, ...>(the_path, params); +// return; +// } +// } +// } +// } +// +// +// +// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is +// done in the recursion of PathSearchOnlyCompiledPaths. +// [2] - Done by PathSearchOnlyCompiledPaths's partial template +// specialization on InCompiledPaths. This is the check which necessitates +// doing the whole computation at C++ compile time. +// [3] - Done by the `if` in the main definition of +// PathSearchOnlyCompiledPaths. +// +// The template metaprogramming is necessary because: +// - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++ +// compile-time constant. +// - PopulateTrMulParamsAllCompiledPaths must not instantiate +// inner loops for paths that are not in CompiledPaths, since that can result in +// bogus instantiations which cause a compile time failure. +template <Path CompiledPaths, int BitNumber, typename LhsScalar, + typename RhsScalar, typename AccumScalar, typename DstScalar> +struct PathSearchCountdown; + +template <Path CompiledPaths, bool InCompiledPaths, int BitNumber, + typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +struct PathSearchOnlyCompiledPaths { + static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber); + static void Search(Path the_path, TrMulParams* params) { + if (kCurrentPath == the_path) { + PopulateTrMulParams<kCurrentPath, LhsScalar, RhsScalar, AccumScalar, + DstScalar>(params); + return; + } + PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar, + AccumScalar, DstScalar>::Search(the_path, params); + } +}; + +// Skip this iteration if CompiledPaths doesn't contain the specified path. +template <Path CompiledPaths, int BitNumber, typename LhsScalar, + typename RhsScalar, typename AccumScalar, typename DstScalar> +struct PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar, + RhsScalar, AccumScalar, DstScalar> { + static void Search(Path the_path, TrMulParams* params) { + PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar, + AccumScalar, DstScalar>::Search(the_path, params); + } +}; + +template <Path CompiledPaths, int BitNumber, typename LhsScalar, + typename RhsScalar, typename AccumScalar, typename DstScalar> +struct PathSearchCountdown { + static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber); + static void Search(Path the_path, TrMulParams* params) { + PathSearchOnlyCompiledPaths< + CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber, + LhsScalar, RhsScalar, AccumScalar, DstScalar>::Search(the_path, params); + } +}; + +// Termination of the countdown. If the counter reaches -1, then we haven't +// found the specified path. +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +struct PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, AccumScalar, + DstScalar> { + static void Search(Path, TrMulParams*) { RUY_DCHECK(false); } +}; + +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) { + RUY_TRACE_SCOPE; + return PathSearchCountdown<CompiledPaths, 8 * sizeof(Path) - 1, LhsScalar, + RhsScalar, AccumScalar, + DstScalar>::Search(the_path, params); +} + +template <typename AccumScalar, typename DstScalar> +void AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized( + const MulParams<AccumScalar, DstScalar>& mul_params, int user_size, + int user_capacity) { +#if RUY_DCHECK_IS_ENABLED + if (mul_params.bias()) { + for (int i = user_size; i < user_capacity; i++) { + RUY_DCHECK_EQ(mul_params.bias()[i], 0); + } + } + if (mul_params.multiplier_fixedpoint_perchannel()) { + for (int i = user_size; i < user_capacity; i++) { + RUY_DCHECK_EQ(mul_params.multiplier_fixedpoint_perchannel()[i], 0); + } + } + if (mul_params.multiplier_exponent_perchannel()) { + for (int i = user_size; i < user_capacity; i++) { + RUY_DCHECK_EQ(mul_params.multiplier_exponent_perchannel()[i], 0); + } + } +#else + (void)mul_params; + (void)user_size; + (void)user_capacity; +#endif +} + +template <typename AccumScalar, typename DstScalar, + bool HaveQuantizedMultipliers = + std::is_same<AccumScalar, std::int32_t>::value && + !std::is_same<DstScalar, std::int32_t>::value> +struct EnsurePerChannelBuffersLargeEnoughImpl { + static void Run(const TrMulParams& params, Allocator* allocator, + MulParams<AccumScalar, DstScalar>* mul_params) { + const Side channel_side = + mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs + : Side::kRhs; + const int required_capacity = + params.packed_matrix[channel_side].layout.cols; + const int user_size = params.src[channel_side].layout.cols; + const int user_capacity = round_up_pot( + user_size, mul_params->perchannel_buffers_capacity_rounding()); + // We should have already checked earlier for the case where + // user_capacity >= required_capacity. + RUY_DCHECK_GT(required_capacity, user_capacity); + if (mul_params->bias()) { + AccumScalar* new_data = + allocator->Allocate<AccumScalar>(required_capacity); + std::memcpy(new_data, mul_params->bias(), + user_size * sizeof(AccumScalar)); + std::memset(new_data + user_size, 0, + (required_capacity - user_size) * sizeof(AccumScalar)); + mul_params->set_bias(new_data); + } + if (mul_params->multiplier_fixedpoint_perchannel()) { + AccumScalar* new_data = + allocator->Allocate<AccumScalar>(required_capacity); + std::memcpy(new_data, mul_params->multiplier_fixedpoint_perchannel(), + user_size * sizeof(AccumScalar)); + std::memset(new_data + user_size, 0, + (required_capacity - user_size) * sizeof(AccumScalar)); + mul_params->set_multiplier_fixedpoint_perchannel(new_data); + } + if (mul_params->multiplier_exponent_perchannel()) { + int* new_data = allocator->Allocate<int>(required_capacity); + std::memcpy(new_data, mul_params->multiplier_exponent_perchannel(), + user_size * sizeof(int)); + std::memset(new_data + user_size, 0, + (required_capacity - user_size) * sizeof(int)); + mul_params->set_multiplier_exponent_perchannel(new_data); + } + } +}; + +template <typename AccumScalar, typename DstScalar> +struct EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar, false> { + static void Run(const TrMulParams& params, Allocator* allocator, + MulParams<AccumScalar, DstScalar>* mul_params) { + const Side channel_side = + mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs + : Side::kRhs; + const int required_capacity = + params.packed_matrix[channel_side].layout.cols; + const int user_size = params.src[channel_side].layout.cols; + const int user_capacity = round_up_pot( + user_size, mul_params->perchannel_buffers_capacity_rounding()); + // We should have already checked earlier for the case where + // user_capacity >= required_capacity. + RUY_DCHECK_GT(required_capacity, user_capacity); + if (mul_params->bias()) { + AccumScalar* new_data = + allocator->Allocate<AccumScalar>(required_capacity); + std::memcpy(new_data, mul_params->bias(), + user_size * sizeof(AccumScalar)); + std::memset(new_data + user_size, 0, + (required_capacity - user_size) * sizeof(AccumScalar)); + mul_params->set_bias(new_data); + } + } +}; + +template <typename AccumScalar, typename DstScalar> +void EnsurePerChannelBuffersLargeEnough( + const TrMulParams& params, Ctx* ctx, + MulParams<AccumScalar, DstScalar>* mul_params) { + // Early exit in the common case where the packed matrix size matches the + // number of channels (as opposed to having been rounded up to a slightly + // larger value). + const Side channel_side = + mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs + : Side::kRhs; + const int required_capacity = params.packed_matrix[channel_side].layout.cols; + const int user_size = params.src[channel_side].layout.cols; + const int user_capacity = round_up_pot( + user_size, mul_params->perchannel_buffers_capacity_rounding()); + AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized( + *mul_params, user_size, user_capacity); + if (required_capacity <= user_capacity) { + return; + } + ctx->set_performance_advisory( + PerformanceAdvisory::kReallocatedPerChannelBuffer); + EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar>::Run( + params, ctx->GetMainAllocator(), mul_params); +} + +// Ensures that `params->mul_params_bytes` contains MulParams data that's ready +// to be consumed by the kernel. As a first-order approximation, that is simply +// copying the user-provided `mul_params`, however there are a few changes. +// +// 1. The specified `channel_dimension` value overrides the channel_dimension +// member in `mul_params`. The reason why `channel_dimension` is being +// special-cased among MulParams members is that we will need to transpose +// MulParams, and that consists just in toggling channel_dimension. +// 2. Per-channel buffers may be reallocated, see +// EnsurePerChannelBuffersLargeEnough. +template <typename AccumScalar, typename DstScalar> +void FinalizeMulParams(const MulParams<AccumScalar, DstScalar>& mul_params, + ChannelDimension channel_dimension, Ctx* ctx, + TrMulParams* params) { + using MulParamsType = MulParams<AccumScalar, DstScalar>; + static_assert(alignof(MulParamsType) <= kMaxMulParamsAlignment, ""); + static_assert(sizeof(MulParamsType) <= kMaxMulParamsSize, ""); + static_assert(std::is_trivially_copyable<MulParamsType>::value, ""); + auto* dst_mul_params = + reinterpret_cast<MulParamsType*>(params->mul_params_bytes); + std::memcpy(dst_mul_params, &mul_params, sizeof(MulParamsType)); + dst_mul_params->set_channel_dimension(channel_dimension); + EnsurePerChannelBuffersLargeEnough(*params, ctx, dst_mul_params); +} + +// In this function, the `channel_dimension` parameter overrides the value +// of the channel_dimension member in the `mul_params` parameter. See the +// FinalizeMulParams comment. +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +void CreateTrMulParamsAssumingColMajorDst( + const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, + const Mat<DstScalar>& dst, + const MulParams<AccumScalar, DstScalar>& mul_params, + ChannelDimension channel_dimension, Ctx* ctx, TrMulParams* params) { + RUY_TRACE_SCOPE; + RUY_DCHECK(IsColMajor(dst.layout)); + + // Fill in the fields we already know. + params->src[Side::kLhs] = EraseType(lhs); + params->src[Side::kRhs] = EraseType(rhs); + params->dst = EraseType(dst); + + // Determine which exact Path we're going to take in this Mul call. + // This is cheap because it's cached in `ctx`. In user scenarios this always + // evaluates to the same value on a given machine with given `CompiledPaths`, + // but could be invalidated by a call to Ctx::SetRuntimeEnabledPaths(), which + // might be exposed publicly in Context in the future. + const Path the_path = ctx->SelectPath(CompiledPaths); + + RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST); + + // If we ever need again to fall back to Path::kStandardCpp, this is a good + // place to do it -- just pass Path::kStandardCpp as both the template and + // runtime parameters in this function call. + // In the past we did that here (as version control history remembers). + // A typical reason why we might need to resurrect that is if we implement + // a new Path (i.e. port to a new ISA) and need to subdivide that work into + // a series of incremental changes. + PopulateTrMulParamsAllCompiledPaths<CompiledPaths, LhsScalar, RhsScalar, + AccumScalar, DstScalar>(the_path, params); + + // This must be done last, as it depends on the specific choice of kernel. + // Specifically, the EnsurePerChannelBuffersLargeEnough part of this will read + // the packed matrix layouts that are written to `params` by the above + // PopulateTrMulParams* call. + FinalizeMulParams(mul_params, channel_dimension, ctx, params); +} + +} // namespace detail + +inline ChannelDimension Transpose(ChannelDimension channel_dimension) { + return channel_dimension == ChannelDimension::kCol ? ChannelDimension::kRow + : ChannelDimension::kCol; +} + +// CreateTrMulParams's output is a TrMulParams object that encodes +// all of the input information required_capacity by the middle-end, that is, +// the TrMul function. +// +// CreateTrMulParams performs the following tasks: +// 1. Reduce to the case of column-major destination, by transposing the +// whole problem as needed. +// 2. Select the single code path to be taken, out of the set of paths +// described by the `CompiledPaths` template parameter, based on the +// runtime input parameter `the_path`. +// 3. Perform type-erasure, converting templatized typed input parameters +// to the un-typed data stored in TrMulParams. +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +void CreateTrMulParams(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, + const Mat<DstScalar>& dst, + const MulParams<AccumScalar, DstScalar>& mul_params, + Ctx* ctx, TrMulParams* params) { + RUY_TRACE_SCOPE; + ChannelDimension channel_dimension = mul_params.channel_dimension(); + if (IsColMajor(dst.layout)) { + detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>( + lhs, rhs, dst, mul_params, channel_dimension, ctx, params); + } else { + RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_TRANSPOSING); + detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>( + rhs, lhs, Transpose(dst), mul_params, Transpose(channel_dimension), ctx, + params); + } +} + +} // namespace ruy + +#endif // RUY_RUY_CREATE_TRMUL_PARAMS_H_ diff --git a/ruy/ctx.cc b/ruy/ctx.cc new file mode 100644 index 0000000..0ef098d --- /dev/null +++ b/ruy/ctx.cc @@ -0,0 +1,216 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/ctx.h" + +#include <cstdlib> +#include <functional> +#include <string> + +#include "ruy/check_macros.h" +#include "ruy/cpuinfo.h" +#include "ruy/ctx_impl.h" +#include "ruy/have_built_path_for.h" +#include "ruy/path.h" +#include "ruy/performance_advisory.h" +#include "ruy/platform.h" +#include "ruy/prepacked_cache.h" +#include "ruy/trace.h" + +namespace ruy { + +const CtxImpl& Ctx::impl() const { return static_cast<const CtxImpl&>(*this); } +CtxImpl* Ctx::mutable_impl() { return static_cast<CtxImpl*>(this); } + +Path Ctx::last_used_path() const { return impl().last_used_path_; } +Tuning Ctx::explicit_tuning() const { return impl().explicit_tuning_; } +void Ctx::set_explicit_tuning(Tuning value) { + mutable_impl()->explicit_tuning_ = value; +} +const ThreadPool& Ctx::thread_pool() const { return impl().thread_pool_; } +ThreadPool* Ctx::mutable_thread_pool() { return &mutable_impl()->thread_pool_; } +int Ctx::max_num_threads() const { return impl().max_num_threads_; } +void Ctx::set_max_num_threads(int value) { + mutable_impl()->max_num_threads_ = value; +} +void Ctx::clear_performance_advisories() { + mutable_impl()->performance_advisory_ = PerformanceAdvisory::kNone; +} +void Ctx::set_performance_advisory(PerformanceAdvisory advisory) { + mutable_impl()->performance_advisory_ = + mutable_impl()->performance_advisory_ | advisory; +} +bool Ctx::performance_advisory(PerformanceAdvisory advisory) const { + return (impl().performance_advisory_ & advisory) != + PerformanceAdvisory::kNone; +} + +void Ctx::SetRuntimeEnabledPaths(Path paths) { + if (paths == Path::kNone) { + // Revert to default behavior using runtime detection. + mutable_impl()->runtime_enabled_paths_ = Path::kNone; + } else { + // Explicitly set enabled paths. Ensure that non-arch are always enabled + // (needed for fallbacks). + mutable_impl()->runtime_enabled_paths_ = paths | kNonArchPaths; + } +} + +CpuInfo* Ctx::mutable_cpuinfo() { return &mutable_impl()->cpuinfo_; } + +namespace { + +int GetHexIntEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stoi(val, nullptr, 16); +} + +// For each Path bit set in `paths_to_test`, performs runtime detection and +// sets the corresponding bit in the return value if and only if it is +// supported. Path bits that are not set in the input +// `paths_to_detect` value are also left not set in the return value. +Path DetectRuntimeSupportedPaths(Path paths_to_detect, CpuInfo* cpuinfo) { + // Paths in kNonArchPathsIncludingInternalVariants are always implicitly + // supported. Further logic below may add more bits to `results`. + Path result = kNonArchPathsIncludingInternalVariants; + + // Conditionally sets the `path` bit in `result`, if reported as supported + // by the `is_supported` predicate. + auto maybe_add = [&](Path path, std::function<bool(void)> is_supported) { + if ((paths_to_detect & path) != Path::kNone) { + if (is_supported()) { + result = result | path; + } + } + }; + +#if RUY_PLATFORM_ARM + // NEON is unconditionally available on ARM64. + // On ARM32 it's technically possible for it to be unavailable, but we've + // always chosen to just crash on such devices. We could reevaluate that, + // however for non-NEON devices to be actually supported, we would need to + // address also compiler-generated NEON code. That would mean to remove + // -mfpu=neon from ruy_copts and only use this flag in select NEON translation + // units, and implement have_built_path_for_neon, similar to the x86 SIMD + // paths. + maybe_add(Path::kNeon, []() { return true; }); + + // NEON dotprod requires runtime detection, however unlike the x86 SIMD paths + // it still does not require have_built_path_for because we unconditionally + // build it at the moment. That is largely because we have had to machine + // encode dotprod instructions, so we don't actually rely on toolchain support + // for them. + maybe_add(Path::kNeonDotprod, [=]() { return cpuinfo->NeonDotprod(); }); +#elif RUY_PLATFORM_X86 + // x86 SIMD paths currently require both runtime detection, and detection of + // whether we're building the path at all. + maybe_add(Path::kAvx, + [=]() { return HaveBuiltPathForAvx() && cpuinfo->Avx(); }); + maybe_add(Path::kAvx2Fma, + [=]() { return HaveBuiltPathForAvx2Fma() && cpuinfo->Avx2Fma(); }); + maybe_add(Path::kAvx512, + [=]() { return HaveBuiltPathForAvx512() && cpuinfo->Avx512(); }); +#else + (void)maybe_add; + (void)cpuinfo; +#endif + + // Sanity checks + RUY_DCHECK_EQ(kNonArchPaths & ~result, Path::kNone); + RUY_DCHECK_EQ( + result & ~(kNonArchPathsIncludingInternalVariants | paths_to_detect), + Path::kNone); + return result; +} + +} // namespace + +Path Ctx::GetRuntimeEnabledPaths() { + RUY_TRACE_SCOPE; + // Just a shorthand alias. Using a pointer to make it clear we're mutating + // this value in-place. + Path* paths = &mutable_impl()->runtime_enabled_paths_; + + // The value Path::kNone indicates the initial state before detection has been + // performed. + if (*paths != Path::kNone) { + RUY_TRACE_INFO(GET_RUNTIME_ENABLED_PATHS_USING_SET_VALUE); + return *paths; + } + // User may have set path explicitly in env var. + Path paths_bitfield = static_cast<Path>(GetHexIntEnvVarOrZero("RUY_PATHS")); + if (paths_bitfield != Path::kNone) { + *paths = paths_bitfield; + RUY_TRACE_INFO(GET_RUNTIME_ENABLED_PATHS_USING_ENV_VAR); + return *paths; + } + // Finally, use runtime detection. + *paths = DetectRuntimeSupportedPaths(kAllPaths, mutable_cpuinfo()); + RUY_TRACE_INFO(GET_RUNTIME_ENABLED_PATHS_USING_DETECTION); + return *paths; +} + +Path Ctx::SelectPath(Path compiled_paths) { + return mutable_impl()->last_used_path_ = + GetMostSignificantPath(compiled_paths & GetRuntimeEnabledPaths()); +} + +void Ctx::EnsureThreadSpecificResources(int thread_count) { + auto& resources = mutable_impl()->thread_specific_resources_; + while (thread_count > static_cast<int>(resources.size())) { + resources.emplace_back(new ThreadSpecificResource); + } + RUY_DCHECK_LE(thread_count, static_cast<int>(resources.size())); +} + +TuningResolver* Ctx::GetThreadSpecificTuningResolver(int thread_index) const { + const auto& resources = impl().thread_specific_resources_; + RUY_DCHECK_LT(thread_index, static_cast<int>(resources.size())); + return &resources[thread_index]->tuning_resolver; +} + +Allocator* Ctx::GetThreadSpecificAllocator(int thread_index) const { + const auto& resources = impl().thread_specific_resources_; + RUY_DCHECK_LT(thread_index, static_cast<int>(resources.size())); + return &resources[thread_index]->allocator; +} + +Allocator* Ctx::GetMainAllocator() { + if (!impl().main_allocator_) { + mutable_impl()->main_allocator_.reset(new Allocator); + } + return impl().main_allocator_.get(); +} + +PrepackedCache* Ctx::GetPrepackedCache() { + if (!impl().prepacked_cache_) { + mutable_impl()->prepacked_cache_.reset(new PrepackedCache); + } + return impl().prepacked_cache_.get(); +} + +Tuning Ctx::GetMainThreadTuning() { + EnsureThreadSpecificResources(1); + TuningResolver* tuning_resolver = GetThreadSpecificTuningResolver(0); + tuning_resolver->SetTuning(explicit_tuning()); + return tuning_resolver->Resolve(mutable_cpuinfo()); +} + +void Ctx::ClearPrepackedCache() { mutable_impl()->prepacked_cache_ = nullptr; } + +} // namespace ruy diff --git a/ruy/ctx.h b/ruy/ctx.h new file mode 100644 index 0000000..df9dee2 --- /dev/null +++ b/ruy/ctx.h @@ -0,0 +1,91 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Ctx is the internal context interface class used by most of ruy's own code. +// It is subclassed by CtxImpl which provides the actual data members. + +#ifndef RUY_RUY_CTX_H_ +#define RUY_RUY_CTX_H_ + +#include <cstdint> + +namespace ruy { + +class CtxImpl; +class ThreadPool; +class Allocator; +class TuningResolver; +class PrepackedCache; +class CpuInfo; +enum class Path : std::uint8_t; +enum class Tuning; +enum class PerformanceAdvisory; + +// Ctx is the internal context class used throughout ruy code. Whereas Context +// is exposed to users, Ctx is internal to ruy. As many of ruy's internal +// headers, included by ruy public headers, need to use Ctx, it is important +// that it does not include definition of all the actual data members. This is +// solved by a variant of the 'pimpl' idiom, where instead of being implemented +// in the usual way with a pointer member, it is implemented in a subclass, +// CtxImpl. +class Ctx /* not final, subclassed by CtxImpl */ { + public: + Path last_used_path() const; + Tuning explicit_tuning() const; + void set_explicit_tuning(Tuning value); + const ThreadPool& thread_pool() const; + ThreadPool* mutable_thread_pool(); + int max_num_threads() const; + void set_max_num_threads(int value); + CpuInfo* mutable_cpuinfo(); + void clear_performance_advisories(); + void set_performance_advisory(PerformanceAdvisory advisory); + bool performance_advisory(PerformanceAdvisory advisory) const; + + // Returns the set of Path's that are available. By default, this is based on + // runtime detection of CPU features, as well as on which code paths were + // built. Detection results are stored on the context object so that + // subsequent calls are fast. This is overridden by SetRuntimeEnabledPaths. + Path GetRuntimeEnabledPaths(); + + // Override auto-detection of supported code paths. + // + // Passing `paths == Path::kNone` means reverting to the default behavior. + // This will trigger auto-detection on the next use. + // + // Other values will override auto-detection with the explicitly provided set + // of paths. + // + // Paths in kNonArchPaths are always implicitly supported. + void SetRuntimeEnabledPaths(Path paths); + + Path SelectPath(Path compiled_paths); + void EnsureThreadSpecificResources(int thread_count); + TuningResolver* GetThreadSpecificTuningResolver(int thread_index) const; + Allocator* GetThreadSpecificAllocator(int thread_index) const; + Allocator* GetMainAllocator(); + PrepackedCache* GetPrepackedCache(); + Tuning GetMainThreadTuning(); + void ClearPrepackedCache(); + + private: + // Downcast helpers. + const CtxImpl& impl() const; + CtxImpl* mutable_impl(); +}; + +} // namespace ruy + +#endif // RUY_RUY_CTX_H_ diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h new file mode 100644 index 0000000..0a07ef6 --- /dev/null +++ b/ruy/ctx_impl.h @@ -0,0 +1,84 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Internal implementation details for Ctx. Drags in the entire world. Avoid +// #including this, use "ctx.h" instead. + +#ifndef RUY_RUY_CTX_IMPL_H_ +#define RUY_RUY_CTX_IMPL_H_ + +#include <cstddef> +#include <memory> +#include <vector> + +#include "ruy/allocator.h" +#include "ruy/cpuinfo.h" +#include "ruy/ctx.h" +#include "ruy/path.h" +#include "ruy/performance_advisory.h" +#include "ruy/prepacked_cache.h" +#include "ruy/thread_pool.h" +#include "ruy/tune.h" + +namespace ruy { + +// The resources private to each Ruy thread. +struct ThreadSpecificResource final { + // Each thread may be running on a different microarchitecture. For example, + // some threads may be on big cores, while others are on little cores. Thus, + // it's best for the tuning to be per-thread. + TuningResolver tuning_resolver; + // Each thread has its own local allocator. + Allocator allocator; +}; + +// CtxImpl is what actually holds all the data members in a context. +// It is a subclass of Ctx, which provides the interface that is what most +// of ruy's code needs. +// +// A key requirement is that since many ruy files, including public headers, +// need a definition of Ctx, the "ctx.h" header defining it must minimize how +// many other ruy internal headers it includes. That is achieved by putting data +// members in the CtxImpl subclass, and ensuring that only a few .cc files, not +// header files, need a definition of CtxImpl. +class CtxImpl final : public Ctx { + private: + friend class Ctx; + + // Single Path bit indicating which Path was used last. + Path last_used_path_ = Path::kNone; + PerformanceAdvisory performance_advisory_ = PerformanceAdvisory::kNone; + Tuning explicit_tuning_ = Tuning::kAuto; + ThreadPool thread_pool_; + int max_num_threads_ = 1; + // Allocator for main thread work before invoking the threadpool. + // Our simple Allocator does not allow reserving/allocating more blocks + // while it's already in committed state, so the main thread needs both + // this allocator, and its per-thread allocator. + std::unique_ptr<Allocator> main_allocator_; + std::unique_ptr<PrepackedCache> prepacked_cache_; + // Set of Paths enabled at runtime. By default, that is based on runtime + // detection, but may be overridden. The initial value kNone + // means that detection has not yet been performed. + Path runtime_enabled_paths_ = Path::kNone; + CpuInfo cpuinfo_; + // State for each thread in the thread pool. Entry 0 is the main thread. + std::vector<std::unique_ptr<ThreadSpecificResource>> + thread_specific_resources_; +}; + +} // namespace ruy + +#endif // RUY_RUY_CTX_IMPL_H_ diff --git a/ruy/ctx_test.cc b/ruy/ctx_test.cc new file mode 100644 index 0000000..e55dcfc --- /dev/null +++ b/ruy/ctx_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/ctx_impl.h" +#include "ruy/gtest_wrapper.h" +#include "ruy/path.h" +#include "ruy/platform.h" + +namespace ruy { +namespace { + +TEST(ContextInternalTest, EnabledPathsGeneral) { + CtxImpl ctx; + const auto ruy_paths = ctx.GetRuntimeEnabledPaths(); + const auto ruy_paths_repeat = ctx.GetRuntimeEnabledPaths(); + ASSERT_EQ(ruy_paths, ruy_paths_repeat); + EXPECT_NE(ruy_paths, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp); +} + +#if RUY_PLATFORM_X86 +TEST(ContextInternalTest, EnabledPathsX86Explicit) { + CtxImpl ctx; + ctx.SetRuntimeEnabledPaths(Path::kAvx2Fma); + const auto ruy_paths = ctx.GetRuntimeEnabledPaths(); + EXPECT_EQ(ruy_paths, Path::kStandardCpp | Path::kAvx2Fma); +} +#endif // RUY_PLATFORM_X86 + +#if RUY_PLATFORM_ARM +TEST(ContextInternalTest, EnabledPathsX86Explicit) { + CtxImpl ctx; + ctx.SetRuntimeEnabledPaths(Path::kNeonDotprod); + const auto ruy_paths = ctx.GetRuntimeEnabledPaths(); + EXPECT_EQ(ruy_paths, Path::kStandardCpp | Path::kNeonDotprod); +} + +TEST(ContextInternalTest, EnabledPathsArmDefault) { + CtxImpl ctx; + const auto ruy_paths = ctx.GetRuntimeEnabledPaths(); + EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp); + // NEON is always assumed to be supported at the moment. + EXPECT_EQ(ruy_paths & Path::kNeon, Path::kNeon); +} +#endif // RUY_PLATFORM_ARM + +TEST(ContextInternalTest, ThreadSpecificResources) { + CtxImpl ctx; + for (int i = 1; i <= 4; i++) { + ctx.EnsureThreadSpecificResources(i); + for (int j = 0; j < i; j++) { + EXPECT_NE(ctx.GetThreadSpecificAllocator(j), nullptr); + EXPECT_NE(ctx.GetThreadSpecificTuningResolver(j), nullptr); + } + } +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/frontend.cc b/ruy/frontend.cc new file mode 100644 index 0000000..01ee474 --- /dev/null +++ b/ruy/frontend.cc @@ -0,0 +1,37 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/frontend.h" + +#include "ruy/allocator.h" +#include "ruy/prepare_packed_matrices.h" +#include "ruy/trmul.h" +#include "ruy/trmul_params.h" + +namespace ruy { + +void MulFrontEndFromTrMulParams(Ctx* ctx, TrMulParams* params) { + RUY_TRACE_SCOPE; + // Handle Matrix::cache_policy, possibly retrieving existing packed matrices + // or packing and caching now. + PreparePackedMatrices(ctx, params); + + // We're done with the front-end work, now enter the middle-end. + TrMul(ctx, params); + + ctx->GetMainAllocator()->FreeAll(); +} + +} // namespace ruy diff --git a/ruy/frontend.h b/ruy/frontend.h new file mode 100644 index 0000000..a79f590 --- /dev/null +++ b/ruy/frontend.h @@ -0,0 +1,99 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of MulFrontEnd, the front-end part of ruy. +// This is what the ruy::Mul entry point calls, and this ends in a call to +// TrMul, at which point we enter the middle-end. +// The front-end work includes parameter validation (Validate), detemplatization +// and resolution of the specific code path to take (CreateTrMulParams), and +// any additional logic best done upfront before entering the middle-end +// (e.g. HandlePrepackedCaching). +// The call to CreateTrMulParams is an important watershed in this code's +// structure: code before it needs to be templatized like the ruy::Mul entry +// point, code after it is un-templatized. + +#ifndef RUY_RUY_FRONTEND_H_ +#define RUY_RUY_FRONTEND_H_ + +#include "ruy/create_trmul_params.h" +#include "ruy/ctx.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/trace.h" +#include "ruy/trmul_params.h" +#include "ruy/validate.h" + +namespace ruy { + +// The first half of front-end work, up to the point where we have TrMulParams. +// In other words, this is the part of the front-end work that needs to be +// templatized like the entry point, and that performs the initial work that +// requires this templatization, and the de-templatization. The output of this +// function is the TrMulParams, which contain enough information to allow the +// un-templatized code to take over from there. +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +void MulFrontEndUpToCreateTrMulParams( + const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, + const Mat<DstScalar>& dst, + const MulParams<AccumScalar, DstScalar>& mul_params, Ctx* ctx, + TrMulParams* params) { + RUY_TRACE_SCOPE; + static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path"); + static_assert( + (CompiledPaths & ~kAllPathsIncludingInternalVariants) == Path::kNone, + "CompiledPaths must be a subset of " + "ruy::kAllPathsIncludingInternalVariants"); + + // Perform validation of parameters early so that failures are easier to map + // to user errors. In particular, perform this validation before the + // transposition. + Validate(lhs, rhs, dst); + + // De-templatize this Mul call by creating a TrMulParams structure. + // This is also where the specific kernel and pack code paths corresponding to + // `the_path` are selected, among all the code paths in `CompiledPaths`, and + // recorded as function pointers in the TrMulParams. + // The Transpose(lhs) here is where we switch from 'Mul' to 'TrMul'. + CreateTrMulParams<CompiledPaths>(Transpose(lhs), rhs, dst, mul_params, ctx, + params); +} + +// The second part of the front-end work, starting from where we have freshly +// created TrMulParams, performing any remaining front-end work and entering the +// middle-end. +void MulFrontEndFromTrMulParams(Ctx* ctx, TrMulParams* params); + +// Top-level function orchestrating the two halves of front-end work: +// before and after we have detemplatized the call by creating TrMulParams. +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +void MulFrontEnd(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, + const MulParams<AccumScalar, DstScalar>& mul_params, Ctx* ctx, + Mat<DstScalar>* dst) { + RUY_TRACE_SCOPE; + profiler::ScopeLabel mul_label("Mul"); + profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d", + lhs.layout.rows, lhs.layout.cols, + rhs.layout.cols); + ctx->clear_performance_advisories(); + TrMulParams params; + MulFrontEndUpToCreateTrMulParams<CompiledPaths>(lhs, rhs, *dst, mul_params, + ctx, ¶ms); + MulFrontEndFromTrMulParams(ctx, ¶ms); +} + +} // namespace ruy + +#endif // RUY_RUY_FRONTEND_H_ diff --git a/ruy/gtest_wrapper.h b/ruy/gtest_wrapper.h new file mode 100644 index 0000000..690cea4 --- /dev/null +++ b/ruy/gtest_wrapper.h @@ -0,0 +1,33 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Wrapper around GTest that works around warnings and inconsistencies. + +#ifndef THIRD_PARTY_RUY_RUY_GTEST_WRAPPER_H_ +#define THIRD_PARTY_RUY_RUY_GTEST_WRAPPER_H_ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#include "gtest/gtest.h" // IWYU pragma: export +#pragma GCC diagnostic pop + +// When building for WAsm, ASSERT_DEATH is not defined. +#ifdef ASSERT_DEATH +#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) ASSERT_DEATH(CONDITION, MESSAGE) +#else +#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) +#endif + +#endif // THIRD_PARTY_RUY_RUY_GTEST_WRAPPER_H_ diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h new file mode 100644 index 0000000..23cb028 --- /dev/null +++ b/ruy/have_built_path_for.h @@ -0,0 +1,31 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_HAVE_BUILT_PATH_FOR_H_ +#define RUY_RUY_HAVE_BUILT_PATH_FOR_H_ + +#include "ruy/platform.h" + +namespace ruy { + +#if RUY_PLATFORM_X86 +bool HaveBuiltPathForAvx(); +bool HaveBuiltPathForAvx2Fma(); +bool HaveBuiltPathForAvx512(); +#endif // RUY_PLATFORM_X86 + +} // namespace ruy + +#endif // RUY_RUY_HAVE_BUILT_PATH_FOR_H_ diff --git a/ruy/have_built_path_for_avx.cc b/ruy/have_built_path_for_avx.cc new file mode 100644 index 0000000..948c7a5 --- /dev/null +++ b/ruy/have_built_path_for_avx.cc @@ -0,0 +1,35 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM_X86 +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) + +bool HaveBuiltPathForAvx() { return false; } + +#else // RUY_PLATFORM_AVX && RUY_OPT(ASM) + +bool HaveBuiltPathForAvx() { return true; } + +#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM) +#endif // RUY_PLATFORM_X86 + +} // namespace ruy diff --git a/ruy/have_built_path_for_avx2_fma.cc b/ruy/have_built_path_for_avx2_fma.cc new file mode 100644 index 0000000..03e8f8d --- /dev/null +++ b/ruy/have_built_path_for_avx2_fma.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM_X86 +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)) + +bool HaveBuiltPathForAvx2Fma() { return false; } + +#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) + +bool HaveBuiltPathForAvx2Fma() { return true; } + +#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) +#endif // RUY_PLATFORM_X86 + +} // namespace ruy diff --git a/ruy/have_built_path_for_avx512.cc b/ruy/have_built_path_for_avx512.cc new file mode 100644 index 0000000..29d56a8 --- /dev/null +++ b/ruy/have_built_path_for_avx512.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM_X86 +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM)) + +bool HaveBuiltPathForAvx512() { return false; } + +#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) + +bool HaveBuiltPathForAvx512() { return true; } + +#endif // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) +#endif // RUY_PLATFORM_X86 + +} // namespace ruy diff --git a/ruy/kernel.h b/ruy/kernel.h new file mode 100644 index 0000000..6bfeb4a --- /dev/null +++ b/ruy/kernel.h @@ -0,0 +1,245 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_KERNEL_H_ +#define RUY_RUY_KERNEL_H_ + +#include "ruy/kernel_common.h" +#include "ruy/mul_params.h" +#include "ruy/platform.h" +#include "ruy/trace.h" + +// IWYU pragma: begin_exports +#if RUY_PLATFORM_NEON +#include "ruy/kernel_arm.h" +#elif RUY_PLATFORM_X86 +#include "ruy/kernel_x86.h" +#endif +// IWYU pragma: end_exports + +namespace ruy { + +// KernelArgs is a helper to access the template parameter values from a Kernel +// template instantiation. +template <typename KernelType> +struct KernelArgs {}; + +template <Path tPath, typename tLhsScalar, typename tRhsScalar, + typename tAccumScalar, typename tDstScalar> +struct KernelArgs< + Kernel<tPath, tLhsScalar, tRhsScalar, tAccumScalar, tDstScalar>> { + static constexpr Path kPath = tPath; + using LhsScalar = tLhsScalar; + using RhsScalar = tRhsScalar; + using AccumScalar = tAccumScalar; + using DstScalar = tDstScalar; +}; + +// RunKernel::Run() is the only place that directly invokes Kernel::Run(). +// It performs the types un-erasure, and factoring all Kernel::Run() calls +// through this function also gives a single place where to conditionally +// implement RUY_OPT(FAT_KERNEL). This should be a function but is a class to +// hide and share some boilerplate (see the member types, and the RunTyped +// method also using them). +template <typename KernelType> +class RunKernel final { + public: + static void Run(Tuning tuning, const SidePair<PEMat>& src, + const void* mul_params, const SidePair<int>& start, + const SidePair<int>& end, EMat* dst) { + RUY_TRACE_SCOPE_NAME("RunKernel"); + const auto& unerased_lhs = UneraseType<LhsScalar>(src[Side::kLhs]); + const auto& unerased_rhs = UneraseType<RhsScalar>(src[Side::kRhs]); + auto unerased_dst = UneraseType<DstScalar>(*dst); + RUY_TRACE_INFO(RUN_KERNEL); + RunTyped(tuning, unerased_lhs, unerased_rhs, + *static_cast<const MulParamsType*>(mul_params), start, end, + &unerased_dst); + } + + private: + using Args = KernelArgs<KernelType>; + using LhsScalar = typename Args::LhsScalar; + using RhsScalar = typename Args::RhsScalar; + using AccumScalar = typename Args::AccumScalar; + using DstScalar = typename Args::DstScalar; + using MulParamsType = MulParams<AccumScalar, DstScalar>; + static void RunTyped(Tuning tuning, const PMat<LhsScalar>& lhs, + const PMat<RhsScalar>& rhs, + const MulParamsType& mul_params, + const SidePair<int>& start, const SidePair<int>& end, + Mat<DstScalar>* dst) { + const int start_row = start[Side::kLhs]; + const int start_col = start[Side::kRhs]; + const int end_row = end[Side::kLhs]; + const int end_col = end[Side::kRhs]; + KernelType kernel(tuning); + using LhsLayout = typename KernelType::LhsLayout; + using RhsLayout = typename KernelType::RhsLayout; + // This is a good place to validate kernel layouts. The Kernel class + // template itself isn't a good place to do that because it has + // specializations. + // The kRows of both sides have to match: in TrMul, kRows is the depth + // dimension, on which LHS and RHS have to agree for the matrix + // multiplication to be defined at all, so requiring the corresponding + // dimension of the kernel layouts to also match is reasonable. If it didn't + // match, then the packed matrices could have mismatching depth dimensions + // even with the source matrices agreeing. + static_assert(LhsLayout::kRows == RhsLayout::kRows, ""); + // The kernel layouts have to be power-of-two. This simplifies BlockMap + // logic considerably. This also avoids leaking fine performance + // optimization details up the stack. For instance, if one of the dimensions + // were 6, then users might notice that optimal performance is achieved with + // matrix dimensions that are multiples of 6, and might start contorting + // their own application code to match that requirement, in a way that would + // not be future-proof. + static_assert(is_pot(LhsLayout::kRows), ""); + static_assert(is_pot(LhsLayout::kCols), ""); + static_assert(is_pot(RhsLayout::kRows), ""); + static_assert(is_pot(RhsLayout::kCols), ""); + // end_row and end_col may be larger than dst dimensions. + // that is because kernels write directly to the destination matrix, whose + // dimensions may not be a multiple of the kernel dimensions, and we try to + // keep this annoyance localized as an implementation detail in kernels, + // by allowing to pass rounded-up values down as far as possible. + // These assertions encode the contract. + RUY_DCHECK_LE(0, start_row); + RUY_DCHECK_LE(start_row, end_row); + RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); + RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); + RUY_DCHECK_LE(0, start_col); + RUY_DCHECK_LE(start_col, end_col); + RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); + RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); +#if RUY_OPT(FAT_KERNEL) + kernel.Run(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst); +#else + for (int col = start_col; col < end_col; col += RhsLayout::kCols) { + int block_end_col = std::min(col + RhsLayout::kCols, end_col); + for (int row = start_row; row < end_row; row += LhsLayout::kCols) { + int block_end_row = std::min(row + LhsLayout::kCols, end_row); + kernel.Run(lhs, rhs, mul_params, row, col, block_end_row, block_end_col, + dst); + } + } +#endif + } +}; + +template <Path ThePath> +struct StandardCppKernelLayout {}; + +template <> +struct StandardCppKernelLayout<Path::kStandardCpp> { + using Lhs = FixedKernelLayout<Order::kColMajor, 1, 1>; + using Rhs = FixedKernelLayout<Order::kColMajor, 1, 1>; +}; + +// A variant exercising RowMajor square blocks +template <> +struct StandardCppKernelLayout<Path::kInternalStandardCppVariant1> { + using Lhs = FixedKernelLayout<Order::kRowMajor, 4, 4>; + using Rhs = FixedKernelLayout<Order::kRowMajor, 4, 4>; +}; + +// A variant with a rectangular layout: 4x8 +template <> +struct StandardCppKernelLayout<Path::kInternalStandardCppVariant2> { + using Lhs = FixedKernelLayout<Order::kColMajor, 1, 4>; + using Rhs = FixedKernelLayout<Order::kColMajor, 1, 8>; +}; + +// A variant with different block orders in LHS vs RHS. +template <> +struct StandardCppKernelLayout<Path::kInternalStandardCppVariant3> { + using Lhs = FixedKernelLayout<Order::kColMajor, 2, 16>; + using Rhs = FixedKernelLayout<Order::kRowMajor, 2, 8>; +}; + +// General implementation of the Kernel template, overridden by template +// specializations for specific SIMD code paths. This general implementation +// covers Path::kStandardCpp and its internal test-only variants. +template <Path ThePath, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +struct Kernel { + // Each Kernel specialization defines kPath as the ground-truth path that it + // implements. This is used in assertions. As we support fallbacks between + // paths (see RUY_INHERIT_KERNEL), Unless a specialization for a specific set + // of template parameters was defined, it is normal for template + // instantiations of the form Kernel<SomePath, ...> to have kPath!=SomePath. + // Assertions that kPath==SomePath are used in places where we know that we + // should be using a template specialization for a specific path rather than a + // fallback. + static constexpr Path kPath = ThePath; + using MulParamsType = MulParams<AccumScalar, DstScalar>; + using LhsLayout = typename StandardCppKernelLayout<ThePath>::Lhs; + using RhsLayout = typename StandardCppKernelLayout<ThePath>::Rhs; + explicit Kernel(Tuning) {} + void Run(const PMat<LhsScalar>& lhs, const PMat<RhsScalar>& rhs, + const MulParamsType& mul_params, int start_row, int start_col, + int end_row, int end_col, Mat<DstScalar>* dst) const { + // See the comment in RunKernelTyped. end_row may be larger than + // dst->layout.rows. It's the responsibility of the kernel to avoid + // overrunning dst boundaries, which we do here by computing + // clamped_end_row. + int clamped_end_row = std::min(end_row, dst->layout.rows); + int clamped_end_col = std::min(end_col, dst->layout.cols); + RUY_DCHECK_LE(0, start_row); + RUY_DCHECK_LE(start_row, clamped_end_row); + RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); + RUY_DCHECK_LE(clamped_end_row, end_row); + RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); + RUY_DCHECK_LE(0, start_col); + RUY_DCHECK_LE(start_col, clamped_end_col); + RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); + RUY_DCHECK_LE(clamped_end_col, end_col); + RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); + profiler::ScopeLabel label("Kernel (Standard Cpp)"); + const int depth = lhs.layout.rows; + for (int i = start_row; i < clamped_end_row; i++) { + for (int j = start_col; j < clamped_end_col; j++) { + AccumScalar accum = 0; + for (int k = 0; k < depth; k++) { + AccumScalar lhs_val = Element(lhs, k, i); + AccumScalar rhs_val = Element(rhs, k, j); + accum += lhs_val * rhs_val; + } + int channel = + mul_params.channel_dimension() == ChannelDimension::kRow ? i : j; + if (mul_params.bias()) { + accum += mul_params.bias()[channel]; + } + if (lhs.zero_point) { + accum -= lhs.zero_point * rhs.sums[j]; + } + if (rhs.zero_point) { + accum -= rhs.zero_point * lhs.sums[i]; + } + if (lhs.zero_point && rhs.zero_point) { + accum += lhs.zero_point * rhs.zero_point * depth; + } + ApplyMultiplier(mul_params, channel, &accum); + accum += dst->zero_point; + accum = std::min<AccumScalar>(accum, mul_params.clamp_max()); + accum = std::max<AccumScalar>(accum, mul_params.clamp_min()); + *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum); + } + } + } +}; + +} // namespace ruy + +#endif // RUY_RUY_KERNEL_H_ diff --git a/ruy/kernel_arm.h b/ruy/kernel_arm.h new file mode 100644 index 0000000..76cfc82 --- /dev/null +++ b/ruy/kernel_arm.h @@ -0,0 +1,212 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_KERNEL_ARM_H_ +#define RUY_RUY_KERNEL_ARM_H_ + +#include <cstddef> +#include <cstdint> + +#include "ruy/asm_helpers.h" +#include "ruy/kernel_common.h" +#include "ruy/mat.h" +#include "ruy/mul_params.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM_NEON && RUY_OPT(ASM) + +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) +RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) + +#if RUY_PLATFORM_NEON_64 +void Kernel8bitNeon(const KernelParams8bit<4, 4>& params); +void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params); +#elif RUY_PLATFORM_NEON_32 +void Kernel8bitNeon(const KernelParams8bit<4, 2>& params); +void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params); +#endif +void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params); +void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params); +void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params); +void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params); + +#if RUY_PLATFORM_NEON_64 +template <typename DstScalar> +struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { + static constexpr Path kPath = Path::kNeon; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; + Tuning tuning = Tuning::kAuto; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitNeon1Col(params); + return; + } + if (__builtin_expect(tuning == Tuning::kA55ish, true)) { + Kernel8bitNeonA55ish(params); + } else { + Kernel8bitNeon(params); + } + } +}; +#endif + +#if RUY_PLATFORM_NEON_32 +template <typename DstScalar> +struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { + static constexpr Path kPath = Path::kNeon; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>; + Tuning tuning = Tuning::kAuto; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitNeon1Col(params); + return; + } + Kernel8bitNeon(params); + } +}; +#endif + +#if RUY_PLATFORM_NEON_64 +template <typename DstScalar> +struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, DstScalar> { + static constexpr Path kPath = Path::kNeonDotprod; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitNeonDotprod1Col(params); + } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) { + Kernel8bitNeonDotprodA55ish(params); + } else { + Kernel8bitNeonDotprod(params); + } + } +}; +#endif + +void KernelFloatNeon(const KernelParamsFloat<8, 8>& params); +void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params); +void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params); +void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params); + +#if RUY_PLATFORM_NEON_64 +// A Float kernel for ARM64 Neon. +template <> +struct Kernel<Path::kNeon, float, float, float, float> { + static constexpr Path kPath = Path::kNeon; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<float>& lhs, const PMat<float>& rhs, + const MulParams<float, float>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<float>* dst) const { + KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (__builtin_expect(tuning == Tuning::kA55ish, true)) { + KernelFloatNeonA55ish(params); + } else { + KernelFloatNeon(params); + } + } +}; +#endif + +#if RUY_PLATFORM_NEON_32 +// A Float kernel for ARM32 Neon. +template <> +struct Kernel<Path::kNeon, float, float, float, float> { + static constexpr Path kPath = Path::kNeon; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<float>& lhs, const PMat<float>& rhs, + const MulParams<float, float>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<float>* dst) const { + KernelParamsFloat<8, 4> params; + + MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + + KernelFloat32Neon(params); + } +}; +#endif + +// While the dotprod NEON extension does not concern floating-point arithmetic, +// its presence allows us to distinguish, in the in-order tuning case, between +// A53 and A55r1. TODO: should this be folded into tuning? +template <> +struct Kernel<Path::kNeonDotprod, float, float, float, float> { + static constexpr Path kPath = Path::kNeonDotprod; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + using Base = + Kernel<Path::kNeon, float, float, float, float>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<float>& lhs, const PMat<float>& rhs, + const MulParams<float, float>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<float>* dst) const { + KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (__builtin_expect(tuning == Tuning::kA55ish, true)) { + KernelFloatNeonDotprodA55ish(params); + } else { + KernelFloatNeon(params); + } + } +}; + +#endif // RUY_PLATFORM_NEON && RUY_OPT(ASM) + +} // namespace ruy + +#endif // RUY_RUY_KERNEL_ARM_H_ diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc new file mode 100644 index 0000000..b20f668 --- /dev/null +++ b/ruy/kernel_arm32.cc @@ -0,0 +1,2515 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/kernel_arm.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) + +#define RUY_ASM_LABEL_STORE_UINT8 91 +#define RUY_ASM_LABEL_STORE_INT8 92 +#define RUY_ASM_LABEL_STORE_INT16 93 +#define RUY_ASM_LABEL_STORE_INT32 94 +#define RUY_ASM_LABEL_AFTER_STORE 99 + +#define RUY_OFFSET_LHS_BASE_PTR 0 +#define RUY_OFFSET_RHS_BASE_PTR 4 +#define RUY_OFFSET_DST_BASE_PTR 8 +#define RUY_OFFSET_BIAS 12 +#define RUY_OFFSET_START_ROW 16 +#define RUY_OFFSET_START_COL 20 +#define RUY_OFFSET_LAST_ROW 24 +#define RUY_OFFSET_LAST_COL 28 +#define RUY_OFFSET_DST_ROWS 32 +#define RUY_OFFSET_DST_COLS 36 +#define RUY_OFFSET_LHS_STRIDE 40 +#define RUY_OFFSET_RHS_STRIDE 44 +#define RUY_OFFSET_DST_STRIDE 48 +#define RUY_OFFSET_DEPTH 52 +#define RUY_OFFSET_CLAMP_MIN 56 +#define RUY_OFFSET_CLAMP_MAX 60 +#define RUY_OFFSET_FLAGS 64 + +#define RUY_STACK_OFFSET_SIZE 96 +#define RUY_STACK_OFFSET_DST_COL_PTR 0 +#define RUY_STACK_OFFSET_DST_PTR 16 +#define RUY_STACK_OFFSET_ROW 32 +#define RUY_STACK_OFFSET_COL 48 +#define RUY_STACK_OFFSET_LHS_COL_PTR 64 +#define RUY_STACK_OFFSET_RHS_COL_PTR 80 + +template <typename Params> +void CheckOffsetsInKernelParamsFloat32(const Params&) { + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); + static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); +} + +// Float kernel for ARM32 out-of-order cores. +// Just like Float 64 version, except accumulate in to 8x4 block to only +// use 16 128-bit NEON registers. This is a "first pass" kernel and not +// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9. +void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params) { + CheckOffsetsInKernelParamsFloat32(params); + profiler::ScopeLabel label("Kernel (kNeon)"); + + const float* lhs_ptr = params.lhs_base_ptr; + const float* rhs_ptr = params.rhs_base_ptr; + // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are + // each composed of two 64-bit "d" registers. The asm kernel below has the + // following NEON register allocation: + // Registers q3 -- q10 are accumulators. During accumulation, + // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1 + // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block + // of RHS, like this: + + // Register layout in "q" registers: + // RHS 1x4 block + // /--------------------------| + // |q2.s[0] ... q2.s[3] | + // \--------------------------/ + // LHS 8x1 block + // /---------------------\ /--------------------------| + // | q0.s[0] | | q3.s[0] ... q9.s[0] | + // | ... | | ... ... | + // | q0.s[3] | | q3.s[3] q9.s[3] | + // | q1.s[0] | | q4.s[0] q10.s[0] | + // | ... | | ... ... ... | + // | q1.s[3] | | q4.s[3] .. q10.s[3] | + // \---------------------/ \--------------------------/ + // accumulators 8x4 block + // q11, q14, q15 currently unused. q12 and q13 are used to load + // parameters used for the post-accumulation part of the kernel. + // For completeness, here is the register layout in "d" registers: + // RHS 1x4 block + // /--------------------------| + // |d4[0] ... d5[1] | + // \--------------------------/ + // LHS 8x1 block + // /---------------------\ /--------------------------| + // | d0[0] | | d6[0] ... d18[0] | + // | ... | | ... ... | + // | d1[1] | | d7[1] d19[1] | + // | d2[0] | | d8[0] d20[0] | + // | ... | | ... ... ... | + // | d3[1] | | d9[1] ... d21[1] | + // \---------------------/ \--------------------------/ + // accumulators 8x4 block + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n" + + // clang-format off + + // Load the first 32 bytes of LHS and RHS data. + // Load q0, q1 + "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + // Load q2 + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + // Clear accumulators. + RUY_MAKE_ZERO(q3) + RUY_MAKE_ZERO(q4) + RUY_MAKE_ZERO(q5) + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov r1, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Accumulation loop + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r2\n" + "beq 79f\n" + + "2:\n" + + "vmla.f32 q3, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q7, q0, d5[0]\n" + "vmla.f32 q9, q0, d5[1]\n" + "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS + + "vmla.f32 q4, q1, d4[0]\n" + "vmla.f32 q6, q1, d4[1]\n" + "vmla.f32 q8, q1, d5[0]\n" + "vmla.f32 q10, q1, d5[1]\n" + "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + "add r1, r1, #1\n" + "cmp r1, r2\n" + + "blt 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "vmla.f32 q3, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q7, q0, d5[0]\n" + "vmla.f32 q9, q0, d5[1]\n" + + "vmla.f32 q4, q1, d4[0]\n" + "vmla.f32 q6, q1, d4[1]\n" + "vmla.f32 q8, q1, d5[0]\n" + "vmla.f32 q10, q1, d5[1]\n" + + // End of accumulation. The registers q3 -- q10 contain the final + // float32 accumulator values of the current 8x8 destination block. + // We now have to compute the final values from these accumulators + // and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #3\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #2\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Load some parameters needed for the end work on current block. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Let r8 be stack offset of the row or column variable, whichever + // is the channel index. + "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "ite eq\n" + "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n" + "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n" + // Let r8 be the channel index. + "ldr r8, [sp, r8]\n" + // Compute the bias pointer, by conditionally using the channel index + // (r8) as offset into bias buffer (r1). + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "addne r1, r1, r8, lsl #2\n" + + // Load 4 bias values. When the channel dimension is rows, we will load + // another 4 bias values just before performing the bias addition below, + // as this kernel has a 8x4 rectangular shape. + "vld1.32 {d24, d25}, [r1]!\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore + // in the rest of the work on the current block. + // Load q0, q1 + "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + // Load q2 + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Perform the bias-addition. + // Jump based on channel dimension. + "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows. + // Load the remaining 4 bias values, since we're on the width-8 side + // of this 8x4 kernel. + "vld1.32 {d26, d27}, [r1]\n" + "vadd.f32 q3, q3, q12\n" + "vadd.f32 q5, q5, q12\n" + "vadd.f32 q7, q7, q12\n" + "vadd.f32 q9, q9, q12\n" + "vadd.f32 q4, q4, q13\n" + "vadd.f32 q6, q6, q13\n" + "vadd.f32 q8, q8, q13\n" + "vadd.f32 q10, q10, q13\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns. + "vdup.32 q11, d24[0]\n" + "vdup.32 q13, d24[1]\n" + "vdup.32 q14, d25[0]\n" + "vdup.32 q15, d25[1]\n" + "vadd.f32 q3, q3, q11\n" + "vadd.f32 q4, q4, q11\n" + "vadd.f32 q5, q5, q13\n" + "vadd.f32 q6, q6, q13\n" + "vadd.f32 q7, q7, q14\n" + "vadd.f32 q8, q8, q14\n" + "vadd.f32 q9, q9, q15\n" + "vadd.f32 q10, q10, q15\n" + "7:\n" + + // Load the clamp_min, clamp_max bounds + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.32 q12, r2\n" // clamp_min + "vdup.32 q13, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.f32 q3, q3, q12\n" + "vmax.f32 q4, q4, q12\n" + "vmax.f32 q5, q5, q12\n" + "vmax.f32 q6, q6, q12\n" + "vmax.f32 q7, q7, q12\n" + "vmax.f32 q8, q8, q12\n" + "vmax.f32 q9, q9, q12\n" + "vmax.f32 q10, q10, q12\n" + + // Apply the clamp_max bound + "vmin.f32 q3, q3, q13\n" + "vmin.f32 q4, q4, q13\n" + "vmin.f32 q5, q5, q13\n" + "vmin.f32 q6, q6, q13\n" + "vmin.f32 q7, q7, q13\n" + "vmin.f32 q8, q8, q13\n" + "vmin.f32 q9, q9, q13\n" + "vmin.f32 q10, q10, q13\n" + + // Compute how much of the 8x4 block of destination values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x4, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #8\n" + "mov r5, #4\n" + "cmp r1, #8\n" + // Compute r1 = how many rows of the 8x4 block fit + "it gt\n" + "movgt r1, r3\n" + "cmp r2, #4\n" + // Compute r2 = how many cols of the 8x4 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + // Yes, all of the 8x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x4 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x4 block fits. + // Set (r3 address, r4 stride) to write directly to destination matrix. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + "31:\n" + + // Write our float values to the destination described by + // (r3 address, r4 stride) + "vst1.32 {d6, d7, d8, d9}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q3) + RUY_MAKE_ZERO(q4) + "vst1.32 {d10, d11, d12, d13}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q5) + RUY_MAKE_ZERO(q6) + "vst1.32 {d14, d15, d16, d17}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + "vst1.32 {d18, d19, d20, d21}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + + // If all of the 8x4 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r6, #0\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #32\n" + "add r4, r4, r8\n" + // r2 = how many cols of the 8x4 block fit + "cmp r6, r2\n" + "blt 50b\n" + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #32\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used r3, r5, r10 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #8\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns) + "add r1, r1, r8, lsl #2\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov r1, #1\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13"); +} + +#undef RUY_MAKE_ZERO +#undef RUY_STACK_OFFSET_SIZE +#undef RUY_STACK_OFFSET_DST_COL_PTR +#undef RUY_STACK_OFFSET_DST_PTR +#undef RUY_STACK_OFFSET_ROW +#undef RUY_STACK_OFFSET_COL +#undef RUY_STACK_OFFSET_LHS_COL_PTR +#undef RUY_STACK_OFFSET_RHS_COL_PTR + +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS + +#define RUY_OFFSET_BIAS 0 +#define RUY_OFFSET_LHS_SUMS 4 +#define RUY_OFFSET_RHS_SUMS 8 +#define RUY_OFFSET_LHS_BASE_PTR 12 +#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16 +#define RUY_OFFSET_MULTIPLIER_EXPONENT 20 +#define RUY_OFFSET_RHS_BASE_PTR 24 +#define RUY_OFFSET_DST_BASE_PTR 28 +#define RUY_OFFSET_LHS_ZERO_POINT 32 +#define RUY_OFFSET_RHS_ZERO_POINT 36 +#define RUY_OFFSET_DST_ZERO_POINT 40 +#define RUY_OFFSET_PROD_ZP_DEPTH 44 +#define RUY_OFFSET_START_ROW 48 +#define RUY_OFFSET_START_COL 52 +#define RUY_OFFSET_LAST_ROW 56 +#define RUY_OFFSET_LAST_COL 60 +#define RUY_OFFSET_DST_ROWS 64 +#define RUY_OFFSET_DST_COLS 68 +#define RUY_OFFSET_LHS_STRIDE 72 +#define RUY_OFFSET_RHS_STRIDE 76 +#define RUY_OFFSET_DST_STRIDE 80 +#define RUY_OFFSET_DEPTH 84 +#define RUY_OFFSET_CLAMP_MIN 88 +#define RUY_OFFSET_CLAMP_MAX 92 +#define RUY_OFFSET_FLAGS 96 +#define RUY_OFFSET_DST_TYPE_ID 97 + +#define RUY_STACK_OFFSET_SIZE 96 +#define RUY_STACK_OFFSET_DST_COL_PTR 0 +#define RUY_STACK_OFFSET_DST_PTR 16 +#define RUY_STACK_OFFSET_ROW 32 +#define RUY_STACK_OFFSET_COL 48 +#define RUY_STACK_OFFSET_LHS_COL_PTR 64 +#define RUY_STACK_OFFSET_RHS_COL_PTR 80 + +template <typename Params> +void CheckOffsetsInKernelParams8bit(const Params&) { + static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, + ""); + static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, + ""); + static_assert(offsetof(Params, multiplier_fixedpoint) == + RUY_OFFSET_MULTIPLIER_FIXEDPOINT, + ""); + static_assert( + offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, + ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); + static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); +} + +// Fast-int8 kernel, ported from ARM 64 version. +// Relevant target CPUs for this kernel include Krait 400 and A9, +// since these are 32-bit, out-of-order CPUs. +void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) { + profiler::ScopeLabel label("Kernel (kNeon)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + + // The asm kernel below has the following NEON register allocation: + // + // q6 - q13 are 128-bit (4x32b) accumulators. + // During accumulation, d0 -- d7 are used to load int8 data from LHS and + // d8 -- d11 from RHS: + // int8 RHS 16x2 block + // /-----------------------------| + // |d8.b[0-7] ..... d10.b[0-7]| + // | ... ... | + // |d9.b[0-7] ..... d11.b[0-7]| + // \-----------------------------/ + // int8 LHS 4x16 block + // /------------------------\ /-----------------------------| + // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 | + // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 | + // (Reload d0, d1, d2, d3) + // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 | + // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 | + // \------------------------/ \-----------------------------/ + // 128-bit accumulators 4x2 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" + + // clang-format off + + // Load the first 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + RUY_MAKE_ZERO(q11) + "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n" + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + RUY_MAKE_ZERO(q12) + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + RUY_MAKE_ZERO(q13) + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(q14) + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + RUY_MAKE_ZERO(q15) + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // r1 is how many levels of depth we have already loaded + // data for, r10 is the total depth. + "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r10\n" + "beq 79f\n" + + "2:\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q6, q7, q10, q11 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + "vpadal.s16 q10, q2\n" + "vpadal.s16 q11, q3\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q8, q9, q12, q13 + "vpadal.s16 q8, q14\n" + "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" + "vpadal.s16 q9, q15\n" + "vpadal.s16 q12, q2\n" + "vpadal.s16 q13, q3\n" + + // Prefetch the next 64 bytes of LHS and RHS data. + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Each iteration of this loop advances by 16 levels of depth. + "add r1, r1, #16\n" + + // Loop termination condition + "cmp r1, r10\n" + + "blt 2b\n" + + "79:\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q6, q7, q10, q11 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + "vpadal.s16 q10, q2\n" + "vpadal.s16 q11, q3\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + + // Then pairwise accumulate in to q8, q9, q12, q13 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + "vpadal.s16 q12, q2\n" + "vpadal.s16 q13, q3\n" + + + // All accumulation over depth done. q6 - q13 contain the 4x32b + // accumulators for the 4x2 final matrix. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x2 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // q6-q13 now contain 4 x 32b + "vpadd.i32 d0, d12, d13\n" + "vpadd.i32 d1, d14, d15\n" + "vpadd.i32 d2, d16, d17\n" + "vpadd.i32 d3, d18, d19\n" + "vpadd.i32 d4, d20, d21\n" + "vpadd.i32 d5, d22, d23\n" + "vpadd.i32 d6, d24, d25\n" + "vpadd.i32 d7, d26, d27\n" + + // d0-d7 each contain 2 x 32b accumulators. + // Need to add pairwise to get 1 x 32b for each of the 4x2 entries + // of destination, (Four 'd' registers total) + "vpadd.i32 d28, d0, d1\n" + "vpadd.i32 d29, d2, d3\n" + "vpadd.i32 d30, d4, d5\n" + "vpadd.i32 d31, d6, d7\n" + + //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #1\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Let r8 be stack offset of the row or column variable, whichever + // is the channel index. + "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "ite eq\n" + "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n" + "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n" + // Let r8 be the channel index. + "ldr r8, [sp, r8]\n" + // Compute the bias pointer, by conditionally using the channel index + // (r8) as offset into bias buffer (r1). + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "addne r1, r1, r8, lsl #2\n" + + // Load 2 bias values. When the channel dimension is rows, we will load + // another 2 bias values just before performing the bias addition below, + // as this kernel has a 4x2 rectangular shape. + "vld1.32 {d24}, [r1]!\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Add to the bias values the product + // (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in + // https://arxiv.org/pdf/1712.05877.pdf + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "vdup.32 q9, r3\n" + "vadd.i32 d24, d24, d18\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // Jump based on channel dimension. + "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows. + // Load the remaining 2 bias values, since we're on the width-4 side + // of this 4x2 kernel. + "vld1.32 {d25}, [r1]\n" + "vadd.i32 d25, d25, d19\n" + "vadd.i32 q14, q14, q12\n" + "vadd.i32 q15, q15, q12\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns. + "vdup.32 q10, d24[0]\n" + "vdup.32 q11, d24[1]\n" + "vadd.i32 q14, q14, q10\n" + "vadd.i32 q15, q15, q11\n" + "7:\n" + + // LHS/RHS zero points + // Has RHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + // Offset by current col * number of bytes per value + "add r3, r3, r4, lsl #2\n" + "vld1.32 { d12 }, [r3]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "vdup.32 q10, r5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vmls.i32 q14, q10, d12[0]\n" + "vmls.i32 q15, q10, d12[1]\n" + "401:\n" + + // Has LHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Offset by current row * number of bytes per value + "add r2, r2, r4, lsl #2\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + + // Load 4 lhs_sums values. + "vld1.32 {d22, d23}, [r2]\n" + "vdup.32 d13, r5\n" // rhs_zero_point + + // Compute lhs_sums * rhs_zero_point. + "vmul.i32 q11, q11, d13[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vsub.s32 q14, q14, q11\n" + "vsub.s32 q15, q15, q11\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + // Compute the data pointers for the multiplier data + // r1 = exponent part + // r2 = fixedpoint part + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + // r6 has flags, r8 has channel index + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "it ne\n" + "addne r1, r1, r8, lsl #2\n" + "it ne\n" + "addne r2, r2, r8, lsl #2\n" + + // Load the first 2 values of multiplier exponent and fixedpoint data + // Since this kernel is rectangular 4x2, we will only conditionally load + // 2 more values below. + "vld1.32 {d20}, [r1]!\n" // 2 values of multiplier_exponent + "vld1.32 {d12}, [r2]!\n" // 2 values of multiplier_fixedpoint + + "tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "vmvn.i32 q8, #0\n" + "bne 8f\n" + // Case where channels are rows. + // Load the remaining 2 bias values, since we're on the width-4 side + // of this 4x2 kernel. + "vld1.32 {d21}, [r1]\n" // 2 more values of multiplier_exponent + "vld1.32 {d13}, [r2]\n" // 2 more values of multiplier_fixedpoint + "vmin.s32 q11, q10, q8\n" + "vsub.s32 q10, q10, q11\n" + + // Apply the positive exponent part of the multiplier. + "vshl.s32 q14, q14, q10\n" + "vshl.s32 q15, q15, q10\n" + + // Apply the fixed-point part of the multiplier. + "vqdmulh.s32 q14, q14, q6\n" + "vqdmulh.s32 q15, q15, q6\n" + + // Apply the negative exponent part of the multiplier. + "vrshl.s32 q14, q14, q11\n" + "vrshl.s32 q15, q15, q11\n" + "b 9f\n" + + "8:\n" + // Case where channels are columns. + "vmin.s32 d22, d20, d16\n" + "vsub.s32 d20, d20, d22\n" + + // Apply the positive exponent part of the multiplier. + "vdup.32 q12, d20[0]\n" + "vdup.32 q13, d20[1]\n" + "vshl.s32 q14, q14, q12\n" + "vshl.s32 q15, q15, q13\n" + + // Apply the fixed-point part of the multiplier. + "vqdmulh.s32 q14, q14, d12[0]\n" + "vqdmulh.s32 q15, q15, d12[1]\n" + + // Apply the negative exponent part of the multiplier. + "vdup.32 q12, d22[0]\n" + "vdup.32 q13, d22[1]\n" + "vrshl.s32 q14, q14, q12\n" + "vrshl.s32 q15, q15, q13\n" + + "9:\n" + + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + // Store uint8 values: + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, d12 -- d26, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to uint8 + // Now all 8 1-byte values are in d30. + "vqmovun.s16 d30, q14\n" + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.u8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.u8 d30, d30, d29\n" + + // Compute how much of the 4x2 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #4\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.32 {d30[0]}, [r3]\n" + "add r4, r4, r5\n" + "mov r3, r4\n" + "vst1.32 {d30[1]}, [r3]\n" + + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + // Store int8 values: + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, d12 -- d26, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to int8 + // Now all 8 1-byte values are in d30. + "vqmovn.s16 d30, q14\n" + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.s8 d30, d30, d29\n" + + // Compute how much of the 4x2 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #4\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.32 {d30[0]}, [r3]\n" + "add r4, r4, r5\n" + "mov r3, r4\n" + "vst1.32 {d30[1]}, [r3]\n" + + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Load the destination zero point into each of the 4 32-bit slots + // in a q register. + "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.32 q13, r4\n" // dst_zero_point + // Add the destination zero point + "vadd.s32 q14, q14, q13\n" + "vadd.s32 q15, q15, q13\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q15) + + // Load the clamp_min, clamp_max bounds + "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.16 q12, r2\n" // clamp_min + "vdup.16 q13, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s16 q14, q14, q12\n" + // Apply the clamp_max bound + "vmin.s16 q14, q14, q13\n" + + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x2 block of destination 16-bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.16 {q14}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + // Shift of offset register for half-word loads not allowed in A32, + // so we shift, load/store, then shift back r8. + "lsl r8, r8, #1\n" + "ldrh r10, [r3, r8]\n" + "strh r10, [r4, r8]\n" + "lsr r8, r8, #1\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #8\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #2\n" + + "vst1.16 {d28[0]}, [r3], r6\n" + "add r4, r4, r5\n" + "vst1.16 {d28[1]}, [r3], r6\n" + "vst1.16 {d28[2]}, [r3], r6\n" + "vst1.16 {d28[3]}, [r3], r6\n" + "mov r3, r4\n" + "vst1.16 {d29[0]}, [r3], r6\n" + "vst1.16 {d29[1]}, [r3], r6\n" + "vst1.16 {d29[2]}, [r3], r6\n" + "vst1.16 {d29[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #8\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q14) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x2 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #16\n" + "b 31f\n" + + "30:\n" + // Yes, all of the 4x2 block fits. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // r3 address, r4 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + + "31:\n" + + "vst1.32 {d28, d29}, [r3]\n" + "add r3, r3, r4\n" + "vst1.32 {d30, d31}, [r3]\n" + + // If all of the 4x2 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 4x2 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r6, #0\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #16\n" + "add r4, r4, r8\n" + // r2 = how many cols of the 8x4 block fit + "cmp r6, r2\n" + "blt 50b\n" + + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #16\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #4\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns) + "add r1, r1, r8, lsl #1\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13", "q14", "q15"); +} + +// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS +// is still packed as if it has two columns +void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) { + profiler::ScopeLabel label("Kernel (kNeon)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + + RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)); + + // The asm kernel below has the following NEON register allocation: + // + // q6 - q13 are 128-bit (4x32b) accumulators. + // During accumulation, d0 -- d7 are used to load int8 data from LHS and + // d8 -- d11 from RHS: + // int8 RHS 16x1 block + // /------------| + // | d8.b[0] | + // | ... | + // | d8.b[7] | + // | d9.b[0] | + // | ... | + // | d9.b[7] | + // \------------/ + // int8 LHS 4x16 block + // /-----------------------------------------\ /------------| + // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 | + // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 | + // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 | + // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 | + // \-----------------------------------------/ \------------/ + // 128-bit accumulators 4x1 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" + + // clang-format off + + // Load the first 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // r1 is how many levels of depth we have already loaded + // data for, r10 is the total depth. + "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r10\n" + "beq 79f\n" + + "2:\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q15, d2, d8\n" + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q15, d3, d9\n" + + // Then pairwise accumulate in to q6, q7 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d4, d8\n" + "vmull.s8 q15, d6, d8\n" + "vmlal.s8 q14, d5, d9\n" + "vmlal.s8 q15, d7, d9\n" + + // Then pairwise accumulate in to q8, q9 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + + + // Load the next 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Each iteration of this loop advances by 16 levels of depth. + "add r1, r1, #16\n" + + // Loop termination condition + "cmp r1, r10\n" + + "blt 2b\n" + + "79:\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q15, d2, d8\n" + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q15, d3, d9\n" + + // Then pairwise accumulate in to q6, q7 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d4, d8\n" + "vmull.s8 q15, d6, d8\n" + "vmlal.s8 q14, d5, d9\n" + "vmlal.s8 q15, d7, d9\n" + + // Then pairwise accumulate in to q8, q9 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + + // All accumulation over depth done. q6 - q9 contain the 4x32b + // accumulators for the 4x1 final matrix. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x2 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // q6-q9 now contain 4 x 32b + "vpadd.i32 d0, d12, d13\n" + "vpadd.i32 d1, d14, d15\n" + "vpadd.i32 d2, d16, d17\n" + "vpadd.i32 d3, d18, d19\n" + + // d0-d4 each contain 2 x 32b accumulators. + // Need to add pairwise to get 1 x 32b for each of the 4x1 entries + // of destination, (Four 'd' registers total) + "vpadd.i32 d28, d0, d1\n" + "vpadd.i32 d29, d2, d3\n" + + // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries. + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #1\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "addne r1, r1, r8, lsl #2\n" + + // Load 4 bias values. + "vld1.32 {d24, d25}, [r1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Add to the bias values the product + // (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in + // https://arxiv.org/pdf/1712.05877.pdf + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "vdup.32 q9, r3\n" + "vadd.i32 q12, q12, q9\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "vadd.i32 q14, q14, q12\n" + + // LHS/RHS zero points + // Has RHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + // Offset by current col * number of bytes per value + "add r3, r3, r4, lsl #2\n" + "vld1.32 { d12 }, [r3]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "vdup.32 q10, r5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vmls.i32 q14, q10, d12[0]\n" + "401:\n" + + // Has LHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Offset by current row * number of bytes per value + "add r2, r2, r4, lsl #2\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + + // Load 4 lhs_sums values. + "vld1.32 {d22, d23}, [r2]\n" + "vdup.32 d13, r5\n" // rhs_zero_point + + // Compute lhs_sums * rhs_zero_point. + "vmul.i32 q11, q11, d13[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vsub.s32 q14, q14, q11\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "it ne\n" + "addne r1, r1, r4, lsl #2\n" + + "vld1.32 {q10}, [r1]\n" + + "vmvn.i32 q8, #0\n" + "vmin.s32 q13, q10, q8\n" + "vsub.s32 q12, q10, q13\n" + + // Apply the positive exponent part of the multiplier. + "vshl.s32 q14, q14, q12\n" + + // Load fixed point part of the multiplier + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + // r6 has flags, r4 has row + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "it ne\n" + "addne r1, r1, r4, lsl #2\n" + "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint + + // Apply the fixed-point part of the multiplier. + "vqdmulh.s32 q14, q14, q10\n" + + // Apply the negative exponent part of the multiplier. + "vrshl.s32 q14, q14, q13\n" + + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + // Store uint8 values: + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to uint8 + "vqmovun.s16 d30, q14\n" + // At this point, we only need 4 8-bit values in the lower half + // of d30. + + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.u8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.u8 d30, d30, d29\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x1 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.8 {d30[0]}, [r3], r6\n" + "vst1.8 {d30[1]}, [r3], r6\n" + "vst1.8 {d30[2]}, [r3], r6\n" + "vst1.8 {d30[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + // Store int8 values: + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to int8 + "vqmovn.s16 d30, q14\n" + // At this point, we only need 4 8-bit values in the lower half + // of d30. + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.s8 d30, d30, d29\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4 i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.8 {d30[0]}, [r3], r6\n" + "vst1.8 {d30[1]}, [r3], r6\n" + "vst1.8 {d30[2]}, [r3], r6\n" + "vst1.8 {d30[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Load the destination zero point into each of the 4 32-bit slots + // in a q register. + "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.32 q13, r4\n" // dst_zero_point + // Add the destination zero point + "vadd.s32 q14, q14, q13\n" + //"vadd.s32 q15, q15, q13\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q15) + + // Load the clamp_min, clamp_max bounds + "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.16 d24, r2\n" // clamp_min + "vdup.16 d26, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s16 d28, d28, d24\n" + // Apply the clamp_max bound + "vmin.s16 d28, d28, d26\n" + + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x1 block of destination 16-bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x1 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.16 {d28}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + // Shift of offset register for half-word loads not allowed in A32, + // so we shift, load/store, then shift back r8. + "lsl r8, r8, #1\n" + "ldrh r10, [r3, r8]\n" + "strh r10, [r4, r8]\n" + "lsr r8, r8, #1\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #2\n" + + "vst1.16 {d28[0]}, [r3], r6\n" + "vst1.16 {d28[1]}, [r3], r6\n" + "vst1.16 {d28[2]}, [r3], r6\n" + "vst1.16 {d28[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #8\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q14) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x1 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #16\n" + "b 31f\n" + + "30:\n" + // Yes, all of the 4x1 block fits. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // r3 address, r4 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + + "31:\n" + + "vst1.32 {d28, d29}, [r3]\n" + + // If all of the 4x1 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 4x1 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #16\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #4\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by dst_stride (i.e. 1 column) + "add r1, r1, r8\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13", "q14", "q15"); +} + +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_LHS_SUMS +#undef RUY_OFFSET_RHS_SUMS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT +#undef RUY_OFFSET_MULTIPLIER_EXPONENT +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_LHS_ZERO_POINT +#undef RUY_OFFSET_RHS_ZERO_POINT +#undef RUY_OFFSET_DST_ZERO_POINT +#undef RUY_OFFSET_PROD_ZP_DEPTH +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS +#undef RUY_OFFSET_DST_TYPE_ID + +#undef RUY_STACK_OFFSET_SIZE +#undef RUY_STACK_OFFSET_DST_COL_PTR +#undef RUY_STACK_OFFSET_DST_PTR +#undef RUY_STACK_OFFSET_ROW +#undef RUY_STACK_OFFSET_COL +#undef RUY_STACK_OFFSET_LHS_COL_PTR +#undef RUY_STACK_OFFSET_RHS_COL_PTR + +#endif // RUY_PLATFORM_NEON_32 && (RUY_OPT(ASM) +} // namespace ruy diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc new file mode 100644 index 0000000..fe65d9c --- /dev/null +++ b/ruy/kernel_arm64.cc @@ -0,0 +1,8075 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> + +#include "ruy/asm_helpers.h" +#include "ruy/check_macros.h" +#include "ruy/kernel_arm.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) + +#define RUY_ASM_LABEL_STORE_UINT8 91 +#define RUY_ASM_LABEL_STORE_INT8 92 +#define RUY_ASM_LABEL_STORE_INT16 93 +#define RUY_ASM_LABEL_STORE_INT32 94 +#define RUY_ASM_LABEL_AFTER_STORE 99 + +#define RUY_OFFSET_BIAS 0 +#define RUY_OFFSET_LHS_SUMS 8 +#define RUY_OFFSET_RHS_SUMS 16 +#define RUY_OFFSET_LHS_BASE_PTR 24 +#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32 +#define RUY_OFFSET_MULTIPLIER_EXPONENT 40 +#define RUY_OFFSET_RHS_BASE_PTR 48 +#define RUY_OFFSET_DST_BASE_PTR 56 +#define RUY_OFFSET_LHS_ZERO_POINT 64 +#define RUY_OFFSET_RHS_ZERO_POINT 68 +#define RUY_OFFSET_DST_ZERO_POINT 72 +#define RUY_OFFSET_PROD_ZP_DEPTH 76 +#define RUY_OFFSET_START_ROW 80 +#define RUY_OFFSET_START_COL 84 +#define RUY_OFFSET_LAST_ROW 88 +#define RUY_OFFSET_LAST_COL 92 +#define RUY_OFFSET_DST_ROWS 96 +#define RUY_OFFSET_DST_COLS 100 +#define RUY_OFFSET_LHS_STRIDE 104 +#define RUY_OFFSET_RHS_STRIDE 108 +#define RUY_OFFSET_DST_STRIDE 112 +#define RUY_OFFSET_DEPTH 116 +#define RUY_OFFSET_CLAMP_MIN 120 +#define RUY_OFFSET_CLAMP_MAX 124 +#define RUY_OFFSET_FLAGS 128 + +template <typename Params> +void CheckOffsetsInKernelParams8bit(const Params&) { + static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, + ""); + static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, + ""); + static_assert(offsetof(Params, multiplier_fixedpoint) == + RUY_OFFSET_MULTIPLIER_FIXEDPOINT, + ""); + static_assert( + offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, + ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); + static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); +} + +// Fast-int8-trick kernel, similar to this production gemmlowp kernel: +// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296 +// +// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, +// since these are 64-bit, out-of-order and without dotprod support. +void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label("Kernel (kNeon)"); + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 -- v7 from RHS: + // + // int8 RHS 16x4 block + // /-----------------------------------------| + // |v4.b[0] ... v7.b[0] | + // | ... ... | + // |v4.b[15] ... v7.b[15] | + // \-----------------------------------------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------------------------------------| + // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 4x4 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "sadalp v24.4s, v8.8h\n" + "smull v8.8h, v0.8b, v4.8b\n" + "sadalp v25.4s, v9.8h\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + "smull v9.8h, v1.8b, v4.8b\n" + "sadalp v26.4s, v10.8h\n" + "smull v10.8h, v2.8b, v4.8b\n" + "sadalp v27.4s, v11.8h\n" + "smull v11.8h, v3.8b, v4.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v5.8b\n" + "sadalp v29.4s, v13.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v30.4s, v14.8h\n" + "smull v14.8h, v2.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + + // Loop termination condition + "cmp w1, w12\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v24.4s, v8.8h\n" + "sadalp v25.4s, v9.8h\n" + "sadalp v26.4s, v10.8h\n" + "sadalp v27.4s, v11.8h\n" + "sadalp v28.4s, v12.8h\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 4x4 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x4 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + "addp v20.4s, v20.4s, v21.4s\n" + "addp v22.4s, v22.4s, v23.4s\n" + "addp v24.4s, v24.4s, v25.4s\n" + "addp v26.4s, v26.4s, v27.4s\n" + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + "addp v17.4s, v20.4s, v22.4s\n" + "addp v18.4s, v24.4s, v26.4s\n" + "addp v19.4s, v28.4s, v30.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "mvni v8.4s, #0\n" + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Load the multiplier_fixedpoint values. + "add x5, x4, x3, lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v14.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v14.4s\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v20.4s, v14.s[0]\n" + "dup v21.4s, v14.s[1]\n" + "dup v22.4s, v14.s[2]\n" + "dup v23.4s, v14.s[3]\n" + "add v16.4s, v16.4s, v20.4s\n" + "add v17.4s, v17.4s, v21.4s\n" + "add v18.4s, v18.4s, v22.4s\n" + "add v19.4s, v19.4s, v23.4s\n" + "7:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[1]\n" + "mls v18.4s, v10.4s, v14.s[2]\n" + "mls v19.4s, v10.4s, v14.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v11.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, x3, lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smin v11.4s, v8.4s, v14.4s\n" + "sub v12.4s, v14.4s, v11.4s\n" + + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 8f\n" + // Case where channels are rows + + // Apply the positive exponent part of the multiplier. + "sshl v16.4s, v16.4s, v12.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "sshl v18.4s, v18.4s, v12.4s\n" + "sshl v19.4s, v19.4s, v12.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v15.4s\n" + "sqdmulh v17.4s, v17.4s, v15.4s\n" + "sqdmulh v18.4s, v18.4s, v15.4s\n" + "sqdmulh v19.4s, v19.4s, v15.4s\n" + + // Apply the negative exponent part of the multiplier. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v11.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v11.4s\n" + "b 9f\n" + + "8:\n" + // Case where channels are columns + + // Apply the positive exponent part of the multiplier. + "dup v20.4s, v12.s[0]\n" + "dup v21.4s, v12.s[1]\n" + "dup v22.4s, v12.s[2]\n" + "dup v23.4s, v12.s[3]\n" + "sshl v16.4s, v16.4s, v20.4s\n" + "sshl v17.4s, v17.4s, v21.4s\n" + "sshl v18.4s, v18.4s, v22.4s\n" + "sshl v19.4s, v19.4s, v23.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v15.s[0]\n" + "sqdmulh v17.4s, v17.4s, v15.s[1]\n" + "sqdmulh v18.4s, v18.4s, v15.s[2]\n" + "sqdmulh v19.4s, v19.4s, v15.s[3]\n" + + // Apply the negative exponent part of the multiplier. + "dup v20.4s, v11.s[0]\n" + "dup v21.4s, v11.s[1]\n" + "dup v22.4s, v11.s[2]\n" + "dup v23.4s, v11.s[3]\n" + "srshl v16.4s, v16.4s, v20.4s\n" + "srshl v17.4s, v17.4s, v21.4s\n" + "srshl v18.4s, v18.4s, v22.4s\n" + "srshl v19.4s, v19.4s, v23.4s\n" + "9:\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "sqxtun2 v16.16b, v17.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + + // Cast-and-saturate from int16 to int8 + "sqxtn v16.8b, v16.8h\n" + "sqxtn2 v16.16b, v17.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[5], [x3], #2\n" + "st1 {v16.h}[6], [x3], #2\n" + "st1 {v16.h}[7], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[1], [x3], #2\n" + "st1 {v17.h}[2], [x3], #2\n" + "st1 {v17.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[5], [x3], #2\n" + "st1 {v17.h}[6], [x3], #2\n" + "st1 {v17.h}[7], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + "str q18, [%[dst_tmp_buf], #32]\n" + "str q19, [%[dst_tmp_buf], #48]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v17.s}[1], [x3], #4\n" + "st1 {v17.s}[2], [x3], #4\n" + "st1 {v17.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v18.s}[1], [x3], #4\n" + "st1 {v18.s}[2], [x3], #4\n" + "st1 {v18.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v19.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v19.s}[1], [x3], #4\n" + "st1 {v19.s}[2], [x3], #4\n" + "st1 {v19.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Similar to existing Kernel8bitNeon but specialized for the case of +// RHS cols == 1. +// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, +// since these are 64-bit, out-of-order and without dotprod support. +void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label("Kernel (kNeon)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)); + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v19 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 from RHS: + // + // int8 RHS 16x1 block + // /-----------| + // |v4.b[0] | + // | ... | + // |v4.b[15] | + // \-----------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------| + // |v0.b[0] ... v0.b[15] | |v16.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s | + // \---------------------/ \-----------/ + // int32 accumulators 4x1 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + "sadalp v17.4s, v9.8h\n" + "sadalp v18.4s, v10.8h\n" + "sadalp v19.4s, v11.8h\n" + + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + + // Loop termination condition + "cmp w1, w12\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "sadalp v17.4s, v9.8h\n" + "sadalp v18.4s, v10.8h\n" + "sadalp v19.4s, v11.8h\n" + + // End of accumulation. The registers v16 -- v19 contain the final + // int32 accumulator values of the current 4x1 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x1 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + // (still multiply column stride by 4 due to packing) + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "mvni v8.4s, #0\n" + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + "add x5, x1, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // (all four 32-bit accumulators are in v16 at this point) + "add v16.4s, v16.4s, v14.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smin v11.4s, v8.4s, v14.4s\n" + "sub v12.4s, v14.4s, v11.4s\n" + + // Apply the positive exponent part of the multiplier. + "sshl v16.4s, v16.4s, v12.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v15.4s\n" + + // Apply the negative exponent part of the multiplier. + "srshl v16.4s, v16.4s, v11.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this instruction, all data is in lower half (64-bits) of v16 + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + // Now all data is in the first 32-bits of v16 + "sqxtun v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x1 block fit + "csel w1, w1, w3, le\n" + + // Test if w1==4, i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in the lower half (64 bits) of v16. + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to int8 + "sqxtn v16.8b, v16.8h\n" + // At this point, we only need 4 lowest 8-bit values in v16. + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4, i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + // After this instruction, all data is in lower half of v16. + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4 i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19"); +} + +// Variant of the above Kernel8bitNeon, tuned for A55-ish CPUs. +// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and +// the original Cortex-A55, since these are 64-bit and do not support dotprod. +// +// While this kernel does not have a direct equivalent in gemmlowp, it was +// developed based on insights that David Mansell at ARM shared with their +// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A53: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 +void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 -- v7 from RHS: + // + // int8 RHS 16x4 block + // /-----------------------------------------| + // |v4.b[0] ... v7.b[0] | + // | ... ... | + // |v4.b[15] ... v7.b[15] | + // \-----------------------------------------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------------------------------------| + // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 4x4 block + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + RUY_MAKE_ZERO(v16) + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(v17) + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + RUY_MAKE_ZERO(v18) + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + RUY_MAKE_ZERO(v19) + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v20) + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v21) + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + RUY_MAKE_ZERO(v22) + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + RUY_MAKE_ZERO(v23) + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v24) + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v25) + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v26) + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v27) + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v28) + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v29) + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v30) + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v31) + + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ldr d4, [%[rhs_ptr], #0]\n" + "smull v8.8h, v0.8b, v6.8b\n" + "ldr x7, [%[rhs_ptr], #8]\n" + "sadalp v17.4s, v9.8h\n" + "ldr d5, [%[rhs_ptr], #16]\n" + "smull v9.8h, v1.8b, v6.8b\n" + "ldr x8, [%[rhs_ptr], #24]\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "smull v11.8h, v3.8b, v6.8b\n" + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "sadalp v20.4s, v12.8h\n" + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + "smull v12.8h, v0.8b, v7.8b\n" + // Loop termination condition + "cmp w1, w12\n" + "sadalp v21.4s, v13.8h\n" + "ldr x3, [%[lhs_ptr], #-56]\n" + "smull v13.8h, v1.8b, v7.8b\n" + "ldr x4, [%[lhs_ptr], #-40]\n" + "sadalp v22.4s, v14.8h\n" + "ldr x5, [%[lhs_ptr], #-24]\n" + "smull v14.8h, v2.8b, v7.8b\n" + "ldr x6, [%[lhs_ptr], #-8]\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "ldr x9, [%[rhs_ptr], #-24]\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + "ldr d6, [%[rhs_ptr], #-32]\n" + "smlal2 v12.8h, v0.16b, v7.16b\n" + "ldr d0, [%[lhs_ptr], #-64]\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "ldr d1, [%[lhs_ptr], #-48]\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "ins v4.d[1], x7\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ins v5.d[1], x8\n" + + "ldr d2, [%[lhs_ptr], #-32]\n" + "ins v0.d[1], x3\n" + "sadalp v24.4s, v8.8h\n" + "ldr d3, [%[lhs_ptr], #-16]\n" + "ins v1.d[1], x4\n" + "smull v8.8h, v0.8b, v4.8b\n" + "ins v2.d[1], x5\n" + "sadalp v25.4s, v9.8h\n" + "ins v3.d[1], x6\n" + "smull v9.8h, v1.8b, v4.8b\n" + "ldr d7, [%[rhs_ptr], #-16]\n" + "sadalp v26.4s, v10.8h\n" + "ldr x10, [%[rhs_ptr], #-8]\n" + "smull v10.8h, v2.8b, v4.8b\n" + "sadalp v27.4s, v11.8h\n" + "smull v11.8h, v3.8b, v4.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v5.8b\n" + "sadalp v29.4s, v13.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v30.4s, v14.8h\n" + "smull v14.8h, v2.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "ins v6.d[1], x9\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "ins v7.d[1], x10\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v24.4s, v8.8h\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "sadalp v25.4s, v9.8h\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "sadalp v26.4s, v10.8h\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "sadalp v27.4s, v11.8h\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "sadalp v28.4s, v12.8h\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "sadalp v29.4s, v13.8h\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 4x4 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x4 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + "addp v20.4s, v20.4s, v21.4s\n" + "addp v22.4s, v22.4s, v23.4s\n" + "addp v24.4s, v24.4s, v25.4s\n" + "addp v26.4s, v26.4s, v27.4s\n" + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + "addp v17.4s, v20.4s, v22.4s\n" + "addp v18.4s, v24.4s, v26.4s\n" + "addp v19.4s, v28.4s, v30.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "mvni v8.4s, #0\n" + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Load the multiplier_fixedpoint values. + "add x5, x4, x3, lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "ldr d0, [%[lhs_ptr], #0]\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + + "add v16.4s, v16.4s, v14.4s\n" + "ldr d1, [%[lhs_ptr], #16]\n" + "add v17.4s, v17.4s, v14.4s\n" + "ldr d2, [%[lhs_ptr], #32]\n" + "add v18.4s, v18.4s, v14.4s\n" + "ldr d3, [%[lhs_ptr], #48]\n" + "add v19.4s, v19.4s, v14.4s\n" + "ldr d4, [%[rhs_ptr], #0]\n" + "ldr d5, [%[rhs_ptr], #16]\n" + "ldr d6, [%[rhs_ptr], #32]\n" + "ldr d7, [%[rhs_ptr], #48]\n" + + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v20.4s, v14.s[0]\n" + "ldr d1, [%[lhs_ptr], #16]\n" + "dup v21.4s, v14.s[1]\n" + "ldr d2, [%[lhs_ptr], #32]\n" + "dup v22.4s, v14.s[2]\n" + "ldr d3, [%[lhs_ptr], #48]\n" + "dup v23.4s, v14.s[3]\n" + "ldr d4, [%[rhs_ptr], #0]\n" + "add v16.4s, v16.4s, v20.4s\n" + "ldr d5, [%[rhs_ptr], #16]\n" + "add v17.4s, v17.4s, v21.4s\n" + "ldr d6, [%[rhs_ptr], #32]\n" + "add v18.4s, v18.4s, v22.4s\n" + "ldr d7, [%[rhs_ptr], #48]\n" + "add v19.4s, v19.4s, v23.4s\n" + "7:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[1]\n" + "mls v18.4s, v10.4s, v14.s[2]\n" + "mls v19.4s, v10.4s, v14.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v11.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, x3, lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smin v11.4s, v8.4s, v14.4s\n" + "ldr x1, [%[lhs_ptr], #8]\n" + "sub v12.4s, v14.4s, v11.4s\n" + + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 8f\n" + // Case where channels are rows + + + // Apply the positive exponent part of the multiplier. + "sshl v16.4s, v16.4s, v12.4s\n" + "ldr x2, [%[lhs_ptr], #24]\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "ldr x3, [%[lhs_ptr], #40]\n" + "sshl v18.4s, v18.4s, v12.4s\n" + "ldr x4, [%[lhs_ptr], #56]\n" + "sshl v19.4s, v19.4s, v12.4s\n" + + + // Apply the fixed-point part of the multiplier. + "ins v0.d[1], x1\n" + "ldr x1, [%[rhs_ptr], #8]\n" + "sqdmulh v16.4s, v16.4s, v15.4s\n" + "ins v1.d[1], x2\n" + "ldr x2, [%[rhs_ptr], #24]\n" + "sqdmulh v17.4s, v17.4s, v15.4s\n" + "ins v2.d[1], x3\n" + "ldr x3, [%[rhs_ptr], #40]\n" + "sqdmulh v18.4s, v18.4s, v15.4s\n" + "ins v3.d[1], x4\n" + "ldr x4, [%[rhs_ptr], #56]\n" + "sqdmulh v19.4s, v19.4s, v15.4s\n" + + // Apply the negative exponent part of the multiplier. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v11.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v11.4s\n" + + "b 9f\n" + + "8:\n" + // Case where channels are columns + + // Apply the positive exponent part of the multiplier. + "dup v20.4s, v12.s[0]\n" + "ldr x2, [%[lhs_ptr], #24]\n" + "ldr x3, [%[lhs_ptr], #40]\n" + "dup v21.4s, v12.s[1]\n" + "ldr x4, [%[lhs_ptr], #56]\n" + "dup v22.4s, v12.s[2]\n" + "ins v0.d[1], x1\n" + "dup v23.4s, v12.s[3]\n" + "ldr x1, [%[rhs_ptr], #8]\n" + "sshl v16.4s, v16.4s, v20.4s\n" + "ins v1.d[1], x2\n" + "sshl v17.4s, v17.4s, v21.4s\n" + "ldr x2, [%[rhs_ptr], #24]\n" + "sshl v18.4s, v18.4s, v22.4s\n" + "ins v2.d[1], x3\n" + "sshl v19.4s, v19.4s, v23.4s\n" + "ldr x3, [%[rhs_ptr], #40]\n" + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v15.s[0]\n" + "ins v3.d[1], x4\n" + "sqdmulh v17.4s, v17.4s, v15.s[1]\n" + "ldr x4, [%[rhs_ptr], #56]\n" + "sqdmulh v18.4s, v18.4s, v15.s[2]\n" + "dup v20.4s, v11.s[0]\n" + "sqdmulh v19.4s, v19.4s, v15.s[3]\n" + + // Apply the negative exponent part of the multiplier. + "dup v21.4s, v11.s[1]\n" + "srshl v16.4s, v16.4s, v20.4s\n" + "dup v22.4s, v11.s[2]\n" + "srshl v17.4s, v17.4s, v21.4s\n" + "dup v23.4s, v11.s[3]\n" + "srshl v18.4s, v18.4s, v22.4s\n" + "srshl v19.4s, v19.4s, v23.4s\n" + + "9:\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "dup v14.8h, v13.h[4]\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "add v16.8h, v16.8h, v14.8h\n" + RUY_MAKE_ZERO(v21) + "add v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v22) + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + RUY_MAKE_ZERO(v23) + "sqxtun2 v16.16b, v17.8h\n" + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.16b, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.16b, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "dup v14.8h, v13.h[4]\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "add v16.8h, v16.8h, v14.8h\n" + RUY_MAKE_ZERO(v21) + "add v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v22) + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + RUY_MAKE_ZERO(v23) + "sqxtn2 v16.16b, v17.8h\n" + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.16b, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.16b, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + "add %[lhs_ptr], %[lhs_ptr], #64\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.8h, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.8h, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[5], [x3], #2\n" + "st1 {v16.h}[6], [x3], #2\n" + "st1 {v16.h}[7], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[1], [x3], #2\n" + "st1 {v17.h}[2], [x3], #2\n" + "st1 {v17.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[5], [x3], #2\n" + "st1 {v17.h}[6], [x3], #2\n" + "st1 {v17.h}[7], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + "ldr x1, [%[lhs_ptr], #8]\n" + "ldr x2, [%[lhs_ptr], #24]\n" + "ldr x3, [%[lhs_ptr], #40]\n" + "ldr x4, [%[lhs_ptr], #56]\n" + + "ins v0.d[1], x1\n" + "ldr x1, [%[rhs_ptr], #8]\n" + "ins v1.d[1], x2\n" + "ldr x2, [%[rhs_ptr], #24]\n" + "ins v2.d[1], x3\n" + "ldr x3, [%[rhs_ptr], #40]\n" + "ins v3.d[1], x4\n" + "ldr x4, [%[rhs_ptr], #56]\n" + "ins v4.d[1], x1\n" + "ins v5.d[1], x2\n" + "ins v6.d[1], x3\n" + "ins v7.d[1], x4\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + + RUY_MAKE_ZERO(v20) + "add %[lhs_ptr], %[lhs_ptr], #64\n" + RUY_MAKE_ZERO(v21) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + RUY_MAKE_ZERO(v22) + + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + "str q18, [%[dst_tmp_buf], #32]\n" + "str q19, [%[dst_tmp_buf], #48]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v17.s}[1], [x3], #4\n" + "st1 {v17.s}[2], [x3], #4\n" + "st1 {v17.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v18.s}[1], [x3], #4\n" + "st1 {v18.s}[2], [x3], #4\n" + "st1 {v18.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v19.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v19.s}[1], [x3], #4\n" + "st1 {v19.s}[2], [x3], #4\n" + "st1 {v19.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "smull v11.8h, v3.8b, v4.8b\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "smull v12.8h, v0.8b, v5.8b\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms),[dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Kernel taking advantage of the optional dotprod instruction. +// This is very similar to (and directly inspired by) this gemmlowp kernel +// which was contributed by David Mansell at ARM: +// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391 +// +// Besides the ruy-ification, the main difference here is that we use a 8x8 +// instead of 12x8 width, so as to stick to power-of-two widths. This slightly +// narrower kernel layout is still wide enough to achieve high performance +// although we haven't actually performed a real comparison to know exactly +// how this compares to ARM's aforementioned kernel. +// +// Relevant target CPUs for this kernel include ARM Cortex-A76, +// since these are 64-bit, out-of-order and with dotprod support. +void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel (kNeonDotprod)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v15 are used to load int8 data from LHS and + // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and + // v3 are used to load a 4x8 block of RHS, like this: + // + // int8 RHS 4x8 block + // /-----------------------------------------| + // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| + // | ... ... | + // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| + // \-----------------------------------------/ + // int8 LHS 8x4 block + // /---------------------\ /-----------------------------------------| + // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| + // | ... ... | | ... ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| + // | ... ... | | ... ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 8x8 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for loading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Optional, maximally-streaming, partial-unrolling (4x unrolled) + // optimization of the kernel inner loop (over depth). For more + // comments, see the non-unrolled loop below after the #endif. +#if RUY_OPT(MAX_STREAMING) + "cmp w12, #32\n" + "blt 78f\n" + + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v8.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v9.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" + "mov w1, #16\n" + + "and w3, w12, #-16\n" + "81:\n" + "add w1, w1, #16\n" + + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ldr q0, [%[lhs_ptr], #0]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ldr q2, [%[rhs_ptr], #0]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ldr q1, [%[lhs_ptr], #16]\n" + + ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" + ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" + "ldr q3, [%[rhs_ptr], #16]\n" + ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" + ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" + ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" + ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" + ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" + ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" + ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" + ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" + ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" + ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" + "ldr q5, [%[lhs_ptr], #48]\n" + ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" + ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" + "ldr q7, [%[rhs_ptr], #48]\n" + ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" + ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" + "ldr q4, [%[lhs_ptr], #32]\n" + + ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" + ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" + "ldr q6, [%[rhs_ptr], #32]\n" + ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" + ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" + ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" + ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" + ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" + ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" + ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" + ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" + ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" + ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" + "ldr q9, [%[lhs_ptr], #80]\n" + ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" + ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" + "ldr q11, [%[rhs_ptr], #80]\n" + ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" + ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" + "ldr q8, [%[lhs_ptr], #64]\n" + + ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" + ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" + "ldr q10, [%[rhs_ptr], #64]\n" + ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" + ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" + "add %[lhs_ptr], %[lhs_ptr], #128\n" + ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" + ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" + "add %[rhs_ptr], %[rhs_ptr], #128\n" + ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" + ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" + ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" + ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" + "cmp w1, w3\n" + ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" + ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" + "ldr q13, [%[lhs_ptr], #-16]\n" + ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" + ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" + "ldr q15, [%[rhs_ptr], #-16]\n" + ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" + ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" + "ldr q12, [%[lhs_ptr], #-32]\n" + + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + "ldr q14, [%[rhs_ptr], #-32]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + "blt 81b\n" + + ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" + ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" + ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" + ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" + ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" + ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" + ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" + ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" + ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" + ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" + ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" + ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" + ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" + ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" + ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" + ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" + + ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" + ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" + ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" + ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" + ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" + ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" + ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" + ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" + ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" + ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" + ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" + ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" + ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" + ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" + ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" + ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" + + ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" + ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" + ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" + ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" + ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" + ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" + ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" + ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" + ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" + ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" + ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" + ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" + ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" + ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" + ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" + ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" + + "78:\n" + +#endif // #if RUY_OPT(MAX_STREAMING) + + // Ordinary kernel inner loop (over depth), the simpler loop that the + // above was an equivalent 4x-partially-unrolled version of. + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Because of the data that we have already loaded, we can start the + // loop body right away with some multiply-adds. + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + // Each iteration of this loop advances by 4 levels of depth. + "add w1, w1, #4\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + // Loop termination condition. + "cmp w1, w12\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + + "blt 2b\n" + + "79:\n" + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last 4 levels of depth, for which the LHS + // and RHS data is already loaded. + + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "mvni v8.4s, #0\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v15.4s\n" + "add v20.4s, v20.4s, v14.4s\n" + "add v21.4s, v21.4s, v15.4s\n" + "add v22.4s, v22.4s, v14.4s\n" + "add v23.4s, v23.4s, v15.4s\n" + "add v24.4s, v24.4s, v14.4s\n" + "add v25.4s, v25.4s, v15.4s\n" + "add v26.4s, v26.4s, v14.4s\n" + "add v27.4s, v27.4s, v15.4s\n" + "add v28.4s, v28.4s, v14.4s\n" + "add v29.4s, v29.4s, v15.4s\n" + "add v30.4s, v30.4s, v14.4s\n" + "add v31.4s, v31.4s, v15.4s\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v10.4s, v14.s[0]\n" + "dup v11.4s, v14.s[1]\n" + "dup v12.4s, v14.s[2]\n" + "dup v13.4s, v14.s[3]\n" + "add v16.4s, v16.4s, v10.4s\n" + "add v17.4s, v17.4s, v10.4s\n" + "add v18.4s, v18.4s, v11.4s\n" + "add v19.4s, v19.4s, v11.4s\n" + "add v20.4s, v20.4s, v12.4s\n" + "add v21.4s, v21.4s, v12.4s\n" + "add v22.4s, v22.4s, v13.4s\n" + "add v23.4s, v23.4s, v13.4s\n" + "dup v10.4s, v15.s[0]\n" + "dup v11.4s, v15.s[1]\n" + "dup v12.4s, v15.s[2]\n" + "dup v13.4s, v15.s[3]\n" + "add v24.4s, v24.4s, v10.4s\n" + "add v25.4s, v25.4s, v10.4s\n" + "add v26.4s, v26.4s, v11.4s\n" + "add v27.4s, v27.4s, v11.4s\n" + "add v28.4s, v28.4s, v12.4s\n" + "add v29.4s, v29.4s, v12.4s\n" + "add v30.4s, v30.4s, v13.4s\n" + "add v31.4s, v31.4s, v13.4s\n" + "7:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3], #16\n" + "ld1 {v15.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "mls v18.4s, v10.4s, v14.s[1]\n" + "mls v19.4s, v10.4s, v14.s[1]\n" + "mls v20.4s, v10.4s, v14.s[2]\n" + "mls v21.4s, v10.4s, v14.s[2]\n" + "mls v22.4s, v10.4s, v14.s[3]\n" + "mls v23.4s, v10.4s, v14.s[3]\n" + "mls v24.4s, v10.4s, v15.s[0]\n" + "mls v25.4s, v10.4s, v15.s[0]\n" + "mls v26.4s, v10.4s, v15.s[1]\n" + "mls v27.4s, v10.4s, v15.s[1]\n" + "mls v28.4s, v10.4s, v15.s[2]\n" + "mls v29.4s, v10.4s, v15.s[2]\n" + "mls v30.4s, v10.4s, v15.s[3]\n" + "mls v31.4s, v10.4s, v15.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2], #16\n" + "ld1 {v12.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v12.4s\n" + "sub v20.4s, v20.4s, v11.4s\n" + "sub v21.4s, v21.4s, v12.4s\n" + "sub v22.4s, v22.4s, v11.4s\n" + "sub v23.4s, v23.4s, v12.4s\n" + "sub v24.4s, v24.4s, v11.4s\n" + "sub v25.4s, v25.4s, v12.4s\n" + "sub v26.4s, v26.4s, v11.4s\n" + "sub v27.4s, v27.4s, v12.4s\n" + "sub v28.4s, v28.4s, v11.4s\n" + "sub v29.4s, v29.4s, v12.4s\n" + "sub v30.4s, v30.4s, v11.4s\n" + "sub v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + // Compute the multiplier_exponent pointer + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, x3, lsl #2\n" + "csel x1, x1, x5, eq\n" + // Load multiplier_exponent + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + // Separate positive and negative exponents + "smin v11.4s, v8.4s, v9.4s\n" + "smin v12.4s, v8.4s, v10.4s\n" + "sub v9.4s, v9.4s, v11.4s\n" + "sub v10.4s, v10.4s, v12.4s\n" + + // Compute the multiplier_fixedpoint pointer + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "add x5, x4, x3, lsl #2\n" + "csel x4, x4, x5, eq\n" + // Load multiplier_fixedpoint + "ldr q14, [x4]\n" + "ldr q15, [x4, #16]\n" + + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 8f\n" + // Case where channels are rows + + // Apply the positive exponent part of the multiplier. + "sshl v16.4s, v16.4s, v9.4s\n" + "sshl v17.4s, v17.4s, v10.4s\n" + "sshl v18.4s, v18.4s, v9.4s\n" + "sshl v19.4s, v19.4s, v10.4s\n" + "sshl v20.4s, v20.4s, v9.4s\n" + "sshl v21.4s, v21.4s, v10.4s\n" + "sshl v22.4s, v22.4s, v9.4s\n" + "sshl v23.4s, v23.4s, v10.4s\n" + "sshl v24.4s, v24.4s, v9.4s\n" + "sshl v25.4s, v25.4s, v10.4s\n" + "sshl v26.4s, v26.4s, v9.4s\n" + "sshl v27.4s, v27.4s, v10.4s\n" + "sshl v28.4s, v28.4s, v9.4s\n" + "sshl v29.4s, v29.4s, v10.4s\n" + "sshl v30.4s, v30.4s, v9.4s\n" + "sshl v31.4s, v31.4s, v10.4s\n" + "10:\n" + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v14.4s\n" + "sqdmulh v17.4s, v17.4s, v15.4s\n" + "sqdmulh v18.4s, v18.4s, v14.4s\n" + "sqdmulh v19.4s, v19.4s, v15.4s\n" + "sqdmulh v20.4s, v20.4s, v14.4s\n" + "sqdmulh v21.4s, v21.4s, v15.4s\n" + "sqdmulh v22.4s, v22.4s, v14.4s\n" + "sqdmulh v23.4s, v23.4s, v15.4s\n" + "sqdmulh v24.4s, v24.4s, v14.4s\n" + "sqdmulh v25.4s, v25.4s, v15.4s\n" + "sqdmulh v26.4s, v26.4s, v14.4s\n" + "sqdmulh v27.4s, v27.4s, v15.4s\n" + "sqdmulh v28.4s, v28.4s, v14.4s\n" + "sqdmulh v29.4s, v29.4s, v15.4s\n" + "sqdmulh v30.4s, v30.4s, v14.4s\n" + "sqdmulh v31.4s, v31.4s, v15.4s\n" + + // Apply the negative exponent part of the multiplier. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + "srshl v20.4s, v20.4s, v11.4s\n" + "srshl v21.4s, v21.4s, v12.4s\n" + "srshl v22.4s, v22.4s, v11.4s\n" + "srshl v23.4s, v23.4s, v12.4s\n" + "srshl v24.4s, v24.4s, v11.4s\n" + "srshl v25.4s, v25.4s, v12.4s\n" + "srshl v26.4s, v26.4s, v11.4s\n" + "srshl v27.4s, v27.4s, v12.4s\n" + "srshl v28.4s, v28.4s, v11.4s\n" + "srshl v29.4s, v29.4s, v12.4s\n" + "srshl v30.4s, v30.4s, v11.4s\n" + "srshl v31.4s, v31.4s, v12.4s\n" + "b 9f\n" + + "8:\n" + // Case where channels are columns + + // Apply the positive exponent part of the multiplier. + "dup v4.4s, v9.s[0]\n" + "dup v5.4s, v9.s[1]\n" + "dup v6.4s, v9.s[2]\n" + "dup v7.4s, v9.s[3]\n" + "sshl v16.4s, v16.4s, v4.4s\n" + "sshl v17.4s, v17.4s, v4.4s\n" + "sshl v18.4s, v18.4s, v5.4s\n" + "sshl v19.4s, v19.4s, v5.4s\n" + "sshl v20.4s, v20.4s, v6.4s\n" + "sshl v21.4s, v21.4s, v6.4s\n" + "sshl v22.4s, v22.4s, v7.4s\n" + "sshl v23.4s, v23.4s, v7.4s\n" + "dup v4.4s, v10.s[0]\n" + "dup v5.4s, v10.s[1]\n" + "dup v6.4s, v10.s[2]\n" + "dup v7.4s, v10.s[3]\n" + "sshl v24.4s, v24.4s, v4.4s\n" + "sshl v25.4s, v25.4s, v4.4s\n" + "sshl v26.4s, v26.4s, v5.4s\n" + "sshl v27.4s, v27.4s, v5.4s\n" + "sshl v28.4s, v28.4s, v6.4s\n" + "sshl v29.4s, v29.4s, v6.4s\n" + "sshl v30.4s, v30.4s, v7.4s\n" + "sshl v31.4s, v31.4s, v7.4s\n" + "11:\n" + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v14.s[0]\n" + "sqdmulh v17.4s, v17.4s, v14.s[0]\n" + "sqdmulh v18.4s, v18.4s, v14.s[1]\n" + "sqdmulh v19.4s, v19.4s, v14.s[1]\n" + "sqdmulh v20.4s, v20.4s, v14.s[2]\n" + "sqdmulh v21.4s, v21.4s, v14.s[2]\n" + "sqdmulh v22.4s, v22.4s, v14.s[3]\n" + "sqdmulh v23.4s, v23.4s, v14.s[3]\n" + "sqdmulh v24.4s, v24.4s, v15.s[0]\n" + "sqdmulh v25.4s, v25.4s, v15.s[0]\n" + "sqdmulh v26.4s, v26.4s, v15.s[1]\n" + "sqdmulh v27.4s, v27.4s, v15.s[1]\n" + "sqdmulh v28.4s, v28.4s, v15.s[2]\n" + "sqdmulh v29.4s, v29.4s, v15.s[2]\n" + "sqdmulh v30.4s, v30.4s, v15.s[3]\n" + "sqdmulh v31.4s, v31.4s, v15.s[3]\n" + + // Apply the negative exponent part of the multiplier. + "dup v4.4s, v11.s[0]\n" + "dup v5.4s, v11.s[1]\n" + "dup v6.4s, v11.s[2]\n" + "dup v7.4s, v11.s[3]\n" + "srshl v16.4s, v16.4s, v4.4s\n" + "srshl v17.4s, v17.4s, v4.4s\n" + "srshl v18.4s, v18.4s, v5.4s\n" + "srshl v19.4s, v19.4s, v5.4s\n" + "srshl v20.4s, v20.4s, v6.4s\n" + "srshl v21.4s, v21.4s, v6.4s\n" + "srshl v22.4s, v22.4s, v7.4s\n" + "srshl v23.4s, v23.4s, v7.4s\n" + "dup v4.4s, v12.s[0]\n" + "dup v5.4s, v12.s[1]\n" + "dup v6.4s, v12.s[2]\n" + "dup v7.4s, v12.s[3]\n" + "srshl v24.4s, v24.4s, v4.4s\n" + "srshl v25.4s, v25.4s, v4.4s\n" + "srshl v26.4s, v26.4s, v5.4s\n" + "srshl v27.4s, v27.4s, v5.4s\n" + "srshl v28.4s, v28.4s, v6.4s\n" + "srshl v29.4s, v29.4s, v6.4s\n" + "srshl v30.4s, v30.4s, v7.4s\n" + "srshl v31.4s, v31.4s, v7.4s\n" + "9:\n" + + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "sqxtun2 v16.16b, v17.8h\n" + "sqxtun v17.8b, v18.8h\n" + "sqxtun2 v17.16b, v19.8h\n" + "sqxtun v18.8b, v20.8h\n" + "sqxtun2 v18.16b, v21.8h\n" + "sqxtun v19.8b, v22.8h\n" + "sqxtun2 v19.16b, v23.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + "umax v17.16b, v17.16b, v14.16b\n" + "umax v18.16b, v18.16b, v14.16b\n" + "umax v19.16b, v19.16b, v14.16b\n" + + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + "umin v17.16b, v17.16b, v15.16b\n" + "umin v18.16b, v18.16b, v15.16b\n" + "umin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + "sqxtn2 v16.16b, v17.8h\n" + "sqxtn v17.8b, v18.8h\n" + "sqxtn2 v17.16b, v19.8h\n" + "sqxtn v18.8b, v20.8h\n" + "sqxtn2 v18.16b, v21.8h\n" + "sqxtn v19.8b, v22.8h\n" + "sqxtn2 v19.16b, v23.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + "smax v17.16b, v17.16b, v14.16b\n" + "smax v18.16b, v18.16b, v14.16b\n" + "smax v19.16b, v19.16b, v14.16b\n" + + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + "smin v17.16b, v17.16b, v15.16b\n" + "smin v18.16b, v18.16b, v15.16b\n" + "smin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 150b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + "saddw v20.4s, v20.4s, v14.4h\n" + "saddw v21.4s, v21.4s, v14.4h\n" + "saddw v22.4s, v22.4s, v14.4h\n" + "saddw v23.4s, v23.4s, v14.4h\n" + "saddw v24.4s, v24.4s, v14.4h\n" + "saddw v25.4s, v25.4s, v14.4h\n" + "saddw v26.4s, v26.4s, v14.4h\n" + "saddw v27.4s, v27.4s, v14.4h\n" + "saddw v28.4s, v28.4s, v14.4h\n" + "saddw v29.4s, v29.4s, v14.4h\n" + "saddw v30.4s, v30.4s, v14.4h\n" + "saddw v31.4s, v31.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + "smax v18.8h, v18.8h, v14.8h\n" + "smax v19.8h, v19.8h, v14.8h\n" + "smax v20.8h, v20.8h, v14.8h\n" + "smax v21.8h, v21.8h, v14.8h\n" + "smax v22.8h, v22.8h, v14.8h\n" + "smax v23.8h, v23.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + "smin v18.8h, v18.8h, v15.8h\n" + "smin v19.8h, v19.8h, v15.8h\n" + "smin v20.8h, v20.8h, v15.8h\n" + "smin v21.8h, v21.8h, v15.8h\n" + "smin v22.8h, v22.8h, v15.8h\n" + "smin v23.8h, v23.8h, v15.8h\n" + + // Compute how much of the 8x8 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 16bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 250b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x8 block of destination 32it values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x8 block fits. + // Write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "st1 {v16.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v16) + "st1 {v17.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v17) + "st1 {v18.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v18) + "st1 {v19.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v19) + "st1 {v20.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v20) + "st1 {v21.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v21) + "st1 {v22.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v22) + "st1 {v23.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v23) + "st1 {v24.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v24) + "st1 {v25.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v25) + "st1 {v26.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v26) + "st1 {v27.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v27) + "st1 {v28.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v28) + "st1 {v29.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v29) + "st1 {v30.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v30) + "st1 {v31.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v31) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x8 block fits. + "mov x4, %[dst_ptr]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.4s, v17.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.4s, v19.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v20.4s, v21.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v22.4s, v23.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v24.4s, v25.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v26.4s, v27.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v28.4s, v29.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v30.4s, v31.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 350b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Similar to the above 8-bit dotprod kernel, but specialized for the case of +// RHS cols == 1. +// Relevant target CPUs for this kernel include ARM Cortex-A76, +// since these are 64-bit, out-of-order and with dotprod support. +void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel (kNeonDotprod)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)); + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v15 are used to load int8 data from LHS and + // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and + // v3 are used to load a 4x8 block of RHS, like this: + // + // int8 RHS 4x1 block + // /-------| + // |v2.b[0]| + // | ... | + // |v2.b[3]| + // \-------/ + // int8 LHS 8x4 block + // /---------------------\ /--------| + // |v0.b[0] ... v0.b[3] | |v16.s[0]| + // | ... ... | | ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0]| + // | ... ... | | ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3]| + // \---------------------/ \--------/ + // int32 accumulators 8x1 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for loading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Ordinary kernel inner loop (over depth), the simpler loop that the + // above was an equivalent 4x-partially-unrolled version of. + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Because of the data that we have already loaded, we can start the + // loop body right away with some multiply-adds. + // Each iteration of this loop advances by 4 levels of depth. + "add w1, w1, #4\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + // Loop termination condition. + "cmp w1, w12\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + + "blt 2b\n" + + "79:\n" + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last 4 levels of depth, for which the LHS + // and RHS data is already loaded. + + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "mvni v8.4s, #0\n" + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3], #16\n" + "ld1 {v15.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2], #16\n" + "ld1 {v12.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + + "smin v11.4s, v8.4s, v9.4s\n" + "smin v12.4s, v8.4s, v10.4s\n" + "sub v9.4s, v9.4s, v11.4s\n" + "sub v10.4s, v10.4s, v12.4s\n" + + // Apply the positive exponent part of the multiplier. + "sshl v16.4s, v16.4s, v9.4s\n" + "sshl v17.4s, v17.4s, v10.4s\n" + "403:\n" + + "ldr q14, [x4]\n" // multiplier_fixedpoint + "ldr q15, [x4, #16]\n" // multiplier_fixedpoint + + // Apply the fixed-point part of the multiplier. + "sqdmulh v16.4s, v16.4s, v14.4s\n" + "sqdmulh v17.4s, v17.4s, v15.4s\n" + + // Apply the negative exponent part of the multiplier. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + // All data in v16 at this point. + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8, leaving all data in the + // lower half of v16. + "sqxtun v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + + // Compute how much of the 8x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x1, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x1 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8b}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + + // Compute how much of the 8x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x1 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8b}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + + // Compute how much of the 8x1 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 16bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8h}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x1 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x1 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x1 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x1, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + + // Write our 32bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.4s}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.4s}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x4, %[dst_ptr]\n" + "mov x3, x4\n" + + // Write our 32bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.4s, v17.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17"); +} + +// Variant of the above Kernel8bitNeonDotprod, tuned for in-order +// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1, +// since these are 64-bit and support dotprod. +// +// While this kernel does not have a direct equivalent in gemmlowp, it was +// developed based on insights that David Mansell at ARM shared with their +// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A55r1: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 +void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for in-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // RHS. + // + // int8 RHS 4x8 block + // /-----------------------------------------| + // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| + // | ... ... | + // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| + // \-----------------------------------------/ + // int8 LHS 8x4 block + // /---------------------\ /-----------------------------------------| + // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| + // | ... ... | | ... ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| + // | ... ... | | ... ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for + // the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + RUY_MAKE_ZERO(v16) + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(v17) + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + RUY_MAKE_ZERO(v18) + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + RUY_MAKE_ZERO(v19) + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v20) + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v21) + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + RUY_MAKE_ZERO(v22) + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_MAKE_ZERO(v28) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_MAKE_ZERO(v29) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_MAKE_ZERO(v30) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_MAKE_ZERO(v31) + + + "1:\n" + + "add x5, %[lhs_ptr], x12, lsl #3\n" + "sub x5, x5, #32\n" + "cmp %[lhs_ptr], x5\n" + + "beq 79f\n" + + // Main accumulation loop + "2:\n" + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + "ldr x1, [%[lhs_ptr], #8]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + "ldr x3, [%[rhs_ptr], #8]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + "ldr x4, [%[rhs_ptr], #24]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ldr d0, [%[lhs_ptr], #0]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + "ins v0.d[1], x1\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + "ldr x2, [%[lhs_ptr], #24]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + "add %[lhs_ptr], %[lhs_ptr], #32\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ldr d2, [%[rhs_ptr], #0]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + "ins v2.d[1], x3\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + "cmp %[lhs_ptr], x5\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ldr d3, [%[rhs_ptr], #-16]\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + "ins v3.d[1], x4\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + "ins v1.d[1], x2\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + "blt 2b\n" + + // Last accumulation steps, nothing left to load. + "79:\n" + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + "cmp %w[row], w7\n" // Have we finished the last row? + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + // Load some parameters needed for the end work on current block. + "mvni v8.4s, #0\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + // Determine the channel index. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.2s}, [x1], #8\n" + "ldr x5, [x1], #8\n" + "ins v14.d[1], x5\n" + "ld1 {v15.2s}, [x1], #8\n" + "ldr x5, [x1], #8\n" + "ins v15.d[1], x5\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v15.4s\n" + "add v20.4s, v20.4s, v14.4s\n" + "add v21.4s, v21.4s, v15.4s\n" + "add v22.4s, v22.4s, v14.4s\n" + "add v23.4s, v23.4s, v15.4s\n" + "add v24.4s, v24.4s, v14.4s\n" + "add v25.4s, v25.4s, v15.4s\n" + "add v26.4s, v26.4s, v14.4s\n" + "add v27.4s, v27.4s, v15.4s\n" + "add v28.4s, v28.4s, v14.4s\n" + "add v29.4s, v29.4s, v15.4s\n" + "add v30.4s, v30.4s, v14.4s\n" + "add v31.4s, v31.4s, v15.4s\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v10.4s, v14.s[0]\n" + "dup v11.4s, v14.s[1]\n" + "add v16.4s, v16.4s, v10.4s\n" + "dup v12.4s, v14.s[2]\n" + "add v17.4s, v17.4s, v10.4s\n" + "dup v13.4s, v14.s[3]\n" + "add v18.4s, v18.4s, v11.4s\n" + "dup v10.4s, v15.s[0]\n" + "add v19.4s, v19.4s, v11.4s\n" + "dup v11.4s, v15.s[1]\n" + "add v20.4s, v20.4s, v12.4s\n" + "add v21.4s, v21.4s, v12.4s\n" + "dup v12.4s, v15.s[2]\n" + "add v22.4s, v22.4s, v13.4s\n" + "add v23.4s, v23.4s, v13.4s\n" + "dup v13.4s, v15.s[3]\n" + "add v24.4s, v24.4s, v10.4s\n" + "add v25.4s, v25.4s, v10.4s\n" + "add v26.4s, v26.4s, v11.4s\n" + "add v27.4s, v27.4s, v11.4s\n" + "add v28.4s, v28.4s, v12.4s\n" + "add v29.4s, v29.4s, v12.4s\n" + "add v30.4s, v30.4s, v13.4s\n" + "add v31.4s, v31.4s, v13.4s\n" + "7:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x5, x5, %x[col], lsl #2\n" + // Load 8 rhs_sums values. + "ld1 {v14.2s}, [x5], #8\n" + "ldr x7, [x5], #8\n" + "ld1 {v15.2s}, [x5], #8\n" + "ins v14.d[1], x7\n" + "ldr x7, [x5], #8\n" + "ins v15.d[1], x7\n" + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "mls v18.4s, v10.4s, v14.s[1]\n" + "mls v19.4s, v10.4s, v14.s[1]\n" + "mls v20.4s, v10.4s, v14.s[2]\n" + "mls v21.4s, v10.4s, v14.s[2]\n" + "mls v22.4s, v10.4s, v14.s[3]\n" + "mls v23.4s, v10.4s, v14.s[3]\n" + "mls v24.4s, v10.4s, v15.s[0]\n" + "mls v25.4s, v10.4s, v15.s[0]\n" + "mls v26.4s, v10.4s, v15.s[1]\n" + "mls v27.4s, v10.4s, v15.s[1]\n" + "mls v28.4s, v10.4s, v15.s[2]\n" + "mls v29.4s, v10.4s, v15.s[2]\n" + "mls v30.4s, v10.4s, v15.s[3]\n" + "mls v31.4s, v10.4s, v15.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Load 8 lhs_sums values. + "ld1 {v11.2s}, [x2], #8\n" + "ldr x4, [x2], #8\n" + "ins v11.d[1], x4\n" + "ld1 {v12.2s}, [x2], #8\n" + "ldr x4, [x2], #8\n" + "ins v12.d[1], x4\n" + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v12.4s\n" + "sub v20.4s, v20.4s, v11.4s\n" + "sub v21.4s, v21.4s, v12.4s\n" + "sub v22.4s, v22.4s, v11.4s\n" + "sub v23.4s, v23.4s, v12.4s\n" + "sub v24.4s, v24.4s, v11.4s\n" + "sub v25.4s, v25.4s, v12.4s\n" + "sub v26.4s, v26.4s, v11.4s\n" + "sub v27.4s, v27.4s, v12.4s\n" + "sub v28.4s, v28.4s, v11.4s\n" + "sub v29.4s, v29.4s, v12.4s\n" + "sub v30.4s, v30.4s, v11.4s\n" + "sub v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + // Compute the multiplier_exponent pointer + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, x3, lsl #2\n" + "csel x1, x1, x5, eq\n" + // Load multiplier_exponent + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + // Separate positive and negative exponents + "smin v11.4s, v8.4s, v9.4s\n" + "smin v12.4s, v8.4s, v10.4s\n" + "sub v9.4s, v9.4s, v11.4s\n" + "sub v10.4s, v10.4s, v12.4s\n" + + // Compute the multiplier_fixedpoint pointer + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "add x5, x4, x3, lsl #2\n" + "csel x4, x4, x5, eq\n" + // Load multiplier_fixedpoint + "ldr q14, [x4]\n" + "ldr q15, [x4, #16]\n" + + // Jump based on channel dimension. + "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 8f\n" + // Case where channels are rows + + // Apply the positive exponent part of the multiplier. + "sshl v16.4s, v16.4s, v9.4s\n" + "sshl v17.4s, v17.4s, v10.4s\n" + "sshl v18.4s, v18.4s, v9.4s\n" + "sshl v19.4s, v19.4s, v10.4s\n" + "sshl v20.4s, v20.4s, v9.4s\n" + "sshl v21.4s, v21.4s, v10.4s\n" + "sshl v22.4s, v22.4s, v9.4s\n" + "sshl v23.4s, v23.4s, v10.4s\n" + "sshl v24.4s, v24.4s, v9.4s\n" + "sshl v25.4s, v25.4s, v10.4s\n" + "sshl v26.4s, v26.4s, v9.4s\n" + "sshl v27.4s, v27.4s, v10.4s\n" + "sshl v28.4s, v28.4s, v9.4s\n" + "sshl v29.4s, v29.4s, v10.4s\n" + "sshl v30.4s, v30.4s, v9.4s\n" + "sshl v31.4s, v31.4s, v10.4s\n" + "10:\n" + + // Apply the fixed-point part of the multiplier. + // + // ... and, interleaved into that: + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" + "sqdmulh v16.4s, v16.4s, v14.4s\n" + "ldr x1, [%[lhs_ptr]], #8\n" + "sqdmulh v17.4s, v17.4s, v15.4s\n" + "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" + "sqdmulh v18.4s, v18.4s, v14.4s\n" + "ldr x2, [%[lhs_ptr]], #8\n" + "sqdmulh v19.4s, v19.4s, v15.4s\n" + "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" + "sqdmulh v20.4s, v20.4s, v14.4s\n" + "ldr x5, [%[rhs_ptr]], #8\n" + "sqdmulh v21.4s, v21.4s, v15.4s\n" + "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" + "sqdmulh v22.4s, v22.4s, v14.4s\n" + "ldr x6, [%[rhs_ptr]], #8\n" + "sqdmulh v23.4s, v23.4s, v15.4s\n" + "sqdmulh v24.4s, v24.4s, v14.4s\n" + "sqdmulh v25.4s, v25.4s, v15.4s\n" + "sqdmulh v26.4s, v26.4s, v14.4s\n" + "sqdmulh v27.4s, v27.4s, v15.4s\n" + "sqdmulh v28.4s, v28.4s, v14.4s\n" + "sqdmulh v29.4s, v29.4s, v15.4s\n" + "sqdmulh v30.4s, v30.4s, v14.4s\n" + "sqdmulh v31.4s, v31.4s, v15.4s\n" + + // Apply the negative exponent part of the multiplier. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + "srshl v20.4s, v20.4s, v11.4s\n" + "srshl v21.4s, v21.4s, v12.4s\n" + "srshl v22.4s, v22.4s, v11.4s\n" + "srshl v23.4s, v23.4s, v12.4s\n" + "srshl v24.4s, v24.4s, v11.4s\n" + "srshl v25.4s, v25.4s, v12.4s\n" + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "srshl v26.4s, v26.4s, v11.4s\n" + "ins v13.h[4], w4\n" // dst_zero_point + "srshl v27.4s, v27.4s, v12.4s\n" + "ins v0.d[1], x1\n" + "srshl v28.4s, v28.4s, v11.4s\n" + "ins v1.d[1], x2\n" + "srshl v29.4s, v29.4s, v12.4s\n" + "ins v2.d[1], x5\n" + "srshl v30.4s, v30.4s, v11.4s\n" + "ins v3.d[1], x6\n" + "srshl v31.4s, v31.4s, v12.4s\n" + "b 9f\n" + + "8:\n" + // Case where channels are columns + + // Apply the positive exponent part of the multiplier. + "dup v4.4s, v9.s[0]\n" + "dup v5.4s, v9.s[1]\n" + "sshl v16.4s, v16.4s, v4.4s\n" + "dup v6.4s, v9.s[2]\n" + "sshl v17.4s, v17.4s, v4.4s\n" + "dup v7.4s, v9.s[3]\n" + "sshl v18.4s, v18.4s, v5.4s\n" + "dup v4.4s, v10.s[0]\n" + "sshl v19.4s, v19.4s, v5.4s\n" + "dup v5.4s, v10.s[1]\n" + "sshl v20.4s, v20.4s, v6.4s\n" + "sshl v21.4s, v21.4s, v6.4s\n" + "dup v6.4s, v10.s[2]\n" + "sshl v22.4s, v22.4s, v7.4s\n" + "sshl v23.4s, v23.4s, v7.4s\n" + "dup v7.4s, v10.s[3]\n" + "sshl v24.4s, v24.4s, v4.4s\n" + "sshl v25.4s, v25.4s, v4.4s\n" + "sshl v26.4s, v26.4s, v5.4s\n" + "sshl v27.4s, v27.4s, v5.4s\n" + "sshl v28.4s, v28.4s, v6.4s\n" + "sshl v29.4s, v29.4s, v6.4s\n" + "sshl v30.4s, v30.4s, v7.4s\n" + "sshl v31.4s, v31.4s, v7.4s\n" + "11:\n" + + // Apply the fixed-point part of the multiplier. + // + // ... and, interleaved into that: + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" + "sqdmulh v16.4s, v16.4s, v14.s[0]\n" + "ldr x1, [%[lhs_ptr]], #8\n" + "sqdmulh v17.4s, v17.4s, v14.s[0]\n" + "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" + "sqdmulh v18.4s, v18.4s, v14.s[1]\n" + "ldr x2, [%[lhs_ptr]], #8\n" + "sqdmulh v19.4s, v19.4s, v14.s[1]\n" + "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" + "sqdmulh v20.4s, v20.4s, v14.s[2]\n" + "ldr x5, [%[rhs_ptr]], #8\n" + "sqdmulh v21.4s, v21.4s, v14.s[2]\n" + "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" + "sqdmulh v22.4s, v22.4s, v14.s[3]\n" + "ldr x6, [%[rhs_ptr]], #8\n" + "sqdmulh v23.4s, v23.4s, v14.s[3]\n" + "dup v4.4s, v11.s[0]\n" + "sqdmulh v24.4s, v24.4s, v15.s[0]\n" + "dup v5.4s, v11.s[1]\n" + "sqdmulh v25.4s, v25.4s, v15.s[0]\n" + "dup v6.4s, v11.s[2]\n" + "sqdmulh v26.4s, v26.4s, v15.s[1]\n" + "dup v7.4s, v11.s[3]\n" + "sqdmulh v27.4s, v27.4s, v15.s[1]\n" + "sqdmulh v28.4s, v28.4s, v15.s[2]\n" + "sqdmulh v29.4s, v29.4s, v15.s[2]\n" + "sqdmulh v30.4s, v30.4s, v15.s[3]\n" + "sqdmulh v31.4s, v31.4s, v15.s[3]\n" + + // Apply the negative exponent part of the multiplier. + "srshl v16.4s, v16.4s, v4.4s\n" + "srshl v17.4s, v17.4s, v4.4s\n" + "dup v4.4s, v12.s[0]\n" + "srshl v18.4s, v18.4s, v5.4s\n" + "srshl v19.4s, v19.4s, v5.4s\n" + "dup v5.4s, v12.s[1]\n" + "srshl v20.4s, v20.4s, v6.4s\n" + "srshl v21.4s, v21.4s, v6.4s\n" + "dup v6.4s, v12.s[2]\n" + "srshl v22.4s, v22.4s, v7.4s\n" + "srshl v23.4s, v23.4s, v7.4s\n" + "dup v7.4s, v12.s[3]\n" + "srshl v24.4s, v24.4s, v4.4s\n" + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "srshl v25.4s, v25.4s, v4.4s\n" + "ins v13.h[4], w4\n" // dst_zero_point + "srshl v26.4s, v26.4s, v5.4s\n" + "ins v0.d[1], x1\n" + "srshl v27.4s, v27.4s, v5.4s\n" + "ins v1.d[1], x2\n" + "srshl v28.4s, v28.4s, v6.4s\n" + "ins v2.d[1], x5\n" + "srshl v29.4s, v29.4s, v6.4s\n" + "ins v3.d[1], x6\n" + "srshl v30.4s, v30.4s, v7.4s\n" + "srshl v31.4s, v31.4s, v7.4s\n" + "9:\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // Destination zero_point + "dup v14.8h, v13.h[4]\n" + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "sqxtun2 v16.16b, v17.8h\n" + "sqxtun v17.8b, v18.8h\n" + "sqxtun2 v17.16b, v19.8h\n" + "sqxtun v18.8b, v20.8h\n" + "sqxtun2 v18.16b, v21.8h\n" + "sqxtun v19.8b, v22.8h\n" + "sqxtun2 v19.16b, v23.8h\n" + + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + "sub w2, %w[dst_cols], %w[col]\n" + "umax v17.16b, v17.16b, v14.16b\n" + "mov w3, #8\n" + "umax v18.16b, v18.16b, v14.16b\n" + "cmp w1, #8\n" + "umax v19.16b, v19.16b, v14.16b\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + "cmp w2, #8\n" + "umin v17.16b, v17.16b, v15.16b\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + "umin v18.16b, v18.16b, v15.16b\n" + "umin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // Destination zero_point + "dup v14.8h, v13.h[4]\n" + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "sqxtn2 v16.16b, v17.8h\n" + "sqxtn v17.8b, v18.8h\n" + "sqxtn2 v17.16b, v19.8h\n" + "sqxtn v18.8b, v20.8h\n" + "sqxtn2 v18.16b, v21.8h\n" + "sqxtn v19.8b, v22.8h\n" + "sqxtn2 v19.16b, v23.8h\n" + + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + "sub w2, %w[dst_cols], %w[col]\n" + "smax v17.16b, v17.16b, v14.16b\n" + "mov w3, #8\n" + "smax v18.16b, v18.16b, v14.16b\n" + "cmp w1, #8\n" + "smax v19.16b, v19.16b, v14.16b\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + "cmp w2, #8\n" + "smin v17.16b, v17.16b, v15.16b\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + "smin v18.16b, v18.16b, v15.16b\n" + "smin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 150b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + "saddw v20.4s, v20.4s, v14.4h\n" + "saddw v21.4s, v21.4s, v14.4h\n" + "saddw v22.4s, v22.4s, v14.4h\n" + "saddw v23.4s, v23.4s, v14.4h\n" + "saddw v24.4s, v24.4s, v14.4h\n" + "saddw v25.4s, v25.4s, v14.4h\n" + "saddw v26.4s, v26.4s, v14.4h\n" + "saddw v27.4s, v27.4s, v14.4h\n" + "saddw v28.4s, v28.4s, v14.4h\n" + "saddw v29.4s, v29.4s, v14.4h\n" + "saddw v30.4s, v30.4s, v14.4h\n" + "saddw v31.4s, v31.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + "smax v18.8h, v18.8h, v14.8h\n" + "smax v19.8h, v19.8h, v14.8h\n" + "smax v20.8h, v20.8h, v14.8h\n" + "smax v21.8h, v21.8h, v14.8h\n" + "smax v22.8h, v22.8h, v14.8h\n" + "smax v23.8h, v23.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + "smin v18.8h, v18.8h, v15.8h\n" + "smin v19.8h, v19.8h, v15.8h\n" + "smin v20.8h, v20.8h, v15.8h\n" + "smin v21.8h, v21.8h, v15.8h\n" + "smin v22.8h, v22.8h, v15.8h\n" + "smin v23.8h, v23.8h, v15.8h\n" + + // Compute how much of the 8x8 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 250b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" + "ldr x1, [%[lhs_ptr]], #8\n" + "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" + "ldr x2, [%[lhs_ptr]], #8\n" + "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" + "ldr x5, [%[rhs_ptr]], #8\n" + "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" + "ldr x6, [%[rhs_ptr]], #8\n" + "ins v0.d[1], x1\n" + "ins v1.d[1], x2\n" + "ins v2.d[1], x5\n" + "ins v3.d[1], x6\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x8 block of destination 32it values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x8 block fits. + // Write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "st1 {v16.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v16) + "st1 {v17.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v17) + "st1 {v18.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v18) + "st1 {v19.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v19) + "st1 {v20.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v20) + "st1 {v21.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v21) + "st1 {v22.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v22) + "st1 {v23.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v23) + "st1 {v24.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v24) + "st1 {v25.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v25) + "st1 {v26.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v26) + "st1 {v27.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v27) + "st1 {v28.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v28) + "st1 {v29.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v29) + "st1 {v30.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v30) + "st1 {v31.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v31) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x8 block fits. + "mov x4, %[dst_ptr]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.4s, v17.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v18.4s, v19.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v20.4s, v21.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v22.4s, v23.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v24.4s, v25.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v26.4s, v27.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v28.4s, v29.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v30.4s, v31.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 350b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_LHS_SUMS +#undef RUY_OFFSET_RHS_SUMS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT +#undef RUY_OFFSET_MULTIPLIER_EXPONENT +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_LHS_ZERO_POINT +#undef RUY_OFFSET_RHS_ZERO_POINT +#undef RUY_OFFSET_DST_ZERO_POINT +#undef RUY_OFFSET_PROD_ZP_DEPTH +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS + +#define RUY_OFFSET_LHS_BASE_PTR 0 +#define RUY_OFFSET_RHS_BASE_PTR 8 +#define RUY_OFFSET_DST_BASE_PTR 16 +#define RUY_OFFSET_BIAS 24 +#define RUY_OFFSET_START_ROW 32 +#define RUY_OFFSET_START_COL 36 +#define RUY_OFFSET_LAST_ROW 40 +#define RUY_OFFSET_LAST_COL 44 +#define RUY_OFFSET_LHS_STRIDE 56 +#define RUY_OFFSET_RHS_STRIDE 60 +#define RUY_OFFSET_DST_STRIDE 64 +#define RUY_OFFSET_DEPTH 68 +#define RUY_OFFSET_CLAMP_MIN 72 +#define RUY_OFFSET_CLAMP_MAX 76 +#define RUY_OFFSET_FLAGS 80 + +template <typename Params> +void CheckOffsetsInKernelParamsFloat(const Params&) { + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); + static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); +} + +// Just a plain float kernel; good enough for out-of-order cores. +// The closest to it in the gemmlowp collection would be +// NEON_64bit_GEMM_Float32_WithScalar, +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925 +// +// Besides ruy-ification, the main nuance here is that we stick to a 8x8 +// width instead of the wider 12x8 that the register space permits and that +// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now +// and we don't have evidence that going beyond 8x8 is needed. +void KernelFloatNeon(const KernelParamsFloat<8, 8>& params) { + CheckOffsetsInKernelParamsFloat(params); + profiler::ScopeLabel label("Kernel (kNeon)"); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v15 are used to load data from LHS and RHS. + // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and + // v3 are used to load a 1x8 block of RHS, like this: + // + // RHS 1x8 block + // /-----------------------------------------| + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------| + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for floading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov w1, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + +#if RUY_OPT(MAX_STREAMING) + "cmp w12, #8\n" + "blt 78f\n" + "and w2, w12, #-4\n" + + "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v5.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" + + "ld1 {v8.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v9.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v10.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v11.4s}, [%[rhs_ptr]], #16\n" + + "ld1 {v12.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v13.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v14.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v15.4s}, [%[rhs_ptr]], #16\n" + "mov w1, #4\n" + + "80:\n" + + "add %[lhs_ptr], %[lhs_ptr], #128\n" + "add %[rhs_ptr], %[rhs_ptr], #128\n" + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr q0, [%[lhs_ptr], #-128]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr q3, [%[rhs_ptr], #-112]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ldr q1, [%[lhs_ptr], #-112]\n" + "fmla v16.4s, v4.4s, v6.s[0]\n" + "fmla v18.4s, v4.4s, v6.s[1]\n" + "ldr q2, [%[rhs_ptr], #-128]\n" + "fmla v20.4s, v4.4s, v6.s[2]\n" + "fmla v22.4s, v4.4s, v6.s[3]\n" + + "fmla v24.4s, v4.4s, v7.s[0]\n" + "fmla v26.4s, v4.4s, v7.s[1]\n" + "fmla v28.4s, v4.4s, v7.s[2]\n" + "fmla v30.4s, v4.4s, v7.s[3]\n" + "ldr q4, [%[lhs_ptr], #-96]\n" + "fmla v25.4s, v5.4s, v7.s[0]\n" + "fmla v27.4s, v5.4s, v7.s[1]\n" + "fmla v29.4s, v5.4s, v7.s[2]\n" + "fmla v31.4s, v5.4s, v7.s[3]\n" + "ldr q7, [%[rhs_ptr], #-80]\n" + "fmla v17.4s, v5.4s, v6.s[0]\n" + "fmla v19.4s, v5.4s, v6.s[1]\n" + "fmla v21.4s, v5.4s, v6.s[2]\n" + "fmla v23.4s, v5.4s, v6.s[3]\n" + "ldr q5, [%[lhs_ptr], #-80]\n" + "fmla v16.4s, v8.4s, v10.s[0]\n" + "fmla v18.4s, v8.4s, v10.s[1]\n" + "ldr q6, [%[rhs_ptr], #-96]\n" + "fmla v20.4s, v8.4s, v10.s[2]\n" + "fmla v22.4s, v8.4s, v10.s[3]\n" + + "fmla v24.4s, v8.4s, v11.s[0]\n" + "fmla v26.4s, v8.4s, v11.s[1]\n" + "fmla v28.4s, v8.4s, v11.s[2]\n" + "fmla v30.4s, v8.4s, v11.s[3]\n" + "ldr q8, [%[lhs_ptr], #-64]\n" + "fmla v25.4s, v9.4s, v11.s[0]\n" + "fmla v27.4s, v9.4s, v11.s[1]\n" + "fmla v29.4s, v9.4s, v11.s[2]\n" + "fmla v31.4s, v9.4s, v11.s[3]\n" + "ldr q11, [%[rhs_ptr], #-48]\n" + "fmla v17.4s, v9.4s, v10.s[0]\n" + "fmla v19.4s, v9.4s, v10.s[1]\n" + "fmla v21.4s, v9.4s, v10.s[2]\n" + "fmla v23.4s, v9.4s, v10.s[3]\n" + "ldr q9, [%[lhs_ptr], #-48]\n" + "fmla v16.4s, v12.4s, v14.s[0]\n" + "fmla v18.4s, v12.4s, v14.s[1]\n" + "ldr q10, [%[rhs_ptr], #-64]\n" + "fmla v20.4s, v12.4s, v14.s[2]\n" + "fmla v22.4s, v12.4s, v14.s[3]\n" + + "fmla v24.4s, v12.4s, v15.s[0]\n" + "fmla v26.4s, v12.4s, v15.s[1]\n" + "fmla v28.4s, v12.4s, v15.s[2]\n" + "fmla v30.4s, v12.4s, v15.s[3]\n" + "ldr q12, [%[lhs_ptr], #-32]\n" + "fmla v25.4s, v13.4s, v15.s[0]\n" + "fmla v27.4s, v13.4s, v15.s[1]\n" + "fmla v29.4s, v13.4s, v15.s[2]\n" + "fmla v31.4s, v13.4s, v15.s[3]\n" + "ldr q15, [%[rhs_ptr], #-16]\n" + "fmla v17.4s, v13.4s, v14.s[0]\n" + "fmla v19.4s, v13.4s, v14.s[1]\n" + "fmla v21.4s, v13.4s, v14.s[2]\n" + "fmla v23.4s, v13.4s, v14.s[3]\n" + "ldr q13, [%[lhs_ptr], #-16]\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "ldr q14, [%[rhs_ptr], #-32]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + "add w1, w1, #4\n" + "cmp w1, w2\n" + "blt 80b\n" + + "fmla v16.4s, v4.4s, v6.s[0]\n" + "fmla v18.4s, v4.4s, v6.s[1]\n" + "fmla v20.4s, v4.4s, v6.s[2]\n" + "fmla v22.4s, v4.4s, v6.s[3]\n" + "fmla v24.4s, v4.4s, v7.s[0]\n" + "fmla v26.4s, v4.4s, v7.s[1]\n" + "fmla v28.4s, v4.4s, v7.s[2]\n" + "fmla v30.4s, v4.4s, v7.s[3]\n" + "fmla v25.4s, v5.4s, v7.s[0]\n" + "fmla v27.4s, v5.4s, v7.s[1]\n" + "fmla v29.4s, v5.4s, v7.s[2]\n" + "fmla v31.4s, v5.4s, v7.s[3]\n" + "fmla v17.4s, v5.4s, v6.s[0]\n" + "fmla v19.4s, v5.4s, v6.s[1]\n" + "fmla v21.4s, v5.4s, v6.s[2]\n" + "fmla v23.4s, v5.4s, v6.s[3]\n" + + "fmla v16.4s, v8.4s, v10.s[0]\n" + "fmla v18.4s, v8.4s, v10.s[1]\n" + "fmla v20.4s, v8.4s, v10.s[2]\n" + "fmla v22.4s, v8.4s, v10.s[3]\n" + "fmla v24.4s, v8.4s, v11.s[0]\n" + "fmla v26.4s, v8.4s, v11.s[1]\n" + "fmla v28.4s, v8.4s, v11.s[2]\n" + "fmla v30.4s, v8.4s, v11.s[3]\n" + "fmla v25.4s, v9.4s, v11.s[0]\n" + "fmla v27.4s, v9.4s, v11.s[1]\n" + "fmla v29.4s, v9.4s, v11.s[2]\n" + "fmla v31.4s, v9.4s, v11.s[3]\n" + "fmla v17.4s, v9.4s, v10.s[0]\n" + "fmla v19.4s, v9.4s, v10.s[1]\n" + "fmla v21.4s, v9.4s, v10.s[2]\n" + "fmla v23.4s, v9.4s, v10.s[3]\n" + + "fmla v16.4s, v12.4s, v14.s[0]\n" + "fmla v18.4s, v12.4s, v14.s[1]\n" + "fmla v20.4s, v12.4s, v14.s[2]\n" + "fmla v22.4s, v12.4s, v14.s[3]\n" + "fmla v24.4s, v12.4s, v15.s[0]\n" + "fmla v26.4s, v12.4s, v15.s[1]\n" + "fmla v28.4s, v12.4s, v15.s[2]\n" + "fmla v30.4s, v12.4s, v15.s[3]\n" + "fmla v25.4s, v13.4s, v15.s[0]\n" + "fmla v27.4s, v13.4s, v15.s[1]\n" + "fmla v29.4s, v13.4s, v15.s[2]\n" + "fmla v31.4s, v13.4s, v15.s[3]\n" + "fmla v17.4s, v13.4s, v14.s[0]\n" + "fmla v19.4s, v13.4s, v14.s[1]\n" + "fmla v21.4s, v13.4s, v14.s[2]\n" + "fmla v23.4s, v13.4s, v14.s[3]\n" + + "78:\n" +#endif + + // Accumulation loop + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "add w1, w1, #1\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "cmp w1, w12\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "mov v2.16b, v4.16b\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "fmla v22.4s, v0.4s, v4.s[3]\n" + "blt 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Determine the channel index. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition. + // Jump based on channel dimension. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v8.4s, v14.s[0]\n" + "dup v9.4s, v14.s[1]\n" + "dup v10.4s, v14.s[2]\n" + "dup v11.4s, v14.s[3]\n" + "dup v12.4s, v15.s[0]\n" + "dup v13.4s, v15.s[1]\n" + "dup v14.4s, v15.s[2]\n" + "dup v15.4s, v15.s[3]\n" + "fadd v16.4s, v16.4s, v8.4s\n" + "fadd v17.4s, v17.4s, v8.4s\n" + "fadd v18.4s, v18.4s, v9.4s\n" + "fadd v19.4s, v19.4s, v9.4s\n" + "fadd v20.4s, v20.4s, v10.4s\n" + "fadd v21.4s, v21.4s, v10.4s\n" + "fadd v22.4s, v22.4s, v11.4s\n" + "fadd v23.4s, v23.4s, v11.4s\n" + "fadd v24.4s, v24.4s, v12.4s\n" + "fadd v25.4s, v25.4s, v12.4s\n" + "fadd v26.4s, v26.4s, v13.4s\n" + "fadd v27.4s, v27.4s, v13.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v14.4s\n" + "fadd v30.4s, v30.4s, v15.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + "7:\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov w1, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Variant of KernelFloatNeon tuned for in-order CPUs that do not +// support dotprod (while dotprod by itself is not relevant to floating-point, +// this additional bit of information that we have about the target happens to +// be useful here). +// +// So a typical target CPU here would be ARM Cortex-A53 or the original +// Cortex-A55. +// +// This kernel is similar to and inspired by gemmlowp's +// NEON_64bit_GEMM_Float32_WithScalar_A53. +// which was contributed by David Mansell with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A53: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 +void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); + + CheckOffsetsInKernelParamsFloat(params); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v3 are used to load data from LHS and RHS. + // + // RHS 1x8 block + // /-----------------------------------------| + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------| + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used + // for the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v17) + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v18) + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v19) + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") + RUY_MAKE_ZERO(v24) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") + RUY_MAKE_ZERO(v26) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "cmp w1, #0\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + // Accumulation loop + "beq 79f\n" + + "2:\n" + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "ldr x2, [%[lhs_ptr], #8]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ldr x3, [%[lhs_ptr], #24]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "ldr x5, [%[rhs_ptr], #24]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr x4, [%[rhs_ptr], #8]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "subs w1, w1, #1\n" + "ldr d0, [%[lhs_ptr]], #32\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ins v0.d[1], x2\n" + "ldr d3, [%[rhs_ptr], #16]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "ins v3.d[1], x5\n" + "ldr d4, [%[rhs_ptr]], #32\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "ins v4.d[1], x4\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "ins v1.d[1], x3\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + "mov v2.16b, v4.16b\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + "fmla v22.4s, v0.4s, v4.s[3]\n" + "bne 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Determine the channel index. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition. + // Jump based on channel dimension. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v8.4s, v14.s[0]\n" + "dup v9.4s, v14.s[1]\n" + "fadd v16.4s, v16.4s, v8.4s\n" + "dup v10.4s, v14.s[2]\n" + "fadd v17.4s, v17.4s, v8.4s\n" + "dup v11.4s, v14.s[3]\n" + "fadd v18.4s, v18.4s, v9.4s\n" + "dup v12.4s, v15.s[0]\n" + "fadd v19.4s, v19.4s, v9.4s\n" + "dup v13.4s, v15.s[1]\n" + "fadd v20.4s, v20.4s, v10.4s\n" + "dup v14.4s, v15.s[2]\n" + "fadd v21.4s, v21.4s, v10.4s\n" + "dup v15.4s, v15.s[3]\n" + "fadd v22.4s, v22.4s, v11.4s\n" + "fadd v23.4s, v23.4s, v11.4s\n" + "fadd v24.4s, v24.4s, v12.4s\n" + "fadd v25.4s, v25.4s, v12.4s\n" + "fadd v26.4s, v26.4s, v13.4s\n" + "fadd v27.4s, v27.4s, v13.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v14.4s\n" + "fadd v30.4s, v30.4s, v15.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + "7:\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Variant of KernelFloatNeonA55ish tuned for in-order CPUs that do +// support dotprod (while dotprod by itself is not relevant to floating-point, +// this additional bit of information that we have about the target happens to +// be useful here). +// +// So a typical target CPU here would be ARM Cortex-A55r1. +// +// This kernel is similar to and inspired by gemmlowp's +// NEON_64bit_GEMM_Float32_WithScalar_A55r1. +// which was contributed by David Mansell with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A55r1: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 +void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for in-order cores)"); + + CheckOffsetsInKernelParamsFloat(params); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v3 are used to load data from LHS and RHS. + // + // RHS 1x8 block + // /-----------------------------------------| + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------| + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used + // for the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v17) + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v18) + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v19) + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") + RUY_MAKE_ZERO(v24) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") + RUY_MAKE_ZERO(v26) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "cmp w1, #0\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + // Accumulation loop + "beq 79f\n" + + "2:\n" + + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + "fmla v24.4s, v0.4s, v3.s[0]\n" + "ldr x2, [%[lhs_ptr], #8]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ldr x3, [%[lhs_ptr], #24]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "ldr x5, [%[rhs_ptr], #24]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr d0, [%[lhs_ptr]], #32\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "ldr x4, [%[rhs_ptr], #8]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "subs w1, w1, #1\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "ins v0.d[1], x2\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr d3, [%[rhs_ptr], #16]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "ins v3.d[1], x5\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "ldr d4, [%[rhs_ptr]], #32\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "ins v4.d[1], x4\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + "fmla v16.4s, v0.4s, v4.s[0]\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "ins v1.d[1], x3\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "mov v2.16b, v4.16b\n" + "fmla v22.4s, v0.4s, v4.s[3]\n" + "bne 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Determine the channel index. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "csel w3, %w[row], %w[col], eq\n" + + // Offset the bias pointer as needed given the current row, col. + "add x5, x1, x3, lsl #2\n" + + // If there is no bias, use no offset, just address the passed zero + // data. + + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition. + // Jump based on channel dimension. + "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" + "bne 6f\n" + // Case where channels are rows + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + "b 7f\n" + + "6:\n" + // Case where channels are columns + "dup v8.4s, v14.s[0]\n" + "dup v9.4s, v14.s[1]\n" + "fadd v16.4s, v16.4s, v8.4s\n" + "dup v10.4s, v14.s[2]\n" + "fadd v17.4s, v17.4s, v8.4s\n" + "dup v11.4s, v14.s[3]\n" + "fadd v18.4s, v18.4s, v9.4s\n" + "dup v12.4s, v15.s[0]\n" + "fadd v19.4s, v19.4s, v9.4s\n" + "dup v13.4s, v15.s[1]\n" + "fadd v20.4s, v20.4s, v10.4s\n" + "dup v14.4s, v15.s[2]\n" + "fadd v21.4s, v21.4s, v10.4s\n" + "dup v15.4s, v15.s[3]\n" + "fadd v22.4s, v22.4s, v11.4s\n" + "fadd v23.4s, v23.4s, v11.4s\n" + "fadd v24.4s, v24.4s, v12.4s\n" + "fadd v25.4s, v25.4s, v12.4s\n" + "fadd v26.4s, v26.4s, v13.4s\n" + "fadd v27.4s, v27.4s, v13.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v14.4s\n" + "fadd v30.4s, v30.4s, v15.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + "7:\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_FLAGS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR + +#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) + +} // namespace ruy diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc new file mode 100644 index 0000000..2405735 --- /dev/null +++ b/ruy/kernel_avx.cc @@ -0,0 +1,1476 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <algorithm> +#include <cstdint> +#include <cstring> + +#include "ruy/check_macros.h" +#include "ruy/kernel_common.h" +#include "ruy/kernel_x86.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM_AVX && RUY_OPT(ASM) +#include <immintrin.h> // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) + +void Kernel8bitAvx(const KernelParams8bit<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx(const KernelParamsFloat<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM_AVX && RUY_OPT(ASM) + +static constexpr int kAvx8bitBlockSize = 8; +static constexpr int kAvx8bitInnerSize = 4; + +namespace { +namespace intrin_utils { + +template <> +inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a, + const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo); + __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi); + return _mm256_set_m128i(dst_hi, dst_lo); +} + +template <> +inline __m128i mm256_extracti128_si256<Path::kAvx>(const __m256i& a, + const int imm) { + switch (imm) { + case 0: + return _mm256_extractf128_si256(a, 0); + case 1: + return _mm256_extractf128_si256(a, 1); + default: + RUY_DCHECK_LT(imm, 2); + return _mm_setzero_si128(); + } +} + +template <Path path> +inline __m256i mm256_cvtepi8_epi16(const __m128i& a) { + // Take the upper 64 bits of a and put in the first 64 bits of 'hi' + __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128()); + return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a)); +} + +template <Path path> +inline __m256i mm256_cvtepi32_epi64(const __m128i& a) { + // sign extend the 32-bit values in the lower 64 bits of a. + __m128i lo = _mm_cvtepi32_epi64(a); + __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128())); + return _mm256_set_m128i(hi, lo); +} + +inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b, + const int imm) { + __m128i tmp = _mm_setzero_si128(); + if (!(imm & 8)) { + switch (imm & 3) { + case 0: + return _mm256_extractf128_si256(a, 0); + case 1: + return _mm256_extractf128_si256(a, 1); + case 2: + return _mm256_extractf128_si256(b, 0); + case 3: + return _mm256_extractf128_si256(b, 1); + } + } + return tmp; +} + +template <Path path> +inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b, + const int imm) { + const int lo_imm = imm & 15; + __m128i lo = mm_permute_helper(a, b, lo_imm); + const int hi_imm = (imm >> 4) & 15; + __m128i hi = mm_permute_helper(a, b, hi_imm); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_max_epi32(a_lo, b_lo); + __m128i hi = _mm_max_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_min_epi32(a_lo, b_lo); + __m128i hi = _mm_min_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_add_epi32(a_lo, b_lo); + __m128i hi = _mm_add_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_add_epi64(a_lo, b_lo); + __m128i hi = _mm_add_epi64(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_slli_epi64(const __m256i& a, int imm) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i lo = _mm_slli_epi64(a_lo, imm); + __m128i hi = _mm_slli_epi64(a_hi, imm); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_mullo_epi32(a_lo, b_lo); + __m128i hi = _mm_mullo_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +// Defined as a macro since `imm` must be an immediate. +#define BlendM128_epi32(a, b, imm) \ + _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm)) + +// Defined as a macro since `imm` must be an immediate. +#define BlendM128_epi64(a, b, imm) \ + _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm)) + +// Defined as a macro since `imm` must be an immediate. +#define mm256_blend_epi32(ans, a, b, imm) \ + __m128i a_lo = _mm256_extractf128_si256(a, 0); \ + __m128i a_hi = _mm256_extractf128_si256(a, 1); \ + __m128i b_lo = _mm256_extractf128_si256(b, 0); \ + __m128i b_hi = _mm256_extractf128_si256(b, 1); \ + __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \ + __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4); \ + ans = _mm256_set_m128i(hi, lo); + +#define mm256_shuffle_epi32(ans, a, a_lo, a_hi, imm) \ + a_lo = _mm256_extractf128_si256(a, 0); \ + a_hi = _mm256_extractf128_si256(a, 1); \ + ans = _mm256_set_m128i(_mm_shuffle_epi32(a_hi, imm), \ + _mm_shuffle_epi32(a_lo, imm)); + +template <Path path> +inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_madd_epi16(a_lo, b_lo); + __m128i hi = _mm_madd_epi16(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) { + // shift both elements of a by lower 64bits of b. + __m128i res_lo = _mm_srl_epi64(a, b); + // shift both elements of a by upper 64bits of b. + __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128()); + __m128i res_hi = _mm_srl_epi64(a, hi_count); + // Take the lower 64 bits of res_lo and upper 64 bits of res hi + // 1. Swap the upper and lower 64 bits of res_hi + __m128i tmp_hi = + _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1)); + // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. + return _mm_unpacklo_epi64(res_lo, tmp_hi); +} + +template <Path path> +inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = mm_srlv_epi64(a_lo, b_lo); + __m128i hi = mm_srlv_epi64(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) { + // shift both elements of a by lower 64bits of b. + __m128i res_lo = _mm_sll_epi64(a, b); + // shift both elements of a by upper 64bits of b. + __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128()); + __m128i res_hi = _mm_sll_epi64(a, hi_count); + // Take the lower 64 bits of res_lo and upper 64 bits of res hi + // 1. Swap the upper and lower 64 bits of res_hi + __m128i tmp_hi = + _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1)); + // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. + return _mm_unpacklo_epi64(res_lo, tmp_hi); +} + +template <Path path> +inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo); + __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +#define PermuteM128_epi32(a, imm) \ + _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm)); + +inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) { + // shift all elements of a by first 32bits of b. + __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1)); + + // put bits 32-63 of b in the first slot. + __m128i tmp1 = PermuteM128_epi32(b, 1); + // put bits 32-63 of a in the first slot. + __m128i a1 = PermuteM128_epi32(a, 1); + // shift all elements of a by second 32bits of b. + __m128i res1 = + _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1)); + + // put bits 64-95 of b in the first slot. + __m128i tmp2 = PermuteM128_epi32(b, 2); + // shift all elements of a by third 32bits of b. + __m128i res2 = + _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1)); + + // put bits 96-127 of b in the first slot. + __m128i tmp3 = PermuteM128_epi32(b, 3); + // put bits 96-127 of a in the third slot. + __m128i a3 = PermuteM128_epi32(a, 48); + // shift all elements of a3 by fourth 32bits of b. + __m128i res3 = + _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1)); + + // Take bits 0-31 of res0, bits 0-31 of res1, + // bits 64-95 of res2, and bits 64-95 of res3. + // res0 _ _ _ 0 + // res1 _ _ _ 1 + // res2 _ 2 _ _ + // res3 _ 3 _ _ + // f_01 _ _ 1 0 + // f_23 _ _ 3 2 + __m128i f_01 = _mm_unpacklo_epi32(res0, res1); + __m128i f_23 = _mm_unpackhi_epi32(res2, res3); + // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. + return _mm_unpacklo_epi64(f_01, f_23); +} + +template <Path path> +inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = mm_sllv_epi32(a_lo, b_lo); + __m128i hi = mm_sllv_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_sub_epi32(a_lo, b_lo); + __m128i hi = _mm_sub_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_mul_epi32(a_lo, b_lo); + __m128i hi = _mm_mul_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +// Perform the equivalent of mm256_permutevar8x32 with +// a second argument of {7, 5, 3, 1, 6, 4, 2, 0} +template <Path path> +inline __m256i PermuteEpi32EvenOdds(const __m256i& a) { + // a_lo = 3 2 1 0 + __m128i a_lo = _mm256_extractf128_si256(a, 0); + // a_hi = 7 6 5 4 + __m128i a_hi = _mm256_extractf128_si256(a, 1); + // shuffle a_lo to get 3 1 2 0 + __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8); + // shuffle a_hi to get 7 5 6 4 + __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8); + // unpack lo 64 of res_lo and res hi to get 6 4 2 0 + __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi); + // unpack hi 64 of res_lo and res hi to get 7 5 1 3 + __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi); + return _mm256_set_m128i(res_hi, res_lo); +} + +template <Path path> +inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) { + const __m256i bias0 = _mm256_set1_epi32(*(bias + offset)); + return mm256_add_epi32<path>(a, bias0); +} + +__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b, + const __m256i& mask) { + __m256 result = + _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), + _mm256_castsi256_ps(mask)); + return _mm256_castps_si256(result); +} + +template <Path path> +inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo); + __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + + __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0)); + __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1)); + __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2)); + __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3)); + __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4)); + __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5)); + __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6)); + __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7)); + + // get element 0 from r0, element 1 from r1 + __m128i r01 = BlendM128_epi32(r0, r1, 2); + // get element 2 from r2, element 3 from r3 + __m128i r23 = BlendM128_epi32(r2, r3, 8); + // get element 0 from r4, element 1 from r5 + __m128i r45 = BlendM128_epi32(r4, r5, 2); + // get element 2 from r6, element 3 from r7 + __m128i r67 = BlendM128_epi32(r6, r7, 8); + // get lower 64 bits of r01, upper 64 bits of r23 + __m128i r0123 = BlendM128_epi64(r01, r23, 2); + // get lower 64 bits of r45, upper 64 bits of r67 + __m128i r4567 = BlendM128_epi64(r45, r67, 2); + return _mm256_set_m128i(r4567, r0123); +} + +// AVX doesn't have fused multiply-add so we define an inline function to be +// used in the common code following. +template <> +inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b, + const __m256& c) { + const __m256 prod = _mm256_mul_ps(a, b); + return _mm256_add_ps(prod, c); +} + +} // namespace intrin_utils +} // namespace + +template <Path path> +void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx 8-bit"); + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + std::int32_t dst_stride = 0; + if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || + (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + int channel = + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row; + int multiplier_channel = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0; + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + const __m256i splitter_idx = _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + __m256i accum_data_v1; + __m256i accum_data_v2; + __m256i accum_data_v3; + __m256i accum_data_v4; + __m256i accum_data_v5; + __m256i accum_data_v6; + __m256i accum_data_v7; + + // initial_accum_data will be the initialize of each of the + // accum_data_* accumulator registers. We compute into it terms that are + // identical across columns. + __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth); + __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth); + + // In the channels-are-rows case, we can load bias here. + if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && + !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { + initial_accum_data_lo = _mm_add_epi32( + initial_accum_data_lo, + _mm_loadu_si128( + reinterpret_cast<const __m128i*>(params.bias + row))); + initial_accum_data_hi = _mm_add_epi32( + initial_accum_data_hi, + _mm_loadu_si128( + reinterpret_cast<const __m128i*>(params.bias + row + 4))); + } + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point); + const __m128i lhs_sums_offset_lo = _mm_mullo_epi32( + rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>( + ¶ms.lhs_sums[row]))); + const __m128i lhs_sums_offset_hi = _mm_mullo_epi32( + rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>( + ¶ms.lhs_sums[row + 4]))); + + initial_accum_data_lo = + _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo); + initial_accum_data_hi = + _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + __m256i initial_accum_data = + _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo); + + accum_data_v0 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); + accum_data_v1 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); + accum_data_v2 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); + accum_data_v3 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); + accum_data_v4 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); + accum_data_v5 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); + accum_data_v6 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); + accum_data_v7 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); + } else { + __m256i initial_accum_data = + _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo); + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } + + // Finally, in the channels-are-columns case, load bias data here. + if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { + accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0, + params.bias + col, 0); + accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1, + params.bias + col, 1); + accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2, + params.bias + col, 2); + accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3, + params.bias + col, 3); + accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4, + params.bias + col, 4); + accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5, + params.bias + col, 5); + accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6, + params.bias + col, 6); + accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7, + params.bias + col, 7); + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); + const __m256i rhs_data_8bit = + _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr)); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + std::int32_t rhs_data[16]; + const __m128i rhs_data_bottom_lane = + _mm256_castsi256_si128(rhs_data_8bit); + const __m128i rhs_data_top_lane = + _mm256_extractf128_si256(rhs_data_8bit, 1); + const __m256i rhs_16_bit_dup_low = + intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane); + const __m256i rhs_16_bit_dup_high = + intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), + rhs_16_bit_dup_low); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), + rhs_16_bit_dup_high); + + // NOTE: There may be opportunities for permuting the data in the + // packing code instead of here. + const __m256i lhs_data_split = + intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + intrin_utils::mm256_cvtepi8_epi16<path>( + _mm256_extractf128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + intrin_utils::mm256_cvtepi8_epi16<path>( + _mm256_extractf128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = + intrin_utils::mm256_permute2x128_si256<path>( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = + intrin_utils::mm256_permute2x128_si256<path>( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + + __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>( + rhs_data)); // Load [0 1 2 3 4 5 6 7] + __m256i rhs1 = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15] + __m256i rhs0_3 = + _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3] + __m256i rhs4_7 = + _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7] + __m256i rhs8_11 = + _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11] + __m256i rhs12_15 = + _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15] + + auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi, + __m256i& accum) { + // Perform mul-adds on low and high components of accum separately. + __m128i accum_lo = _mm256_extractf128_si256(accum, 0); + __m128i accum_hi = _mm256_extractf128_si256(accum, 1); + + __m128i lhs_lo_0 = _mm256_extractf128_si256(lhs_16_bit_low, 0); + __m128i lhs_lo_1 = _mm256_extractf128_si256(lhs_16_bit_low, 1); + __m128i rhs_dup_lo_0 = _mm256_extractf128_si256(rhs_dup_lo, 0); + __m128i rhs_dup_lo_1 = _mm256_extractf128_si256(rhs_dup_lo, 1); + __m128i lo_0 = _mm_madd_epi16(lhs_lo_0, rhs_dup_lo_0); + __m128i lo_1 = _mm_madd_epi16(lhs_lo_1, rhs_dup_lo_1); + + accum_lo = _mm_add_epi32(accum_lo, lo_0); + accum_hi = _mm_add_epi32(accum_hi, lo_1); + + __m128i lhs_hi_0 = _mm256_extractf128_si256(lhs_16_bit_high, 0); + __m128i lhs_hi_1 = _mm256_extractf128_si256(lhs_16_bit_high, 1); + __m128i rhs_dup_hi_0 = _mm256_extractf128_si256(rhs_dup_hi, 0); + __m128i rhs_dup_hi_1 = _mm256_extractf128_si256(rhs_dup_hi, 1); + __m128i hi_0 = _mm_madd_epi16(lhs_hi_0, rhs_dup_hi_0); + __m128i hi_1 = _mm_madd_epi16(lhs_hi_1, rhs_dup_hi_1); + + accum_lo = _mm_add_epi32(accum_lo, hi_0); + accum_hi = _mm_add_epi32(accum_hi, hi_1); + accum = _mm256_set_m128i(accum_hi, accum_lo); + }; + __m256i tmp0, tmp1, tmp2, tmp3; + __m128i lo0, lo1, hi0, hi1; + mm256_shuffle_epi32(tmp0, rhs0_3, lo0, hi0, 0); + mm256_shuffle_epi32(tmp1, rhs0_3, lo1, hi1, 0x55); + process_column(tmp0, tmp1, accum_data_v0); + mm256_shuffle_epi32(tmp2, rhs0_3, lo0, hi0, 0xaa); + mm256_shuffle_epi32(tmp3, rhs0_3, lo1, hi1, 0xff); + process_column(tmp2, tmp3, accum_data_v1); + + mm256_shuffle_epi32(tmp0, rhs4_7, lo0, hi0, 0); + mm256_shuffle_epi32(tmp1, rhs4_7, lo1, hi1, 0x55); + process_column(tmp0, tmp1, accum_data_v2); + mm256_shuffle_epi32(tmp2, rhs4_7, lo0, hi0, 0xaa); + mm256_shuffle_epi32(tmp3, rhs4_7, lo1, hi1, 0xff); + process_column(tmp2, tmp3, accum_data_v3); + + mm256_shuffle_epi32(tmp0, rhs8_11, lo0, hi0, 0); + mm256_shuffle_epi32(tmp1, rhs8_11, lo1, hi1, 0x55); + process_column(tmp0, tmp1, accum_data_v4); + mm256_shuffle_epi32(tmp2, rhs8_11, lo0, hi0, 0xaa); + mm256_shuffle_epi32(tmp3, rhs8_11, lo1, hi1, 0xff); + process_column(tmp2, tmp3, accum_data_v5); + + mm256_shuffle_epi32(tmp0, rhs12_15, lo0, hi0, 0); + mm256_shuffle_epi32(tmp1, rhs12_15, lo1, hi1, 0x55); + process_column(tmp0, tmp1, accum_data_v6); + mm256_shuffle_epi32(tmp2, rhs12_15, lo0, hi0, 0xaa); + mm256_shuffle_epi32(tmp3, rhs12_15, lo1, hi1, 0xff); + process_column(tmp2, tmp3, accum_data_v7); + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_fixedpoint + multiplier_channel)); + e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_exponent + multiplier_channel)); + + const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(m_vector, 0)); + const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = + intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector); + const __m256i neg_e_vector = + intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); + const __m256i right_shift = + intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); + const __m256i final_right_shift = _mm256_set1_epi32(31); + const __m256i final_right_shift_low = + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(final_right_shift, 1)); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = _mm256_setzero_si256(); + + // A "half" added for rounding prior to truncation of 64-bit value. + const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>( + intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); + } + + // We cannot do + // + // scaled_v_low = + // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); + // scaled_v_high = + // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); + // + // since this instruction is not in AVX2. Instead we use + // _mm256_srlv_epi64, but this is an unsigned shift, so we applied + // offsets before (convert_to_unsigned_64) and after + // (convert_to_signed_halved). + // + // The overall process is, for 64-bit scaled accumulator: + // unsigned_accum = signed_accum + 1 << 63; + // unsigned_accum = (unsigned_accum >> right_shift) >> 31; + // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; + + // There are various ways to repack the results, in the absence of + // _mm256_cvtepi64_epi32() or anything like it. + // A. + // accum_data_v[j] = + // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), + // _mm256_extract_epi32(scaled_v_high, 4), + // _mm256_extract_epi32(scaled_v_high, 2), + // _mm256_extract_epi32(scaled_v_high, 0), + // _mm256_extract_epi32(scaled_v_low, 6), + // _mm256_extract_epi32(scaled_v_low, 4), + // _mm256_extract_epi32(scaled_v_low, 2), + // _mm256_extract_epi32(scaled_v_low, 0)); + // B. + // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); + // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); + // accum_data_v[j] = + // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), + // _mm256_extract_epi64(scaled_v_high, 0), + // _mm256_extract_epi64(scaled_v_low, 2), + // _mm256_extract_epi64(scaled_v_low, 0)); + // C. + // scaled_v_low = + // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); + // scaled_v_high = + // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); + // accum_data_v[j] = + // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); + // + // However, we choose the following because it uses two lighter + // instructions. The permutation does have a longer latency, but this + // loop can be unrolled. + // D. + // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + // __m256i results = + // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + // results = _mm256_permutevar8x32_epi32(results, repack_perm); + // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results, + // post_scaling_offset); + + // This multiplier code is complex and expensive enough on x86, that + // we prefer to implement the channels-are-columns case by transposing + // around it, rather than duplicate it (which would also require + // duplicating the above code computing the multiplier constants). + // This is one instance where channels-are-columns has lower performance + // than channels-are-rows. + const bool transpose_around_multiplier = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); + if (transpose_around_multiplier) { + // Transpose the 8x8 accumulators block. Will be un-transposed below + // after the multplier implementation. + intrin_utils::mm256_transpose8x8_epi32<path>( + &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, + &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); + } + + auto rounding_right_shift = [=](__m256i& results, + const __m256i& exponent) { + // Construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m256i zeros = _mm256_setzero_si256(); + const __m256i mask_rightshift_gtz = + intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros); + const __m256i one_shift_exp_minus1 = + intrin_utils::mm256_sllv_epi32<path>( + _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( + exponent, _mm256_set1_epi32(1))); + __m256i nudge = intrin_utils::mm256_blendv_epi32( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m256i r_plus_nudge = + intrin_utils::mm256_add_epi32<path>(results, nudge); + const __m256i shifted_sum = + intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent); + + // Identify overflow in each lane and create mask. + const __m256i one_shift_31minus_exp = + intrin_utils::mm256_sllv_epi32<path>( + _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( + _mm256_set1_epi32(31), exponent)); + const __m256i mask_num_plus_nudge_overflow = + intrin_utils::mm256_cmpgt_epi32<path>( + results, intrin_utils::mm256_sub_epi32<path>( + _mm256_set1_epi32(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm256_blendv_epi32( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + }; + + auto apply_multiplier = [=](__m256i& accum) { + __m256i shifted_accum = + intrin_utils::mm256_sllv_epi32<path>(accum, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>( + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>( + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(shifted_accum, 1)), + m_64bit_high); + scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, + offset_vector); + scaled_v_high = intrin_utils::mm256_add_epi64<path>( + scaled_v_high, offset_vector); + + scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( + scaled_v_low, final_right_shift_low); + scaled_v_high = intrin_utils::mm256_srlv_epi64<path>( + scaled_v_high, final_right_shift_high); + scaled_v_high = + intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32); + __m256i results; + mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa); + // Permute results to this ordering of int32 elements + // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); + results = intrin_utils::PermuteEpi32EvenOdds<path>(results); + + rounding_right_shift(results, right_shift); + accum = + intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset); + }; + apply_multiplier(accum_data_v0); + apply_multiplier(accum_data_v1); + apply_multiplier(accum_data_v2); + apply_multiplier(accum_data_v3); + apply_multiplier(accum_data_v4); + apply_multiplier(accum_data_v5); + apply_multiplier(accum_data_v6); + apply_multiplier(accum_data_v7); + // See above comment: here we transpose again to undo the transposition + // of the 8x8 block of accumulators used to implement the + // channels-are-columns case. + if (transpose_around_multiplier) { + intrin_utils::mm256_transpose8x8_epi32<path>( + &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, + &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + __m256i accum_data_v[kAvx8bitBlockSize]; + if (!store_full_block) { + accum_data_v[0] = accum_data_v0; + accum_data_v[1] = accum_data_v1; + accum_data_v[2] = accum_data_v2; + accum_data_v[3] = accum_data_v3; + accum_data_v[4] = accum_data_v4; + accum_data_v[5] = accum_data_v5; + accum_data_v[6] = accum_data_v6; + accum_data_v[7] = accum_data_v7; + } + + if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { + std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = + intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); + accum_data_v0 = + intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); + accum_data_v1 = + intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); + accum_data_v1 = + intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); + accum_data_v2 = + intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); + accum_data_v2 = + intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); + accum_data_v3 = + intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); + accum_data_v3 = + intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); + accum_data_v4 = + intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); + accum_data_v4 = + intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); + accum_data_v5 = + intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); + accum_data_v5 = + intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); + accum_data_v6 = + intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); + accum_data_v6 = + intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); + accum_data_v7 = + intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); + accum_data_v7 = + intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[0 * dst_stride], accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[1 * dst_stride], accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { + std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = + intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); + accum_data_v0 = + intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); + accum_data_v1 = + intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); + accum_data_v1 = + intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); + accum_data_v2 = + intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); + accum_data_v2 = + intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); + accum_data_v3 = + intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); + accum_data_v3 = + intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); + accum_data_v4 = + intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); + accum_data_v4 = + intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); + accum_data_v5 = + intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); + accum_data_v5 = + intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); + accum_data_v6 = + intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); + accum_data_v6 = + intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); + accum_data_v7 = + intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); + accum_data_v7 = + intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = + intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); + accum_data_v0 = + intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); + accum_data_v1 = + intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); + accum_data_v1 = + intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); + accum_data_v2 = + intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); + accum_data_v2 = + intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); + accum_data_v3 = + intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); + accum_data_v3 = + intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); + accum_data_v4 = + intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); + accum_data_v4 = + intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); + accum_data_v5 = + intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); + accum_data_v5 = + intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); + accum_data_v6 = + intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); + accum_data_v6 = + intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); + accum_data_v7 = + intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); + accum_data_v7 = + intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + intrin_utils::mm256_n_storeu_epi32<path>( + dst_block_ptr, residual_rows, accum_data_v[j]); + dst_block_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) { + Kernel8bitAvxImpl<Path::kAvx>(params); +} + +template <Path path> +void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + + const __m256i splitter_idx = + _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + + // Initialize with bias. + __m256i initial_accum_data = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr)); + bias_ptr += bias_ptr_block_increment; + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>( + _mm256_set1_epi32(rhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); + initial_accum_data = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, lhs_sums_offset); + } + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth) { + initial_accum_data = intrin_utils::mm256_add_epi32<path>( + initial_accum_data, _mm256_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); + } else { + accum_data_v0 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); + const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + // For simplicity we load 4x the data that we need and process twice the + // data that we need and store only the data we need. + std::int32_t rhs_data[2]; + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + + // NOTE: There may be opportunities for permuting the data in the packing + // code instead of here. + const __m256i lhs_data_split = + intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + intrin_utils::mm256_cvtepi8_epi16<path>( + _mm256_extractf128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + intrin_utils::mm256_cvtepi8_epi16<path>( + _mm256_extractf128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = + intrin_utils::mm256_permute2x128_si256<path>( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = + intrin_utils::mm256_permute2x128_si256<path>( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + // Accumulate for column 0. + const std::int32_t low_rhs_value = rhs_data[0]; + const std::int32_t high_rhs_value = rhs_data[1]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v0 = intrin_utils::mm256_add_epi32<path>( + accum_data_v0, intrin_utils::mm256_madd_epi16<path>( + lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v0 = intrin_utils::mm256_add_epi32<path>( + accum_data_v0, intrin_utils::mm256_madd_epi16<path>( + lhs_16_bit_high, rhs_16_bit_dup_high)); + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0; + m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_fixedpoint + channel)); + e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_exponent + channel)); + + const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(m_vector, 0)); + const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = + intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector); + const __m256i neg_e_vector = + intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); + const __m256i right_shift = + intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); + const __m256i final_right_shift = _mm256_set1_epi32(31); + const __m256i final_right_shift_low = + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(final_right_shift, 1)); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = _mm256_setzero_si256(); + + // A "half" added for rounding prior to truncation of 64-bit value. + const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>( + intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); + } + + // See GEMM version for details of this process. + { + __m256i shifted_accum = + intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>( + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>( + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, + offset_vector); + scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high, + offset_vector); + + scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( + scaled_v_low, final_right_shift_low); + scaled_v_high = intrin_utils::mm256_srlv_epi64<path>( + scaled_v_high, final_right_shift_high); + + scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32); + __m256i results; + mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa); + // Permute results to this ordering of int32 elements + // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); + results = intrin_utils::PermuteEpi32EvenOdds<path>(results); + + // Now perform the Rounding Right Shift. + // First, construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m256i zeros = _mm256_setzero_si256(); + const __m256i mask_rightshift_gtz = + intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros); + const __m256i one_shift_exp_minus1 = + intrin_utils::mm256_sllv_epi32<path>( + _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( + right_shift, _mm256_set1_epi32(1))); + __m256i nudge = intrin_utils::mm256_blendv_epi32( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m256i r_plus_nudge = + intrin_utils::mm256_add_epi32<path>(results, nudge); + const __m256i shifted_sum = + intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift); + + // Identify overflow in each lane and create mask. + const __m256i one_shift_31minus_exp = + intrin_utils::mm256_sllv_epi32<path>( + _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( + _mm256_set1_epi32(31), right_shift)); + const __m256i mask_num_plus_nudge_overflow = + intrin_utils::mm256_cmpgt_epi32<path>( + results, intrin_utils::mm256_sub_epi32<path>( + _mm256_set1_epi32(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm256_blendv_epi32( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + accum_data_v0 = + intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + + if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { + std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); + __m256i result = accum_data_v0; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { + std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); + __m256i result = accum_data_v0; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); + __m256i result = accum_data_v0; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); + intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows, + accum_data_v0); + dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; +} // NOLINT(readability/fn_size) + +void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) { + Kernel8bitAvxSingleColImpl<Path::kAvx>(params); +} + +void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx float"); + KernelFloatAvxCommon<Path::kAvx>(params); +} + +void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx float GEMV"); + KernelFloatAvxCommonSingleCol<Path::kAvx>(params); +} + +#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM) + +} // namespace ruy diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc new file mode 100644 index 0000000..eae333c --- /dev/null +++ b/ruy/kernel_avx2_fma.cc @@ -0,0 +1,1011 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <algorithm> +#include <cstdint> +#include <cstring> + +#include "ruy/check_macros.h" +#include "ruy/kernel_common.h" +#include "ruy/kernel_x86.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) +#include <immintrin.h> // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)) + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) + +static constexpr int kAvx8bitBlockSize = 8; +static constexpr int kAvx8bitInnerSize = 4; + +namespace { +namespace intrin_utils { + +template <> +inline __m256i mm256_shuffle_epi8<Path::kAvx2Fma>(const __m256i& a, + const __m256i& b) { + return _mm256_shuffle_epi8(a, b); +} + +// Make an inline function for FMA so we can share the float kernels +// with non-FMA code. +template <> +inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b, + const __m256& c) { + return _mm256_fmadd_ps(a, b, c); +} + +template <> +inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a, + const int imm) { + switch (imm) { + case 0: + return _mm256_extracti128_si256(a, 0); + case 1: + return _mm256_extracti128_si256(a, 1); + default: + RUY_DCHECK_LT(imm, 2); + return _mm_setzero_si128(); + } +} + +__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b, + const __m256i& mask) { + __m256 result = + _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), + _mm256_castsi256_ps(mask)); + return _mm256_castps_si256(result); +} + +} // namespace intrin_utils +} // namespace + +template <Path path> +void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit"); + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + std::int32_t dst_stride = 0; + if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || + (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + int channel = + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row; + int multiplier_channel = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0; + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + const __m256i splitter_idx = _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + __m256i accum_data_v1; + __m256i accum_data_v2; + __m256i accum_data_v3; + __m256i accum_data_v4; + __m256i accum_data_v5; + __m256i accum_data_v6; + __m256i accum_data_v7; + + // initial_accum_data will be the initialize of each of the + // accum_data_* accumulator registers. We compute into it terms that are + // identical across columns. + __m256i initial_accum_data = _mm256_set1_epi32(params.prod_zp_depth); + + // In the channels-are-rows case, we can load bias here. + if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && + !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { + initial_accum_data = _mm256_add_epi32( + initial_accum_data, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(params.bias + row))); + } + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m256i lhs_sums_offset = _mm256_mullo_epi32( + _mm256_set1_epi32(rhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); + initial_accum_data = + _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); + accum_data_v1 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); + accum_data_v2 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); + accum_data_v3 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); + accum_data_v4 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); + accum_data_v5 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); + accum_data_v6 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); + accum_data_v7 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); + } else { + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } + + // Finally, in the channels-are-columns case, load bias data here. + if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { + const __m256i bias_data = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(params.bias + col)); + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, + _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(0))); + accum_data_v1 = _mm256_add_epi32( + accum_data_v1, + _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(1))); + accum_data_v2 = _mm256_add_epi32( + accum_data_v2, + _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(2))); + accum_data_v3 = _mm256_add_epi32( + accum_data_v3, + _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(3))); + accum_data_v4 = _mm256_add_epi32( + accum_data_v4, + _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(4))); + accum_data_v5 = _mm256_add_epi32( + accum_data_v5, + _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(5))); + accum_data_v6 = _mm256_add_epi32( + accum_data_v6, + _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(6))); + accum_data_v7 = _mm256_add_epi32( + accum_data_v7, + _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(7))); + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); + const __m256i rhs_data_8bit = + _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr)); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + std::int32_t rhs_data[16]; + const __m128i rhs_data_bottom_lane = + _mm256_castsi256_si128(rhs_data_8bit); + const __m128i rhs_data_top_lane = + _mm256_extracti128_si256(rhs_data_8bit, 1); + const __m256i rhs_16_bit_dup_low = + _mm256_cvtepi8_epi16(rhs_data_bottom_lane); + const __m256i rhs_16_bit_dup_high = + _mm256_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), + rhs_16_bit_dup_low); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), + rhs_16_bit_dup_high); + + const __m256i lhs_data_split = + _mm256_shuffle_epi8(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + + __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>( + rhs_data)); // Load [0 1 2 3 4 5 6 7] + __m256i rhs1 = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15] + __m256i rhs0_3 = + _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3] + __m256i rhs4_7 = + _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7] + __m256i rhs8_11 = + _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11] + __m256i rhs12_15 = + _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15] + + auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi, + __m256i& accum) { + accum = _mm256_add_epi32( + accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_dup_lo)); + accum = _mm256_add_epi32( + accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_dup_hi)); + }; + __m256i tmp0, tmp1, tmp2, tmp3; + tmp0 = _mm256_shuffle_epi32(rhs0_3, 0); + tmp1 = _mm256_shuffle_epi32(rhs0_3, 0x55); + process_column(tmp0, tmp1, accum_data_v0); + tmp2 = _mm256_shuffle_epi32(rhs0_3, 0xaa); + tmp3 = _mm256_shuffle_epi32(rhs0_3, 0xff); + process_column(tmp2, tmp3, accum_data_v1); + + tmp0 = _mm256_shuffle_epi32(rhs4_7, 0); + tmp1 = _mm256_shuffle_epi32(rhs4_7, 0x55); + process_column(tmp0, tmp1, accum_data_v2); + tmp2 = _mm256_shuffle_epi32(rhs4_7, 0xaa); + tmp3 = _mm256_shuffle_epi32(rhs4_7, 0xff); + process_column(tmp2, tmp3, accum_data_v3); + + tmp0 = _mm256_shuffle_epi32(rhs8_11, 0); + tmp1 = _mm256_shuffle_epi32(rhs8_11, 0x55); + process_column(tmp0, tmp1, accum_data_v4); + tmp2 = _mm256_shuffle_epi32(rhs8_11, 0xaa); + tmp3 = _mm256_shuffle_epi32(rhs8_11, 0xff); + process_column(tmp2, tmp3, accum_data_v5); + + tmp0 = _mm256_shuffle_epi32(rhs12_15, 0); + tmp1 = _mm256_shuffle_epi32(rhs12_15, 0x55); + process_column(tmp0, tmp1, accum_data_v6); + tmp2 = _mm256_shuffle_epi32(rhs12_15, 0xaa); + tmp3 = _mm256_shuffle_epi32(rhs12_15, 0xff); + process_column(tmp2, tmp3, accum_data_v7); + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_fixedpoint + multiplier_channel)); + e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_exponent + multiplier_channel)); + + const __m256i m_64bit_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); + const __m256i m_64bit_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); + const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); + const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); + const __m256i final_right_shift = _mm256_set1_epi32(31); + const __m256i final_right_shift_low = _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(final_right_shift, 1)); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = _mm256_setzero_si256(); + // A "half" added for rounding prior to truncation of 64-bit value. + const __m256i offset_vector = _mm256_add_epi64( + _mm256_slli_epi64(_mm256_set1_epi64x(1), 30), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); + } + + const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + + // We cannot do + // + // scaled_v_low = + // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); + // scaled_v_high = + // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); + // + // since this instruction is not in AVX2. Instead we use + // _mm256_srlv_epi64, but this is an unsigned shift, so we applied + // offsets before (convert_to_unsigned_64) and after + // (convert_to_signed_halved). + // + // The overall process is, for 64-bit scaled accumulator: + // unsigned_accum = signed_accum + 1 << 63; + // unsigned_accum = (unsigned_accum >> right_shift) >> 31; + // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; + + // There are various ways to repack the results, in the absence of + // _mm256_cvtepi64_epi32() or anything like it. + // A. + // accum_data_v[j] = + // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), + // _mm256_extract_epi32(scaled_v_high, 4), + // _mm256_extract_epi32(scaled_v_high, 2), + // _mm256_extract_epi32(scaled_v_high, 0), + // _mm256_extract_epi32(scaled_v_low, 6), + // _mm256_extract_epi32(scaled_v_low, 4), + // _mm256_extract_epi32(scaled_v_low, 2), + // _mm256_extract_epi32(scaled_v_low, 0)); + // B. + // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); + // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); + // accum_data_v[j] = + // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), + // _mm256_extract_epi64(scaled_v_high, 0), + // _mm256_extract_epi64(scaled_v_low, 2), + // _mm256_extract_epi64(scaled_v_low, 0)); + // C. + // scaled_v_low = + // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); + // scaled_v_high = + // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); + // accum_data_v[j] = + // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); + // + // However, we choose the following because it uses two lighter + // instructions. The permutation does have a longer latency, but this + // loop can be unrolled. + // D. + // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + // __m256i results = + // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + // results = _mm256_permutevar8x32_epi32(results, repack_perm); + // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset); + + // This multiplier code is complex and expensive enough on x86, that + // we prefer to implement the channels-are-columns case by transposing + // around it, rather than duplicate it (which would also require + // duplicating the above code computing the multiplier constants). + // This is one instance where channels-are-columns has lower performance + // than channels-are-rows. + const bool transpose_around_multiplier = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); + if (transpose_around_multiplier) { + // Transpose the 8x8 accumulators block. Will be un-transposed below + // after the multplier implementation. + intrin_utils::mm256_transpose8x8_epi32<path>( + &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, + &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); + } + + auto rounding_right_shift = [=](__m256i& results, + const __m256i& exponent) { + // Construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m256i zeros = _mm256_setzero_si256(); + const __m256i mask_rightshift_gtz = + _mm256_cmpgt_epi32(exponent, zeros); + const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32( + _mm256_set1_epi32(1), + _mm256_sub_epi32(exponent, _mm256_set1_epi32(1))); + __m256i nudge = intrin_utils::mm256_blendv_epi32( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge); + const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent); + + // Identify overflow in each lane and create mask. + const __m256i one_shift_31minus_exp = _mm256_sllv_epi32( + _mm256_set1_epi32(1), + _mm256_sub_epi32(_mm256_set1_epi32(31), exponent)); + const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32( + results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm256_blendv_epi32( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + }; + + auto apply_multiplier = [=](__m256i& accum) { + __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + // Now do a Rounding Right Shift. + rounding_right_shift(results, right_shift); + accum = _mm256_add_epi32(results, post_scaling_offset); + }; + apply_multiplier(accum_data_v0); + apply_multiplier(accum_data_v1); + apply_multiplier(accum_data_v2); + apply_multiplier(accum_data_v3); + apply_multiplier(accum_data_v4); + apply_multiplier(accum_data_v5); + apply_multiplier(accum_data_v6); + apply_multiplier(accum_data_v7); + // See above comment: here we transpose again to undo the transposition + // of the 8x8 block of accumulators used to implement the + // channels-are-columns case. + if (transpose_around_multiplier) { + intrin_utils::mm256_transpose8x8_epi32<path>( + &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, + &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + __m256i accum_data_v[kAvx8bitBlockSize]; + if (!store_full_block) { + accum_data_v[0] = accum_data_v0; + accum_data_v[1] = accum_data_v1; + accum_data_v[2] = accum_data_v2; + accum_data_v[3] = accum_data_v3; + accum_data_v[4] = accum_data_v4; + accum_data_v[5] = accum_data_v5; + accum_data_v[6] = accum_data_v6; + accum_data_v[7] = accum_data_v7; + } + + if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { + std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[0 * dst_stride], accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[1 * dst_stride], accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { + std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + intrin_utils::mm256_n_storeu_epi32<path>( + dst_block_ptr, residual_rows, accum_data_v[j]); + dst_block_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { + Kernel8bitAvx2Impl<Path::kAvx2Fma>(params); +} + +template <Path path> +void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + + const __m256i splitter_idx = + _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + + // Initialize with bias. + __m256i initial_accum_data = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr)); + bias_ptr += bias_ptr_block_increment; + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m256i lhs_sums_offset = _mm256_mullo_epi32( + _mm256_set1_epi32(rhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); + initial_accum_data = + _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); + } + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth) { + initial_accum_data = _mm256_add_epi32(initial_accum_data, + _mm256_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm256_sub_epi32(initial_accum_data, + _mm256_set1_epi32(rhs_sums_offsets[0])); + } else { + accum_data_v0 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); + const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + // For simplicity we load 4x the data that we need and process twice the + // data that we need and store only the data we need. + std::int32_t rhs_data[2]; + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + + // NOTE: There may be opportunities for permuting the data in the packing + // code instead of here. + const __m256i lhs_data_split = + _mm256_shuffle_epi8(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + // Accumulate for column 0. + const std::int32_t low_rhs_value = rhs_data[0]; + const std::int32_t high_rhs_value = rhs_data[1]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0; + m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_fixedpoint + channel)); + e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_exponent + channel)); + + const __m256i m_64bit_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); + const __m256i m_64bit_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); + const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); + const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); + const __m256i final_right_shift = _mm256_set1_epi32(31); + const __m256i final_right_shift_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1)); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = _mm256_setzero_si256(); + // A "half" added for rounding prior to truncation of 64-bit value. + const __m256i offset_vector = _mm256_add_epi64( + _mm256_slli_epi64(_mm256_set1_epi64x(1), 30), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); + } + + const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + + // See GEMM version for details of this process. + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + // Now do a Rounding Right Shift. + // First, construct the nudge value for each lane. + const __m256i zeros = _mm256_setzero_si256(); + const __m256i mask_rightshift_gtz = + _mm256_cmpgt_epi32(right_shift, zeros); + const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32( + _mm256_set1_epi32(1), + _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1))); + __m256i nudge = intrin_utils::mm256_blendv_epi32( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge); + const __m256i shifted_sum = + _mm256_srav_epi32(r_plus_nudge, right_shift); + + // Identify overflow in each lane and create mask. + const __m256i one_shift_31minus_exp = _mm256_sllv_epi32( + _mm256_set1_epi32(1), + _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift)); + const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32( + results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm256_blendv_epi32( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + + accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + + if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { + std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); + __m256i result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { + std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); + __m256i result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); + __m256i result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); + intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows, + accum_data_v0); + dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { + Kernel8bitAvx2SingleColImpl<Path::kAvx2Fma>(params); +} + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2Fma float"); + KernelFloatAvxCommon<Path::kAvx2Fma>(params); +} + +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV"); + KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params); +} + +#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) + +} // namespace ruy diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc new file mode 100644 index 0000000..fddb482 --- /dev/null +++ b/ruy/kernel_avx512.cc @@ -0,0 +1,1550 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <algorithm> +#include <cstdint> + +#include "ruy/check_macros.h" +#include "ruy/kernel_x86.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM_AVX512 && RUY_OPT(ASM) +#include <immintrin.h> // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM)) + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) + +namespace { +namespace intrin_utils { + +__m256i mm256_blendv_epi64(const __m256i& a, const __m256i& b, + const __m256i& mask) { + __m256d result = + _mm256_blendv_pd(_mm256_castsi256_pd(a), _mm256_castsi256_pd(b), + _mm256_castsi256_pd(mask)); + return _mm256_castpd_si256(result); +} + +__m512i mm512_blendv_epi64(const __m512i& a, const __m512i& b, + const __m512i& mask) { + __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); + __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); + __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); + __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); + __m256i mask_lo = _mm512_extracti64x4_epi64(mask, 0); + __m256i mask_hi = _mm512_extracti64x4_epi64(mask, 1); + __m256i lo = mm256_blendv_epi64(a_lo, b_lo, mask_lo); + __m256i hi = mm256_blendv_epi64(a_hi, b_hi, mask_hi); + __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0); + return _mm512_inserti64x4(result, hi, 1); +} + +__m512i mm512_cmpgt_epi64(const __m512i& a, const __m512i& b) { + __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); + __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); + __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); + __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); + __m256i lo = _mm256_cmpgt_epi64(a_lo, b_lo); + __m256i hi = _mm256_cmpgt_epi64(a_hi, b_hi); + __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0); + return _mm512_inserti64x4(result, hi, 1); +} + +} // namespace intrin_utils +} // namespace + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 8-bit"); + + std::int32_t dst_stride = 0; + if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || + (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + + for (int col = params.start_col; col <= params.last_col; col += 16) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[16]; + if (has_rhs_sums_offsets) { + const __m512i rhs_sums_offset_v = + _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), + _mm512_loadu_si512(¶ms.rhs_sums[col])); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; row += 16) { + int channel = + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row; + int multiplier_channel = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0; + + const int residual_rows = std::min(params.dst_rows - row, 16); + const int residual_cols = std::min(params.dst_cols - col, 16); + + __m512i accum_data_v0; + __m512i accum_data_v1; + __m512i accum_data_v2; + __m512i accum_data_v3; + __m512i accum_data_v4; + __m512i accum_data_v5; + __m512i accum_data_v6; + __m512i accum_data_v7; + __m512i accum_data_v8; + __m512i accum_data_v9; + __m512i accum_data_va; + __m512i accum_data_vb; + __m512i accum_data_vc; + __m512i accum_data_vd; + __m512i accum_data_ve; + __m512i accum_data_vf; + + const __mmask16 row_mask = + (static_cast<std::uint32_t>(1) << residual_rows) - 1; + + // initial_accum_data will be the initialize of each of the + // accum_data_* accumulator registers. We compute into it terms that are + // identical across columns. + __m512i initial_accum_data = _mm512_set1_epi32(params.prod_zp_depth); + + // In the channels-are-rows case, we can load bias here. + if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && + !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { + initial_accum_data = _mm512_add_epi32( + initial_accum_data, + _mm512_loadu_si512( + reinterpret_cast<const __m512i*>(params.bias + row))); + } + + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m512i lhs_sums_offset = + _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), + _mm512_loadu_si512(¶ms.lhs_sums[row])); + initial_accum_data = + _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0])); + accum_data_v1 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1])); + accum_data_v2 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2])); + accum_data_v3 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3])); + accum_data_v4 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4])); + accum_data_v5 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5])); + accum_data_v6 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6])); + accum_data_v7 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7])); + accum_data_v8 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8])); + accum_data_v9 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9])); + accum_data_va = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10])); + accum_data_vb = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11])); + accum_data_vc = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12])); + accum_data_vd = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13])); + accum_data_ve = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14])); + accum_data_vf = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15])); + } else { + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + accum_data_v8 = initial_accum_data; + accum_data_v9 = initial_accum_data; + accum_data_va = initial_accum_data; + accum_data_vb = initial_accum_data; + accum_data_vc = initial_accum_data; + accum_data_vd = initial_accum_data; + accum_data_ve = initial_accum_data; + accum_data_vf = initial_accum_data; + } + + // Finally, in the channels-are-columns case, load bias data here. + if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { + const __m512i bias_data = _mm512_loadu_si512( + reinterpret_cast<const __m512i*>(params.bias + col)); + accum_data_v0 = _mm512_add_epi32( + accum_data_v0, + _mm512_permutexvar_epi32(_mm512_set1_epi32(0), bias_data)); + accum_data_v1 = _mm512_add_epi32( + accum_data_v1, + _mm512_permutexvar_epi32(_mm512_set1_epi32(1), bias_data)); + accum_data_v2 = _mm512_add_epi32( + accum_data_v2, + _mm512_permutexvar_epi32(_mm512_set1_epi32(2), bias_data)); + accum_data_v3 = _mm512_add_epi32( + accum_data_v3, + _mm512_permutexvar_epi32(_mm512_set1_epi32(3), bias_data)); + accum_data_v4 = _mm512_add_epi32( + accum_data_v4, + _mm512_permutexvar_epi32(_mm512_set1_epi32(4), bias_data)); + accum_data_v5 = _mm512_add_epi32( + accum_data_v5, + _mm512_permutexvar_epi32(_mm512_set1_epi32(5), bias_data)); + accum_data_v6 = _mm512_add_epi32( + accum_data_v6, + _mm512_permutexvar_epi32(_mm512_set1_epi32(6), bias_data)); + accum_data_v7 = _mm512_add_epi32( + accum_data_v7, + _mm512_permutexvar_epi32(_mm512_set1_epi32(7), bias_data)); + accum_data_v8 = _mm512_add_epi32( + accum_data_v8, + _mm512_permutexvar_epi32(_mm512_set1_epi32(8), bias_data)); + accum_data_v9 = _mm512_add_epi32( + accum_data_v9, + _mm512_permutexvar_epi32(_mm512_set1_epi32(9), bias_data)); + accum_data_va = _mm512_add_epi32( + accum_data_va, + _mm512_permutexvar_epi32(_mm512_set1_epi32(10), bias_data)); + accum_data_vb = _mm512_add_epi32( + accum_data_vb, + _mm512_permutexvar_epi32(_mm512_set1_epi32(11), bias_data)); + accum_data_vc = _mm512_add_epi32( + accum_data_vc, + _mm512_permutexvar_epi32(_mm512_set1_epi32(12), bias_data)); + accum_data_vd = _mm512_add_epi32( + accum_data_vd, + _mm512_permutexvar_epi32(_mm512_set1_epi32(13), bias_data)); + accum_data_ve = _mm512_add_epi32( + accum_data_ve, + _mm512_permutexvar_epi32(_mm512_set1_epi32(14), bias_data)); + accum_data_vf = _mm512_add_epi32( + accum_data_vf, + _mm512_permutexvar_epi32(_mm512_set1_epi32(15), bias_data)); + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += 4) { + const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); + __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + std::int32_t rhs_data[32]; + const __m256i rhs_data_bottom_lane = + _mm512_castsi512_si256(rhs_data_8bit); + const __m256i rhs_data_top_lane = + _mm512_extracti32x8_epi32(rhs_data_8bit, 1); + const __m512i rhs_16_bit_dup_low = + _mm512_cvtepi8_epi16(rhs_data_bottom_lane); + const __m512i rhs_16_bit_dup_high = + _mm512_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data), + rhs_16_bit_dup_low); + _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16), + rhs_16_bit_dup_high); + + // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. + const __m512i lhs_16_bit_low = + _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); + // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. + const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( + _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); + + auto process_column = [=](int col, __m512i& accum) { + const __m512i rhs_16_bit_dup_low = + _mm512_set1_epi32(rhs_data[2 * col]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[2 * col + 1]); + + accum = _mm512_add_epi32( + accum, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum = _mm512_add_epi32( + accum, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + }; + process_column(0, accum_data_v0); + process_column(1, accum_data_v1); + process_column(2, accum_data_v2); + process_column(3, accum_data_v3); + process_column(4, accum_data_v4); + process_column(5, accum_data_v5); + process_column(6, accum_data_v6); + process_column(7, accum_data_v7); + process_column(8, accum_data_v8); + process_column(9, accum_data_v9); + process_column(10, accum_data_va); + process_column(11, accum_data_vb); + process_column(12, accum_data_vc); + process_column(13, accum_data_vd); + process_column(14, accum_data_ve); + process_column(15, accum_data_vf); + + lhs_ptr += 16 * 4; + rhs_ptr += 16 * 4; + } + + if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + // The non-per-channel case could equivalently be handled in the per-row + // or per-column code path. The per-row code path is slightly more + // efficient so we handle it there. + const bool per_column_multiplier = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); + + __m512i m_vector; + __m512i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( + params.multiplier_fixedpoint + multiplier_channel)); + e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( + params.multiplier_exponent + multiplier_channel)); + + const __m512i m_64bit_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); + const __m512i m_64bit_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); + + const __m512i zero_vector = _mm512_setzero_epi32(); + const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); + const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); + const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); + const __m512i final_right_shift = _mm512_set1_epi32(31); + const __m512i right_shift_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)); + const __m512i right_shift_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)); + const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 0)); + const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 1)); + + // A "half" added for rounding prior to truncation of 64-bit value. + const __m512i offset_vector = + _mm512_slli_epi64(_mm512_set1_epi64(1), 30); + + auto rounding_right_shift = [=](__m512i& results, + const __m512i& exponent) { + // Construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m512i zeros = _mm512_setzero_si512(); + const __m512i mask_rightshift_gtz = + intrin_utils::mm512_cmpgt_epi64(exponent, zeros); + const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64( + _mm512_set1_epi64(1), + _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); + __m512i nudge = intrin_utils::mm512_blendv_epi64( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); + const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); + + // Identify overflow in each lane and create mask. + const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( + _mm512_set1_epi64(1), + _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); + const __m512i mask_num_plus_nudge_overflow = + intrin_utils::mm512_cmpgt_epi64( + results, + _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm512_blendv_epi64( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + }; + + if (per_column_multiplier) { + auto apply_multiplier = [=](__m512i& accum, int col) { + __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8); + // Apply the fixed-point part of the multiplier. + __m512i left_shift_val = + _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift); + __m512i m_64bit_val = _mm512_permutexvar_epi64( + perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high); + __m512i offset_vector_val = _mm512_permutexvar_epi64( + perm_64bit_vals, offset_vector); + __m512i final_right_shift_val = _mm512_permutexvar_epi64( + perm_64bit_vals, + col < 8 ? final_right_shift_low : final_right_shift_high); + __m512i right_shift_val = _mm512_permutexvar_epi64( + perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high); + + accum = _mm512_sllv_epi32(accum, left_shift_val); + __m512i scaled_v_low = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)), + m_64bit_val); + __m512i scaled_v_high = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)), + m_64bit_val); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val); + + scaled_v_low = + _mm512_srav_epi64(scaled_v_low, final_right_shift_val); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_val); + + rounding_right_shift(scaled_v_low, right_shift_val); + rounding_right_shift(scaled_v_high, right_shift_val); + + accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum = _mm512_inserti32x8(accum, + _mm512_cvtepi64_epi32(scaled_v_high), 1); + }; + apply_multiplier(accum_data_v0, 0); + apply_multiplier(accum_data_v1, 1); + apply_multiplier(accum_data_v2, 2); + apply_multiplier(accum_data_v3, 3); + apply_multiplier(accum_data_v4, 4); + apply_multiplier(accum_data_v5, 5); + apply_multiplier(accum_data_v6, 6); + apply_multiplier(accum_data_v7, 7); + apply_multiplier(accum_data_v8, 8); + apply_multiplier(accum_data_v9, 9); + apply_multiplier(accum_data_va, 10); + apply_multiplier(accum_data_vb, 11); + apply_multiplier(accum_data_vc, 12); + apply_multiplier(accum_data_vd, 13); + apply_multiplier(accum_data_ve, 14); + apply_multiplier(accum_data_vf, 15); + } else { // not per-column, so per-row + auto apply_multiplier = [=](__m512i& accum) { + accum = _mm512_sllv_epi32(accum, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)), + m_64bit_low); + __m512i scaled_v_high = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector); + + scaled_v_low = + _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + rounding_right_shift(scaled_v_low, right_shift_low); + rounding_right_shift(scaled_v_high, right_shift_high); + accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum = _mm512_inserti32x8(accum, + _mm512_cvtepi64_epi32(scaled_v_high), 1); + }; + apply_multiplier(accum_data_v0); + apply_multiplier(accum_data_v1); + apply_multiplier(accum_data_v2); + apply_multiplier(accum_data_v3); + apply_multiplier(accum_data_v4); + apply_multiplier(accum_data_v5); + apply_multiplier(accum_data_v6); + apply_multiplier(accum_data_v7); + apply_multiplier(accum_data_v8); + apply_multiplier(accum_data_v9); + apply_multiplier(accum_data_va); + apply_multiplier(accum_data_vb); + apply_multiplier(accum_data_vc); + apply_multiplier(accum_data_vd); + apply_multiplier(accum_data_ve); + apply_multiplier(accum_data_vf); + } + + if (params.dst_zero_point != 0) { + __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); + accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); + accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point); + accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point); + accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point); + accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point); + accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point); + accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point); + accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point); + accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point); + accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point); + accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point); + accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point); + accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point); + accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point); + accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point); + accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point); + } + } + + const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); + const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); + + const bool store_full_block = + (residual_rows == 16) && (residual_cols == 16); + + __m512i accum_data_v[16]; + + // In most cases we would make this conditional on (!store_full_block) and + // unwind the clamp-and-store loop, but the benefit appears small. + { + accum_data_v[0] = accum_data_v0; + accum_data_v[1] = accum_data_v1; + accum_data_v[2] = accum_data_v2; + accum_data_v[3] = accum_data_v3; + accum_data_v[4] = accum_data_v4; + accum_data_v[5] = accum_data_v5; + accum_data_v[6] = accum_data_v6; + accum_data_v[7] = accum_data_v7; + accum_data_v[8] = accum_data_v8; + accum_data_v[9] = accum_data_v9; + accum_data_v[10] = accum_data_va; + accum_data_v[11] = accum_data_vb; + accum_data_v[12] = accum_data_vc; + accum_data_v[13] = accum_data_vd; + accum_data_v[14] = accum_data_ve; + accum_data_v[15] = accum_data_vf; + } + + if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { + std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < 16; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset), + _mm512_cvtepi32_epi8(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi8(result)); + } + } + dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { + std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset), + _mm512_cvtepi32_epi8(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi8(result)); + } + } + dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < 16; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(tmp_ptr + j * block_col_offset), + _mm512_cvtepi32_epi16(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi16(result)); + } + } + dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); + for (int j = 0; j < 16; ++j) { + _mm512_storeu_si512(tmp_ptr + j * dst_stride, accum_data_v[j]); + } + } else { + std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask, + accum_data_v[j]); + } + } + dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += 16 * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + + 16 * params.dst_stride); + rhs_col_ptr += 16 * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[16]; + if (has_rhs_sums_offsets) { + const __m512i rhs_sums_offset_v = + _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), + _mm512_loadu_si512(¶ms.rhs_sums[0])); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; row += 16) { + const int residual_rows = std::min(params.dst_rows - row, 16); + + __m512i accum_data_v0; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast<std::uint32_t>(1) << residual_rows) - 1; + __m512i initial_accum_data = + _mm512_loadu_si512(reinterpret_cast<const __m512i*>(bias_ptr)); + bias_ptr += bias_ptr_block_increment; + + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m512i lhs_sums_offset = + _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), + _mm512_loadu_si512(¶ms.lhs_sums[row])); + initial_accum_data = + _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); + } + + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth != 0) { + initial_accum_data = _mm512_add_epi32(initial_accum_data, + _mm512_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm512_sub_epi32(initial_accum_data, + _mm512_set1_epi32(rhs_sums_offsets[0])); + } else { + accum_data_v0 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += 4) { + const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); + const __m128i rhs_data_8bit = + _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr)); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + // For simplicity we load 4x the data that we need and process twice the + // data that we need and store only the data we need. + std::int32_t rhs_data[2]; + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + + // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. + const __m512i lhs_16_bit_low = + _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); + // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. + const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( + _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); + + // Process column 0. + __m512i accum_v = accum_data_v0; + constexpr int index = 0; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v0 = accum_v; + + lhs_ptr += 16 * 4; + rhs_ptr += 16 * 4; + } + + if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + __m512i m_vector; + __m512i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0; + m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( + params.multiplier_fixedpoint + channel)); + e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( + params.multiplier_exponent + channel)); + + const __m512i m_64bit_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); + const __m512i m_64bit_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); + + const __m512i zero_vector = _mm512_setzero_epi32(); + const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); + const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); + const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); + const __m512i final_right_shift = _mm512_set1_epi32(31); + const __m512i right_shift_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)); + const __m512i right_shift_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)); + const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 0)); + const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 1)); + + // A "half" added for rounding prior to truncation of 64-bit value. + const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); + + auto rounding_right_shift = [=](__m512i& results, + const __m512i& exponent) { + // Construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m512i zeros = _mm512_setzero_si512(); + const __m512i mask_rightshift_gtz = + intrin_utils::mm512_cmpgt_epi64(exponent, zeros); + const __m512i one_shift_exp_minus1 = + _mm512_sllv_epi64(_mm512_set1_epi64(1), + _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); + __m512i nudge = intrin_utils::mm512_blendv_epi64( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); + const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); + + // Identify overflow in each lane and create mask. + const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( + _mm512_set1_epi64(1), + _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); + const __m512i mask_num_plus_nudge_overflow = + intrin_utils::mm512_cmpgt_epi64( + results, + _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm512_blendv_epi64( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + }; + + // Shift and round column 0. + accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)), + m_64bit_low); + __m512i scaled_v_high = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + rounding_right_shift(scaled_v_low, right_shift_low); + rounding_right_shift(scaled_v_high, right_shift_high); + + accum_data_v0 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v0 = _mm512_inserti32x8( + accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); + + if (params.dst_zero_point != 0) { + __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); + accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); + } + } + + const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); + const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); + + if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { + std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); + dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { + std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); + dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_mask_storeu_epi16(tmp_ptr, row_mask, + _mm512_cvtepi32_epi16(result)); + dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); + _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0); + dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += 16 * params.lhs_stride; + } // End row-block loop. +} // NOLINT(readability/fn_size) + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 float"); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + const std::int64_t dst_stride = params.dst_stride >> 2; + const std::int64_t rhs_stride = params.rhs_stride >> 2; + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + const int end_row = std::min(params.dst_rows, params.last_row + 16); + const int end_col = std::min(params.dst_cols, params.last_col + 16); + + const float* adj_rhs_col_ptr = + params.rhs_base_ptr - params.start_col * rhs_stride; + float* adj_dst_col_ptr = + params.dst_base_ptr - params.start_col * dst_stride - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_ptr = params.bias; + + const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); + const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); + const bool channel_dimension_is_col = + params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; + + int col = params.start_col; + for (; col <= end_col - 16; col += 16) { + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + int row = params.start_row; + for (; row <= end_row - 16; row += 16) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + + // Process block in two halves, split by columns. + { + constexpr int mmm = 0; + + __m512 accum_data_v0; + __m512 accum_data_v1; + __m512 accum_data_v2; + __m512 accum_data_v3; + __m512 accum_data_v4; + __m512 accum_data_v5; + __m512 accum_data_v6; + __m512 accum_data_v7; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = + bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; + accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); + accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); + accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); + accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); + accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); + accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); + accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); + accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); + } else { + const __m512 initial_accum_data = + _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); + + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + // In this version RHS values are loaded individually rather than + // first loading together and then extract with broadcasting. This is + // because AVX flavours and instrinsics and compilers in combination + // do not handle this pattern of extraction very well. + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + + { + // Load 8 float32 values. + __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); + __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 + __m512 rhs4_7 = + _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 + + const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + // Load 8 float32 values. + __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); + __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 + __m512 rhs4_7 = + _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 + const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); + } + } + } // Inner half-block loop, unrolled, first iteration. + { + constexpr int mmm = 1; + + __m512 accum_data_v0; + __m512 accum_data_v1; + __m512 accum_data_v2; + __m512 accum_data_v3; + __m512 accum_data_v4; + __m512 accum_data_v5; + __m512 accum_data_v6; + __m512 accum_data_v7; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = + bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; + accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); + accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); + accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); + accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); + accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); + accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); + accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); + accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); + } else { + const __m512 initial_accum_data = + _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); + + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + { + // Load 8 float32 values. + __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); + __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 + __m512 rhs4_7 = + _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 + + const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + // Load 8 float32 values. + __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); + __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 + __m512 rhs4_7 = + _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 + const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); + } + } + } // Inner half-block loop, unrolled, second iteration. + } // End row-block loop. + + // The unrolling within this conditional may be somewhat pointless. It + // depends on the kinds of models. + if (row < end_row) { + const int residual_rows = end_row - row; + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + + const __mmask16 row_mask = + (static_cast<std::uint32_t>(1) << residual_rows) - 1; + + // Process block in two halves, split by columns. + for (int mmm = 0; mmm < 2; ++mmm) { + __m512 accum_data_v0; + __m512 accum_data_v1; + __m512 accum_data_v2; + __m512 accum_data_v3; + __m512 accum_data_v4; + __m512 accum_data_v5; + __m512 accum_data_v6; + __m512 accum_data_v7; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = + bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; + accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); + accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); + accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); + accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); + accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); + accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); + accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); + accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); + } else { + const __m512 initial_accum_data = + _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); + + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + { + // Load 8 float32 values. + __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); + __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 + __m512 rhs4_7 = + _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 + + const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + // Load 8 float32 values. + __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); + __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 + __m512 rhs4_7 = + _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 + const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask, + accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask, + accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask, + accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask, + accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask, + accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask, + accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask, + accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask, + accum_data_v7); + } + } + } // Inner half-block loop. + } // Residual rows, main col-block loop. + } // End col-block loop. + + if (col < end_col) { + RUY_DCHECK_GE(end_col - col, 0); + RUY_DCHECK_LT(end_col - col, 16); + + __m512 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + for (int row = params.start_row; row < end_row; row += 16) { + const int residual_rows = std::min(end_row - row, 16); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + + const __mmask16 row_mask = + (static_cast<std::uint32_t>(1) << residual_rows) - 1; + + // Process block in two halves, split by columns. + for (int mmm = 0; mmm < 2; ++mmm) { + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = + bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]); + } + } else { + const __m512 initial_accum_data = + _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 16; + rhs_ptr += 16; + } + + const int residual_cols = std::min(end_col - col - 8 * mmm, 8); + + if (residual_rows == 16) { + if (residual_cols == 8) { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_storeu_ps(block_ptr, accum_data_v[j]); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_storeu_ps(block_ptr, accum_data_v[j]); + } + } + } else { + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]); + } + } + } // Inner half-block loop. + } // End row-block loop. + } // Residual cols. +} + +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 float GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + const int end_row = std::min(params.dst_rows, params.last_row + 16); + + float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); + const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); + + __m512 accum_data_v; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = adj_dst_col_ptr; + + int row = params.start_row; + for (; row <= end_row - 16; row += 16) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm512_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float rhs_data = *rhs_ptr; + + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); + accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 16; + rhs_ptr += 16; + } + + accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); + _mm512_storeu_ps(dst_ptr, accum_data_v); + } // End row-block loop. + + if (row < end_row) { + const int residual_rows = end_row - row; + RUY_CHECK_GE(residual_rows, 1); + RUY_CHECK_LT(residual_rows, 16); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast<std::uint32_t>(1) << residual_rows) - 1; + accum_data_v = _mm512_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float rhs_data = *rhs_ptr; + + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); + accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 16; + rhs_ptr += 16; + } + + accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); + _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v); + } // End handling of residual rows. +} + +#endif // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) + +} // namespace ruy diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h new file mode 100644 index 0000000..9509b8f --- /dev/null +++ b/ruy/kernel_common.h @@ -0,0 +1,287 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_KERNEL_COMMON_H_ +#define RUY_RUY_KERNEL_COMMON_H_ + +#include <algorithm> +#include <cstdint> +#include <type_traits> + +#include "ruy/apply_multiplier.h" +#include "ruy/check_macros.h" +#include "ruy/mat.h" +#include "ruy/matrix.h" +#include "ruy/mul_params.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/tune.h" + +namespace ruy { + +template <Path ThePath, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +struct Kernel; + +#define RUY_INHERIT_KERNEL(PARENT, CHILD) \ + template <typename LhsScalar, typename RhsScalar, typename DstScalar, \ + typename AccumScalar> \ + struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar> \ + : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> { \ + explicit Kernel(Tuning tuning) \ + : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \ + tuning) {} \ + }; + +// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. +// +// In other cases, we still define (empty) versions, so that dummy kernels +// can use the classes in function signatures. +#if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \ + RUY_PLATFORM_X86 + +#define RUY_ASM_FLAG_HAS_BIAS 0x1 +#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 +#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 +#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 +#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 +#define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20 + +#define RUY_ASM_TYPE_ID_UINT8 1 +#define RUY_ASM_TYPE_ID_INT8 2 +#define RUY_ASM_TYPE_ID_INT16 3 +#define RUY_ASM_TYPE_ID_INT32 4 + +template <typename DstScalar> +struct DstTypeId {}; + +template <> +struct DstTypeId<std::uint8_t> { + static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; +}; + +template <> +struct DstTypeId<std::int8_t> { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; +}; + +template <> +struct DstTypeId<std::int16_t> { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; +}; + +template <> +struct DstTypeId<std::int32_t> { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; +}; + +template <int LhsCols, int RhsCols> +struct KernelParams8bit { + static constexpr int kMaxDstTypeSize = 4; + + const std::int32_t* bias; + const std::int32_t* lhs_sums; + const std::int32_t* rhs_sums; + const std::int8_t* lhs_base_ptr; + const std::int32_t* multiplier_fixedpoint; + const std::int32_t* multiplier_exponent; + const std::int8_t* rhs_base_ptr; + void* dst_base_ptr; + std::int32_t lhs_zero_point; + std::int32_t rhs_zero_point; + std::int32_t dst_zero_point; + std::int32_t prod_zp_depth; + std::int32_t start_row; + std::int32_t start_col; + std::int32_t last_row; + std::int32_t last_col; + std::int32_t dst_rows; + std::int32_t dst_cols; + std::int32_t lhs_stride; + std::int32_t rhs_stride; + std::int32_t dst_stride; + std::int32_t depth; + std::int32_t clamp_min; + std::int32_t clamp_max; + std::uint8_t flags; + std::uint8_t dst_type_id; + const std::int32_t zero_data[LhsCols] = {0}; + std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; + std::int32_t multiplier_fixedpoint_buf[LhsCols]; + std::int32_t multiplier_exponent_buf[LhsCols]; +}; + +template <typename DstScalar, int LhsCols, int RhsCols> +void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, + const PMat<std::int8_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, + int start_row, int start_col, int end_row, + int end_col, Mat<DstScalar>* dst, + KernelParams8bit<LhsCols, RhsCols>* params) { + using Params = KernelParams8bit<LhsCols, RhsCols>; + + static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); + + const int depth = lhs.layout.rows; + RUY_DCHECK_EQ(start_row % LhsCols, 0); + RUY_DCHECK_EQ(start_col % RhsCols, 0); + RUY_DCHECK_EQ(end_row % LhsCols, 0); + RUY_DCHECK_EQ(end_col % RhsCols, 0); + + params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; + params->flags = 0; + params->bias = params->zero_data; + if (mul_params.bias()) { + params->bias = mul_params.bias(); + params->flags |= RUY_ASM_FLAG_HAS_BIAS; + } + if (lhs.sums) { + params->lhs_sums = lhs.sums; + params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; + } + if (rhs.sums) { + params->rhs_sums = rhs.sums; + params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; + } + if (mul_params.channel_dimension() == ChannelDimension::kCol) { + params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; + } + params->start_row = start_row; + params->start_col = start_col; + params->last_row = end_row - LhsCols; + params->last_col = end_col - RhsCols; + params->lhs_stride = lhs.layout.stride; + params->rhs_stride = rhs.layout.stride; + params->dst_stride = sizeof(DstScalar) * dst->layout.stride; + params->lhs_zero_point = lhs.zero_point; + params->rhs_zero_point = rhs.zero_point; + params->dst_zero_point = dst->zero_point; + params->depth = depth; + params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; + params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; + if (mul_params.multiplier_fixedpoint_perchannel()) { + params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; + params->multiplier_fixedpoint = + mul_params.multiplier_fixedpoint_perchannel(); + params->multiplier_exponent = mul_params.multiplier_exponent_perchannel(); + } else { + params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; + params->multiplier_exponent = params->multiplier_exponent_buf; + for (int i = 0; i < LhsCols; i++) { + params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint(); + params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent(); + } + } + params->clamp_min = mul_params.clamp_min(); + params->clamp_max = mul_params.clamp_max(); + params->dst_rows = dst->layout.rows; + params->dst_cols = dst->layout.cols; + + RUY_DCHECK_LT(params->last_row, params->dst_rows); + RUY_DCHECK_LT(params->last_col, params->dst_cols); + + params->dst_type_id = DstTypeId<DstScalar>::kValue; + params->dst_base_ptr = + dst->data.get() + start_col * dst->layout.stride + start_row; +} + +template <int LhsCols, int RhsCols> +struct KernelParamsFloat { + const float* lhs_base_ptr; + const float* rhs_base_ptr; + float* dst_base_ptr; + const float* bias; + std::int32_t start_row; + std::int32_t start_col; + std::int32_t last_row; + std::int32_t last_col; + std::int32_t dst_rows; + std::int32_t dst_cols; + std::int32_t lhs_stride; + std::int32_t rhs_stride; + std::int32_t dst_stride; + std::int32_t depth; + float clamp_min; + float clamp_max; + std::uint8_t flags; + const float zero_data[LhsCols] = {0}; + float dst_tmp_buf[LhsCols * RhsCols]; +}; + +template <int LhsCols, int RhsCols> +inline void MakeKernelParamsFloat(const PMat<float>& lhs, + const PMat<float>& rhs, + const MulParams<float, float>& mul_params, + int start_row, int start_col, int end_row, + int end_col, Mat<float>* dst, + KernelParamsFloat<LhsCols, RhsCols>* params) { + const int depth = lhs.layout.rows; + RUY_DCHECK_EQ(start_row % LhsCols, 0); + RUY_DCHECK_EQ(start_col % RhsCols, 0); + RUY_DCHECK_EQ(end_row % LhsCols, 0); + RUY_DCHECK_EQ(end_col % RhsCols, 0); + + params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; + params->dst_base_ptr = + dst->data.get() + start_col * dst->layout.stride + start_row; + + std::uint8_t flags = 0; + params->bias = params->zero_data; + if (mul_params.bias()) { + params->bias = mul_params.bias(); + flags |= RUY_ASM_FLAG_HAS_BIAS; + } + if (mul_params.channel_dimension() == ChannelDimension::kCol) { + flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; + } + params->flags = flags; + params->start_row = start_row; + params->start_col = start_col; + params->last_row = end_row - LhsCols; + params->last_col = end_col - RhsCols; + params->lhs_stride = sizeof(float) * lhs.layout.stride; + params->rhs_stride = sizeof(float) * rhs.layout.stride; + params->dst_stride = sizeof(float) * dst->layout.stride; + params->depth = depth; + params->clamp_min = mul_params.clamp_min(); + params->clamp_max = mul_params.clamp_max(); + params->dst_rows = dst->layout.rows; + params->dst_cols = dst->layout.cols; + + RUY_DCHECK_LT(params->last_row, params->dst_rows); + RUY_DCHECK_LT(params->last_col, params->dst_cols); +} + +#else // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && + // RUY_OPT(ASM)) || RUY_PLATFORM_X86 + +template <int LhsCols, int RhsCols> +struct KernelParams8bit {}; + +template <int LhsCols, int RhsCols> +struct KernelParamsFloat {}; + +#endif // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && + // RUY_OPT(ASM)) || RUY_PLATFORM_X86 + +} // namespace ruy + +#endif // RUY_RUY_KERNEL_COMMON_H_ diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h new file mode 100644 index 0000000..2f8fe19 --- /dev/null +++ b/ruy/kernel_x86.h @@ -0,0 +1,874 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_KERNEL_X86_H_ +#define RUY_RUY_KERNEL_X86_H_ + +#include <cstdint> +#include <cstring> + +#include "ruy/kernel_common.h" +#include "ruy/mat.h" +#include "ruy/mul_params.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM_X86 + +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma) +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx) +RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512) + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); + +template <typename DstScalar> +struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar> { + static constexpr Path kPath = Path::kAvx512; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitAvx512SingleCol(params); + } else { + Kernel8bitAvx512(params); + } + } +}; + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); + +template <> +struct Kernel<Path::kAvx512, float, float, float, float> { + static constexpr Path kPath = Path::kAvx512; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>; + using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<float>& lhs, const PMat<float>& rhs, + const MulParams<float, float>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<float>* dst) const { + KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + KernelFloatAvx512SingleCol(params); + } else { + KernelFloatAvx512(params); + } + } +}; + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); + +template <typename DstScalar> +struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t, + DstScalar> { + static constexpr Path kPath = Path::kAvx2Fma; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitAvx2SingleCol(params); + } else { + Kernel8bitAvx2(params); + } + } +}; + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); + +template <> +struct Kernel<Path::kAvx2Fma, float, float, float, float> { + static constexpr Path kPath = Path::kAvx2Fma; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<float>& lhs, const PMat<float>& rhs, + const MulParams<float, float>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<float>* dst) const { + KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + KernelFloatAvx2SingleCol(params); + } else { + KernelFloatAvx2(params); + } + } +}; + +void KernelFloatAvx(const KernelParamsFloat<8, 8>& params); +void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params); + +template <> +struct Kernel<Path::kAvx, float, float, float, float> { + static constexpr Path kPath = Path::kAvx; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<float>& lhs, const PMat<float>& rhs, + const MulParams<float, float>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<float>* dst) const { + KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + KernelFloatAvxSingleCol(params); + } else { + KernelFloatAvx(params); + } + } +}; + +void Kernel8bitAvx(const KernelParams8bit<8, 8>& params); +void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params); + +template <typename DstScalar> +struct Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar> { + static constexpr Path kPath = Path::kAvx2Fma; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitAvxSingleCol(params); + } else { + Kernel8bitAvx(params); + } + } +}; + +#endif // RUY_PLATFORM_X86 +} // namespace ruy + +#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)) + +#include <immintrin.h> // IWYU pragma: keep + +namespace ruy { +namespace { +namespace intrin_utils { + +// Defined as a template so clang won't detect it as an uneeded +// definition. +template <Path path> +inline float mm256_get1_ps(const __m256 a, int i) { + __m256i ai = _mm256_castps_si256(a); + int float_val_as_int; + switch (i) { + case 0: + float_val_as_int = _mm256_extract_epi32(ai, 0); + break; + case 1: + float_val_as_int = _mm256_extract_epi32(ai, 1); + break; + case 2: + float_val_as_int = _mm256_extract_epi32(ai, 2); + break; + case 3: + float_val_as_int = _mm256_extract_epi32(ai, 3); + break; + case 4: + float_val_as_int = _mm256_extract_epi32(ai, 4); + break; + case 5: + float_val_as_int = _mm256_extract_epi32(ai, 5); + break; + case 6: + float_val_as_int = _mm256_extract_epi32(ai, 6); + break; + case 7: + float_val_as_int = _mm256_extract_epi32(ai, 7); + break; + default: + RUY_DCHECK_LT(i, 8); + return .0f; + } + float float_val; + std::memcpy(&float_val, &float_val_as_int, sizeof(float_val)); + return float_val; +} + +// Defined as a template so clang won't detect it as an uneeded +// definition. +template <Path path> +inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { + for (int i = 0; i < residual_rows; ++i) { + dst[i] = intrin_utils::mm256_get1_ps<path>(v, i); + } +} + +template <Path path> +inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) { + // Specializations added for AVX and AVX2FMA paths in their respective kernel + // files. + RUY_DCHECK(false); + return _mm256_set1_ps(0); +} + +template <Path path> +inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) { + // Specializations added for AVX and AVX2FMA paths in their respective kernel + // files. + RUY_DCHECK(false); + return _mm256_set1_epi32(0); +} + +// Polyfill for _mm_storeu_si16(dst, v). +template <Path path> +inline void mm_storeu_si16(void* dst, __m128i v) { +#if (defined __clang__) || (defined _MSC_VER) + _mm_storeu_si16(dst, v); +#else + // GCC 9 lacks support for __mm_storeu_si16. + *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0); +#endif +} + +// Polyfill for _mm_storeu_si32(dst, v). +template <Path path> +inline void mm_storeu_si32(void* dst, __m128i v) { +#if (defined __clang__) || (defined _MSC_VER) + _mm_storeu_si32(dst, v); +#else + // GCC 9 lacks support for __mm_storeu_si32. + *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0); +#endif +} + +// Polyfill for _mm_loadu_si32(src). +template <Path path> +inline __m128i mm_loadu_si32(const void* src) { +#if (defined __clang__) || (defined _MSC_VER) + return _mm_loadu_si32(src); +#else + // GCC 9 lacks support for _mm_loadu_si32. + __m128i res; + asm("movss %[src], %[res]" + : [res] "=x"(res) + : [src] "m"(*static_cast<const int*>(src))); + return res; +#endif +} + +template <Path path> +inline __m128i mm256_extracti128_si256(const __m256i&, const int) { + RUY_DCHECK(false); + return _mm_setzero_si128(); +} + +template <Path path> +inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, + const __m256i v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + __m256i shuffled_v; + if (residual_rows > 1) { + // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 + // in each 128-bit lane. + shuffled_v = intrin_utils::mm256_shuffle_epi8<path>(v, repack_perm); + } + switch (residual_rows) { + case 0: + break; + case 1: + dst[0] = _mm256_extract_epi8(v, 0); + break; + case 2: + mm_storeu_si16<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + break; + case 3: { + __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 0); + mm_storeu_si16<path>(dst, trailing_packed); + dst[2] = _mm_extract_epi8(trailing_packed, 2); + break; + } + case 4: + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + break; + case 5: + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + dst[4] = _mm256_extract_epi8(shuffled_v, 16); + break; + case 6: + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + mm_storeu_si16<path>(dst + 4, + mm256_extracti128_si256<path>(shuffled_v, 1)); + break; + case 7: { + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1); + mm_storeu_si16<path>(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi8(trailing_packed, 2); + break; + } + case 8: + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + mm_storeu_si32<path>(dst + 4, + mm256_extracti128_si256<path>(shuffled_v, 1)); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +template <Path path> +inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); +} + +template <Path path> +inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, + const __m256i v) { + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + reinterpret_cast<std::uint8_t*>(dst), residual_rows, v); +} + +template <Path path> +inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); +} + +template <Path path> +inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, + const __m256i v) { + // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively + // truncating each 16-bit integer. + const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); + __m256i shuffled_v; + __m128i shuffled_v_low; + if (residual_rows > 1) { + shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); + shuffled_v_low = mm256_extracti128_si256<path>(shuffled_v, 0); + } else { + shuffled_v_low = mm256_extracti128_si256<path>(v, 0); + } + switch (residual_rows) { + case 0: + break; + case 1: + mm_storeu_si16<path>(dst, shuffled_v_low); + break; + case 2: + mm_storeu_si32<path>(dst, shuffled_v_low); + break; + case 3: { + mm_storeu_si32<path>(dst, shuffled_v_low); + dst[2] = _mm_extract_epi16(shuffled_v_low, 2); + break; + } + case 4: + _mm_storeu_si64(dst, shuffled_v_low); + break; + case 5: + _mm_storeu_si64(dst, shuffled_v_low); + dst[4] = _mm256_extract_epi16(shuffled_v, 8); + break; + case 6: + _mm_storeu_si64(dst, shuffled_v_low); + mm_storeu_si32<path>(dst + 4, + mm256_extracti128_si256<path>(shuffled_v, 1)); + break; + case 7: { + _mm_storeu_si64(dst, shuffled_v_low); + __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1); + mm_storeu_si32<path>(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi16(trailing_packed, 2); + break; + } + case 8: + _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +template <Path path> +inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) { + // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively + // truncating each 16-bit integer. + const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); + const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); + _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); +} + +template <Path path> +inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, + const __m256i v) { + const __m128i v_low = mm256_extracti128_si256<path>(v, 0); + switch (residual_rows) { + case 0: + break; + case 1: + mm_storeu_si32<path>(dst, v_low); + break; + case 2: + _mm_storeu_si64(dst, v_low); + break; + case 3: { + __m128i trailing_packed = v_low; + _mm_storeu_si64(dst, trailing_packed); + dst[2] = _mm_extract_epi32(trailing_packed, 2); + break; + } + case 4: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + break; + case 5: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + dst[4] = _mm256_extract_epi32(v, 4); + break; + case 6: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(v, 1)); + break; + case 7: { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + __m128i trailing_packed = mm256_extracti128_si256<path>(v, 1); + _mm_storeu_si64(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi32(trailing_packed, 2); + break; + } + case 8: + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +template <Path path> +inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); +} + +// Transpose a 8x8 matrix of floats. +template <Path path> +void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3, + __m256* v4, __m256* v5, __m256* v6, __m256* v7) { + __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1); + __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1); + __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3); + __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3); + __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5); + __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5); + __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7); + __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7); + __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2)); + *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20); + *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20); + *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20); + *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20); + *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31); + *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31); + *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31); + *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31); +} + +// Transpose a 8x8 matrix of int32's. +template <Path path> +void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2, + __m256i* v3, __m256i* v4, __m256i* v5, + __m256i* v6, __m256i* v7) { + mm256_transpose8x8_ps<path>( + reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1), + reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3), + reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5), + reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7)); +} + +} // namespace intrin_utils +} // namespace + +template <Path path> +inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + const std::int64_t dst_stride = params.dst_stride >> 2; + const std::int64_t rhs_stride = params.rhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // AVX2 float block size = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + const int end_col = std::min(params.dst_cols, params.last_col + 8); + // + const float* adj_rhs_col_ptr = + params.rhs_base_ptr - params.start_col * rhs_stride; + float* adj_dst_col_ptr = + params.dst_base_ptr - params.start_col * dst_stride - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_ptr = params.bias; + + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); + const bool channel_dimension_is_col = + params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; + + int col = params.start_col; + // Loop through cols by float block size, leaving incomplete remainder + for (; col <= end_col - 8; col += 8) { + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); + } + } else { + const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; + const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + // Load 8 RHS values, then use permute instructions to + // broadcast each value to a register. + __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7] + __m256 rhs0_3 = + _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3] + __m256 rhs4_7 = + _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7] + + const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); + accum_data_v[0] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_0, accum_data_v[0]); + + const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85); + accum_data_v[1] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_1, accum_data_v[1]); + + const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170); + accum_data_v[2] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_2, accum_data_v[2]); + + const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255); + accum_data_v[3] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_3, accum_data_v[3]); + + const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0); + accum_data_v[4] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_4, accum_data_v[4]); + + const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85); + accum_data_v[5] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_5, accum_data_v[5]); + + const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170); + accum_data_v[6] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_6, accum_data_v[6]); + + const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255); + accum_data_v[7] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_7, accum_data_v[7]); + + lhs_ptr += 8; + rhs_ptr += 8; + } + + if (residual_rows == 8) { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + _mm256_storeu_ps(block_ptr, accum_data_v[j]); + } + } else { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows, + accum_data_v[j]); + } + } + } // End row-block loop. + } // End col-block loop. + + if (col < end_col) { + // Remaining cols in [0, float block size). + RUY_DCHECK_GE(end_col - col, 0); + RUY_DCHECK_LT(end_col - col, 8); + + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + const int residual_cols = std::min(end_col - col, 8); + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); + } + } else { + const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; + const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7] + __m256 rhs0_3 = + _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3] + __m256 rhs4_7 = + _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7] + + const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); + accum_data_v[0] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_0, accum_data_v[0]); + + const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85); + accum_data_v[1] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_1, accum_data_v[1]); + + const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170); + accum_data_v[2] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_2, accum_data_v[2]); + + const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255); + accum_data_v[3] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_3, accum_data_v[3]); + + const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0); + accum_data_v[4] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_4, accum_data_v[4]); + + const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85); + accum_data_v[5] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_5, accum_data_v[5]); + + const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170); + accum_data_v[6] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_6, accum_data_v[6]); + + const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255); + accum_data_v[7] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_7, accum_data_v[7]); + + lhs_ptr += 8; + rhs_ptr += 8; + } + + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows, + accum_data_v[j]); + } + } // End row-block loop. + } // End col-block terminal conditional. +} + +template <Path path> +inline void KernelFloatAvxCommonSingleCol( + const KernelParamsFloat<8, 8>& params) { + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // AVX2 float block size = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + + float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); + + __m256 accum_data_v; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = adj_dst_col_ptr; + + int row = params.start_row; + for (; row <= end_row - 8; row += 8) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm256_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + int d = 0; + for (; d <= params.depth - 4; d += 4) { + const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); + const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); + accum_data_v = intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0, + accum_data_v); + const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); + const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); + accum_data_v = intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1, + accum_data_v); + + const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); + const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); + accum_data_v = intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2, + accum_data_v); + const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); + const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); + accum_data_v = intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3, + accum_data_v); + lhs_ptr += 32; // Loaded 8 * 4 floats. + rhs_ptr += 32; + } + for (; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); + accum_data_v = + intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 8; + rhs_ptr += 8; + } + + accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); + _mm256_storeu_ps(dst_ptr, accum_data_v); + } // End row-block loop. + + if (row < end_row) { + const int residual_rows = end_row - row; + RUY_CHECK_GE(residual_rows, 1); + RUY_CHECK_LT(residual_rows, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm256_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); + accum_data_v = + intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 8; + rhs_ptr += 8; + } + + accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); + intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v); + } // End handling of residual rows. +} +} // namespace ruy +#endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM) + +#endif // RUY_RUY_KERNEL_X86_H_ diff --git a/ruy/mat.h b/ruy/mat.h new file mode 100644 index 0000000..587b208 --- /dev/null +++ b/ruy/mat.h @@ -0,0 +1,492 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Internal types and helpers for matrices. +// "Mat" is the name we use to refer to our internal matrix classes; it can be +// thought of as a shorter version of "InternalMatrix"` +// +// Ruy has four internal matrix classes, besides the +// Matrix<T> class that we expose to the user-facing API. +// +// TODO(silvasean): Put parts of this architecture description somewhere more +// prominent. +// +// The 4 internal matrix classes are named Mat, EMat, PMat, PEMat, where: +// - "E" indicates a type-erased class, storing a void* pointer and a runtime +// enum value to track the scalar type, as opposed to being templatized +// on a Scalar type and storing a Scalar* pointer. +// - "P" indicates a packed matrix class, the output of the packing code and +// input of the kernel code. See comments in pack.h regarding packing. +// +// In other words: +// +// Plain matrices Packed matrices +// +---------------------------------- +// Templated | Mat, Matrix PMat +// Type-erased | EMat PEMat +// +// Note that Matrix<T> is *not* implemented in terms of the internal types. It +// is an independent, simple, and user-facing type. Matrix<T> is functionally +// equivalent to Mat, but we keep it separate to insulate internals from +// interface and to be able to make different compromises in internals +// vs interface: in internals we prefer Mat to be a C-style struct with +// raw data member access and to be similar to the other PMat/EMat/PEMat +// classes for consistency. +// +// The use of type-erasure might seem surprising for a library like Ruy with a +// heavily-templated entry point, but it is motivated by the desire for most of +// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3 +// main parts: +// - "entry-point" (ruy.h) - this is the highly templated ruy::Mul entry +// point. +// - "front-end" (frontend.*, validate.*, create_trmul_params.*, +// prepare_packed_matrices.*) - the work to handle the entry-point call down to +// the point where it can be handed off to the middle/back ends below. That +// includes routines that select RunKernel and RunPack +// implementations statically based on those template parameters. +// - "back-end" (kernel_*.*, pack_*.*)- this consists of the implementations of +// RunKernel and RunPack, often in assembly code, which are the building blocks +// that Ruy calls to perform matrix multiplication. These are templated so that +// only the requested types/Path's are actually emitted by the compiler. +// - "middle-end" (trmul.*) - this is the part of Ruy that orchestrates the +// calls to the "back-end" optimized building blocks. This layer has to deal +// with issues like cache locality and low-overhead multi-threading. +// +// There is a desire for the "middle-end" to be non-templated in order to +// simplify the implementation and reduce code-size. We type-erase when going +// from the "front-end" to the "middle-end", and un-type-erase going from the +// "middle-end" to the "back-end". The un-type-erasure is possible because the +// "front-end" is responsible for instantiating the needed "back-end" templates, +// and thus the static type information is still present. +// +// Each layer of Ruy uses matrix types: +// - "entry-point": Matrix<T> +// - "front-end": Mat +// - "middle-end": EMat, PEMat +// - "back-end": Mat, PMat +// +// The use of separate types for packed matrices is not essential, but makes it +// obvious at a glance whether a matrix is a packed matrix or not. We would +// reconsider this decision if there was significant duplication between packed +// and unpacked matrices, but that doesn't seem to be the case at the moment. +// +// Another goal is to keep the user-facing Matrix<T> as simple and +// understandable as possible. Ideally, a user should be able to read the struct +// definition for Matrix<T> and see a very simple definition with no internal +// details like sums and kernel block layout. + +#ifndef RUY_RUY_INTERNAL_MATRIX_H_ +#define RUY_RUY_INTERNAL_MATRIX_H_ + +#include <cstddef> +#include <cstdint> +#include <type_traits> +#include <utility> + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/size_util.h" + +namespace ruy { + +// Internal counterpart of Layout, used by Mat. +struct MatLayout final { + std::int32_t rows = 0; + std::int32_t cols = 0; + // Stride is the offset between two adjacent matrix elements + // in the non-contiguous direction. + std::int32_t stride = 0; + Order order = Order::kColMajor; +}; + +inline MatLayout ToInternal(const Layout& src) { + MatLayout ret; + ret.rows = src.rows(); + ret.cols = src.cols(); + ret.stride = src.stride(); + ret.order = src.order(); + return ret; +} + +// Internal counterpart of Matrix +template <typename Scalar> +struct Mat final { + detail::ConstCheckingPtr<Scalar> data; + MatLayout layout; + Scalar zero_point = 0; + CachePolicy cache_policy = CachePolicy::kNeverCache; +}; + +template <typename Scalar> +inline Mat<Scalar> ToInternal(const Matrix<Scalar>& src) { + Mat<Scalar> ret; + ret.data.set(src.data()); + ret.layout = ToInternal(src.layout()); + ret.zero_point = src.zero_point(); + ret.cache_policy = src.cache_policy(); + return ret; +} + +template <typename Scalar> +inline Mat<Scalar> ToInternal(Matrix<Scalar>& src) { + Mat<Scalar> ret; + ret.data.set(src.data()); + ret.layout = ToInternal(src.layout()); + ret.zero_point = src.zero_point(); + ret.cache_policy = src.cache_policy(); + return ret; +} + +// KernelLayout describes small-scale block structure in a packed matrix layout. +// It's a runtime (as opposed to compile-time-constant) version of the +// FixedKernelLayout struct used to declare kernel layouts. +// +// This is is sometimes known as "tiling" in other contexts. +// +// For example, consider a packed matrix in column-major format with a +// column-major KernelLayout. The matrix logically has a shape of +// `[cols, rows]`. However, the matrix is laid out as though it were a 4D array +// of shape `[cols / kcols, rows / krows, kcols, krows]`. +// +// Note that in the case of kcols=1, krows=1, this degenerates to +// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block +// structure. +struct KernelLayout final { + Order order = Order::kColMajor; + std::uint8_t rows = 1; + std::uint8_t cols = 1; +}; + +// A packed matrix has a small-scale block structure that is not present in in +// the input matrices. This block structure is necessary for the kernels to +// process data efficiently. +// +// This struct is very similar to MatLayout, but has the extra KernelLayout +// field. +struct PMatLayout final { + std::int32_t rows = 0; + std::int32_t cols = 0; + // Stride is the offset between two adjacent matrix elements + // in the non-contiguous direction. + std::int32_t stride = 0; + Order order = Order::kColMajor; + // Small scale layout shuffling, potentially departing from + // linear row-major or column-major storage. See KernelLayout. + KernelLayout kernel; +}; + +inline bool operator==(const PMatLayout& a, const PMatLayout& b) { + return a.cols == b.cols && a.rows == b.rows && a.stride == b.stride && + a.order == b.order && a.kernel.rows == b.kernel.rows && + a.kernel.cols == b.kernel.cols && a.kernel.order == b.kernel.order; +} + +// Dynamic representation for a type. +// +// The most important field in this struct is the size, which Ruy uses to know +// how much memory to allocate without having to be templated on a type. +// Signed-ness and floating-point-ness are mainly present as debugging checks. +// +// Note: Ruy does not use this struct to to dynamically dispatch between +// different typed implementations. As described in the comment at the top of +// this file, Ruy's "front-end", which is templated, instantiates all the +// necessary "back-end" routines with complete static knowledge of all the +// types. +struct Type final { + template <typename T> + static Type Create() { + Type ret; + ret.is_signed = std::is_signed<T>::value; + ret.is_floating_point = std::is_floating_point<T>::value; + ret.size = sizeof(T); + return ret; + } + + template <typename T> + void AssertIs() const { + RUY_DCHECK_EQ(is_signed, Create<T>().is_signed); + RUY_DCHECK_EQ(is_floating_point, Create<T>().is_floating_point); + RUY_DCHECK_EQ(size, Create<T>().size); + } + + bool is_signed = false; + bool is_floating_point = false; + std::uint8_t size = 0; +}; + +inline bool operator==(const Type& type1, const Type& type2) { + return type1.is_signed == type2.is_signed && + type1.is_floating_point == type2.is_floating_point && + type1.size == type2.size; +} + +// Type-erased matrix. +struct EMat final { + Type data_type; + void* data = nullptr; + MatLayout layout; + std::int32_t zero_point = 0; + CachePolicy cache_policy = CachePolicy::kNeverCache; +}; + +// Type-erased packed matrix. +struct PEMat final { + Type data_type; + void* data = nullptr; + Type sums_type; + void* sums = nullptr; + PMatLayout layout; + std::int32_t zero_point = 0; +}; + +// Convenient typed helper for packed matrices. +template <typename Scalar> +struct PMat final { + // The row/column sums needed for quantized matrix multiplication when + // the opposite operand of the multiplication uses a non-symmetric zero + // point. + // This member is only relevant for packed matrices. + // Additionally, Ruy always uses 32-bit signed accumulators for quantized + // matrix multiplication. + // For floating point types, there is no quantization, so this pointer + // will always be null. We still need code referencing it to compile + // though, even if it is always branched around. Hence we use Scalar* + // itself as the type in that case. + using SumsType = + typename std::conditional<std::is_floating_point<Scalar>::value, Scalar, + std::int32_t>::type; + + Scalar* data = nullptr; + SumsType* sums = nullptr; + PMatLayout layout; + std::int32_t zero_point = 0; +}; + +template <typename T> +EMat EraseType(const Mat<T>& matrix) { + EMat ret; + ret.data_type = Type::Create<T>(); + ret.data = const_cast<void*>(static_cast<const void*>(matrix.data.get())); + ret.layout = matrix.layout; + ret.zero_point = matrix.zero_point; + ret.cache_policy = matrix.cache_policy; + return ret; +} + +template <typename T> +Mat<T> UneraseType(const EMat& matrix) { + matrix.data_type.AssertIs<T>(); + Mat<T> ret; + ret.data.set(static_cast<T*>(matrix.data)); + ret.layout = matrix.layout; + ret.zero_point = matrix.zero_point; + ret.cache_policy = matrix.cache_policy; + return ret; +} + +template <typename T> +PMat<T> UneraseType(const PEMat& matrix) { + using SumsType = typename PMat<T>::SumsType; + matrix.data_type.AssertIs<T>(); + matrix.sums_type.AssertIs<SumsType>(); + PMat<T> ret; + ret.data = static_cast<T*>(matrix.data); + ret.sums = static_cast<SumsType*>(matrix.sums); + ret.layout = matrix.layout; + ret.zero_point = matrix.zero_point; + return ret; +} + +// Helpers for MatLayout / PMatLayout. + +inline bool IsUnstrided(const MatLayout& layout) { + if (layout.order == Order::kColMajor) { + return layout.stride == layout.rows; + } else { + return layout.stride == layout.cols; + } +} + +inline bool IsRowMajor(const MatLayout& layout) { + return layout.order == Order::kRowMajor; +} + +inline bool IsColMajor(const MatLayout& layout) { + return layout.order == Order::kColMajor; +} + +inline int FlatSize(const MatLayout& layout) { + const int outerdim = + layout.order == Order::kColMajor ? layout.cols : layout.rows; + return layout.stride * outerdim; +} + +inline bool IsUnstrided(const PMatLayout& layout) { + if (layout.order == Order::kColMajor) { + return layout.stride == layout.rows; + } else { + return layout.stride == layout.cols; + } +} + +inline bool IsRowMajor(const PMatLayout& layout) { + return layout.order == Order::kRowMajor; +} + +inline bool IsColMajor(const PMatLayout& layout) { + return layout.order == Order::kColMajor; +} + +inline int FlatSize(const PMatLayout& layout) { + const int outerdim = + layout.order == Order::kColMajor ? layout.cols : layout.rows; + return layout.stride * outerdim; +} + +// TODO(b/130417400) add a unit test +inline int Offset(const MatLayout& layout, int row, int col) { + // TODO(benoitjacob) - should check this but this make the _slow tests take + // 5x longer. Find a mitigation like in Eigen with an 'internal' variant + // bypassing the check? + // RUY_DCHECK_GE(row, 0); + // RUY_DCHECK_GE(col, 0); + // RUY_DCHECK_LT(row, layout.rows); + // RUY_DCHECK_LT(col, layout.cols); + int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride; + int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride; + return row * row_stride + col * col_stride; +} + +// TODO(b/130417400) add a unit test +inline int Offset(const PMatLayout& layout, int row, int col) { + RUY_DCHECK(is_pot(layout.kernel.rows)); + RUY_DCHECK(is_pot(layout.kernel.cols)); + int row_outer = row & ~(layout.kernel.rows - 1); + int col_outer = col & ~(layout.kernel.cols - 1); + int row_stride_outer = + layout.order == Order::kColMajor ? layout.kernel.cols : layout.stride; + int col_stride_outer = + layout.order == Order::kRowMajor ? layout.kernel.rows : layout.stride; + int offset_outer = + row_outer * row_stride_outer + col_outer * col_stride_outer; + int row_inner = row - row_outer; + int col_inner = col - col_outer; + int row_stride_inner = + layout.kernel.order == Order::kColMajor ? 1 : layout.kernel.cols; + int col_stride_inner = + layout.kernel.order == Order::kRowMajor ? 1 : layout.kernel.rows; + int offset_inner = + row_inner * row_stride_inner + col_inner * col_stride_inner; + return offset_outer + offset_inner; +} + +// Helpers for Mat<T>. + +template <typename Scalar> +const Scalar* ElementPtr(const Mat<Scalar>& mat, int row, int col) { + return mat.data.get() + Offset(mat.layout, row, col); +} + +template <typename Scalar> +Scalar* ElementPtr(Mat<Scalar>* mat, int row, int col) { + return mat->data.get() + Offset(mat->layout, row, col); +} + +template <typename Scalar> +Scalar Element(const Mat<Scalar>& mat, int row, int col) { + return *ElementPtr(mat, row, col); +} + +// Helpers for PMat<T>. +// Duplicated from Matrix<T>, but the duplication seems acceptable. + +template <typename Scalar> +const Scalar* ElementPtr(const PMat<Scalar>& mat, int row, int col) { + return mat.data + Offset(mat.layout, row, col); +} + +template <typename Scalar> +Scalar* ElementPtr(PMat<Scalar>* mat, int row, int col) { + return mat->data + Offset(mat->layout, row, col); +} + +template <typename Scalar> +Scalar Element(const PMat<Scalar>& mat, int row, int col) { + return *ElementPtr(mat, row, col); +} + +// Helpers for PEMat. + +inline int DataBytes(const PEMat& packed) { + return FlatSize(packed.layout) * packed.data_type.size; +} + +inline int SumsBytes(const PEMat& packed) { + // Packed matrices are only relevant for Ruy's TrMul implementations. For + // TrMul, the number of sums is always equal to the number of columns. + return packed.layout.cols * packed.sums_type.size; +} + +// Transpose helpers. + +inline Order Transpose(Order order) { + return order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor; +} + +inline MatLayout Transpose(const MatLayout& layout) { + MatLayout result(layout); + result.order = Transpose(result.order); + std::swap(result.rows, result.cols); + return result; +} + +template <typename Scalar> +Mat<Scalar> Transpose(const Mat<Scalar>& matrix) { + Mat<Scalar> result(matrix); + result.layout = Transpose(result.layout); + return result; +} + +// Compile-time version of KernelLayout, used to declare kernel layouts in a +// way that can be consumed by compile-time logic. +template <Order tOrder, int tRows, int tCols> +struct FixedKernelLayout { + static constexpr Order kOrder = tOrder; + static constexpr int kRows = tRows; + static constexpr int kCols = tCols; +}; + +template <typename FixedKernelLayout> +KernelLayout ToKernelLayout() { + KernelLayout ret; + ret.order = FixedKernelLayout::kOrder; + ret.rows = FixedKernelLayout::kRows; + ret.cols = FixedKernelLayout::kCols; + return ret; +} + +#if (__cplusplus < 201703L) +// A static constexpr data member is automatically inline and should not require +// redeclaration without an initializer. This is actually deprecated from C++17 +// onwards. Clang with -O0 without this can fail to link. +template <Order tOrder, int tRows, int tCols> +constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kCols; +template <Order tOrder, int tRows, int tCols> +constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kRows; +#endif + +} // namespace ruy + +#endif // RUY_RUY_INTERNAL_MATRIX_H_ diff --git a/ruy/matrix.h b/ruy/matrix.h new file mode 100644 index 0000000..c9353e6 --- /dev/null +++ b/ruy/matrix.h @@ -0,0 +1,218 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_MATRIX_H_ +#define RUY_RUY_MATRIX_H_ + +#include <cstddef> +#include <cstdint> // IWYU pragma: keep +#include <type_traits> + +#include "ruy/check_macros.h" + +namespace ruy { + +// Layout storage order. Here and elsewhere, 'col' is short for 'column'. +// 'column-major' means that each column is contiguous in memory. +enum class Order : std::uint8_t { kColMajor, kRowMajor }; + +// Describes the shape and storage layout of a matrix. +class Layout final { + public: + int rows() const { return rows_; } + void set_rows(int val) { rows_ = val; } + int cols() const { return cols_; } + void set_cols(int val) { cols_ = val; } + int stride() const { return stride_; } + void set_stride(int val) { stride_ = val; } + Order order() const { return order_; } + void set_order(Order val) { order_ = val; } + + private: + int rows_ = 0; + int cols_ = 0; + // Stride is the offset between two adjacent matrix elements + // in the non-contiguous direction. + int stride_ = 0; + Order order_ = Order::kColMajor; +}; + +namespace detail { + +// Thin wrapper around a pointer with a constness model that works for the +// purposes of the Matrix class. +// +// A typical conundrum of any C++ container class is what type constness should +// encode at compile time constancy of the contained data? +// `Matrix<const T>` or `const Matrix<T>`? +// With either approach it is very difficult to achieve perfect +// const-correctness that that can only be done with some combination of +// inconvenient interface and c++ complexity/abstraction. +// +// Here we opt for something that's entirely tailored to the needs of the Ruy +// interface. The only purpose of the Matrix class is to pass matrix data +// pointers to ruy. There is an asymmetry here: the caller of ruy::Mul only +// needs to `set` the data; ruy itself only needs to `get` the data. In the +// caller's code, it's convenient to be able to just deal with `Matrix<T>` +// without having to sprinkle `const` keywords in the right places, so we want +// to track whether the data is constant in a way that's decoupled from the +// constness of `this`, and we never want to have Matrix<const T>. Inside ruy +// code, as input matrices are passed by const-reference and output matrices are +// passed by pointer (to non-const), the constness of `this` is telling whether +// the data is constant. See the `get` and `set` methods below and the comment +// explaining the core logic that they encapsulate. +template <typename T> +class ConstCheckingPtr final { + public: + using element_type = T; + + // Core accessors. These encapsulate the main logic: + // - for `set`, the constness of the argument determines whether internal + // pointer should be tracked as const/mutable. + // - for `get`, the constness of `this` determines whether the call + // counts as a const or mutable use of the internal pointer. + void set(T* ptr) { + ptr_ = ptr; + set_mutable(true); + } + void set(const T* ptr) { + ptr_ = ptr; + set_mutable(false); + } + void set(std::nullptr_t) { ptr_ = nullptr; } + T* get() /* NOT const */ { + assert_mutable(); + return const_cast<T*>(ptr_); + } + const T* get() const { return ptr_; } + + private: + // There's never a need for Matrix<const T>. + static_assert(!std::is_const<T>::value, ""); + const T* ptr_ = nullptr; +#ifndef NDEBUG + bool is_mutable_ = true; + void set_mutable(bool val) { is_mutable_ = val; } + void assert_mutable() { RUY_DCHECK(is_mutable_); } +#else + void set_mutable(bool) {} + void assert_mutable() {} +#endif +}; + +} // namespace detail + +enum class CachePolicy : std::uint8_t { + kNeverCache, + kCacheIfLargeSpeedup, + kCacheIfSignificantSpeedup, + kAlwaysCache, +}; + +// A Matrix merely wraps existing data as a matrix. It doesn't own any buffer. +// The purpose of Matrix is only to be used in ruy's interface -- it's just +// a structured way for the user to pass to ruy::Mul the matrix data pointers +// together with other matrix parameters. +// Scalar may be any floating-point or integral type. When integral, it may be +// signed or unsigned. It's never const: use Matrix<T> for both input and output +// matrices, never use Matrix<const T>. +// See the comments on detail::ConstCheckingPointer. +template <typename Scalar> +class Matrix final { + public: + static_assert(!std::is_const<Scalar>::value, + "Never use Matrix<const T>. Just use Matrix<T>. Constness of " + "the data is guarded by debug-only runtime assertions. See " + "detail::ConstCheckingPtr."); + + Scalar* data() { return data_.get(); } + const Scalar* data() const { return data_.get(); } + void set_data(Scalar* ptr) { data_.set(ptr); } + void set_data(const Scalar* ptr) { data_.set(ptr); } + void set_data(std::nullptr_t) { data_.set(nullptr); } + const Layout& layout() const { return layout_; } + Layout* mutable_layout() { return &layout_; } + Scalar zero_point() const { return zero_point_; } + void set_zero_point(Scalar value) { zero_point_ = value; } + CachePolicy cache_policy() const { return cache_policy_; } + void set_cache_policy(CachePolicy value) { cache_policy_ = value; } + + private: + // The underlying buffer wrapped by this matrix. + detail::ConstCheckingPtr<Scalar> data_; + // The shape and data layout of this matrix. + Layout layout_; + // The zero_point, i.e. which Scalar value is to be interpreted as zero. + // When Scalar is floating-point, this must be 0. + Scalar zero_point_ = 0; + // When the data pointed to by this matrix is constant data, so that it is + // valid to assume that equality of pointers implies equality of data, + // a CachePolicy may be used instead of the default kNeverCache, + // which will enable ruy to take advantage of this constancy of the data to + // cache the packing work, which can be a large speedup in matrix*vector + // and other narrow shapes. + CachePolicy cache_policy_ = CachePolicy::kNeverCache; +}; + +inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { + layout->set_rows(rows); + layout->set_cols(cols); + layout->set_order(order); + layout->set_stride(order == Order::kColMajor ? rows : cols); +} + +template <typename StreamType, typename Scalar> +StreamType& operator<<(StreamType& stream, const Matrix<Scalar>& mat) { + for (int row = 0; row < mat.layout().rows(); row++) { + for (int col = 0; col < mat.layout().cols(); col++) { + stream << static_cast<double>(Element(mat, row, col)) << " "; + } + stream << "\n"; + } + return stream; +} + +// TODO(b/130417400) add a unit test +inline int Offset(const Layout& layout, int row, int col) { + // TODO(benoitjacob) - should check this but this make the _slow tests take + // 5x longer. Find a mitigation like in Eigen with an 'internal' variant + // bypassing the check? + // RUY_DCHECK_GE(row, 0); + // RUY_DCHECK_GE(col, 0); + // RUY_DCHECK_LT(row, layout.rows()); + // RUY_DCHECK_LT(col, layout.cols()); + int row_stride = layout.order() == Order::kColMajor ? 1 : layout.stride(); + int col_stride = layout.order() == Order::kRowMajor ? 1 : layout.stride(); + return row * row_stride + col * col_stride; +} + +template <typename Scalar> +const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) { + return mat.data() + Offset(mat.layout(), row, col); +} + +template <typename Scalar> +Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) { + return mat->data() + Offset(mat->layout(), row, col); +} + +template <typename Scalar> +Scalar Element(const Matrix<Scalar>& mat, int row, int col) { + return *ElementPtr(mat, row, col); +} + +} // namespace ruy + +#endif // RUY_RUY_MATRIX_H_ diff --git a/ruy/matrix_test.cc b/ruy/matrix_test.cc new file mode 100644 index 0000000..0f3fd13 --- /dev/null +++ b/ruy/matrix_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/matrix.h" + +#include "ruy/gtest_wrapper.h" + +namespace ruy { +namespace { + +TEST(MatrixTest, LayoutClassSanity) { + Layout layout; + EXPECT_EQ(layout.rows(), 0); + EXPECT_EQ(layout.cols(), 0); + EXPECT_EQ(layout.stride(), 0); + EXPECT_EQ(layout.order(), Order::kColMajor); + layout.set_rows(123); + layout.set_cols(456); + layout.set_stride(789); + layout.set_order(Order::kRowMajor); + EXPECT_EQ(layout.rows(), 123); + EXPECT_EQ(layout.cols(), 456); + EXPECT_EQ(layout.stride(), 789); + EXPECT_EQ(layout.order(), Order::kRowMajor); +} + +TEST(MatrixTest, MakeSimpleLayout) { + Layout layout; + MakeSimpleLayout(123, 456, Order::kColMajor, &layout); + EXPECT_EQ(layout.rows(), 123); + EXPECT_EQ(layout.cols(), 456); + EXPECT_EQ(layout.stride(), 123); + EXPECT_EQ(layout.order(), Order::kColMajor); + MakeSimpleLayout(321, 654, Order::kRowMajor, &layout); + EXPECT_EQ(layout.rows(), 321); + EXPECT_EQ(layout.cols(), 654); + EXPECT_EQ(layout.stride(), 654); + EXPECT_EQ(layout.order(), Order::kRowMajor); +} + +TEST(MatrixTest, ConstCheckingPtrSanity) { + using PtrType = detail::ConstCheckingPtr<int>; + PtrType ptr; + int some_nonconst; + const int some_const = 0; + EXPECT_EQ(ptr.get(), nullptr); + ptr.set(&some_nonconst); + EXPECT_EQ(static_cast<const PtrType&>(ptr).get(), &some_nonconst); + EXPECT_EQ(ptr.get(), &some_nonconst); + ptr.set(&some_const); + EXPECT_EQ(static_cast<const PtrType&>(ptr).get(), &some_const); +#ifndef NDEBUG + RUY_ASSERT_DEATH(ptr.get(), ""); +#endif +} + +TEST(MatrixTest, MatrixClassSanity) { + Matrix<int> matrix; + EXPECT_EQ(matrix.data(), nullptr); + EXPECT_EQ(matrix.zero_point(), 0); + EXPECT_EQ(matrix.cache_policy(), CachePolicy::kNeverCache); + EXPECT_EQ(matrix.layout().rows(), 0); + EXPECT_EQ(matrix.layout().cols(), 0); + EXPECT_EQ(matrix.layout().stride(), 0); + EXPECT_EQ(matrix.layout().order(), Order::kColMajor); + const int some_const = 0; + matrix.set_data(&some_const); + matrix.set_zero_point(123); + matrix.set_cache_policy(CachePolicy::kAlwaysCache); + MakeSimpleLayout(12, 34, Order::kRowMajor, matrix.mutable_layout()); + EXPECT_EQ(static_cast<const Matrix<int>&>(matrix).data(), &some_const); +#ifndef NDEBUG + RUY_ASSERT_DEATH(matrix.data(), ""); +#endif + EXPECT_EQ(matrix.zero_point(), 123); + EXPECT_EQ(matrix.cache_policy(), CachePolicy::kAlwaysCache); + EXPECT_EQ(matrix.layout().rows(), 12); + EXPECT_EQ(matrix.layout().cols(), 34); + EXPECT_EQ(matrix.layout().stride(), 34); + EXPECT_EQ(matrix.layout().order(), Order::kRowMajor); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/mul_params.h b/ruy/mul_params.h new file mode 100644 index 0000000..d5aa27b --- /dev/null +++ b/ruy/mul_params.h @@ -0,0 +1,299 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_MUL_PARAMS_H_ +#define RUY_RUY_MUL_PARAMS_H_ + +#include <cstdint> +#include <limits> +#include <type_traits> + +#include "ruy/check_macros.h" +#include "ruy/size_util.h" + +namespace ruy { + +// Enumeration to designate which dimension is the 'channels', for MulParams +// features that are 'per-channel', namely the bias-vector and the quantized +// multiplier. +enum class ChannelDimension : std::int8_t { + // kRow means that 'per-channel' means 'per row of the destination matrix' + kRow, + // kCol means that 'per-channel' means 'per column of the destination matrix' + kCol +}; + +namespace detail { +template <typename tAccumScalar, typename tDstScalar> +struct MulParamsStorage; +} + +// MulParams describes all about a matrix multiplication that +// isn't encoded in the LHS, RHS and destination matrices. Some of that +// information is encoded as compile-time constants and types (for instance, the +// choice of accumulator type, AccumScalar). Some of that information is encoded +// as runtime values (for instance, the optional bias vector). +// +// Template parameters: +// AccumScalar: Accumulator type. The type of accumulators used to compute the +// dot-products before being ultimately casted to the destination type. +// DstScalar: The destination scalar type. +// +// Constraints on these template parameters (see also the ruy::Mul comment): +// * If DstScalar is floating-point then AccumScalar must also be. +// * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover +// in that integral case, there is a mode switch: +// - If DstScalar is std::int32_t then the multiplier_* fields are all +// disabled, and ruy::Mul will just return raw (unscaled) accumulators. +// - If DstScalar is not std::int32_t then the multiplier_* fields are +// enabled, and ruy::Mul will use them to scale internal std::int32_t +// accumulators before casting them to the DstScalar type. The default +// values are such that the effective multiplier is 1 (no scaling). +// +// For the latter case (DstScalar integral and narrower than std::int32_t), +// reference code can be found in the implementation of ruy::ApplyMultiplier. +// If you look there, you'll find warnings like this: +// +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Warning: this code is not meant to be bit-exact-normative. +// Please refer to the class comment of ruy::MulParams, in mul_params.h. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// +// The explanation of this warning is that as of early 2021, we still don't know +// whether it is advisable to let this code as-is have normative value, or +// whether that would become advisable after some specific final change. +// +// Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform +// bit-exactly to this reference, but we also know that x86 could be faster if +// it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't +// know that this particular reference code is inherently better than other +// forms that could perform better on these architectures --- in fact, the +// alternative that was proposed in [2] as better performing on ARM Cortex-M +// is also inherently more accurate thanks to rounding only once, but it would +// perform worse on both ARM NEON, and x86. +// +// In fact, if we look at other hardware architectures beyond current Ruy +// targets, namely "hardware accelerators", it becomes clear that there is no +// hope for any form of this to be efficiently implementable simultaneously on +// all current relevant hardware. Indeed, some accelerators prefer to perform +// the multiplication in IEEE float32, others in IEEE float16, others in +// bfloat16, others in 16-bit fixed-point... +// +// See: +// [1] https://github.com/google/ruy/pull/227 +// [2] https://github.com/tensorflow/tensorflow/issues/25087 +template <typename tAccumScalar, typename tDstScalar> +class MulParams final { + public: + using AccumScalar = tAccumScalar; + using DstScalar = tDstScalar; + + // The bias vector data, if not null. + const AccumScalar* bias() const { return storage_.bias; } + void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; } + // Only for non-floating-point cases. The fixed-point part of the multiplier + // by which accumulators are multiplied before being casted to the destination + // type. This is a fixed-point quantity with 0 integer bits. Since + // (as explained in the class comment) AccumScalar must be std::int32_t, + // that means that the fixed-point format is Q0.31. For example, + // a multiplier_fixedpoint value of 2^30 has the effect of multiplying + // by one half (1/2). More generally, the effect is to multiply by + // (multiplier_fixedpoint / (2^31)). + AccumScalar multiplier_fixedpoint() const { + return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint; + } + void set_multiplier_fixedpoint(const AccumScalar value) { + set_perchannel(false); + storage_.multiplier_fixedpoint = value; + } + // Only for non-floating-point cases. The exponent part of the aforementioned + // multiplier. + int multiplier_exponent() const { + return storage_.perchannel ? 0 : storage_.multiplier_exponent; + } + void set_multiplier_exponent(const int value) { + set_perchannel(false); + storage_.multiplier_exponent = value; + } + // Per-channel variant of multiplier_fixedpoint. Setting this switches + // to per-channel mode, where `multiplier_fixedpoint` and + // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel` + // and `multiplier_exponent_perchannel` are used instead. + // + // This must point to a buffer of as many values as there are rows or columns + // in the destination matrix, whichever is the channels dimension. Each + // channel of the destination matrix will use the corresponding buffer element + // instead of multiplier_fixedpoint. + const AccumScalar* multiplier_fixedpoint_perchannel() const { + return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel + : nullptr; + } + void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) { + set_perchannel(true); + storage_.multiplier_fixedpoint_perchannel = ptr; + } + // Per-channel variant of multiplier_exponent. Same comments as for + // multiplier_fixedpoint_perchannel. + const int* multiplier_exponent_perchannel() const { + return storage_.perchannel ? storage_.multiplier_exponent_perchannel + : nullptr; + } + void set_multiplier_exponent_perchannel(const int* ptr) { + set_perchannel(true); + storage_.multiplier_exponent_perchannel = ptr; + } + // min clamp bound of destination values. + DstScalar clamp_min() const { return storage_.clamp_min; } + void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; } + // max clamp bound of destination values. + DstScalar clamp_max() const { return storage_.clamp_max; } + void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; } + // Designates which dimension is the 'channels', for per-channel features + // such as bias-addition and per-channel quantization multipliers. + ChannelDimension channel_dimension() const { + return storage_.channel_dimension; + } + void set_channel_dimension(ChannelDimension value) { + storage_.channel_dimension = value; + } + // Specifies the upward rounding of the allocated capacity of per-channel + // buffers such as bias vectors and per-channel quantization multipliers. + // The unit is matrix entries, not bytes. + // + // This value must be a power of two. + // + // The default value, 1, means no upward rounding, meaning that the buffers + // are not required to have a capacity greater than the size of the + // corresponding matrix dimension, i.e. the number of rows (respectively + // columns) of the destination matrix if `channel_dimension()` is kRow + // (respectively kCol). + // + // Higher values allow the implementation to assume that it is OK to access + // these buffers a little past this boundary, which is useful in SIMD + // optimized kernels. In practice, when this value is lower than what the + // kernel requires, ruy has to internally reallocate and copy per-channel + // buffers. When this value is high enough, this reallocation and copy is + // avoided. + // + // When a value greater than 1 is specified, the tail region of the buffer + // (past the end of the values actually corresponding to channels) is required + // to be zero-initialized. + // + // As of 2020, values as high as 16 may be useful on some CPU architectures + // (corresponding to the widest kernels used on any CPU architecture). + int perchannel_buffers_capacity_rounding() const { + return 1 << storage_.perchannel_buffers_capacity_rounding_log2; + } + void set_perchannel_buffers_capacity_rounding(int value) { + // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two. + storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value); + } + + private: + detail::MulParamsStorage<AccumScalar, DstScalar> storage_; + + void set_perchannel(bool perchannel) { + storage_.perchannel = perchannel; + } +}; + +namespace detail { + +// Floating-point case. +template <typename AccumScalar, typename DstScalar> +struct MulParamsStorage final { + static_assert(std::is_floating_point<AccumScalar>::value, ""); + static_assert(std::is_floating_point<DstScalar>::value, ""); + static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), ""); + + const AccumScalar* bias = nullptr; + DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity(); + DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity(); + ChannelDimension channel_dimension = ChannelDimension::kRow; + std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; + + // Data members that are disabled in this case are left as `static constexpr` + // so that one can write some generic code. + static constexpr const AccumScalar* multiplier_fixedpoint_perchannel = + nullptr; + static constexpr const int* multiplier_exponent_perchannel = nullptr; + static constexpr AccumScalar multiplier_fixedpoint = 0; + static constexpr int multiplier_exponent = 0; + static constexpr bool perchannel = false; +}; + +// Specialization for the integer-quantized type, with down-quantization of +// int32 accumulators to a narrower destination scalar type. +template <typename DstScalar> +struct MulParamsStorage<std::int32_t, DstScalar> final { + using AccumScalar = std::int32_t; + static_assert(std::is_integral<DstScalar>::value, ""); + static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, ""); + + const AccumScalar* bias = nullptr; + union { + const AccumScalar* multiplier_fixedpoint_perchannel; + // Let the default multiplier be effecively a multiplication by 1, so that + // the matmul behaves as a (saturating) plain integer matmul. Unfortunately + // 1 is not exactly representable in fixedpoint with 0 integer bits, but + // using the highest representable value is a sufficiently good + // approximation: since this specialization of MulParams is for the case + // where DstScalar is at least 2x narrower than MulScalar, the values + // for which there would be a difference will get saturated anyway. + AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max(); + }; + union { + const int* multiplier_exponent_perchannel; + // See the above comment about the default value of multiplier_fixedpoint. + int multiplier_exponent = 0; + }; + DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest(); + DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); + ChannelDimension channel_dimension = ChannelDimension::kRow; + bool perchannel = false; + std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; +}; + +// Specialization used in the integer case when outputting raw int32 +// accumulators, without down-quantization to a narrower destination scalar +// type. In this case, the feature of clamping destination values is not +// available. +template <> +struct MulParamsStorage<std::int32_t, std::int32_t> final { + using AccumScalar = std::int32_t; + using DstScalar = std::int32_t; + + const AccumScalar* bias = nullptr; + ChannelDimension channel_dimension = ChannelDimension::kRow; + std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; + + // Data members that are disabled in this case are left as `static constexpr` + // so that one can write some generic code. + static constexpr const AccumScalar* multiplier_fixedpoint_perchannel = + nullptr; + static constexpr const int* multiplier_exponent_perchannel = nullptr; + static constexpr AccumScalar multiplier_fixedpoint = 0; + static constexpr int multiplier_exponent = 0; + static constexpr DstScalar clamp_min = + std::numeric_limits<DstScalar>::lowest(); + static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); + static constexpr bool perchannel = false; +}; + +} // namespace detail + +} // namespace ruy + +#endif // RUY_RUY_MUL_PARAMS_H_ diff --git a/ruy/mul_params_test.cc b/ruy/mul_params_test.cc new file mode 100644 index 0000000..4bc9f87 --- /dev/null +++ b/ruy/mul_params_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/mul_params.h" + +#include <cstdint> +#include <type_traits> + +#include "ruy/gtest_wrapper.h" + +namespace ruy { +namespace { + +TEST(MulParamsTest, SpecClassSanity) { + using MulParamsType = MulParams<std::int32_t, std::int8_t>; + static_assert(std::is_same<MulParamsType::AccumScalar, std::int32_t>::value, + ""); + static_assert(std::is_same<MulParamsType::DstScalar, std::int8_t>::value, ""); + + MulParamsType mul_params; + EXPECT_EQ(mul_params.bias(), nullptr); + EXPECT_EQ(mul_params.multiplier_fixedpoint(), std::numeric_limits<std::int32_t>::max()); + EXPECT_EQ(mul_params.multiplier_exponent(), 0); + EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), nullptr); + EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), nullptr); + EXPECT_EQ(mul_params.clamp_min(), -128); + EXPECT_EQ(mul_params.clamp_max(), 127); + EXPECT_EQ(mul_params.channel_dimension(), ChannelDimension::kRow); + EXPECT_EQ(mul_params.perchannel_buffers_capacity_rounding(), 1); + std::int32_t bias_data[1]; + mul_params.set_bias(bias_data); + mul_params.set_multiplier_fixedpoint(123); + mul_params.set_multiplier_exponent(4); + mul_params.set_channel_dimension(ChannelDimension::kCol); + mul_params.set_perchannel_buffers_capacity_rounding(8); + EXPECT_EQ(mul_params.bias(), bias_data); + EXPECT_EQ(mul_params.multiplier_fixedpoint(), 123); + EXPECT_EQ(mul_params.multiplier_exponent(), 4); + EXPECT_EQ(mul_params.channel_dimension(), ChannelDimension::kCol); + EXPECT_EQ(mul_params.perchannel_buffers_capacity_rounding(), 8); + mul_params.set_multiplier_fixedpoint(0); + mul_params.set_multiplier_exponent(0); + std::int32_t multiplier_fixedpoint_perchannel_data[1]; + int multiplier_exponent_perchannel_data[1]; + mul_params.set_multiplier_fixedpoint_perchannel( + multiplier_fixedpoint_perchannel_data); + mul_params.set_multiplier_exponent_perchannel( + multiplier_exponent_perchannel_data); + mul_params.set_clamp_min(-10); + mul_params.set_clamp_max(10); + EXPECT_EQ(mul_params.multiplier_fixedpoint(), 0); + EXPECT_EQ(mul_params.multiplier_exponent(), 0); + EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), + multiplier_fixedpoint_perchannel_data); + EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), + multiplier_exponent_perchannel_data); + EXPECT_EQ(mul_params.clamp_min(), -10); + EXPECT_EQ(mul_params.clamp_max(), 10); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/opt_set.h b/ruy/opt_set.h new file mode 100644 index 0000000..244d9da --- /dev/null +++ b/ruy/opt_set.h @@ -0,0 +1,51 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_OPT_SET_H_ +#define RUY_RUY_OPT_SET_H_ + +// RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling +// certain optimizations. It should be used by defining that macro on the +// compiler command line. +// +// Each bit in RUY_OPT_SET controls a particular optimization done in Ruy. +#define RUY_OPT_BIT_INTRINSICS 0x1 +#define RUY_OPT_BIT_ASM 0x2 +#define RUY_OPT_BIT_TUNING 0x4 +#define RUY_OPT_BIT_FAT_KERNEL 0x8 +// 0x10 used to be RUY_OPT_BIT_NATIVE_ROUNDING +#define RUY_OPT_BIT_AVOID_ALIASING 0x20 +#define RUY_OPT_BIT_MAX_STREAMING 0x40 +#define RUY_OPT_BIT_PACK_AHEAD 0x80 +#define RUY_OPT_BIT_PREFETCH_LOAD 0x100 +#define RUY_OPT_BIT_PREFETCH_STORE 0x200 +#define RUY_OPT_BIT_FRACTAL_Z 0x400 +#define RUY_OPT_BIT_FRACTAL_U 0x800 +#define RUY_OPT_BIT_FRACTAL_HILBERT 0x1000 + +#if !defined(RUY_OPT_SET) +#ifdef RUY_OPTIMIZE_FOR_MATMUL_BENCHMARK +// Load prefetching is detrimental in matrix multiplication benchmarks. +// Store prefetching is not. +#define RUY_OPT_SET (~RUY_OPT_BIT_PREFETCH_LOAD) +#else +// Default to all optimizations. +#define RUY_OPT_SET (~0) +#endif +#endif + +#define RUY_OPT(X) ((RUY_OPT_SET & RUY_OPT_BIT_##X) != 0) + +#endif // RUY_RUY_OPT_SET_H_ diff --git a/ruy/pack.h b/ruy/pack.h new file mode 100644 index 0000000..744f9bc --- /dev/null +++ b/ruy/pack.h @@ -0,0 +1,155 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy has an optional feature, accessed by Matrix::set_cache_policy(), to +// cache the packed forms of constant matrices. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef RUY_RUY_PACK_H_ +#define RUY_RUY_PACK_H_ + +#include "ruy/mat.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/trace.h" + +// IWYU pragma: begin_exports +#if RUY_PLATFORM_NEON +#include "ruy/pack_arm.h" +#elif RUY_PLATFORM_X86 +#include "ruy/pack_x86.h" +#endif +// IWYU pragma: end_exports + +namespace ruy { + +// General implementation of the PackImpl template, overridden by template +// specializations for specific SIMD code paths. This general implementation +// is used with Path::kStandardCpp and its internal test-only variants. +template <Path ThePath, typename FixedKernelLayout, typename Scalar, + typename PackedScalar, typename SumsType, Order SrcOrder> +struct PackImpl { + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<PackedScalar>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (generic)"); + RUY_DCHECK_EQ(SrcOrder, src_matrix.layout.order); + RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0); + SumsType* sums = packed_matrix->sums; + for (int col = start_col; col < end_col; col++) { + SumsType accum = 0; + for (int row = 0; row < packed_matrix->layout.rows; row++) { + PackedScalar packed_val; + if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) { + packed_val = Pack<PackedScalar>(Element(src_matrix, row, col)); + } else { + packed_val = packed_matrix->zero_point; + } + accum += packed_val; + *ElementPtr(packed_matrix, row, col) = packed_val; + } + if (sums) { + sums[col] = accum; + } + } + } +}; + +// Main entry point for packing. +template <Path ThePath, typename FixedKernelLayout, typename Scalar, + typename PackedScalar> +void RunPack(Tuning tuning, const EMat& src_matrix, PEMat* packed_matrix, + int start_col, int end_col) { + RUY_TRACE_SCOPE; + using SumsType = typename PMat<PackedScalar>::SumsType; + Mat<Scalar> src = UneraseType<Scalar>(src_matrix); + PMat<PackedScalar> packed = UneraseType<PackedScalar>(*packed_matrix); + RUY_TRACE_INFO(RUN_PACK); + if (src.layout.order == Order::kColMajor) { + PackImpl<ThePath, FixedKernelLayout, Scalar, PackedScalar, SumsType, + Order::kColMajor>::Run(tuning, src, &packed, start_col, end_col); + } else { + PackImpl<ThePath, FixedKernelLayout, Scalar, PackedScalar, SumsType, + Order::kRowMajor>::Run(tuning, src, &packed, start_col, end_col); + } +} + +} // namespace ruy + +#endif // RUY_RUY_PACK_H_ diff --git a/ruy/pack_arm.cc b/ruy/pack_arm.cc new file mode 100644 index 0000000..c337986 --- /dev/null +++ b/ruy/pack_arm.cc @@ -0,0 +1,2480 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/pack_arm.h" + +#include <cstdint> + +#include "ruy/asm_helpers.h" +#include "ruy/opt_set.h" +#include "ruy/pack_common.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM_NEON +#include <arm_neon.h> +#endif + +namespace ruy { + +#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) + +void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + int input_xor) { + profiler::ScopeLabel label("Pack (kNeon)"); + asm volatile( + // clang-format off + // v26 will be the vector to XOR input values with to perform + // any input data type conversion (e.g. uint8 to int8). + "dup v26.16b, %w[input_xor]\n" + // w1 will be the number of rows already loaded. + "mov w1, #0\n" + // v28--v32 will be used to accumulate the sums + "movi v28.4s, #0\n" + "movi v29.4s, #0\n" + "movi v30.4s, #0\n" + "movi v31.4s, #0\n" + // Let w2 be `rows` rounded down to multiple of 16. + "ands w2, %w[rows], #-16\n" + // If there are no full blocks of 16 rows to process, jump to the + // code handling the last < 16 rows. + "beq 3f\n" + // Load the first block of 16 rows. + "add w1, w1, #16\n" + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + // Check if these were the only full block of 16 rows to load. + "cmp w1, w2\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + // In that case, jump to the code handling the last loaded block of + // 16 rows. + "beq 2f\n" + // Main loop processing blocks of 16 rows. + "1:\n" + // Load the next 16 rows, interleaved with the XOR input type + // conversion (e.g. uint8->int8) on the already loaded inputs. + "add w1, w1, #16\n" + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + // Compute the sums, interleaved with storing to the packed matrix. + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + // Was this the last block of 16 rows to load? + "cmp w1, w2\n" + "sadalp v29.4s, v17.8h\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + // End of main loop on blocks of 16 rows. + "bne 1b\n" + + // Code handling the last already-loaded block of 16 rows. + "2:\n" + + // Process the last loaded full 16x4 block. + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "add %[packed_ptr], %[packed_ptr], #64\n" + + // End of code handling full blocks of 16 rows. + // Now we handle any remaining rows. + "3:\n" + // Let w2 be the number of rows left to handle. + "ands w2, %w[rows], #15\n" + // If w2==0, there are no remaining rows, jump to the end. + "beq 4f\n" + // Zero out a 16x4 block in registers, which we'll partially overwrite + // with any remaining rows. + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op. +#undef RUY_LOAD_ONE_ROW + "5:\n" + + // Process the last zero-padded 16x4 block. + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "saddlp v17.8h, v5.16b\n" + "saddlp v18.8h, v6.16b\n" + "saddlp v19.8h, v7.16b\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "str q4, [%[packed_ptr], #0]\n" + "str q5, [%[packed_ptr], #16]\n" + "str q6, [%[packed_ptr], #32]\n" + "str q7, [%[packed_ptr], #48]\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + + "4:\n" + + // Horizontal reduction of the registers used to accumulate sums. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Store the sums. + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), + [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3), + [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr) + : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)), + [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), + [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)), + [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), + [rows] "r"(src_rows), [src_zero_point] "r"(src_zero_point), + [input_xor] "r"(input_xor) + : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); +} +#endif + +#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) + +#define RUY_OFFSET_SRC_PTR0 0 +#define RUY_OFFSET_SRC_PTR1 4 +#define RUY_OFFSET_SRC_PTR2 8 +#define RUY_OFFSET_SRC_PTR3 12 +#define RUY_OFFSET_SUMS_PTR 16 +#define RUY_OFFSET_PACKED_PTR 20 +#define RUY_OFFSET_SRC_INC0 24 +#define RUY_OFFSET_SRC_INC1 28 +#define RUY_OFFSET_SRC_INC2 32 +#define RUY_OFFSET_SRC_INC3 36 +#define RUY_OFFSET_SRC_ROWS 40 +#define RUY_OFFSET_SRC_ZERO_POINT 44 +#define RUY_OFFSET_INPUT_XOR 48 + +template <typename Params> +void CheckOffsetsInPackParams8bit(const Params&) { + static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, ""); + static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, ""); + static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, ""); + static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, ""); + static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, ""); + static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, ""); + static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, ""); + static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, ""); + static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, ""); + static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, ""); + static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, ""); + static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT, + ""); + static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, ""); +} + +// No attempt made at making this code efficient on A55-ish cores yet. +void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params) { + CheckOffsetsInPackParams8bit(params); + profiler::ScopeLabel label("Pack (kNeon)"); + const void* src_ptr0 = params.src_ptr0; + const void* src_ptr1 = params.src_ptr1; + const void* src_ptr2 = params.src_ptr2; + const void* src_ptr3 = params.src_ptr3; + const int src_inc0 = params.src_inc0; + const int src_inc1 = params.src_inc1; + const int src_inc2 = params.src_inc2; + const int src_inc3 = params.src_inc3; + const std::int8_t* packed_ptr = params.packed_ptr; + + asm volatile( + // clang-format off + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" + "vdup.8 q11, r2\n" + "mov r1, #0\n" + // Zero-out the accumulators + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" + "vmov.i32 q14, #0\n" + "vmov.i32 q15, #0\n" + + // Round down src_rows to nearest multiple of 16. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "and r2, r3, #-16\n" + "cmp r1, r2\n" + "beq 3f\n" + + "1:\n" + "add r1, r1, #16\n" + /* Load q0 */ + "vld1.8 {d0, d1}, [%[src_ptr0]]\n" + "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n") + + /* Load q1 */ + "vld1.8 {d2, d3}, [%[src_ptr1]]\n" + "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n") + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + // Pairwise add in to 16b accumulators. + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q12 and q13 contain 4x32b accumulators + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + // Now do the same for src_ptr2 and src_ptr3. + "vld1.8 {d0, d1}, [%[src_ptr2]]\n" + "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n") + + "vld1.8 {d2, d3}, [%[src_ptr3]]\n" + "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n") + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q14 and q15 contain 4x32b accumulators + "vpadal.s16 q14, q8\n" + "vpadal.s16 q15, q9\n" + + "cmp r1, r2\n" + "bne 1b\n" + + "3:\n" + + // Now pack the last (num_rows % 16) rows. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "ands r2, r3, #15\n" + "beq 4f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// First, read/accumulate/write for src_ptr0 and src_ptr1. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Reset to src_zero for src_ptr2 and src_ptr3. + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// Next, read/accumulate/write for src_ptr2 and src_ptr3. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q14, q8\n" + "vpadal.s16 q15, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + "4:\n" + // Pairwise add 32-bit accumulators + "vpadd.i32 d24, d24, d25\n" + "vpadd.i32 d26, d26, d27\n" + "vpadd.i32 d28, d28, d29\n" + "vpadd.i32 d30, d30, d31\n" + // Final 32-bit values per row + "vpadd.i32 d25, d24, d26\n" + "vpadd.i32 d27, d28, d30\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" + "cmp r3, #0\n" + "beq 6f\n" + "vst1.32 {d25}, [r3]!\n" + "vst1.32 {d27}, [r3]!\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3) + : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), + [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3), + [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) + : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); +} + +// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. +// No attempt made at making this code efficient on in-order cores yet. +// This version differs from the above in that we only handle two columns +// at a time. +void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params) { + CheckOffsetsInPackParams8bit(params); + profiler::ScopeLabel label("Pack (kNeon)"); + const void* src_ptr0 = params.src_ptr0; + const void* src_ptr1 = params.src_ptr1; + const int src_inc0 = params.src_inc0; + const int src_inc1 = params.src_inc1; + const std::int8_t* packed_ptr = params.packed_ptr; + + asm volatile( + // clang-format off + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" + "vdup.8 q11, r2\n" + "mov r1, #0\n" + // Zero-out the accumulators + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" + + // Round down src_rows to nearest multiple of 16. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "and r2, r3, #-16\n" + "cmp r1, r2\n" + "beq 3f\n" + + "1:\n" + "add r1, r1, #16\n" + /* Load q0 */ + "vld1.8 {d0, d1}, [%[src_ptr0]]\n" + "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" + + /* Load q1 */ + "vld1.8 {d2, d3}, [%[src_ptr1]]\n" + "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + // Pairwise add in to 16b accumulators. + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q12 and q13 contain 4x32b accumulators + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "cmp r1, r2\n" + + "bne 1b\n" + + "3:\n" + + // Now pack the last (num_rows % 16) rows. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "ands r2, r3, #15\n" + "beq 4f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// Read/accumulate/write for src_ptr0 and src_ptr1. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + "4:\n" + + // Pairwise add 32-bit accumulators + "vpadd.i32 d24, d24, d25\n" + "vpadd.i32 d26, d26, d27\n" + // Final 32-bit values per row + "vpadd.i32 d25, d24, d26\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" + "cmp r3, #0\n" + "beq 6f\n" + "vst1.32 {d25}, [r3]!\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1) + : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), + [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) + : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); +} + +#undef RUY_OFFSET_SRC_PTR0 +#undef RUY_OFFSET_SRC_PTR1 +#undef RUY_OFFSET_SRC_PTR2 +#undef RUY_OFFSET_SRC_PTR32 +#undef RUY_OFFSET_SUMS_PTR +#undef RUY_OFFSET_PACKED_PTR0 +#undef RUY_OFFSET_SRC_INC0 +#undef RUY_OFFSET_SRC_INC1 +#undef RUY_OFFSET_SRC_INC2 +#undef RUY_OFFSET_SRC_INC3 +#undef RUY_OFFSET_SRC_ROWS +#undef RUY_OFFSET_SRC_ZERO_POINT +#undef RUY_OFFSET_INPUT_XOR + +#endif // RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) + +#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) + +void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, + int src_zero_point, std::int8_t* packed_ptr, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); + asm volatile( + // clang-format off + // v26 will be the vector to XOR input values with to perform + // any input data type conversion (e.g. uint8 to int8). + "dup v26.16b, %w[input_xor]\n" + // w1 will be the number of rows already loaded. + "mov w1, #0\n" + // v28--v32 will be used to accumulate the sums + "movi v28.4s, #0\n" + "movi v29.4s, #0\n" + "movi v30.4s, #0\n" + "movi v31.4s, #0\n" + // Let w2 be `rows` rounded down to multiple of 16. + "ands w2, %w[rows], #-16\n" + // If there are no full blocks of 16 rows to process, jump to the + // code handling the last < 16 rows. + "beq 3f\n" + // Load the first block of 16 rows. + "add w1, w1, #16\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "ldr x13, [%[src_ptr3], #8]\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + // Check if these were the only full block of 16 rows to load. + "cmp w1, w2\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + // In that case, jump to the code handling the last loaded block of + // 16 rows. + "beq 2f\n" + // Main loop processing blocks of 16 rows. + "1:\n" + // Load the next 16 rows, interleaved with the XOR input type + // conversion (e.g. uint8->int8) on the already loaded inputs. + "add w1, w1, #16\n" + "ins v0.d[1], x10\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ins v1.d[1], x11\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ins v2.d[1], x12\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ins v3.d[1], x13\n" + "ldr x13, [%[src_ptr3], #8]\n" + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + // Compute the sums, interleaved with storing to the packed matrix. + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + // Was this the last block of 16 rows to load? + "cmp w1, w2\n" + "sadalp v29.4s, v17.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "add %[packed_ptr], %[packed_ptr], #64\n" + "sadalp v30.4s, v18.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "sadalp v31.4s, v19.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + // End of main loop on blocks of 16 rows. + "bne 1b\n" + + // Code handling the last already-loaded block of 16 rows. + "2:\n" + // Process the last loaded full 16x4 block. + "ins v0.d[1], x10\n" + "ins v1.d[1], x11\n" + "ins v2.d[1], x12\n" + "ins v3.d[1], x13\n" + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "add %[packed_ptr], %[packed_ptr], #64\n" + + // End of code handling full blocks of 16 rows. + // Now we handle any remaining rows. + "3:\n" + // Let w2 be the number of rows left to handle. + "ands w2, %w[rows], #15\n" + // If w2==0, there are no remaining rows, jump to the end. + "beq 4f\n" + // Zero out a 16x4 block in registers, which we'll partially overwrite + // with any remaining rows. + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op. +#undef RUY_LOAD_ONE_ROW + "5:\n" + + // Process the last zero-padded 16x4 block. + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "saddlp v17.8h, v5.16b\n" + "saddlp v18.8h, v6.16b\n" + "saddlp v19.8h, v7.16b\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "str q4, [%[packed_ptr], #0]\n" + "str q5, [%[packed_ptr], #16]\n" + "str q6, [%[packed_ptr], #32]\n" + "str q7, [%[packed_ptr], #48]\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + + "4:\n" + + // Horizontal reduction of the registers used to accumulate sums. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Store the sums. + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)), + [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)), + [ rows ] "r"(src_rows), + [ src_zero_point ] "r"(src_zero_point), + [input_xor] "r"(input_xor) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void Pack8bitColMajorForNeonDotprodA55ish( + const void* src_ptr0, const void* src_ptr1, const void* src_ptr2, + const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label( + "Pack (kNeonDotprod, optimized for in-order cores)"); + asm volatile( + // clang-format off + // v26 will be the vector to XOR input values with to perform + // any input data type conversion (e.g. uint8 to int8). + "dup v26.16b, %w[input_xor]\n" + // v27 will be filled with 1's. It will be used as an operand + // to SDOT to compute the sums. + "mov w1, #1\n" + "dup v27.16b, w1\n" + // w1 will be the number of rows already loaded. + "mov w1, #0\n" + // v28--v32 will be used to accumulate the sums + "movi v28.4s, #0\n" + "movi v29.4s, #0\n" + "movi v30.4s, #0\n" + "movi v31.4s, #0\n" + + // Let w2 be `rows` rounded down to multiple of 16. + "ands w2, %w[rows], #-16\n" + // If there are no full blocks of 16 rows to process, jump to the + // code handling the last < 16 rows. + "beq 3f\n" + // Load the first block of 16 rows. + "add w1, w1, #16\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "ldr x13, [%[src_ptr3], #8]\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + // Check if these were the only full block of 16 rows to load. + "cmp w1, w2\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + // In that case, jump to the code handling the last loaded block of + // 16 rows. + "beq 2f\n" + + // Main loop processing blocks of 16 rows. + "1:\n" + "add w1, w1, #16\n" + // Prepare the already-loaded 16 rows by inserting the parts + // loaded into general purpose registers x10--x13 into the + // NEON registers v0--v3 where the other parts had already been + // loaded. + "ins v0.d[1], x10\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ins v1.d[1], x11\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ins v2.d[1], x12\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ins v3.d[1], x13\n" + "ldr x13, [%[src_ptr3], #8]\n" + + // Load the next 16 rows and, interleaved with that, + // perform the input type conversion (e.g. uint8->int8) on the + // current 16 rows. + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + + // Transposition of 4x4 blocks, part 1 + "trn1 v16.4s, v4.4s, v5.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + "trn2 v17.4s, v4.4s, v5.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "trn1 v18.4s, v6.4s, v7.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "trn2 v19.4s, v6.4s, v7.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + + // Transposition of 4x4 blocks, part 2 + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + "cmp w1, w2\n" + + // Store the block to the packed matrix and, interleaved with + // that, compute sums using sdot instructions. + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + // End of main loop on blocks of 16 rows. + "bne 1b\n" + + // Code handling the last already-loaded block of 16 rows. + "2:\n" + // Process the last loaded full 16x4 block. + "ins v0.d[1], x10\n" + "ins v1.d[1], x11\n" + "ins v2.d[1], x12\n" + "ins v3.d[1], x13\n" + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // End of code handling full blocks of 16 rows. + // Now we handle any remaining rows. + "3:\n" + // Let w2 be the number of rows left to handle. + "ands w2, %w[rows], #15\n" + // If w2==0, there are no remaining rows, jump to the end. + "beq 4f\n" + // Zero out a 16x4 block in registers, which we'll partially overwrite + // with any remaining rows. + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op. +#undef RUY_LOAD_ONE_ROW + + "5:\n" + // Process the last zero-padded 16x4 block. + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + "cmp w2, #4\n" + "ble 4f\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + "cmp w2, #8\n" + "ble 4f\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + "cmp w2, #12\n" + "ble 4f\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "4:\n" + + // Reduction of the registers used to accumulate sums. + "add v28.4s, v28.4s, v29.4s\n" + "add v30.4s, v30.4s, v31.4s\n" + "add v28.4s, v28.4s, v30.4s\n" + + // Store the sums. + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), + [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)), + [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)), + [rows] "r"(src_rows), + [src_zero_point] "r"(static_cast<int>(src_zero_point)), + [input_xor] "r"(input_xor) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, + int src_zero_point, std::int8_t* packed_ptr, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label("Pack (kNeonDotprod)"); + asm volatile( + // clang-format off + // v26 will be the vector to XOR input values with to perform + // any input data type conversion (e.g. uint8 to int8). + "dup v26.16b, %w[input_xor]\n" + // v27 will be filled with 1's. It will be used as an operand + // to SDOT to compute the sums. + "mov w1, #1\n" + "dup v27.16b, w1\n" + // w1 will be the number of rows already loaded. + "mov w1, #0\n" + // v28--v32 will be used to accumulate the sums + "movi v28.4s, #0\n" + "movi v29.4s, #0\n" + "movi v30.4s, #0\n" + "movi v31.4s, #0\n" + + // 4x partially unrolled code processing blocks of 64 rows. + // Read the original loop below first, it has more comments. +#if RUY_OPT(MAX_STREAMING) + // Let w2 be `rows` rounded down to multiple of 64. + // Each iteration of this 4x partially unrolled loop handles + // 64 rows. + "ands w2, %w[rows], #-64\n" + // If there are no full blocks of 64 rows to process, jump to + // the main loop below handling 16 rows per iteration. + "beq 9f\n" + // Load the first block of 64 rows. + "add w1, w1, #64\n" + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" + // Was that the last full block of 64 rows to load? + "cmp w1, w2\n" + "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" + // Then jump to the end of the 64-rows-at-a-time code. + "beq 8f\n" + + // Start of the main 4x partially unrolled loop. + "7:\n" + // Process rows 0 -- 15 out of 64. + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // Process rows 16 -- 31 out of 64. + "eor v4.16b, v4.16b, v26.16b\n" + "eor v5.16b, v5.16b, v26.16b\n" + "eor v6.16b, v6.16b, v26.16b\n" + "eor v7.16b, v7.16b, v26.16b\n" + + "trn1 v16.4s, v4.4s, v5.4s\n" + "trn2 v17.4s, v4.4s, v5.4s\n" + "trn1 v18.4s, v6.4s, v7.4s\n" + "trn2 v19.4s, v6.4s, v7.4s\n" + + "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // Process rows 32 -- 47 out of 64. + "eor v8.16b, v8.16b, v26.16b\n" + "eor v9.16b, v9.16b, v26.16b\n" + "eor v10.16b, v10.16b, v26.16b\n" + "eor v11.16b, v11.16b, v26.16b\n" + + "trn1 v16.4s, v8.4s, v9.4s\n" + "trn2 v17.4s, v8.4s, v9.4s\n" + "trn1 v18.4s, v10.4s, v11.4s\n" + "trn2 v19.4s, v10.4s, v11.4s\n" + + "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // Process rows 48 -- 63 out of 64. + "eor v12.16b, v12.16b, v26.16b\n" + "eor v13.16b, v13.16b, v26.16b\n" + "eor v14.16b, v14.16b, v26.16b\n" + "eor v15.16b, v15.16b, v26.16b\n" + + "trn1 v16.4s, v12.4s, v13.4s\n" + "trn2 v17.4s, v12.4s, v13.4s\n" + "trn1 v18.4s, v14.4s, v15.4s\n" + "trn2 v19.4s, v14.4s, v15.4s\n" + + "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "cmp w1, w2\n" + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // End of main 4x partially unrolled loop. + "bne 7b\n" + + // Last part of the 4x partially unrolled code: + // handle the last already-loaded 64 rows. + "8:\n" + + // Process rows 0 -- 15 out of 64. + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // Process rows 16 -- 31 out of 64. + "eor v4.16b, v4.16b, v26.16b\n" + "eor v5.16b, v5.16b, v26.16b\n" + "eor v6.16b, v6.16b, v26.16b\n" + "eor v7.16b, v7.16b, v26.16b\n" + + "trn1 v16.4s, v4.4s, v5.4s\n" + "trn2 v17.4s, v4.4s, v5.4s\n" + "trn1 v18.4s, v6.4s, v7.4s\n" + "trn2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // Process rows 32 -- 47 out of 64. + "eor v8.16b, v8.16b, v26.16b\n" + "eor v9.16b, v9.16b, v26.16b\n" + "eor v10.16b, v10.16b, v26.16b\n" + "eor v11.16b, v11.16b, v26.16b\n" + + "trn1 v16.4s, v8.4s, v9.4s\n" + "trn2 v17.4s, v8.4s, v9.4s\n" + "trn1 v18.4s, v10.4s, v11.4s\n" + "trn2 v19.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // Process rows 48 -- 63 out of 64. + "eor v12.16b, v12.16b, v26.16b\n" + "eor v13.16b, v13.16b, v26.16b\n" + "eor v14.16b, v14.16b, v26.16b\n" + "eor v15.16b, v15.16b, v26.16b\n" + + "trn1 v16.4s, v12.4s, v13.4s\n" + "trn2 v17.4s, v12.4s, v13.4s\n" + "trn1 v18.4s, v14.4s, v15.4s\n" + "trn2 v19.4s, v14.4s, v15.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "9:\n" +#endif // #if RUY_OPT(MAX_STREAMING) + // End of 4x partially unrolled code processing blocks of 64 rows. + + // Main part of the code, processing blocks of 16 rows. + + // Let w2 be `rows` rounded down to multiple of 16. + "and w2, %w[rows], #-16\n" + // If there are no full blocks of 16 rows to process, jump to the + // code handling the last < 16 rows. + "cmp w1, w2\n" + "beq 3f\n" + + // Load the first block of 16 rows. + "add w1, w1, #16\n" + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + // Check if these were the only full block of 16 rows to load. + "cmp w1, w2\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + // In that case, jump to the code handling the last loaded block of + // 16 rows. + "beq 2f\n" + // Main loop processing blocks of 16 rows. + "1:\n" + // Input type conversion (e.g. uint8->int8). + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + // Transposition of 4x4 blocks, part 1 + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + // Load the next 16 rows + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + // Transposition of 4x4 blocks, part 2 + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + // Compute sums using sdot instructions. + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + // Store the block to the packed matrix. + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "cmp w1, w2\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + // End of main loop on blocks of 16 rows. + "bne 1b\n" + + // Code handling the last already-loaded block of 16 rows. + "2:\n" + + // Process the last loaded full 16x4 block. + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // End of code handling full blocks of 16 rows. + // Now we handle any remaining rows. + "3:\n" + // Let w2 be the number of rows left to handle. + "ands w2, %w[rows], #15\n" + // If w2==0, there are no remaining rows, jump to the end. + "beq 4f\n" + // Zero out a 16x4 block in registers, which we'll partially overwrite + // with any remaining rows. + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op. +#undef RUY_LOAD_ONE_ROW + + "5:\n" + // Process the last zero-padded 16x4 block. + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + "cmp w2, #4\n" + "ble 4f\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + "cmp w2, #8\n" + "ble 4f\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + "cmp w2, #12\n" + "ble 4f\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "4:\n" + + // Reduction of the registers used to accumulate sums. + "add v28.4s, v28.4s, v29.4s\n" + "add v30.4s, v30.4s, v31.4s\n" + "add v28.4s, v28.4s, v30.4s\n" + + // Store the sums. + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), + [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3), + [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr) + : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)), + [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), + [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)), + [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), + [rows] "r"(src_rows), + [src_zero_point] "r"(static_cast<int>(src_zero_point)), + [input_xor] "r"(input_xor) + : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); +} + +void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_cols, + int src_zero_point, std::int8_t* packed_ptr, + int packed_stride, std::int32_t* sums_ptr, + int input_xor) { + profiler::ScopeLabel label("Pack (kNeonDotprod, from row-major)"); + asm( + // clang-format off + // Prefetch data. This was tuned on Cortex-A55-rev1 cores. + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0]]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1]]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2]]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3]]\n") + // Let w0 = (number of columns to compute) - 8. + "subs w0, %w[src_cols], 8\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 64]\n") + // Let v26 duplicate the input_xor value in all lanes. + "dup v26.16b, %w[input_xor]\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 64]\n") + // Let v27 be 1 in all lanes. Used with sdot to compute sums. + "movi v27.16b, 1\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 64]\n") + // If there isn't a full block of 8 columns to load from, jump to the + // code after the loop handling leftovers. + "blt 2f\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 64]\n") + // Main loop, each iteration handles a full block of 8 cols. + "1:\n" + // Load the 4x8 block from the source matrix, or zero if we're + // past the bottom of the source matrix. + "ld1 {v0.8b}, [%[src_ptr0]]\n" + "ld1 {v1.8b}, [%[src_ptr1]]\n" + "ld1 {v2.8b}, [%[src_ptr2]]\n" + "ld1 {v3.8b}, [%[src_ptr3]]\n" + // Load values from the sums buffer, and start the reordering + // of the loaded 4x8 block by interleaving 8bit values. + "zip1 v0.16b, v0.16b, v1.16b\n" + "ldr q8, [%[sums_ptr], 0]\n" + "zip1 v1.16b, v2.16b, v3.16b\n" + "ldr q9, [%[sums_ptr], 16]\n" + // Finish the reordering of the 4x8 block, putting it into + // column-major order. + "zip1 v2.8h, v0.8h, v1.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 128]\n") + "zip2 v3.8h, v0.8h, v1.8h\n" + // Apply input_xor, i.e. convert source values from uint8 to int8 + // if needed. + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 128]\n") + "eor v2.16b, v2.16b, v26.16b\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 128]\n") + "eor v3.16b, v3.16b, v26.16b\n" + // Update the sums. + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 128]\n") + ".word 0x4e9b9448 // sdot v8.4s, v2.16b, v27.16b\n" + ".word 0x4e9b9469 // sdot v9.4s, v3.16b, v27.16b\n" + // Store the column-major 4x8 block to the packed matrix, and + // increment some source pointers. + "str q2, [%[packed_ptr], 0]\n" + "add %[src_ptr0], %[src_ptr0], %w[src_inc0], sxtw\n" + "str q3, [%[packed_ptr], 16]\n" + "add %[src_ptr1], %[src_ptr1], %w[src_inc1], sxtw\n" + // Store the updated sums, and increment the remaining pointers + // and the block_col loop index. + "st1 {v8.4s}, [%[sums_ptr]], 16\n" + "add %[packed_ptr], %[packed_ptr], %[packed_stride], lsl 3\n" + "st1 {v9.4s}, [%[sums_ptr]], 16\n" + // Advance by 8 columns and set the condition code. + "subs w0, w0, 8\n" + "add %[src_ptr2], %[src_ptr2], %w[src_inc2], sxtw\n" + "add %[src_ptr3], %[src_ptr3], %w[src_inc3], sxtw\n" + // End of the main loop. + "bge 1b\n" + + "2:\n" + // We add back 8 to w0 so that w0 is the number of columns remaining + // to handle. + "adds w0, w0, 8\n" + // Nothing left? Then jump to the end. + "beq 3f\n" + // Here w0 is between 1 and 7. We zero-initialize v0--v3 ... + "dup v0.8b, %w[src_zero_point]\n" + "dup v1.8b, %w[src_zero_point]\n" + "dup v2.8b, %w[src_zero_point]\n" + "dup v3.8b, %w[src_zero_point]\n" + // ... and now we fill lanes one by one with leftover columns. +#define RUY_LOAD_ONE_COL(C)\ + "cmp w0, " #C "\n" \ + "beq 4f\n" \ + "ld1 { v0.b }[" #C "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #C "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #C "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #C "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_COL(0) + RUY_LOAD_ONE_COL(1) + RUY_LOAD_ONE_COL(2) + RUY_LOAD_ONE_COL(3) + RUY_LOAD_ONE_COL(4) + RUY_LOAD_ONE_COL(5) + RUY_LOAD_ONE_COL(6) + // Here we know that w0==7, so RUY_LOAD_ONE_COL(7) would be a no-op. +#undef RUY_LOAD_ONE_COL + + "4:\n" + // The leftovers source data is loaded, now we can perform the + // computation as usual. + // Load values from the sums buffer, and start the reordering + // of the loaded 4x8 block by interleaving 8bit values. + "zip1 v0.16b, v0.16b, v1.16b\n" + "ldr q8, [%[sums_ptr], 0]\n" + "zip1 v1.16b, v2.16b, v3.16b\n" + "ldr q9, [%[sums_ptr], 16]\n" + // Finish the reordering of the 4x8 block, putting it into + // column-major order. + "zip1 v2.8h, v0.8h, v1.8h\n" + "zip2 v3.8h, v0.8h, v1.8h\n" + // Apply input_xor, i.e. convert source values from uint8 to int8 + // if needed. + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + // Update the sums. + ".word 0x4e9b9448 // sdot v8.4s, v2.16b, v27.16b\n" + ".word 0x4e9b9469 // sdot v9.4s, v3.16b, v27.16b\n" + // Store the column-major 4x8 block to the packed matrix, and + // increment some source pointers. + "str q2, [%[packed_ptr], 0]\n" + "str q3, [%[packed_ptr], 16]\n" + // Store the updated sums, and increment the remaining pointers + // and the block_col loop index. + "st1 {v8.4s}, [%[sums_ptr]], 16\n" + "st1 {v9.4s}, [%[sums_ptr]], 16\n" + + // End label. + "3:\n" + // clang-format on + : [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr), + [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), + [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3) + : [src_inc0] "r"(src_inc0), [src_inc1] "r"(src_inc1), + [src_inc2] "r"(src_inc2), [src_inc3] "r"(src_inc3), + [input_xor] "r"(input_xor), [src_zero_point] "r"(src_zero_point), + [packed_stride] "r"(static_cast<std::int64_t>(packed_stride)), + [src_cols] "r"(src_cols) + : "cc", "memory", "x0", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); +} + +void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack (kNeon)"); + asm volatile( + // clang-format off + // w1 will be the number of rows already loaded. + "mov w1, #0\n" + // Let w2 be `rows` rounded down to multiple of 4. + "ands w2, %w[rows], #-4\n" + // If there are no full blocks of 4 rows to process, jump to the + // code handling the last < 4 rows. + "beq 3f\n" + // Load the first block of 16 rows. + "add w1, w1, #4\n" + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + // Check if these were the only full block of 4 rows to load. + "cmp w1, w2\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + // In that case, jump to the code handling the last loaded block of + // 4 rows. + "beq 2f\n" + // Main loop processing blocks of 4 rows. + "1:\n" + // Advance by 4 rows. + "add w1, w1, #4\n" + // Transposition of the already-loaded 4x4 block, part 1. + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + // Load the next 4x4 block. + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + // Transposition of the already-loaded 4x4 block, part 2. + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + // Was this the last full 4x4 block to load? + "cmp w1, w2\n" + // Store the transposed 4x4 block. + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + // End of main loop on 4x4 blocks. + "bne 1b\n" + + // Code handling the last already-loaded 4x4 block. + "2:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + // End of code handling full 4x4 blocks. + // Now we handle any remaining rows. + "3:\n" + // Let w2 be the number of rows left to handle. + "ands w2, %w[rows], #3\n" + // If w2==0, there are no remaining rows, jump to the end. + "beq 4f\n" + // Zero out a 4x4 block in registers, which we'll partially overwrite + // with any remaining rows. + "movi v0.16b, #0\n" + "movi v1.16b, #0\n" + "movi v2.16b, #0\n" + "movi v3.16b, #0\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ + "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ + "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ + "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + // Here we know that w2==3, so RUY_LOAD_ONE_ROW(3) would be a no-op. +#undef RUY_LOAD_ONE_ROW + "5:\n" + + // Transpose that last zero-padded 4x4 block. + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + // Store that last zero-padded block to the packed matrix. + "mov x1, #32\n" +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp w2, #" #ROW "\n" \ + "beq 4f\n" \ + "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" + + RUY_STORE_ONE_ROW(0, v20) + RUY_STORE_ONE_ROW(1, v21) + RUY_STORE_ONE_ROW(2, v22) + RUY_STORE_ONE_ROW(3, v23) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + + : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), + [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3), + [packed_ptr] "+r"(packed_ptr) + : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)), + [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), + [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)), + [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), + [rows] "r"(src_rows) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} +#endif + +#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) +void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc, int src_rows, float* packed_ptr, + int output_stride) { + profiler::ScopeLabel label("Pack (kNeon)"); + asm volatile( + // clang-format off + "mov r1, #0\n" + "and r2, %[rows], #-4\n" + "cmp r1, r2\n" + "beq 3f\n" +#define RUY_LOAD_FOUR_BY_FOUR() \ + /* Load q0 */ \ + "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \ + /* if src_inc0 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #1\n" \ + "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\ + /* Load q1 */ \ + "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \ + /* if src_inc1 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #2\n" \ + "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\ + /* Load q2 */ \ + "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \ + /* if src_inc2 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #4\n" \ + "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\ + /* Load q3 */ \ + "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \ + /* if src_inc3 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #8\n" \ + "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\ + + RUY_LOAD_FOUR_BY_FOUR() + "add r1, r1, #4\n" + "cmp r1, r2\n" + + "beq 2f\n" + + "1:\n" + "add r1, r1, #4\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + RUY_LOAD_FOUR_BY_FOUR() +#undef RUY_LOAD_FOUR_BY_FOUR + +#define RUY_STORE_FOUR_BY_FOUR() \ + /* Store q8, q10, q9, q11 */ \ + /* q8 = d16, d17 */ \ + "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \ + /* q10 = d20, d21 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \ + /* q9 = d18, d19 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \ + /* q11 = d22, d23 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + + RUY_STORE_FOUR_BY_FOUR() + "cmp r1, r2\n" + + "bne 1b\n" + + "2:\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + RUY_STORE_FOUR_BY_FOUR() +#undef RUY_STORE_FOUR_BY_FOUR + "3:\n" + + "ands r2, %[rows], #3\n" + "beq 4f\n" + "mov r0, #0\n" + // Zero out q0 - q3 + "vdup.32 q0, r0\n" + "vdup.32 q1, r0\n" + "vdup.32 q2, r0\n" + "vdup.32 q3, r0\n" +#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \ + "cmp r2, #" #R "\n" \ + "beq 5f\n" \ + "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \ + "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \ + "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \ + "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n" + +#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \ + "cmp r2, #" #R "\n" \ + "beq 5f\n" \ + "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \ + "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \ + "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \ + "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n" + + RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0) + RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1) + RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0) + RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1) +#undef RUY_LOAD_ONE_ROW_SECOND_HALF +#undef RUY_LOAD_ONE_ROW_FIRST_HALF + "5:\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + "mov r1, #32\n" + +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp r2, #" #ROW "\n" \ + "beq 4f\n" \ + "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" + + // Store q8 + RUY_STORE_ONE_ROW(0, q8) + // Store q10 + RUY_STORE_ONE_ROW(1, q10) + // Store q9 + RUY_STORE_ONE_ROW(2, q9) + // Store q11 + RUY_STORE_ONE_ROW(3, q11) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr) + : [ src_inc ] "r"(static_cast<std::int64_t>(src_inc)), + [ rows ] "r"(src_rows), [ stride ] "r"(output_stride) + : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); +} + +#endif // (RUY_PLATFORM_NEON_32 + +#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) +void PackFloatColMajorForNeonA55ish(const float* src_ptr0, + const float* src_ptr1, + const float* src_ptr2, + const float* src_ptr3, int src_inc0, + int src_inc1, int src_inc2, int src_inc3, + int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); + + asm volatile( + // clang-format off + "mov w1, #0\n" + + "and w2, %w[rows], #-4\n" + "cmp w1, w2\n" + "beq 3f\n" + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + "add w1, w1, #4\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #4\n" + + "ldr x10, [%[src_ptr0], #8]\n" + "trn1 v16.4s, v0.4s, v1.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + "ldr x11, [%[src_ptr1], #8]\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "ldr x12, [%[src_ptr2], #8]\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "ldr x13, [%[src_ptr3], #8]\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + + "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n" + "trn1 v20.2d, v16.2d, v18.2d\n" + "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + "cmp w1, w2\n" + + "ins v0.d[1], x10\n" + "str q20, [%[packed_ptr], #0]\n" + "ins v1.d[1], x11\n" + "str q21, [%[packed_ptr], #32]\n" + "ins v2.d[1], x12\n" + "str q22, [%[packed_ptr], #64]\n" + "ins v3.d[1], x13\n" + "str q23, [%[packed_ptr], #96]\n" + + "add %[packed_ptr], %[packed_ptr], #128\n" + + "bne 1b\n" + + "2:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #3\n" + "beq 4f\n" + "movi v0.16b, #0\n" + "movi v1.16b, #0\n" + "movi v2.16b, #0\n" + "movi v3.16b, #0\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ + "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ + "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ + "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "mov x1, #32\n" + +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp w2, #" #ROW "\n" \ + "beq 4f\n" \ + "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" + + RUY_STORE_ONE_ROW(0, v20) + RUY_STORE_ONE_ROW(1, v21) + RUY_STORE_ONE_ROW(2, v22) + RUY_STORE_ONE_ROW(3, v23) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), + [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr) + : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)), + [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), [rows] "r"(src_rows) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} +#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) + +#if RUY_PLATFORM_NEON + +namespace { +// transpose_*bit_vals are wrappers around ARM TRN1 instructions, allowing +// to use these instructions like we would in assembly --- this is one instance +// where assembly is more idiomatic than intrinsics. +// +// The way that TRN1 is exposed by vtrn_* intrinsics makes its usage very +// cumbersome. The issue is that transposing grouped of values has been exposed +// only as transposing values of a wider type, so this requires many +// vreinterpret's, and to make it worse, vtrn_* return NEON array types like +// int8x8x2_t for which vreinterpret's are not defined! +void transpose_8bit_vals(int8x8_t& a, int8x8_t& b) { + int8x8x2_t t = vtrn_s8(a, b); + a = t.val[0]; + b = t.val[1]; +} + +void transpose_16bit_vals(int8x8_t& a, int8x8_t& b) { + int16x4x2_t t = vtrn_s16(vreinterpret_s16_s8(a), vreinterpret_s16_s8(b)); + a = vreinterpret_s8_s16(t.val[0]); + b = vreinterpret_s8_s16(t.val[1]); +} + +void transpose_32bit_vals(int8x8_t& a, int8x8_t& b) { + int32x2x2_t t = vtrn_s32(vreinterpret_s32_s8(a), vreinterpret_s32_s8(b)); + a = vreinterpret_s8_s32(t.val[0]); + b = vreinterpret_s8_s32(t.val[1]); +} +} // namespace + +void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride, + int src_rows, int src_cols, int block_row, + int start_col, int end_col, + std::int8_t* packed_ptr, int packed_stride, + int packed_zero_point, std::int32_t* sums, + int input_xor, int kernel_cols) { + profiler::ScopeLabel label("Pack (kNeon, from row-major)"); + + int src_end_col = std::min(end_col, src_cols); + int col = start_col; + for (; col <= src_end_col - 8; col += 8) { + // Each iteration of this loop handles 8 columns, and the kernel format + // has 16 rows, so each iteration handles a 16x8 block. + // + // Since the source is row-major, handling 8 columns at a time means + // loading only 8 bytes i.e. 64bit from each row. This may seem surprising + // on 128bit SIMD like NEON. While we could handle 16 columns at a time, + // we prefer to stick with 8 for the following reasons: + // 1. The arithmetic (computing sums and transposing data) done on these + // values is such that even though we initially start from 64bit vectors, + // most of our NEON instructions are full 128bit instructions. For the + // sums computation, that is because summing 8bit values requires + // expansion to 16bit anyway. For the matrix transposition code, that is + // because the ARM ZIP instructions take 64bit of data from two input + // registers and zip it into a 128bit output. If we had 128bit of data + // in each input registers, we would need 2x more ARM NEON instructions + // to zip it. + // 2. The main optimization target for this (ARM, 8bit, non-dotprod) + // code path is in-order ARM cores such as the Cortex-A53, which prefer + // 64bit loads anyway. + // 3. Handling only 8 columns at a time limits the size of the final + // leftover columns handled with slow scalar code. + // + // This code is not very optimized anyway, as evidenced from the facts that + // (1) it's written in intrinsics, (2) it's not using separate versions + // tuned for different types of CPU cores. At the level of optimization that + // it's working at, this seems like a fair compromise. If one wanted to + // maximize performance at the cost of more code complexity/size, one could + // have code handling 16 columns at a time (maybe limited to + // Tuning::kGeneric), then 8, then 4 to minimize the amount of slow + // leftovers. + // + // Load 8 sums in sums0, sums1. + int32x4_t sums0 = vld1q_s32(sums + col); + int32x4_t sums1 = vld1q_s32(sums + col + 4); + // Load the 8x16 block from the source matrix. + // Each val* here is the data from one row. + int8x8_t val0, val1, val2, val3, val4, val5, val6, val7, val8, val9, val10, + val11, val12, val13, val14, val15; + // Even though this function takes a uint8_t* src_ptr, that's only a + // type-erased pointer (using uint8_t* so that pointer arithmetic is + // allowed). The actual type may be either uint8_t or int8_t. The only + // difference it makes is that if it's uint8_t then we need to flip the + // sign bit. This is specified by the input_xor value (which is 0x80 if the + // input data is uint8_t, and 0x0 otherwise). + auto load_and_convert = [=](const std::uint8_t* from) { + return vreinterpret_s8_u8(veor_u8(vdup_n_u8(input_xor), vld1_u8(from))); + }; + if (block_row <= src_rows - 16) { + // Load data in the regular case: there are still 16 rows to be read from + // the source matrix. + val0 = load_and_convert(src_ptr + 0 * src_stride); + val1 = load_and_convert(src_ptr + 1 * src_stride); + val2 = load_and_convert(src_ptr + 2 * src_stride); + val3 = load_and_convert(src_ptr + 3 * src_stride); + val4 = load_and_convert(src_ptr + 4 * src_stride); + val5 = load_and_convert(src_ptr + 5 * src_stride); + val6 = load_and_convert(src_ptr + 6 * src_stride); + val7 = load_and_convert(src_ptr + 7 * src_stride); + val8 = load_and_convert(src_ptr + 8 * src_stride); + val9 = load_and_convert(src_ptr + 9 * src_stride); + val10 = load_and_convert(src_ptr + 10 * src_stride); + val11 = load_and_convert(src_ptr + 11 * src_stride); + val12 = load_and_convert(src_ptr + 12 * src_stride); + val13 = load_and_convert(src_ptr + 13 * src_stride); + val14 = load_and_convert(src_ptr + 14 * src_stride); + val15 = load_and_convert(src_ptr + 15 * src_stride); + } else { + // Boundary case: there are fewer than 16 rows to be read from the source + // matrix. We pad by the zero_point. + val0 = vdup_n_s8(packed_zero_point); + val1 = val0; + val2 = val0; + val3 = val0; + val4 = val0; + val5 = val0; + val6 = val0; + val7 = val0; + val8 = val0; + val9 = val0; + val10 = val0; + val11 = val0; + val12 = val0; + val13 = val0; + val14 = val0; + val15 = val0; + if (block_row + 0 < src_rows) + val0 = load_and_convert(src_ptr + 0 * src_stride); + if (block_row + 1 < src_rows) + val1 = load_and_convert(src_ptr + 1 * src_stride); + if (block_row + 2 < src_rows) + val2 = load_and_convert(src_ptr + 2 * src_stride); + if (block_row + 3 < src_rows) + val3 = load_and_convert(src_ptr + 3 * src_stride); + if (block_row + 4 < src_rows) + val4 = load_and_convert(src_ptr + 4 * src_stride); + if (block_row + 5 < src_rows) + val5 = load_and_convert(src_ptr + 5 * src_stride); + if (block_row + 6 < src_rows) + val6 = load_and_convert(src_ptr + 6 * src_stride); + if (block_row + 7 < src_rows) + val7 = load_and_convert(src_ptr + 7 * src_stride); + if (block_row + 8 < src_rows) + val8 = load_and_convert(src_ptr + 8 * src_stride); + if (block_row + 9 < src_rows) + val9 = load_and_convert(src_ptr + 9 * src_stride); + if (block_row + 10 < src_rows) + val10 = load_and_convert(src_ptr + 10 * src_stride); + if (block_row + 11 < src_rows) + val11 = load_and_convert(src_ptr + 11 * src_stride); + if (block_row + 12 < src_rows) + val12 = load_and_convert(src_ptr + 12 * src_stride); + if (block_row + 13 < src_rows) + val13 = load_and_convert(src_ptr + 13 * src_stride); + if (block_row + 14 < src_rows) + val14 = load_and_convert(src_ptr + 14 * src_stride); + if (block_row + 15 < src_rows) + val15 = load_and_convert(src_ptr + 15 * src_stride); + } + src_ptr += 8; + // Compute sums. + int16x8_t sums16_0 = vaddl_s8(val0, val1); + int16x8_t sums16_1 = vaddl_s8(val2, val3); + sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val4, val5)); + sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val6, val7)); + sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val8, val9)); + sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val10, val11)); + sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val12, val13)); + sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val14, val15)); + int16x8_t sums16 = vaddq_s16(sums16_0, sums16_1); + sums0 = vaddw_s16(sums0, vget_low_s16(sums16)); + sums1 = vaddw_s16(sums1, vget_high_s16(sums16)); + // Store sums. + vst1q_s32(sums + col, sums0); + vst1q_s32(sums + col + 4, sums1); + + // Transpose the data, i.e. change the storage order of the + // 16x8 block, to convert from the row-major source to the + // column-major packed format. + // + // Before, for i in [0, 15], val<i> is the i-th row. + // After, for i in [0, 7], { val<i> val<i+8> } is the i-th column. + transpose_8bit_vals(val0, val1); + transpose_8bit_vals(val2, val3); + transpose_8bit_vals(val4, val5); + transpose_8bit_vals(val6, val7); + transpose_8bit_vals(val8, val9); + transpose_8bit_vals(val10, val11); + transpose_8bit_vals(val12, val13); + transpose_8bit_vals(val14, val15); + transpose_16bit_vals(val0, val2); + transpose_16bit_vals(val1, val3); + transpose_16bit_vals(val4, val6); + transpose_16bit_vals(val5, val7); + transpose_16bit_vals(val8, val10); + transpose_16bit_vals(val9, val11); + transpose_16bit_vals(val12, val14); + transpose_16bit_vals(val13, val15); + transpose_32bit_vals(val0, val4); + transpose_32bit_vals(val1, val5); + transpose_32bit_vals(val2, val6); + transpose_32bit_vals(val3, val7); + transpose_32bit_vals(val8, val12); + transpose_32bit_vals(val9, val13); + transpose_32bit_vals(val10, val14); + transpose_32bit_vals(val11, val15); + // Store to the packed_matrix. + std::int8_t* dst_ptr = packed_ptr; + vst1q_s8(dst_ptr, vcombine_s8(val0, val8)); + vst1q_s8(dst_ptr + 16, vcombine_s8(val1, val9)); + dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32; + vst1q_s8(dst_ptr, vcombine_s8(val2, val10)); + vst1q_s8(dst_ptr + 16, vcombine_s8(val3, val11)); + packed_ptr += 4 * packed_stride; + dst_ptr = packed_ptr; + vst1q_s8(dst_ptr, vcombine_s8(val4, val12)); + vst1q_s8(dst_ptr + 16, vcombine_s8(val5, val13)); + dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32; + vst1q_s8(dst_ptr, vcombine_s8(val6, val14)); + vst1q_s8(dst_ptr + 16, vcombine_s8(val7, val15)); + packed_ptr += 4 * packed_stride; + } + // Handle remaining columns, not fitting in a full block of 8 columns, but + // still true columns frome the source matrix (as opposed to the final columns + // below). + for (; col < src_end_col; col++) { + std::int32_t accum = 0; + std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16; + for (int r = 0; r < 16; r++) { + std::int8_t packed_val = (block_row + r < src_rows) + ? (src_ptr[r * src_stride] ^ input_xor) + : packed_zero_point; + accum += packed_val; + dst_ptr[r] = packed_val; + } + if (sums) { + sums[col] += accum; + } + src_ptr++; + if (((col + 1) & (kernel_cols - 1)) == 0) { + packed_ptr += kernel_cols * packed_stride; + } + } + // Handle the final columns of the packed matrix, beyond the last column of + // the source matrix. The values here don't matter, we just want to avoid + // leaving uninitialized data. Since the sums are already initialized above, + // we don't need to do anything about them here. + for (; col < end_col; col++) { + std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16; + std::memset(dst_ptr, 0, 16); + if (((col + 1) & (kernel_cols - 1)) == 0) { + packed_ptr += kernel_cols * packed_stride; + } + } +} + +#endif + +} // namespace ruy diff --git a/ruy/pack_arm.h b/ruy/pack_arm.h new file mode 100644 index 0000000..ba8964d --- /dev/null +++ b/ruy/pack_arm.h @@ -0,0 +1,613 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PACK_ARM_H_ +#define RUY_RUY_PACK_ARM_H_ + +#include <algorithm> +#include <cstdint> +#include <type_traits> + +#include "ruy/asm_helpers.h" +#include "ruy/check_macros.h" +#include "ruy/mat.h" +#include "ruy/opt_set.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM_NEON +RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon) +RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod) + +RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 8) +#if RUY_PLATFORM_NEON_32 +RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 4) +#endif + +template <> +struct PackedTypeImpl<Path::kNeon, std::uint8_t> { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> { + using Type = std::int8_t; +}; +#endif + +#if RUY_PLATFORM_NEON +void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride, + int src_rows, int src_cols, int block_row, + int start_col, int end_col, + std::int8_t* packed_ptr, int packed_stride, + int packed_zero_point, std::int32_t* sums_ptr, + int input_xor, int kernel_cols); +#endif + +#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) + +void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + int input_xor); +void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, + int src_zero_point, std::int8_t* packed_ptr, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, + int src_zero_point, std::int8_t* packed_ptr, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitColMajorForNeonDotprodA55ish( + const void* src_ptr0, const void* src_ptr1, const void* src_ptr2, + const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_cols, + int src_zero_point, std::int8_t* packed_ptr, + int packed_stride, std::int32_t* sums_ptr, + int input_xor); +#elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) + +struct PackParams8bit { + const void* src_ptr0; + const void* src_ptr1; + const void* src_ptr2; + const void* src_ptr3; + const std::int32_t* sums_ptr; + const std::int8_t* packed_ptr; + int src_inc0; + int src_inc1; + int src_inc2; + int src_inc3; + int src_rows; + int src_zero_point; + int input_xor; +}; + +inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + const std::int32_t* sums_ptr, + const std::int8_t* packed_ptr, int src_inc0, + int src_inc1, int src_inc2, int src_inc3, + int src_rows, int src_zero_point, int input_xor, + PackParams8bit* params) { + params->src_ptr0 = src_ptr0; + params->src_ptr1 = src_ptr1; + params->src_ptr2 = src_ptr2; + params->src_ptr3 = src_ptr3; + params->sums_ptr = sums_ptr; + params->packed_ptr = packed_ptr; + params->src_inc0 = src_inc0; + params->src_inc1 = src_inc1; + params->src_inc2 = src_inc2; + params->src_inc3 = src_inc3; + params->src_rows = src_rows; + params->src_zero_point = src_zero_point; + params->input_xor = input_xor; +} + +void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params); +void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params); + +#endif // (RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) + +#if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM) + +template <typename Scalar> +struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar, + std::int8_t, std::int32_t, Order::kColMajor> { + static_assert(std::is_same<Scalar, std::int8_t>::value || + std::is_same<Scalar, std::uint8_t>::value, + ""); + static constexpr int kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 4, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + const Scalar* src_ptr2 = src_ptr1 + src_stride; + const Scalar* src_ptr3 = src_ptr2 + src_stride; + int src_inc0 = 16; + int src_inc1 = 16; + int src_inc2 = 16; + int src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * block_col; + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; +#if RUY_PLATFORM_NEON_64 + if (__builtin_expect(tuning == Tuning::kA55ish, true)) { + Pack8bitColMajorForNeonA55ish( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, sums_ptr, kInputXor); + } else { + Pack8bitColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, + src_inc0, src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, sums_ptr, kInputXor); + } +#else + (void)tuning; + // We have a more limited set of general purpose registers in ARMv7, so + // we use the "params" struct technique from the kernel code to save + // registers. + PackParams8bit params; + MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr, + packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, src_matrix.zero_point, + kInputXor, ¶ms); + Pack8bitColMajorForNeon4Cols(params); +#endif // RUY_PLATFORM_NEON_64 + } + } +}; + +#endif // (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && + // RUY_OPT(ASM) + +#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) +// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional +// partial specialization for the RHS, which has a FixedKernelLayout with 2 +// columns. +template <typename Scalar> +struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 2>, Scalar, + std::int8_t, std::int32_t, Order::kColMajor> { + static_assert(std::is_same<Scalar, std::int8_t>::value || + std::is_same<Scalar, std::uint8_t>::value, + ""); + static constexpr int kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 2, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 2) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + int src_inc0 = 16; + int src_inc1 = 16; + if (block_col >= src_matrix.layout.cols - 2) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * block_col; + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + PackParams8bit params; + MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr, + packed_ptr, src_inc0, src_inc1, -1, -1, + src_matrix.layout.rows, src_matrix.zero_point, + kInputXor, ¶ms); + Pack8bitColMajorForNeon2Cols(params); + } + } +}; +#endif // (RUY_PLATFORM_NEON_32) && RUY_OPT(ASM) + +#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) +template <typename Scalar> +struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>, + Scalar, std::int8_t, std::int32_t, Order::kColMajor> { + static_assert(std::is_same<Scalar, std::int8_t>::value || + std::is_same<Scalar, std::uint8_t>::value, + ""); + static constexpr int kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 8, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + const Scalar* src_ptr2 = src_ptr1 + src_stride; + const Scalar* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & ~7) + + ((block_col & 4) * 4); + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + if (__builtin_expect(tuning == Tuning::kA55ish, true)) { + Pack8bitColMajorForNeonDotprodA55ish( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, sums_ptr, kInputXor); + } else { + Pack8bitColMajorForNeonDotprod( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, sums_ptr, kInputXor); + } + } + } +}; +#endif // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM) + +#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) +void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, float* packed_ptr); +void PackFloatColMajorForNeonA55ish(const float* src_ptr0, + const float* src_ptr1, + const float* src_ptr2, + const float* src_ptr3, int src_inc0, + int src_inc1, int src_inc2, int src_inc3, + int src_rows, float* packed_ptr); + +#elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) +void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc, int src_rows, float* packed_ptr, + int stride); +#endif // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM) + +#if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM) + +template <> +struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, + float, float, Order::kColMajor> { + static void Run(Tuning tuning, const Mat<float>& src_matrix, + PMat<float>* packed_matrix, int start_col, int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 8, 0); + const float zerobuf[4] = {0}; + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + float* packed_ptr = packed_matrix->data + + packed_matrix->layout.stride * (block_col & ~7) + + ((block_col & 4)); +#if RUY_PLATFORM_NEON_64 + if (__builtin_expect(tuning == Tuning::kA55ish, true)) { + PackFloatColMajorForNeonA55ish(src_ptr0, src_ptr1, src_ptr2, src_ptr3, + src_inc0, src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, packed_ptr); + } else { + PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, + src_inc0, src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, packed_ptr); + } +#else + (void)tuning; + // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc + // to save on registers (we have fewer general purpose registers in + // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four + // values that are each either 16 or 0 and use them directly. For the + // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should + // use the value 16 (bit is set) or 0 (bit is not set) for the + // respective increment value. + std::int64_t src_inc = 0; + src_inc += src_inc0 == 16 ? 1 : 0; + src_inc += src_inc1 == 16 ? 2 : 0; + src_inc += src_inc2 == 16 ? 4 : 0; + src_inc += src_inc3 == 16 ? 8 : 0; + const int kOutputStride = 32; + PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, + src_matrix.layout.rows, packed_ptr, + kOutputStride); +#endif // RUY_PLATFORM_NEON_64 + } + } +}; + +#if RUY_PLATFORM_NEON_32 +// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional +// specialization for a FixedKernelLayout with 4 columns. +template <> +struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float, + float, float, Order::kColMajor> { + static void Run(Tuning, const Mat<float>& src_matrix, + PMat<float>* packed_matrix, int start_col, int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 4, 0); + const float zerobuf[4] = {0}; + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + float* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * (block_col); + // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc + // to save registers. + std::int64_t src_inc = 0; + src_inc += src_inc0 == 16 ? 1 : 0; + src_inc += src_inc1 == 16 ? 2 : 0; + src_inc += src_inc2 == 16 ? 4 : 0; + src_inc += src_inc3 == 16 ? 8 : 0; + const int kOutputStride = 16; + PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, + src_matrix.layout.rows, packed_ptr, + kOutputStride); + } + } +}; +#endif // (RUY_PLATFORM_NEON_32) +#endif // (RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && \ + // RUY_OPT(ASM) + +#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) + +template <typename Scalar> +struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>, + Scalar, std::int8_t, std::int32_t, Order::kRowMajor> { + static_assert(std::is_same<Scalar, std::int8_t>::value || + std::is_same<Scalar, std::uint8_t>::value, + ""); + static constexpr int kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsRowMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 8, 0); + std::int32_t* sums = packed_matrix->sums; + std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); + Scalar zerobuf[8]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + int src_stride = src_matrix.layout.stride; + // As the source matrix is row-major and the destination packed matrix is + // column-major, there is no traversal order that will be optimal for both + // so we choose to favor the source matrix with a row-major traversal order. + // Loop over groups of 4 rows. + for (int block_row = 0; block_row < packed_matrix->layout.rows; + block_row += 4) { + // src_ptr[0-3] shall point to the positions in the 4 rows of the source + // matrix that we are loading from, and will be incremented by + // src_inc[0-3] after each 4x8 block is loaded. + // First we compute these src_ptr and src_inc values for the case where + // there are 4 rows left to be loaded from in the source matrix ... + const Scalar* src_ptr0 = + src_matrix.data.get() + src_stride * block_row + start_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + const Scalar* src_ptr2 = src_ptr1 + src_stride; + const Scalar* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + // ... and now we adjust these values in case there are fewer than 4 rows + // left to load from in the source matrix. In that case, we set the + // corresponding src_ptr pointer to load from `zerobuf` and set src_inc + // to 0 to avoid overrunning that small buffer. + if (block_row >= src_matrix.layout.rows - 3) { + if (block_row >= src_matrix.layout.rows - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_row >= src_matrix.layout.rows - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_row >= src_matrix.layout.rows - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_row >= src_matrix.layout.rows - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + // Let src_cols be the number of source matrix columns to handle. + int src_cols = std::min(end_col, src_matrix.layout.cols) - start_col; + std::int8_t* packed_ptr = packed_matrix->data + + packed_matrix->layout.stride * start_col + + 8 * block_row; + std::int32_t* sums_ptr = sums + start_col; + Pack8bitRowMajorForNeonDotprod( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, src_inc2, + src_inc3, src_cols, src_matrix.zero_point, packed_ptr, + packed_matrix->layout.stride, sums_ptr, kInputXor); + } + } +}; + +#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) + +#if RUY_PLATFORM_NEON + +template <typename Scalar, int KernelCols> +struct PackImpl<Path::kNeon, + FixedKernelLayout<Order::kColMajor, 16, KernelCols>, Scalar, + std::int8_t, std::int32_t, Order::kRowMajor> { + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (KNeon, from row-major source)"); + static constexpr int kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); + RUY_DCHECK_EQ((end_col - start_col) % KernelCols, 0); + std::int32_t* sums = packed_matrix->sums; + std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); + int block_row = 0; + for (; block_row < packed_matrix->layout.rows; block_row += 16) { + int src_stride = src_matrix.layout.stride; + int packed_stride = packed_matrix->layout.stride; + const Scalar* src_ptr = + src_matrix.data.get() + block_row * src_stride + start_col; + std::int8_t* packed_ptr = packed_matrix->data + + start_col * packed_stride + + block_row * KernelCols; + + Pack8bitRowMajorForNeon( + reinterpret_cast<const std::uint8_t*>(src_ptr), src_stride, + src_matrix.layout.rows, src_matrix.layout.cols, block_row, start_col, + end_col, packed_ptr, packed_stride, packed_matrix->zero_point, sums, + kInputXor, KernelCols); + } + } +}; +#endif + +} // namespace ruy + +#endif // RUY_RUY_PACK_ARM_H_ diff --git a/ruy/pack_avx.cc b/ruy/pack_avx.cc new file mode 100644 index 0000000..2b929e7 --- /dev/null +++ b/ruy/pack_avx.cc @@ -0,0 +1,831 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <cstring> + +#include "ruy/check_macros.h" +#include "ruy/opt_set.h" +#include "ruy/pack_x86.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS) +#include <immintrin.h> // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) + +void Pack8bitColMajorForAvx(const std::int8_t*, std::int8_t, const std::int8_t*, + int, int, int, std::int8_t*, std::int32_t*) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatColMajorForAvx(const float*, const float*, int, int, int, + float*) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Pack8bitRowMajorForAvx(const std::uint8_t*, int, int, std::int8_t*, int, + int, int, int, int, int, int, std::int32_t*) { + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM_AVX && RUY_OPT(ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvx = + PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, std::int8_t, + std::int8_t, std::int32_t, Order::kColMajor>; + +using PackImplFloatAvx = + PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, + float, float, Order::kColMajor>; + +namespace { + +// Perform the equivalent of mm256_permutevar8x32 with +// a second argument of {7, 5, 3, 1, 6, 4, 2, 0} +inline __m256i PermuteEpi32EvenOdds(const __m256i& a) { + // a_lo = 3 2 1 0 + __m128i a_lo = _mm256_extractf128_si256(a, 0); + // a_hi = 7 6 5 4 + __m128i a_hi = _mm256_extractf128_si256(a, 1); + // shuffle a_lo to get 3 1 2 0 + __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8); + // shuffle a_hi to get 7 5 6 4 + __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8); + // unpack lo 64 of res_lo and res hi to get 6 4 2 0 + __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi); + // unpack hi 64 of res_lo and res hi to get 7 5 3 1 + __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi); + return _mm256_set_m128i(res_hi, res_lo); +} + +inline __m128i mm256_extracti128_si256(const __m256i& a, const int imm) { + switch (imm) { + case 0: + return _mm256_extractf128_si256(a, 0); + case 1: + return _mm256_extractf128_si256(a, 1); + default: + RUY_DCHECK_LT(imm, 2); + return _mm_setzero_si128(); + } +} + +inline __m256i mm256_cvtepi8_epi16(const __m128i& a) { + // Take the upper 64 bits of a and put in the first 64 bits of 'hi' + __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128()); + return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a)); +} + +inline __m256i mm256_cvtepi16_epi32(const __m128i& a) { + // Take the upper 64 bits of a and put in the first 64 bits of 'hi' + __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128()); + return _mm256_set_m128i(_mm_cvtepi16_epi32(hi), _mm_cvtepi16_epi32(a)); +} + +inline __m256i mm256_xor_si256(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_xor_si128(a_lo, b_lo); + __m128i hi = _mm_xor_si128(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m256i mm256_unpacklo_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_unpacklo_epi32(a_lo, b_lo); + __m128i hi = _mm_unpacklo_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m256i mm256_unpacklo_epi64(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_unpacklo_epi64(a_lo, b_lo); + __m128i hi = _mm_unpacklo_epi64(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m256i mm256_unpackhi_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_unpackhi_epi32(a_lo, b_lo); + __m128i hi = _mm_unpackhi_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m256i mm256_unpackhi_epi64(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_unpackhi_epi64(a_lo, b_lo); + __m128i hi = _mm_unpackhi_epi64(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_add_epi32(a_lo, b_lo); + __m128i hi = _mm_add_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m256i mm256_add_epi16(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_add_epi16(a_lo, b_lo); + __m128i hi = _mm_add_epi16(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_madd_epi16(a_lo, b_lo); + __m128i hi = _mm_madd_epi16(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b, + const int imm) { + __m128i tmp = _mm_setzero_si128(); + if (!(imm & 8)) { + switch (imm & 3) { + case 0: + return _mm256_extractf128_si256(a, 0); + case 1: + return _mm256_extractf128_si256(a, 1); + case 2: + return _mm256_extractf128_si256(b, 0); + case 3: + return _mm256_extractf128_si256(b, 1); + } + } + return tmp; +} + +inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b, + const int imm) { + const int lo_imm = imm & 15; + __m128i lo = mm_permute_helper(a, b, lo_imm); + const int hi_imm = (imm >> 4) & 15; + __m128i hi = mm_permute_helper(a, b, hi_imm); + return _mm256_set_m128i(hi, lo); +} + +inline void Pack8bitColMajorForAvxPacker(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, std::int8_t* packed_ptr, + std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitAvx::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have Layout::kCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: Layout::kCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + std::int32_t sums_adjustment = 0; + const __m256i ones_16bit = _mm256_set1_epi16(1); + __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); + __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + if (sums_ptr) { + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0)); + t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4)); + t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1)); + t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5)); + t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2)); + t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6)); + t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3)); + t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7)); + + r0 = mm256_unpacklo_epi32(t0, t1); + r4 = mm256_unpacklo_epi32(t4, t5); + r2 = mm256_unpackhi_epi32(t0, t1); + r6 = mm256_unpackhi_epi32(t4, t5); + r1 = mm256_unpacklo_epi32(t2, t3); + r5 = mm256_unpacklo_epi32(t6, t7); + r3 = mm256_unpackhi_epi32(t2, t3); + r7 = mm256_unpackhi_epi32(t6, t7); + + t0 = mm256_unpacklo_epi64(r0, r1); + t4 = mm256_unpacklo_epi64(r4, r5); + t2 = mm256_unpackhi_epi64(r0, r1); + t6 = mm256_unpackhi_epi64(r4, r5); + t1 = mm256_unpacklo_epi64(r2, r3); + t5 = mm256_unpacklo_epi64(r6, r7); + t3 = mm256_unpackhi_epi64(r2, r3); + t7 = mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = mm256_permute2x128_si256(t0, t4, 0x20); + r4 = mm256_permute2x128_si256(t1, t5, 0x20); + r1 = mm256_permute2x128_si256(t0, t4, 0x31); + r5 = mm256_permute2x128_si256(t1, t5, 0x31); + r2 = mm256_permute2x128_si256(t2, t6, 0x20); + r6 = mm256_permute2x128_si256(t3, t7, 0x20); + r3 = mm256_permute2x128_si256(t2, t6, 0x31); + r7 = mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = mm256_xor_si256(r0, input_xor_v); + r1 = mm256_xor_si256(r1, input_xor_v); + r2 = mm256_xor_si256(r2, input_xor_v); + r3 = mm256_xor_si256(r3, input_xor_v); + r4 = mm256_xor_si256(r4, input_xor_v); + r5 = mm256_xor_si256(r5, input_xor_v); + r6 = mm256_xor_si256(r6, input_xor_v); + r7 = mm256_xor_si256(r7, input_xor_v); + + __m256i sums_4x4_16bit_lo; + sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); + + // The sums have been performed across columns, and now we have 4x16-bit + // sums packed together. We use madd for pairwise 32-bit sums. + const __m256i sums_4x2_32bit_lo_new = + mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); + sums_4x2_32bit_lo = + mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); + + __m256i sums_4x4_16bit_hi; + sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1)); + sums_4x4_16bit_hi = mm256_add_epi16( + sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1))); + sums_4x4_16bit_hi = mm256_add_epi16( + sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1))); + sums_4x4_16bit_hi = mm256_add_epi16( + sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1))); + sums_4x4_16bit_hi = mm256_add_epi16( + sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1))); + sums_4x4_16bit_hi = mm256_add_epi16( + sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1))); + sums_4x4_16bit_hi = mm256_add_epi16( + sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1))); + sums_4x4_16bit_hi = mm256_add_epi16( + sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1))); + + const __m256i sums_4x2_32bit_hi_new = + mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); + sums_4x2_32bit_hi = + mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), + r7); + } else { + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0)); + t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4)); + t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1)); + t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5)); + t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2)); + t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6)); + t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3)); + t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7)); + + r0 = mm256_unpacklo_epi32(t0, t1); + r4 = mm256_unpacklo_epi32(t4, t5); + r2 = mm256_unpackhi_epi32(t0, t1); + r6 = mm256_unpackhi_epi32(t4, t5); + r1 = mm256_unpacklo_epi32(t2, t3); + r5 = mm256_unpacklo_epi32(t6, t7); + r3 = mm256_unpackhi_epi32(t2, t3); + r7 = mm256_unpackhi_epi32(t6, t7); + + t0 = mm256_unpacklo_epi64(r0, r1); + t4 = mm256_unpacklo_epi64(r4, r5); + t2 = mm256_unpackhi_epi64(r0, r1); + t6 = mm256_unpackhi_epi64(r4, r5); + t1 = mm256_unpacklo_epi64(r2, r3); + t5 = mm256_unpacklo_epi64(r6, r7); + t3 = mm256_unpackhi_epi64(r2, r3); + t7 = mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = mm256_permute2x128_si256(t0, t4, 0x20); + r4 = mm256_permute2x128_si256(t1, t5, 0x20); + r1 = mm256_permute2x128_si256(t0, t4, 0x31); + r5 = mm256_permute2x128_si256(t1, t5, 0x31); + r2 = mm256_permute2x128_si256(t2, t6, 0x20); + r6 = mm256_permute2x128_si256(t3, t7, 0x20); + r3 = mm256_permute2x128_si256(t2, t6, 0x31); + r7 = mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = mm256_xor_si256(r0, input_xor_v); + r1 = mm256_xor_si256(r1, input_xor_v); + r2 = mm256_xor_si256(r2, input_xor_v); + r3 = mm256_xor_si256(r3, input_xor_v); + r4 = mm256_xor_si256(r4, input_xor_v); + r5 = mm256_xor_si256(r5, input_xor_v); + r6 = mm256_xor_si256(r6, input_xor_v); + r7 = mm256_xor_si256(r7, input_xor_v); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), + r7); + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // We compensate for padding-with-zero_point by initializing the + // summations with the compensating offset, effectively + // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + // + // Note that (zero_point ^ input_xor) is performed in 8-bits and then + // cast. + sums_adjustment += + -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); + + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr0); + t4 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr4); + t1 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr1); + t5 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr5); + t2 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr2); + t6 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr6); + t3 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr3); + t7 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr7); + + r0 = mm256_unpacklo_epi32(t0, t1); + r4 = mm256_unpacklo_epi32(t4, t5); + r2 = mm256_unpackhi_epi32(t0, t1); + r6 = mm256_unpackhi_epi32(t4, t5); + r1 = mm256_unpacklo_epi32(t2, t3); + r5 = mm256_unpacklo_epi32(t6, t7); + r3 = mm256_unpackhi_epi32(t2, t3); + r7 = mm256_unpackhi_epi32(t6, t7); + + t0 = mm256_unpacklo_epi64(r0, r1); + t4 = mm256_unpacklo_epi64(r4, r5); + t2 = mm256_unpackhi_epi64(r0, r1); + t6 = mm256_unpackhi_epi64(r4, r5); + t1 = mm256_unpacklo_epi64(r2, r3); + t5 = mm256_unpacklo_epi64(r6, r7); + t3 = mm256_unpackhi_epi64(r2, r3); + t7 = mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = mm256_permute2x128_si256(t0, t4, 0x20); + r4 = mm256_permute2x128_si256(t1, t5, 0x20); + r1 = mm256_permute2x128_si256(t0, t4, 0x31); + r5 = mm256_permute2x128_si256(t1, t5, 0x31); + r2 = mm256_permute2x128_si256(t2, t6, 0x20); + r6 = mm256_permute2x128_si256(t3, t7, 0x20); + r3 = mm256_permute2x128_si256(t2, t6, 0x31); + r7 = mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = mm256_xor_si256(r0, input_xor_v); + r1 = mm256_xor_si256(r1, input_xor_v); + r2 = mm256_xor_si256(r2, input_xor_v); + r3 = mm256_xor_si256(r3, input_xor_v); + r4 = mm256_xor_si256(r4, input_xor_v); + r5 = mm256_xor_si256(r5, input_xor_v); + r6 = mm256_xor_si256(r6, input_xor_v); + r7 = mm256_xor_si256(r7, input_xor_v); + + __m256i sums_4x4_16bit_lo; + sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); + sums_4x4_16bit_lo = mm256_add_epi16( + sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); + + // The sums have been performed across columns, and now we have 4x16-bit + // sums packed together. We use madd for pairwise 32-bit sums. + const __m256i sums_4x2_32bit_lo_new = + mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); + sums_4x2_32bit_lo = + mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); + + __m256i sums_4x4_16bit_hi; + sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1)); + sums_4x4_16bit_hi = + mm256_add_epi16(sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1))); + sums_4x4_16bit_hi = + mm256_add_epi16(sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1))); + sums_4x4_16bit_hi = + mm256_add_epi16(sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1))); + sums_4x4_16bit_hi = + mm256_add_epi16(sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1))); + sums_4x4_16bit_hi = + mm256_add_epi16(sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1))); + sums_4x4_16bit_hi = + mm256_add_epi16(sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1))); + sums_4x4_16bit_hi = + mm256_add_epi16(sums_4x4_16bit_hi, + mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1))); + + const __m256i sums_4x2_32bit_hi_new = + mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); + sums_4x2_32bit_hi = + mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), + r7); + } + + packed_ptr += 8 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + + if (sums_ptr) { + const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); + + __m256i sums = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr)); + + // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the + // neighbours, finshing up by adding them to the stored accumulated sums. + const __m256i sums_2x4_32bit_lo = PermuteEpi32EvenOdds(sums_4x2_32bit_lo); + const __m256i sums_2x4_32bit_hi = PermuteEpi32EvenOdds(sums_4x2_32bit_hi); + const __m256i sums_2x4_32bit_a = + mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); + const __m256i sums_2x4_32bit_b = + mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); + sums = mm256_add_epi32(sums, sums_adjustment_v); + sums = mm256_add_epi32(sums, sums_2x4_32bit_a); + sums = mm256_add_epi32(sums, sums_2x4_32bit_b); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); + } +} + +// Use a generic AVX intrinsic for greater-than comparison. +template <> +inline __m256i CompareGreaterThan<Path::kAvx>(const __m256i& a, + const __m256i& b) { + constexpr int kGreaterThanSignalling = 14; + const __m256 v = _mm256_cmp_ps(_mm256_cvtepi32_ps(a), _mm256_cvtepi32_ps(b), + kGreaterThanSignalling); + return _mm256_cvtps_epi32(v); +} + +} // namespace. + +void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvx 8bit"); + + using Layout = PackImpl8bitAvx::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + static constexpr int kNumRowChunks = 8; // Short input is padded. + + // Each packed block is 4*8, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + Pack8bitColMajorForAvxPacker(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, + sums_ptr, trailing_buf); + + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvx float"); + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + float trailing_buf[(kPackRows - 1) * kPackCols]; + if (remaining_src_cols < 8) { + memset(trailing_buf, 0, sizeof(trailing_buf)); + } + PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx, Path::kAvx>( + src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr, + trailing_buf); + + const int trailing_rows = src_rows & (kPackRows - 1); + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~(kPackRows - 1); + memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, + kPackCols * trailing_rows * sizeof(float)); + } +} + +void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride, + int src_zero_point, std::int8_t* packed_ptr, + int packed_stride, int start_col, int end_col, + int src_cols, int block_row, int src_rows, + int input_xor, std::int32_t* sums) { + int col = start_col; + int src_end_col = std::min(end_col, src_cols); + + for (; col <= src_end_col - 8; col += 8) { + std::int8_t* dst_ptr = packed_ptr; + __m128i val0, val1, val2, val3; + __m128i input_xor_dup = _mm_set1_epi8(input_xor); + // Load a 4x8 block. + if (block_row + 4 <= src_rows) { + val0 = _mm_loadu_si64(src_ptr + 0 * src_stride); + val1 = _mm_loadu_si64(src_ptr + 1 * src_stride); + val2 = _mm_loadu_si64(src_ptr + 2 * src_stride); + val3 = _mm_loadu_si64(src_ptr + 3 * src_stride); + } else { + val0 = _mm_set1_epi8(src_zero_point); + val1 = val0; + val2 = val0; + val3 = val0; + if (block_row + 0 < src_rows) + val0 = _mm_loadu_si64(src_ptr + 0 * src_stride); + if (block_row + 1 < src_rows) + val1 = _mm_loadu_si64(src_ptr + 1 * src_stride); + if (block_row + 2 < src_rows) + val2 = _mm_loadu_si64(src_ptr + 2 * src_stride); + if (block_row + 3 < src_rows) + val3 = _mm_loadu_si64(src_ptr + 3 * src_stride); + } + // Maybe xor the sign bit to convert from uint8 to int8. + val0 = _mm_xor_si128(val0, input_xor_dup); + val1 = _mm_xor_si128(val1, input_xor_dup); + val2 = _mm_xor_si128(val2, input_xor_dup); + val3 = _mm_xor_si128(val3, input_xor_dup); + // Update the sums. + __m128i val16_0 = _mm_cvtepi8_epi16(val0); + __m128i val16_1 = _mm_cvtepi8_epi16(val1); + __m128i val16_2 = _mm_cvtepi8_epi16(val2); + __m128i val16_3 = _mm_cvtepi8_epi16(val3); + __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1), + _mm_add_epi16(val16_2, val16_3)); + __m256i sum = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col)); + sum = mm256_add_epi32(sum, mm256_cvtepi16_epi32(new_sum16)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum); + // Perform the transposition of 4x4 blocks + __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1); + __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3); + __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1); + __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1); + src_ptr += 8; + packed_ptr += packed_stride * 8; + } + for (; col < src_end_col; col++) { + std::int32_t accum = 0; + for (int r = 0; r < 4; r++) { + std::int8_t packed_val; + if (block_row + r < src_rows) { + packed_val = input_xor ^ src_ptr[r * src_stride]; + } else { + packed_val = input_xor ^ src_zero_point; + } + accum += packed_val; + *packed_ptr++ = packed_val; + } + if (sums) { + sums[col] += accum; + } + src_ptr++; + } + for (; col < end_col; col++) { + std::memset(packed_ptr, 0, 4); + packed_ptr += 4; + } +} + +#endif // RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_avx2_fma.cc b/ruy/pack_avx2_fma.cc new file mode 100644 index 0000000..2564b72 --- /dev/null +++ b/ruy/pack_avx2_fma.cc @@ -0,0 +1,689 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <cstring> + +#include "ruy/check_macros.h" +#include "ruy/opt_set.h" +#include "ruy/pack_x86.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS) +#include <immintrin.h> // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)) + +void Pack8bitColMajorForAvx2(const std::int8_t*, std::int8_t, + const std::int8_t*, int, int, int, std::int8_t*, + std::int32_t*) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatColMajorForAvx2(const float*, const float*, int, int, int, + float*) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Pack8bitRowMajorForAvx2(const std::uint8_t*, int, int, std::int8_t*, int, + int, int, int, int, int, int, std::int32_t*) { + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvx2 = + PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, + std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>; + +using PackImplFloatAvx2 = + PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, + float, float, Order::kColMajor>; + +namespace { + +inline void Pack8bitColMajorForAvx2Packer( + const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, int remaining_src_cols, + int src_rows, std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitAvx2::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have Layout::kCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: Layout::kCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + std::int32_t sums_adjustment = 0; + const __m256i ones_16bit = _mm256_set1_epi16(1); + __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); + __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + if (sums_ptr) { + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0)); + t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4)); + t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1)); + t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5)); + t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2)); + t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6)); + t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3)); + t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7)); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + __m256i sums_4x4_16bit_lo; + sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); + + // The sums have been performed across columns, and now we have 4x16-bit + // sums packed together. We use madd for pairwise 32-bit sums. + const __m256i sums_4x2_32bit_lo_new = + _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); + sums_4x2_32bit_lo = + _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); + + __m256i sums_4x4_16bit_hi; + sums_4x4_16bit_hi = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); + + const __m256i sums_4x2_32bit_hi_new = + _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); + sums_4x2_32bit_hi = + _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), + r7); + } else { + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0)); + t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4)); + t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1)); + t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5)); + t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2)); + t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6)); + t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3)); + t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7)); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), + r7); + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // We compensate for padding-with-zero_point by initializing the + // summations with the compensating offset, effectively + // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + // + // Note that (zero_point ^ input_xor) is performed in 8-bits and then + // cast. + sums_adjustment += + -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); + + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr0); + t4 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr4); + t1 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr1); + t5 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr5); + t2 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr2); + t6 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr6); + t3 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr3); + t7 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr7); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + __m256i sums_4x4_16bit_lo; + sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); + + // The sums have been performed across columns, and now we have 4x16-bit + // sums packed together. We use madd for pairwise 32-bit sums. + const __m256i sums_4x2_32bit_lo_new = + _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); + sums_4x2_32bit_lo = + _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); + + __m256i sums_4x4_16bit_hi; + sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); + + const __m256i sums_4x2_32bit_hi_new = + _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); + sums_4x2_32bit_hi = + _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), + r7); + } + + packed_ptr += 8 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + + if (sums_ptr) { + const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); + + __m256i sums = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr)); + const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + + // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the + // neighbours, finshing up by adding them to the stored accumulated sums. + const __m256i sums_2x4_32bit_lo = + _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx); + const __m256i sums_2x4_32bit_hi = + _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx); + const __m256i sums_2x4_32bit_a = + _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); + const __m256i sums_2x4_32bit_b = + _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); + sums = _mm256_add_epi32(sums, sums_adjustment_v); + sums = _mm256_add_epi32(sums, sums_2x4_32bit_a); + sums = _mm256_add_epi32(sums, sums_2x4_32bit_b); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); + } +} + +// Use AVX2 specific intrinsic for greater than comparison. +template <> +inline __m256i CompareGreaterThan<Path::kAvx2Fma>(const __m256i& a, + const __m256i& b) { + return _mm256_cmpgt_epi32(a, b); +} + +} // namespace. + +void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvx2Fma 8bit"); + + using Layout = PackImpl8bitAvx2::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + static constexpr int kNumRowChunks = 8; // Short input is padded. + + // Each packed block is 4*8, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + Pack8bitColMajorForAvx2Packer(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, + sums_ptr, trailing_buf); + + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvx2Fma float"); + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + float trailing_buf[(kPackRows - 1) * kPackCols]; + if (remaining_src_cols < 8) { + memset(trailing_buf, 0, sizeof(trailing_buf)); + } + PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx2, Path::kAvx2Fma>( + src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr, + trailing_buf); + + const int trailing_rows = src_rows & (kPackRows - 1); + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~(kPackRows - 1); + memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, + kPackCols * trailing_rows * sizeof(float)); + } +} + +void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride, + int src_zero_point, std::int8_t* packed_ptr, + int packed_stride, int start_col, int end_col, + int src_cols, int block_row, int src_rows, + int input_xor, std::int32_t* sums) { + int col = start_col; + int src_end_col = std::min(end_col, src_cols); + + for (; col <= src_end_col - 8; col += 8) { + std::int8_t* dst_ptr = packed_ptr; + __m128i val0, val1, val2, val3; + __m128i input_xor_dup = _mm_set1_epi8(input_xor); + // Load a 4x8 block. + if (block_row + 4 <= src_rows) { + val0 = _mm_loadu_si64(src_ptr + 0 * src_stride); + val1 = _mm_loadu_si64(src_ptr + 1 * src_stride); + val2 = _mm_loadu_si64(src_ptr + 2 * src_stride); + val3 = _mm_loadu_si64(src_ptr + 3 * src_stride); + } else { + val0 = _mm_set1_epi8(src_zero_point); + val1 = val0; + val2 = val0; + val3 = val0; + if (block_row + 0 < src_rows) + val0 = _mm_loadu_si64(src_ptr + 0 * src_stride); + if (block_row + 1 < src_rows) + val1 = _mm_loadu_si64(src_ptr + 1 * src_stride); + if (block_row + 2 < src_rows) + val2 = _mm_loadu_si64(src_ptr + 2 * src_stride); + if (block_row + 3 < src_rows) + val3 = _mm_loadu_si64(src_ptr + 3 * src_stride); + } + // Maybe xor the sign bit to convert from uint8 to int8. + val0 = _mm_xor_si128(val0, input_xor_dup); + val1 = _mm_xor_si128(val1, input_xor_dup); + val2 = _mm_xor_si128(val2, input_xor_dup); + val3 = _mm_xor_si128(val3, input_xor_dup); + // Update the sums. + __m128i val16_0 = _mm_cvtepi8_epi16(val0); + __m128i val16_1 = _mm_cvtepi8_epi16(val1); + __m128i val16_2 = _mm_cvtepi8_epi16(val2); + __m128i val16_3 = _mm_cvtepi8_epi16(val3); + __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1), + _mm_add_epi16(val16_2, val16_3)); + __m256i sum = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col)); + sum = _mm256_add_epi32(sum, _mm256_cvtepi16_epi32(new_sum16)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum); + // Perform the transposition of 4x4 blocks + __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1); + __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3); + __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1); + __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1); + src_ptr += 8; + packed_ptr += packed_stride * 8; + } + for (; col < src_end_col; col++) { + std::int32_t accum = 0; + for (int r = 0; r < 4; r++) { + std::int8_t packed_val; + if (block_row + r < src_rows) { + packed_val = input_xor ^ src_ptr[r * src_stride]; + } else { + packed_val = input_xor ^ src_zero_point; + } + accum += packed_val; + *packed_ptr++ = packed_val; + } + if (sums) { + sums[col] += accum; + } + src_ptr++; + } + for (; col < end_col; col++) { + std::memset(packed_ptr, 0, 4); + packed_ptr += 4; + } +} + +#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc new file mode 100644 index 0000000..5281fa8 --- /dev/null +++ b/ruy/pack_avx512.cc @@ -0,0 +1,828 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <cstring> + +#include "ruy/check_macros.h" +#include "ruy/opt_set.h" +#include "ruy/pack_x86.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS) +#include <immintrin.h> // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM)) + +void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t, + const std::int8_t*, int, int, int, std::int8_t*, + std::int32_t*) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatColMajorForAvx512(const float*, const float*, int, int, int, + float*) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int, + int, int, int, int, int, int, std::int32_t*) { + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvx512 = + PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, + std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>; + +namespace { + +inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point, + std::int8_t* packed_ptr) { + using Layout = PackImpl8bitAvx512::Layout; + static constexpr int kHalfLayoutCols = + PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a + // block. + RUY_DCHECK_EQ(kHalfLayoutCols, 8); + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + + const int non_trailing_blocks = (src_rows & ~31) >> 2; + // This routine fills half blocks, and typically fills the second halves. + // Thus packed_ptr is already offset by 8 * 4. + for (int k = 0; k < non_trailing_blocks; ++k) { + for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) { + packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point; + } + } +} + +inline __m512i LoaduTwo(const std::int8_t* addr_lo, + const std::int8_t* addr_hi) { + __m512i lower_filled = _mm512_castsi256_si512( + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo))); + return _mm512_inserti32x8( + lower_filled, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_hi)), 1); +} + +inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, + const std::int8_t* addr_lo, + const std::int8_t* addr_hi) { + const __m512i lower_filled = _mm512_castsi256_si512( + _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo)); + return _mm512_inserti32x8( + lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi), + 1); +} + +inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitAvx512::Layout; + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have kHalfLayoutCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: kHalfLayoutCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + std::int32_t sums_adjustment = 0; + const __m512i ones_16bit = _mm512_set1_epi16(1); + __m512i sums_8x2_32bit = _mm512_set1_epi32(0); + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { + // m: {0, 1} for 2 chunks of rows. + for (int m = 0; m < 2; ++m) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + // i: chunks, s: Layout::Rows. + if (sums_ptr) { + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + + __m512i sums_8x4_16bit; + sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); + // The sums have been performed across columns, and now we have + // 4x16-bit sums packed together. We use madd for pairwise 32-bit + // sums. + const __m512i sums_8x2_32bit_new = + _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); + + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1); + } else { + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1); + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); + const __mmask32 row_mask = + (static_cast<std::uint64_t>(1) << available_src_rows) - 1; + + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // We compensate for padding-with-zero_point by initializing the + // summations with the compensating offset, effectively + // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + // + // Note that (zero_point ^ input_xor) is performed in 8-bits and then + // cast. + sums_adjustment += -(zero_point ^ input_xor) * 4 * + (8 - ((available_src_rows + 3) >> 2)); + + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + const __m256i zero_point_v = _mm256_set1_epi8(zero_point); + + t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); + t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); + t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); + t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + + __m512i sums_8x4_16bit; + sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); + // The sums have been performed across columns, and now we have + // 4x16-bit sums packed together. We use madd for pairwise 32-bit + // sums. + const __m512i sums_8x2_32bit_new = + _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); + + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(trailing_buf + 0 * 16 * 4), r0_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(trailing_buf + 2 * 16 * 4), r0_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(trailing_buf + 4 * 16 * 4), r1_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(trailing_buf + 6 * 16 * 4), r1_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(trailing_buf + 1 * 16 * 4), r2_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(trailing_buf + 3 * 16 * 4), r2_1); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(trailing_buf + 5 * 16 * 4), r3_0); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(trailing_buf + 7 * 16 * 4), r3_1); + } + + packed_ptr += 16 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } + + if (sums_ptr) { + const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); + + __m256i sums = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr)); + const __m512i idx = + _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); + + // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the + // neighbours, finshing up by adding them to the stored accumulated sums. + const __m512i sums_2x8_32bit = + _mm512_permutexvar_epi32(idx, sums_8x2_32bit); + sums = _mm256_add_epi32(sums, sums_adjustment_v); + sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); + sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); + } +} + +inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) { + const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo)); + return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1); +} + +inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo, + const float* addr_hi) { + const __m512 lower_filled = + _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo)); + return _mm512_insertf32x8(lower_filled, + _mm256_maskz_loadu_ps(row_mask, addr_hi), 1); +} + +inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) { + return _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); +} + +inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) { + return _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); +} + +inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += 16) { + for (int m = 0; m < 2; ++m) { + const int available_src_rows = src_rows - k - 8 * m; + // Effectively, + // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); + // but treat each case separately. + if (available_src_rows > 7) { + __m512 t0, t1, t2, t3; + __m512 r0, r1, r2, r3; + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_ps(t0, t1); + r2 = _mm512_unpackhi_ps(t0, t1); + r1 = _mm512_unpacklo_ps(t2, t3); + r3 = _mm512_unpackhi_ps(t2, t3); + + t0 = Mm512UnpackloPsx2(r0, r1); + t2 = Mm512UnpackhiPsx2(r0, r1); + t1 = Mm512UnpackloPsx2(r2, r3); + t3 = Mm512UnpackhiPsx2(r2, r3); + + r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); + + _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0)); + _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); + _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1)); + _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); + _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2)); + _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); + _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3)); + _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1)); + } else if (available_src_rows > 0) { + const __mmask8 row_mask = + (static_cast<std::uint32_t>(1) << available_src_rows) - 1; + + __m512 t0, t1, t2, t3; + __m512 r0, r1, r2, r3; + + t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4); + t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5); + t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6); + t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_ps(t0, t1); + r2 = _mm512_unpackhi_ps(t0, t1); + r1 = _mm512_unpacklo_ps(t2, t3); + r3 = _mm512_unpackhi_ps(t2, t3); + + t0 = Mm512UnpackloPsx2(r0, r1); + t2 = Mm512UnpackhiPsx2(r0, r1); + t1 = Mm512UnpackloPsx2(r2, r3); + t3 = Mm512UnpackhiPsx2(r2, r3); + + r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); + + _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0)); + _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); + _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1)); + _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); + _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2)); + _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); + _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3)); + // Do not store _mm512_extractf32x8_ps(r3, 1). + } + + packed_ptr += 16 * 8; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } +} + +inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) { + const int non_trailing_rows = src_rows & ~7; + for (int k = 0; k < non_trailing_rows; ++k) { + for (int j = 0; j < 8; ++j) { + packed_ptr[j] = 0.0f; + } + packed_ptr += 16; + } +} + +} // namespace. + +void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, + std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvx512 8bit"); + + using Layout = PackImpl8bitAvx512::Layout; + constexpr int kHalfBlockOffset = 32; + RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); + static constexpr int kHalfLayoutCols = + PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a + // block. + RUY_DCHECK_EQ(kHalfLayoutCols, 8); + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + + // Each packed block is 4*16, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + std::int32_t* second_sums_ptr = + sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; + if (remaining_src_cols > kHalfLayoutCols) { + HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor, + zerobuf, src_stride, + remaining_src_cols - kHalfLayoutCols, src_rows, + packed_ptr + kHalfBlockOffset, second_sums_ptr, + trailing_buf + kHalfBlockOffset); + } else { + HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor, + packed_ptr + kHalfBlockOffset); + // The kernel may not need the second half-blocks sums to be set. + if (second_sums_ptr) { + for (int i = 0; i < kHalfLayoutCols; ++i) { + second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); + } + } + } + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvx512 float"); + float trailing_buf[7 * 16]; + if (remaining_src_cols > 8) { + HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride, + remaining_src_cols - 8, src_rows, packed_ptr + 8, + trailing_buf + 8); + } else { + memset(trailing_buf, 0, sizeof(trailing_buf)); + HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + ZeroHalfFloatAvx512(src_rows, packed_ptr + 8); + } + const int trailing_rows = src_rows & 7; + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~7; + memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, + 16 * trailing_rows * sizeof(float)); + } +} + +void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride, + int src_zero_point, std::int8_t* packed_ptr, + int packed_stride, int start_col, int end_col, + int src_cols, int block_row, int src_rows, + int input_xor, std::int32_t* sums) { + int col = start_col; + int src_end_col = std::min(end_col, src_cols); + + for (; col <= src_end_col - 16; col += 16) { + std::int8_t* dst_ptr = packed_ptr; + __m128i val0, val1, val2, val3; + __m128i input_xor_dup = _mm_set1_epi8(input_xor); + // Load a 4x16 block. + if (block_row + 4 <= src_rows) { + val0 = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride)); + val1 = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride)); + val2 = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride)); + val3 = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride)); + } else { + val0 = _mm_set1_epi8(src_zero_point); + val1 = val0; + val2 = val0; + val3 = val0; + if (block_row + 0 < src_rows) + val0 = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride)); + if (block_row + 1 < src_rows) + val1 = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride)); + if (block_row + 2 < src_rows) + val2 = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride)); + if (block_row + 3 < src_rows) + val3 = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride)); + } + // Maybe xor the sign bit to convert from uint8 to int8. + val0 = _mm_xor_si128(val0, input_xor_dup); + val1 = _mm_xor_si128(val1, input_xor_dup); + val2 = _mm_xor_si128(val2, input_xor_dup); + val3 = _mm_xor_si128(val3, input_xor_dup); + // Update the sums. + __m256i val16_0 = _mm256_cvtepi8_epi16(val0); + __m256i val16_1 = _mm256_cvtepi8_epi16(val1); + __m256i val16_2 = _mm256_cvtepi8_epi16(val2); + __m256i val16_3 = _mm256_cvtepi8_epi16(val3); + __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1), + _mm256_add_epi16(val16_2, val16_3)); + __m512i sum = + _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col)); + sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16)); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum); + auto zip = [](__m128i x, __m128i y) { + auto perm_64_0_64_0 = [](__m128i x) { + return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3), + _mm256_castsi128_si256(x)); + }; + return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y)); + }; + __m256i t2_val0 = zip(val0, val1); + __m256i t2_val1 = zip(val2, val3); + __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1); + __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), + _mm256_extractf128_si256(t4_val0, 0)); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), + _mm256_extractf128_si256(t4_val1, 0)); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32), + _mm256_extractf128_si256(t4_val0, 1)); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48), + _mm256_extractf128_si256(t4_val1, 1)); + src_ptr += 16; + packed_ptr += packed_stride * 16; + } + for (; col < src_end_col; col++) { + std::int32_t accum = 0; + for (int r = 0; r < 4; r++) { + std::int8_t packed_val; + if (block_row + r < src_rows) { + packed_val = input_xor ^ src_ptr[r * src_stride]; + } else { + packed_val = input_xor ^ src_zero_point; + } + accum += packed_val; + *packed_ptr++ = packed_val; + } + if (sums) { + sums[col] += accum; + } + src_ptr++; + } + for (; col < end_col; col++) { + std::memset(packed_ptr, 0, 4); + packed_ptr += 4; + } +} + +#endif // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_common.h b/ruy/pack_common.h new file mode 100644 index 0000000..8f07658 --- /dev/null +++ b/ruy/pack_common.h @@ -0,0 +1,143 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PACK_COMMON_H_ +#define RUY_RUY_PACK_COMMON_H_ + +#include <algorithm> +#include <cstdint> +#include <cstring> +#include <limits> +#include <type_traits> + +#include "ruy/check_macros.h" +#include "ruy/mat.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +template <typename Scalar> +Scalar SymmetricZeroPoint() { + if (std::is_floating_point<Scalar>::value) { + return 0; + } + if (std::is_signed<Scalar>::value) { + return 0; + } + return std::numeric_limits<Scalar>::max() / 2 + 1; +} + +template <Path ThePath, typename Scalar> +struct PackedTypeImpl { + using Type = Scalar; +}; + +template <Path ThePath, typename Scalar> +using PackedType = typename PackedTypeImpl<ThePath, Scalar>::Type; + +template <typename PackedScalar, typename Scalar> +PackedScalar Pack(Scalar x) { + return x - SymmetricZeroPoint<Scalar>() + SymmetricZeroPoint<PackedScalar>(); +} + +template <Path ThePath, typename FixedKernelLayout, typename Scalar, + typename PackedScalar, typename SumsType, Order SrcOrder> +struct PackImpl; + +#define RUY_INHERIT_PACK(PARENT, CHILD) \ + template <typename FixedKernelLayout, typename Scalar, \ + typename PackedScalar, typename SumsType, Order SrcOrder> \ + struct PackImpl<CHILD, FixedKernelLayout, Scalar, PackedScalar, SumsType, \ + SrcOrder> : PackImpl<PARENT, FixedKernelLayout, Scalar, \ + PackedScalar, SumsType, SrcOrder> {}; + +// A generic yet fairly fast implementation of +// +// PackImpl<ThePath, FixedKernelLayout<Order::kRowMajor, 1, KernelCols>, +// float, float, float, Order::kRowMajor> +// +// that is, a packing code path for the case of floating-point, row-major +// source matrix, targeting typical float kernel layouts consisting of a +// single row. +// +// The only reason why this isn't a partial specialization of PackImpl is that +// this leads to ambiguous partial specializations as this conflicts with +// the ones defined by RUY_INHERIT_PACK. +// +// What's special about floating-point kernels is that they tend to use +// FixedKernelLayout<Order::kRowMajor, 1, KernelCols> for some value of +// KernelCols, making it easy to implement the packing code as essentially +// a bunch of memcpy's with compile-time-fixed size +// (KernelCols * sizeof(float)), typically 16, 32 or 64 bytes. Unlike the +// quantized case, there are no sums to compute, and the float kernels tend +// to use this kind of simple layout on multiple architectures, unlike the +// heavily architecture-specific layouts of quantized kernels. +// +// Here are the current instantiations of this template (as of 2020): +// Path | KernelCols +// --------------+--------------------------------- +// kNeon (ARM32) | 8 and 4 (for LHS and RHS sides) +// kNeon (ARM64) | 8 +// kAvxFma | 8 +// kAvx512 | 16 +template <Path ThePath, int KernelCols> +struct MemcpyRowMajorFloatPackImpl { + static void Run(Tuning, const Mat<float>& src_matrix, + PMat<float>* packed_matrix, int start_col, int end_col) { + RUY_DCHECK(IsRowMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % KernelCols, 0); + int src_stride = src_matrix.layout.stride; + // As the source matrix is row-major and the destination packed matrix is + // column-major, there is no traversal order that will be optimal for both + // so we choose to favor the source matrix with a row-major traversal order. + for (int block_row = 0; block_row < src_matrix.layout.rows; + block_row += 1) { + const float* src_ptr = + src_matrix.data.get() + src_stride * block_row + start_col; + float* packed_ptr = packed_matrix->data + + packed_matrix->layout.stride * start_col + + KernelCols * block_row; + int src_cols = std::min(end_col, src_matrix.layout.cols) - start_col; + int col = 0; + for (; col <= src_cols - KernelCols; col += KernelCols) { + memcpy(packed_ptr, src_ptr, KernelCols * sizeof(float)); + packed_ptr += KernelCols * packed_matrix->layout.stride; + src_ptr += KernelCols; + } + int remaining_cols = src_cols - col; + if (remaining_cols > 0) { + memcpy(packed_ptr, src_ptr, remaining_cols * sizeof(float)); + memset(packed_ptr + remaining_cols, 0, + (KernelCols - remaining_cols) * sizeof(float)); + } + } + } +}; + +#define RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(ThePath, KernelCols) \ + template <> \ + struct PackImpl<ThePath, FixedKernelLayout<Order::kRowMajor, 1, KernelCols>, \ + float, float, float, Order::kRowMajor> \ + : MemcpyRowMajorFloatPackImpl<ThePath, KernelCols> {}; + +} // namespace ruy + +#endif // RUY_RUY_PACK_COMMON_H_ diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h new file mode 100644 index 0000000..f3ea54e --- /dev/null +++ b/ruy/pack_x86.h @@ -0,0 +1,659 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PACK_X86_H_ +#define RUY_RUY_PACK_X86_H_ + +#include <cstdint> +#include <cstring> +#include <type_traits> + +#include "ruy/check_macros.h" +#include "ruy/mat.h" +#include "ruy/opt_set.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM_X86 + +RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx) +RUY_INHERIT_PACK(Path::kAvx, Path::kAvx2Fma) +RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512) + +RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx2Fma, 8) +RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx512, 16) + +template <> +struct PackedTypeImpl<Path::kAvx, std::uint8_t> { + using Type = std::int8_t; +}; + +template <> +struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl<Path::kAvx512, std::uint8_t> { + using Type = std::int8_t; +}; + +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template <typename Scalar> +struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, + Scalar, std::int8_t, std::int32_t, Order::kColMajor> { + static_assert(std::is_same<Scalar, std::int8_t>::value || + std::is_same<Scalar, std::uint8_t>::value, + ""); + using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>; + static constexpr std::int8_t kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX2 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[Layout::kCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + Layout::kCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitColMajorForAvx2( + reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, + reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); + } + } +}; + +void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template <typename Scalar> +struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar, + std::int8_t, std::int32_t, Order::kColMajor> { + static_assert(std::is_same<Scalar, std::int8_t>::value || + std::is_same<Scalar, std::uint8_t>::value, + ""); + using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>; + static constexpr std::int8_t kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[Layout::kCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + Layout::kCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitColMajorForAvx( + reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, + reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); + } + } +}; + +void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr); + +template <> +struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, + float, float, Order::kColMajor> { + using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + static void Run(Tuning, const Mat<float>& src_matrix, + PMat<float>* packed_matrix, int start_col, int end_col) { + profiler::ScopeLabel label("Pack (AVX float)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatColMajorForAvx(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; + +void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr); + +template <> +struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, + float, float, float, Order::kColMajor> { + using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>; + static void Run(Tuning, const Mat<float>& src_matrix, + PMat<float>* packed_matrix, int start_col, int end_col) { + profiler::ScopeLabel label("Pack (AVX2 float)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatColMajorForAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; + +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template <typename Scalar> +struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, + Scalar, std::int8_t, std::int32_t, Order::kColMajor> { + static_assert(std::is_same<Scalar, std::int8_t>::value || + std::is_same<Scalar, std::uint8_t>::value, + ""); + using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>; + static constexpr int kHalfLayoutCols = + 8; // Half the number of cols in a block. + static constexpr std::int8_t kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitColMajorForAvx512( + reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, + reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); + } + } +}; + +void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr); + +template <> +struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>, + float, float, float, Order::kColMajor> { + static void Run(Tuning, const Mat<float>& src_matrix, + PMat<float>* packed_matrix, int start_col, int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 float)"); + using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>; + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatColMajorForAvx512(src_ptr, zerobuf, src_stride, + remaining_src_cols, src_matrix.layout.rows, + packed_ptr); + } + } +}; + +void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride, + int src_zero_point, std::int8_t* packed_ptr, + int packed_stride, int start_col, int end_col, + int src_cols, int block_row, int src_rows, + int input_xor, std::int32_t* sums); + +template <typename Scalar> +struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, + Scalar, std::int8_t, std::int32_t, Order::kRowMajor> { + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)"); + RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); + static constexpr int kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + std::int32_t* sums = packed_matrix->sums; + std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); + int block_row = 0; + for (; block_row < packed_matrix->layout.rows; block_row += 4) { + int src_stride = src_matrix.layout.stride; + int packed_stride = packed_matrix->layout.stride; + const Scalar* src_ptr = + src_matrix.data.get() + block_row * src_stride + start_col; + std::int8_t* packed_ptr = + packed_matrix->data + start_col * packed_stride + block_row * 8; + Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr), + src_stride, src_matrix.zero_point, packed_ptr, + packed_stride, start_col, end_col, + src_matrix.layout.cols, block_row, + src_matrix.layout.rows, kInputXor, sums); + } + } +}; + +void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride, + int src_zero_point, std::int8_t* packed_ptr, + int packed_stride, int start_col, int end_col, + int src_cols, int block_row, int src_rows, + int input_xor, std::int32_t* sums); + +template <typename Scalar> +struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar, + std::int8_t, std::int32_t, Order::kRowMajor> { + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX 8bit row-major)"); + RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); + static constexpr int kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + std::int32_t* sums = packed_matrix->sums; + std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); + int block_row = 0; + for (; block_row < packed_matrix->layout.rows; block_row += 4) { + int src_stride = src_matrix.layout.stride; + int packed_stride = packed_matrix->layout.stride; + const Scalar* src_ptr = + src_matrix.data.get() + block_row * src_stride + start_col; + std::int8_t* packed_ptr = + packed_matrix->data + start_col * packed_stride + block_row * 8; + Pack8bitRowMajorForAvx(reinterpret_cast<const std::uint8_t*>(src_ptr), + src_stride, src_matrix.zero_point, packed_ptr, + packed_stride, start_col, end_col, + src_matrix.layout.cols, block_row, + src_matrix.layout.rows, kInputXor, sums); + } + } +}; + +void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride, + int src_zero_point, std::int8_t* packed_ptr, + int packed_stride, int start_col, int end_col, + int src_cols, int block_row, int src_rows, + int input_xor, std::int32_t* sums); + +template <typename Scalar> +struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, + Scalar, std::int8_t, std::int32_t, Order::kRowMajor> { + static void Run(Tuning, const Mat<Scalar>& src_matrix, + PMat<std::int8_t>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)"); + RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); + static constexpr int kInputXor = + std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; + std::int32_t* sums = packed_matrix->sums; + std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); + int block_row = 0; + for (; block_row < packed_matrix->layout.rows; block_row += 4) { + int src_stride = src_matrix.layout.stride; + int packed_stride = packed_matrix->layout.stride; + const Scalar* src_ptr = + src_matrix.data.get() + block_row * src_stride + start_col; + std::int8_t* packed_ptr = + packed_matrix->data + start_col * packed_stride + block_row * 16; + Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr), + src_stride, src_matrix.zero_point, packed_ptr, + packed_stride, start_col, end_col, + src_matrix.layout.cols, block_row, + src_matrix.layout.rows, kInputXor, sums); + } + } +}; +#endif // RUY_PLATFORM_X86 + +} // namespace ruy + +#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)) + +#include <immintrin.h> // IWYU pragma: keep + +namespace ruy { +namespace { + +template <Path path> +inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) { + return _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); +} + +template <Path path> +inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) { + return _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); +} + +template <Path path> +inline __m256i CompareGreaterThan(const __m256i&, const __m256i&) { + RUY_DCHECK(false); + return _mm256_set1_epi32(0); +} + +// Shared between AVX and AVX2+FMA. +template <Path path> +inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point, + const std::int8_t* addr) { + RUY_DCHECK_LT(available_src_rows, 32); + __m256i padded_data; + + if (available_src_rows >= 16) { + __m128i load_hi = _mm_set1_epi8(zero_point); + __m128i load_lo = _mm_loadu_si128(reinterpret_cast<const __m128i*>(addr)); + memcpy(&load_hi, addr + 16, available_src_rows - 16); + padded_data = _mm256_set_m128i(load_hi, load_lo); + } else { + __m128i load_hi = _mm_set1_epi8(zero_point); + __m128i load_lo = load_hi; + memcpy(&load_lo, addr, available_src_rows); + padded_data = _mm256_set_m128i(load_hi, load_lo); + } + return padded_data; +} + +} // namespace. + +template <typename PackImpl, Path path> +inline void PackFloatColMajorForAvxCommonPacker(const float* src_ptr, + const float* zerobuf, + int src_stride, + int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + RUY_DCHECK_EQ(PackImpl::Layout::kCols, 8); + RUY_DCHECK_EQ(PackImpl::Layout::kRows, 1); + + // This packing amounts to transposition of 8x8 blocks. + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + // Handle cases where source does not have kPackDim (8) columns. + if (remaining_src_cols < kPackCols) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += kPackRows) { + const int available_src_rows = src_rows - k; + // Effectively, + // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); + // but treat each case separately. + if (available_src_rows >= kPackRows) { + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + + t0 = _mm256_loadu_ps(src_ptr0); + t4 = _mm256_loadu_ps(src_ptr4); + t1 = _mm256_loadu_ps(src_ptr1); + t5 = _mm256_loadu_ps(src_ptr5); + t2 = _mm256_loadu_ps(src_ptr2); + t6 = _mm256_loadu_ps(src_ptr6); + t3 = _mm256_loadu_ps(src_ptr3); + t7 = _mm256_loadu_ps(src_ptr7); + + r0 = _mm256_unpacklo_ps(t0, t1); + r4 = _mm256_unpacklo_ps(t4, t5); + r2 = _mm256_unpackhi_ps(t0, t1); + r6 = _mm256_unpackhi_ps(t4, t5); + r1 = _mm256_unpacklo_ps(t2, t3); + r5 = _mm256_unpacklo_ps(t6, t7); + r3 = _mm256_unpackhi_ps(t2, t3); + r7 = _mm256_unpackhi_ps(t6, t7); + + t0 = Mm256UnpackloPsx2<path>(r0, r1); + t4 = Mm256UnpackloPsx2<path>(r4, r5); + t2 = Mm256UnpackhiPsx2<path>(r0, r1); + t6 = Mm256UnpackhiPsx2<path>(r4, r5); + t1 = Mm256UnpackloPsx2<path>(r2, r3); + t5 = Mm256UnpackloPsx2<path>(r6, r7); + t3 = Mm256UnpackhiPsx2<path>(r2, r3); + t7 = Mm256UnpackhiPsx2<path>(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by 16 + // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) + // are interleaved to create (r0, r1). This complexity follows from the + // way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2f128_ps(t0, t4, 0x20); + r4 = _mm256_permute2f128_ps(t1, t5, 0x20); + r1 = _mm256_permute2f128_ps(t0, t4, 0x31); + r5 = _mm256_permute2f128_ps(t1, t5, 0x31); + r2 = _mm256_permute2f128_ps(t2, t6, 0x20); + r6 = _mm256_permute2f128_ps(t3, t7, 0x20); + r3 = _mm256_permute2f128_ps(t2, t6, 0x31); + r7 = _mm256_permute2f128_ps(t3, t7, 0x31); + + _mm256_storeu_ps(packed_ptr + 0 * 8, r0); + _mm256_storeu_ps(packed_ptr + 2 * 8, r4); + _mm256_storeu_ps(packed_ptr + 4 * 8, r1); + _mm256_storeu_ps(packed_ptr + 6 * 8, r5); + _mm256_storeu_ps(packed_ptr + 1 * 8, r2); + _mm256_storeu_ps(packed_ptr + 3 * 8, r6); + _mm256_storeu_ps(packed_ptr + 5 * 8, r3); + _mm256_storeu_ps(packed_ptr + 7 * 8, r7); + } else if (available_src_rows > 0) { + const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + const __m256i row_mask_v = CompareGreaterThan<path>( + _mm256_set1_epi32(available_src_rows), series); + + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + + t0 = _mm256_maskload_ps(src_ptr0, row_mask_v); + t4 = _mm256_maskload_ps(src_ptr4, row_mask_v); + t1 = _mm256_maskload_ps(src_ptr1, row_mask_v); + t5 = _mm256_maskload_ps(src_ptr5, row_mask_v); + t2 = _mm256_maskload_ps(src_ptr2, row_mask_v); + t6 = _mm256_maskload_ps(src_ptr6, row_mask_v); + t3 = _mm256_maskload_ps(src_ptr3, row_mask_v); + t7 = _mm256_maskload_ps(src_ptr7, row_mask_v); + + r0 = _mm256_unpacklo_ps(t0, t1); + r4 = _mm256_unpacklo_ps(t4, t5); + r2 = _mm256_unpackhi_ps(t0, t1); + r6 = _mm256_unpackhi_ps(t4, t5); + r1 = _mm256_unpacklo_ps(t2, t3); + r5 = _mm256_unpacklo_ps(t6, t7); + r3 = _mm256_unpackhi_ps(t2, t3); + r7 = _mm256_unpackhi_ps(t6, t7); + + t0 = Mm256UnpackloPsx2<path>(r0, r1); + t4 = Mm256UnpackloPsx2<path>(r4, r5); + t2 = Mm256UnpackhiPsx2<path>(r0, r1); + t6 = Mm256UnpackhiPsx2<path>(r4, r5); + t1 = Mm256UnpackloPsx2<path>(r2, r3); + t5 = Mm256UnpackloPsx2<path>(r6, r7); + t3 = Mm256UnpackhiPsx2<path>(r2, r3); + t7 = Mm256UnpackhiPsx2<path>(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by 16 + // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) + // are interleaved to create (r0, r1). This complexity follows from the + // way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2f128_ps(t0, t4, 0x20); + r4 = _mm256_permute2f128_ps(t1, t5, 0x20); + r1 = _mm256_permute2f128_ps(t0, t4, 0x31); + r5 = _mm256_permute2f128_ps(t1, t5, 0x31); + r2 = _mm256_permute2f128_ps(t2, t6, 0x20); + r6 = _mm256_permute2f128_ps(t3, t7, 0x20); + r3 = _mm256_permute2f128_ps(t2, t6, 0x31); + // r7 no longer needed. + + _mm256_storeu_ps(trailing_buf + 0 * 8, r0); + _mm256_storeu_ps(trailing_buf + 2 * 8, r4); + _mm256_storeu_ps(trailing_buf + 4 * 8, r1); + _mm256_storeu_ps(trailing_buf + 6 * 8, r5); + _mm256_storeu_ps(trailing_buf + 1 * 8, r2); + _mm256_storeu_ps(trailing_buf + 3 * 8, r6); + _mm256_storeu_ps(trailing_buf + 5 * 8, r3); + // No store to (trailing_buf + 7 * 8), space not allocated. + } + + packed_ptr += kPackRows * kPackCols; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } +} +} // namespace ruy +#endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM) + +#endif // RUY_RUY_PACK_X86_H_ diff --git a/ruy/path.h b/ruy/path.h new file mode 100644 index 0000000..d3c5b06 --- /dev/null +++ b/ruy/path.h @@ -0,0 +1,203 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PATH_H_ +#define RUY_RUY_PATH_H_ + +#include <cstdint> + +#include "ruy/platform.h" +#include "ruy/size_util.h" + +namespace ruy { + +// A Path is an implementation path, typically corresponding to a SIMD +// instruction set being targetted. For example, on the ARM architecture, +// Path::kNeon means using NEON instructions, and Path::kNeonDotprod means +// also using the newer NEON dot-product instructions. +// +// Different Path enum values are defined on different CPU architectures, +// corresponding to different SIMD ISA extensions available there. +// +// Path::kStandardCpp is the one Path that is always available. +// +// Path enum values are bits and may be OR-ed to form "sets of Paths". +// Ruy entry points such as ruy::Mul either implicitly use such a set of Paths, +// or allow passing an explicit one as a template parameter. The meaning of such +// an OR-ed Path combination is "compile all of +// these paths; which path is used will be determined at runtime". This is why +// for most users, it is enough to call ruy::Mul(...), which will compile a +// reasonable selection of paths for the target CPU architecture's various +// SIMD ISA extensions, and let ruy determine at runtime which one to use. +// Internally, after the actual path has been resolved, ruy's internal functions +// templatized on a Path tend to require that to be a single bit. +// +// An element of ruy's internal design was to allow for code compiled for +// multiple such paths to coexist without violating the C++ One Definition Rule +// (ODR). This is achieved by having all ruy internal functions, whose +// definition depends on a choice of Path, be templatized on a Path, so that +// each path-specific specialization is a separate symbol. There is never +// a need to compile ruy code with different compilation flags to enable +// different SIMD extensions and dispatch at runtime between them, as this is +// taken care of internally by ruy in an ODR-correct way. +enum class Path : std::uint8_t { + // This is a special null value, representing the absence of any path. + kNone = 0, + // Standard C++ implementation of Ruy's architecture-specific parts. + // + // This is intended for testing/development, and as a fallback for when + // the SIMD ISA extensions required by other paths are unavailable at runtime. + kStandardCpp = 0x1, + // Internal, test-only variants of StandardCpp used to exercise more corners + // of internal ruy logic. + // They are intentionally omitted from ruy::kAllPaths and ruy::kNonArchPaths, + // and are only ever used in dedicated ruy tests explicitly referencing them. + kInternalStandardCppVariant1 = 0x2, + kInternalStandardCppVariant2 = 0x4, + kInternalStandardCppVariant3 = 0x8, + +#if RUY_PLATFORM_ARM + // Optimized path using a widely available subset of ARM NEON instructions. + kNeon = 0x10, + // Optimized path making use of ARM NEON dot product instructions that are + // available on newer ARM cores. + kNeonDotprod = 0x20, +#endif // RUY_PLATFORM_ARM + +#if RUY_PLATFORM_X86 + // Optimized for AVX + // Compiled with -mavx + kAvx = 0x10, + // Optimized for AVX2+FMA. + // Compiled with -mavx2 -mfma. + kAvx2Fma = 0x20, + // Optimized for AVX-512. + // Compiled with -mavx512f -mavx512vl -mavx512cd -mavx512bw -mavx512dq. + kAvx512 = 0x40, +#endif // RUY_PLATFORM_X86 +}; + +inline constexpr Path operator|(Path p, Path q) { + return static_cast<Path>(static_cast<std::uint32_t>(p) | + static_cast<std::uint32_t>(q)); +} + +inline constexpr Path operator&(Path p, Path q) { + return static_cast<Path>(static_cast<std::uint32_t>(p) & + static_cast<std::uint32_t>(q)); +} + +inline constexpr Path operator^(Path p, Path q) { + return static_cast<Path>(static_cast<std::uint32_t>(p) ^ + static_cast<std::uint32_t>(q)); +} + +inline constexpr Path operator~(Path p) { + return static_cast<Path>(~static_cast<std::uint32_t>(p)); +} + +inline constexpr bool Disjoint(Path p, Path q) { + return (p & q) == Path::kNone; +} + +inline Path GetMostSignificantPath(Path path_mask) { + return static_cast<Path>(round_down_pot(static_cast<int>(path_mask))); +} + +// We define three disjoint sets of paths. +// +// kNonArchPaths is the set of paths that are defined regardless of +// the CPU architecture (excluding some internal test-only paths). +// These paths are slow, but portable. At the moment, +// that is only kStandardCpp. In the past, that used to also include a +// kReference path providing an even more basic implementation, but that has +// been split out into a separate library, see the ReferenceMul function. +constexpr Path kNonArchPaths = Path::kStandardCpp; + +// The other two are specific to each CPU architecture. Note that these sets +// do NOT include a fallback for when none of these architecture paths are +// supported at runtime by the CPU. For that, see the other constants defined +// further below. +// +// kDefaultArchPaths is the set of architecture-specific paths that +// we recommend for most users. It is part of kDefaultPaths defined +// below. +// +// kExtraArchPaths is the set of all other architecture-specific paths +// that for whatever reason we're not recommending to most users at the moment. +// Typically that would include work-in-progress paths, or paths targeting +// minority hardware that isn't the best compromise of code size to performance +// for most users. + +#if RUY_PLATFORM_NEON_64 +constexpr Path kDefaultArchPaths = Path::kNeon | Path::kNeonDotprod; +constexpr Path kExtraArchPaths = Path::kNone; +#elif RUY_PLATFORM_NEON_32 +constexpr Path kDefaultArchPaths = Path::kNeon; +constexpr Path kExtraArchPaths = Path::kNone; +#elif RUY_PLATFORM_X86 +constexpr Path kDefaultArchPaths = Path::kAvx | Path::kAvx2Fma | Path::kAvx512; +constexpr Path kExtraArchPaths = Path::kNone; +#else +constexpr Path kDefaultArchPaths = Path::kNone; +constexpr Path kExtraArchPaths = Path::kNone; +#endif + +// kNonArchPathsIncludingInternalVariants is the set of all +// non-architecture-specific paths without exception. This includes some paths +// that are internal-only and test-only and not useful to any user. +static constexpr Path kNonArchPathsIncludingInternalVariants = + kNonArchPaths | Path::kInternalStandardCppVariant1 | + Path::kInternalStandardCppVariant2 | Path::kInternalStandardCppVariant3; + +// Enforce that kDefaultArchPaths, kExtraArchPaths and +// kNonArchPathsIncludingInternalVariants are mutually disjoint, +// and that kNonArchPaths is a subset of kNonArchPathsIncludingInternalVariants. +static_assert(Disjoint(kDefaultArchPaths, kExtraArchPaths), ""); +static_assert(Disjoint(kDefaultArchPaths, + kNonArchPathsIncludingInternalVariants), + ""); +static_assert(Disjoint(kExtraArchPaths, kNonArchPathsIncludingInternalVariants), + ""); +static_assert(Disjoint(kNonArchPaths, ~kNonArchPathsIncludingInternalVariants), + ""); + +// We now define two aggregate sets of paths for convenience, including +// both architecture-specific paths and some portable fallbacks. +// +// kDefaultPaths is the set of paths that we recommend most users to use. +// It is what ruy::Mul(...), the entry point not taking an explicit Path value, +// uses. +constexpr Path kDefaultPaths = Path::kStandardCpp | kDefaultArchPaths; + +// kAllPaths is the set of all paths that are available to compile, except +// some some internal test-only paths that no user would ever want to use. +// In addition to the Default paths, it also includes the extra +// architecture paths, as well as any other non-arch path besides kStandardCpp +// (there is none at the moment). +constexpr Path kAllPaths = kNonArchPaths | kDefaultArchPaths | kExtraArchPaths; + +// kAllPathsIncludingInternalVariants is the set of all paths without exception. +// This includes some paths that are internal-only and test-only and not useful +// to any user. +static constexpr Path kAllPathsIncludingInternalVariants = + kAllPaths | kNonArchPathsIncludingInternalVariants; + +static_assert(Disjoint(kDefaultPaths, ~kAllPaths), ""); +static_assert(Disjoint(kAllPaths, ~kAllPathsIncludingInternalVariants), ""); + +} // namespace ruy + +#endif // RUY_RUY_PATH_H_ diff --git a/ruy/perchannel_buffers_reallocation_test.cc b/ruy/perchannel_buffers_reallocation_test.cc new file mode 100644 index 0000000..0754aef --- /dev/null +++ b/ruy/perchannel_buffers_reallocation_test.cc @@ -0,0 +1,120 @@ +#include "ruy/context.h" +#include "ruy/gtest_wrapper.h" +#include "ruy/kernel.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/performance_advisory.h" +#include "ruy/ruy.h" + +namespace ruy { +namespace { + +constexpr Path kPath = Path::kInternalStandardCppVariant3; +constexpr int kBufferSize = 64; + +template <typename AccumScalar, typename DstScalar, + bool HaveQuantizedMultipliers = + std::is_same<AccumScalar, std::int32_t>::value && + !std::is_same<DstScalar, std::int32_t>::value> +struct PopulatePerChannelBuffersImpl { + static void Run(MulParams<AccumScalar, DstScalar>* mul_params) { + static const AccumScalar bias_buf[kBufferSize] = {0}; + static const AccumScalar multiplier_fixedpoint_buf[kBufferSize] = {0}; + static const int multiplier_exponent_buf[kBufferSize] = {0}; + mul_params->set_bias(bias_buf); + mul_params->set_multiplier_fixedpoint_perchannel(multiplier_fixedpoint_buf); + mul_params->set_multiplier_exponent_perchannel(multiplier_exponent_buf); + } +}; + +template <typename AccumScalar, typename DstScalar> +struct PopulatePerChannelBuffersImpl<AccumScalar, DstScalar, false> { + static void Run(MulParams<AccumScalar, DstScalar>* mul_params) { + static const AccumScalar bias_buf[kBufferSize] = {0}; + mul_params->set_bias(bias_buf); + } +}; + +template <typename AccumScalar, typename DstScalar> +void PopulatePerChannelBuffers(MulParams<AccumScalar, DstScalar>* mul_params) { + PopulatePerChannelBuffersImpl<AccumScalar, DstScalar>::Run(mul_params); +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestPerChannelBuffersReallocation() { + using KernelType = Kernel<kPath, float, float, float, float>; + + MulParams<AccumScalar, DstScalar> mul_params; + PopulatePerChannelBuffers(&mul_params); + + const int kMatrixSize = 3; + ruy::Matrix<LhsScalar> lhs; + ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kRowMajor, + lhs.mutable_layout()); + const LhsScalar lhs_data[kMatrixSize * kMatrixSize] = {0}; + lhs.set_data(lhs_data); + ruy::Matrix<RhsScalar> rhs; + ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kColMajor, + rhs.mutable_layout()); + const RhsScalar rhs_data[kMatrixSize * kMatrixSize] = {0}; + rhs.set_data(rhs_data); + DstScalar dst_data[kMatrixSize * kMatrixSize] = {0}; + ruy::Matrix<DstScalar> dst; + ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kColMajor, + dst.mutable_layout()); + dst.set_data(dst_data); + + ruy::Context context; + + auto test_advisory = [&](bool expect_advisory, + ChannelDimension channel_dimension, + int capacity_rounding) { + mul_params.set_channel_dimension(channel_dimension); + mul_params.set_perchannel_buffers_capacity_rounding(capacity_rounding); + ruy::Mul<kPath>(lhs, rhs, mul_params, &context, &dst); + EXPECT_EQ(context.performance_advisory( + PerformanceAdvisory::kReallocatedPerChannelBuffer), + expect_advisory); + }; + + static_assert(KernelType::LhsLayout::kCols == 16, ""); + test_advisory(true, ChannelDimension::kRow, 1); + test_advisory(true, ChannelDimension::kRow, 2); + test_advisory(true, ChannelDimension::kRow, 4); + test_advisory(true, ChannelDimension::kRow, 8); + test_advisory(false, ChannelDimension::kRow, 16); + test_advisory(false, ChannelDimension::kRow, 32); + test_advisory(false, ChannelDimension::kRow, 64); + + static_assert(KernelType::RhsLayout::kCols == 8, ""); + test_advisory(true, ChannelDimension::kCol, 1); + test_advisory(true, ChannelDimension::kCol, 2); + test_advisory(true, ChannelDimension::kCol, 4); + test_advisory(false, ChannelDimension::kCol, 8); + test_advisory(false, ChannelDimension::kCol, 16); + test_advisory(false, ChannelDimension::kCol, 32); + test_advisory(false, ChannelDimension::kCol, 64); +} + +TEST(PerChannelBuffersReallocationTest, Float) { + TestPerChannelBuffersReallocation<float, float, float, float>(); +} + +TEST(PerChannelBuffersReallocationTest, Quantized) { + TestPerChannelBuffersReallocation<std::int8_t, std::int8_t, std::int32_t, + std::int8_t>(); +} + +TEST(PerChannelBuffersReallocationTest, RawInt32) { + TestPerChannelBuffersReallocation<std::int8_t, std::int8_t, std::int32_t, + std::int32_t>(); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/performance_advisory.h b/ruy/performance_advisory.h new file mode 100644 index 0000000..02dd5d3 --- /dev/null +++ b/ruy/performance_advisory.h @@ -0,0 +1,40 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PERFORMANCE_ADVISORY_H_ +#define RUY_RUY_PERFORMANCE_ADVISORY_H_ + +namespace ruy { + +enum class PerformanceAdvisory { + kNone = 0, + kReallocatedPerChannelBuffer = 0x1 +}; + +inline constexpr PerformanceAdvisory operator|(PerformanceAdvisory p, + PerformanceAdvisory q) { + return static_cast<PerformanceAdvisory>(static_cast<int>(p) | + static_cast<int>(q)); +} + +inline constexpr PerformanceAdvisory operator&(PerformanceAdvisory p, + PerformanceAdvisory q) { + return static_cast<PerformanceAdvisory>(static_cast<int>(p) & + static_cast<int>(q)); +} + +} // namespace ruy + +#endif // RUY_RUY_PERFORMANCE_ADVISORY_H_ diff --git a/ruy/platform.h b/ruy/platform.h new file mode 100644 index 0000000..eb51931 --- /dev/null +++ b/ruy/platform.h @@ -0,0 +1,159 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Preprocessor platform check macros. +// Note that ruy_copts contains '-Wundef', which ensures that we get a compile +// error if these macros are mistyped or if they are used without having +// #included this header. + +#ifndef RUY_RUY_PLATFORM_H_ +#define RUY_RUY_PLATFORM_H_ + +#ifdef __ANDROID_NDK__ +#include <android/ndk-version.h> +#endif + +// Detect APPLE. +#ifdef __APPLE__ +#define RUY_PLATFORM_APPLE 1 +#else +#define RUY_PLATFORM_APPLE 0 +#endif + +// Detect APPLE. +#if defined(__ppc__) || defined(__powerpc__) +#define RUY_PLATFORM_PPC 1 +#else +#define RUY_PLATFORM_PPC 0 +#endif + +// Detect Fuchsia +#ifdef __Fuchsia__ +#define RUY_PLATFORM_FUCHSIA 1 +#else +#define RUY_PLATFORM_FUCHSIA 0 +#endif + +// Architecture-level platform detection. +// +// Ruy requires these to be mutually exclusive. + +// Detect x86. +#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \ + defined(__x86__) || defined(__X86__) || defined(_X86_) || \ + defined(_M_IX86) || defined(_M_X64) +#define RUY_PLATFORM_X86 1 +#else +#define RUY_PLATFORM_X86 0 +#endif + +// Detect ARM 32-bit. +#ifdef __arm__ +#define RUY_PLATFORM_ARM_32 1 +#else +#define RUY_PLATFORM_ARM_32 0 +#endif + +// Detect ARM 64-bit. +#ifdef __aarch64__ +#define RUY_PLATFORM_ARM_64 1 +#else +#define RUY_PLATFORM_ARM_64 0 +#endif + +// Combined ARM. +#define RUY_PLATFORM_ARM (RUY_PLATFORM_ARM_64 || RUY_PLATFORM_ARM_32) + +// Feature and capability platform detection. +// +// These are mostly sub-selections of architectures. + +// Detect NEON. Explicitly avoid emulation, or anything like it, on x86. +#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM_X86 +#define RUY_PLATFORM_NEON 1 +#else +#define RUY_PLATFORM_NEON 0 +#endif + +// Define ARM 32-bit NEON. +#define RUY_PLATFORM_NEON_32 (RUY_PLATFORM_NEON && RUY_PLATFORM_ARM_32) + +// Define ARM 64-bit NEON. +// Note: NEON is implied by ARM64, so this define is redundant. +// It still allows some conveyance of intent. +#define RUY_PLATFORM_NEON_64 (RUY_PLATFORM_NEON && RUY_PLATFORM_ARM_64) + +// Determine whether to enable X86 non-portable performance improvements, +// typically x86 SIMD paths (AVX, etc). +#if defined(RUY_FORCE_ENABLE_X86_ENHANCEMENTS) +#define RUY_PLATFORM_X86_ENHANCEMENTS 1 +#elif defined(__EMSCRIPTEN__) +// We use some x86 asm e.g. in runtime CPU detection and to implement missing +// intrinsics. This can't build to Emscripten. +#define RUY_PLATFORM_X86_ENHANCEMENTS 0 +#elif defined(__ANDROID_NDK__) && defined(__NDK_MAJOR__) && \ + (__NDK_MAJOR__ >= 20) +// Enable on sufficiently recent Android NDK. Earlier versions had broken +// intrinsics headers. +#define RUY_PLATFORM_X86_ENHANCEMENTS 1 +#elif defined(__linux__) && defined(__clang__) && (__clang_major__ >= 8) +// Enable on recent versions of Clang on Linux. Might be possible +// to relax this version requirement. +// Not enabling on Apple at the moment because b/138922878, see comment #8, we +// may only need to disable this on XCode <= 10.2. +#define RUY_PLATFORM_X86_ENHANCEMENTS 1 +#elif defined(__GNUC__) && (__GNUC__ >= 9) +// Enable on recent versions of GCC. Might be possible +// to relax this version requirement. +#define RUY_PLATFORM_X86_ENHANCEMENTS 1 +// Things are working on MSVC 2019. This should also enable on sufficiently +// recent Clang-CL. +#elif defined(_MSC_VER) && (_MSC_VER >= 1920) +#define RUY_PLATFORM_X86_ENHANCEMENTS 1 +#else +#define RUY_PLATFORM_X86_ENHANCEMENTS 0 +#endif + +// These CPU capabilities will all be true when Skylake, etc, are enabled during +// compilation. +#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && \ + defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \ + defined(__AVX512BW__) && defined(__AVX512VL__) +#define RUY_PLATFORM_AVX512 1 +#else +#define RUY_PLATFORM_AVX512 0 +#endif + +#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX2__) && \ + (defined(__FMA__) || defined(_MSC_VER)) +#define RUY_PLATFORM_AVX2_FMA 1 +#else +#define RUY_PLATFORM_AVX2_FMA 0 +#endif + +#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX__) +#define RUY_PLATFORM_AVX 1 +#else +#define RUY_PLATFORM_AVX 0 +#endif + +// Detect Emscripten, typically Wasm. +#ifdef __EMSCRIPTEN__ +#define RUY_PLATFORM_EMSCRIPTEN 1 +#else +#define RUY_PLATFORM_EMSCRIPTEN 0 +#endif + +#endif // RUY_RUY_PLATFORM_H_ diff --git a/ruy/pmu.cc b/ruy/pmu.cc new file mode 100644 index 0000000..d1a60be --- /dev/null +++ b/ruy/pmu.cc @@ -0,0 +1,297 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/pmu.h" + +#include "ruy/check_macros.h" + +#ifdef __linux__ +#include <asm/unistd.h> +#include <linux/perf_event.h> +#include <sys/ioctl.h> +#include <syscall.h> +#include <unistd.h> + +#include <cstdio> +#endif + +#include <algorithm> +#include <cstdint> +#include <cstdlib> +#include <cstring> + +namespace ruy { + +// Linux-specific. Not ARM-specific. +#ifdef __linux__ +class PerfEvent { + public: + PerfEvent(std::uint32_t type, std::uint64_t config) { + perf_event_attr pe; + memset(&pe, 0, sizeof(pe)); + pe.size = sizeof(pe); + pe.type = type; + pe.config = config; + pe.disabled = 1; + pe.exclude_kernel = 1; + pe.exclude_hv = 1; + pe.inherit = 1; + fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0); + if (fd_ == -1) { + fprintf(stderr, "perf_event_open failed for config 0x%lx\n", + static_cast<unsigned long>(config)); + // abort(); + } + } + + ~PerfEvent() { + RUY_CHECK(!started_); + close(fd_); + } + + void Start() { + RUY_CHECK(!started_); + started_ = true; + ioctl(fd_, PERF_EVENT_IOC_RESET, 0); + ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0); + count_at_start_ = Read(); + } + + void Stop() { + RUY_CHECK(started_); + started_ = false; + ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0); + count_at_stop_ = Read(); + } + + std::int64_t Count() const { + RUY_CHECK(!started_); + return count_at_stop_ - count_at_start_; + } + + private: + std::int64_t Read() const { + std::int64_t count; + RUY_CHECK_NE(read(fd_, &count, sizeof(count)), -1); + return count; + } + std::int64_t count_at_start_ = -1; + std::int64_t count_at_stop_ = -1; + bool started_ = false; + int fd_ = -1; +}; +#else +// Placeholder implementation to at least compile outside of linux. +#define PERF_TYPE_RAW 0 +class PerfEvent { + public: + PerfEvent(std::uint32_t, std::uint64_t) {} + ~PerfEvent() {} + void Start() {} + void Stop() {} + std::int64_t Count() const { return 0; } +}; +#endif + +// ARM-specific. Query ARM PMU counters as Linux perf events using +// PERF_TYPE_RAW. +namespace arm_pmuv3 { + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-const-variable" + +// These event numbers are listed in the ARMv8 architecture reference manual. +constexpr std::uint16_t L1I_CACHE_REFILL = 0x01; +constexpr std::uint16_t L1I_TLB_REFILL = 0x02; +constexpr std::uint16_t L1D_CACHE_REFILL = 0x03; +constexpr std::uint16_t L1D_CACHE = 0x04; +constexpr std::uint16_t L1D_TLB_REFILL = 0x05; +constexpr std::uint16_t LD_RETIRED = 0x06; +constexpr std::uint16_t ST_RETIRED = 0x07; +constexpr std::uint16_t INST_RETIRED = 0x08; +constexpr std::uint16_t EXC_TAKEN = 0x09; +constexpr std::uint16_t EXC_RETURN = 0x0A; +constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B; +constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C; +constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D; +constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E; +constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F; +constexpr std::uint16_t BR_MIS_PRED = 0x10; +constexpr std::uint16_t CPU_CYCLES = 0x11; +constexpr std::uint16_t BR_PRED = 0x12; +constexpr std::uint16_t MEM_ACCESS = 0x13; +constexpr std::uint16_t L1I_CACHE = 0x14; +constexpr std::uint16_t L1D_CACHE_WB = 0x15; +constexpr std::uint16_t L2D_CACHE = 0x16; +constexpr std::uint16_t L2D_CACHE_REFILL = 0x17; +constexpr std::uint16_t L2D_CACHE_WB = 0x18; +constexpr std::uint16_t BUS_ACCESS = 0x19; +constexpr std::uint16_t MEMORY_ERROR = 0x1A; +constexpr std::uint16_t INST_SPEC = 0x1B; +constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C; +constexpr std::uint16_t BUS_CYCLES = 0x1D; +constexpr std::uint16_t CHAIN = 0x1E; +constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F; +constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20; +constexpr std::uint16_t BR_RETIRED = 0x21; +constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22; +constexpr std::uint16_t STALL_FRONTEND = 0x23; +constexpr std::uint16_t STALL_BACKEND = 0x24; +constexpr std::uint16_t L1D_TLB = 0x25; +constexpr std::uint16_t L1I_TLB = 0x26; +constexpr std::uint16_t L2I_CACHE = 0x27; +constexpr std::uint16_t L2I_CACHE_REFILL = 0x28; +constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29; +constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A; +constexpr std::uint16_t L3D_CACHE = 0x2B; +constexpr std::uint16_t L3D_CACHE_WB = 0x2C; +constexpr std::uint16_t L2D_TLB_REFILL = 0x2D; +constexpr std::uint16_t L2I_TLB_REFILL = 0x2E; +constexpr std::uint16_t L2D_TLB = 0x2F; +constexpr std::uint16_t L2I_TLB = 0x30; +constexpr std::uint16_t LL_CACHE = 0x32; +constexpr std::uint16_t LL_CACHE_MISS = 0x33; +constexpr std::uint16_t DTLB_WALK = 0x34; +constexpr std::uint16_t LL_CACHE_RD = 0x36; +constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37; + +// Additional implementation-defined events found by googling around. +constexpr std::uint16_t L1D_CACHE_RD = 0x40; +constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42; +constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C; +constexpr std::uint16_t L1D_TLB_RD = 0x4E; +constexpr std::uint16_t L2D_CACHE_RD = 0x50; +constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52; +constexpr std::uint16_t BUS_ACCESS_RD = 0x60; +constexpr std::uint16_t MEM_ACCESS_RD = 0x66; +constexpr std::uint16_t L3D_CACHE_RD = 0xA0; +constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2; + +#pragma GCC diagnostic pop + +} // namespace arm_pmuv3 + +class PmuEventsPrivate { + public: + PmuEventsPrivate() + : l1d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_REFILL), + l2d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_REFILL), + l3d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L3D_CACHE_REFILL), + ll_cache_miss(PERF_TYPE_RAW, arm_pmuv3::LL_CACHE_MISS), + l1d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_TLB_REFILL), + l2d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_TLB_REFILL), + stall_frontend(PERF_TYPE_RAW, arm_pmuv3::STALL_FRONTEND), + stall_backend(PERF_TYPE_RAW, arm_pmuv3::STALL_BACKEND), + br_mis_pred(PERF_TYPE_RAW, arm_pmuv3::BR_MIS_PRED), + l1d_cache_writeback(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_WB), + l2d_cache_writeback(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_WB) {} + + private: + friend class PmuEvents; + PerfEvent l1d_cache_refill; + PerfEvent l2d_cache_refill; + PerfEvent l3d_cache_refill; + PerfEvent ll_cache_miss; + PerfEvent l1d_tlb_refill; + PerfEvent l2d_tlb_refill; + PerfEvent stall_frontend; + PerfEvent stall_backend; + PerfEvent br_mis_pred; + PerfEvent l1d_cache_writeback; + PerfEvent l2d_cache_writeback; +}; + +PmuEvents::PmuEvents() : priv(new PmuEventsPrivate) {} +PmuEvents::~PmuEvents() { delete priv; } + +void PmuEvents::StartRecording() { + priv->l1d_cache_refill.Start(); + priv->l2d_cache_refill.Start(); + priv->l3d_cache_refill.Start(); + priv->ll_cache_miss.Start(); + priv->l1d_tlb_refill.Start(); + priv->l2d_tlb_refill.Start(); + priv->stall_frontend.Start(); + priv->stall_backend.Start(); + priv->br_mis_pred.Start(); + priv->l1d_cache_writeback.Start(); + priv->l2d_cache_writeback.Start(); +} + +void PmuEvents::StopRecording() { + priv->l1d_cache_refill.Stop(); + priv->l2d_cache_refill.Stop(); + priv->l3d_cache_refill.Stop(); + priv->ll_cache_miss.Stop(); + priv->l1d_tlb_refill.Stop(); + priv->l2d_tlb_refill.Stop(); + priv->stall_frontend.Stop(); + priv->stall_backend.Stop(); + priv->br_mis_pred.Stop(); + priv->l1d_cache_writeback.Stop(); + priv->l2d_cache_writeback.Stop(); +} + +float PmuEvents::BranchMispredictionCount() const { + return static_cast<float>(priv->br_mis_pred.Count()); +} + +float PmuEvents::FrontendStallCount() const { + return static_cast<float>(priv->stall_frontend.Count()); +} + +float PmuEvents::BackendStallCount() const { + return static_cast<float>(priv->stall_backend.Count()); +} + +float PmuEvents::L1RefillCount() const { + return static_cast<float>(priv->l1d_cache_refill.Count()); +} + +float PmuEvents::L2RefillCount() const { + return static_cast<float>(priv->l2d_cache_refill.Count()); +} + +float PmuEvents::L3RefillCount() const { + // Important: this was discovered in the context of the above experiments, + // which also tested the _RD variants of these counters. So it's possible that + // it's just not needed here with the default (non _RD) counters. + // + // Some CPUs implement LL_CACHE_MISS[_RD], some implement + // L3D_CACHE_REFILL[_RD]. It seems that either one of these two counters is + // zero, or they roughly both agree with each other. Therefore, taking the max + // of them is a reasonable way to get something more portable across various + // CPUs. + return static_cast<float>( + std::max(priv->l3d_cache_refill.Count(), priv->ll_cache_miss.Count())); +} + +float PmuEvents::L1TLBRefillCount() const { + return static_cast<float>(priv->l1d_tlb_refill.Count()); +} + +float PmuEvents::L2TLBRefillCount() const { + return static_cast<float>(priv->l2d_tlb_refill.Count()); +} + +float PmuEvents::L1WritebackCount() const { + return static_cast<float>(priv->l1d_cache_writeback.Count()); +} + +float PmuEvents::L2WritebackCount() const { + return static_cast<float>(priv->l2d_cache_writeback.Count()); +} + +} // namespace ruy diff --git a/ruy/pmu.h b/ruy/pmu.h new file mode 100644 index 0000000..2d769e2 --- /dev/null +++ b/ruy/pmu.h @@ -0,0 +1,46 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PMU_H_ +#define RUY_RUY_PMU_H_ + +namespace ruy { + +class PmuEventsPrivate; + +class PmuEvents { + public: + PmuEvents(); + ~PmuEvents(); + void StartRecording(); + void StopRecording(); + float L1RefillCount() const; + float L2RefillCount() const; + float L3RefillCount() const; + float BranchMispredictionCount() const; + float FrontendStallCount() const; + float BackendStallCount() const; + float L1TLBRefillCount() const; + float L2TLBRefillCount() const; + float L1WritebackCount() const; + float L2WritebackCount() const; + + private: + PmuEventsPrivate* priv = nullptr; +}; + +} // namespace ruy + +#endif // RUY_RUY_PMU_H_ diff --git a/ruy/prepacked_cache.cc b/ruy/prepacked_cache.cc new file mode 100644 index 0000000..ee891cb --- /dev/null +++ b/ruy/prepacked_cache.cc @@ -0,0 +1,129 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/prepacked_cache.h" + +#include "ruy/mat.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/system_aligned_alloc.h" + +namespace ruy { + +namespace { + +// Allocates the `data` and `sums` buffers, and sets the corresponding +// pointer fields, in a PEMat whose other fields, particularly `layout` +// and the runtime data types, are already populated. +int AllocateBuffers(PEMat* packed_matrix) { + const int data_bytes = DataBytes(*packed_matrix); + packed_matrix->data = detail::SystemAlignedAlloc(data_bytes); + int sums_bytes = 0; + if (!packed_matrix->sums_type.is_floating_point) { + // Integer quantized matrices also need the `sums` buffer. + sums_bytes = SumsBytes(*packed_matrix); + packed_matrix->sums = detail::SystemAlignedAlloc(sums_bytes); + } + return data_bytes + sums_bytes; +} + +// Frees the `data` and `sums` buffers held by a PEMat. +void FreeBuffers(const PEMat& packed_matrix) { + detail::SystemAlignedFree(packed_matrix.data); + detail::SystemAlignedFree(packed_matrix.sums); +} + +} // end anonymous namespace + +std::size_t PrepackedCache::KeyHash::operator()( + const PrepackedCache::Key& key) const { + std::size_t src_data_hash = reinterpret_cast<std::size_t>(key.src_data); + // Naive hash of the layout. Based on some heuristic reasoning, not any + // benchmarking. + // A choice of hash function here is just an optimization matter + // anyway, since a hash collision only results in some Key::operator== calls + // to disambiguate, and even just returning src_data_hash, ignoring the layout + // altogether, would probably be good enough, as the case of multiple entries + // with the same data pointer will be uncommon. + // Here we multiply-add the layout fields using some small constant prime + // numbers as multipliers. The conventional approach of xor-ing bit-rotations + // would result in some hash collisions because these values are typically + // small positive integers, so bit-rotations are essentially bit-shifts, + // and powers of two are common. + std::size_t packed_layout_hash = + static_cast<int>(key.packed_layout.order) + + static_cast<int>(key.packed_layout.kernel.order) * 2 + + key.packed_layout.stride * 3 + key.packed_layout.kernel.rows * 5 + + key.packed_layout.kernel.cols * 7 + key.packed_layout.rows * 11 + + key.packed_layout.cols * 13; + return src_data_hash ^ packed_layout_hash; +} + +PrepackedCache::~PrepackedCache() { + for (const auto& pair : cache_) { + FreeBuffers(pair.second.packed_matrix); + } +} + +PrepackedCache::Action PrepackedCache::Get(const void* src_data, + PEMat* packed_matrix) { + // Construct a Key and look up the cache. + Key key; + key.src_data = src_data; + key.packed_layout = packed_matrix->layout; + key.zero_point = packed_matrix->zero_point; + const auto& itr = cache_.find(key); + + if (itr != cache_.end()) { + // Found existing entry. Update its timestamp and return it. + itr->second.timestamp = timestamp_++; + *packed_matrix = itr->second.packed_matrix; + return Action::kGotExistingEntry; + } + + // No existing entry found. Allocate new buffers now and insert in the cache. + const int new_bytes = AllocateBuffers(packed_matrix); + EjectUntilRoomFor(new_bytes); + Entry entry{*packed_matrix, timestamp_++}; + cache_.emplace(key, entry); + buffers_bytes_ += new_bytes; + return Action::kInsertedNewEntry; +} + +void PrepackedCache::EjectUntilRoomFor(int new_bytes) { + profiler::ScopeLabel label("PrepackedCacheEjection"); + // While we are above the threshold of ejection, eject the LRU entry. + while (!cache_.empty() && buffers_bytes_ + new_bytes > max_buffers_bytes_) { + EjectOne(); + } +} + +void PrepackedCache::EjectOne() { + auto oldest = cache_.begin(); + Timestamp oldest_timestamp = oldest->second.timestamp; + { + for (auto itr = cache_.begin(); itr != cache_.end(); ++itr) { + if (itr->second.timestamp < oldest_timestamp) { + oldest = itr; + oldest_timestamp = itr->second.timestamp; + } + } + } + const PEMat& packed_matrix = oldest->second.packed_matrix; + buffers_bytes_ -= DataBytes(packed_matrix) + SumsBytes(packed_matrix); + FreeBuffers(packed_matrix); + cache_.erase(oldest); +} + +} // namespace ruy diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h new file mode 100644 index 0000000..cb3a113 --- /dev/null +++ b/ruy/prepacked_cache.h @@ -0,0 +1,141 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PREPACKED_CACHE_H_ +#define RUY_RUY_PREPACKED_CACHE_H_ + +#include <cstddef> +#include <cstdint> +#include <unordered_map> + +#include "ruy/mat.h" + +namespace ruy { + +// "Low effort" Least Recently Used Cache for Prepacked Matrices +// A cache mechanism for prepacked matrices that ejects oldest entries. +// The implementation is "low effort" in the following ways: +// - we just linearly search for the oldest entry when doing an ejection +// - the ejection policy is very simple: if the new size would be above the +// . threshold, we will eject entries until the size is below the threshold. +// Current use cases (RNNs with GEMV operations) indicate that ejection is rare +// and memory constraints are tight, so we devote no additional storage to the +// LRU mechanism and accept O(n) search to eject oldest entry. In practice, +// the number of total entries has not been shown to be large. +// +// An instance of PrepackedCache is always owned by a Context. Just like +// Context, no "thread safety" consideration is applicable to this class: +// no two threads may simulaneously be accessing the same instance. +class PrepackedCache final { + public: + enum class Action { kGotExistingEntry, kInsertedNewEntry }; + + static constexpr int kDefaultMaxBuffersBytes = 1 << 28; + + // A key in this key-value cache. Equality of keys implies interchangeable + // packed matrices, so we must be careful to make this Key type specific + // enough, and its equality comparison operator strict enough. + // + // These keys need to be used before the packed matrix buffers are allocated + // (since they are used to determine whether to allocate a new buffer). + // So they instead use the source matrix's data pointer. On the other hand, + // the packed matrix's layout structure is already available by the time we + // need to handle Keys, and that's fortunate because it is more specific + // than the source matrix's layout: it also encodes details about the kernel's + // small-scale block layout. In the past, we made the kernel function pointer + // part of the cache key, but all that is relevant here is the packed layout. + // + // The only other field that needs to be involved is the zero_point, for + // quantized matrices, although it seems far-fetched that the same matrix + // data would be reused with different zero_point values. + // + // The data types (PEMat::data_type and PEMat::sums_type) are omitted based on + // the "strict aliasing" model: each memory location should contain data of + // only one type. This could be relaxed in the future, by adding data types + // to this Key type, if a use case arises. + struct Key { + // The source matrix's data pointer. + const void* src_data; + // The packed matrix's layout, see PEMat::layout. + PMatLayout packed_layout; + // The packed matrix's zero point (for integer-quantized matrices only). + std::int32_t zero_point; + }; + + friend bool operator==(const Key& a, const Key& b) { + return a.src_data == b.src_data && a.packed_layout == b.packed_layout && + a.zero_point == b.zero_point; + } + + struct KeyHash { + std::size_t operator()(const Key&) const; + }; + + // A dummy timestamp to associate to each entry (see struct Entry) to + // determine which entry is "least recently used" when ejecting entries. + // This is just an integer counter, not related to physical time. + // It does not need to be atomic because only one thread uses an instance + // of PrepackedCache at a time (see class comment). + using Timestamp = std::uint64_t; + + struct Entry { + PEMat packed_matrix; + Timestamp timestamp; + }; + + explicit PrepackedCache(int max_buffers_bytes = kDefaultMaxBuffersBytes) + : max_buffers_bytes_(max_buffers_bytes) {} + + ~PrepackedCache(); + + // Returns the total size in bytes of buffers held in this cache. + int BuffersBytes() const { return buffers_bytes_; } + + // Returns the number of packed matrices held in this cache. + int MatrixCount() const { return cache_.size(); } + + // This is the method by which new matrices are cached, and existing cache + // entries are queried. + // `src_data` is the source matrix data pointer. + // `packed_matrix` is a packed matrix structure where all fields have already + // been populated, except for the `data` and `sums` pointers which have not + // yet been allocated. + // + // This method: + // 1. Queries the cache for an entry matching the given `src_data` pointer and + // the relevant fields of `packed_matrix`, particularly its `layout`. + // 2. If a matching cache entry does not exist, it is created and inserted + // into the cache, and its `data` and `sums` buffers are allocated. + // 3. The `packed_matrix` has its `data` and `sums` pointers set to point + // to the allocated buffers. + // 4. The cache entry's timestamp is updated so it's the most recently used + // entry. + // 5. The return value is Action::kInsertedNewEntry if at step 2 a new + // entry was created. Otherwise it is Action::kGotExistingEntry. + Action Get(const void* src_data, PEMat* packed_matrix); + + private: + void EjectOne(); + void EjectUntilRoomFor(int new_bytes); + + std::unordered_map<Key, Entry, KeyHash> cache_; + const int max_buffers_bytes_; + int buffers_bytes_ = 0; + Timestamp timestamp_ = 0; +}; + +} // namespace ruy + +#endif // RUY_RUY_PREPACKED_CACHE_H_ diff --git a/ruy/prepacked_cache_test.cc b/ruy/prepacked_cache_test.cc new file mode 100644 index 0000000..9625931 --- /dev/null +++ b/ruy/prepacked_cache_test.cc @@ -0,0 +1,309 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/prepacked_cache.h" + +#include <thread> // NOLINT(build/c++11) + +#include "ruy/context.h" +#include "ruy/context_get_ctx.h" +#include "ruy/gtest_wrapper.h" +#include "ruy/mat.h" +#include "ruy/matrix.h" +#include "ruy/ruy.h" +#include "ruy/time.h" + +namespace ruy { +namespace { + +PEMat MakeDummyPEMat(Type data_type, int rows, int cols) { + PEMat ret; + ret.data_type = data_type; + if (!data_type.is_floating_point) { + ret.sums_type = Type::Create<std::int32_t>(); + } + ret.layout.rows = rows; + ret.layout.cols = cols; + ret.layout.stride = rows; + ret.layout.order = Order::kColMajor; + // The kernel block layout is not relevant to this test, so we leave it + // trivial 1x1. + ret.layout.kernel.rows = 1; + ret.layout.kernel.cols = 1; + return ret; +} + +template <typename T> +void DummyPack(const std::vector<T>& data, PEMat* packed_matrix) { + EXPECT_EQ(data.size(), FlatSize(packed_matrix->layout)); + memcpy(packed_matrix->data, data.data(), data.size() * sizeof(T)); +} + +TEST(PrepackedCacheTest, TestCacheBasic) { + PrepackedCache prepacked_cache(307); + // Allocate the prepacked matrix. + // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes + std::vector<std::uint8_t> data1(10 * 20); + PEMat mat1 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20); + EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) == + PrepackedCache::Action::kInsertedNewEntry); + DummyPack(data1, &mat1); + + // DataBytes=15, SumsBytes=3*4=12, Total: 27 bytes + std::vector<std::uint8_t> data2(5 * 3); + PEMat mat2 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 5, 3); + EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) == + PrepackedCache::Action::kInsertedNewEntry); + DummyPack(data2, &mat2); + + // Both should now be in cache. + EXPECT_EQ(prepacked_cache.MatrixCount(), 2); + EXPECT_EQ(prepacked_cache.BuffersBytes(), 307); + EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) == + PrepackedCache::Action::kGotExistingEntry); + EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) == + PrepackedCache::Action::kGotExistingEntry); +} + +TEST(PrepackedCacheTest, TestCacheBasicFloat) { + PrepackedCache prepacked_cache(860); + // Allocate the prepacked matrix. + // DataBytes=200*4, SumsBytes=0 because float, Total: 800 bytes + std::vector<float> data1(10 * 20); + PEMat mat1 = MakeDummyPEMat(Type::Create<float>(), 10, 20); + EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) == + PrepackedCache::Action::kInsertedNewEntry); + DummyPack(data1, &mat1); + + // DataBytes=15*4, SumsBytes=0 because float, Total: 60 bytes + std::vector<float> data2(5 * 3); + PEMat mat2 = MakeDummyPEMat(Type::Create<float>(), 5, 3); + EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) == + PrepackedCache::Action::kInsertedNewEntry); + DummyPack(data2, &mat2); + + // Both should now be in cache. + EXPECT_EQ(prepacked_cache.MatrixCount(), 2); + EXPECT_EQ(prepacked_cache.BuffersBytes(), 860); + EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) == + PrepackedCache::Action::kGotExistingEntry); + EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) == + PrepackedCache::Action::kGotExistingEntry); +} + +TEST(PrepackedCacheTest, TestCacheEjection) { + PrepackedCache prepacked_cache(306); + // Allocate the prepacked matrix. + // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes + std::vector<std::uint8_t> data1(10 * 20); + PEMat mat1 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20); + prepacked_cache.Get(data1.data(), &mat1); + DummyPack(data1, &mat1); + + // DataBytes=15, SumsBytes=3*4=12, Total: 27 bytes + std::vector<std::uint8_t> data2(5 * 3); + PEMat mat2 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 5, 3); + prepacked_cache.Get(data2.data(), &mat2); + DummyPack(data2, &mat2); + + // The first matrix should have been ejected from the cache. + // Only the second matrix should now be in cache. + EXPECT_EQ(prepacked_cache.MatrixCount(), 1); + EXPECT_EQ(prepacked_cache.BuffersBytes(), 27); + EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) == + PrepackedCache::Action::kGotExistingEntry); + EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) == + PrepackedCache::Action::kInsertedNewEntry); + + // The second matrix should have been ejected from the cache. + // Only the first matrix should now be in cache. + EXPECT_EQ(prepacked_cache.MatrixCount(), 1); + EXPECT_EQ(prepacked_cache.BuffersBytes(), 280); + EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) == + PrepackedCache::Action::kGotExistingEntry); + EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) == + PrepackedCache::Action::kInsertedNewEntry); +} + +TEST(PrepackedCacheTest, TestCacheEjection2) { + PrepackedCache prepacked_cache(1000); + // Allocate the prepacked matrix 1. + // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes + std::vector<std::uint8_t> data1(10 * 20); + PEMat mat1 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20); + prepacked_cache.Get(data1.data(), &mat1); + DummyPack(data1, &mat1); + + // Allocate the prepacked matrix 2. + // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes + std::vector<std::uint8_t> data2(10 * 20); + PEMat mat2 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20); + prepacked_cache.Get(data2.data(), &mat2); + DummyPack(data2, &mat2); + + // Allocate the prepacked matrix 3. + // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes + std::vector<std::uint8_t> data3(10 * 20); + PEMat mat3 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20); + prepacked_cache.Get(data3.data(), &mat3); + DummyPack(data3, &mat3); + + // The next insertion will cause the cache size to go over the ejection + // threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest + EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) == + PrepackedCache::Action::kGotExistingEntry); + EXPECT_TRUE(prepacked_cache.Get(data3.data(), &mat3) == + PrepackedCache::Action::kGotExistingEntry); + + // Allocate the prepacked matrix 4. + // DataBytes=200, SumsBytes=20*4=80, Total: 280 bytes + std::vector<std::uint8_t> data4(10 * 20); + PEMat mat4 = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20); + prepacked_cache.Get(data4.data(), &mat4); + DummyPack(data4, &mat4); + + // Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not. + EXPECT_EQ(prepacked_cache.MatrixCount(), 3); + EXPECT_TRUE(prepacked_cache.Get(data1.data(), &mat1) == + PrepackedCache::Action::kGotExistingEntry); + EXPECT_TRUE(prepacked_cache.Get(data3.data(), &mat3) == + PrepackedCache::Action::kGotExistingEntry); + EXPECT_TRUE(prepacked_cache.Get(data4.data(), &mat4) == + PrepackedCache::Action::kGotExistingEntry); + EXPECT_TRUE(prepacked_cache.Get(data2.data(), &mat2) == + PrepackedCache::Action::kInsertedNewEntry); +} + +TEST(PrepackedCacheTest, TestDistinguishSubtlyDifferentMatrices) { + PrepackedCache prepacked_cache; + + std::vector<std::uint8_t> data(10 * 20); + PEMat mat = MakeDummyPEMat(Type::Create<std::uint8_t>(), 10, 20); + EXPECT_EQ(prepacked_cache.Get(data.data(), &mat), + PrepackedCache::Action::kInsertedNewEntry); + + // Same layout, different source data pointer + EXPECT_EQ(prepacked_cache.Get(data.data() + 1, &mat), + PrepackedCache::Action::kInsertedNewEntry); + + // Layout tweaks + mat.layout.rows = 9; + EXPECT_EQ(prepacked_cache.Get(data.data(), &mat), + PrepackedCache::Action::kInsertedNewEntry); + + mat.layout.cols = 19; + EXPECT_EQ(prepacked_cache.Get(data.data(), &mat), + PrepackedCache::Action::kInsertedNewEntry); + + mat.layout.order = Order::kRowMajor; + EXPECT_EQ(prepacked_cache.Get(data.data(), &mat), + PrepackedCache::Action::kInsertedNewEntry); + + mat.layout.kernel.rows = 2; + EXPECT_EQ(prepacked_cache.Get(data.data(), &mat), + PrepackedCache::Action::kInsertedNewEntry); + + mat.layout.kernel.cols = 2; + EXPECT_EQ(prepacked_cache.Get(data.data(), &mat), + PrepackedCache::Action::kInsertedNewEntry); + + mat.layout.kernel.order = Order::kRowMajor; + EXPECT_EQ(prepacked_cache.Get(data.data(), &mat), + PrepackedCache::Action::kInsertedNewEntry); + + EXPECT_EQ(prepacked_cache.MatrixCount(), 8); +} + +void TestCachePolicies(CachePolicy cache_policy, bool expected_cached) { + ruy::Context context; + ruy::Ctx* ctx = get_ctx(&context); + PrepackedCache* cache = ctx->GetPrepackedCache(); + EXPECT_EQ(cache->MatrixCount(), 0); + + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2}; + float dst_data[4]; + + ruy::Matrix<float> lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout()); + lhs.set_data(lhs_data); + ruy::Matrix<float> rhs; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, rhs.mutable_layout()); + rhs.set_data(rhs_data); + ruy::Matrix<float> dst; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, dst.mutable_layout()); + dst.set_data(dst_data); + + ruy::MulParams<float, float> mul_params; + // Perform the multiplication and confirm no caching occurred. + ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst); + EXPECT_EQ(cache->MatrixCount(), 0); + + // Set cache policy for the LHS, repeat the multiplication, and see + // that caching did occur. + lhs.set_cache_policy(cache_policy); + ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst); + const bool actual_cached = cache->MatrixCount() == 1; + EXPECT_EQ(actual_cached, expected_cached); +} + +TEST(PrepackedCacheTest, TestCachePolicies) { + for (CachePolicy cache_policy : + {CachePolicy::kNeverCache, CachePolicy::kCacheIfLargeSpeedup, + CachePolicy::kCacheIfSignificantSpeedup, CachePolicy::kAlwaysCache}) { + TestCachePolicies(cache_policy, + cache_policy != CachePolicy::kNeverCache); + } +} + +TEST(PrepackedCacheTest, TestClearCache) { + ruy::Context context; + PrepackedCache* cache = get_ctx(&context)->GetPrepackedCache(); + EXPECT_EQ(cache->MatrixCount(), 0); + + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2}; + float dst_data[4]; + + ruy::Matrix<float> lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout()); + lhs.set_data(lhs_data); + ruy::Matrix<float> rhs; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, rhs.mutable_layout()); + rhs.set_data(rhs_data); + ruy::Matrix<float> dst; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, dst.mutable_layout()); + dst.set_data(dst_data); + + ruy::MulParams<float, float> mul_params; + // Set cache policy for the LHS and see that caching occurs. + lhs.set_cache_policy(CachePolicy::kAlwaysCache); + ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst); + EXPECT_NE(cache->MatrixCount(), 0); + + // Clear the cache via the Context. + context.ClearPrepackedCache(); + // Verify that the cache is now empty. + cache = get_ctx(&context)->GetPrepackedCache(); + EXPECT_EQ(cache->MatrixCount(), 0); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/prepare_packed_matrices.cc b/ruy/prepare_packed_matrices.cc new file mode 100644 index 0000000..5a01af7 --- /dev/null +++ b/ruy/prepare_packed_matrices.cc @@ -0,0 +1,94 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/prepare_packed_matrices.h" + +#include "ruy/allocator.h" +#include "ruy/ctx.h" +#include "ruy/matrix.h" +#include "ruy/prepacked_cache.h" +#include "ruy/side_pair.h" +#include "ruy/trace.h" +#include "ruy/trmul_params.h" + +namespace ruy { +namespace { + +// Returns true if the operand on the given side should use caching of the +// packed form. This may either be explicitly dictated by its cache_policy +// (if it is kNeverCache, the default, or kAlwaysCache), or it may depend +// on a heuristic decision based on the other operand's width. For example, +// in a matrix*vector product, for the LHS matrix operand, the other side is +// the RHS vector, with a width of 1, causing the packing of the LHS to be +// a large fraction of the overall work, so a heuristic would typically +// decide in favor of caching, if permitted at all by the cache_policy. +bool ShouldCache(const TrMulParams& params, Side side) { + const CachePolicy cache_policy = params.src[side].cache_policy; + // The width that matters is that of the other side, it is what determines + // the amortization of the packing work done on the present side. + const Side other_side = OtherSide(side); + const int other_width = params.src[other_side].layout.cols; + const int other_kernel_width = + params.packed_matrix[other_side].layout.kernel.cols; + switch (cache_policy) { + case CachePolicy::kNeverCache: + return false; + case CachePolicy::kAlwaysCache: + return true; + case CachePolicy::kCacheIfLargeSpeedup: + // The condition (other_width <= other_kernel_width) means that the kernel + // will traverse each value of the present side only once, meaning that + // the overhead of the packing work will be maximal, hence maximally + // worth caching. + return (other_width <= other_kernel_width); + case CachePolicy::kCacheIfSignificantSpeedup: + // Variant of the heuristic used in the kCacheIfLargeSpeedup case. The + // kernel will run on each value of the present side only a few times, + // so packing overhead will be significant. + return (other_width <= 4 * other_kernel_width); + default: + RUY_DCHECK(false); + return false; + } +} + +} // namespace + +void PreparePackedMatrices(Ctx* ctx, TrMulParams* params) { + RUY_TRACE_SCOPE; + for (Side side : {Side::kLhs, Side::kRhs}) { + PEMat& packed_matrix = params->packed_matrix[side]; + if (ShouldCache(*params, side)) { + // Use a cached packed matrix (possibly packing and caching now). + auto* cache = ctx->GetPrepackedCache(); + auto action = cache->Get(params->src[side].data, &packed_matrix); + RUY_TRACE_INFO(PREPARE_PACKED_MATRICES_SHOULD_CACHE); + if (action == PrepackedCache::Action::kInsertedNewEntry) { + params->RunPack(side, ctx->GetMainThreadTuning(), 0, + packed_matrix.layout.cols); + } + params->is_prepacked[side] = true; + } else { + RUY_TRACE_INFO(PREPARE_PACKED_MATRICES_NO_CACHE); + // Do not use a cached packed matrix. Only need to allocate buffers now. + Allocator* allocator = ctx->GetMainAllocator(); + packed_matrix.data = allocator->AllocateBytesAvoidingAliasingWith( + DataBytes(packed_matrix), params->src[side].data); + packed_matrix.sums = allocator->AllocateBytes(SumsBytes(packed_matrix)); + } + } +} + +} // namespace ruy diff --git a/ruy/prepare_packed_matrices.h b/ruy/prepare_packed_matrices.h new file mode 100644 index 0000000..1092dc9 --- /dev/null +++ b/ruy/prepare_packed_matrices.h @@ -0,0 +1,42 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PREPARE_PACKED_MATRICES_H_ +#define RUY_RUY_PREPARE_PACKED_MATRICES_H_ + +#include "ruy/ctx.h" +#include "ruy/trmul_params.h" + +namespace ruy { + +// Ensures that the packed matrices are ready for TrMul's work. In the generic +// case, this is merely allocating their buffers. +// +// In the non-default case where +// a matrix has a cache_policy allowing caching, this is where we implement +// this caching feature: determining whether to cache each matrix, performing +// the cache lookup, and possibly performing the packing and cache update if +// not already cached. +// +// Assumes that the packed matrices have previously been created, with their +// fields already set except for the buffer allocations, as part of +// CreateTrMulParams. The reason for separating this preparation from the +// creation is that the creation needs to be templatized and this preparation +// does not. +void PreparePackedMatrices(Ctx* ctx, TrMulParams* params); + +} // namespace ruy + +#endif // RUY_RUY_PREPARE_PACKED_MATRICES_H_ diff --git a/ruy/profiler/BUILD b/ruy/profiler/BUILD new file mode 100644 index 0000000..64754bf --- /dev/null +++ b/ruy/profiler/BUILD @@ -0,0 +1,66 @@ +# A minimalistic profiler sampling pseudo-stacks + +load("//ruy:build_defs.oss.bzl", "ruy_linkopts_thread_standard_library") + +package( + licenses = ["notice"], # Apache 2.0 +) + +config_setting( + name = "ruy_profiler", + define_values = {"ruy_profiler": "true"}, +) + +# Used to build TFLite Micro RUY dependency for embedded targets outside of the +# RUY source tree. +filegroup( + name = "ruy_instrumentation_header", + srcs = ["instrumentation.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "instrumentation", + srcs = ["instrumentation.cc"], + hdrs = ["instrumentation.h"], + defines = select({ + ":ruy_profiler": ["RUY_PROFILER"], + "//conditions:default": [], + }), + linkopts = ruy_linkopts_thread_standard_library(), + visibility = ["//visibility:public"], +) + +cc_library( + name = "profiler", + srcs = [ + "profiler.cc", + "treeview.cc", + ], + hdrs = [ + "profiler.h", + "treeview.h", + ], + linkopts = ruy_linkopts_thread_standard_library(), + visibility = ["//visibility:public"], + deps = [":instrumentation"], +) + +cc_library( + name = "test_instrumented_library", + testonly = True, + srcs = ["test_instrumented_library.cc"], + hdrs = ["test_instrumented_library.h"], + deps = [":instrumentation"], +) + +cc_test( + name = "test", + srcs = ["test.cc"], + linkopts = ruy_linkopts_thread_standard_library(), + deps = [ + ":profiler", + ":test_instrumented_library", + "//ruy:gtest_wrapper", + ], +) diff --git a/ruy/profiler/CMakeLists.txt b/ruy/profiler/CMakeLists.txt new file mode 100644 index 0000000..df4b30a --- /dev/null +++ b/ruy/profiler/CMakeLists.txt @@ -0,0 +1,72 @@ +# This file is generated (whence no license header). Do not edit! +# To regenerate, run: +# cmake/bazel_to_cmake.sh + +if(${RUY_PROFILER}) + set(ruy_profiler_0_RUY_PROFILER "RUY_PROFILER") +else() + set(ruy_profiler_0_RUY_PROFILER "") +endif() + +if(CMAKE_SYSTEM_NAME STREQUAL Windows) + set(ruy_profiler_1_pthread "") +else() + set(ruy_profiler_1_pthread "-pthread") +endif() + +ruy_cc_library( + NAME + ruy_profiler_instrumentation + SRCS + instrumentation.cc + HDRS + instrumentation.h + DEFINES + ${ruy_profiler_0_RUY_PROFILER} + LINKOPTS + ${ruy_profiler_1_pthread} + PUBLIC +) + +ruy_cc_library( + NAME + ruy_profiler_profiler + SRCS + profiler.cc + treeview.cc + HDRS + profiler.h + treeview.h + LINKOPTS + ${ruy_profiler_1_pthread} + PUBLIC + DEPS + ruy_profiler_instrumentation +) + +ruy_cc_library( + NAME + ruy_profiler_test_instrumented_library + TESTONLY + SRCS + test_instrumented_library.cc + HDRS + test_instrumented_library.h + DEPS + ruy_profiler_instrumentation +) + +ruy_cc_test( + NAME + ruy_profiler_test + SRCS + test.cc + LINKOPTS + ${ruy_profiler_1_pthread} + DEPS + ruy_profiler_profiler + ruy_profiler_test_instrumented_library + ruy_gtest_wrapper +) + +ruy_add_all_subdirs() diff --git a/ruy/profiler/README.md b/ruy/profiler/README.md new file mode 100644 index 0000000..4ff344d --- /dev/null +++ b/ruy/profiler/README.md @@ -0,0 +1,149 @@ +# A minimalistic profiler sampling pseudo-stacks + +## Overview + +The present directory is the "ruy profiler". As a time profiler, it allows to +measure where code is spending time. + +Contrary to most typical profilers, what it samples is not real call stacks, but +"pseudo-stacks" which are just simple data structures constructed from within +the program being profiled. Using this profiler requires manually instrumenting +code to construct such pseudo-stack information. + +Another unusual characteristic of this profiler is that it uses only the C++11 +standard library. It does not use any non-portable feature, in particular it +does not rely on signal handlers. The sampling is performed by a thread, the +"profiler thread". + +A discussion of pros/cons of this approach is appended below. + +## How to use this profiler + +### How to instrument code + +An example of instrumented code is given in `test_instrumented_library.cc`. + +Code is instrumented by constructing `ScopeLabel` objects. These are RAII +helpers, ensuring that the thread pseudo-stack contains the label during their +lifetime. In the most common use case, one would construct such an object at the +start of a function, so that its scope is the function scope and it allows to +measure how much time is spent in this function. + +```c++ +#include "ruy/profiler/instrumentation.h" + +... + +void SomeFunction() { + ruy::profiler::ScopeLabel function_label("SomeFunction"); + ... do something ... +} +``` + +A `ScopeLabel` may however have any scope, for instance: + +```c++ +if (some_case) { + ruy::profiler::ScopeLabel extra_work_label("Some more work"); + ... do some more work ... +} +``` + +The string passed to the `ScopeLabel` constructor must be just a pointer to a +literal string (a `char*` pointer). The profiler will assume that these pointers +stay valid until the profile is finalized. + +However, that literal string may be a `printf` format string, and labels may +have up to 4 parameters, of type `int`. For example: + +```c++ +void SomeFunction(int size) { + ruy::profiler::ScopeLabel function_label("SomeFunction (size=%d)", size); + +``` + +### How to run the profiler + +Profiling instrumentation is a no-op unless the preprocessor token +`RUY_PROFILER` is defined, so defining it is the first step when actually +profiling. When building with Bazel, the preferred way to enable that is to pass +this flag on the Bazel command line: + +``` +--define=ruy_profiler=true +``` + +To actually profile a code scope, it is enough to construct a `ScopeProfile` +object, also a RAII helper. It will start the profiler on construction, and on +destruction it will terminate the profiler and report the profile treeview on +standard output by default. Example: + +```c++ +void SomeProfiledBenchmark() { + ruy::profiler::ScopeProfile profile; + + CallSomeInstrumentedCode(); +} +``` + +An example is provided by the `:test` target in the present directory. Run it +with `--define=ruy_profiler=true` as explained above: + +``` +bazel run -c opt \ + --define=ruy_profiler=true \ + //tensorflow/lite/experimental/ruy/profiler:test +``` + +The default behavior dumping the treeview on standard output may be overridden +by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor. +This causes the tree-view to be stored in that `TreeView` object, where it may +be accessed an manipulated using the functions declared in `treeview.h`. The +aforementioned `:test` provides examples for doing so. + +## Advantages and inconvenients + +Compared to a traditional profiler, e.g. Linux's "perf", the present kind of +profiler has the following inconvenients: + +* Requires manual instrumentation of code being profiled. +* Substantial overhead, modifying the performance characteristics of the code + being measured. +* Questionable accuracy. + +But also the following advantages: + +* Profiling can be driven from within a benchmark program, allowing the entire + profiling procedure to be a single command line. +* Not relying on symbol information removes removes exposure to toolchain + details and means less hassle in some build environments, especially + embedded/mobile (single command line to run and profile, no symbols files + required). +* Fully portable (all of this is standard C++11). +* Fully testable (see `:test`). Profiling becomes just another feature of the + code like any other. +* Customized instrumentation can result in easier to read treeviews (only + relevant functions, and custom labels may be more readable than function + names). +* Parametrized/formatted labels allow to do things that aren't possible with + call-stack-sampling profilers. For example, break down a profile where much + time is being spent in matrix multiplications, by the various matrix + multiplication shapes involved. + +The philosophy underlying this profiler is that software performance depends on +software engineers profiling often, and a key factor limiting that in practice +is the difficulty or cumbersome aspects of profiling with more serious profilers +such as Linux's "perf", especially in embedded/mobile development: multiple +command lines are involved to copy symbol files to devices, retrieve profile +data from the device, etc. In that context, it is useful to make profiling as +easy as benchmarking, even on embedded targets, even if the price to pay for +that is lower accuracy, higher overhead, and some intrusive instrumentation +requirement. + +Another key aspect determining what profiling approach is suitable for a given +context, is whether one already has a-priori knowledge of where much of the time +is likely being spent. When one has such a-priori knowledge, it is feasible to +instrument the known possibly-critical code as per the present approach. On the +other hand, in situations where one doesn't have such a-priori knowledge, a real +profiler such as Linux's "perf" allows to right away get a profile of real +stacks, from just symbol information generated by the toolchain. diff --git a/ruy/profiler/instrumentation.cc b/ruy/profiler/instrumentation.cc new file mode 100644 index 0000000..cc8122b --- /dev/null +++ b/ruy/profiler/instrumentation.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/profiler/instrumentation.h" + +#ifdef RUY_PROFILER + +#include <cstring> + +namespace ruy { +namespace profiler { + +void Label::operator=(const Label& other) { + format_ = other.format_; + args_count_ = other.args_count_; + for (int i = 0; i < args_count_; i++) { + args_[i] = other.args_[i]; + } +} + +bool Label::operator==(const Label& other) const { + if (std::string(format_) != std::string(other.format_)) { + return false; + } + if (args_count_ != other.args_count_) { + return false; + } + for (int i = 0; i < args_count_; i++) { + if (args_[i] != other.args_[i]) { + return false; + } + } + return true; +} + +std::string Label::Formatted() const { + static constexpr int kBufSize = 256; + char buf[kBufSize]; + if (args_count_ == 0) { + return format_; + } + if (args_count_ == 1) { + snprintf(buf, kBufSize, format_, args_[0]); + } else if (args_count_ == 2) { + snprintf(buf, kBufSize, format_, args_[0], args_[1]); + } else if (args_count_ == 3) { + snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]); + } else if (args_count_ == 4) { + snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]); + } else { + abort(); + } + return buf; +} + +namespace detail { + +std::mutex* GlobalsMutex() { + static std::mutex mutex; + return &mutex; +} + +bool& GlobalIsProfilerRunning() { + static bool b; + return b; +} + +std::vector<ThreadStack*>* GlobalAllThreadStacks() { + static std::vector<ThreadStack*> all_stacks; + return &all_stacks; +} + +ThreadStack* ThreadLocalThreadStack() { + thread_local static ThreadStack thread_stack; + return &thread_stack; +} + +ThreadStack::ThreadStack() { + std::lock_guard<std::mutex> lock(*GlobalsMutex()); + static std::uint32_t global_next_thread_stack_id = 0; + stack_.id = global_next_thread_stack_id++; + GlobalAllThreadStacks()->push_back(this); +} + +ThreadStack::~ThreadStack() { + std::lock_guard<std::mutex> lock(*GlobalsMutex()); + std::vector<ThreadStack*>* all_stacks = GlobalAllThreadStacks(); + for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) { + if (*it == this) { + all_stacks->erase(it); + return; + } + } +} +int GetBufferSize(const Stack& stack) { + return sizeof(stack.id) + sizeof(stack.size) + + stack.size * sizeof(stack.labels[0]); +} + +void CopyToBuffer(const Stack& stack, char* dst) { + memcpy(dst, &stack.id, sizeof(stack.id)); + dst += sizeof(stack.id); + memcpy(dst, &stack.size, sizeof(stack.size)); + dst += sizeof(stack.size); + memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0])); +} + +void ReadFromBuffer(const char* src, Stack* stack) { + memcpy(&stack->id, src, sizeof(stack->id)); + src += sizeof(stack->id); + memcpy(&stack->size, src, sizeof(stack->size)); + src += sizeof(stack->size); + memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0])); +} + +} // namespace detail +} // namespace profiler +} // namespace ruy + +#endif diff --git a/ruy/profiler/instrumentation.h b/ruy/profiler/instrumentation.h new file mode 100644 index 0000000..c4df1e6 --- /dev/null +++ b/ruy/profiler/instrumentation.h @@ -0,0 +1,203 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PROFILER_INSTRUMENTATION_H_ +#define RUY_RUY_PROFILER_INSTRUMENTATION_H_ + +#ifdef RUY_PROFILER +#include <cstdio> +#include <mutex> +#include <vector> +#endif + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +// A label is how a code scope is annotated to appear in profiles. +// The stacks that are sampled by the profiler are stacks of such labels. +// A label consists of a literal string, plus optional integer arguments. +class Label { + public: + Label() {} + template <typename... Args> + explicit Label(Args... args) { + Set(args...); + } + void Set(const char* format) { + format_ = format; + args_count_ = 0; + } + template <typename... Args> + void Set(const char* format, Args... args) { + format_ = format; + args_count_ = sizeof...(args); + SetArgs(0, args...); + } + + void operator=(const Label& other); + + bool operator==(const Label& other) const; + + std::string Formatted() const; + const char* format() const { return format_; } + + private: + void SetArgs(int position, int arg0) { args_[position] = arg0; } + + template <typename... Args> + void SetArgs(int position, int arg0, Args... args) { + SetArgs(position, arg0); + SetArgs(position + 1, args...); + } + + static constexpr int kMaxArgs = 4; + const char* format_ = nullptr; + int args_count_ = 0; + int args_[kMaxArgs]; +}; + +namespace detail { + +// Forward-declaration, see class ThreadStack below. +class ThreadStack; + +bool& GlobalIsProfilerRunning(); + +// Returns the global vector of pointers to all stacks, there being one stack +// per thread executing instrumented code. +std::vector<ThreadStack*>* GlobalAllThreadStacks(); + +// Returns the mutex to be locked around any access to GlobalAllThreadStacks(). +std::mutex* GlobalsMutex(); + +// Returns the thread-local stack, specific to the current thread. +ThreadStack* ThreadLocalThreadStack(); + +// This 'stack' is what may be more appropriately called a 'pseudostack': +// It contains Label entries that are 'manually' entered by instrumentation +// code. It's unrelated to real call stacks. +struct Stack { + std::uint32_t id = 0; + static constexpr int kMaxSize = 64; + int size = 0; + Label labels[kMaxSize]; +}; + +// Returns the buffer byte size required by CopyToSample. +int GetBufferSize(const Stack& stack); + +// Copies this Stack into a byte buffer, called a 'sample'. +void CopyToBuffer(const Stack& stack, char* dst); + +// Populates this Stack from an existing sample buffer, typically +// produced by CopyToSample. +void ReadFromBuffer(const char* src, Stack* stack); + +// ThreadStack is meant to be used as a thread-local singleton, assigning to +// each thread a Stack object holding its pseudo-stack of profile labels, +// plus a mutex allowing to synchronize accesses to this pseudo-stack between +// this thread and a possible profiler thread sampling it. +class ThreadStack { + public: + ThreadStack(); + ~ThreadStack(); + + const Stack& stack() const { return stack_; } + + // Returns the mutex to lock around any access to this stack. Each stack is + // accessed by potentially two threads: the thread that it belongs to + // (which calls Push and Pop) and the profiler thread during profiling + // (which calls CopyToSample). + std::mutex& Mutex() const { return mutex_; } + + // Pushes a new label on the top of this Stack. + template <typename... Args> + void Push(Args... args) { + // This mutex locking is needed to guard against race conditions as both + // the current thread and the profiler thread may be concurrently accessing + // this stack. In addition to that, this mutex locking also serves the other + // purpose of acting as a barrier (of compiler code reordering, of runtime + // CPU instruction reordering, and of memory access reordering), which + // gives a measure of correctness to this profiler. The downside is some + // latency. As this lock will be uncontended most of the times, the cost + // should be roughly that of an sequentially-consistent atomic access, + // comparable to an access to the level of CPU data cache that is shared + // among all cores, typically 60 cycles on current ARM CPUs, plus side + // effects from barrier instructions. + std::lock_guard<std::mutex> lock(mutex_); + // Avoid overrunning the stack, even in 'release' builds. This profiling + // instrumentation code should not ship in release builds anyway, the + // overhead of this check is negligible, and overrunning a stack array would + // be bad. + if (stack_.size >= Stack::kMaxSize) { + abort(); + } + stack_.labels[stack_.size++].Set(args...); + } + + // Pops the top-most label from this Stack. + void Pop() { + // See the comment in Push about this lock. While it would be tempting to + // try to remove this lock and just atomically decrement size_ with a + // store-release, that would not necessarily be a substitute for all of the + // purposes that this lock serves, or if it was done carefully to serve all + // of the same purposes, then that wouldn't be faster than this (mostly + // uncontended) lock. + std::lock_guard<std::mutex> lock(mutex_); + stack_.size--; + } + + private: + mutable std::mutex mutex_; + Stack stack_; +}; + +} // namespace detail + +// RAII user-facing way to construct Labels associated with their life scope +// and get them pushed to / popped from the current thread stack. +class ScopeLabel { + public: + template <typename... Args> + ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) { + thread_stack_->Push(args...); + } + + ~ScopeLabel() { thread_stack_->Pop(); } + + private: + detail::ThreadStack* thread_stack_; +}; + +#else // no RUY_PROFILER + +class ScopeLabel { + public: + template <typename... Args> + explicit ScopeLabel(Args...) {} + + // This destructor is needed to consistently silence clang's -Wunused-variable + // which seems to trigger semi-randomly. + ~ScopeLabel() {} +}; + +#endif + +} // namespace profiler +} // namespace ruy + +#endif // RUY_RUY_PROFILER_INSTRUMENTATION_H_ diff --git a/ruy/profiler/profiler.cc b/ruy/profiler/profiler.cc new file mode 100644 index 0000000..ae3a2e2 --- /dev/null +++ b/ruy/profiler/profiler.cc @@ -0,0 +1,109 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/profiler/profiler.h" + +#ifdef RUY_PROFILER +#include <atomic> +#include <chrono> // NOLINT +#include <cstdio> +#include <cstdlib> +#include <thread> // NOLINT +#include <vector> +#endif + +#include "ruy/profiler/instrumentation.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +ScopeProfile::ScopeProfile() { Start(); } +ScopeProfile::ScopeProfile(bool enable) { + if (enable) { + Start(); + } +} +ScopeProfile::~ScopeProfile() { + if (!thread_) { + return; + } + finishing_.store(true); + thread_->join(); + Finish(); +} + +void ScopeProfile::Start() { + { + std::lock_guard<std::mutex> lock(*detail::GlobalsMutex()); + if (detail::GlobalIsProfilerRunning()) { + fprintf(stderr, "FATAL: profiler already running!\n"); + abort(); + } + detail::GlobalIsProfilerRunning() = true; + } + finishing_ = false; + thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this)); +} + +void ScopeProfile::ThreadFunc() { + while (!finishing_.load()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::lock_guard<std::mutex> lock(*detail::GlobalsMutex()); + auto* thread_stacks = detail::GlobalAllThreadStacks(); + for (detail::ThreadStack* thread_stack : *thread_stacks) { + Sample(*thread_stack); + } + } +} + +void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) { + std::lock_guard<std::mutex> lock(thread_stack.Mutex()); + // Drop empty stacks. + // This ensures that profiles aren't polluted by uninteresting threads. + if (thread_stack.stack().size == 0) { + return; + } + int sample_size = detail::GetBufferSize(thread_stack.stack()); + int old_buf_size = samples_buf_.size(); + samples_buf_.resize(old_buf_size + sample_size); + detail::CopyToBuffer(thread_stack.stack(), + samples_buf_.data() + old_buf_size); +} + +void ScopeProfile::Finish() { + { + std::lock_guard<std::mutex> lock(*detail::GlobalsMutex()); + if (!detail::GlobalIsProfilerRunning()) { + fprintf(stderr, "FATAL: profiler is not running!\n"); + abort(); + } + detail::GlobalIsProfilerRunning() = false; + } + if (user_treeview_) { + user_treeview_->Populate(samples_buf_); + } else { + TreeView treeview; + treeview.Populate(samples_buf_); + Print(treeview); + } +} + +#endif // RUY_PROFILER + +} // namespace profiler +} // namespace ruy diff --git a/ruy/profiler/profiler.h b/ruy/profiler/profiler.h new file mode 100644 index 0000000..10db8df --- /dev/null +++ b/ruy/profiler/profiler.h @@ -0,0 +1,106 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PROFILER_PROFILER_H_ +#define RUY_RUY_PROFILER_PROFILER_H_ + +#include <cstdio> + +#ifdef RUY_PROFILER +#include <atomic> +#include <chrono> +#include <thread> +#include <vector> +#endif + +#include "ruy/profiler/instrumentation.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +// RAII user-facing way to create a profiler and let it profile a code scope, +// and print out an ASCII/MarkDown treeview upon leaving the scope. +class ScopeProfile { + public: + // Default constructor, unconditionally profiling. + ScopeProfile(); + + // Constructor allowing to choose at runtime whether to profile. + explicit ScopeProfile(bool enable); + + // Destructor. It's where the profile is reported. + ~ScopeProfile(); + + // See treeview_ member. + void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; } + + private: + void Start(); + + // Thread entry point function for the profiler thread. This thread is + // created on construction. + void ThreadFunc(); + + // Record a stack as a sample. + void Sample(const detail::ThreadStack& stack); + + // Finalize the profile. Called on destruction. + // If user_treeview_ is non-null, it will receive the treeview. + // Otherwise the treeview will just be printed. + void Finish(); + + // Buffer where samples are recorded during profiling. + std::vector<char> samples_buf_; + + // Used to synchronize thread termination. + std::atomic<bool> finishing_; + + // Underlying profiler thread, which will perform the sampling. + // This profiler approach relies on a thread rather than on signals. + std::unique_ptr<std::thread> thread_; + + // TreeView to populate upon destruction. If left null (the default), + // a temporary treeview will be used and dumped on stdout. The user + // may override that by passing their own TreeView object for other + // output options or to directly inspect the TreeView. + TreeView* user_treeview_ = nullptr; +}; + +#else // no RUY_PROFILER + +struct ScopeProfile { + ScopeProfile() { +#ifdef GEMMLOWP_PROFILING + fprintf( + stderr, + "\n\n\n**********\n\nWARNING:\n\nLooks like you defined " + "GEMMLOWP_PROFILING, but this code has been ported to the new ruy " + "profiler replacing the old gemmlowp profiler. You should now be " + "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using " + "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n"); +#endif + } + explicit ScopeProfile(bool) {} +}; + +#endif + +} // namespace profiler +} // namespace ruy + +#endif // RUY_RUY_PROFILER_PROFILER_H_ diff --git a/ruy/profiler/test.cc b/ruy/profiler/test.cc new file mode 100644 index 0000000..0405ac7 --- /dev/null +++ b/ruy/profiler/test.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <chrono> +#include <random> +#include <thread> + +#include "ruy/gtest_wrapper.h" +#include "ruy/profiler/profiler.h" +#include "ruy/profiler/test_instrumented_library.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { +namespace { + +void DoSomeMergeSort(int size) { + std::vector<int> data(size); + + std::default_random_engine engine; + for (auto& val : data) { + val = engine(); + } + + MergeSort(size, data.data()); +} + +// The purpose of this basic test is to cover the basic path that will be taken +// by a majority of users, not inspecting treeviews but just implicitly printing +// them on stdout, and to have this test enabled even when RUY_PROFILER is not +// defined, so that we have coverage for the non-RUY_PROFILER case. +TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) { + { + ScopeProfile profile; + DoSomeMergeSort(1 << 20); + } +} + +#ifdef RUY_PROFILER + +TEST(ProfilerTest, MergeSortSingleThread) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + DoSomeMergeSort(1 << 20); + } + Print(treeview); + EXPECT_EQ(treeview.thread_roots().size(), 1); + const auto& thread_root = *treeview.thread_roots().begin()->second; + EXPECT_EQ(DepthOfTreeBelow(thread_root), 22); + EXPECT_GE( + WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"), + 0.1 * thread_root.weight); + EXPECT_GE(WeightBelowNodeMatchingFormatted( + thread_root, "MergeSortRecurse (level=20, size=1)"), + 0.01 * thread_root.weight); + + TreeView treeview_collapsed; + CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)", + &treeview_collapsed); + Print(treeview_collapsed); + const auto& collapsed_thread_root = + *treeview_collapsed.thread_roots().begin()->second; + EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6); + EXPECT_EQ( + WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"), + WeightBelowNodeMatchingUnformatted(collapsed_thread_root, + "MergeSort (size=%d)")); +} + +TEST(ProfilerTest, MemcpyFourThreads) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + std::vector<std::unique_ptr<std::thread>> threads; + for (int i = 0; i < 4; i++) { + threads.emplace_back(new std::thread([i]() { + ScopeLabel thread_label("worker thread #%d", i); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + ScopeLabel some_more_work_label("some more work"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + })); + } + for (int i = 0; i < 4; i++) { + threads[i]->join(); + } + } + Print(treeview); + // Since we cleared GlobalAllThreadStacks and the current thread hasn't + // created any ScopeLabel, only the 4 worker threads should be recorded. + EXPECT_EQ(treeview.thread_roots().size(), 4); + for (const auto& thread_root : treeview.thread_roots()) { + const TreeView::Node& root_node = *thread_root.second; + // The root node may have 1 or 2 children depending on whether there is + // an "[other]" child. + EXPECT_GE(root_node.children.size(), 1); + EXPECT_LE(root_node.children.size(), 2); + const TreeView::Node& child_node = *root_node.children[0]; + EXPECT_EQ(child_node.label.format(), "worker thread #%d"); + // There must be 2 children, since roughly half the time will be in + // "some more work" leaving the other half in "[other]". + EXPECT_EQ(child_node.children.size(), 2); + const TreeView::Node& child_child_node = *child_node.children[0]; + // Since we sample every millisecond and the threads run for >= 2000 + // milliseconds, the "thread func" label should get roughly 2000 samples. + // Not very rigorous, as we're depending on the profiler thread getting + // scheduled, so to avoid this test being flaky, we use a much more + // conservative value of 500, one quarter of that normal value 2000. + EXPECT_GE(child_node.weight, 500); + // Likewise, allow up to four times more than the normal value 2000. + EXPECT_LE(child_node.weight, 8000); + // Roughly half of time should be spent under the "some more work" label. + float some_more_work_percentage = + 100.f * child_child_node.weight / child_node.weight; + EXPECT_GE(some_more_work_percentage, 40.0f); + EXPECT_LE(some_more_work_percentage, 60.0f); + } +} + +TEST(ProfilerTest, OneThreadAfterAnother) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + { + std::thread thread([]() { + ScopeLabel thread_label("thread 0"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + }); + thread.join(); + } + { + std::thread thread([]() { + ScopeLabel thread_label("thread 1"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + }); + thread.join(); + } + } + Print(treeview); + EXPECT_EQ(treeview.thread_roots().size(), 2); +} + +#endif // RUY_PROFILER + +} // namespace +} // namespace profiler +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/profiler/test_instrumented_library.cc b/ruy/profiler/test_instrumented_library.cc new file mode 100644 index 0000000..b017ea9 --- /dev/null +++ b/ruy/profiler/test_instrumented_library.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <vector> + +#include "ruy/profiler/instrumentation.h" + +namespace { + +void MergeSortRecurse(int level, int size, int* data, int* workspace) { + ruy::profiler::ScopeLabel function_label( + "MergeSortRecurse (level=%d, size=%d)", level, size); + if (size <= 1) { + return; + } + int half_size = size / 2; + MergeSortRecurse(level + 1, half_size, data, workspace); + MergeSortRecurse(level + 1, size - half_size, data + half_size, + workspace + half_size); + + ruy::profiler::ScopeLabel merging_sorted_halves_label( + "Merging sorted halves"); + int dst_index = 0; + int left_index = 0; + int right_index = half_size; + while (dst_index < size) { + int val; + if (left_index < half_size && + ((right_index >= size) || data[left_index] < data[right_index])) { + val = data[left_index++]; + } else { + val = data[right_index++]; + } + workspace[dst_index++] = val; + } + for (int i = 0; i < size; i++) { + data[i] = workspace[i]; + } +} + +} // namespace + +void MergeSort(int size, int* data) { + ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size); + std::vector<int> workspace(size); + MergeSortRecurse(0, size, data, workspace.data()); +} diff --git a/ruy/profiler/test_instrumented_library.h b/ruy/profiler/test_instrumented_library.h new file mode 100644 index 0000000..4882962 --- /dev/null +++ b/ruy/profiler/test_instrumented_library.h @@ -0,0 +1,23 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ +#define RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ + +#include "ruy/profiler/instrumentation.h" + +void MergeSort(int size, int* data); + +#endif // RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ diff --git a/ruy/profiler/treeview.cc b/ruy/profiler/treeview.cc new file mode 100644 index 0000000..be0a944 --- /dev/null +++ b/ruy/profiler/treeview.cc @@ -0,0 +1,252 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef RUY_PROFILER + +#include "ruy/profiler/treeview.h" + +#include <algorithm> +#include <cassert> +#include <cstdio> +#include <functional> +#include <memory> +#include <vector> + +namespace ruy { +namespace profiler { + +namespace { + +void SortNode(TreeView::Node* node) { + using NodePtr = std::unique_ptr<TreeView::Node>; + std::sort(node->children.begin(), node->children.end(), + [](const NodePtr& n1, const NodePtr& n2) { + return n1->weight > n2->weight; + }); + for (const auto& child : node->children) { + SortNode(child.get()); + } +} + +// Records a stack i.e. a sample in a treeview, by incrementing the weights +// of matching existing nodes and/or by creating new nodes as needed, +// recursively, below the given node. +void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) { + node->weight++; + if (stack.size == level) { + return; + } + TreeView::Node* child_to_add_to = nullptr; + for (const auto& child : node->children) { + if (child->label == stack.labels[level]) { + child_to_add_to = child.get(); + break; + } + } + if (!child_to_add_to) { + child_to_add_to = new TreeView::Node; + child_to_add_to->label = stack.labels[level]; + node->children.emplace_back(child_to_add_to); + } + AddStack(stack, child_to_add_to, level + 1); +} + +// Recursively populates the treeview below the given node with 'other' +// entries documenting for each node the difference between its weight and the +// sum of its children's weight. +void AddOther(TreeView::Node* node) { + int top_level_children_weight = 0; + for (const auto& child : node->children) { + AddOther(child.get()); + top_level_children_weight += child->weight; + } + if (top_level_children_weight != 0 && + top_level_children_weight != node->weight) { + auto* new_child = new TreeView::Node; + new_child->label = Label("[other]"); + new_child->weight = node->weight - top_level_children_weight; + node->children.emplace_back(new_child); + } +} + +} // namespace + +void TreeView::Populate(const std::vector<char>& samples_buf_) { + thread_roots_.clear(); + // Populate the treeview with regular nodes coming from samples. + const char* buf_ptr = samples_buf_.data(); + const char* const buf_ptr_end = buf_ptr + samples_buf_.size(); + while (buf_ptr < buf_ptr_end) { + detail::Stack stack; + detail::ReadFromBuffer(buf_ptr, &stack); + // Empty stacks should have been dropped during sampling. + assert(stack.size > 0); + buf_ptr += GetBufferSize(stack); + const int id = stack.id; + if (!thread_roots_[id]) { + thread_roots_[id].reset(new Node); + } + AddStack(stack, thread_roots_[id].get(), 0); + } + // Populate the treeview with additional 'other' nodes, sort, and set + // root labels. + for (const auto& thread_root : thread_roots_) { + std::uint32_t id = thread_root.first; + Node* root = thread_root.second.get(); + AddOther(root); + SortNode(root); + root->label.Set("Thread %x (%d samples)", id, root->weight); + } +} + +// Recursively prints the treeview below the given node. The 'root' node +// argument is only needed to compute weights ratios, with the root ratio +// as denominator. +void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root, + int level) { + if (&node == &root) { + printf("%s\n\n", node.label.Formatted().c_str()); + } else { + for (int i = 1; i < level; i++) { + printf(" "); + } + printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight, + node.label.Formatted().c_str()); + } + for (const auto& child : node.children) { + PrintTreeBelow(*child, root, level + 1); + } +} + +void Print(const TreeView& treeview) { + printf("\n"); + printf("Profile (%d threads):\n\n", + static_cast<int>(treeview.thread_roots().size())); + for (const auto& thread_root : treeview.thread_roots()) { + const TreeView::Node& root = *thread_root.second; + PrintTreeBelow(root, root, 0); + printf("\n"); + } +} + +int DepthOfTreeBelow(const TreeView::Node& node) { + if (node.children.empty()) { + return 0; + } else { + int max_child_depth = 0; + for (const auto& child : node.children) { + max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child)); + } + return 1 + max_child_depth; + } +} + +int WeightBelowNodeMatchingFunction( + const TreeView::Node& node, + const std::function<bool(const Label&)>& match) { + int weight = 0; + if (match(node.label)) { + weight += node.weight; + } + for (const auto& child : node.children) { + weight += WeightBelowNodeMatchingFunction(*child, match); + } + return weight; +} + +int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, + const std::string& format) { + return WeightBelowNodeMatchingFunction( + node, [&format](const Label& label) { return label.format() == format; }); +} + +int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, + const std::string& formatted) { + return WeightBelowNodeMatchingFunction( + node, [&formatted](const Label& label) { + return label.Formatted() == formatted; + }); +} + +void CollapseNode(const TreeView::Node& node_in, int depth, + TreeView::Node* node_out) { + node_out->label = node_in.label; + node_out->weight = node_in.weight; + node_out->children.clear(); + if (depth > 0) { + for (const auto& child_in : node_in.children) { + auto* child_out = new TreeView::Node; + node_out->children.emplace_back(child_out); + CollapseNode(*child_in, depth - 1, child_out); + } + } +} + +void CollapseSubnodesMatchingFunction( + const TreeView::Node& node_in, int depth, + const std::function<bool(const Label&)>& match, TreeView::Node* node_out) { + if (match(node_in.label)) { + CollapseNode(node_in, depth, node_out); + } else { + node_out->label = node_in.label; + node_out->weight = node_in.weight; + node_out->children.clear(); + + for (const auto& child_in : node_in.children) { + auto* child_out = new TreeView::Node; + node_out->children.emplace_back(child_out); + CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out); + } + } +} + +void CollapseNodesMatchingFunction( + const TreeView& treeview_in, int depth, + const std::function<bool(const Label&)>& match, TreeView* treeview_out) { + treeview_out->mutable_thread_roots()->clear(); + for (const auto& thread_root_in : treeview_in.thread_roots()) { + std::uint32_t id = thread_root_in.first; + const auto& root_in = *thread_root_in.second; + auto* root_out = new TreeView::Node; + treeview_out->mutable_thread_roots()->emplace( + id, std::unique_ptr<TreeView::Node>(root_out)); + CollapseSubnodesMatchingFunction(root_in, depth, match, root_out); + } +} + +void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, + const std::string& format, + TreeView* treeview_out) { + CollapseNodesMatchingFunction( + treeview_in, depth, + [&format](const Label& label) { return label.format() == format; }, + treeview_out); +} + +void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, + const std::string& formatted, + TreeView* treeview_out) { + CollapseNodesMatchingFunction( + treeview_in, depth, + [&formatted](const Label& label) { + return label.Formatted() == formatted; + }, + treeview_out); +} + +} // namespace profiler +} // namespace ruy + +#endif // RUY_PROFILER diff --git a/ruy/profiler/treeview.h b/ruy/profiler/treeview.h new file mode 100644 index 0000000..de3313a --- /dev/null +++ b/ruy/profiler/treeview.h @@ -0,0 +1,130 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_PROFILER_TREEVIEW_H_ +#define RUY_RUY_PROFILER_TREEVIEW_H_ + +#ifdef RUY_PROFILER + +#include <functional> +#include <map> +#include <memory> +#include <vector> + +#include "ruy/profiler/instrumentation.h" + +namespace ruy { +namespace profiler { + +// A tree view of a profile. +class TreeView { + public: + struct Node { + std::vector<std::unique_ptr<Node>> children; + Label label; + int weight = 0; + }; + + void Populate(const std::vector<char>& samples_buf_); + + // Intentionally an *ordered* map so that threads are enumerated + // in an order that's consistent and typically putting the 'main thread' + // first. + using ThreadRootsMap = std::map<std::uint32_t, std::unique_ptr<Node>>; + + const ThreadRootsMap& thread_roots() const { return thread_roots_; } + ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; } + + private: + ThreadRootsMap thread_roots_; +}; + +/* Below are API functions for manipulating and printing treeviews. */ + +// Prints the treeview to stdout. +void Print(const TreeView& treeview); + +// Prints the treeview below the given node on stdout. +void PrintTreeBelow(const TreeView::Node& node); + +// Returns the tree depth below the given node. +int DepthOfTreeBelow(const TreeView::Node& node); + +// Returns the sum of weights of nodes below the given node and filtered by +// the `match` predicate. +int WeightBelowNodeMatchingFunction( + const TreeView::Node& node, const std::function<bool(const Label&)>& match); + +// Returns the sum of weights of nodes below the given node and whose +// unformatted label (i.e. raw format string) matches the given `format` string. +// +// This allows to aggregate nodes whose labels differ only by parameter values. +int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, + const std::string& format); + +// Returns the sum of weights of nodes below the given node and whose formatted +// label matches the `formatted` string. +// +// In the case of nodes with parametrized labels, this allows to count only +// nodes with specific parameter values. For that purpose, one may also instead +// use WeightBelowNodeMatchingFunction directly, with a `match` predicate +// comparing raw integer parameter values directly, instead of going through +// formatted strings. +int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, + const std::string& formatted); + +// Produces a `node_out` that is a copy of `node_in` but with tree depth below +// it clamped at `depth`, with further subtrees aggregated into single leaf +// nodes. +void CollapseNode(const TreeView::Node& node_in, int depth, + TreeView::Node* node_out); + +// Calls CollapseNode with the given `depth` on every subnode filtered by the +// `match` predicate. Note that this does NOT limit the tree depth below +// `node_out` to `depth`, since each collapsed node below `node_out` may be +// arbitrarily far below it and `depth` is only used as the collapsing depth +// at that point. +void CollapseSubnodesMatchingFunction( + const TreeView::Node& node_in, int depth, + const std::function<bool(const Label&)>& match, TreeView::Node* node_out); + +// Calls CollapseNode with the given `depth` on every node filtered by the +// `match` predicate. Note that this does NOT limit the tree depth below +// `node_out` to `depth`, since each collapsed node below `node_out` may be +// arbitrarily far below it and `depth` is only used as the collapsing depth +// at that point. +void CollapseNodesMatchingFunction( + const TreeView& treeview_in, int depth, + const std::function<bool(const Label&)>& match, TreeView* treeview_out); + +// Special case of CollapseNodesMatchingFunction matching unformatted labels, +// i.e. raw format strings. +// See the comment on WeightBelowNodeMatchingUnformatted. +void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, + const std::string& format, + TreeView* treeview_out); + +// Special case of CollapseNodesMatchingFunction matching formatted labels. +// See the comment on WeightBelowNodeMatchingFormatted. +void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, + const std::string& formatted, + TreeView* treeview_out); + +} // namespace profiler +} // namespace ruy + +#endif // RUY_PROFILER + +#endif // RUY_RUY_PROFILER_TREEVIEW_H_ diff --git a/ruy/reference_mul.h b/ruy/reference_mul.h new file mode 100644 index 0000000..6c00dc7 --- /dev/null +++ b/ruy/reference_mul.h @@ -0,0 +1,56 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_REFERENCE_MUL_H_ +#define RUY_RUY_REFERENCE_MUL_H_ + +#include <algorithm> + +#include "ruy/apply_multiplier.h" +#include "ruy/matrix.h" +#include "ruy/mul_params.h" + +namespace ruy { + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void ReferenceMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, + const MulParams<AccumScalar, DstScalar>& mul_params, + Matrix<DstScalar>* dst) { + for (int i = 0; i < lhs.layout().rows(); i++) { + for (int j = 0; j < rhs.layout().cols(); j++) { + AccumScalar accum = 0; + for (int k = 0; k < lhs.layout().cols(); k++) { + AccumScalar lhs_val = Element(lhs, i, k); + AccumScalar rhs_val = Element(rhs, k, j); + accum += (lhs_val - lhs.zero_point()) * (rhs_val - rhs.zero_point()); + } + int channel = + mul_params.channel_dimension() == ChannelDimension::kRow ? i : j; + if (mul_params.bias()) { + accum += mul_params.bias()[channel]; + } + ApplyMultiplier(mul_params, channel, &accum); + accum += dst->zero_point(); + accum = std::min<AccumScalar>(accum, mul_params.clamp_max()); + accum = std::max<AccumScalar>(accum, mul_params.clamp_min()); + *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum); + } + } +} + +} // namespace ruy + +#endif // RUY_RUY_REFERENCE_MUL_H_ diff --git a/ruy/ruy.h b/ruy/ruy.h new file mode 100644 index 0000000..ddbe192 --- /dev/null +++ b/ruy/ruy.h @@ -0,0 +1,114 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the main Ruy public header. + +#ifndef RUY_RUY_RUY_H_ +#define RUY_RUY_RUY_H_ + +#include "ruy/context.h" +#include "ruy/context_get_ctx.h" +#include "ruy/frontend.h" +#include "ruy/mat.h" +#include "ruy/matrix.h" +#include "ruy/mul_params.h" +#include "ruy/path.h" +#include "ruy/trace.h" + +namespace ruy { + +// Entry point allowing to specify a custom OR-ed set of Path's to +// compile. See the comments in path.h for more details about that. +// Most users should use the other ruy::Mul overload not taking a Path template +// parameter, and the main documentation comment is on that overload. +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename AccumScalar, typename DstScalar> +void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, + const MulParams<AccumScalar, DstScalar>& mul_params, Context* context, + Matrix<DstScalar>* dst) { + RUY_TRACE_SCOPE; + RUY_TRACE_INFO(MUL); + Mat<LhsScalar> internal_lhs = ToInternal(lhs); + Mat<RhsScalar> internal_rhs = ToInternal(rhs); + Mat<DstScalar> internal_dst = ToInternal(*dst); + MulFrontEnd<CompiledPaths>(internal_lhs, internal_rhs, mul_params, + get_ctx(context), &internal_dst); +} + +// Performs a multiplication of matrices, with some extra features for +// neural network applications. The basic operation is: +// +// dst = lhs * rhs // matrix multiplication +// +// The `mul_params` argument conveys additional parameters that are not +// naturally associated with lhs, rhs, dst. That includes typical neural network +// application domain specific features such as a bias-vector and clamp bounds, +// as well as integer quantization parameters. +// +// A simple reference implementation of the operation performed by ruy::Mul +// is provided by the ruy::ReferenceMul function in reference_mul.h. +// +// The `context` argument can be any ruy::Context object as long as no other +// thread is going to concurrently access that ruy::Context. The simplest +// correct (but not efficient) calling pattern is +// +// ruy::Context context; +// ruy::Mul(lhs, rhs, mul_params, &context, dst); +// +// However, creating and destroying a new context everytime is inefficient +// because it doesn't allow for resources to persist across ruy calls. Such +// resources may include heap allocations, a thread pool, and hardware detection +// results, and can be expensive to obtain. So the recommended usage pattern is +// more like this: +// +// // Once during initialization: +// ruy::Context* context = new ruy::Context; +// +// // Many times +// ruy::Mul(lhs, rhs, mul_params, context, dst); +// +// If multiple threads may concurrently be calling ruy::Mul, they must either +// use separate Contexts, or use a lock to ensure that no two threads are +// concurrently accessing the Context object. There is no lock inside Context, +// nothing is done to ensure reentrancy with shared Context objects. +// +// Ruy defaults to using only 1 thread. Multi-threading is always opted in to, +// by calling Context::set_max_num_threads() with an explicit thread count. +// If multiple threads may concurrently be calling ruy::Mul, it is advisable +// to set up their respective Context objects with set_max_num_threads so that +// the overall number of threads doesn't exceed the overall number of threads +// that the system can usefully execute concurrently +// (e.g. the number of CPU cores in typical scenarios). At least ruy forces +// each invocation to make an explicit decision here, there is no automatic +// detection of the best number of threads to use in ruy. +// +// Constraints on the template parameters: +// * If DstScalar is floating-point then AccumScalar must also be. +// * If DstScalar is integral then AccumScalar must be std::int32_t. +// Please refer to MulParams' class comment for more information. When +// DstScalar is integral and is narrower than AccumScalar, additional +// MulParams fields must be set to control the scaling of internal accumulators +// before the final saturating cast to the DstScalar type. +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, + const MulParams<AccumScalar, DstScalar>& mul_params, Context* context, + Matrix<DstScalar>* dst) { + Mul<kDefaultPaths>(lhs, rhs, mul_params, context, dst); +} + +} // namespace ruy + +#endif // RUY_RUY_RUY_H_ diff --git a/ruy/ruy_test.bzl b/ruy/ruy_test.bzl new file mode 100644 index 0000000..ef7e8b1 --- /dev/null +++ b/ruy/ruy_test.bzl @@ -0,0 +1,34 @@ +# Provides the ruy_test macro for type-parametrized tests. +"""ruy_test is a macro for building a test with multiple paths corresponding to tuples of types for LHS, RHS, accumulator and destination.""" + +def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = [], deps = None): + for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: + native.cc_test( + name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), + srcs = srcs, + copts = copts + [ + "-DRUY_TEST_LHSSCALAR=%s" % lhs, + "-DRUY_TEST_RHSSCALAR=%s" % rhs, + "-DRUY_TEST_ACCUMSCALAR=%s" % accum, + "-DRUY_TEST_DSTSCALAR=%s" % dst, + ], + deps = deps, + tags = tags, + ) + +def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts, deps = None): + tags = ["req_dep=//third_party/gemmlowp:profiler"] + for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: + native.cc_binary( + name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), + testonly = True, + srcs = srcs, + copts = copts + [ + "-DRUY_TEST_LHSSCALAR=%s" % lhs, + "-DRUY_TEST_RHSSCALAR=%s" % rhs, + "-DRUY_TEST_ACCUMSCALAR=%s" % accum, + "-DRUY_TEST_DSTSCALAR=%s" % dst, + ], + deps = deps, + tags = tags, + ) diff --git a/ruy/ruy_test_ext.oss.bzl b/ruy/ruy_test_ext.oss.bzl new file mode 100644 index 0000000..5701fff --- /dev/null +++ b/ruy/ruy_test_ext.oss.bzl @@ -0,0 +1,7 @@ +"""Allows to specialize the ruy BUILD to availability of external libraries""" + +def ruy_test_ext_defines(): + return [] + +def ruy_test_ext_deps(): + return [] diff --git a/ruy/side_pair.h b/ruy/side_pair.h new file mode 100644 index 0000000..2f40277 --- /dev/null +++ b/ruy/side_pair.h @@ -0,0 +1,68 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_SIDE_PAIR_H_ +#define RUY_RUY_SIDE_PAIR_H_ + +#include "ruy/check_macros.h" + +namespace ruy { + +// Enumeration of the sides, i.e. the operands 'slots', in a matrix +// multiplication. The numerical values of these enumeration constants matter +// because these will be used as indices into the array underlying a SidePair. +enum class Side { + // Left-hand side + kLhs = 0, + // Right-hand side + kRhs = 1 +}; + +inline Side OtherSide(Side side) { + return side == Side::kLhs ? Side::kRhs : Side::kLhs; +} + +// SidePair is a pair container where the two elements are indexed by a Side +// enum. +template <typename T> +class SidePair final { + public: + SidePair() {} + SidePair(const T& a, const T& b) : elem_{a, b} {} + const T& operator[](Side side) const { + const int index = static_cast<int>(side); + // Technically this check is vacuous, since other values would be + // out-of-range for enum Side. + RUY_DCHECK(index == 0 || index == 1); + return elem_[index]; + } + + T& operator[](Side side) { + const int index = static_cast<int>(side); + // Technically this check is vacuous, since other values would be + // out-of-range for enum Side. + RUY_DCHECK(index == 0 || index == 1); + return elem_[index]; + } + + private: + static_assert(static_cast<int>(Side::kLhs) == 0, ""); + static_assert(static_cast<int>(Side::kRhs) == 1, ""); + T elem_[2]; +}; + +} // namespace ruy + +#endif // RUY_RUY_SIDE_PAIR_H_ diff --git a/ruy/size_util.h b/ruy/size_util.h new file mode 100644 index 0000000..a144522 --- /dev/null +++ b/ruy/size_util.h @@ -0,0 +1,105 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_SIZE_UTIL_H_ +#define RUY_RUY_SIZE_UTIL_H_ + +#include <type_traits> + +#include "ruy/check_macros.h" + +#ifdef _WIN32 +#include <intrin.h> +#endif + +namespace ruy { + +template <typename Integer> +inline Integer floor_log2(Integer n) { + static_assert(std::is_integral<Integer>::value, ""); + static_assert(std::is_signed<Integer>::value, ""); + static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, ""); + + RUY_DCHECK_GE(n, 1); +#ifdef _MSC_VER + unsigned long result; // NOLINT[runtime/int] + if (sizeof(Integer) == 4) { + _BitScanReverse(&result, n); + } else { +#if defined(_M_X64) || defined(_M_ARM64) + // _BitScanReverse64 is supported only on 64-bit MSVC platforms + _BitScanReverse64(&result, static_cast<unsigned __int64>(n)); +#else + // Emulate using 32-bit _BitScanReverse + const uint32_t n_hi = uint64_t(n) >> 32; + if (n_hi == 0) { + _BitScanReverse(&result, static_cast<unsigned long>(n)); + } else { + _BitScanReverse(&result, static_cast<unsigned long>(n_hi)); + result += 32; + } +#endif // defined(_M_X64) || defined(_M_ARM64) + } + return result; +#else + if (sizeof(Integer) == 4) { + return 31 - __builtin_clz(n); + } else { + return 63 - __builtin_clzll(n); + } +#endif +} + +template <typename Integer> +Integer ceil_log2(Integer n) { + RUY_DCHECK_GE(n, 1); + return n == 1 ? 0 : floor_log2(n - 1) + 1; +} + +template <typename Integer> +constexpr bool is_pot(Integer value) { + return (value > 0) && ((value & (value - 1)) == 0); +} + +template <typename Integer> +Integer pot_log2(Integer n) { + RUY_DCHECK(is_pot(n)); + return floor_log2(n); +} + +template <typename Integer> +Integer round_down_pot(Integer value) { + return static_cast<Integer>(1) << floor_log2(value); +} + +template <typename Integer> +Integer round_up_pot(Integer value) { + return static_cast<Integer>(1) << ceil_log2(value); +} + +template <typename Integer, typename Modulo> +Integer round_down_pot(Integer value, Modulo modulo) { + RUY_DCHECK_EQ(modulo & (modulo - 1), 0); + return value & ~(modulo - 1); +} + +template <typename Integer, typename Modulo> +Integer round_up_pot(Integer value, Modulo modulo) { + return round_down_pot(value + modulo - 1, modulo); +} + +} // namespace ruy + +#endif // RUY_RUY_SIZE_UTIL_H_ diff --git a/ruy/size_util_test.cc b/ruy/size_util_test.cc new file mode 100644 index 0000000..eb4c9ca --- /dev/null +++ b/ruy/size_util_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/size_util.h" + +#include <cstddef> +#include <cstdint> +#include <limits> + +#include "ruy/gtest_wrapper.h" + +namespace ruy { +namespace { + +template <typename Integer> +void SizeUtilTestValue(Integer value) { + if (value == 0) { + return; + } + + EXPECT_LE(0, floor_log2(value)); + EXPECT_LE(floor_log2(value), ceil_log2(value)); + EXPECT_LE(ceil_log2(value), 8 * sizeof(Integer)); + + if (is_pot(value)) { + EXPECT_EQ(floor_log2(value), ceil_log2(value)); + EXPECT_EQ(floor_log2(value), pot_log2(value)); + } else { + EXPECT_EQ(floor_log2(value) + 1, ceil_log2(value)); + } + EXPECT_EQ(value >> floor_log2(value), 1); + EXPECT_EQ(round_down_pot(value), static_cast<Integer>(1) + << floor_log2(value)); + EXPECT_LE(round_down_pot(value), value); + EXPECT_GE(round_down_pot(value), value >> 1); + EXPECT_TRUE(is_pot(round_down_pot(value))); + + if (ceil_log2(value) < static_cast<int>(8 * sizeof(Integer) - 1)) { + EXPECT_EQ(value >> ceil_log2(value), is_pot(value) ? 1 : 0); + EXPECT_EQ(round_up_pot(value), static_cast<Integer>(1) << ceil_log2(value)); + EXPECT_GE(round_up_pot(value), value); + EXPECT_LE(round_up_pot(value) >> 1, value); + EXPECT_TRUE(is_pot(round_up_pot(value))); + } + + for (std::uint8_t modulo : {1, 2, 8, 32, 128}) { + EXPECT_GE(value, round_down_pot(value, modulo)); + EXPECT_EQ(round_down_pot(value, modulo) % modulo, 0); + + if (value <= std::numeric_limits<Integer>::max() - modulo) { + EXPECT_LE(value, round_up_pot(value, modulo)); + EXPECT_EQ(round_up_pot(value, modulo) % modulo, 0); + } + } +} + +template <typename Integer> +void SizeUtilTest() { + for (unsigned exponent = 0; exponent < 8 * sizeof(Integer) - 1; exponent++) { + const Integer pot = static_cast<Integer>(1) << exponent; + SizeUtilTestValue(pot - 1); + SizeUtilTestValue(pot); + SizeUtilTestValue(pot + 1); + SizeUtilTestValue(pot + 12); + SizeUtilTestValue(pot + 123); + } + SizeUtilTestValue(std::numeric_limits<Integer>::max() - 1); + SizeUtilTestValue(std::numeric_limits<Integer>::max()); +} + +TEST(SizeUtilTest, Int) { SizeUtilTest<int>(); } + +TEST(SizeUtilTest, Long) { SizeUtilTest<long int>(); } // NOLINT + +TEST(SizeUtilTest, LongLong) { SizeUtilTest<long long int>(); } // NOLINT + +TEST(SizeUtilTest, Int32) { SizeUtilTest<std::int32_t>(); } + +TEST(SizeUtilTest, Int64) { SizeUtilTest<std::int64_t>(); } + +TEST(SizeUtilTest, Ptrdiff) { SizeUtilTest<std::ptrdiff_t>(); } + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/system_aligned_alloc.cc b/ruy/system_aligned_alloc.cc new file mode 100644 index 0000000..7c86691 --- /dev/null +++ b/ruy/system_aligned_alloc.cc @@ -0,0 +1,51 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/system_aligned_alloc.h" + +#include <cstddef> +#include <cstdlib> + +#ifdef _WIN32 +#include <malloc.h> +#endif + +namespace ruy { + +namespace detail { + +void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) { +#ifdef _WIN32 + return _aligned_malloc(num_bytes, kMinimumBlockAlignment); +#else + void *ptr; + if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) { + return nullptr; + } + return ptr; +#endif +} + +void SystemAlignedFree(void *ptr) { +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +} // namespace detail + +} // namespace ruy diff --git a/ruy/system_aligned_alloc.h b/ruy/system_aligned_alloc.h new file mode 100644 index 0000000..cdf73af --- /dev/null +++ b/ruy/system_aligned_alloc.h @@ -0,0 +1,53 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_SYSTEM_ALIGNED_ALLOC_H_ +#define RUY_RUY_SYSTEM_ALIGNED_ALLOC_H_ + +#include <cstddef> + +namespace ruy { + +namespace detail { + +// Minimum alignment for blocks. +// +// Considerations: +// - This needs to be at least the alignment of any usual data type. +// - It's useful that this is at least the size of a cache line to limit +// possible cache side effects (if only on performance behavior). +// - It's useful that this is at least the size of SIMD registers, as +// some SIMD instruction sets have at least performance behavior +// differences (e.g. NEON) or even different requirements (e.g. SSE) +// based on that. +// - It's useful that this is at least the size of an "exclusive reservation +// granule" on ARM, meaning that if we use this Allocator to allocate +// an atomic variable, there will be no side effects from other things +// contending for exclusive/atomic memory accesses to it. While the +// ARM reference manual mentions that this granule size may be as large +// as 2048 bytes, in practice we observe it to be 64 bytes. It can +// be queried cheaply, at runtime, from userspace, if needed. +constexpr std::ptrdiff_t kMinimumBlockAlignment = 64; + +// Primitive allocation functions obtaining aligned memory from the +// operating system. +void* SystemAlignedAlloc(std::ptrdiff_t num_bytes); +void SystemAlignedFree(void* ptr); + +} // namespace detail + +} // namespace ruy + +#endif // RUY_RUY_SYSTEM_ALIGNED_ALLOC_H_ diff --git a/ruy/test.h b/ruy/test.h new file mode 100644 index 0000000..5aa4c41 --- /dev/null +++ b/ruy/test.h @@ -0,0 +1,2308 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_TEST_H_ +#define RUY_RUY_TEST_H_ + +#include <math.h> + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <cstdlib> +#include <ctime> +#include <iostream> +#include <iterator> +#include <limits> +#include <memory> +#include <numeric> +#include <random> +#include <set> +#include <sstream> +#include <string> +#include <tuple> +#include <type_traits> +#include <vector> + +#include "ruy/allocator.h" +#include "ruy/context.h" +#include "ruy/context_get_ctx.h" +#include "ruy/ctx.h" +#include "ruy/gtest_wrapper.h" // IWYU pragma: export +#include "ruy/matrix.h" // IWYU pragma: export +#include "ruy/mul_params.h" // IWYU pragma: export +#include "ruy/pack_common.h" +#include "ruy/platform.h" +#include "ruy/pmu.h" +#include "ruy/reference_mul.h" +#include "ruy/ruy.h" +#include "ruy/size_util.h" +#include "ruy/time.h" + +#ifdef RUY_TEST_EXTERNAL_PATHS +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#pragma GCC diagnostic pop +#include "third_party/gemmlowp/public/gemmlowp.h" +#include "third_party/lapack/blas.h" +#endif + +#ifdef RUY_PROFILER +#include "ruy/profiler/profiler.h" +#endif + +#ifdef __linux__ +#include <sys/mman.h> +#include <unistd.h> +#endif + +namespace ruy { + +const float kClampRatio = 0.1f; + +enum class ExternalPath { + kNone, + kReference, + kGemmlowp, + kEigen, + kEigenTensor, + kOpenBlas +}; + +inline std::vector<std::string>* CoveredPaths() { + static std::vector<std::string> covered_paths; + return &covered_paths; +} + +inline const char* PathName(Path path) { +#define RUY_PATHNAME_CASE(NAME) \ + case Path::NAME: \ + return #NAME; + switch (path) { + RUY_PATHNAME_CASE(kStandardCpp) + RUY_PATHNAME_CASE(kInternalStandardCppVariant1) + RUY_PATHNAME_CASE(kInternalStandardCppVariant2) + RUY_PATHNAME_CASE(kInternalStandardCppVariant3) + +#if RUY_PLATFORM_NEON + RUY_PATHNAME_CASE(kNeon) + RUY_PATHNAME_CASE(kNeonDotprod) +#elif RUY_PLATFORM_X86 + RUY_PATHNAME_CASE(kAvx2Fma) + RUY_PATHNAME_CASE(kAvx512) + RUY_PATHNAME_CASE(kAvx) +#endif + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_PATHNAME_CASE +} + +inline const char* TuningName(Tuning tuning) { +#define RUY_SUBPATHNAME_CASE(NAME) \ + case Tuning::NAME: \ + return #NAME; + switch (tuning) { + RUY_SUBPATHNAME_CASE(kA55ish) + RUY_SUBPATHNAME_CASE(kGeneric) + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_SUBPATHNAME_CASE +} + +inline const char* PathName(ExternalPath path) { +#define RUY_PATHNAME_CASE(NAME) \ + case ExternalPath::NAME: \ + return #NAME; + switch (path) { + RUY_PATHNAME_CASE(kReference) + RUY_PATHNAME_CASE(kGemmlowp) + RUY_PATHNAME_CASE(kEigen) + RUY_PATHNAME_CASE(kEigenTensor) + RUY_PATHNAME_CASE(kOpenBlas) + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_PATHNAME_CASE +} + +inline std::ostream& operator<<(std::ostream& stream, Path path) { + return stream << PathName(path); +} + +inline std::ostream& operator<<(std::ostream& stream, + ExternalPath external_path) { + return stream << PathName(external_path); +} + +template <typename ContainerType> +std::string Join(const ContainerType& container) { + if (container.empty()) { + return "<empty>"; + } + std::ostringstream stream; + auto it = container.begin(); + stream << *it++; + for (; it != container.end(); ++it) { + stream << ", "; + stream << *it; + } + return stream.str(); +} + +struct LogCoveredPathsOnDestruction final { + ~LogCoveredPathsOnDestruction() { + std::cerr << "Covered paths: " << Join(*CoveredPaths()) << std::endl; + + // When we know that it would be abnormal for some path not to be covered, + // we check it here. Accidentally disabling SIMD paths has occurred in the + // past and is one of the biggest performance regressions imaginable. + // + // TODO: we should be able to require some x86 paths as well, at least + // SSE4.2. + +#if RUY_PLATFORM_ARM + // When testing on ARM32 or ARM64, make sure that we covered the NEON path. + // NEON is always available on ARM64, and we treat it as always available + // also on ARM32. + bool found_neon = false; + for (const std::string& covered_path : *CoveredPaths()) { + if (covered_path == "kNeon") { + found_neon = true; + } + } + if (!found_neon) { + std::cerr + << "Error: we haven't tested the kNeon path as we should have.\n" + << std::endl; + abort(); + } +#endif + + // When testing on ARM64 ChromiumOS emulator, make sure that we covered + // the dotprod path. We're getting such coverage at the moment thanks to + // using a sufficiently recent emulator, and we don't want to regress that. +#if RUY_PLATFORM_ARM_64 && defined RUY_TESTING_ON_CHROMIUMOS + bool found_dotprod = false; + for (const std::string& covered_path : *CoveredPaths()) { + if (covered_path == "kNeonDotprod") { + found_dotprod = true; + } + } + if (!found_dotprod) { + std::cerr + << "Error: we haven't tested the kNeonDotprod path as we should " + "have. At the moment, this is required on ChromiumOS as this is " + "what we run emulator tests in, that currently supports " + "dot-product " + "instructions, and we care very much about not regressing that. " + "If this test was run in an emulator, please upgrade to a newer " + "emulator version. If this test was run on an actual device, and " + "you need to be able to run ruy tests on devices not supporting " + "dot-product instructions, get in touch with us.\n" + << std::endl; + abort(); + } +#endif + } + static void Singleton() { static LogCoveredPathsOnDestruction singleton; } +}; + +enum class RandomRange { + kGeneral, + kAvoidMinValue, + kOffCenterAvoidMinValue, + kReasonableSrcZeroPoint, + kReasonableDstZeroPoint, + kBias +}; + +template <typename Scalar, + bool IsFloatingPoint = std::is_floating_point<Scalar>::value> +struct RandomRangeBounds {}; + +template <typename Scalar> +struct RandomRangeBounds<Scalar, true> { + static Scalar GetMinBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return -1; + case RandomRange::kAvoidMinValue: + return -1; + case RandomRange::kOffCenterAvoidMinValue: + return -1; + case RandomRange::kReasonableSrcZeroPoint: + return 0; + case RandomRange::kReasonableDstZeroPoint: + return 0; + case RandomRange::kBias: + return -1; + default: + RUY_CHECK(false); + return 0; + } + } + static Scalar GetMaxBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return 1; + case RandomRange::kAvoidMinValue: + return 1; + case RandomRange::kOffCenterAvoidMinValue: + return 1; + case RandomRange::kReasonableSrcZeroPoint: + return 0; + case RandomRange::kReasonableDstZeroPoint: + return 0; + case RandomRange::kBias: + return 1; + default: + RUY_CHECK(false); + return 0; + } + } +}; + +template <typename Scalar> +Scalar WeightedSum(Scalar s1, float weight1, Scalar s2, float weight2) { + float sum = s1 * weight1 + s2 * weight2; + float clamped = std::min<float>( + std::numeric_limits<Scalar>::max(), + std::max<float>(std::numeric_limits<Scalar>::lowest(), sum)); + return static_cast<Scalar>(clamped); +} + +template <typename Scalar> +Scalar Parametrized(float param) { + return WeightedSum(std::numeric_limits<Scalar>::max(), param, + std::numeric_limits<Scalar>::lowest(), 1 - param); +} + +template <typename Scalar> +struct RandomRangeBounds<Scalar, false> { + static Scalar GetMinBound(RandomRange range) { + static constexpr double offcenteredness = + 0.02; // Shift lower limit by about 5 for range of 255. + switch (range) { + case RandomRange::kGeneral: + return std::numeric_limits<Scalar>::lowest(); + case RandomRange::kAvoidMinValue: + return 1 + std::numeric_limits<Scalar>::lowest(); + case RandomRange::kOffCenterAvoidMinValue: + return 1 + std::numeric_limits<Scalar>::lowest() + + static_cast<Scalar>( + offcenteredness * std::numeric_limits<Scalar>::max() - + offcenteredness * + (std::numeric_limits<Scalar>::lowest() + 1)); + case RandomRange::kReasonableSrcZeroPoint: + return std::numeric_limits<Scalar>::lowest(); + case RandomRange::kReasonableDstZeroPoint: + return Parametrized<Scalar>(0.4); + case RandomRange::kBias: + return std::is_same<Scalar, std::int32_t>::value + ? static_cast<Scalar>(-10000) + : 0; + default: + RUY_CHECK(false); + return 0; + } + } + static Scalar GetMaxBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return std::numeric_limits<Scalar>::max(); + case RandomRange::kAvoidMinValue: + return std::numeric_limits<Scalar>::max(); + case RandomRange::kOffCenterAvoidMinValue: + return std::numeric_limits<Scalar>::max(); + case RandomRange::kReasonableSrcZeroPoint: + return std::numeric_limits<Scalar>::max(); + case RandomRange::kReasonableDstZeroPoint: + return Parametrized<Scalar>(0.6); + case RandomRange::kBias: + return std::is_same<Scalar, std::int32_t>::value + ? static_cast<Scalar>(10000) + : 0; + default: + RUY_CHECK(false); + return 0; + } + } +}; + +inline std::default_random_engine& global_random_engine() { + static std::default_random_engine engine; + return engine; +} + +template <typename Scalar> +struct UniformRandomDistribution { + UniformRandomDistribution(RandomRange range) + : dist(RandomRangeBounds<Scalar>::GetMinBound(range), + RandomRangeBounds<Scalar>::GetMaxBound(range)) {} + Scalar Get() { return dist(global_random_engine()); } + // std::uniform_int_distribution is specified not to support char types, + // only short and wider types. MSVC actually generates an error on + // std::uniform_int_distribution<std::int8_t>. + using StdDistType = typename std::conditional< + std::is_floating_point<Scalar>::value, + std::uniform_real_distribution<Scalar>, + std::uniform_int_distribution<std::int32_t>>::type; + StdDistType dist; +}; + +template <typename Scalar> +void MakeRandomScalar(UniformRandomDistribution<Scalar>* uniform_dist, + Scalar* dst) { + *dst = uniform_dist->Get(); +} + +#if defined(__has_feature) +#if __has_feature(address_sanitizer) +#define RUY_TEST_BUILT_WITH_ASAN +#endif +#endif + +// Don't use separate mappings when building with Address Sanitizer, as the +// manual mappings could hide actual address errors from ASan (ASan can't see +// the actual buffer valid address range inside the manual mapping). +#if defined __linux__ && !defined RUY_TEST_BUILT_WITH_ASAN +#define RUY_TEST_USE_SEPARATE_MAPPINGS +#endif + +template <typename T> +struct SeparateMappingAllocator { + using value_type = T; + + T* allocate(std::size_t n) { +#ifdef RUY_TEST_USE_SEPARATE_MAPPINGS + const std::size_t page_size = getpagesize(); + std::size_t buffer_size = n * sizeof(T); + std::size_t rounded_buffer_size = round_up_pot(buffer_size, page_size); + // We are going to map an additional page at the end of our buffer, then + // unmap it, to ensure that our buffer's end is the last mapped byte, so as + // to catch any overrun. + void* mapping = + mmap(nullptr, rounded_buffer_size + page_size, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + RUY_CHECK_NE(mapping, MAP_FAILED); + int unmap_result = + munmap(static_cast<char*>(mapping) + rounded_buffer_size, page_size); + RUY_CHECK_EQ(unmap_result, 0); + // Clearing bytes should be redundant since mmap has do do it already, but + // it does not hurt and acts as an assertion checking that we got the above + // mapping and unmapping right. + std::memset(mapping, 0, rounded_buffer_size); + // Compute the offset to make our buffer end at the last mapped byte. + std::size_t buffer_offset = rounded_buffer_size - buffer_size; + void* buffer = + static_cast<void*>(static_cast<char*>(mapping) + buffer_offset); + return static_cast<T*>(buffer); +#else + T* ret = new T[n]; + std::memset(ret, 0, n * sizeof(T)); + return ret; +#endif + } + void deallocate(T* p, std::size_t n) { +#ifdef RUY_TEST_USE_SEPARATE_MAPPINGS + // The mapped bytes are the buffer address range, rounded on both ends + // to page boundary. + const std::size_t page_size = getpagesize(); + std::size_t buffer_size = n * sizeof(T); + std::size_t rounded_buffer_size = round_up_pot(buffer_size, page_size); + std::uintptr_t p_addr = reinterpret_cast<std::uintptr_t>(p); + void* mapping = reinterpret_cast<void*>(p_addr & ~(page_size - 1)); + int ret = munmap(mapping, rounded_buffer_size); + RUY_CHECK_EQ(ret, 0); +#else + (void)n; + delete[] p; +#endif + } +}; + +template <typename T> +using SeparateMappingVector = std::vector<T, SeparateMappingAllocator<T>>; + +template <typename Scalar, typename Allocator> +void MakeRandomVector(UniformRandomDistribution<Scalar>* uniform_dist, int size, + std::vector<Scalar, Allocator>* dst) { + dst->resize(size); + for (auto& x : *dst) { + MakeRandomScalar(uniform_dist, &x); + } +} + +template <typename Scalar> +void MakeRandomScalar(RandomRange range, Scalar* dst) { + UniformRandomDistribution<Scalar> dist(range); + *dst = dist.Get(); + if (range == RandomRange::kReasonableDstZeroPoint || + range == RandomRange::kReasonableSrcZeroPoint) { + if (global_random_engine()() & 1) { + *dst = SymmetricZeroPoint<Scalar>(); + } + } +} + +template <typename Scalar, typename Allocator> +void MakeRandomVector(RandomRange range, int size, + std::vector<Scalar, Allocator>* dst) { + UniformRandomDistribution<Scalar> dist(range); + dst->resize(size); + for (auto& x : *dst) { + MakeRandomScalar(&dist, &x); + } +} + +enum class LayoutStyle { kUnstridedLinear, kLinear }; + +inline void MakeLayout(int rows, int cols, Order order, + LayoutStyle layout_style, Layout* layout) { + layout->set_rows(rows); + layout->set_cols(cols); + layout->set_order(order); + + const int min_stride = order == Order::kColMajor ? rows : cols; + + RUY_CHECK(layout_style == LayoutStyle::kUnstridedLinear || + layout_style == LayoutStyle::kLinear); + if (layout_style == LayoutStyle::kUnstridedLinear) { + layout->set_stride(min_stride); + } else { + layout->set_stride(min_stride + 1); + } +} + +template <typename Scalar> +struct StorageMatrix { + StorageMatrix() = default; + StorageMatrix(const StorageMatrix&) = delete; + SeparateMappingVector<Scalar> data; + Matrix<Scalar> matrix; +}; + +inline bool IsUnstrided(const Layout& layout) { + if (layout.order() == Order::kColMajor) { + return layout.stride() == layout.rows(); + } else { + return layout.stride() == layout.cols(); + } +} + +inline bool IsRowMajor(const Layout& layout) { + return layout.order() == Order::kRowMajor; +} + +inline bool IsColMajor(const Layout& layout) { + return layout.order() == Order::kColMajor; +} + +inline int FlatSize(const Layout& layout) { + const int outerdim = + layout.order() == Order::kColMajor ? layout.cols() : layout.rows(); + return layout.stride() * outerdim; +} + +template <typename Scalar> +void VerifyConsistentFields(const StorageMatrix<Scalar>& storage_matrix) { + if (storage_matrix.data.empty()) { + RUY_CHECK_EQ(storage_matrix.matrix.data(), nullptr); + RUY_CHECK_EQ(storage_matrix.matrix.layout().rows(), 0); + RUY_CHECK_EQ(storage_matrix.matrix.layout().cols(), 0); + } else { + RUY_CHECK_EQ(storage_matrix.matrix.data(), storage_matrix.data.data()); + RUY_CHECK_EQ(FlatSize(storage_matrix.matrix.layout()), + static_cast<int>(storage_matrix.data.size())); + } +} + +template <typename Scalar> +void MakeRandom(int rows, int cols, Order order, Scalar zero_point, + LayoutStyle layout_style, RandomRange range, + StorageMatrix<Scalar>* storage_matrix) { + MakeLayout(rows, cols, order, layout_style, + storage_matrix->matrix.mutable_layout()); + storage_matrix->matrix.set_zero_point(zero_point); + UniformRandomDistribution<Scalar> data_dist(range); + MakeRandomVector(&data_dist, FlatSize(storage_matrix->matrix.layout()), + &storage_matrix->data); + storage_matrix->matrix.set_data(storage_matrix->data.data()); + VerifyConsistentFields(*storage_matrix); +} + +template <typename Scalar> +struct TestResult { + void operator=(const TestResult&) = delete; + void operator=(const TestResult&&) = delete; + StorageMatrix<Scalar> storage_matrix; + Path path = Path::kNone; + Tuning tuning = Tuning::kAuto; + ExternalPath external_path = ExternalPath::kNone; + float latency; + float l1_refill_rate; + float l2_refill_rate; + float l3_refill_rate; + float l1tlb_refill_rate; + float l2tlb_refill_rate; + float mispred_rate; + float frontend_stall_rate; + float backend_stall_rate; +}; + +template <typename Scalar> +std::string PathName(const TestResult<Scalar>& result) { + std::string pathname; + if (result.path != Path::kNone) { + pathname.assign(PathName(result.path)); + } else if (result.external_path != ExternalPath::kNone) { + pathname.assign(PathName(result.external_path)); + } else { + RUY_CHECK(false); + } + if (result.tuning != Tuning::kAuto) { + pathname.append("/"); + pathname.append(TuningName(result.tuning)); + } + return pathname; +} + +template <typename tLhsScalar, typename tRhsScalar, typename tAccumScalar, + typename tDstScalar> +struct TestSet final { + using LhsScalar = tLhsScalar; + using RhsScalar = tRhsScalar; + using AccumScalar = tAccumScalar; + using DstScalar = tDstScalar; + using MulParamsType = MulParams<AccumScalar, DstScalar>; + using TestResultType = TestResult<DstScalar>; + + void Run() { + MakeZeroPoints(); + MakeLhsRhs(); + MakeMulParams(); + MakeOtherParams(); + MakeResultPaths(); + Eval(); + Verify(); + } + + private: + void MakeZeroPoints(); + void MakeLhsRhs(); + void MakeMulParams(); + void MakeResultPaths(); + void MakeOtherParams(); + void EvalAndVerify(); + void Eval(); + void Verify(); + + void EvalResult(TestResultType* result); + void EvalRuy(TestResultType* result); + void DoMul(TestResultType* result); + void Benchmark(TestResultType* result); + void VerifyTestResults() const; + + public: + enum class LifeStage { + kInitial, + kHasZeroPoints, + kHasLhsRhs, + kHasMulParams, + kHasOtherParams, + kHasResultPaths, + kEvaluated, + kFinal + }; + + ~TestSet(); + + LifeStage life_stage = LifeStage::kInitial; + + int rows = 0; + int cols = 0; + int depth = 0; + Order lhs_order = Order::kRowMajor; + Order rhs_order = Order::kColMajor; + Order dst_order = Order::kColMajor; + LayoutStyle layout_style = LayoutStyle::kUnstridedLinear; + + bool use_specified_zero_points = false; + LhsScalar lhs_zero_point = 0; + RhsScalar rhs_zero_point = 0; + DstScalar dst_zero_point = 0; + + SeparateMappingVector<AccumScalar> per_channel_multiplier_fixedpoint; + SeparateMappingVector<int> per_channel_multiplier_exponent; + + StorageMatrix<LhsScalar> lhs; + StorageMatrix<RhsScalar> rhs; + MulParamsType mul_params; + SeparateMappingVector<AccumScalar> bias_data; + std::vector<std::unique_ptr<TestResultType>> results; + + std::vector<Path> paths; + std::vector<ExternalPath> external_paths; + + bool benchmark = false; + bool perchannel = false; + int max_num_threads = 0; + + bool cache_lhs = false; + bool cache_rhs = false; +}; + +inline PmuEvents& GlobalPmuEvents() { + static PmuEvents pmu; + return pmu; +} + +inline Context& GlobalContext() { + // Ensure that GlobalPmuEvents is constructed before we create any context. + // This ensures that pmu counters are opened before we create any worker + // thread, which is necessary to count events from worker threads. + GlobalPmuEvents(); + + static Context context; + return context; +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::~TestSet() { + RUY_CHECK_EQ(life_stage, LifeStage::kFinal); + LogCoveredPathsOnDestruction::Singleton(); + GlobalContext().ClearPrepackedCache(); +} + +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define RUY_TSAN +#endif +#if __has_feature(address_sanitizer) +#define RUY_ASAN +#endif +#endif // defined(__has_feature) + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::DoMul( + TestResultType* result) { + Mul<kAllPathsIncludingInternalVariants>(lhs.matrix, rhs.matrix, mul_params, + &GlobalContext(), + &result->storage_matrix.matrix); +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::EvalRuy( + TestResultType* result) { + GlobalContext().set_explicit_tuning(result->tuning); + if (max_num_threads) { + GlobalContext().set_max_num_threads(max_num_threads); + } else if (benchmark) { + GlobalContext().set_max_num_threads(1); + } else { + GlobalContext().set_max_num_threads(1 + global_random_engine()() % 8); + } + get_ctx(&GlobalContext())->SetRuntimeEnabledPaths(result->path); + DoMul(result); + // If enabling caching, Mul is stateful, so we run it a second time to get + // coverage of these aspects. + if (!benchmark && (cache_lhs || cache_rhs)) { + DoMul(result); + } + RUY_CHECK_EQ(GlobalContext().last_used_path(), result->path); + GlobalContext().set_explicit_tuning(Tuning::kAuto); + GlobalContext().set_max_num_threads(1); +} + +#ifdef RUY_TEST_EXTERNAL_PATHS + +template <typename Scalar, gemmlowp::MapOrder tOrder> +void WrapGemmlowp(const Matrix<Scalar>& src, + gemmlowp::MatrixMap<const Scalar, tOrder>* dst) { + RUY_CHECK(src.layout().order() == (tOrder == gemmlowp::MapOrder::ColMajor + ? Order::kColMajor + : Order::kRowMajor)); + *dst = gemmlowp::MatrixMap<const Scalar, tOrder>( + src.data(), src.layout().rows(), src.layout().cols(), + src.layout().stride()); +} + +template <typename Scalar, gemmlowp::MapOrder tOrder> +void WrapGemmlowpMutable(Matrix<Scalar>* src, + gemmlowp::MatrixMap<Scalar, tOrder>* dst) { + RUY_CHECK(src->layout().order() == (tOrder == gemmlowp::MapOrder::ColMajor + ? Order::kColMajor + : Order::kRowMajor)); + *dst = gemmlowp::MatrixMap<Scalar, tOrder>(src->data(), src->layout().rows(), + src->layout().cols(), + src->layout().stride()); +} + +template <Order tOrder> +struct GemmlowpOrder {}; + +template <> +struct GemmlowpOrder<Order::kColMajor> { + static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::ColMajor; +}; + +template <> +struct GemmlowpOrder<Order::kRowMajor> { + static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::RowMajor; +}; + +inline gemmlowp::GemmContext& GlobalGemmlowpContext() { + static gemmlowp::GemmContext context; + return context; +} + +template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename LhsScalar, + typename RhsScalar, typename DstScalar, typename MulParamsType> +void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, + const MulParamsType& mul_params, int max_num_threads, + Matrix<DstScalar>* dst) { + static constexpr gemmlowp::MapOrder kGemmlowpLhsOrder = + GemmlowpOrder<LhsOrder>::kValue; + static constexpr gemmlowp::MapOrder kGemmlowpRhsOrder = + GemmlowpOrder<RhsOrder>::kValue; + static constexpr gemmlowp::MapOrder kGemmlowpDstOrder = + GemmlowpOrder<DstOrder>::kValue; + gemmlowp::MatrixMap<const LhsScalar, kGemmlowpLhsOrder> gemmlowp_lhs; + gemmlowp::MatrixMap<const RhsScalar, kGemmlowpRhsOrder> gemmlowp_rhs; + gemmlowp::MatrixMap<DstScalar, kGemmlowpDstOrder> gemmlowp_dst; + WrapGemmlowp(lhs, &gemmlowp_lhs); + WrapGemmlowp(rhs, &gemmlowp_rhs); + WrapGemmlowpMutable(dst, &gemmlowp_dst); + + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; + quantize_down_stage.result_offset_after_shift = dst->zero_point(); + quantize_down_stage.result_fixedpoint_multiplier = + mul_params.multiplier_fixedpoint(); + quantize_down_stage.result_exponent = mul_params.multiplier_exponent(); + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< + gemmlowp::VectorShape::Col> + quantize_down_stage_pc; + quantize_down_stage_pc.result_offset_after_shift = dst->zero_point(); + using ColVectorMap = + gemmlowp::VectorMap<const std::int32_t, gemmlowp::VectorShape::Col>; + quantize_down_stage_pc.result_fixedpoint_multiplier = ColVectorMap( + mul_params.multiplier_fixedpoint_perchannel(), lhs.layout().rows()); + quantize_down_stage_pc.result_exponent = ColVectorMap( + mul_params.multiplier_exponent_perchannel(), lhs.layout().rows()); + + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = mul_params.clamp_min(); + clamp_stage.max = mul_params.clamp_max(); + using OutputStageSaturatingCast = typename std::conditional< + std::is_same<DstScalar, std::uint8_t>::value, + gemmlowp::OutputStageSaturatingCastToUint8, + gemmlowp::OutputStageSaturatingCastToInt16>::type; + OutputStageSaturatingCast saturating_cast_stage; + + GlobalGemmlowpContext().set_max_num_threads(max_num_threads ? max_num_threads + : 1); + if (mul_params.bias()) { + using ColVectorMap = + gemmlowp::VectorMap<const std::int32_t, gemmlowp::VectorShape::Col>; + gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_add_stage; + bias_add_stage.bias_vector = + ColVectorMap(mul_params.bias(), dst->layout().rows()); +#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE + if (mul_params.multiplier_exponent_perchannel()) { + const auto& output_pipeline = + std::make_tuple(bias_add_stage, quantize_down_stage_pc, clamp_stage, + saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point(), -rhs.zero_point(), output_pipeline); + } else // NOLINT[readability/braces] +#endif + { + const auto& output_pipeline = + std::make_tuple(bias_add_stage, quantize_down_stage, clamp_stage, + saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point(), -rhs.zero_point(), output_pipeline); + } + } else { +#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE + if (mul_params.multiplier_exponent_perchannel()) { + const auto& output_pipeline = std::make_tuple( + quantize_down_stage_pc, clamp_stage, saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point(), -rhs.zero_point(), output_pipeline); + } else // NOLINT[readability/braces] +#endif + { + const auto& output_pipeline = std::make_tuple( + quantize_down_stage, clamp_stage, saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point(), -rhs.zero_point(), output_pipeline); + } + } +} + +inline constexpr int Mash(Order LhsOrder, Order RhsOrder, Order DstOrder) { + return (LhsOrder == Order::kRowMajor ? 4 : 0) + + (RhsOrder == Order::kRowMajor ? 2 : 0) + + (DstOrder == Order::kRowMajor ? 1 : 0); +} + +template <typename LhsScalar, typename RhsScalar, typename DstScalar, + typename MulParamsType> +void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, + const MulParamsType& mul_params, int max_num_threads, + Matrix<DstScalar>* dst) { + int index = + Mash(lhs.layout().order(), rhs.layout().order(), dst->layout().order()); + switch (index) { +#define EVALGEMMLOWP_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalGemmlowp<LHS, RHS, DST>(lhs, rhs, mul_params, max_num_threads, \ + dst); +#define EVALGEMMLOWP_CASE2(LHS, RHS) \ + EVALGEMMLOWP_CASE3(LHS, RHS, Order::kColMajor) \ + EVALGEMMLOWP_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALGEMMLOWP_CASE1(LHS) \ + EVALGEMMLOWP_CASE2(LHS, Order::kColMajor) \ + EVALGEMMLOWP_CASE2(LHS, Order::kRowMajor) + + EVALGEMMLOWP_CASE1(Order::kColMajor) + EVALGEMMLOWP_CASE1(Order::kRowMajor) + +#undef EVALGEMMLOWP_CASE1 +#undef EVALGEMMLOWP_CASE2 +#undef EVALGEMMLOWP_CASE3 + + default: + RUY_CHECK(false); + } +} + +template <Order tOrder> +struct EigenOrder {}; + +template <> +struct EigenOrder<Order::kColMajor> { + static constexpr int kValue = Eigen::ColMajor; +}; + +template <> +struct EigenOrder<Order::kRowMajor> { + static constexpr int kValue = Eigen::RowMajor; +}; + +template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename LhsScalar, + typename RhsScalar, typename DstScalar, typename MulParamsType> +void EvalEigen(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, + const MulParamsType& mul_params, int max_num_threads, + Matrix<DstScalar>* dst) { + RUY_CHECK_EQ(lhs.zero_point(), 0); + RUY_CHECK_EQ(rhs.zero_point(), 0); + RUY_CHECK_EQ(dst->zero_point(), 0); + RUY_CHECK_EQ(mul_params.multiplier_fixedpoint(), 0); + RUY_CHECK_EQ(mul_params.multiplier_exponent(), 0); + + static constexpr int kEigenLhsOrder = EigenOrder<LhsOrder>::kValue; + static constexpr int kEigenRhsOrder = EigenOrder<RhsOrder>::kValue; + static constexpr int kEigenDstOrder = EigenOrder<DstOrder>::kValue; + + using EigenLhsType = typename Eigen::Matrix<LhsScalar, Eigen::Dynamic, + Eigen::Dynamic, kEigenLhsOrder>:: + template StridedConstMapType<Eigen::OuterStride<Eigen::Dynamic>>::type; + using EigenRhsType = typename Eigen::Matrix<RhsScalar, Eigen::Dynamic, + Eigen::Dynamic, kEigenRhsOrder>:: + template StridedConstMapType<Eigen::OuterStride<Eigen::Dynamic>>::type; + using EigenDstType = typename Eigen::Matrix<DstScalar, Eigen::Dynamic, + Eigen::Dynamic, kEigenDstOrder>:: + template StridedMapType<Eigen::OuterStride<Eigen::Dynamic>>::type; + using EigenBiasType = + typename Eigen::Matrix<DstScalar, Eigen::Dynamic, 1>::ConstMapType; + + EigenLhsType eigen_lhs( + lhs.data(), lhs.layout().rows(), lhs.layout().cols(), + Eigen::OuterStride<Eigen::Dynamic>(lhs.layout().stride())); + EigenRhsType eigen_rhs( + rhs.data(), rhs.layout().rows(), rhs.layout().cols(), + Eigen::OuterStride<Eigen::Dynamic>(rhs.layout().stride())); + EigenDstType eigen_dst( + dst->data(), dst->layout().rows(), dst->layout().cols(), + Eigen::OuterStride<Eigen::Dynamic>(dst->layout().stride())); + Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); + + if (mul_params.bias()) { + EigenBiasType eigen_bias(mul_params.bias(), dst->layout().rows()); + if (mul_params.clamp_max() == std::numeric_limits<DstScalar>::infinity() && + mul_params.clamp_min() == -std::numeric_limits<DstScalar>::infinity()) { + eigen_dst.noalias() = (eigen_lhs * eigen_rhs).colwise() + eigen_bias; + } else { + eigen_dst.noalias() = ((eigen_lhs * eigen_rhs).colwise() + eigen_bias) + .cwiseMin(mul_params.clamp_max()) + .cwiseMax(mul_params.clamp_min()); + } + } else { + if (mul_params.clamp_max() == std::numeric_limits<DstScalar>::infinity() && + mul_params.clamp_min() == -std::numeric_limits<DstScalar>::infinity()) { + eigen_dst.noalias() = eigen_lhs * eigen_rhs; + } else { + eigen_dst.noalias() = (eigen_lhs * eigen_rhs) + .cwiseMin(mul_params.clamp_max()) + .cwiseMax(mul_params.clamp_min()); + } + } +} + +template <typename LhsScalar, typename RhsScalar, typename DstScalar, + typename MulParamsType> +void EvalEigen(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, + const MulParamsType& mul_params, int max_num_threads, + Matrix<DstScalar>* dst) { + int index = + Mash(lhs.layout().order(), rhs.layout().order(), dst->layout().order()); + switch (index) { +#define EVALEIGEN_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalEigen<LHS, RHS, DST>(lhs, rhs, mul_params, max_num_threads, dst); +#define EVALEIGEN_CASE2(LHS, RHS) \ + EVALEIGEN_CASE3(LHS, RHS, Order::kColMajor) \ + EVALEIGEN_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALEIGEN_CASE1(LHS) \ + EVALEIGEN_CASE2(LHS, Order::kColMajor) \ + EVALEIGEN_CASE2(LHS, Order::kRowMajor) + + EVALEIGEN_CASE1(Order::kColMajor) + EVALEIGEN_CASE1(Order::kRowMajor) + +#undef EVALEIGEN_CASE1 +#undef EVALEIGEN_CASE2 +#undef EVALEIGEN_CASE3 + + default: + RUY_CHECK(false); + } +} + +template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename Scalar, + typename MulParamsType> +void EvalEigenTensor(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs, + const MulParamsType& mul_params, int max_num_threads, + Matrix<Scalar>* dst) { + RUY_CHECK_EQ(lhs.zero_point(), 0); + RUY_CHECK_EQ(rhs.zero_point(), 0); + RUY_CHECK_EQ(dst->zero_point(), 0); + RUY_CHECK_EQ(mul_params.multiplier_fixedpoint(), 0); + RUY_CHECK_EQ(mul_params.multiplier_exponent(), 0); + + // Eigen::TensorMap only supports unstrided layouts + RUY_CHECK(IsUnstrided(lhs.layout())); + RUY_CHECK(IsUnstrided(rhs.layout())); + RUY_CHECK(IsUnstrided(dst->layout())); + + using TensorLhsType = + Eigen::TensorMap<Eigen::Tensor<const Scalar, 2, Eigen::ColMajor>>; + using TensorRhsType = + Eigen::TensorMap<Eigen::Tensor<const Scalar, 2, Eigen::ColMajor>>; + using TensorDstType = + Eigen::TensorMap<Eigen::Tensor<Scalar, 2, Eigen::ColMajor>>; + using TensorBiasType = + Eigen::TensorMap<Eigen::Tensor<const Scalar, 1, Eigen::ColMajor>>; + + const bool tr = DstOrder == Order::kRowMajor; + const auto& contract_lhs = tr ? rhs : lhs; + const auto& contract_rhs = tr ? lhs : rhs; + + TensorLhsType tensor_lhs( + contract_lhs.data(), + LhsOrder == Order::kColMajor ? contract_lhs.layout().rows() + : contract_lhs.layout().cols(), + LhsOrder == Order::kColMajor ? contract_lhs.layout().cols() + : contract_lhs.layout().rows()); + TensorRhsType tensor_rhs( + contract_rhs.data(), + RhsOrder == Order::kColMajor ? contract_rhs.layout().rows() + : contract_rhs.layout().cols(), + RhsOrder == Order::kColMajor ? contract_rhs.layout().cols() + : contract_rhs.layout().rows()); + TensorDstType tensor_dst(dst->data(), + DstOrder == Order::kColMajor ? dst->layout().rows() + : dst->layout().cols(), + DstOrder == Order::kColMajor ? dst->layout().cols() + : dst->layout().rows()); + using DimPair = + typename Eigen::Tensor<Scalar, 1, 0, Eigen::Index>::DimensionPair; + Eigen::array<DimPair, 1> contract_dims( + {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0, + (RhsOrder == Order::kColMajor) ? 0 : 1)}); + Eigen::array<int, 2> shuffle(DstOrder == Order::kColMajor ? 0 : 1, + DstOrder == Order::kColMajor ? 1 : 0); + static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1); + static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); + if (mul_params.bias()) { + TensorBiasType tensor_bias(mul_params.bias(), dst->layout().rows()); + Eigen::array<int, 2> bias_2d_shape(tr ? 1 : dst->layout().rows(), + tr ? dst->layout().rows() : 1); + Eigen::array<int, 2> bcast(tr ? dst->layout().cols() : 1, + tr ? 1 : dst->layout().cols()); + if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() && + mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) { + tensor_dst.device(device) = + tensor_lhs.contract(tensor_rhs, contract_dims); + } else { + tensor_dst.device(device) = + (tensor_lhs.contract(tensor_rhs, contract_dims) + + tensor_bias.reshape(bias_2d_shape).broadcast(bcast)) + .cwiseMin(mul_params.clamp_max()) + .cwiseMax(mul_params.clamp_min()); + } + } else { + if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() && + mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) { + tensor_dst.device(device) = + tensor_lhs.contract(tensor_rhs, contract_dims); + } else { + tensor_dst.device(device) = tensor_lhs.contract(tensor_rhs, contract_dims) + .cwiseMin(mul_params.clamp_max()) + .cwiseMax(mul_params.clamp_min()); + } + } +} + +template <typename Scalar, typename MulParamsType> +void EvalEigenTensor(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs, + const MulParamsType& mul_params, int max_num_threads, + Matrix<Scalar>* dst) { + int index = + Mash(lhs.layout().order(), rhs.layout().order(), dst->layout().order()); + switch (index) { +#define EVALEIGENTENSOR_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalEigenTensor<LHS, RHS, DST>(lhs, rhs, mul_params, \ + max_num_threads, dst); +#define EVALEIGENTENSOR_CASE2(LHS, RHS) \ + EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kColMajor) \ + EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALEIGENTENSOR_CASE1(LHS) \ + EVALEIGENTENSOR_CASE2(LHS, Order::kColMajor) \ + EVALEIGENTENSOR_CASE2(LHS, Order::kRowMajor) + + EVALEIGENTENSOR_CASE1(Order::kColMajor) + EVALEIGENTENSOR_CASE1(Order::kRowMajor) + +#undef EVALEIGENTENSOR_CASE1 +#undef EVALEIGENTENSOR_CASE2 +#undef EVALEIGENTENSOR_CASE3 + + default: + RUY_CHECK(false); + } +} + +template <typename Scalar> +struct GenericBlasGemm {}; + +template <> +struct GenericBlasGemm<lapack::doublereal> { + static void Run(char* transa, char* transb, lapack::integer* m, + lapack::integer* n, lapack::integer* k, + lapack::doublereal* alpha, lapack::doublereal* a, + lapack::integer* lda, lapack::doublereal* b, + lapack::integer* ldb, lapack::doublereal* beta, + lapack::doublereal* c, lapack::integer* ldc) { + dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +}; + +template <> +struct GenericBlasGemm<lapack::real> { + static void Run(char* transa, char* transb, lapack::integer* m, + lapack::integer* n, lapack::integer* k, lapack::real* alpha, + lapack::real* a, lapack::integer* lda, lapack::real* b, + lapack::integer* ldb, lapack::real* beta, lapack::real* c, + lapack::integer* ldc) { + sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +}; + +inline void TransposeLayout(Layout* layout) { + layout->set_order((layout->order() == Order::kRowMajor) ? Order::kColMajor + : Order::kRowMajor); + int tmp_rows = layout->rows(); + layout->set_rows(layout->cols()); + layout->set_cols(tmp_rows); +} + +template <typename Scalar> +void Transpose(Matrix<Scalar>* matrix) { + TransposeLayout(matrix->mutable_layout()); +} + +template <typename Scalar, typename MulParamsType> +void EvalOpenBlas(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs, + const MulParamsType& mul_params, int max_num_threads, + Matrix<Scalar>* dst) { + RUY_CHECK_EQ(lhs.zero_point(), 0); + RUY_CHECK_EQ(rhs.zero_point(), 0); + RUY_CHECK_EQ(dst->zero_point(), 0); + RUY_CHECK_EQ(mul_params.multiplier_fixedpoint(), 0); + RUY_CHECK_EQ(mul_params.multiplier_exponent(), 0); + + Matrix<Scalar> gemm_lhs; + Matrix<Scalar> gemm_rhs; + Matrix<Scalar> gemm_dst; + gemm_dst = *dst; + + // Use Transpose to reduce to the all-column-major case. + // Notice that ruy::Matrix merely holds a pointer, does not own data, + // so Transpose is cheap -- no actual matrix data is being transposed here. + if (dst->layout().order() == Order::kColMajor) { + gemm_lhs = lhs; + gemm_rhs = rhs; + } else { + gemm_lhs = rhs; + gemm_rhs = lhs; + Transpose(&gemm_lhs); + Transpose(&gemm_rhs); + Transpose(&gemm_dst); + } + bool transposed_lhs = false; + bool transposed_rhs = false; + + if (gemm_lhs.layout().order() == Order::kRowMajor) { + Transpose(&gemm_lhs); + transposed_lhs = true; + } + if (gemm_rhs.layout().order() == Order::kRowMajor) { + Transpose(&gemm_rhs); + transposed_rhs = true; + } + + RUY_CHECK_EQ(gemm_lhs.layout().order(), Order::kColMajor); + RUY_CHECK_EQ(gemm_rhs.layout().order(), Order::kColMajor); + RUY_CHECK_EQ(gemm_dst.layout().order(), Order::kColMajor); + + char transa = transposed_lhs ? 'T' : 'N'; + char transb = transposed_rhs ? 'T' : 'N'; + int m = gemm_lhs.layout().rows(); + int n = gemm_rhs.layout().cols(); + int k = gemm_lhs.layout().cols(); + float alpha = 1; + Scalar* a = gemm_lhs.data(); + int lda = gemm_lhs.layout().stride(); + Scalar* b = gemm_rhs.data(); + int ldb = gemm_rhs.layout().stride(); + float beta = 0; + Scalar* c = gemm_dst.data(); + int ldc = gemm_dst.layout().stride(); + GenericBlasGemm<Scalar>::Run(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, + &ldb, &beta, c, &ldc); + + // BLAS does not allow us to express the bias-addition and clamping, so + // we use Eigen for that. + + using EigenDstType = + typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>:: + template StridedMapType<Eigen::OuterStride<Eigen::Dynamic>>::type; + using EigenBiasType = + typename Eigen::Matrix<Scalar, Eigen::Dynamic, 1>::ConstMapType; + + EigenDstType eigen_dst( + gemm_dst.data(), gemm_dst.layout().rows(), gemm_dst.layout().cols(), + Eigen::OuterStride<Eigen::Dynamic>(gemm_dst.layout().stride())); + Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); + + if (mul_params.bias()) { + EigenBiasType eigen_bias(mul_params.bias(), dst->layout().rows()); + if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() && + mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) { + eigen_dst.noalias() = eigen_dst.colwise() + eigen_bias; + } else { + eigen_dst.noalias() = (eigen_dst.colwise() + eigen_bias) + .cwiseMin(mul_params.clamp_max()) + .cwiseMax(mul_params.clamp_min()); + } + } else { + if (mul_params.clamp_max() == std::numeric_limits<Scalar>::infinity() && + mul_params.clamp_min() == -std::numeric_limits<Scalar>::infinity()) { + } else { + eigen_dst.noalias() = eigen_dst.cwiseMin(mul_params.clamp_max()) + .cwiseMax(mul_params.clamp_min()); + } + } +} + +template <typename TestSetType> +struct SupportsGemmlowp { + static constexpr bool kValue = + std::is_same<typename TestSetType::LhsScalar, std::uint8_t>::value && + std::is_same<typename TestSetType::RhsScalar, std::uint8_t>::value; +}; + +template <typename TestSetType> +struct UsesSingleScalarType { + static constexpr bool kValue = + std::is_same<typename TestSetType::DstScalar, + typename TestSetType::LhsScalar>::value && + std::is_same<typename TestSetType::DstScalar, + typename TestSetType::RhsScalar>::value && + std::is_same<typename TestSetType::DstScalar, + typename TestSetType::AccumScalar>::value; +}; + +template <typename TestSetType, + bool IsFloatingPoint = + std::is_floating_point<typename TestSetType::AccumScalar>::value, + bool EnableGemmlowp = SupportsGemmlowp<TestSetType>::kValue, + bool SingleScalarType = UsesSingleScalarType<TestSetType>::kValue> +struct EvalExternalPathImpl { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType*, TestResult<DstScalar>*) { RUY_CHECK(false); } +}; + +template <typename TestSetType> +struct EvalExternalPathImpl<TestSetType, true, false, true> { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType* test_set, TestResult<DstScalar>* test_result) { + if (test_result->external_path == ExternalPath::kEigen) { + EvalEigen(test_set->lhs.matrix, test_set->rhs.matrix, + test_set->mul_params, test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else if (test_result->external_path == ExternalPath::kEigenTensor) { + EvalEigenTensor(test_set->lhs.matrix, test_set->rhs.matrix, + test_set->mul_params, test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else if (test_result->external_path == ExternalPath::kOpenBlas) { + EvalOpenBlas(test_set->lhs.matrix, test_set->rhs.matrix, + test_set->mul_params, test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else { + RUY_CHECK(false); + } + } +}; + +template <typename TestSetType, bool SingleScalarType> +struct EvalExternalPathImpl<TestSetType, false, true, SingleScalarType> { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType* test_set, TestResult<DstScalar>* test_result) { + if (test_result->external_path == ExternalPath::kGemmlowp) { + EvalGemmlowp(test_set->lhs.matrix, test_set->rhs.matrix, + test_set->mul_params, test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else { + RUY_CHECK(false); + } + } +}; + +#endif // RUY_TEST_EXTERNAL_PATHS + +template <typename TestSetType> +void EvalExternalPath( + TestSetType* test_set, + TestResult<typename TestSetType::DstScalar>* test_result) { + if (test_result->external_path == ExternalPath::kReference) { + // kReference is special because it's always available (the implementation + // is provided by ruy) and supports all cases (quantized and float). + ruy::ReferenceMul(test_set->lhs.matrix, test_set->rhs.matrix, + test_set->mul_params, + &test_result->storage_matrix.matrix); + } else { +#ifdef RUY_TEST_EXTERNAL_PATHS + EvalExternalPathImpl<TestSetType>::Run(test_set, test_result); +#endif // RUY_TEST_EXTERNAL_PATHS + } +} + +template <typename Scalar> +bool Agree(ExternalPath external_path1, const Matrix<Scalar>& matrix1, + ExternalPath external_path2, const Matrix<Scalar>& matrix2, + int depth) { + RUY_CHECK_EQ(matrix1.layout().rows(), matrix2.layout().rows()); + RUY_CHECK_EQ(matrix1.layout().cols(), matrix2.layout().cols()); + RUY_CHECK_EQ(matrix1.zero_point(), matrix2.zero_point()); + const int size = matrix1.layout().rows() * matrix1.layout().cols(); + double tolerated_max_diff = 0; + double tolerated_mean_diff = 0; + const float kSmallestAllowedDifference = + 4. * std::numeric_limits<Scalar>::epsilon(); + if (std::is_floating_point<Scalar>::value) { + double max_abs_val = 0; + for (int row = 0; row < matrix1.layout().rows(); row++) { + for (int col = 0; col < matrix1.layout().cols(); col++) { + max_abs_val = + std::max(max_abs_val, + std::abs(static_cast<double>(Element(matrix1, row, col)))); + max_abs_val = + std::max(max_abs_val, + std::abs(static_cast<double>(Element(matrix2, row, col)))); + } + } + tolerated_max_diff = max_abs_val * std::numeric_limits<Scalar>::epsilon() * + 64 * std::sqrt(static_cast<float>(depth)); + if (tolerated_max_diff < kSmallestAllowedDifference) { + // Clamp the tolerated max diff to be a bit above machine epsilon if the + // calculated value is too small. + tolerated_max_diff = kSmallestAllowedDifference; + if (external_path1 == ExternalPath::kEigen || + external_path2 == ExternalPath::kEigen || + external_path1 == ExternalPath::kEigenTensor || + external_path2 == ExternalPath::kEigenTensor) { + // Make additional allowance for Eigen differences. + tolerated_max_diff *= 10.0f; + } + } + tolerated_mean_diff = tolerated_max_diff / std::sqrt(size); + } else if (std::is_same<Scalar, std::int32_t>::value) { + // raw integer case, no rounding, so we can require exactness + tolerated_max_diff = 0; + tolerated_mean_diff = 0; + } else { + // quantized case, with rounding errors in downscaling int32 accumulators + // to final 8bit or 16bit values. + + if (external_path1 != ExternalPath::kNone || + external_path2 != ExternalPath::kNone) { + // In this case, we are comparing against some other library than ruy. + // + // We may have to tolerate an error of +/- 1 from using different + // rounding in fixed-point multiplication, and then again an error of +/- + // 1 from using different rounding in right shifts, so the tolerance on + // the difference may have to be as large as 2. + tolerated_max_diff = 2; + } else if (RUY_PLATFORM_ARM) { + // All our code paths on ARM (32 and 64-bit) are bit-exact + // with the reference code (by design of the reference code). + tolerated_max_diff = 0; + } else if (RUY_PLATFORM_X86) { + // Our reference and ARM paths have diverged from x86 paths in PR #227. + // TODO: update the x86 path to adapt to that and reset that tolerance + // to 0. + tolerated_max_diff = 1; + } else { + // Other architectures, which we don't have dedicated code paths for + // at the moment. + // TODO: try resetting that tolerance to 0, since by + // definition we're only using non-optimized code paths here. + tolerated_max_diff = 1; + } + + // totally empirical + tolerated_mean_diff = std::min(1.0, 2.0 * std::pow(size, -0.18)); + } + double sum_diff = 0; + for (int row = 0; row < matrix1.layout().rows(); row++) { + for (int col = 0; col < matrix1.layout().cols(); col++) { + double elem1 = Element(matrix1, row, col); + double elem2 = Element(matrix2, row, col); + double diff = elem1 - elem2; + + sum_diff += diff; + // Test (std::abs(diff) > tolerated_max_diff), but also true if diff is + // NaN. + if (!(std::abs(diff) <= tolerated_max_diff)) { + return false; + } + } + } + double mean_diff = sum_diff / size; + if (std::abs(mean_diff) > tolerated_mean_diff) { + return false; + } + return true; +} + +template <typename Scalar> +bool Agree(ExternalPath external_path1, + const StorageMatrix<Scalar>& storage_matrix1, + ExternalPath external_path2, + const StorageMatrix<Scalar>& storage_matrix2, int depth) { + VerifyConsistentFields(storage_matrix1); + VerifyConsistentFields(storage_matrix2); + return Agree(external_path1, storage_matrix1.matrix, external_path2, + storage_matrix2.matrix, depth); +} + +template <typename Scalar> +bool Agree(const TestResult<Scalar>& result1, const TestResult<Scalar>& result2, + int depth) { + return Agree(result1.external_path, result1.storage_matrix, + result2.external_path, result2.storage_matrix, depth); +} + +template <typename Scalar> +void AddTestResultToCluster( + TestResult<Scalar>** result, + std::vector<std::vector<TestResult<Scalar>*>>& clusters, int depth) { + bool inserted = false; + for (auto& cluster : clusters) { + bool agreement = true; + // Test for agreement with every result in the cluster. + for (auto& other_result : cluster) { + agreement &= Agree(**result, *other_result, depth); + } + if (agreement) { + cluster.push_back(*result); + inserted = true; + } + } + if (!inserted) { + std::vector<TestResult<Scalar>*> new_results; + new_results.push_back(*result); + clusters.push_back(new_results); + } +} + +template <typename Scalar> +void PrintPathsInAgreement( + const std::vector<std::unique_ptr<TestResult<Scalar>>>& results, + int depth) { + // A container holding vectors of TestResults, where membership indicates + // that all TestResults agree with each other. + std::vector<std::vector<TestResult<Scalar>*>> clusters; + for (const auto& result : results) { + TestResult<Scalar>* test_result = result.get(); + AddTestResultToCluster(&test_result, clusters, depth); + } + + std::cerr << "Error: Not all paths agree. \n"; + for (auto& cluster : clusters) { + std::cerr << "These paths all agree with each other: "; + for (auto& result : cluster) { + std::cerr << PathName(*result) << ", "; + } + std::cerr << "but disagree with the rest.\n"; + } +} + +struct Stats { + double median; + double mean; + double min; + double max; +}; + +inline std::string StatsAsString(const Stats& stats) { + char buf[256]; + snprintf(buf, sizeof(buf), "(median = %g, mean = %g, min = %g, max = %g)", + stats.median, stats.mean, stats.min, stats.max); + return std::string(buf); +} + +template <typename Scalar> +void GetMatrixStats(const Matrix<Scalar>& matrix, Stats* stats) { + double min = std::numeric_limits<double>::infinity(); + double max = -std::numeric_limits<double>::infinity(); + double sum = 0; + std::vector<double> allvals; + for (int row = 0; row < matrix.layout().rows(); row++) { + for (int col = 0; col < matrix.layout().cols(); col++) { + double val = Element(matrix, row, col); + min = std::min(min, val); + max = std::max(max, val); + sum += val; + allvals.push_back(val); + } + } + std::sort(allvals.begin(), allvals.end()); + stats->min = min; + stats->max = max; + stats->mean = sum / allvals.size(); + stats->median = allvals[allvals.size() / 2]; +} + +struct ErrorAnalysis { + Stats stats_good; + Stats stats_bad; + // The below is to help document departure from bit exactness. It's probably + // not going to be relevant to floating-point. + std::set<int> error_rows; + std::set<int> error_cols; + int row_of_first_error = 0; + int col_of_first_error = 0; + double first_error_good_value = 0; + double first_error_bad_value = 0; +}; + +template <typename TestSetType> +void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index, + ErrorAnalysis* error_analysis) { + const auto& good_matrix = test_set.results[0]->storage_matrix.matrix; + const auto& bad_matrix = + test_set.results[first_bad_result_index]->storage_matrix.matrix; + GetMatrixStats(good_matrix, &error_analysis->stats_good); + GetMatrixStats(bad_matrix, &error_analysis->stats_bad); + bool found_first_error = false; + for (int row = 0; row < good_matrix.layout().rows(); row++) { + for (int col = 0; col < good_matrix.layout().cols(); col++) { + if (Element(good_matrix, row, col) != Element(bad_matrix, row, col)) { + if (!found_first_error) { + found_first_error = true; + error_analysis->row_of_first_error = row; + error_analysis->col_of_first_error = col; + error_analysis->first_error_good_value = + Element(good_matrix, row, col); + error_analysis->first_error_bad_value = Element(bad_matrix, row, col); + } + error_analysis->error_rows.insert(row); + error_analysis->error_cols.insert(col); + } + } + } +} + +template <typename TestSetType> +void ComputeReasonableMultiplier( + const Matrix<typename TestSetType::LhsScalar>& lhs, + const Matrix<typename TestSetType::RhsScalar>&, double* multiplier) { + using LhsScalar = typename TestSetType::LhsScalar; + using RhsScalar = typename TestSetType::RhsScalar; + using DstScalar = typename TestSetType::DstScalar; + if (std::is_floating_point<DstScalar>::value || + std::is_same<DstScalar, std::int32_t>::value) { + *multiplier = 0; + return; + } + *multiplier = static_cast<double>(std::numeric_limits<DstScalar>::max()) / + (static_cast<double>(lhs.layout().cols()) * + std::numeric_limits<LhsScalar>::max() * + std::numeric_limits<RhsScalar>::max()); +} + +inline void QuantizeMultiplier(double multiplier_double, + std::int32_t* multiplier_fixedpoint, + int* multiplier_exponent) { + RUY_CHECK_GT(multiplier_double, 0); + if (multiplier_double == 0.) { + *multiplier_fixedpoint = 0; + *multiplier_exponent = 0; + return; + } + const double q = std::frexp(multiplier_double, multiplier_exponent); + auto q_fixed = static_cast<std::int64_t>(std::round(q * (1ll << 31))); + RUY_CHECK_LE(q_fixed, (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + ++*multiplier_exponent; + } + RUY_CHECK_LE(q_fixed, std::numeric_limits<std::int32_t>::max()); + *multiplier_fixedpoint = static_cast<std::int32_t>(q_fixed); +} + +template <typename TestSetType> +void SwitchMultiplierToPerChannel(TestSetType* test_set) { + ChannelDimension channel_dimension = + (test_set->benchmark || global_random_engine()() % 2) + ? ChannelDimension::kRow + : ChannelDimension::kCol; + test_set->mul_params.set_channel_dimension(channel_dimension); + const int num_channels = channel_dimension == ChannelDimension::kRow + ? test_set->rows + : test_set->cols; + test_set->per_channel_multiplier_fixedpoint.resize(num_channels); + test_set->per_channel_multiplier_exponent.resize(num_channels); + for (int i = 0; i < num_channels; i++) { + // multipliers typically range in [2^30 ; 2^31 - 1]. + // Values in [0, 2^30 - 1] are normally unused, but harmless. + // Thus a good way to randomize multipliers is to subtract from them + // a random value smaller than 2^30 but still significant compared to it. + std::int32_t nudged_multiplier = + test_set->mul_params.multiplier_fixedpoint() - + (global_random_engine()() % (1 << 26)); + int nudged_exponent = test_set->mul_params.multiplier_exponent() - 1 + + (global_random_engine()() % 4); + test_set->per_channel_multiplier_fixedpoint[i] = nudged_multiplier; + test_set->per_channel_multiplier_exponent[i] = nudged_exponent; + } + test_set->mul_params.set_multiplier_fixedpoint(0); + test_set->mul_params.set_multiplier_exponent(0); + test_set->mul_params.set_multiplier_fixedpoint_perchannel( + test_set->per_channel_multiplier_fixedpoint.data()); + test_set->mul_params.set_multiplier_exponent_perchannel( + test_set->per_channel_multiplier_exponent.data()); +} + +template < + typename TestSetType, + bool IsApplicable = + std::is_same<typename TestSetType::AccumScalar, std::int32_t>::value && + !std::is_same<typename TestSetType::DstScalar, std::int32_t>::value> +struct MakeSpecMultiplierFieldsImpl {}; + +template <typename TestSetType> +struct MakeSpecMultiplierFieldsImpl<TestSetType, true> { + static void Run(TestSetType* test_set) { + double multiplier; + ComputeReasonableMultiplier<TestSetType>(test_set->lhs.matrix, + test_set->rhs.matrix, &multiplier); + typename TestSetType::AccumScalar multiplier_fixedpoint; + int multiplier_exponent; + QuantizeMultiplier(multiplier, &multiplier_fixedpoint, + &multiplier_exponent); + test_set->mul_params.set_multiplier_fixedpoint(multiplier_fixedpoint); + test_set->mul_params.set_multiplier_exponent(multiplier_exponent); + + if (!test_set->benchmark) { + test_set->perchannel = global_random_engine()() & 1; + } + if (test_set->perchannel) { + SwitchMultiplierToPerChannel(test_set); + } + } +}; + +template <typename TestSetType> +struct MakeSpecMultiplierFieldsImpl<TestSetType, false> { + static void Run(TestSetType*) {} +}; + +template <typename MulParamsType> +void MakeSpecClampFields(MulParamsType* mul_params) { + using DstScalar = typename MulParamsType::DstScalar; + + if (getenv("BENCHMARK_ONLY_MATMUL")) { + if (std::is_floating_point<DstScalar>::value) { + mul_params->set_clamp_min(-std::numeric_limits<DstScalar>::infinity()); + mul_params->set_clamp_max(std::numeric_limits<DstScalar>::infinity()); + } else { + mul_params->set_clamp_min(std::numeric_limits<DstScalar>::lowest()); + mul_params->set_clamp_max(std::numeric_limits<DstScalar>::max()); + } + return; + } + + mul_params->set_clamp_min(std::numeric_limits<DstScalar>::lowest() + 1); + mul_params->set_clamp_max(std::numeric_limits<DstScalar>::max() - 1); +} + +void MakeSpecClampFields(MulParams<std::int32_t, std::int32_t>*) { + // Returning raw accumulators, clamping is not supported. +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeZeroPoints() { + RUY_CHECK_EQ(life_stage, LifeStage::kInitial); + if (!benchmark && !use_specified_zero_points) { + MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point); + MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point); + // If destination is std::int32_t, no dst_zero_point is necessary. + if (std::is_same<DstScalar, std::int32_t>::value) { + dst_zero_point = 0; + } else { + MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point); + } + } + life_stage = LifeStage::kHasZeroPoints; +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeLhsRhs() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasZeroPoints); + MakeRandom(rows, depth, lhs_order, lhs_zero_point, layout_style, + RandomRange::kOffCenterAvoidMinValue, &lhs); + MakeRandom(depth, cols, rhs_order, rhs_zero_point, layout_style, + RandomRange::kGeneral, &rhs); + if (!benchmark) { + cache_lhs = (global_random_engine()() & 0xf) == 0; + cache_rhs = (global_random_engine()() & 0xf) == 0; + } + if (cache_lhs) { + lhs.matrix.set_cache_policy(CachePolicy::kAlwaysCache); + } + if (cache_rhs) { + rhs.matrix.set_cache_policy(CachePolicy::kAlwaysCache); + } + life_stage = LifeStage::kHasLhsRhs; +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeMulParams() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs); + + if (!getenv("BENCHMARK_ONLY_MATMUL") && + (benchmark || (global_random_engine()() & 1))) { + MakeRandomVector(RandomRange::kBias, std::max(rows, cols), &bias_data); + mul_params.set_bias(bias_data.data()); + } + if (lhs.matrix.zero_point() == std::numeric_limits<LhsScalar>::lowest() && + rhs.matrix.zero_point() == std::numeric_limits<RhsScalar>::lowest()) { + lhs.matrix.set_zero_point(lhs.matrix.zero_point() + 1); + } + MakeSpecMultiplierFieldsImpl<TestSet>::Run(this); + MakeSpecClampFields(&mul_params); + life_stage = LifeStage::kHasMulParams; +} + +inline int GetIntEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stoi(val); +} + +inline float GetFloatEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stof(val); +} + +inline int GetHexIntEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stoi(val, nullptr, 16); +} + +inline bool GetBoolEnvVarOrFalse(const char* name) { + return static_cast<bool>(GetIntEnvVarOrZero(name)); +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeOtherParams() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasMulParams); + if (max_num_threads == 0) { + max_num_threads = GetIntEnvVarOrZero("THREADS"); + } + life_stage = LifeStage::kHasOtherParams; +} + +inline std::vector<Path> PathsBitfieldAsVector(Path paths_bitfield) { + std::vector<Path> result; + std::uint32_t remaining_paths = static_cast<std::uint32_t>(paths_bitfield); + std::uint32_t test_bit = 1; + while (remaining_paths) { + if (remaining_paths & test_bit) { + result.push_back(static_cast<Path>(test_bit)); + } + remaining_paths &= ~test_bit; + test_bit <<= 1; + } + return result; +} + +inline std::vector<Tuning> EnumerateTuningsForPath(Path path, bool benchmark) { + if (benchmark) { + return {Tuning::kAuto}; + } +#if RUY_PLATFORM_ARM + if (path == Path::kNeon || path == Path::kNeonDotprod) { + return {Tuning::kA55ish, Tuning::kGeneric, Tuning::kAuto}; + } +#endif + (void)path; + return {Tuning::kAuto}; +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::MakeResultPaths() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasOtherParams); + + Path paths_bitfield = static_cast<Path>(GetHexIntEnvVarOrZero("PATHS")); + + if (paths_bitfield == Path::kNone) { + // Use a dummy Context just to perform the resolution of specific runtime + // enabled paths. + Context context; + paths_bitfield = get_ctx(&context)->GetRuntimeEnabledPaths(); + } + + // Disable the internal test-only variants of the StandardCpp path on large + // tests. + // This constant be large enough to exercise some interesting BlockMap logic, + // small enough to avoid large test latency increases from running these + // slow code paths on large matrix multiplications. + const int kMaxSizeToTestInternalStandardCppVariants = 300; + if (rows > kMaxSizeToTestInternalStandardCppVariants || + cols > kMaxSizeToTestInternalStandardCppVariants || + depth > kMaxSizeToTestInternalStandardCppVariants) { + paths_bitfield = paths_bitfield & kAllPaths; + } + + // Trim bits that don't correspond to a compiled path, + // to allow specifying e.g. ffff to mean 'all paths' regardless of whether all + // those bits exist as actual paths. + paths_bitfield = paths_bitfield & kAllPathsIncludingInternalVariants; + RUY_CHECK_NE(paths_bitfield, Path::kNone); + paths = PathsBitfieldAsVector(paths_bitfield); + + // kReference is a special 'external path' that's always available. + // It can still be disabled by NOEXT. + if (!GetBoolEnvVarOrFalse("NOEXT")) { + external_paths.push_back(ExternalPath::kReference); + } + +#ifdef RUY_TEST_EXTERNAL_PATHS + + if (!GetBoolEnvVarOrFalse("NOEXT")) { + if (SupportsGemmlowp<TestSet>::kValue) { +#ifdef GEMMLOWP_SSE4 + const bool gemmlowp_supported = + !mul_params.multiplier_fixedpoint_perchannel() && + mul_params.channel_dimension() == ChannelDimension::kRow; +#else + const bool gemmlowp_supported = + mul_params.channel_dimension() == ChannelDimension::kRow; +#endif + if (gemmlowp_supported) { + external_paths.push_back(ExternalPath::kGemmlowp); + } + } + if (UsesSingleScalarType<TestSet>::kValue && + std::is_floating_point<AccumScalar>::value) { + external_paths.push_back(ExternalPath::kEigen); + if (layout_style == LayoutStyle::kUnstridedLinear) { + external_paths.push_back(ExternalPath::kEigenTensor); + } +// We link against a generic BLAS target that only maps to OpenBLAS on specific +// architectures. +#if RUY_PLATFORM_ARM_32 || RUY_PLATFORM_ARM_64 + // OpenBLAS multi-threading is disabled, so avoid mixing single-threaded + // and multi-threaded benchmark results. + if (max_num_threads == 1 && !getenv("NO_OPENBLAS")) { + external_paths.push_back(ExternalPath::kOpenBlas); + } +#endif + } + } + +#endif // RUY_TEST_EXTERNAL_PATHS + + for (Path path : paths) { + for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) { + results.emplace_back(new TestResultType); + TestResultType& result = *results.back(); + result.path = path; + result.tuning = tuning; + MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, + RandomRange::kGeneral, &result.storage_matrix); + } + } + + for (ExternalPath external_path : external_paths) { + results.emplace_back(new TestResultType); + TestResultType& result = *results.back(); + result.external_path = external_path; + MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, + RandomRange::kGeneral, &result.storage_matrix); + } + + life_stage = LifeStage::kHasResultPaths; +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::EvalResult( + TestResult<DstScalar>* result) { + RUY_CHECK(result->path != Path::kNone || + result->external_path != ExternalPath::kNone); + if (result->path != Path::kNone) { + EvalRuy(result); + } else { + EvalExternalPath(this, result); + } + const std::string& pathname = PathName(*result); + if (std::find(CoveredPaths()->begin(), CoveredPaths()->end(), pathname) == + CoveredPaths()->end()) { + CoveredPaths()->push_back(pathname); + } +} + +using f32 = float; +using f64 = double; +using u8 = std::uint8_t; +using i8 = std::int8_t; +using u16 = std::uint16_t; +using i16 = std::int16_t; +using u32 = std::uint32_t; +using i32 = std::int32_t; +using u64 = std::uint64_t; +using i64 = std::int64_t; + +template <typename Scalar> +const char* TypeName() { + return nullptr; +} + +#define RUY_TYPENAME(TYPE) \ + template <> \ + inline const char* TypeName<TYPE>() { \ + return #TYPE; \ + } + +RUY_TYPENAME(f32) +RUY_TYPENAME(f64) +RUY_TYPENAME(u8) +RUY_TYPENAME(i8) +RUY_TYPENAME(u16) +RUY_TYPENAME(i16) +RUY_TYPENAME(u32) +RUY_TYPENAME(i32) +RUY_TYPENAME(u64) +RUY_TYPENAME(i64) + +#undef RUY_TYPENAME + +template <typename Scalar> +const char* SymmetryName(const Matrix<Scalar>& matrix) { + if (matrix.zero_point() == SymmetricZeroPoint<Scalar>()) { + return "symm"; + } else { + return "asymm"; + } +} + +template <typename Scalar> +int StorageSize(const Matrix<Scalar>& matrix) { + return sizeof(Scalar) * FlatSize(matrix.layout()); +} + +// Helper that replicates a buffer and gives out pointers to the replicas. +// This is useful when one wants to traverse data so that it is cold in cache. +// By having a sufficiently large value of num_repeats, one can ensure that the +// working set covered by the replicas is greater than the cache size. +template <typename T> +class RepeatedBuffer { + public: + RepeatedBuffer() = default; + void Init(const T* elems, int num_elems, int num_repeats) { + buffers_.clear(); + allocator_.FreeAll(); + for (int i = 0; i < num_repeats; i++) { + T* p; + allocator_.Allocate(num_elems, &p); + memcpy(p, elems, num_elems * sizeof(T)); + buffers_.push_back(p); + } + } + T* Next() { + T* ret = buffers_[current_]; + current_ = (current_ + 1) % buffers_.size(); + return ret; + } + + private: + Allocator allocator_; + std::vector<T*> buffers_; + int current_ = 0; +}; + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Benchmark( + TestResult<DstScalar>* result) { + const bool cold = getenv("RUY_BENCHMARK_COLD"); + LhsScalar* orig_lhs_data = lhs.matrix.data(); + RhsScalar* orig_rhs_data = rhs.matrix.data(); + DstScalar* orig_dst_data = result->storage_matrix.matrix.data(); + + int num_matmul_sets = 0; + + RepeatedBuffer<LhsScalar> cold_lhs; + RepeatedBuffer<RhsScalar> cold_rhs; + RepeatedBuffer<DstScalar> cold_dst; + + if (cold) { + const int kWorkingSetSize = 100 << 20; + const int each_matmul_set_size = StorageSize(lhs.matrix) + + StorageSize(rhs.matrix) + + StorageSize(result->storage_matrix.matrix); + num_matmul_sets = + (kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size; + + cold_lhs.Init(lhs.matrix.data(), FlatSize(lhs.matrix.layout()), + num_matmul_sets); + cold_rhs.Init(rhs.matrix.data(), FlatSize(rhs.matrix.layout()), + num_matmul_sets); + cold_dst.Init(result->storage_matrix.matrix.data(), + FlatSize(result->storage_matrix.matrix.layout()), + num_matmul_sets); + } + const bool record_pmu = GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU"); + int repeats = GetIntEnvVarOrZero("RUY_BENCHMARK_REPEATS"); + if (!repeats) { + repeats = 4; + } + float benchmark_min_secs = GetFloatEnvVarOrZero("RUY_BENCHMARK_MIN_SECS"); + if (!benchmark_min_secs) { + benchmark_min_secs = 0.5; + } +#ifdef RUY_PROFILER + { + const char* lhstype = TypeName<LhsScalar>(); + const char* lhssymm = SymmetryName(lhs.matrix); + const char* rhstype = TypeName<RhsScalar>(); + const char* rhssymm = SymmetryName(rhs.matrix); + + printf("Profiling path=%s shape=(%dx%dx%d) lhs=(%s,%s) rhs=(%s,%s)\n", + PathName(*result).c_str(), rows, depth, cols, lhstype, lhssymm, + rhstype, rhssymm); + ruy::profiler::ScopeProfile profile; +#endif + + float latency = std::numeric_limits<float>::infinity(); + float l1_refill_rate = std::numeric_limits<float>::infinity(); + float l2_refill_rate = std::numeric_limits<float>::infinity(); + float l3_refill_rate = std::numeric_limits<float>::infinity(); + float l1tlb_refill_rate = std::numeric_limits<float>::infinity(); + float l2tlb_refill_rate = std::numeric_limits<float>::infinity(); + float mispred_rate = std::numeric_limits<float>::infinity(); + float frontend_stall_rate = std::numeric_limits<float>::infinity(); + float backend_stall_rate = std::numeric_limits<float>::infinity(); + + for (int repeat = 0; repeat < repeats; repeat++) { + auto& pmu_events = GlobalPmuEvents(); + if (record_pmu) { + pmu_events.StartRecording(); + } + TimePoint time_start = Now(); + TimePoint t = time_start; + int iters = 0; + int iters_at_a_time = 1; + while (ToFloatSeconds(t - time_start) < benchmark_min_secs) { + for (int i = 0; i < iters_at_a_time; i++) { + if (cold) { + lhs.matrix.set_data(cold_lhs.Next()); + rhs.matrix.set_data(cold_rhs.Next()); + result->storage_matrix.matrix.set_data(cold_dst.Next()); + } + EvalResult(result); + iters++; + } + iters_at_a_time *= 2; + t = Now(); + } + latency = std::min( + latency, static_cast<float>(ToFloatSeconds(t - time_start) / iters)); + if (record_pmu) { + pmu_events.StopRecording(); + const float normalization_factor = + 1.0f / (static_cast<float>(iters) * rows * cols * depth); + l1_refill_rate = std::min( + l1_refill_rate, pmu_events.L1RefillCount() * normalization_factor); + l2_refill_rate = std::min( + l2_refill_rate, pmu_events.L2RefillCount() * normalization_factor); + l3_refill_rate = std::min( + l3_refill_rate, pmu_events.L3RefillCount() * normalization_factor); + l1tlb_refill_rate = + std::min(l1tlb_refill_rate, + pmu_events.L1TLBRefillCount() * normalization_factor); + l2tlb_refill_rate = + std::min(l2tlb_refill_rate, + pmu_events.L2TLBRefillCount() * normalization_factor); + mispred_rate = + std::min(mispred_rate, pmu_events.BranchMispredictionCount() * + normalization_factor); + frontend_stall_rate = + std::min(frontend_stall_rate, + pmu_events.FrontendStallCount() * normalization_factor); + backend_stall_rate = + std::min(backend_stall_rate, + pmu_events.BackendStallCount() * normalization_factor); + } + } + result->latency = latency; + if (record_pmu) { + result->l1_refill_rate = l1_refill_rate; + result->l2_refill_rate = l2_refill_rate; + result->l3_refill_rate = l3_refill_rate; + result->l1tlb_refill_rate = l1tlb_refill_rate; + result->l2tlb_refill_rate = l2tlb_refill_rate; + result->mispred_rate = mispred_rate; + result->frontend_stall_rate = frontend_stall_rate; + result->backend_stall_rate = backend_stall_rate; + } + +#ifdef RUY_PROFILER + } + fflush(stdout); +#endif + + if (cold) { + lhs.matrix.set_data(orig_lhs_data); + rhs.matrix.set_data(orig_rhs_data); + memcpy(orig_dst_data, result->storage_matrix.matrix.data(), + StorageSize(result->storage_matrix.matrix)); + result->storage_matrix.matrix.set_data(orig_dst_data); + } +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Eval() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasResultPaths); + for (auto& result : results) { + if (benchmark) { + Benchmark(result.get()); + } else { + EvalResult(result.get()); + } + } + life_stage = LifeStage::kEvaluated; +} + +template <typename Scalar> +std::string DumpRegion(const Matrix<Scalar>& matrix, int center_row, + int center_col) { + static constexpr int kRadius = 20; + int first_row = std::max(0, center_row - kRadius); + int last_row = std::min(matrix.layout().rows() - 1, center_row + kRadius); + int first_col = std::max(0, center_col - kRadius); + int last_col = std::min(matrix.layout().cols() - 1, center_col + kRadius); + std::ostringstream stream; + for (int row = first_row; row <= last_row; row++) { + for (int col = first_col; col <= last_col; col++) { + stream << static_cast<double>(Element(matrix, row, col)) << " "; + } + stream << "\n"; + } + return stream.str(); +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::VerifyTestResults() + const { + const int depth = lhs.matrix.layout().cols(); + for (int i = 0; i < static_cast<int>(results.size()) - 1; i++) { + if (!Agree(*results[i], *results[i + 1], depth)) { + PrintPathsInAgreement(results, depth); + ErrorAnalysis error_analysis; + AnalyzeTestError(*this, i + 1, &error_analysis); + std::cerr << "Shape: rows = " << rows << ", cols = " << cols + << ", depth = " << depth << std::endl; + std::cerr << "Stats of the good result matrix: " + << StatsAsString(error_analysis.stats_good) << std::endl; + std::cerr << "Stats of the bad result matrix: " + << StatsAsString(error_analysis.stats_bad) << std::endl; + if (static_cast<int>(error_analysis.error_rows.size()) < rows) { + std::cerr << "Rows containing errors: " + << Join(error_analysis.error_rows) << std::endl; + } else { + std::cerr << "Errors found in ALL rows." << std::endl; + } + if (static_cast<int>(error_analysis.error_cols.size()) < cols) { + std::cerr << "Cols containing errors: " + << Join(error_analysis.error_cols) << std::endl; + } else { + std::cerr << "Errors found in ALL cols." << std::endl; + } + std::cerr << "The first error occurs at row " + << error_analysis.row_of_first_error << ", col " + << error_analysis.col_of_first_error << std::endl; + std::cerr << "Good value: " << error_analysis.first_error_good_value + << std::endl; + std::cerr << "Bad value : " << error_analysis.first_error_bad_value + << std::endl; + std::cerr << "Region of Good result matrix around first error:\n\n" + << DumpRegion(results[0]->storage_matrix.matrix, + error_analysis.row_of_first_error, + error_analysis.col_of_first_error) + << std::endl; + std::cerr << "Region of Bad result matrix around first error:\n\n" + << DumpRegion(results[i + 1]->storage_matrix.matrix, + error_analysis.row_of_first_error, + error_analysis.col_of_first_error) + << std::endl; + RUY_CHECK(false); + } + } +} + +template <typename LhsScalar, typename RhsScalar, typename AccumScalar, + typename DstScalar> +void TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Verify() { + RUY_CHECK_EQ(life_stage, LifeStage::kEvaluated); + VerifyTestResults(); + life_stage = LifeStage::kFinal; +} + +template <typename TestSetType> +void TestRCC(int rows, int depth, int cols) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = Order::kRowMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kUnstridedLinear; + test_set.Run(); +} + +template <typename TestSetType> +void TestNonRCC(int rows, int depth, int cols) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = Order::kColMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kUnstridedLinear; + test_set.Run(); +} + +template <typename TestSetType> +void TestLinearAllOrders(int rows, int depth, int cols) { + const std::vector<Order> orders{Order::kColMajor, Order::kRowMajor}; + + for (Order lhs_order : orders) { + for (Order rhs_order : orders) { + for (Order dst_order : orders) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = lhs_order; + test_set.rhs_order = rhs_order; + test_set.dst_order = dst_order; + test_set.layout_style = LayoutStyle::kLinear; + test_set.Run(); + } + } + } +} + +} // namespace ruy + +#endif // RUY_RUY_TEST_H_ diff --git a/ruy/test_fast.cc b/ruy/test_fast.cc new file mode 100644 index 0000000..5598afb --- /dev/null +++ b/ruy/test_fast.cc @@ -0,0 +1,109 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test contains cheap test cases, completes in a few seconds. + +#include <vector> + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; + +using TestSetType = TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>; + +TEST(RuyTest, TestSquareMuls) { + const std::vector<int> sizes{ + // small sizes + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + // multiplies of 16 + 16, + 32, + 48, + 64, + // pot-minus-1 sizes + 15, + 31, + 63, + // pot-plus-1 sizes + 17, + 33, + 65, + }; + + for (int size : sizes) { + TestRCC<TestSetType>(size, size, size); + TestLinearAllOrders<TestSetType>(size, size, size); + } +} + +TEST(RuyTest, TestMiscMuls) { + const int shapes[][3] = { + {2, 3, 4}, {7, 6, 5}, {12, 23, 6}, {19, 3, 11}, {3, 10, 17}, + {30, 21, 43}, {7, 57, 9}, {49, 69, 71}, {38, 111, 29}, {87, 98, 76}, + {16, 96, 16}, {16, 88, 16}, {16, 84, 16}, {16, 92, 16}, {16, 82, 16}, + {16, 81, 16}, {16, 95, 16}, {3, 128, 5}}; + for (const auto& shape : shapes) { + TestLinearAllOrders<TestSetType>(shape[0], shape[1], shape[2]); + } +} + +TEST(RuyTest, TestDeepMuls) { + // TODO(b/137649322): clarify what's the max allowed matrix size. + TestRCC<TestSetType>(1, 32767, 1); + TestLinearAllOrders<TestSetType>(5, 5001, 4); + TestLinearAllOrders<TestSetType>(9, 1025, 10); +} + +TEST(RuyTest, TestShallowMuls) { + TestLinearAllOrders<TestSetType>(101, 1, 103); + TestLinearAllOrders<TestSetType>(71, 2, 53); + TestLinearAllOrders<TestSetType>(51, 3, 73); + TestLinearAllOrders<TestSetType>(51, 4, 43); +} + +TEST(RuyTest, TestNarrowMuls) { + for (int width : {1, 2, 3, 4, 5, 8}) { + TestLinearAllOrders<TestSetType>(width, 12, 13); + TestLinearAllOrders<TestSetType>(15, 19, width); + TestLinearAllOrders<TestSetType>(width, 123, 137); + TestLinearAllOrders<TestSetType>(158, 119, width); + } +} + +TEST(RuyTest, TestGEMV) { + for (int size = 1; size < 1024; size *= 2) { + for (int depth = 1; depth < 500; depth += 47) { + TestLinearAllOrders<TestSetType>(size, depth, 1); + } + } + TestLinearAllOrders<TestSetType>(5, 5001, 1); + TestLinearAllOrders<TestSetType>(8193, 17, 1); +} + +} // namespace ruy diff --git a/ruy/test_slow.cc b/ruy/test_slow.cc new file mode 100644 index 0000000..0ddc40b --- /dev/null +++ b/ruy/test_slow.cc @@ -0,0 +1,70 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test contains more expensive test cases. + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; + +using TestSetType = TestSet<LhsScalar, RhsScalar, AccumScalar, DstScalar>; + +TEST(RuyTest, TestBigNarrowMuls) { + for (int width : {1, 2, 3, 4, 5, 8}) { + TestRCC<TestSetType>(width, 401, 601); + TestRCC<TestSetType>(587, 443, width); + } + TestRCC<TestSetType>(7, 45984, + 5); // Large enough to trigger row-sum overflows. + TestRCC<TestSetType>(512, 256, 16); +} + +TEST(RuyTest, TestBigShallowMuls) { + TestLinearAllOrders<TestSetType>(501, 1, 321); + TestLinearAllOrders<TestSetType>(301, 5, 403); + TestLinearAllOrders<TestSetType>(256, 32, 512); +} + +TEST(RuyTest, TestBigMuls) { + TestRCC<TestSetType>(225, 303, 199); + TestLinearAllOrders<TestSetType>(256, 192, 128); +} + +TEST(RuyTest, TestBigPowerOfTwoDepthWithAvoidAliasing) { + // Important to test some power-of-two depths: that's when the + // RUY_AVOID_ALIASING optimization kicks in and makes unstrided matrices + // strided, exposing bugs in kernels mixing up size and stride. + // Moreover, it's important that the test matrices be sufficiently wide + // that they will result in multiple blocks, exposing bugs in the + // computation of the base address of each block. + TestLinearAllOrders<TestSetType>(70, 1024, 80); + TestLinearAllOrders<TestSetType>(60, 2048, 70); + TestLinearAllOrders<TestSetType>(40, 4096, 50); +} + +TEST(RuyTest, TestGEMV) { + for (int size = 1025; size <= 1409; size += 384) { + for (int depth = 350; depth < 500; depth += 47) { + TestLinearAllOrders<TestSetType>(size, depth, 1); + } + } +} + +} // namespace ruy diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc new file mode 100644 index 0000000..100cfe3 --- /dev/null +++ b/ruy/thread_pool.cc @@ -0,0 +1,218 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/thread_pool.h" + +#include <atomic> +#include <chrono> // NOLINT(build/c++11) +#include <condition_variable> // NOLINT(build/c++11) +#include <cstdint> +#include <cstdlib> +#include <memory> +#include <mutex> // NOLINT(build/c++11) +#include <thread> // NOLINT(build/c++11) + +#include "ruy/check_macros.h" +#include "ruy/trace.h" +#include "ruy/wait.h" + +namespace ruy { + +// A worker thread. +class Thread { + public: + enum class State { + Startup, // The initial state before the thread main loop runs. + Ready, // Is not working, has not yet received new work to do. + HasWork, // Has work to do. + ExitAsSoonAsPossible // Should exit at earliest convenience. + }; + + explicit Thread(BlockingCounter* counter_to_decrement_when_ready, + Duration spin_duration) + : task_(nullptr), + state_(State::Startup), + counter_to_decrement_when_ready_(counter_to_decrement_when_ready), + spin_duration_(spin_duration) { + thread_.reset(new std::thread(ThreadFunc, this)); + } + + ~Thread() { + ChangeState(State::ExitAsSoonAsPossible); + thread_->join(); + } + + // Changes State; may be called from either the worker thread + // or the master thread; however, not all state transitions are legal, + // which is guarded by assertions. + // + // The Task argument is to be used only with new_state==HasWork. + // It specifies the Task being handed to this Thread. + void ChangeState(State new_state, Task* task = nullptr) { + state_mutex_.lock(); + State old_state = state_.load(std::memory_order_relaxed); + RUY_DCHECK_NE(old_state, new_state); + switch (old_state) { + case State::Startup: + RUY_DCHECK_EQ(new_state, State::Ready); + break; + case State::Ready: + RUY_DCHECK(new_state == State::HasWork || + new_state == State::ExitAsSoonAsPossible); + break; + case State::HasWork: + RUY_DCHECK(new_state == State::Ready || + new_state == State::ExitAsSoonAsPossible); + break; + default: + abort(); + } + switch (new_state) { + case State::Ready: + if (task_) { + // Doing work is part of reverting to 'ready' state. + task_->Run(); + task_ = nullptr; + } + break; + case State::HasWork: + RUY_DCHECK(!task_); + task_ = task; + break; + default: + break; + } + state_.store(new_state, std::memory_order_relaxed); + state_cond_.notify_all(); + state_mutex_.unlock(); + if (new_state == State::Ready) { + counter_to_decrement_when_ready_->DecrementCount(); + } + } + + static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); } + + // Called by the master thead to give this thread work to do. + void StartWork(Task* task) { ChangeState(State::HasWork, task); } + + private: + // Thread entry point. + void ThreadFuncImpl() { + RUY_TRACE_SCOPE_NAME("Ruy worker thread function"); + ChangeState(State::Ready); + + // Thread main loop + while (true) { + RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration"); + // In the 'Ready' state, we have nothing to do but to wait until + // we switch to another state. + const auto& condition = [this]() { + return state_.load(std::memory_order_acquire) != State::Ready; + }; + RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING); + Wait(condition, spin_duration_, &state_cond_, &state_mutex_); + + // Act on new state. + switch (state_.load(std::memory_order_acquire)) { + case State::HasWork: { + RUY_TRACE_SCOPE_NAME("Worker thread task"); + // Got work to do! So do it, and then revert to 'Ready' state. + ChangeState(State::Ready); + break; + } + case State::ExitAsSoonAsPossible: + return; + default: + abort(); + } + } + } + + // The underlying thread. + std::unique_ptr<std::thread> thread_; + + // The task to be worked on. + Task* task_; + + // The condition variable and mutex guarding state changes. + std::condition_variable state_cond_; + std::mutex state_mutex_; + + // The state enum tells if we're currently working, waiting for work, etc. + // Its concurrent accesses by the thread and main threads are guarded by + // state_mutex_, and can thus use memory_order_relaxed. This still needs + // to be a std::atomic because we use WaitForVariableChange. + std::atomic<State> state_; + + // pointer to the master's thread BlockingCounter object, to notify the + // master thread of when this thread switches to the 'Ready' state. + BlockingCounter* const counter_to_decrement_when_ready_; + + // See ThreadPool::spin_duration_. + const Duration spin_duration_; +}; + +void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { + RUY_TRACE_SCOPE_NAME("ThreadPool::Execute"); + RUY_DCHECK_GE(task_count, 1); + + // Case of 1 thread: just run the single task on the current thread. + if (task_count == 1) { + (tasks + 0)->Run(); + return; + } + + // Task #0 will be run on the current thread. + CreateThreads(task_count - 1); + counter_to_decrement_when_ready_.Reset(task_count - 1); + for (int i = 1; i < task_count; i++) { + RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK); + auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride; + threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address)); + } + + RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD); + // Execute task #0 immediately on the current thread. + (tasks + 0)->Run(); + + RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS); + // Wait for the threads submitted above to finish. + counter_to_decrement_when_ready_.Wait(spin_duration_); +} + +// Ensures that the pool has at least the given count of threads. +// If any new thread has to be created, this function waits for it to +// be ready. +void ThreadPool::CreateThreads(int threads_count) { + RUY_DCHECK_GE(threads_count, 0); + unsigned int unsigned_threads_count = threads_count; + if (threads_.size() >= unsigned_threads_count) { + return; + } + counter_to_decrement_when_ready_.Reset(threads_count - threads_.size()); + while (threads_.size() < unsigned_threads_count) { + threads_.push_back( + new Thread(&counter_to_decrement_when_ready_, spin_duration_)); + } + counter_to_decrement_when_ready_.Wait(spin_duration_); +} + +ThreadPool::~ThreadPool() { + for (auto w : threads_) { + delete w; + } +} + +} // end namespace ruy diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h new file mode 100644 index 0000000..e3b6803 --- /dev/null +++ b/ruy/thread_pool.h @@ -0,0 +1,127 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 +// license. + +#ifndef RUY_RUY_THREAD_POOL_H_ +#define RUY_RUY_THREAD_POOL_H_ + +#include <vector> + +#include "ruy/blocking_counter.h" +#include "ruy/time.h" + +namespace ruy { + +// A workload for a thread. +struct Task { + virtual ~Task() {} + virtual void Run() = 0; +}; + +class Thread; + +// A simple pool of threads, that only allows the very +// specific parallelization pattern that we use here: +// One thread, which we call the 'main thread', calls Execute, distributing +// a Task each to N threads, being N-1 'worker threads' and the main thread +// itself. After the main thread has completed its own Task, it waits for +// the worker threads to have all completed. That is the only synchronization +// performed by this ThreadPool. +// +// In particular, there is a naive 1:1 mapping of Tasks to threads. +// This ThreadPool considers it outside of its own scope to try to work +// with fewer threads than there are Tasks. The idea is that such N:M mappings +// of tasks to threads can be implemented as a higher-level feature on top of +// the present low-level 1:1 threadpool. For example, a user might have a +// Task subclass referencing a shared atomic counter indexing into a vector of +// finer-granularity subtasks. Different threads would then concurrently +// increment this atomic counter, getting each their own subtasks to work on. +// That approach is the one used in ruy's multi-thread matrix multiplication +// implementation --- see ruy's TrMulTask. +class ThreadPool { + public: + ThreadPool() {} + + ~ThreadPool(); + + // Executes task_count tasks on task_count threads. + // Grows the threadpool as needed to have at least (task_count-1) threads. + // The 0-th task is run on the thread on which Execute is called: that + // is by definition what we call the "main thread". Synchronization of all + // threads is performed before this function returns. + // + // As explained in the class comment, there is a 1:1 mapping of tasks to + // threads. If you need something smarter than that, for instance if you + // want to run an unbounded number of tasks on a bounded number of threads, + // then you need something higher-level than this ThreadPool, that can + // be layered on top of it by appropriately subclassing Tasks. + // + // TaskType must be a subclass of ruy::Task. That is implicitly guarded by + // the static_cast in this inline implementation. + template <typename TaskType> + void Execute(int task_count, TaskType* tasks) { + ExecuteImpl(task_count, sizeof(TaskType), static_cast<Task*>(tasks)); + } + + void set_spin_milliseconds(float milliseconds) { + spin_duration_ = DurationFromMilliseconds(milliseconds); + } + + float spin_milliseconds() const { + return ToFloatMilliseconds(spin_duration_); + } + + private: + // Ensures that the pool has at least the given count of threads. + // If any new thread has to be created, this function waits for it to + // be ready. + void CreateThreads(int threads_count); + + // Non-templatized implementation of the public Execute method. + // See the inline implementation of Execute for how this is used. + void ExecuteImpl(int task_count, int stride, Task* tasks); + + // copy construction disallowed + ThreadPool(const ThreadPool&) = delete; + + // The threads in this pool. They are owned by the pool: + // the pool creates threads and destroys them in its destructor. + std::vector<Thread*> threads_; + + // The BlockingCounter used to wait for the threads. + BlockingCounter counter_to_decrement_when_ready_; + + // This value was empirically derived with some microbenchmark, we don't have + // high confidence in it. + // + // That this value means that we may be sleeping substantially longer + // than a scheduler timeslice's duration is not necessarily surprising. The + // idea is to pick up quickly new work after having finished the previous + // workload. When it's new work within the same GEMM as the previous work, the + // time interval that we might be busy-waiting is very small, so for that + // purpose it would be more than enough to sleep for 1 ms. + // That is all what we would observe on a GEMM benchmark. However, in a real + // application, after having finished a GEMM, we might do unrelated work for + // a little while, then start on a new GEMM. In that case the wait interval + // may be a little longer. There may also not be another GEMM for a long time, + // in which case we'll end up passively waiting below. + Duration spin_duration_ = DurationFromMilliseconds(2); +}; + +} // namespace ruy + +#endif // RUY_RUY_THREAD_POOL_H_ diff --git a/ruy/time.h b/ruy/time.h new file mode 100644 index 0000000..aa02245 --- /dev/null +++ b/ruy/time.h @@ -0,0 +1,87 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_TIME_H_ +#define RUY_RUY_TIME_H_ + +#include <chrono> // NOLINT(build/c++11) +#include <cstdint> // IWYU pragma: keep +#include <ratio> // NOLINT(build/c++11) + +#ifdef __linux__ +#include <sys/time.h> +// IWYU pragma: no_include <type_traits> + +#include <ctime> +#endif + +namespace ruy { + +using InternalDefaultClock = std::chrono::steady_clock; + +using TimePoint = InternalDefaultClock::time_point; +using Duration = InternalDefaultClock::duration; + +template <typename RepresentationType> +Duration DurationFromSeconds(RepresentationType representation) { + return std::chrono::duration_cast<Duration>( + std::chrono::duration<RepresentationType>(representation)); +} + +template <typename RepresentationType> +Duration DurationFromMilliseconds(RepresentationType representation) { + return std::chrono::duration_cast<Duration>( + std::chrono::duration<RepresentationType, std::milli>(representation)); +} + +template <typename RepresentationType> +Duration DurationFromNanoseconds(RepresentationType representation) { + return std::chrono::duration_cast<Duration>( + std::chrono::duration<RepresentationType, std::nano>(representation)); +} + +inline float ToFloatSeconds(const Duration& duration) { + return std::chrono::duration_cast<std::chrono::duration<float>>(duration) + .count(); +} + +inline float ToFloatMilliseconds(const Duration& duration) { + return std::chrono::duration_cast<std::chrono::duration<float, std::milli>>( + duration) + .count(); +} + +inline std::int64_t ToInt64Nanoseconds(const Duration& duration) { + return std::chrono::duration_cast< + std::chrono::duration<std::int64_t, std::nano>>(duration) + .count(); +} + +inline TimePoint Now() { return InternalDefaultClock::now(); } + +inline TimePoint CoarseNow() { +#ifdef __linux__ + timespec t; + clock_gettime(CLOCK_MONOTONIC_COARSE, &t); + return TimePoint( + DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec)); +#else + return Now(); +#endif +} + +} // namespace ruy + +#endif // RUY_RUY_TIME_H_ diff --git a/ruy/trace.h b/ruy/trace.h new file mode 100644 index 0000000..4f5059e --- /dev/null +++ b/ruy/trace.h @@ -0,0 +1,836 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_TRACE_H_ +#define RUY_RUY_TRACE_H_ + +#ifdef RUY_TRACE + +#include <algorithm> +#include <cstdio> +#include <cstdlib> +#include <memory> +#include <mutex> +#include <string> +#include <thread> +#include <vector> + +#include "ruy/mat.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/side_pair.h" + +namespace ruy { + +// Helper for `formatted` so we don't have to put .c_str() on strings. +template <typename T> +T value_for_snprintf(T value) { + return value; +} + +inline const char* value_for_snprintf(const std::string& s) { + return s.c_str(); +} + +// A sprintf-like function returning a std::string. +// Remove this once we can rely on std::format (c++20). +template <typename... Args> +std::string formatted(const char* format, Args... args) { + char buf[1024]; +#pragma GCC diagnostic push +#pragma GCC diagnostic warning "-Wformat-security" + int size = snprintf(buf, sizeof buf, format, value_for_snprintf(args)...); +#pragma GCC diagnostic pop + if (size <= 0) { + abort(); + } + return std::string(buf); +} + +// An entry in the trace. +struct ThreadTraceEntry final { + std::string text; + int indent = 0; + const char* source_file = nullptr; + int source_line = 0; +}; + +// Trace for one thread. +class ThreadTrace final { + public: + ~ThreadTrace() {} + + void set_thread_id(int thread_id) { thread_id_ = thread_id; } + int thread_id() const { return thread_id_; } + + bool is_in_run_ahead_packing_loop() const { + return is_in_run_ahead_packing_loop_; + } + void set_is_in_run_ahead_packing_loop(bool value) { + is_in_run_ahead_packing_loop_ = value; + } + + void set_current_source_file(const char* source_file) { + current_source_file_ = source_file; + } + + void set_current_source_line(int source_line) { + current_source_line_ = source_line; + } + + const std::vector<ThreadTraceEntry>& entries() const { return entries_; } + + template <typename... Args> + void Write(const char* format, Args... args) { + ThreadTraceEntry entry; + entry.text = formatted(format, args...); + entry.indent = indent_; + entry.source_file = current_source_file_; + entry.source_line = current_source_line_; + entries_.emplace_back(std::move(entry)); + } + + template <typename... Args> + void EnterScope(const char* scope_name) { + Write("%s {", scope_name); + indent_++; + } + void LeaveScope(const char* scope_name) { + indent_--; + Write("} // end of %s", scope_name); + } + + private: + // The trace contents + std::vector<ThreadTraceEntry> entries_; + + // Current indentation level. + int indent_ = 0; + // Thread's ID as set by Ruy, e.g. [0,N-1]. Not OS TID. + int thread_id_ = -1; + // The run-ahead loop in `EnsurePacked` may run many iterations when the + // thread is waiting for a block to be packed by another thread --- it's + // a busy wait. We track whether we are in that mode to avoid generating + // many uninteresting trace entries. + bool is_in_run_ahead_packing_loop_ = false; + // Last recorded value of __FILE__ and __LINE__, as a convenience so we don't + // have to pass these in every call to `Write`. + const char* current_source_file_ = nullptr; + int current_source_line_ = 0; +}; + +// Main components of ruy. Used for trace colorization. +enum class Component { kNone, kFrontEnd, kMiddleEnd, kBackEnd, kThreadPool }; + +// Output format for the trace. +enum class TraceOutputFormat { kNone, kTerminal, kHtml }; + +inline std::string IndentString(int indent) { + std::string s; + for (int i = 0; i < indent; i++) { + s += " "; + } + return s; +} + +// Returns the text to write to the trace to open a colored section. +inline const char* ColorSectionStart(TraceOutputFormat output_format, + Component component) { + switch (output_format) { + case TraceOutputFormat::kTerminal: + switch (component) { + case Component::kFrontEnd: + return "\x1b[36m"; + case Component::kMiddleEnd: + return "\x1b[32m"; + case Component::kBackEnd: + return "\x1b[31m"; + case Component::kThreadPool: + return "\x1b[33m"; + default: + abort(); + return nullptr; + } + case TraceOutputFormat::kHtml: + switch (component) { + case Component::kFrontEnd: + return "<span style=\"background-color:#B2EBF2\">"; + case Component::kMiddleEnd: + return "<span style=\"background-color:#C8E6C9\">"; + case Component::kBackEnd: + return "<span style=\"background-color:#FFCDD2\">"; + case Component::kThreadPool: + return "<span style=\"background-color:#FFF9C4\">"; + default: + abort(); + return nullptr; + } + default: + abort(); + return nullptr; + } +} + +// Returns the text to write to the trace to close a colored section. +inline const char* ColorSectionEnd(TraceOutputFormat output_format) { + switch (output_format) { + case TraceOutputFormat::kTerminal: + return "\x1b[0m"; + case TraceOutputFormat::kHtml: + return "</span>"; + default: + abort(); + return nullptr; + } +} + +// Returns the output format to use for the trace. +inline TraceOutputFormat GetOutputFormat() { + const char* html_env = getenv("RUY_TRACE_HTML"); + if (html_env && strtol(html_env, nullptr, 10) != 0) { + return TraceOutputFormat::kHtml; + } else { + return TraceOutputFormat::kTerminal; + } +} + +// A `basename` function that's good enough for ruy __FILE__'s. +// Note: `basename` is POSIX-only and annoying (takes a char*, may mutate). +inline const char* GetBaseName(const char* path) { + std::size_t len = strlen(path); + if (len == 0) { + return path; + } + const char* ptr = path + len - 1; + while (ptr != path) { + if (*ptr == '/' || *ptr == '\\') { + return ptr + 1; + } + --ptr; + } + // Path did not contain any path separator. + return path; +} + +// Determines a Component (used for colorization) by source file. +inline Component GetComponentBySourceFile(const char* base_name) { + if (!strcmp(base_name, "pack.h") || !strcmp(base_name, "kernel.h")) { + return Component::kBackEnd; + } else if (!strcmp(base_name, "trmul.cc") || + !strcmp(base_name, "block_map.cc")) { + return Component::kMiddleEnd; + } else if (!strcmp(base_name, "thread_pool.cc")) { + return Component::kThreadPool; + } else { + return Component::kFrontEnd; + } +} + +inline std::string EscapeText(TraceOutputFormat output_format, + const std::string& text) { + if (output_format == TraceOutputFormat::kHtml) { + std::string escaped_text; + for (char c : text) { + if (c == '<') { + escaped_text += "<"; + } else if (c == '>') { + escaped_text += ">"; + } else { + escaped_text += c; + } + } + return escaped_text; + } else { + return text; + } +} + +// Prints an entry from the trace to the destination trace file. +inline void Print(const ThreadTraceEntry& entry, + TraceOutputFormat output_format, FILE* file) { + const char* base_name = GetBaseName(entry.source_file); + Component component = GetComponentBySourceFile(base_name); + const std::string& source_location = + formatted("%s:%d", base_name, entry.source_line); + const std::string& escaped_text = EscapeText(output_format, entry.text); + fprintf(file, "%s%-32s%s%s%s\n", ColorSectionStart(output_format, component), + source_location.c_str(), IndentString(entry.indent).c_str(), + escaped_text.c_str(), ColorSectionEnd(output_format)); +} + +// Prints a thread's entire trace to the destination trace file. +inline void Print(const ThreadTrace& trace, TraceOutputFormat output_format, + FILE* file) { + if (output_format == TraceOutputFormat::kHtml) { + fprintf(file, "<html><body><pre>\n<span style=\"font-weight:bold\">\n"); + } + fprintf(file, "Ruy trace for thread %d:\n", trace.thread_id()); + if (output_format == TraceOutputFormat::kHtml) { + fprintf(file, "</span>\n"); + } + for (const ThreadTraceEntry& entry : trace.entries()) { + Print(entry, output_format, file); + } + fprintf(file, "\n"); + if (output_format == TraceOutputFormat::kHtml) { + fprintf(file, "</pre></body></html>\n"); + } +} + +// Holds all the threads' traces. This is a global singleton class. +// On exit, when the singleton is destroyed, the destructor prints out the +// traces. +class AllThreadTraces final { + public: + // Add a new ThreadTrace for the current thread. Should be called only once + // on each thread. + ThreadTrace* AddCurrentThread() { + std::lock_guard<std::mutex> lock(mutex_); + ThreadTrace* thread_trace = new ThreadTrace; + thread_traces_.emplace_back(thread_trace); + return thread_trace; + } + ~AllThreadTraces() { + std::lock_guard<std::mutex> lock(mutex_); + // Open the destination file. + const char* file_env = getenv("RUY_TRACE_FILE"); + FILE* file = stdout; + if (file_env) { + file = fopen(file_env, "w"); + if (!file) { + fprintf(stderr, "Failed to open %s for write\n", file_env); + exit(1); + } + } + // Sort the threads by Ruy Thread ID (not OS TID). + auto output_format = GetOutputFormat(); + std::sort(std::begin(thread_traces_), std::end(thread_traces_), + [](const auto& a, const auto& b) { + return a->thread_id() < b->thread_id(); + }); + // Print all the threads' traces. + for (const auto& trace : thread_traces_) { + Print(*trace, output_format, file); + } + if (file_env) { + fclose(file); + } + } + static AllThreadTraces* Singleton() { + static AllThreadTraces all_thread_traces; + return &all_thread_traces; + } + + private: + std::vector<std::unique_ptr<ThreadTrace>> thread_traces_; + std::mutex mutex_; +}; + +// Returns the thread-local ThreadTrace singleton, constructing it as needed. +inline ThreadTrace* ThreadLocalTrace() { + static thread_local ThreadTrace* thread_local_trace = + AllThreadTraces::Singleton()->AddCurrentThread(); + return thread_local_trace; +} + +// RAII helper to trace a scope, e.g. a function scope. +class RuyTraceScope { + const char* source_file_; + int source_line_; + const char* scope_name_; + + public: + RuyTraceScope(const char* source_file, int source_line, + const char* scope_name) + : source_file_(source_file), + source_line_(source_line), + scope_name_(scope_name) { + ThreadLocalTrace()->set_current_source_file(source_file_); + ThreadLocalTrace()->set_current_source_line(source_line_); + ThreadLocalTrace()->EnterScope(scope_name_); + } + ~RuyTraceScope() { + ThreadLocalTrace()->set_current_source_file(source_file_); + ThreadLocalTrace()->set_current_source_line(source_line_); + ThreadLocalTrace()->LeaveScope(scope_name_); + } +}; + +#define RUY_TRACE_SCOPE_NAME_IMPL(file, line, name) \ + RuyTraceScope ruy_trace_scope##line(file, line, name) +#define RUY_TRACE_SCOPE_NAME(name) \ + RUY_TRACE_SCOPE_NAME_IMPL(__FILE__, __LINE__, name) +#define RUY_TRACE_SCOPE \ + RUY_TRACE_SCOPE_NAME_IMPL(__FILE__, __LINE__, __FUNCTION__) + +// Helpers to trace Ruy objects. + +inline std::string str(Order o) { + return o == Order::kRowMajor ? "row-major" : "column-major"; +} + +inline std::string str(Side s) { return s == Side::kLhs ? "LHS" : "RHS"; } + +inline std::string str(const Layout& layout) { + std::string s = + formatted("%dx%d, %s", layout.rows(), layout.cols(), str(layout.order())); + int inner_size = + layout.order() == Order::kRowMajor ? layout.cols() : layout.rows(); + if (inner_size != layout.stride()) { + s += formatted(", stride=%d", layout.stride()); + } else { + s += formatted(", unstrided"); + } + return s; +} + +inline std::string str(const MatLayout& layout) { + std::string s = + formatted("%dx%d, %s", layout.rows, layout.cols, str(layout.order)); + int inner_size = layout.order == Order::kRowMajor ? layout.cols : layout.rows; + if (inner_size != layout.stride) { + s += formatted(", stride=%d", layout.stride); + } else { + s += formatted(", unstrided"); + } + return s; +} + +inline std::string str(const PMatLayout& layout) { + std::string s = + formatted("%dx%d, %s", layout.rows, layout.cols, str(layout.order)); + int inner_size = layout.order == Order::kRowMajor ? layout.cols : layout.rows; + if (inner_size != layout.stride) { + s += formatted(", stride=%d", layout.stride); + } else { + s += formatted(", unstrided"); + } + s += formatted(", kernel blocks: %dx%d %s", layout.kernel.rows, + layout.kernel.cols, str(layout.kernel.order)); + return s; +} + +template <typename T> +std::string str() { + return "<unknown type>"; +} +#define RUY_IMPL_STR_TYPE_STD(T) \ + template <> \ + inline std::string str<std::T>() { \ + return #T; \ + } +#define RUY_IMPL_STR_TYPE(T) \ + template <> \ + inline std::string str<T>() { \ + return #T; \ + } + +RUY_IMPL_STR_TYPE(float) +RUY_IMPL_STR_TYPE(double) +RUY_IMPL_STR_TYPE_STD(int8_t) +RUY_IMPL_STR_TYPE_STD(uint8_t) +RUY_IMPL_STR_TYPE_STD(int16_t) +RUY_IMPL_STR_TYPE_STD(uint16_t) +RUY_IMPL_STR_TYPE_STD(int32_t) +RUY_IMPL_STR_TYPE_STD(uint32_t) +RUY_IMPL_STR_TYPE_STD(int64_t) +RUY_IMPL_STR_TYPE_STD(uint64_t) + +template <typename T> +std::string str(const Matrix<T>& matrix) { + std::string s = formatted("Matrix<%s>, %s", str<T>(), str(matrix.layout())); + if (matrix.zero_point()) { + s += formatted(", zero_point=%d", static_cast<int>(matrix.zero_point())); + } + if (matrix.cache_policy() != CachePolicy::kNeverCache) { + s += + formatted(", cache_policy=%d", static_cast<int>(matrix.cache_policy())); + } + return s; +} + +inline std::string str(const Type& type) { + char c; + if (type.is_floating_point) { + c = 'f'; + } else if (type.is_signed) { + c = 'i'; + } else { + c = 'u'; + } + return formatted("%c%d", c, type.size * 8); +} + +inline std::string str(const EMat& mat) { + std::string s = + formatted("EMat, data_type=%s, %s", str(mat.data_type), str(mat.layout)); + if (mat.zero_point) { + s += formatted(", zero_point=%d", static_cast<int>(mat.zero_point)); + } + if (mat.cache_policy != CachePolicy::kNeverCache) { + s += formatted(", cache_policy=%d", static_cast<int>(mat.cache_policy)); + } + return s; +} + +inline std::string str(const PEMat& mat) { + std::string s = + formatted("PEMat, data_type=%s, %s", str(mat.data_type), str(mat.layout)); + if (mat.zero_point) { + s += formatted(", zero_point=%d", static_cast<int>(mat.zero_point)); + } + return s; +} + +inline std::string str(Path paths) { + bool first = true; + std::string s; + for (int bit = 0; bit < 16; bit++) { + Path cur_path = static_cast<Path>(1 << bit); + if ((paths & cur_path) != Path::kNone) { + if (!first) { + s += " | "; + } + first = false; + switch (cur_path) { + case Path::kNone: + continue; +#define RUY_HANDLE_PATH(p) \ + case Path::p: \ + s += #p; \ + break; + RUY_HANDLE_PATH(kStandardCpp) + RUY_HANDLE_PATH(kInternalStandardCppVariant1) + RUY_HANDLE_PATH(kInternalStandardCppVariant2) + RUY_HANDLE_PATH(kInternalStandardCppVariant3) +#if RUY_PLATFORM_ARM + RUY_HANDLE_PATH(kNeon) + RUY_HANDLE_PATH(kNeonDotprod) +#endif // RUY_PLATFORM_ARM +#if RUY_PLATFORM_X86 + RUY_HANDLE_PATH(kAvx) + RUY_HANDLE_PATH(kAvx2Fma) + RUY_HANDLE_PATH(kAvx512) +#endif // RUY_PLATFORM_X86 +#undef RUY_HANDLE_PATH + default: + fprintf(stderr, "Unhandled Path value 0x%x\n", + static_cast<int>(cur_path)); + abort(); + } + } + } + return s; +} + +// Implementation of RUY_TRACE_INFO(X) macros. + +#define RUY_TRACE_INFO_MUL \ + ThreadLocalTrace()->Write("CompiledPaths: %s", str(CompiledPaths)); \ + ThreadLocalTrace()->Write("LHS: %s", str(lhs)); \ + ThreadLocalTrace()->Write("RHS: %s", str(rhs)); \ + ThreadLocalTrace()->Write("Destination: %s", str(*dst)); + +#define RUY_TRACE_INFO_CREATE_TRMUL_PARAMS_TRANSPOSING \ + ThreadLocalTrace()->Write("Canonicalizing to column-major destination:"); \ + ThreadLocalTrace()->Write( \ + "Swapping LHS<->RHS and flipping all storage orders."); + +#define RUY_TRACE_INFO_CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST \ + ThreadLocalTrace()->Write("Runtime-selected path: %s", str(the_path)); \ + ThreadLocalTrace()->Write("LHS: %s", str(params->src[Side::kLhs])); \ + ThreadLocalTrace()->Write("RHS: %s", str(params->src[Side::kRhs])); \ + ThreadLocalTrace()->Write("Destination: %s", str(params->dst)); + +#define RUY_TRACE_INFO_POPULATE_TRMUL_PARAMS \ + ThreadLocalTrace()->Write( \ + "Here we have this Path as a template parameter: %s", str(ThePath)); \ + ThreadLocalTrace()->Write("PackedLhsScalar: %s", str<PackedLhsScalar>()); \ + ThreadLocalTrace()->Write("PackedRhsScalar: %s", str<PackedRhsScalar>()); \ + ThreadLocalTrace()->Write("Kernel function pointer: %p", \ + params->run_kernel); \ + ThreadLocalTrace()->Write("Kernel LHS block layout: %dx%d %s", \ + LhsKernelLayout::kRows, LhsKernelLayout::kCols, \ + str(LhsKernelLayout::kOrder)); \ + ThreadLocalTrace()->Write("Kernel RHS block layout: %dx%d %s", \ + RhsKernelLayout::kRows, RhsKernelLayout::kCols, \ + str(RhsKernelLayout::kOrder)); \ + ThreadLocalTrace()->Write("Created packed matrices:"); \ + ThreadLocalTrace()->Write("Packed LHS matrix: %s", \ + str(params->packed_matrix[Side::kLhs])); \ + ThreadLocalTrace()->Write("Packed RHS matrix: %s", \ + str(params->packed_matrix[Side::kRhs])); \ + ThreadLocalTrace()->Write("LHS packing function pointer: %p", \ + params->run_pack[Side::kLhs]); \ + ThreadLocalTrace()->Write("RHS packing function pointer: %p", \ + params->run_pack[Side::kRhs]); + +#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_SET_VALUE \ + ThreadLocalTrace()->Write("SetRuntimeEnabledPaths forcing paths: %s", \ + str(*paths)); + +#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_ENV_VAR \ + ThreadLocalTrace()->Write("Environment variable forcing paths: %s", \ + str(*paths)); + +#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_DETECTION \ + ThreadLocalTrace()->Write( \ + "Runtime-detected paths: %s", \ + str(*paths & ~kNonArchPathsIncludingInternalVariants)); + +#define RUY_TRACE_INFO_PREPARE_PACKED_MATRICES_SHOULD_CACHE \ + ThreadLocalTrace()->Write( \ + "Caching the packed %s matrix. Already in cache: %s", str(side), \ + action == PrepackedCache::Action::kInsertedNewEntry ? "no" : "yes"); + +#define RUY_TRACE_INFO_PREPARE_PACKED_MATRICES_NO_CACHE \ + ThreadLocalTrace()->Write("Not caching the packed %s matrix.", str(side)); + +#define RUY_TRACE_INFO_GET_TENTATIVE_THREAD_COUNT \ + ThreadLocalTrace()->Write( \ + "tentative_thread_count=%d (determined based on shape %dx%dx%d and " \ + "max_num_threads=%d)", \ + tentative_thread_count, rows, depth, cols, ctx->max_num_threads()); + +#define RUY_TRACE_INFO_GET_USE_SIMPLE_LOOP_RETURNS_TRUE \ + ThreadLocalTrace()->Write( \ + "Choosing to use the simple loop code path in TrMul because of the " \ + "linear traversal and single thread."); + +#define RUY_TRACE_INFO_GET_USE_SIMPLE_LOOP_RETURNS_FALSE \ + ThreadLocalTrace()->Write( \ + "Choosing to use the general case code path in TrMul because of: %s", \ + tentative_thread_count > 1 ? "multi-threading" \ + : "non-linear traversal order"); + +#define RUY_TRACE_INFO_TRMUL_SIMPLE_LOOP \ + ThreadLocalTrace()->Write("Entering the simple loop code path of TrMul"); + +#define RUY_TRACE_INFO_TRMUL_GENERAL_CASE \ + ThreadLocalTrace()->Write("Entering the general case code path of TrMul"); + +#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_START \ + ThreadLocalTrace()->Write("Kernel block: %dx%d", kernel_rows, kernel_cols); \ + ThreadLocalTrace()->Write( \ + "BlockMap shape: %dx%d (destination matrix shape rounded to next " \ + "kernel blocks)", \ + rows, cols); \ + ThreadLocalTrace()->Write( \ + "Rectangularness log2: %dx%d (powers of two factors bringing the shape " \ + "closest to square)", \ + rows_rectangularness_log2, cols_rectangularness_log2); \ + ThreadLocalTrace()->Write("Accumulation depth: %d", depth); \ + ThreadLocalTrace()->Write("LHS scalar type size: %d", lhs_scalar_size); \ + ThreadLocalTrace()->Write("RHS scalar type size: %d", rhs_scalar_size); \ + ThreadLocalTrace()->Write("Tentative thread count: %d", \ + tentative_thread_count); \ + ThreadLocalTrace()->Write( \ + "CPU cache params: local_cache_size=%d, last_level_cache_size=%d", \ + cpu_cache_params.local_cache_size, \ + cpu_cache_params.last_level_cache_size); \ + ThreadLocalTrace()->Write( \ + "For the sizes below, when rows!=cols, we always retain the min of the " \ + "two."); \ + ThreadLocalTrace()->Write("Kernel block size_log2: %d", kernel_size_log2); \ + ThreadLocalTrace()->Write( \ + "BlockMap size_log2: %d (destination matrix shape rounded to next " \ + "kernel blocks)", \ + size_log2); \ + ThreadLocalTrace()->Write( \ + "Now we will pick the optimal log2 of BlockMap block size"); + +#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE \ + ThreadLocalTrace()->Write( \ + "For BlockMap block size_log2 %d, score=%d (" \ + "multithreading_score=%d + cache_locality_score=%d + " \ + "kernel_amortization_score=%d)", \ + block_size_log2, score, multithreading_score, cache_locality_score, \ + kernel_amortization_score); + +#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_END \ + ThreadLocalTrace()->Write("Selecting BlockMap block size_log2: %d", \ + best_score_block_size_log2); \ + ThreadLocalTrace()->Write( \ + "BlockMap has %dx%d blocks, each of size between %dx%d and %dx%d.", \ + 1 << num_blocks_of_rows_log2, 1 << num_blocks_of_cols_log2, \ + block_map->small_block_dims[Side::kLhs], \ + block_map->small_block_dims[Side::kRhs], \ + block_map->small_block_dims[Side::kLhs] + \ + block_map->kernel_dims[Side::kLhs], \ + block_map->small_block_dims[Side::kRhs] + \ + block_map->kernel_dims[Side::kRhs]); \ + ThreadLocalTrace()->Write( \ + "The first %d rows of blocks have %d rows, the remaining ones have %d " \ + "rows ", \ + block_map->large_blocks[Side::kLhs], \ + block_map->small_block_dims[Side::kLhs] + \ + block_map->kernel_dims[Side::kLhs], \ + block_map->small_block_dims[Side::kLhs]); \ + ThreadLocalTrace()->Write( \ + "The first %d columns of blocks have %d columns, the remaining ones " \ + "have %d columns ", \ + block_map->large_blocks[Side::kRhs], \ + block_map->small_block_dims[Side::kRhs] + \ + block_map->kernel_dims[Side::kLhs], \ + block_map->small_block_dims[Side::kRhs]); \ + ThreadLocalTrace()->Write( \ + "Traversal order: %s", \ + block_map->traversal_order == BlockMapTraversalOrder::kLinear ? "linear" \ + : block_map->traversal_order == BlockMapTraversalOrder::kFractalZ \ + ? "fractal Z-curve" \ + : block_map->traversal_order == BlockMapTraversalOrder::kFractalU \ + ? "fractal U-curve" \ + : block_map->traversal_order == BlockMapTraversalOrder::kFractalHilbert \ + ? "fractal Hilbert curve" \ + : nullptr); \ + ThreadLocalTrace()->Write("Finalized thread count: %d", \ + block_map->thread_count); + +#define RUY_TRACE_SET_THEAD_ID(thread_id) \ + ThreadLocalTrace()->set_thread_id(thread_id); + +#define RUY_TRACE_INFO_TRMUL_TASK_MAIN_LOOP_GOT_BLOCK_COORDS \ + ThreadLocalTrace()->Write( \ + "Block #%d is at position (%d, %d) in the BlockMap.", block_id, \ + block[Side::kLhs], block[Side::kRhs]); \ + ThreadLocalTrace()->Write( \ + "Block #%d has shape %dx%d and starts at position (%d, %d) in the " \ + "destination matrix.", \ + block_id, end[Side::kLhs] - start[Side::kLhs], \ + end[Side::kRhs] - start[Side::kRhs], start[Side::kLhs], \ + start[Side::kRhs]); \ + ThreadLocalTrace()->Write( \ + "Block #%d depends on LHS panel #%d and RHS panel #%d.", block_id, \ + block[Side::kLhs], block[Side::kRhs]); + +#define RUY_TRACE_INFO_TRYPACK_PACKING \ + ThreadLocalTrace()->Write( \ + "%s panel #%d is not already packed. Packing it now.", str(side), \ + block); + +#define RUY_TRACE_INFO_TRYPACK_ANOTHER_THREAD_PACKING \ + if (!ThreadLocalTrace()->is_in_run_ahead_packing_loop()) { \ + ThreadLocalTrace()->Write( \ + "%s panel #%d is currently being packed by another thread.", \ + str(side), block); \ + } + +#define RUY_TRACE_INFO_TRYPACK_PREVIOUSLY_PACKED \ + if (!ThreadLocalTrace()->is_in_run_ahead_packing_loop()) { \ + ThreadLocalTrace()->Write("%s panel #%d had previously been packed.", \ + str(side), block); \ + } + +#define RUY_TRACE_INFO_TRYPACK_PACKED_BY_ANOTHER_THREAD \ + ThreadLocalTrace()->Write( \ + "%s panel #%d has just been packed by another thread.", str(side), \ + block); + +#define RUY_TRACE_INFO_ENSURE_PACKED_ENTER_RUN_AHEAD \ + if (!ThreadLocalTrace()->is_in_run_ahead_packing_loop()) { \ + ThreadLocalTrace()->set_is_in_run_ahead_packing_loop(true); \ + ThreadLocalTrace()->Write( \ + "We're blocked on other threads packing the panels that we need. " \ + "Packing some other panels while we wait..."); \ + } + +#define RUY_TRACE_INFO_ENSURE_PACKED_END \ + if (ThreadLocalTrace()->is_in_run_ahead_packing_loop()) { \ + ThreadLocalTrace()->set_is_in_run_ahead_packing_loop(false); \ + ThreadLocalTrace()->Write( \ + "Other threads have finished packing what we were waiting for."); \ + } + +#define RUY_TRACE_INFO_RUN_PACK \ + ThreadLocalTrace()->Write("Path: %s", str(ThePath)); \ + ThreadLocalTrace()->Write("Packing panel consisting of columns [%d, %d)", \ + start_col, end_col); \ + ThreadLocalTrace()->Write("Source: columns [%d, %d) of %s", start_col, \ + end_col, str(src_matrix)); \ + ThreadLocalTrace()->Write("Destination: columns [%d, %d) of %s", start_col, \ + end_col, str(*packed_matrix)); \ + if (end_col > src_matrix.layout.cols) { \ + ThreadLocalTrace()->Write( \ + "This runs past the last column of the source matrix. Padding as " \ + "needed."); \ + } \ + if (packed_matrix->layout.rows > src_matrix.layout.rows) { \ + ThreadLocalTrace()->Write( \ + "The packed matrix has more rows than the source matrix due to " \ + "rounding up to the kernel block size. Padding as needed."); \ + } + +#define RUY_TRACE_INFO_RUN_KERNEL \ + { \ + ThreadLocalTrace()->Write("Path: %s", str(KernelArgs<KernelType>::kPath)); \ + int lhs_cols = end[Side::kLhs] - start[Side::kLhs]; \ + int rhs_cols = end[Side::kRhs] - start[Side::kRhs]; \ + int kernel_lhs_cols = src[Side::kLhs].layout.kernel.cols; \ + int kernel_rhs_cols = src[Side::kRhs].layout.kernel.cols; \ + ThreadLocalTrace()->Write("LHS: columns [%d, %d) of %s", \ + start[Side::kLhs], end[Side::kLhs], \ + str(src[Side::kLhs])); \ + ThreadLocalTrace()->Write("RHS: columns [%d, %d) of %s", \ + start[Side::kRhs], end[Side::kRhs], \ + str(src[Side::kRhs])); \ + ThreadLocalTrace()->Write("Destination: block [%d, %d)x[%d, %d) of %s", \ + start[Side::kLhs], end[Side::kLhs], \ + start[Side::kRhs], end[Side::kRhs], str(*dst)); \ + if (end[Side::kLhs] > dst->layout.rows || \ + end[Side::kRhs] > dst->layout.cols) { \ + ThreadLocalTrace()->Write( \ + "This runs over the destination matrix boundaries. The kernel will " \ + "internally clamp stores to avoid overruns."); \ + } \ + ThreadLocalTrace()->Write( \ + "The kernel's inner loop only produces a %dx%d block, so the " \ + "kernel's outer loops will run %dx%d times.", \ + kernel_lhs_cols, kernel_rhs_cols, lhs_cols / kernel_lhs_cols, \ + rhs_cols / kernel_rhs_cols); \ + } + +#define RUY_TRACE_INFO_THREAD_FUNC_IMPL_WAITING \ + ThreadLocalTrace()->Write("Waiting for a task..."); + +#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_STARTING_TASK \ + ThreadLocalTrace()->Write("Sending task #%d to a worker thread...", i); + +#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD \ + ThreadLocalTrace()->Write("Running task #0 on the current thread..."); + +#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_WAITING_FOR_THREADS \ + ThreadLocalTrace()->Write("Waiting for worker threads to finish.."); + +#define RUY_TRACE_INFO(id) \ + [=]() { \ + ThreadLocalTrace()->set_current_source_file(__FILE__); \ + ThreadLocalTrace()->set_current_source_line(__LINE__); \ + RUY_TRACE_INFO_##id \ + }() + +} // namespace ruy + +#else + +// Vacuous implementation when RUY_TRACE is not defined. +#define RUY_TRACE_SCOPE_NAME(name) +#define RUY_TRACE_SCOPE +#define RUY_TRACE_SET_THEAD_ID(thread_id) +#define RUY_TRACE_INFO(id) + +#endif + +#endif // RUY_RUY_TRACE_H_ diff --git a/ruy/trmul.cc b/ruy/trmul.cc new file mode 100644 index 0000000..9345f0c --- /dev/null +++ b/ruy/trmul.cc @@ -0,0 +1,397 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The 'middle-end' in ruy. See TrMul function comment. + +#include "ruy/trmul.h" + +#include <algorithm> +#include <atomic> +#include <cstdint> +#include <cstring> +#include <memory> +#include <vector> + +#include "ruy/allocator.h" +#include "ruy/block_map.h" +#include "ruy/check_macros.h" +#include "ruy/cpu_cache_params.h" +#include "ruy/cpuinfo.h" +#include "ruy/ctx.h" +#include "ruy/mat.h" +#include "ruy/matrix.h" +#include "ruy/mul_params.h" +#include "ruy/opt_set.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/thread_pool.h" +#include "ruy/trace.h" +#include "ruy/tune.h" + +namespace ruy { + +namespace { + +// Enum to track the packingstatus of a block of the LHS or RHS matrix. +enum class PackingStatus : std::uint8_t { + kNotStarted, // No thread has started packing this block yet. + kInProgress, // Some thread is currently packing this block. + kFinished // This block has already been packed. +}; + +// TrMulTask is the task that a ruy thread runs to perform the TrMul operation. +class TrMulTask final : public Task { + public: + TrMulTask(TrMulParams* params, const BlockMap& block_map, + std::atomic<int>* atomic_block_id, int thread_id, bool need_atomics, + SidePair<std::atomic<PackingStatus>*> packing_status, + TuningResolver* tuning_resolver, Allocator* local_allocator, + CpuInfo* cpuinfo) + : params_(params), + block_map_(block_map), + atomic_block_id_(atomic_block_id), + thread_id_(thread_id), + need_atomics_(need_atomics), + packing_status_(packing_status), + tuning_resolver_(tuning_resolver), + local_allocator_(local_allocator), + local_already_packed_{nullptr, nullptr}, + cpuinfo_(cpuinfo) {} + + // Thread main function. This is one thread's share of the TrMul work. + void Run() override { + RUY_TRACE_SCOPE_NAME("TrMulTask::Run"); + RUY_TRACE_SET_THEAD_ID(thread_id_); + // Allocate and initialize `local_packed`. + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params_->is_prepacked[side]) { + const int size = NumBlocksPerSide(side, block_map_); + local_allocator_->Allocate(size, &local_already_packed_[side]); + memset(local_already_packed_[side], 0, size * sizeof(bool)); + } + } + + const Tuning tuning = tuning_resolver_->Resolve(cpuinfo_); + const int num_blocks = NumBlocks(block_map_); + + // Each thread starts by initially reserving the block whose id + // is the thread id. + int block_id = thread_id_; + // Loop until all blocks have been computed. + while (block_id < num_blocks) { + RUY_TRACE_SCOPE_NAME("Main loop iteration"); + // Reserve the next block to handle, hiding the latency of this atomic op. + const int next_block_id = + atomic_block_id_->fetch_add(1, std::memory_order_relaxed); + // Get coordinates of the current block to handle, in "block space". + SidePair<int> block; + GetBlockByIndex(block_map_, block_id, &block); + // Get coordinates of the current block to handle, in matrix space. + SidePair<int> start, end; + GetBlockMatrixCoords(block_map_, block, &start, &end); + RUY_TRACE_INFO(TRMUL_TASK_MAIN_LOOP_GOT_BLOCK_COORDS); + // Maybe pack the current LHS/RHS block, if not already packed. + EnsurePacked(block, start, end, tuning); + // Actually do matrix multiplication work + params_->RunKernel(tuning, start, end); + // Move on to the next block as obtained by the atomic increment + // at the start of this while loop iteration. + block_id = next_block_id; + } + + local_allocator_->FreeAll(); + } + + private: + // Tries to pack a block, without blocking. + // If the block was already packed, returns true. + // If the block was not started packing, packs it and returns true. + // If the block was being packed by another thread, returns false. + bool TryPack(Side side, int block, int start, int end, Tuning tuning) { + if (params_->is_prepacked[side]) { + return true; + } + if (!local_already_packed_[side][block]) { + if (need_atomics_) { + // Explanation of this compare_exchange_strong operation: + // This atomically performs all of the following: + // 1. Read `status` with "acquire" memory order. + // * That this read uses "acquire" is because both memory orders + // specified have "acquire" as their read-component. + // 2. Compare (bitwise) with `exchanged_status`. + // 3. If equal, stores the value kInProgress to `status` with "release" + // memory order, and returns true, so we take this 'if' branch. + // * That this store uses "release" is because of the _rel part in + // memory_order_acq_rel passed as the first memory order argument. + // 4. If not equal, stores the loaded value of `status` to + // `exchanged_status` with "relaxed" semantics, and returns false, + // so we take the 'else' branch. + // * That this store uses "relaxed" is because the second memory + // order argument, memory_order_acquire, implies no particular + // store semantics. "relaxed" is acceptable here because this + // stores to a local stack variable. + // + // Rationale for compare_exchange_strong as opposed to + // compare_exchange_weak: + // The spurious-failure case with compare_exchange_weak will actually + // happen a lot here, because the atomic 'status' bytes are stored + // contiguously in arrays and neighboring values will be accessed + // by multiple threads concurrently. On a typical ARM CPU, an exclusives + // reservation granule is 64 bytes, so a lot of false-sharing may + // happen. Using compare_exchange_weak would thus result in often having + // TryPack return 'false' when it could instead have done the packing + // work and returned 'true'. Heuristically, that is not a good thing. + // Moreover, this changes the TryPack contract, loosening it and making + // it harder for the caller to reason about. Finally, the overhead of + // atomic operations is mitigated by the enclosing check on + // local_already_packed, so maybe the overhead of + // compare_exchange_strong isn't such a problem. But we don't really + // know for sure, that would be interesting to experiment more with. + PackingStatus exchanged_status = PackingStatus::kNotStarted; + std::atomic<PackingStatus>& status = packing_status_[side][block]; + if (status.compare_exchange_strong( + exchanged_status, PackingStatus::kInProgress, + std::memory_order_acq_rel, std::memory_order_acquire)) { + // In this branch, the status was kNotStarted and we just atomically + // changed it to kInProgress as we are about to handle the packing + // ourselves. + RUY_TRACE_INFO(TRYPACK_PACKING); + params_->RunPack(side, tuning, start, end); + status.store(PackingStatus::kFinished, std::memory_order_release); + } else if (exchanged_status == PackingStatus::kInProgress) { + // Another thread is currently packing this block. + RUY_TRACE_INFO(TRYPACK_ANOTHER_THREAD_PACKING); + return false; + } else { + RUY_TRACE_INFO(TRYPACK_PACKED_BY_ANOTHER_THREAD); + } + RUY_DCHECK(status.load(std::memory_order_acquire) == + PackingStatus::kFinished); + } else { + // Single-threaded case: no need for expensive atomics, + // local_already_packed is the truth already. + params_->RunPack(side, tuning, start, end); + } + local_already_packed_[side][block] = true; + } else { + RUY_TRACE_INFO(TRYPACK_PREVIOUSLY_PACKED); + } + return true; + } + + // Ensures that both the LHS and RHS blocks required by the specified block + // are packed. In the event that they are already being packed on another + // threads, this function may perform the packing of some other block while + // waiting for that other thread to finish packing the requested block. + void EnsurePacked(const SidePair<int>& block, const SidePair<int>& start, + const SidePair<int>& end, Tuning tuning) { +#if RUY_OPT(PACK_AHEAD) + SidePair<int> next_runahead_block{block[Side::kLhs] + 1, + block[Side::kRhs] + 1}; + Side next_runahead_side = Side::kLhs; +#endif + while (true) { + bool both_sides_packed = true; + for (Side side : {Side::kLhs, Side::kRhs}) { + both_sides_packed &= + TryPack(side, block[side], start[side], end[side], tuning); + } + if (both_sides_packed) { + break; + } +#if RUY_OPT(PACK_AHEAD) + RUY_TRACE_INFO(ENSURE_PACKED_ENTER_RUN_AHEAD); + const Side runahead_side = next_runahead_side; + const int runahead_block = next_runahead_block[runahead_side]; + next_runahead_side = OtherSide(next_runahead_side); + if (runahead_block >= NumBlocksPerSide(runahead_side, block_map_)) { + continue; + } + int runahead_block_start, runahead_block_end; + GetBlockMatrixCoords(runahead_side, block_map_, runahead_block, + &runahead_block_start, &runahead_block_end); + TryPack(runahead_side, runahead_block, runahead_block_start, + runahead_block_end, tuning); + next_runahead_block[runahead_side] = runahead_block + 1; +#endif + } + RUY_TRACE_INFO(ENSURE_PACKED_END); + } + + TrMulParams* params_; + const BlockMap& block_map_; + std::atomic<int>* atomic_block_id_; + int thread_id_; + bool need_atomics_; + SidePair<std::atomic<PackingStatus>*> packing_status_; + TuningResolver* tuning_resolver_; + Allocator* local_allocator_; + + // Local indicators of packedness to avoid the overhead of atomic ops. + SidePair<bool*> local_already_packed_; + + CpuInfo* cpuinfo_; +}; + +int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) { +#if RUY_PLATFORM_EMSCRIPTEN + // b/139927184, std::thread constructor raises exception + return 1; +#endif + RUY_TRACE_SCOPE; + // Empirically determined rule for reasonable number of + // threads to use. This is proportional to the number of arithmetic ops + // in this Mul (product of the 3 sizes). + static constexpr int kDivisorLog2 = 15; + const int guess_log2 = std::max( + 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2); + int tentative_thread_count = + std::min(1 << guess_log2, ctx->max_num_threads()); + RUY_TRACE_INFO(GET_TENTATIVE_THREAD_COUNT); + return tentative_thread_count; +} + +bool GetUseSimpleLoop(int tentative_thread_count, int rows, int cols, int depth, + int lhs_scalar_size, int rhs_scalar_size, + const CpuCacheParams& cpu_cache_params) { + RUY_TRACE_SCOPE; + if (tentative_thread_count == 1) { + if (IsObviouslyLinearTraversal(rows, cols, depth, lhs_scalar_size, + rhs_scalar_size, cpu_cache_params)) { + RUY_TRACE_INFO(GET_USE_SIMPLE_LOOP_RETURNS_TRUE); + return true; + } + } + RUY_TRACE_INFO(GET_USE_SIMPLE_LOOP_RETURNS_FALSE); + return false; +} + +} // namespace + +// TrMul is the ruy middle-end. It contains the high-level logic to perform +// a ruy::Mul's work, down to calls to back-end Kernel and Pack functions. +// This includes determining how many threads to use, computing the BlockMap, +// executing tasks on a thread-pool. The TrMul function itself runs on the main +// thread, the code that is potentially running on worker threads is in +// TrMulTask::Run(). +void TrMul(Ctx* ctx, TrMulParams* params) { + RUY_TRACE_SCOPE; + profiler::ScopeLabel label( + "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))", + static_cast<int>(params->path), ctx->max_num_threads(), + params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]); + + PEMat& packed_lhs = params->packed_matrix[Side::kLhs]; + PEMat& packed_rhs = params->packed_matrix[Side::kRhs]; + EMat& lhs = params->src[Side::kLhs]; + EMat& rhs = params->src[Side::kRhs]; + + const int rows = lhs.layout.cols; + const int cols = rhs.layout.cols; + const int depth = lhs.layout.rows; + + const int tentative_thread_count = + GetTentativeThreadCount(ctx, rows, cols, depth); + const auto& cpu_cache_params = ctx->mutable_cpuinfo()->CacheParams(); + + // Case of running this TrMul as a simple loop. + // This is a good place to start reading this function: all the rest + // of this function is just an optimized, but functionally equivalent, + // version of that. + if (GetUseSimpleLoop(tentative_thread_count, rows, cols, depth, + lhs.data_type.size, rhs.data_type.size, + cpu_cache_params)) { + profiler::ScopeLabel label_simple("TrMulImpl, simple loop"); + Tuning tuning = ctx->GetMainThreadTuning(); + RUY_TRACE_INFO(TRMUL_SIMPLE_LOOP); + + const SidePair<int> origin{0, 0}; + const SidePair<int> rounded_dims{packed_lhs.layout.cols, + packed_rhs.layout.cols}; + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + params->RunPack(side, tuning, origin[side], rounded_dims[side]); + } + } + params->RunKernel(tuning, origin, rounded_dims); + return; + } + + profiler::ScopeLabel label_general("TrMulImpl, general case"); + RUY_TRACE_INFO(TRMUL_GENERAL_CASE); + Allocator* main_allocator = ctx->GetMainAllocator(); + + // Initialize block map. + BlockMap block_map; + MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth, + packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols, + packed_lhs.data_type.size, packed_rhs.data_type.size, + tentative_thread_count, cpu_cache_params, &block_map); + + // Initialize per-thread state. + const int thread_count = block_map.thread_count; + const bool need_atomics = thread_count > 1; + ctx->EnsureThreadSpecificResources(thread_count); + for (int i = 0; i < thread_count; i++) { + ctx->GetThreadSpecificTuningResolver(i)->SetTuning(ctx->explicit_tuning()); + } + + // In the need_atomics case, allocate and initialize atomic values tracking + // the packing status of blocks. + SidePair<std::atomic<PackingStatus>*> packing_status{nullptr, nullptr}; + if (need_atomics) { + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + const int size = NumBlocksPerSide(side, block_map); + main_allocator->Allocate(size, &packing_status[side]); + for (int i = 0; i < size; i++) { + packing_status[side][i].store(PackingStatus::kNotStarted, + std::memory_order_relaxed); + } + } + } + } + + // Create the atomic block id, allocate it using Allocator so that + // we get the alignment ensuring that it sits alone in its exclusives + // reservation granule. + std::atomic<int>* atomic_block_id; + main_allocator->Allocate(1, &atomic_block_id); + + // Create task objects. + TrMulTask* tasks; + main_allocator->Allocate(thread_count, &tasks); + + atomic_block_id->store(thread_count); + + for (int i = 0; i < thread_count; i++) { + auto* allocator = ctx->GetThreadSpecificAllocator(i); + auto* tuning_resolver = ctx->GetThreadSpecificTuningResolver(i); + new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i, + need_atomics, packing_status, tuning_resolver, + allocator, ctx->mutable_cpuinfo()); + } + + // Do the computation. + ctx->mutable_thread_pool()->Execute(thread_count, tasks); + + // Finish up. + for (int i = 0; i < thread_count; i++) { + tasks[i].~TrMulTask(); + } +} + +} // namespace ruy diff --git a/ruy/trmul.h b/ruy/trmul.h new file mode 100644 index 0000000..b4af287 --- /dev/null +++ b/ruy/trmul.h @@ -0,0 +1,39 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// As a matrix multiplication library, Ruy offers a Mul entry point, performing +// matrix multiplication. For implementation purposes, it is much nicer to +// be dealing with the transpose-and-multiply operation, doing +// Destination = Transpose(LHS) * RHS +// Indeed, the latter is performing dot-products between the *columns* of LHS +// and the columns of RHS, whereas a plain matrix multiplication is performing +// dot-products between the *rows* of LHS and the columns of RHS. +// That is why TrMul is nicer to implement, allowing for a more symmetric +// treatment of LHS and RHS. + +#ifndef RUY_RUY_TRMUL_H_ +#define RUY_RUY_TRMUL_H_ + +#include "ruy/ctx.h" +#include "ruy/trmul_params.h" + +namespace ruy { + +struct ContextInternal; +void TrMul(Ctx* ctx, TrMulParams* params); + +} // namespace ruy + +#endif // RUY_RUY_TRMUL_H_ diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h new file mode 100644 index 0000000..e68d909 --- /dev/null +++ b/ruy/trmul_params.h @@ -0,0 +1,92 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_TRMUL_PARAMS_H_ +#define RUY_RUY_TRMUL_PARAMS_H_ + +#include <cstdint> + +#include "ruy/mat.h" +#include "ruy/mul_params.h" +#include "ruy/path.h" +#include "ruy/side_pair.h" +#include "ruy/tune.h" + +namespace ruy { + +using RunKernelFn = void(Tuning, const SidePair<PEMat>&, const void*, + const SidePair<int>&, const SidePair<int>&, EMat*); + +using RunPackFn = void(Tuning, const EMat&, PEMat*, int, int); + +// This should not be needed since we require c++14, where std::max is already +// constexpr, but TensorFlow continuous integration uses Ubuntu 16 with a +// libstdc++ that does not support that. +constexpr int constexpr_max(int a, int b) { return a > b ? a : b; } + +// Under-estimating these values would be caught by a static_assert in +// StoreMulParams. Over-estimating these values cannot easily be caught, and +// would cause unnecessary inflation of the TrMulParams data structure. +constexpr int kMaxMulParamsAlignment = + constexpr_max(alignof(void*), alignof(double)); +constexpr int kMaxMulParamsSizeFloatingPointCase = + sizeof(MulParams<double, double>); +constexpr int kMaxMulParamsSizeRawIntegerCase = + sizeof(MulParams<std::int32_t, std::int32_t>); +constexpr int kMaxMulParamsSizeQuantizedIntegerCase = + sizeof(MulParams<std::int32_t, std::int16_t>); +constexpr int kMaxMulParamsSize = + constexpr_max(kMaxMulParamsSizeFloatingPointCase, + constexpr_max(kMaxMulParamsSizeRawIntegerCase, + kMaxMulParamsSizeQuantizedIntegerCase)); + +// OK to adjust as needed, but we want to avoid unnecessarily inflating that. +static_assert(kMaxMulParamsSize <= 32, ""); + +// Type-erased data needed for implementing TrMul. +struct TrMulParams { + TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {} + // Helper functions for invoking the function pointers. + void RunPack(Side side, Tuning tuning, int start, int end) { + run_pack[side](tuning, src[side], &packed_matrix[side], start, end); + } + void RunKernel(Tuning tuning, const SidePair<int>& start, + const SidePair<int>& end) { + run_kernel(tuning, packed_matrix, mul_params_bytes, start, end, &dst); + } + + // path id, can be useful info for some fine-tuning, e.g. to guess reasonable + // cache sizes when not runtime-detectable. + Path path; + + // Function pointers to type-erased entry points for kernels and packers. + SidePair<RunPackFn*> run_pack; + RunKernelFn* run_kernel = nullptr; + + // Matrices and packed matrices. + SidePair<EMat> src; + EMat dst; + SidePair<PEMat> packed_matrix; + SidePair<bool> is_prepacked; + + // Bytes underlying the MulParams, used as type-erased storage for MulParams + // data as it isn't used until we reach the kernel code, where it is casted + // back to the original MulParams type. + alignas(kMaxMulParamsAlignment) char mul_params_bytes[kMaxMulParamsSize]; +}; + +} // namespace ruy + +#endif // RUY_RUY_TRMUL_PARAMS_H_ diff --git a/ruy/tune.cc b/ruy/tune.cc new file mode 100644 index 0000000..1f615bf --- /dev/null +++ b/ruy/tune.cc @@ -0,0 +1,46 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/tune.h" + +#include <algorithm> +#include <cstdint> + +#include "ruy/cpuinfo.h" + +namespace ruy { + +Tuning TuningResolver::ResolveNow(CpuInfo* cpuinfo) { + return cpuinfo->CurrentCpuIsA55ish() ? Tuning::kA55ish : Tuning::kGeneric; +} + +TuningResolver::TuningResolver() + : expiry_duration_(DurationFromMilliseconds(250)) {} + +Tuning TuningResolver::Resolve(CpuInfo* cpuinfo) { + if (unresolved_tuning_ != Tuning::kAuto) { + return unresolved_tuning_; + } + TimePoint new_timepoint = CoarseNow(); + if (last_resolved_tuning_ != Tuning::kAuto && + (new_timepoint - last_resolved_timepoint_) < expiry_duration_) { + return last_resolved_tuning_; + } + last_resolved_timepoint_ = new_timepoint; + last_resolved_tuning_ = ResolveNow(cpuinfo); + return last_resolved_tuning_; +} + +} // namespace ruy diff --git a/ruy/tune.h b/ruy/tune.h new file mode 100644 index 0000000..c9beed9 --- /dev/null +++ b/ruy/tune.h @@ -0,0 +1,117 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Library doing minimal CPU detection to decide what to tune asm code for. +// +// # Tuning vs Path +// +// Tunings are merely local variations of optimized code paths, that are +// drop-in replacements for each other --- the input and output data layouts +// are identical. By contrast, what ruy calls a Path dictates its own +// data layouts. For example, Path::kNeonDotprod will use different +// layouts compared to Path::kNeon; but within each, different tunings +// will share that same layout. +// +// # Tuning is for now only based on 1 bit: Generic / A55ish +// +// In practice, each of our asm code paths only needs one bit information to +// decide on tuning: whether the CPU is out-of-order or in-order. +// That is because out-of-order CPUs are by definition relatively insensitive +// to small-scale asm details (which is what "tuning" is about); and for each +// asm code path, there tends to be one main in-order CPU architecture that +// we focus our tuning effort on. Examples: +// * For Path::kNeon, the main in-order CPU is Cortex-A53/A55 (pre-dotprod) +// * For Path::kNeonDotprod, the main in-order CPU is Cortex-A55r1 (dotprod) +// +// Because having tuned code paths is a compromise of efficiency gains +// versus implementation effort and code size, we are happy to stop at just this +// single bit of information, Generic / A55ish, at least in the current CPU +// landscape. This could change in the future. +#ifndef RUY_RUY_TUNE_H_ +#define RUY_RUY_TUNE_H_ + +#include "ruy/cpuinfo.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/time.h" + +namespace ruy { + +enum class Tuning { + // kAuto means please use auto-detection. It's the default in the + // user-visible parts (see Context). It's meant to be resolved to an + // actual tuning at some point by means of TuningResolver. + kAuto, + // Use code not tuned for any particular CPU, typically performing well + // on out-of-order cores that don't require as much tuning. + kGeneric, + // Use code tuned for "Cortex-A55-ish" CPUs, by which we mean mostly: + // A53, A55r0 (pre-dotprod), A55r1 (with dotprod). These CPUs have in common + // that they are in-order CPU cores with largely similar requirements of code + // tuning. The most important such requirement is to use only 64-bit loads + // to maximize dual-issuing. + // + // A55r1 differs from A55r0 and A53 in that it dual-issues 64-bit NEON loads + // whereas A55r0 and A53 require using non-NEON ARM 64-bit loads together with + // INS instructions to insert 64bit lanes into NEON registers. However, since + // A55r1 supports dotprod unlike A55r0 and A53, they are not using the same + // kernels in practice anyway, so there was no need to distinguish them with + // separate Tuning values. + kA55ish +}; + +// Why a TuningResolver class? +// +// Ideally, this Library would offer a single function, +// Tuning GetCurrentCPUTuning(); +// +// However, determining information about the current CPU is not necessarily +// cheap, so we currently cache that and only invalidate/reevaluate after +// a fixed amount of time. This need to store state is why this library +// has to expose a class, TuningResolver, not just a function. +class TuningResolver { + public: + TuningResolver(); + + // Allows the user to specify an explicit Tuning value, bypassing auto + // detection; or to specify Tuning::kAuto, reverting to auto detection. + void SetTuning(Tuning tuning) { unresolved_tuning_ = tuning; } + + // Get an actual tuning --- that is the function that this class wanted to be. + Tuning Resolve(CpuInfo* cpuinfo); + + private: + TuningResolver(const TuningResolver&) = delete; + + // Perform the tuning resolution now. That may typically use EvalRatio and + // ThresholdRatio, but an implementation may use a different approach instead. + Tuning ResolveNow(CpuInfo* cpuinfo); + + // The tuning as specified by the user, before actual resolution happens + // i.e. before querying any specifics of the current CPU. + // The default value kAuto means try to auto-detect. Other values mean + // bypass auto-detect, use explicit value instead. See SetTuning(). + Tuning unresolved_tuning_ = Tuning::kAuto; + // Cached last resolved tuning. + Tuning last_resolved_tuning_ = Tuning::kAuto; + // Timepoint of cached last resolved tuning, for invalidation purposes. + TimePoint last_resolved_timepoint_; + // Cached last resolved tunings that are older than this age are invalid. + const Duration expiry_duration_; +}; + +} // namespace ruy + +#endif // RUY_RUY_TUNE_H_ diff --git a/ruy/tune_test.cc b/ruy/tune_test.cc new file mode 100644 index 0000000..c5f2342 --- /dev/null +++ b/ruy/tune_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/tune.h" + +#include <chrono> // NOLINT(build/c++11) +#include <thread> // NOLINT(build/c++11) + +#include "ruy/cpuinfo.h" +#include "ruy/gtest_wrapper.h" + +namespace ruy { +namespace { + +TEST(TuneTest, TuneTest) { + TuningResolver tuning_resolver; + CpuInfo cpuinfo; + ASSERT_FALSE(tuning_resolver.Resolve(&cpuinfo) == Tuning::kAuto); + // 1 second is likely higher than TuningResolver's internal cache expiry, + // exercising the logic invalidating earlier tuning resolutions. + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_FALSE(tuning_resolver.Resolve(&cpuinfo) == Tuning::kAuto); + + tuning_resolver.SetTuning(Tuning::kAuto); + +#ifdef RUY_IMPLEMENT_TUNING + for (auto tuning : {Tuning::kGeneric, Tuning::kA55ish}) { + tuning_resolver.SetTuning(tuning); + ASSERT_TRUE(tuning_resolver.Resolve(&cpuinfo) == tuning); + // See above comment about 1 second. + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_TRUE(tuning_resolver.Resolve(&cpuinfo) == tuning); + } +#endif +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/validate.h b/ruy/validate.h new file mode 100644 index 0000000..b164530 --- /dev/null +++ b/ruy/validate.h @@ -0,0 +1,75 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Front-end validation code, see the Validate function. + +#ifndef RUY_RUY_VALIDATE_H_ +#define RUY_RUY_VALIDATE_H_ + +#include <cstdint> +#include <limits> +#include <type_traits> + +#include "ruy/check_macros.h" +#include "ruy/mat.h" +#include "ruy/mul_params.h" +#include "ruy/side_pair.h" + +namespace ruy { +namespace detail { + +template <typename Scalar> +void CheckZeroPoint(Scalar zero_point) { + if (std::is_floating_point<Scalar>::value) { + RUY_DCHECK(!zero_point); + } +} + +template <typename LhsScalar, typename RhsScalar, typename DstScalar> +void ValidateZeroPoints(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, + DstScalar dst_zero_point) { + CheckZeroPoint(lhs_zero_point); + CheckZeroPoint(rhs_zero_point); + CheckZeroPoint(dst_zero_point); + + // Guard against the case when both LHS and RHS zero_point's are equal to + // the minimum representable value. In that case, padding with zero_point + // values will generate the bad case for fast int8 kernels on NEON + // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8 + // into a int16: this is safe except in the bad case -128*-128 + -128*-128. + // See b/131609283. This only affects the kNeon path but we ban this for all + // paths in order for ruy to have the same supported parameter space + // on all paths. + // We disable this check for now for the case of LhsScalar==RhsScalar==uint8 + // for backwards compatability with gemmlowp. The issue is still relevant + // because we convert from uint8 to int8 for the backend kernels. + if (!std::is_same<LhsScalar, uint8_t>::value || + !std::is_same<RhsScalar, uint8_t>::value) { + RUY_DCHECK(lhs_zero_point != std::numeric_limits<LhsScalar>::lowest() || + rhs_zero_point != std::numeric_limits<RhsScalar>::lowest()); + } +} + +} // namespace detail + +template <typename LhsScalar, typename RhsScalar, typename DstScalar> +void Validate(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, + const Mat<DstScalar>& dst) { + detail::ValidateZeroPoints(lhs.zero_point, rhs.zero_point, dst.zero_point); +} + +} // namespace ruy + +#endif // RUY_RUY_VALIDATE_H_ diff --git a/ruy/wait.cc b/ruy/wait.cc new file mode 100644 index 0000000..fc33832 --- /dev/null +++ b/ruy/wait.cc @@ -0,0 +1,44 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/wait.h" + +#include <chrono> // NOLINT(build/c++11) + +namespace ruy { + +void Wait(const std::function<bool()>& condition, const Duration& spin_duration, + std::condition_variable* condvar, std::mutex* mutex) { + // First, trivial case where the `condition` is already true; + if (condition()) { + return; + } + + // Then, if spin_duration is nonzero, try busy-waiting. + if (spin_duration.count() > 0) { + const TimePoint wait_start = Now(); + while (Now() - wait_start < spin_duration) { + if (condition()) { + return; + } + } + } + + // Finally, do real passive waiting. + std::unique_lock<std::mutex> lock(*mutex); + condvar->wait(lock, condition); +} + +} // namespace ruy diff --git a/ruy/wait.h b/ruy/wait.h new file mode 100644 index 0000000..24f9912 --- /dev/null +++ b/ruy/wait.h @@ -0,0 +1,68 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_WAIT_H_ +#define RUY_RUY_WAIT_H_ + +#include <condition_variable> // NOLINT(build/c++11) +#include <functional> +#include <mutex> // NOLINT(build/c++11) + +#include "ruy/time.h" + +namespace ruy { + +// Waits until some evaluation of `condition` has returned true. +// +// There is no guarantee that calling `condition` again after this function +// has returned would still return true. The only +// contract is that at some point during the execution of that function, +// `condition` has returned true. +// +// First does some spin-waiting for the specified `spin_duration`, +// then falls back to passive waiting for the given condvar, guarded +// by the given mutex. At this point it will try to acquire the mutex lock, +// around the waiting on the condition variable. +// Therefore, this function expects that the calling thread hasn't already +// locked the mutex before calling it. +// This function will always release the mutex lock before returning. +// +// The idea of doing some initial spin-waiting is to help get +// better and more consistent multithreading benefits for small GEMM sizes. +// Spin-waiting help ensuring that if we need to wake up soon after having +// started waiting, then we can wake up quickly (as opposed to, say, +// having to wait to be scheduled again by the OS). On the other hand, +// we must still eventually revert to passive waiting for longer waits +// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) +// so as to avoid permanently spinning. +// +// In situations where other threads might have more useful things to do with +// these CPU cores than our spin-waiting, it may be best to reduce the value +// of `spin_duration`. Setting it to zero disables the spin-waiting entirely. +// +// There is a risk that the std::function used here might use a heap allocation +// to store its context. The expected usage pattern is that these functions' +// contexts will consist of a single pointer value (typically capturing only +// [this]), and that in this case the std::function implementation will use +// inline storage, avoiding a heap allocation. However, we can't effectively +// guard that assumption, and that's not a big concern anyway because the +// latency of a small heap allocation is probably low compared to the intrinsic +// latency of what this Wait function does. +void Wait(const std::function<bool()>& condition, const Duration& spin_duration, + std::condition_variable* condvar, std::mutex* mutex); + +} // namespace ruy + +#endif // RUY_RUY_WAIT_H_ diff --git a/ruy/wait_test.cc b/ruy/wait_test.cc new file mode 100644 index 0000000..fa02b0d --- /dev/null +++ b/ruy/wait_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/wait.h" + +#include <atomic> +#include <condition_variable> // NOLINT(build/c++11) +#include <mutex> // NOLINT(build/c++11) +#include <thread> // NOLINT(build/c++11) + +#include "ruy/gtest_wrapper.h" +#include "ruy/platform.h" + +namespace ruy { +namespace { + +// Thread taking a `value` atomic counter and incrementing it until it equals +// `end_value`, then notifying the condition variable as long as +// `value == end_value`. If `end_value` is increased, it will then resume +// incrementing `value`, etc. Terminates if `end_value == -1`. +class ThreadCountingUpToValue { + public: + ThreadCountingUpToValue(const std::atomic<int>& end_value, + std::atomic<int>* value, + std::condition_variable* condvar, std::mutex* mutex) + : end_value_(end_value), + value_(value), + condvar_(condvar), + mutex_(mutex) {} + void operator()() { + // end_value_==-1 is how the master thread will tell us it's OK to terminate + while (end_value_.load() != -1) { + // wait until end_value is set to a higher value + while (value_->load() == end_value_.load()) { + } + // increment value as long as it's lower than end_value + while (value_->fetch_add(1) < end_value_.load() - 1) { + } + // when value has reached end_value, notify the master thread. + while (value_->load() == end_value_.load()) { + std::lock_guard<std::mutex> lock(*mutex_); + condvar_->notify_all(); + } + } + } + + private: + const std::atomic<int>& end_value_; + std::atomic<int>* value_; + std::condition_variable* condvar_; + std::mutex* mutex_; +}; + +void WaitTest(const Duration& spin_duration, const Duration& delay) { +#if RUY_PLATFORM_EMSCRIPTEN + // b/139927184, std::thread constructor raises exception + return; +#endif + std::condition_variable condvar; + std::mutex mutex; + std::atomic<int> value(0); + std::atomic<int> end_value(0); + ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex); + std::thread thread(thread_callable); + std::this_thread::sleep_for(delay); + for (int i = 1; i < 10; i++) { + end_value.store(1000 * i); + const auto& condition = [&value, &end_value]() { + return value.load() == end_value.load(); + }; + ruy::Wait(condition, spin_duration, &condvar, &mutex); + EXPECT_EQ(value.load(), end_value.load()); + } + end_value.store(-1); + thread.join(); +} + +TEST(WaitTest, WaitTestNoSpin) { + WaitTest(DurationFromSeconds(0), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneMicrosecond) { + WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneMillisecond) { + WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneSecond) { + WaitTest(DurationFromSeconds(1), DurationFromSeconds(0)); +} + +// Testcase to consistently reproduce the hang in b/139062384. +TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) { + WaitTest(DurationFromSeconds(0), DurationFromSeconds(1)); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/BUILD b/third_party/BUILD new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/third_party/BUILD diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt new file mode 100644 index 0000000..4c82d83 --- /dev/null +++ b/third_party/CMakeLists.txt @@ -0,0 +1,5 @@ +# This file is generated (whence no license header). Do not edit! +# To regenerate, run: +# cmake/bazel_to_cmake.sh + +ruy_add_all_subdirs() diff --git a/third_party/cpuinfo b/third_party/cpuinfo new file mode 160000 +Subproject 5916273f79a21551890fd3d56fc5375a78d1598 diff --git a/third_party/cpuinfo.BUILD b/third_party/cpuinfo.BUILD new file mode 100644 index 0000000..6d68cd5 --- /dev/null +++ b/third_party/cpuinfo.BUILD @@ -0,0 +1,386 @@ +# cpuinfo, a library to detect information about the host CPU +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +C99OPTS = [ + "-std=gnu99", # gnu99, not c99, because dprintf is used + "-Wno-vla", + "-D_GNU_SOURCE=1", # to use CPU_SETSIZE + "-DCPUINFO_INTERNAL=", + "-DCPUINFO_PRIVATE=", +] + +# Source code common to all platforms. +COMMON_SRCS = [ + "src/api.c", + "src/init.c", + "src/cache.c", +] + +# Architecture-specific sources and headers. +X86_SRCS = [ + "src/x86/cache/descriptor.c", + "src/x86/cache/deterministic.c", + "src/x86/cache/init.c", + "src/x86/info.c", + "src/x86/init.c", + "src/x86/isa.c", + "src/x86/name.c", + "src/x86/topology.c", + "src/x86/uarch.c", + "src/x86/vendor.c", +] + +ARM_SRCS = [ + "src/arm/cache.c", + "src/arm/uarch.c", +] + +# Platform-specific sources and headers +LINUX_SRCS = [ + "src/linux/cpulist.c", + "src/linux/multiline.c", + "src/linux/processors.c", + "src/linux/smallfile.c", +] + +MOCK_LINUX_SRCS = [ + "src/linux/mockfile.c", +] + +MACH_SRCS = [ + "src/mach/topology.c", +] + +EMSCRIPTEN_SRCS = [ + "src/emscripten/init.c", +] + +LINUX_X86_SRCS = [ + "src/x86/linux/cpuinfo.c", + "src/x86/linux/init.c", +] + +LINUX_ARM_SRCS = [ + "src/arm/linux/chipset.c", + "src/arm/linux/clusters.c", + "src/arm/linux/cpuinfo.c", + "src/arm/linux/hwcap.c", + "src/arm/linux/init.c", + "src/arm/linux/midr.c", +] + +LINUX_ARM32_SRCS = LINUX_ARM_SRCS + ["src/arm/linux/aarch32-isa.c"] + +LINUX_ARM64_SRCS = LINUX_ARM_SRCS + ["src/arm/linux/aarch64-isa.c"] + +ANDROID_ARM_SRCS = [ + "src/arm/android/properties.c", +] + +WINDOWS_X86_SRCS = [ + "src/x86/windows/init.c", +] + +MACH_X86_SRCS = [ + "src/x86/mach/init.c", +] + +MACH_ARM_SRCS = [ + "src/arm/mach/init.c", +] + +cc_library( + name = "clog", + srcs = [ + "deps/clog/src/clog.c", + ], + hdrs = [ + "deps/clog/include/clog.h", + ], + copts = select({ + ":windows_x86_64": [], + "//conditions:default": ["-Wno-unused-result"], + }), + linkopts = select({ + ":android": ["-llog"], + "//conditions:default": [], + }), + linkstatic = select({ + # https://github.com/bazelbuild/bazel/issues/11552 + ":macos_x86_64": False, + "//conditions:default": True, + }), + defines = select({ + # When linkstatic=False, we need default visibility + ":macos_x86_64": ["CLOG_VISIBILITY="], + "//conditions:default": [], + }), + strip_include_prefix = "deps/clog/include", +) + +cc_library( + name = "cpuinfo_impl", + srcs = select({ + ":linux_x86_64": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS, + ":linux_arm": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS, + ":linux_armhf": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS, + ":linux_armv7a": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS, + ":linux_armeabi": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS, + ":linux_aarch64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS, + ":macos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":windows_x86_64": COMMON_SRCS + X86_SRCS + WINDOWS_X86_SRCS, + ":android_armv7": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS + ANDROID_ARM_SRCS, + ":android_arm64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS + ANDROID_ARM_SRCS, + ":android_x86": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS, + ":android_x86_64": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS, + ":ios_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":ios_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":ios_armv7": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":ios_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":ios_arm64e": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":watchos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":watchos_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":watchos_armv7k": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":watchos_arm64_32": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":tvos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":tvos_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":emscripten": COMMON_SRCS + EMSCRIPTEN_SRCS, + }), + copts = select({ + ":windows_x86_64": [], + "//conditions:default": C99OPTS, + }) + [ + "-Iexternal/cpuinfo/include", + "-Iexternal/cpuinfo/src", + ], + linkstatic = select({ + # https://github.com/bazelbuild/bazel/issues/11552 + ":macos_x86_64": False, + "//conditions:default": True, + }), + # Headers must be in textual_hdrs to allow us to set the standard to C99 + textual_hdrs = [ + "include/cpuinfo.h", + "src/linux/api.h", + "src/mach/api.h", + "src/cpuinfo/common.h", + "src/cpuinfo/internal-api.h", + "src/cpuinfo/log.h", + "src/cpuinfo/utils.h", + "src/x86/api.h", + "src/x86/cpuid.h", + "src/x86/linux/api.h", + "src/arm/android/api.h", + "src/arm/linux/api.h", + "src/arm/linux/cp.h", + "src/arm/api.h", + "src/arm/midr.h", + ], + deps = [ + ":clog", + ], +) + +cc_library( + name = "cpuinfo", + hdrs = [ + "include/cpuinfo.h", + ], + strip_include_prefix = "include", + deps = [ + ":cpuinfo_impl", + ], +) + +cc_library( + name = "cpuinfo_with_unstripped_include_path", + hdrs = [ + "include/cpuinfo.h", + ], + deps = [ + ":cpuinfo_impl", + ], +) + +############################# Build configurations ############################# + +config_setting( + name = "linux_x86_64", + values = {"cpu": "k8"}, +) + +config_setting( + name = "linux_arm", + values = {"cpu": "arm"}, +) + +config_setting( + name = "linux_armhf", + values = {"cpu": "armhf"}, +) + +config_setting( + name = "linux_armv7a", + values = {"cpu": "armv7a"}, +) + +config_setting( + name = "linux_armeabi", + values = {"cpu": "armeabi"}, +) + +config_setting( + name = "linux_aarch64", + values = {"cpu": "aarch64"}, +) + +config_setting( + name = "macos_x86_64", + values = { + "apple_platform_type": "macos", + "cpu": "darwin", + }, +) + +config_setting( + name = "android", + values = {"crosstool_top": "//external:android/crosstool"}, +) + +config_setting( + name = "windows_x86_64", + values = {"cpu": "x64_windows"}, +) + +config_setting( + name = "android_armv7", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "armeabi-v7a", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_arm64", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "arm64-v8a", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_x86", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "x86", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_x86_64", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "x86_64", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "ios_armv7", + values = { + "apple_platform_type": "ios", + "cpu": "ios_armv7", + }, +) + +config_setting( + name = "ios_arm64", + values = { + "apple_platform_type": "ios", + "cpu": "ios_arm64", + }, +) + +config_setting( + name = "ios_arm64e", + values = { + "apple_platform_type": "ios", + "cpu": "ios_arm64e", + }, +) + +config_setting( + name = "ios_x86", + values = { + "apple_platform_type": "ios", + "cpu": "ios_i386", + }, +) + +config_setting( + name = "ios_x86_64", + values = { + "apple_platform_type": "ios", + "cpu": "ios_x86_64", + }, +) + +config_setting( + name = "watchos_armv7k", + values = { + "apple_platform_type": "watchos", + "cpu": "watchos_armv7k", + }, +) + +config_setting( + name = "watchos_arm64_32", + values = { + "apple_platform_type": "watchos", + "cpu": "watchos_arm64_32", + }, +) + +config_setting( + name = "watchos_x86", + values = { + "apple_platform_type": "watchos", + "cpu": "watchos_i386", + }, +) + +config_setting( + name = "watchos_x86_64", + values = { + "apple_platform_type": "watchos", + "cpu": "watchos_x86_64", + }, +) + +config_setting( + name = "tvos_arm64", + values = { + "apple_platform_type": "tvos", + "cpu": "tvos_arm64", + }, +) + +config_setting( + name = "tvos_x86_64", + values = { + "apple_platform_type": "tvos", + "cpu": "tvos_x86_64", + }, +) + +config_setting( + name = "emscripten", + values = {"crosstool_top": "//toolchain:emscripten"}, +) diff --git a/third_party/googletest b/third_party/googletest new file mode 160000 +Subproject 6c58c11d5497b6ee1df3cb400ce30deb72fc28c |