aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@google.com>2022-08-22 00:40:25 -0700
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-08-22 00:41:31 -0700
commit5dfe80fcb697e31e7c1ffe48e3bcfc69025c28ea (patch)
treeb5efbebb19503d2ca32e136d76be34cf936b4997
parent0ef7bcf3c146c800dbe229351815472fd585d68a (diff)
downloadXNNPACK-5dfe80fcb697e31e7c1ffe48e3bcfc69025c28ea.tar.gz
U64->U32 VSQRTSHIFT microkernel
PiperOrigin-RevId: 469114060
-rw-r--r--BUILD.bazel18
-rwxr-xr-xCMakeLists.txt5
-rw-r--r--bench/u64-u32-vsqrtshift.cc72
-rwxr-xr-xscripts/generate-u64-u32-vsqrtshift.sh10
-rw-r--r--src/microparams-init.c9
-rw-r--r--src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c66
-rw-r--r--src/xnnpack/microfnptr.h12
-rw-r--r--src/xnnpack/microparams-init.h8
-rw-r--r--src/xnnpack/microparams.h8
-rw-r--r--src/xnnpack/vunary.h11
-rw-r--r--test/u64-u32-vsqrtshift.cc43
-rw-r--r--test/u64-u32-vsqrtshift.yaml8
-rw-r--r--test/vunary-microkernel-tester.h66
-rwxr-xr-xtools/generate-vunary-test.py42
14 files changed, 365 insertions, 13 deletions
diff --git a/BUILD.bazel b/BUILD.bazel
index a9284b662..d081abdb4 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1390,6 +1390,7 @@ ALL_SCALAR_MICROKERNEL_SRCS = [
"src/u32-vlog/gen/scalar-x2.c",
"src/u32-vlog/gen/scalar-x3.c",
"src/u32-vlog/gen/scalar-x4.c",
+ "src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c",
"src/xx-copy/memcpy.c",
"src/xx-fill/scalar-x16.c",
"src/xx-pad/scalar.c",
@@ -11892,6 +11893,14 @@ xnnpack_benchmark(
)
xnnpack_benchmark(
+ name = "u64_u32_vsqrtshift_bench",
+ srcs = [
+ "bench/u64-u32-vsqrtshift.cc",
+ ],
+ deps = MICROKERNEL_BENCHMARK_DEPS,
+)
+
+xnnpack_benchmark(
name = "s16_vlshift_bench",
srcs = [
"bench/s16-vlshift.cc",
@@ -14400,6 +14409,15 @@ xnnpack_unit_test(
)
xnnpack_unit_test(
+ name = "u64_u32_vsqrtshift_test",
+ srcs = [
+ "test/u64-u32-vsqrtshift.cc",
+ "test/vunary-microkernel-tester.h",
+ ],
+ deps = MICROKERNEL_TEST_DEPS,
+)
+
+xnnpack_unit_test(
name = "x8_lut_test",
srcs = [
"test/lut-microkernel-tester.h",
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8d85a2a30..903c48cf5 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1378,6 +1378,7 @@ SET(ALL_SCALAR_MICROKERNEL_SRCS
src/u32-vlog/gen/scalar-x2.c
src/u32-vlog/gen/scalar-x3.c
src/u32-vlog/gen/scalar-x4.c
+ src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c
src/xx-copy/memcpy.c
src/xx-fill/scalar-x16.c
src/xx-pad/scalar.c
@@ -9359,6 +9360,10 @@ IF(XNNPACK_BUILD_BENCHMARKS)
TARGET_INCLUDE_DIRECTORIES(u32-vlog-bench PRIVATE . include src)
TARGET_LINK_LIBRARIES(u32-vlog-bench PRIVATE benchmark bench-utils cpuinfo fp16 pthreadpool)
+ ADD_EXECUTABLE(u64-u32-vsqrtshift-bench bench/f32-vsqrt.cc $<TARGET_OBJECTS:all_microkernels>)
+ TARGET_INCLUDE_DIRECTORIES(u64-u32-vsqrtshift-bench PRIVATE . include src)
+ TARGET_LINK_LIBRARIES(u64-u32-vsqrtshift-bench PRIVATE benchmark bench-utils fp16 pthreadpool microparams_init)
+
ADD_EXECUTABLE(s16-vlshift-bench bench/s16-vlshift.cc $<TARGET_OBJECTS:all_microkernels>)
TARGET_INCLUDE_DIRECTORIES(s16-vlshift-bench PRIVATE . include src)
TARGET_LINK_LIBRARIES(s16-vlshift-bench PRIVATE benchmark bench-utils cpuinfo fp16 pthreadpool)
diff --git a/bench/u64-u32-vsqrtshift.cc b/bench/u64-u32-vsqrtshift.cc
new file mode 100644
index 000000000..70f109459
--- /dev/null
+++ b/bench/u64-u32-vsqrtshift.cc
@@ -0,0 +1,72 @@
+// Copyright 2022 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <algorithm>
+#include <cmath>
+#include <functional>
+#include <random>
+#include <vector>
+
+#include <benchmark/benchmark.h>
+#include "bench/utils.h"
+
+#include <xnnpack.h>
+#include <xnnpack/aligned-allocator.h>
+#include <xnnpack/common.h>
+#include <xnnpack/microfnptr.h>
+#include <xnnpack/microparams-init.h>
+#include <xnnpack/vunary.h>
+
+
+static void u64_u32_vsqrtshift(
+ benchmark::State& state,
+ xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift,
+ xnn_init_u64_u32_sqrtshift_params_fn init_params,
+ benchmark::utils::IsaCheckFunction isa_check = nullptr)
+{
+ if (isa_check && !isa_check(state)) {
+ return;
+ }
+
+ const size_t num_elements = state.range(0);
+
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto u64rng = std::bind(std::uniform_int_distribution<uint64_t>(), std::ref(rng));
+
+ std::vector<uint64_t, AlignedAllocator<uint64_t, 64>> x(num_elements + XNN_EXTRA_BYTES / sizeof(uint64_t));
+ std::vector<uint32_t, AlignedAllocator<uint32_t, 64>> y(num_elements);
+ std::generate(x.begin(), x.end(), std::ref(u64rng));
+ std::fill(y.begin(), y.end(), UINT32_C(0xDEADBEEF));
+
+ xnn_u64_u32_sqrtshift_params params;
+ init_params(&params, 1 /* shift */);
+ for (auto _ : state) {
+ vsqrtshift(num_elements * sizeof(uint64_t), x.data(), y.data(), &params);
+ }
+
+ const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
+ if (cpu_frequency != 0) {
+ state.counters["cpufreq"] = cpu_frequency;
+ }
+
+ const size_t elements_per_iteration = num_elements;
+ state.counters["elements"] =
+ benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
+
+ const size_t bytes_per_iteration = num_elements * (sizeof(uint64_t) + sizeof(uint32_t));
+ state.counters["bytes"] =
+ benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
+}
+
+BENCHMARK_CAPTURE(u64_u32_vsqrtshift, scalar_cvtu32_sqrt_cvtf64u32_x1,
+ xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1,
+ xnn_init_u64_u32_sqrtshift_scalar_params)
+ ->Apply(benchmark::utils::UnaryElementwiseParameters<uint64_t, uint32_t>)
+ ->UseRealTime();
+
+#ifndef XNNPACK_BENCHMARK_NO_MAIN
+BENCHMARK_MAIN();
+#endif
diff --git a/scripts/generate-u64-u32-vsqrtshift.sh b/scripts/generate-u64-u32-vsqrtshift.sh
new file mode 100755
index 000000000..3fae8fa09
--- /dev/null
+++ b/scripts/generate-u64-u32-vsqrtshift.sh
@@ -0,0 +1,10 @@
+#!/bin/sh
+# Copyright 2022 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+################################## Unit tests #################################
+tools/generate-vunary-test.py --spec test/u64-u32-vsqrtshift.yaml --output test/u64-u32-vsqrtshift.cc &
+
+wait
diff --git a/src/microparams-init.c b/src/microparams-init.c
index 8bb5e1d54..07aa4302a 100644
--- a/src/microparams-init.c
+++ b/src/microparams-init.c
@@ -4038,6 +4038,15 @@ size_t xnn_init_f32_sqrt_avx512_params(
}
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+size_t xnn_init_u64_u32_sqrtshift_scalar_params(
+ union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)],
+ uint32_t shift)
+{
+ assert(shift < 32);
+ params->scalar.shift = shift;
+ return sizeof(params->scalar);
+}
+
size_t xnn_init_f32_chw_params(
union xnn_f32_chw_params params[XNN_MIN_ELEMENTS(1)],
uint32_t width,
diff --git a/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c b/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c
new file mode 100644
index 000000000..801e9aeeb
--- /dev/null
+++ b/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c
@@ -0,0 +1,66 @@
+// Copyright 2022 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <math.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vunary.h>
+
+
+void xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1(
+ size_t batch,
+ const uint64_t* input,
+ uint32_t* output,
+ const union xnn_u64_u32_sqrtshift_params params[restrict XNN_MIN_ELEMENTS(1)])
+{
+ assert(batch != 0);
+ assert(input != NULL);
+ assert(output != NULL);
+
+ const uint32_t vshift = params->scalar.shift;
+ assert(vshift < 32);
+ do {
+ const uint64_t vx = *input++;
+
+ uint64_t vy = vx;
+ const uint32_t vx_hi = (uint32_t) (vx >> 32);
+ const uint32_t vx_lo = (uint32_t) vx;
+ if XNN_LIKELY(vx != 0) {
+ const double vf_hi = (double) vx_hi;
+ const double vf_lo = (double) vx_lo;
+ double vf = vf_hi * 0x1.0p+32 + vf_lo;
+ vf = sqrt(vf);
+ vy = math_cvt_sat_u32_f64(vf);
+ #if XNN_ARCH_ARM || XNN_ARCH_X86
+ const uint64_t vsquared_y_less_x = math_mulext_u32((uint32_t) vy, (uint32_t) vy) - vx;
+ #else
+ const uint64_t vsquared_y_less_x = vy * vy - vx;
+ #endif
+ if XNN_UNPREDICTABLE((int64_t) (vsquared_y_less_x + vy) < 0) {
+ vy += 1;
+ } else if XNN_UNPREDICTABLE((int64_t) (vsquared_y_less_x - vy) >= 0) {
+ vy -= 1;
+ }
+ }
+
+ // Match TFLM is producing incorrect result for high 64-bit inputs
+ const uint32_t vy_lo = (uint32_t) vy;
+ const uint32_t vy_hi = (uint32_t) (vy >> 32);
+ uint32_t vout = vy_lo | -vy_hi;
+ // Match TFLM is producing incorrect result for high 32-bit inputs
+ if XNN_LIKELY(vx_hi == 0) {
+ if (vout == UINT32_C(0x00010000)) {
+ vout -= 1;
+ }
+ }
+
+ *output++ = vout >> vshift;
+
+ batch -= sizeof(uint64_t);
+ } while (batch != 0);
+}
diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h
index 71396f3bd..401fb4861 100644
--- a/src/xnnpack/microfnptr.h
+++ b/src/xnnpack/microfnptr.h
@@ -1331,6 +1331,14 @@ typedef void (*xnn_f32_vsqrt_ukernel_function)(
float* output,
const union xnn_f32_sqrt_params* params);
+// VSQRTSHIFT: Vector SQuare RooT and SHIFT elementwise
+
+typedef void (*xnn_u64_u32_vsqrtshift_ukernel_function)(
+ size_t batch,
+ const uint64_t* input,
+ uint32_t* output,
+ const union xnn_u64_u32_sqrtshift_params* params);
+
// LUT: vector LookUp Table elementwise
typedef void (*xnn_x8_lut_ukernel_function)(
@@ -2013,6 +2021,10 @@ typedef size_t (*xnn_init_f16_sqrt_params_fn)(
typedef size_t (*xnn_init_f32_sqrt_params_fn)(
union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)]);
+typedef size_t (*xnn_init_u64_u32_sqrtshift_params_fn)(
+ union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)],
+ uint32_t shift);
+
typedef void (*xnn_init_qc8_scale_params_fn)(
size_t channels,
size_t channels_tile,
diff --git a/src/xnnpack/microparams-init.h b/src/xnnpack/microparams-init.h
index ac7e2447b..054a6f8eb 100644
--- a/src/xnnpack/microparams-init.h
+++ b/src/xnnpack/microparams-init.h
@@ -668,6 +668,14 @@ DECLARE_INIT_QU8_LRELU_PARAMS_FUNCTION(xnn_init_qu8_lrelu_scalar_select_params)
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+#define DECLARE_INIT_U64_U32_SQRTSHIFT_PARAMS_FUNCTION(fn_name) \
+ XNN_INTERNAL size_t fn_name( \
+ union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)], \
+ uint32_t shift);
+
+DECLARE_INIT_U64_U32_SQRTSHIFT_PARAMS_FUNCTION(xnn_init_u64_u32_sqrtshift_scalar_params)
+
+
XNN_INTERNAL size_t xnn_init_f16_chw_params(
union xnn_f16_chw_params params[XNN_MIN_ELEMENTS(1)],
uint32_t width,
diff --git a/src/xnnpack/microparams.h b/src/xnnpack/microparams.h
index aa40e34ed..8c7c5679a 100644
--- a/src/xnnpack/microparams.h
+++ b/src/xnnpack/microparams.h
@@ -2391,6 +2391,14 @@ union xnn_f32_sqrt_params {
};
+// SqrtShift (Square Root + Shift): used by VSQRTSHIFT microkernels.
+
+union xnn_u64_u32_sqrtshift_params {
+ struct {
+ uint32_t shift;
+ } scalar;
+};
+
// CHW: used by CONV/DWCONV microkernels in CHW layout with Min+Max parameters.
union xnn_f16_chw_params {
diff --git a/src/xnnpack/vunary.h b/src/xnnpack/vunary.h
index aee3576d9..e74d6113a 100644
--- a/src/xnnpack/vunary.h
+++ b/src/xnnpack/vunary.h
@@ -1077,6 +1077,17 @@ DECLARE_U8_VCLAMP_UKERNEL_FUNCTION(xnn_u8_vclamp_ukernel__sse2_x64)
DECLARE_U8_VCLAMP_UKERNEL_FUNCTION(xnn_u8_vclamp_ukernel__wasmsimd_x64)
+#define DECLARE_U64_U32_VSQRTSHIFT_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const uint64_t* x, \
+ uint32_t* y, \
+ const union xnn_u64_u32_sqrtshift_params* params);
+
+
+DECLARE_U64_U32_VSQRTSHIFT_UKERNEL_FUNCTION(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1)
+
+
#define DECLARE_XX_VUNARY_UKERNEL_FUNCTION(fn_name) \
XNN_INTERNAL void fn_name( \
size_t size, \
diff --git a/test/u64-u32-vsqrtshift.cc b/test/u64-u32-vsqrtshift.cc
new file mode 100644
index 000000000..1313a8950
--- /dev/null
+++ b/test/u64-u32-vsqrtshift.cc
@@ -0,0 +1,43 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+//
+// Auto-generated file. Do not edit!
+// Specification: test/u64-u32-vsqrtshift.yaml
+// Generator: tools/generate-vunary-test.py
+
+
+#include <gtest/gtest.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/isa-checks.h>
+
+#include <xnnpack/vunary.h>
+#include "vunary-microkernel-tester.h"
+
+
+TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTF64U32_X1, batch_eq_1) {
+ VUnaryMicrokernelTester()
+ .batch_size(1)
+ .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1, xnn_init_u64_u32_sqrtshift_scalar_params);
+}
+
+TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTF64U32_X1, batch_gt_1) {
+ for (size_t batch_size = 2; batch_size < 10; batch_size++) {
+ VUnaryMicrokernelTester()
+ .batch_size(batch_size)
+ .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1, xnn_init_u64_u32_sqrtshift_scalar_params);
+ }
+}
+
+TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTF64U32_X1, shift) {
+ for (uint32_t shift = 0; shift < 32; shift++) {
+ for (size_t batch_size = 1; batch_size <= 5; batch_size += 1) {
+ VUnaryMicrokernelTester()
+ .batch_size(batch_size)
+ .shift(shift)
+ .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1, xnn_init_u64_u32_sqrtshift_scalar_params);
+ }
+ }
+} \ No newline at end of file
diff --git a/test/u64-u32-vsqrtshift.yaml b/test/u64-u32-vsqrtshift.yaml
new file mode 100644
index 000000000..81357114b
--- /dev/null
+++ b/test/u64-u32-vsqrtshift.yaml
@@ -0,0 +1,8 @@
+# Copyright 2022 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Scalar
+- name: xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1
+ init: xnn_init_u64_u32_sqrtshift_scalar_params
diff --git a/test/vunary-microkernel-tester.h b/test/vunary-microkernel-tester.h
index 568be767a..c915dbd7c 100644
--- a/test/vunary-microkernel-tester.h
+++ b/test/vunary-microkernel-tester.h
@@ -91,6 +91,15 @@ class VUnaryMicrokernelTester {
return this->beta_;
}
+ inline VUnaryMicrokernelTester& shift(uint32_t shift) {
+ this->shift_ = shift;
+ return *this;
+ }
+
+ inline uint32_t shift() const {
+ return this->shift_;
+ }
+
inline VUnaryMicrokernelTester& qmin(uint8_t qmin) {
this->qmin_ = qmin;
return *this;
@@ -1048,6 +1057,62 @@ class VUnaryMicrokernelTester {
}
}
+ void Test(xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift, xnn_init_u64_u32_sqrtshift_params_fn init_params) const {
+ ASSERT_FALSE(inplace());
+
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto u64rng = std::bind( std::uniform_int_distribution<uint64_t>(), std::ref(rng));
+
+ std::vector<uint64_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint64_t));
+ std::vector<uint32_t> y(batch_size());
+ std::vector<uint32_t> y_ref(batch_size());
+ for (size_t iteration = 0; iteration < iterations(); iteration++) {
+ std::generate(x.begin(), x.end(), std::ref(u64rng));
+ std::fill(y.begin(), y.end(), UINT32_C(0xDEADBEEF));
+
+ // Compute reference results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ const uint64_t x_value = x[i];
+ uint32_t y_value = 0;
+ // Match TFLM semantics, including bugs
+ if (uint32_t(x_value) == x_value) {
+ y_value = (uint32_t) std::lrint(std::sqrt(double(int64_t(uint64_t(x_value)))));
+ y_value = std::min<uint32_t>(y_value, std::numeric_limits<uint16_t>::max());
+ } else if (x_value != 0) {
+ uint64_t y0 = x_value >> 1;
+ uint64_t y1 = (y0 + x_value / y0) >> 1;
+ do {
+ y0 = y1;
+ y1 = (y0 + x_value / y0) >> 1;
+ } while (y1 < y0);
+
+ // y0 is sqrt(x_value) rounded down, round up if needed
+ if (int64_t(y0 * y0 + y0 - x_value) < 0) {
+ y0 += 1;
+ }
+ y_value = static_cast<uint32_t>(std::min<uint64_t>(y0, std::numeric_limits<uint32_t>::max()));
+ }
+ y_ref[i] = y_value >> shift();
+ }
+
+ // Prepare parameters.
+ union xnn_u64_u32_sqrtshift_params params;
+ init_params(&params, shift());
+
+ // Call optimized micro-kernel.
+ vsqrtshift(batch_size() * sizeof(uint64_t), x.data(), y.data(), &params);
+
+ // Verify results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ ASSERT_EQ(y_ref[i], y[i])
+ << "at " << i << " / " << batch_size()
+ << ", x[" << i << "]: " << x[i]
+ << ", shift: " << shift();
+ }
+ }
+ }
+
private:
size_t batch_size_ = 1;
bool inplace_ = false;
@@ -1055,6 +1120,7 @@ class VUnaryMicrokernelTester {
float prescale_ = 1.0f;
float alpha_ = 1.0f;
float beta_ = 1.0f;
+ uint32_t shift_ = 1;
uint8_t qmin_ = 0;
uint8_t qmax_ = 255;
size_t iterations_ = 15;
diff --git a/tools/generate-vunary-test.py b/tools/generate-vunary-test.py
index 26a9538b8..aeb83f558 100755
--- a/tools/generate-vunary-test.py
+++ b/tools/generate-vunary-test.py
@@ -27,7 +27,7 @@ parser.set_defaults(defines=list())
def split_ukernel_name(name):
- match = re.fullmatch(r"xnn_(s8|u8|f16|f32)_v(abs|clamp|elu|hswish|lrelu|neg|relu|rndd|rndne|rndu|rndz|sigmoid|sqr|sqrt)_(fact_)?ukernel__(.+)_x(\d+)", name)
+ match = re.fullmatch(r"xnn_(s8|u8|f16|f32|u32|u64)(_(s8|u8|f16|f32|u32|u64))*_v(abs|clamp|elu|hswish|lrelu|neg|relu|rndd|rndne|rndu|rndz|sigmoid|sqr|sqrt|sqrtshift)_(fact_)?ukernel__(.+)_x(\d+)", name)
if match is None:
raise ValueError("Unexpected microkernel name: " + name)
op_type = {
@@ -45,10 +45,11 @@ def split_ukernel_name(name):
"sigmoid": "Sigmoid",
"sqr": "Square",
"sqrt": "SquareRoot",
- }[match.group(2)]
- batch_tile = int(match.group(5))
+ "sqrtshift": "SquareRootShift",
+ }[match.group(4)]
+ batch_tile = int(match.group(7))
- arch, isa = xnncommon.parse_target_name(target_name=match.group(4))
+ arch, isa = xnncommon.parse_target_name(target_name=match.group(6))
return op_type, batch_tile, arch, isa
@@ -92,16 +93,17 @@ TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) {
}
}
-TEST(${TEST_NAME}, inplace) {
- $if ISA_CHECK:
- ${ISA_CHECK};
- for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
- VUnaryMicrokernelTester()
- .batch_size(batch_size)
- .inplace(true)
- .Test(${", ".join(TEST_ARGS)});
+$if OP_TYPE != "SquareRootShift":
+ TEST(${TEST_NAME}, inplace) {
+ $if ISA_CHECK:
+ ${ISA_CHECK};
+ for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
+ VUnaryMicrokernelTester()
+ .batch_size(batch_size)
+ .inplace(true)
+ .Test(${", ".join(TEST_ARGS)});
+ }
}
-}
$if OP_TYPE == "Clamp":
TEST(${TEST_NAME}, qmin) {
@@ -183,6 +185,20 @@ $if OP_TYPE == "LeakyReLU":
}
}
}
+
+$if OP_TYPE == "SquareRootShift":
+ TEST(${TEST_NAME}, shift) {
+ $if ISA_CHECK:
+ ${ISA_CHECK};
+ for (uint32_t shift = 0; shift < 32; shift++) {
+ for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
+ VUnaryMicrokernelTester()
+ .batch_size(batch_size)
+ .shift(shift)
+ .Test(${", ".join(TEST_ARGS)});
+ }
+ }
+ }
"""