diff options
author | Marat Dukhan <maratek@google.com> | 2022-08-22 00:40:25 -0700 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2022-08-22 00:41:31 -0700 |
commit | 5dfe80fcb697e31e7c1ffe48e3bcfc69025c28ea (patch) | |
tree | b5efbebb19503d2ca32e136d76be34cf936b4997 | |
parent | 0ef7bcf3c146c800dbe229351815472fd585d68a (diff) | |
download | XNNPACK-5dfe80fcb697e31e7c1ffe48e3bcfc69025c28ea.tar.gz |
U64->U32 VSQRTSHIFT microkernel
PiperOrigin-RevId: 469114060
-rw-r--r-- | BUILD.bazel | 18 | ||||
-rwxr-xr-x | CMakeLists.txt | 5 | ||||
-rw-r--r-- | bench/u64-u32-vsqrtshift.cc | 72 | ||||
-rwxr-xr-x | scripts/generate-u64-u32-vsqrtshift.sh | 10 | ||||
-rw-r--r-- | src/microparams-init.c | 9 | ||||
-rw-r--r-- | src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c | 66 | ||||
-rw-r--r-- | src/xnnpack/microfnptr.h | 12 | ||||
-rw-r--r-- | src/xnnpack/microparams-init.h | 8 | ||||
-rw-r--r-- | src/xnnpack/microparams.h | 8 | ||||
-rw-r--r-- | src/xnnpack/vunary.h | 11 | ||||
-rw-r--r-- | test/u64-u32-vsqrtshift.cc | 43 | ||||
-rw-r--r-- | test/u64-u32-vsqrtshift.yaml | 8 | ||||
-rw-r--r-- | test/vunary-microkernel-tester.h | 66 | ||||
-rwxr-xr-x | tools/generate-vunary-test.py | 42 |
14 files changed, 365 insertions, 13 deletions
diff --git a/BUILD.bazel b/BUILD.bazel index a9284b662..d081abdb4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1390,6 +1390,7 @@ ALL_SCALAR_MICROKERNEL_SRCS = [ "src/u32-vlog/gen/scalar-x2.c", "src/u32-vlog/gen/scalar-x3.c", "src/u32-vlog/gen/scalar-x4.c", + "src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c", "src/xx-copy/memcpy.c", "src/xx-fill/scalar-x16.c", "src/xx-pad/scalar.c", @@ -11892,6 +11893,14 @@ xnnpack_benchmark( ) xnnpack_benchmark( + name = "u64_u32_vsqrtshift_bench", + srcs = [ + "bench/u64-u32-vsqrtshift.cc", + ], + deps = MICROKERNEL_BENCHMARK_DEPS, +) + +xnnpack_benchmark( name = "s16_vlshift_bench", srcs = [ "bench/s16-vlshift.cc", @@ -14400,6 +14409,15 @@ xnnpack_unit_test( ) xnnpack_unit_test( + name = "u64_u32_vsqrtshift_test", + srcs = [ + "test/u64-u32-vsqrtshift.cc", + "test/vunary-microkernel-tester.h", + ], + deps = MICROKERNEL_TEST_DEPS, +) + +xnnpack_unit_test( name = "x8_lut_test", srcs = [ "test/lut-microkernel-tester.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d85a2a30..903c48cf5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1378,6 +1378,7 @@ SET(ALL_SCALAR_MICROKERNEL_SRCS src/u32-vlog/gen/scalar-x2.c src/u32-vlog/gen/scalar-x3.c src/u32-vlog/gen/scalar-x4.c + src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c src/xx-copy/memcpy.c src/xx-fill/scalar-x16.c src/xx-pad/scalar.c @@ -9359,6 +9360,10 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_INCLUDE_DIRECTORIES(u32-vlog-bench PRIVATE . include src) TARGET_LINK_LIBRARIES(u32-vlog-bench PRIVATE benchmark bench-utils cpuinfo fp16 pthreadpool) + ADD_EXECUTABLE(u64-u32-vsqrtshift-bench bench/f32-vsqrt.cc $<TARGET_OBJECTS:all_microkernels>) + TARGET_INCLUDE_DIRECTORIES(u64-u32-vsqrtshift-bench PRIVATE . include src) + TARGET_LINK_LIBRARIES(u64-u32-vsqrtshift-bench PRIVATE benchmark bench-utils fp16 pthreadpool microparams_init) + ADD_EXECUTABLE(s16-vlshift-bench bench/s16-vlshift.cc $<TARGET_OBJECTS:all_microkernels>) TARGET_INCLUDE_DIRECTORIES(s16-vlshift-bench PRIVATE . include src) TARGET_LINK_LIBRARIES(s16-vlshift-bench PRIVATE benchmark bench-utils cpuinfo fp16 pthreadpool) diff --git a/bench/u64-u32-vsqrtshift.cc b/bench/u64-u32-vsqrtshift.cc new file mode 100644 index 000000000..70f109459 --- /dev/null +++ b/bench/u64-u32-vsqrtshift.cc @@ -0,0 +1,72 @@ +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <algorithm> +#include <cmath> +#include <functional> +#include <random> +#include <vector> + +#include <benchmark/benchmark.h> +#include "bench/utils.h" + +#include <xnnpack.h> +#include <xnnpack/aligned-allocator.h> +#include <xnnpack/common.h> +#include <xnnpack/microfnptr.h> +#include <xnnpack/microparams-init.h> +#include <xnnpack/vunary.h> + + +static void u64_u32_vsqrtshift( + benchmark::State& state, + xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift, + xnn_init_u64_u32_sqrtshift_params_fn init_params, + benchmark::utils::IsaCheckFunction isa_check = nullptr) +{ + if (isa_check && !isa_check(state)) { + return; + } + + const size_t num_elements = state.range(0); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto u64rng = std::bind(std::uniform_int_distribution<uint64_t>(), std::ref(rng)); + + std::vector<uint64_t, AlignedAllocator<uint64_t, 64>> x(num_elements + XNN_EXTRA_BYTES / sizeof(uint64_t)); + std::vector<uint32_t, AlignedAllocator<uint32_t, 64>> y(num_elements); + std::generate(x.begin(), x.end(), std::ref(u64rng)); + std::fill(y.begin(), y.end(), UINT32_C(0xDEADBEEF)); + + xnn_u64_u32_sqrtshift_params params; + init_params(¶ms, 1 /* shift */); + for (auto _ : state) { + vsqrtshift(num_elements * sizeof(uint64_t), x.data(), y.data(), ¶ms); + } + + const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); + if (cpu_frequency != 0) { + state.counters["cpufreq"] = cpu_frequency; + } + + const size_t elements_per_iteration = num_elements; + state.counters["elements"] = + benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); + + const size_t bytes_per_iteration = num_elements * (sizeof(uint64_t) + sizeof(uint32_t)); + state.counters["bytes"] = + benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); +} + +BENCHMARK_CAPTURE(u64_u32_vsqrtshift, scalar_cvtu32_sqrt_cvtf64u32_x1, + xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1, + xnn_init_u64_u32_sqrtshift_scalar_params) + ->Apply(benchmark::utils::UnaryElementwiseParameters<uint64_t, uint32_t>) + ->UseRealTime(); + +#ifndef XNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/scripts/generate-u64-u32-vsqrtshift.sh b/scripts/generate-u64-u32-vsqrtshift.sh new file mode 100755 index 000000000..3fae8fa09 --- /dev/null +++ b/scripts/generate-u64-u32-vsqrtshift.sh @@ -0,0 +1,10 @@ +#!/bin/sh +# Copyright 2022 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +################################## Unit tests ################################# +tools/generate-vunary-test.py --spec test/u64-u32-vsqrtshift.yaml --output test/u64-u32-vsqrtshift.cc & + +wait diff --git a/src/microparams-init.c b/src/microparams-init.c index 8bb5e1d54..07aa4302a 100644 --- a/src/microparams-init.c +++ b/src/microparams-init.c @@ -4038,6 +4038,15 @@ size_t xnn_init_f32_sqrt_avx512_params( } #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 +size_t xnn_init_u64_u32_sqrtshift_scalar_params( + union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)], + uint32_t shift) +{ + assert(shift < 32); + params->scalar.shift = shift; + return sizeof(params->scalar); +} + size_t xnn_init_f32_chw_params( union xnn_f32_chw_params params[XNN_MIN_ELEMENTS(1)], uint32_t width, diff --git a/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c b/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c new file mode 100644 index 000000000..801e9aeeb --- /dev/null +++ b/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c @@ -0,0 +1,66 @@ +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> +#include <math.h> + +#include <xnnpack/math.h> +#include <xnnpack/vunary.h> + + +void xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1( + size_t batch, + const uint64_t* input, + uint32_t* output, + const union xnn_u64_u32_sqrtshift_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(input != NULL); + assert(output != NULL); + + const uint32_t vshift = params->scalar.shift; + assert(vshift < 32); + do { + const uint64_t vx = *input++; + + uint64_t vy = vx; + const uint32_t vx_hi = (uint32_t) (vx >> 32); + const uint32_t vx_lo = (uint32_t) vx; + if XNN_LIKELY(vx != 0) { + const double vf_hi = (double) vx_hi; + const double vf_lo = (double) vx_lo; + double vf = vf_hi * 0x1.0p+32 + vf_lo; + vf = sqrt(vf); + vy = math_cvt_sat_u32_f64(vf); + #if XNN_ARCH_ARM || XNN_ARCH_X86 + const uint64_t vsquared_y_less_x = math_mulext_u32((uint32_t) vy, (uint32_t) vy) - vx; + #else + const uint64_t vsquared_y_less_x = vy * vy - vx; + #endif + if XNN_UNPREDICTABLE((int64_t) (vsquared_y_less_x + vy) < 0) { + vy += 1; + } else if XNN_UNPREDICTABLE((int64_t) (vsquared_y_less_x - vy) >= 0) { + vy -= 1; + } + } + + // Match TFLM is producing incorrect result for high 64-bit inputs + const uint32_t vy_lo = (uint32_t) vy; + const uint32_t vy_hi = (uint32_t) (vy >> 32); + uint32_t vout = vy_lo | -vy_hi; + // Match TFLM is producing incorrect result for high 32-bit inputs + if XNN_LIKELY(vx_hi == 0) { + if (vout == UINT32_C(0x00010000)) { + vout -= 1; + } + } + + *output++ = vout >> vshift; + + batch -= sizeof(uint64_t); + } while (batch != 0); +} diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index 71396f3bd..401fb4861 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -1331,6 +1331,14 @@ typedef void (*xnn_f32_vsqrt_ukernel_function)( float* output, const union xnn_f32_sqrt_params* params); +// VSQRTSHIFT: Vector SQuare RooT and SHIFT elementwise + +typedef void (*xnn_u64_u32_vsqrtshift_ukernel_function)( + size_t batch, + const uint64_t* input, + uint32_t* output, + const union xnn_u64_u32_sqrtshift_params* params); + // LUT: vector LookUp Table elementwise typedef void (*xnn_x8_lut_ukernel_function)( @@ -2013,6 +2021,10 @@ typedef size_t (*xnn_init_f16_sqrt_params_fn)( typedef size_t (*xnn_init_f32_sqrt_params_fn)( union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)]); +typedef size_t (*xnn_init_u64_u32_sqrtshift_params_fn)( + union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)], + uint32_t shift); + typedef void (*xnn_init_qc8_scale_params_fn)( size_t channels, size_t channels_tile, diff --git a/src/xnnpack/microparams-init.h b/src/xnnpack/microparams-init.h index ac7e2447b..054a6f8eb 100644 --- a/src/xnnpack/microparams-init.h +++ b/src/xnnpack/microparams-init.h @@ -668,6 +668,14 @@ DECLARE_INIT_QU8_LRELU_PARAMS_FUNCTION(xnn_init_qu8_lrelu_scalar_select_params) #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 +#define DECLARE_INIT_U64_U32_SQRTSHIFT_PARAMS_FUNCTION(fn_name) \ + XNN_INTERNAL size_t fn_name( \ + union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)], \ + uint32_t shift); + +DECLARE_INIT_U64_U32_SQRTSHIFT_PARAMS_FUNCTION(xnn_init_u64_u32_sqrtshift_scalar_params) + + XNN_INTERNAL size_t xnn_init_f16_chw_params( union xnn_f16_chw_params params[XNN_MIN_ELEMENTS(1)], uint32_t width, diff --git a/src/xnnpack/microparams.h b/src/xnnpack/microparams.h index aa40e34ed..8c7c5679a 100644 --- a/src/xnnpack/microparams.h +++ b/src/xnnpack/microparams.h @@ -2391,6 +2391,14 @@ union xnn_f32_sqrt_params { }; +// SqrtShift (Square Root + Shift): used by VSQRTSHIFT microkernels. + +union xnn_u64_u32_sqrtshift_params { + struct { + uint32_t shift; + } scalar; +}; + // CHW: used by CONV/DWCONV microkernels in CHW layout with Min+Max parameters. union xnn_f16_chw_params { diff --git a/src/xnnpack/vunary.h b/src/xnnpack/vunary.h index aee3576d9..e74d6113a 100644 --- a/src/xnnpack/vunary.h +++ b/src/xnnpack/vunary.h @@ -1077,6 +1077,17 @@ DECLARE_U8_VCLAMP_UKERNEL_FUNCTION(xnn_u8_vclamp_ukernel__sse2_x64) DECLARE_U8_VCLAMP_UKERNEL_FUNCTION(xnn_u8_vclamp_ukernel__wasmsimd_x64) +#define DECLARE_U64_U32_VSQRTSHIFT_UKERNEL_FUNCTION(fn_name) \ + XNN_INTERNAL void fn_name( \ + size_t n, \ + const uint64_t* x, \ + uint32_t* y, \ + const union xnn_u64_u32_sqrtshift_params* params); + + +DECLARE_U64_U32_VSQRTSHIFT_UKERNEL_FUNCTION(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1) + + #define DECLARE_XX_VUNARY_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ size_t size, \ diff --git a/test/u64-u32-vsqrtshift.cc b/test/u64-u32-vsqrtshift.cc new file mode 100644 index 000000000..1313a8950 --- /dev/null +++ b/test/u64-u32-vsqrtshift.cc @@ -0,0 +1,43 @@ +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: test/u64-u32-vsqrtshift.yaml +// Generator: tools/generate-vunary-test.py + + +#include <gtest/gtest.h> + +#include <xnnpack/common.h> +#include <xnnpack/isa-checks.h> + +#include <xnnpack/vunary.h> +#include "vunary-microkernel-tester.h" + + +TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTF64U32_X1, batch_eq_1) { + VUnaryMicrokernelTester() + .batch_size(1) + .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1, xnn_init_u64_u32_sqrtshift_scalar_params); +} + +TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTF64U32_X1, batch_gt_1) { + for (size_t batch_size = 2; batch_size < 10; batch_size++) { + VUnaryMicrokernelTester() + .batch_size(batch_size) + .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1, xnn_init_u64_u32_sqrtshift_scalar_params); + } +} + +TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTF64U32_X1, shift) { + for (uint32_t shift = 0; shift < 32; shift++) { + for (size_t batch_size = 1; batch_size <= 5; batch_size += 1) { + VUnaryMicrokernelTester() + .batch_size(batch_size) + .shift(shift) + .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1, xnn_init_u64_u32_sqrtshift_scalar_params); + } + } +}
\ No newline at end of file diff --git a/test/u64-u32-vsqrtshift.yaml b/test/u64-u32-vsqrtshift.yaml new file mode 100644 index 000000000..81357114b --- /dev/null +++ b/test/u64-u32-vsqrtshift.yaml @@ -0,0 +1,8 @@ +# Copyright 2022 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Scalar +- name: xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtf64u32_x1 + init: xnn_init_u64_u32_sqrtshift_scalar_params diff --git a/test/vunary-microkernel-tester.h b/test/vunary-microkernel-tester.h index 568be767a..c915dbd7c 100644 --- a/test/vunary-microkernel-tester.h +++ b/test/vunary-microkernel-tester.h @@ -91,6 +91,15 @@ class VUnaryMicrokernelTester { return this->beta_; } + inline VUnaryMicrokernelTester& shift(uint32_t shift) { + this->shift_ = shift; + return *this; + } + + inline uint32_t shift() const { + return this->shift_; + } + inline VUnaryMicrokernelTester& qmin(uint8_t qmin) { this->qmin_ = qmin; return *this; @@ -1048,6 +1057,62 @@ class VUnaryMicrokernelTester { } } + void Test(xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift, xnn_init_u64_u32_sqrtshift_params_fn init_params) const { + ASSERT_FALSE(inplace()); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto u64rng = std::bind( std::uniform_int_distribution<uint64_t>(), std::ref(rng)); + + std::vector<uint64_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint64_t)); + std::vector<uint32_t> y(batch_size()); + std::vector<uint32_t> y_ref(batch_size()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u64rng)); + std::fill(y.begin(), y.end(), UINT32_C(0xDEADBEEF)); + + // Compute reference results. + for (size_t i = 0; i < batch_size(); i++) { + const uint64_t x_value = x[i]; + uint32_t y_value = 0; + // Match TFLM semantics, including bugs + if (uint32_t(x_value) == x_value) { + y_value = (uint32_t) std::lrint(std::sqrt(double(int64_t(uint64_t(x_value))))); + y_value = std::min<uint32_t>(y_value, std::numeric_limits<uint16_t>::max()); + } else if (x_value != 0) { + uint64_t y0 = x_value >> 1; + uint64_t y1 = (y0 + x_value / y0) >> 1; + do { + y0 = y1; + y1 = (y0 + x_value / y0) >> 1; + } while (y1 < y0); + + // y0 is sqrt(x_value) rounded down, round up if needed + if (int64_t(y0 * y0 + y0 - x_value) < 0) { + y0 += 1; + } + y_value = static_cast<uint32_t>(std::min<uint64_t>(y0, std::numeric_limits<uint32_t>::max())); + } + y_ref[i] = y_value >> shift(); + } + + // Prepare parameters. + union xnn_u64_u32_sqrtshift_params params; + init_params(¶ms, shift()); + + // Call optimized micro-kernel. + vsqrtshift(batch_size() * sizeof(uint64_t), x.data(), y.data(), ¶ms); + + // Verify results. + for (size_t i = 0; i < batch_size(); i++) { + ASSERT_EQ(y_ref[i], y[i]) + << "at " << i << " / " << batch_size() + << ", x[" << i << "]: " << x[i] + << ", shift: " << shift(); + } + } + } + private: size_t batch_size_ = 1; bool inplace_ = false; @@ -1055,6 +1120,7 @@ class VUnaryMicrokernelTester { float prescale_ = 1.0f; float alpha_ = 1.0f; float beta_ = 1.0f; + uint32_t shift_ = 1; uint8_t qmin_ = 0; uint8_t qmax_ = 255; size_t iterations_ = 15; diff --git a/tools/generate-vunary-test.py b/tools/generate-vunary-test.py index 26a9538b8..aeb83f558 100755 --- a/tools/generate-vunary-test.py +++ b/tools/generate-vunary-test.py @@ -27,7 +27,7 @@ parser.set_defaults(defines=list()) def split_ukernel_name(name): - match = re.fullmatch(r"xnn_(s8|u8|f16|f32)_v(abs|clamp|elu|hswish|lrelu|neg|relu|rndd|rndne|rndu|rndz|sigmoid|sqr|sqrt)_(fact_)?ukernel__(.+)_x(\d+)", name) + match = re.fullmatch(r"xnn_(s8|u8|f16|f32|u32|u64)(_(s8|u8|f16|f32|u32|u64))*_v(abs|clamp|elu|hswish|lrelu|neg|relu|rndd|rndne|rndu|rndz|sigmoid|sqr|sqrt|sqrtshift)_(fact_)?ukernel__(.+)_x(\d+)", name) if match is None: raise ValueError("Unexpected microkernel name: " + name) op_type = { @@ -45,10 +45,11 @@ def split_ukernel_name(name): "sigmoid": "Sigmoid", "sqr": "Square", "sqrt": "SquareRoot", - }[match.group(2)] - batch_tile = int(match.group(5)) + "sqrtshift": "SquareRootShift", + }[match.group(4)] + batch_tile = int(match.group(7)) - arch, isa = xnncommon.parse_target_name(target_name=match.group(4)) + arch, isa = xnncommon.parse_target_name(target_name=match.group(6)) return op_type, batch_tile, arch, isa @@ -92,16 +93,17 @@ TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) { } } -TEST(${TEST_NAME}, inplace) { - $if ISA_CHECK: - ${ISA_CHECK}; - for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { - VUnaryMicrokernelTester() - .batch_size(batch_size) - .inplace(true) - .Test(${", ".join(TEST_ARGS)}); +$if OP_TYPE != "SquareRootShift": + TEST(${TEST_NAME}, inplace) { + $if ISA_CHECK: + ${ISA_CHECK}; + for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { + VUnaryMicrokernelTester() + .batch_size(batch_size) + .inplace(true) + .Test(${", ".join(TEST_ARGS)}); + } } -} $if OP_TYPE == "Clamp": TEST(${TEST_NAME}, qmin) { @@ -183,6 +185,20 @@ $if OP_TYPE == "LeakyReLU": } } } + +$if OP_TYPE == "SquareRootShift": + TEST(${TEST_NAME}, shift) { + $if ISA_CHECK: + ${ISA_CHECK}; + for (uint32_t shift = 0; shift < 32; shift++) { + for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { + VUnaryMicrokernelTester() + .batch_size(batch_size) + .shift(shift) + .Test(${", ".join(TEST_ARGS)}); + } + } + } """ |