aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@google.com>2022-08-22 11:18:15 -0700
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-08-22 11:19:24 -0700
commit1ad54d1f839712bb1d0ae48d4b7bdc5ced9f2694 (patch)
tree0ba3fb65809fc9e8f0e0dae70e4bde1e09cc754a
parent8025ee59c5c8e0446bada40d565e7660199a42bc (diff)
downloadXNNPACK-1ad54d1f839712bb1d0ae48d4b7bdc5ced9f2694.tar.gz
Use explicit shift parameter in U64->U32 VSQRTSHIFT microkernel
PiperOrigin-RevId: 469237060
-rw-r--r--bench/u64-u32-vsqrtshift.cc8
-rw-r--r--src/microparams-init.c9
-rw-r--r--src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c7
-rw-r--r--src/xnnpack/microfnptr.h6
-rw-r--r--src/xnnpack/microparams-init.h8
-rw-r--r--src/xnnpack/vunary.h3
-rw-r--r--test/u64-u32-vsqrtshift.cc6
-rw-r--r--test/u64-u32-vsqrtshift.yaml1
-rw-r--r--test/vunary-microkernel-tester.h8
9 files changed, 12 insertions, 44 deletions
diff --git a/bench/u64-u32-vsqrtshift.cc b/bench/u64-u32-vsqrtshift.cc
index 239d66b47..2a0a94623 100644
--- a/bench/u64-u32-vsqrtshift.cc
+++ b/bench/u64-u32-vsqrtshift.cc
@@ -23,7 +23,6 @@
static void u64_u32_vsqrtshift(
benchmark::State& state,
xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift,
- xnn_init_u64_u32_sqrtshift_params_fn init_params,
benchmark::utils::IsaCheckFunction isa_check = nullptr)
{
if (isa_check && !isa_check(state)) {
@@ -41,10 +40,8 @@ static void u64_u32_vsqrtshift(
std::generate(x.begin(), x.end(), std::ref(u64rng));
std::fill(y.begin(), y.end(), UINT32_C(0xDEADBEEF));
- xnn_u64_u32_sqrtshift_params params;
- init_params(&params, 1 /* shift */);
for (auto _ : state) {
- vsqrtshift(num_elements * sizeof(uint64_t), x.data(), y.data(), &params);
+ vsqrtshift(num_elements * sizeof(uint64_t), x.data(), y.data(), 1 /* shift */);
}
const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
@@ -62,8 +59,7 @@ static void u64_u32_vsqrtshift(
}
BENCHMARK_CAPTURE(u64_u32_vsqrtshift, scalar_cvtu32_sqrt_cvtu32f64_x1,
- xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1,
- xnn_init_u64_u32_sqrtshift_scalar_params)
+ xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1)
->Apply(benchmark::utils::UnaryElementwiseParameters<uint64_t, uint32_t>)
->UseRealTime();
diff --git a/src/microparams-init.c b/src/microparams-init.c
index 07aa4302a..8bb5e1d54 100644
--- a/src/microparams-init.c
+++ b/src/microparams-init.c
@@ -4038,15 +4038,6 @@ size_t xnn_init_f32_sqrt_avx512_params(
}
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
-size_t xnn_init_u64_u32_sqrtshift_scalar_params(
- union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)],
- uint32_t shift)
-{
- assert(shift < 32);
- params->scalar.shift = shift;
- return sizeof(params->scalar);
-}
-
size_t xnn_init_f32_chw_params(
union xnn_f32_chw_params params[XNN_MIN_ELEMENTS(1)],
uint32_t width,
diff --git a/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c b/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c
index 4a277cd0c..3d17b1543 100644
--- a/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c
+++ b/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c
@@ -16,14 +16,13 @@ void xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1(
size_t batch,
const uint64_t* input,
uint32_t* output,
- const union xnn_u64_u32_sqrtshift_params params[restrict XNN_MIN_ELEMENTS(1)])
+ uint32_t shift)
{
assert(batch != 0);
assert(input != NULL);
assert(output != NULL);
+ assert(shift < 32);
- const uint32_t vshift = params->scalar.shift;
- assert(vshift < 32);
do {
const uint64_t vx = *input++;
@@ -59,7 +58,7 @@ void xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1(
}
}
- *output++ = vout >> vshift;
+ *output++ = vout >> shift;
batch -= sizeof(uint64_t);
} while (batch != 0);
diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h
index 401fb4861..9cfb43a06 100644
--- a/src/xnnpack/microfnptr.h
+++ b/src/xnnpack/microfnptr.h
@@ -1337,7 +1337,7 @@ typedef void (*xnn_u64_u32_vsqrtshift_ukernel_function)(
size_t batch,
const uint64_t* input,
uint32_t* output,
- const union xnn_u64_u32_sqrtshift_params* params);
+ uint32_t shift);
// LUT: vector LookUp Table elementwise
@@ -2021,10 +2021,6 @@ typedef size_t (*xnn_init_f16_sqrt_params_fn)(
typedef size_t (*xnn_init_f32_sqrt_params_fn)(
union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)]);
-typedef size_t (*xnn_init_u64_u32_sqrtshift_params_fn)(
- union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)],
- uint32_t shift);
-
typedef void (*xnn_init_qc8_scale_params_fn)(
size_t channels,
size_t channels_tile,
diff --git a/src/xnnpack/microparams-init.h b/src/xnnpack/microparams-init.h
index 054a6f8eb..ac7e2447b 100644
--- a/src/xnnpack/microparams-init.h
+++ b/src/xnnpack/microparams-init.h
@@ -668,14 +668,6 @@ DECLARE_INIT_QU8_LRELU_PARAMS_FUNCTION(xnn_init_qu8_lrelu_scalar_select_params)
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
-#define DECLARE_INIT_U64_U32_SQRTSHIFT_PARAMS_FUNCTION(fn_name) \
- XNN_INTERNAL size_t fn_name( \
- union xnn_u64_u32_sqrtshift_params params[XNN_MIN_ELEMENTS(1)], \
- uint32_t shift);
-
-DECLARE_INIT_U64_U32_SQRTSHIFT_PARAMS_FUNCTION(xnn_init_u64_u32_sqrtshift_scalar_params)
-
-
XNN_INTERNAL size_t xnn_init_f16_chw_params(
union xnn_f16_chw_params params[XNN_MIN_ELEMENTS(1)],
uint32_t width,
diff --git a/src/xnnpack/vunary.h b/src/xnnpack/vunary.h
index 4441cd368..21d63c12f 100644
--- a/src/xnnpack/vunary.h
+++ b/src/xnnpack/vunary.h
@@ -1082,8 +1082,7 @@ DECLARE_U8_VCLAMP_UKERNEL_FUNCTION(xnn_u8_vclamp_ukernel__wasmsimd_x64)
size_t n, \
const uint64_t* x, \
uint32_t* y, \
- const union xnn_u64_u32_sqrtshift_params* params);
-
+ uint32_t shift);
DECLARE_U64_U32_VSQRTSHIFT_UKERNEL_FUNCTION(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1)
diff --git a/test/u64-u32-vsqrtshift.cc b/test/u64-u32-vsqrtshift.cc
index 2d721931d..5cb71f4af 100644
--- a/test/u64-u32-vsqrtshift.cc
+++ b/test/u64-u32-vsqrtshift.cc
@@ -20,14 +20,14 @@
TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTU32F64_X1, batch_eq_1) {
VUnaryMicrokernelTester()
.batch_size(1)
- .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1, xnn_init_u64_u32_sqrtshift_scalar_params);
+ .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1);
}
TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTU32F64_X1, batch_gt_1) {
for (size_t batch_size = 2; batch_size < 10; batch_size++) {
VUnaryMicrokernelTester()
.batch_size(batch_size)
- .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1, xnn_init_u64_u32_sqrtshift_scalar_params);
+ .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1);
}
}
@@ -37,7 +37,7 @@ TEST(U64_U32_VSQRTSHIFT__SCALAR_CVTU32_SQRT_CVTU32F64_X1, shift) {
VUnaryMicrokernelTester()
.batch_size(batch_size)
.shift(shift)
- .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1, xnn_init_u64_u32_sqrtshift_scalar_params);
+ .Test(xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1);
}
}
} \ No newline at end of file
diff --git a/test/u64-u32-vsqrtshift.yaml b/test/u64-u32-vsqrtshift.yaml
index e13a707ff..155e3cae9 100644
--- a/test/u64-u32-vsqrtshift.yaml
+++ b/test/u64-u32-vsqrtshift.yaml
@@ -5,4 +5,3 @@
# Scalar
- name: xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1
- init: xnn_init_u64_u32_sqrtshift_scalar_params
diff --git a/test/vunary-microkernel-tester.h b/test/vunary-microkernel-tester.h
index c915dbd7c..cee7387fd 100644
--- a/test/vunary-microkernel-tester.h
+++ b/test/vunary-microkernel-tester.h
@@ -1057,7 +1057,7 @@ class VUnaryMicrokernelTester {
}
}
- void Test(xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift, xnn_init_u64_u32_sqrtshift_params_fn init_params) const {
+ void Test(xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift) const {
ASSERT_FALSE(inplace());
std::random_device random_device;
@@ -1096,12 +1096,8 @@ class VUnaryMicrokernelTester {
y_ref[i] = y_value >> shift();
}
- // Prepare parameters.
- union xnn_u64_u32_sqrtshift_params params;
- init_params(&params, shift());
-
// Call optimized micro-kernel.
- vsqrtshift(batch_size() * sizeof(uint64_t), x.data(), y.data(), &params);
+ vsqrtshift(batch_size() * sizeof(uint64_t), x.data(), y.data(), shift());
// Verify results.
for (size_t i = 0; i < batch_size(); i++) {