diff options
author | Marat Dukhan <maratek@google.com> | 2022-02-04 00:32:09 -0800 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2022-02-04 00:33:22 -0800 |
commit | af1671ab277a0742f8c7251b6b4ad1a16aba79bd (patch) | |
tree | 1a04932dba77cbd5d8a5e602a2ab58d8332ea08b | |
parent | 4b90bee3268e790231431c9b7fd3e92eb0dd9dbd (diff) | |
download | XNNPACK-af1671ab277a0742f8c7251b6b4ad1a16aba79bd.tar.gz |
Support FP32 weights in F16 PReLU operator
PiperOrigin-RevId: 426333774
-rw-r--r-- | src/operators/prelu-nc.c | 17 | ||||
-rw-r--r-- | src/packing.c | 27 | ||||
-rw-r--r-- | src/xnnpack/pack.h | 22 | ||||
-rw-r--r-- | test/prelu-nc.cc | 13 | ||||
-rw-r--r-- | test/prelu-operator-tester.h | 42 |
5 files changed, 114 insertions, 7 deletions
diff --git a/src/operators/prelu-nc.c b/src/operators/prelu-nc.c index a2e46c0f9..2fc98fa3c 100644 --- a/src/operators/prelu-nc.c +++ b/src/operators/prelu-nc.c @@ -13,6 +13,7 @@ #include <xnnpack/allocator.h> #include <xnnpack/log.h> #include <xnnpack/operator.h> +#include <xnnpack/pack.h> #include <xnnpack/params-init.h> #include <xnnpack/params.h> @@ -23,9 +24,10 @@ static enum xnn_status create_prelu_nc( size_t output_stride, const void* negative_slope, uint32_t flags, + uint32_t log2_weights_element_size, + xnn_pack_prelu_w_function pack_prelu_w, uint32_t datatype_init_flags, enum xnn_operator_type operator_type, - uint32_t log2_weights_element_size, xnn_operator_t* prelu_op_out) { xnn_operator_t prelu_op = NULL; @@ -89,7 +91,7 @@ static enum xnn_status create_prelu_nc( packed_weights_size, xnn_operator_type_to_string(operator_type)); goto error; } - memcpy(prelu_op->packed_weights, negative_slope, channels << log2_weights_element_size); + pack_prelu_w(channels, negative_slope, prelu_op->packed_weights); prelu_op->channels = channels; prelu_op->input_pixel_stride = input_stride; @@ -117,11 +119,17 @@ enum xnn_status xnn_create_prelu_nc_f16( uint32_t flags, xnn_operator_t* prelu_op_out) { + xnn_pack_prelu_w_function pack_prelu_w = (xnn_pack_prelu_w_function) xnn_pack_f16_prelu_w; + if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) { + pack_prelu_w = (xnn_pack_prelu_w_function) xnn_pack_f32_to_f16_prelu_w; + } + return create_prelu_nc( channels, input_stride, output_stride, negative_slope, flags, - XNN_INIT_FLAG_F16, xnn_operator_type_prelu_nc_f16, 1 /* log2(sizeof(uint16_t)) */, + pack_prelu_w, + XNN_INIT_FLAG_F16, xnn_operator_type_prelu_nc_f16, prelu_op_out); } @@ -136,8 +144,9 @@ enum xnn_status xnn_create_prelu_nc_f32( return create_prelu_nc( channels, input_stride, output_stride, negative_slope, flags, - XNN_INIT_FLAG_F32, xnn_operator_type_prelu_nc_f32, 2 /* log2(sizeof(float)) */, + (xnn_pack_prelu_w_function) xnn_pack_f32_prelu_w, + XNN_INIT_FLAG_F32, xnn_operator_type_prelu_nc_f32, prelu_op_out); } diff --git a/src/packing.c b/src/packing.c index 61df3cba6..ca2979924 100644 --- a/src/packing.c +++ b/src/packing.c @@ -8,6 +8,7 @@ #include <stdint.h> #include <stddef.h> +#include <string.h> #include <fp16.h> @@ -2018,3 +2019,29 @@ void xnn_pack_f32_to_f16_vmulcaddc_w( packed_w += cr - cr_block_size; } } + +void xnn_pack_f32_prelu_w( + size_t c, + const float* s, + float* packed_w) +{ + memcpy(packed_w, s, c * sizeof(float)); +} + +void xnn_pack_f16_prelu_w( + size_t c, + const uint16_t* s, + uint16_t* packed_w) +{ + memcpy(packed_w, s, c * sizeof(uint16_t)); +} + +void xnn_pack_f32_to_f16_prelu_w( + size_t c, + const float* s, + uint16_t* packed_w) +{ + do { + *packed_w++ = fp16_ieee_from_fp32_value(*s++); + } while (--c != 0); +} diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index 8b0125e92..76bf6deff 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -677,6 +677,28 @@ XNN_INTERNAL void xnn_pack_f32_to_f16_vmulcaddc_w( uint16_t* packed_w, const void* params); + +typedef void (*xnn_pack_prelu_w_function)( + size_t c, + const void* s, + void* packed_w); + +XNN_INTERNAL void xnn_pack_f32_prelu_w( + size_t c, + const float* s, + float* packed_w); + +XNN_INTERNAL void xnn_pack_f16_prelu_w( + size_t c, + const uint16_t* s, + uint16_t* packed_w); + +XNN_INTERNAL void xnn_pack_f32_to_f16_prelu_w( + size_t c, + const float* s, + uint16_t* packed_w); + + #ifdef __cplusplus } // extern "C" #endif diff --git a/test/prelu-nc.cc b/test/prelu-nc.cc index 351e45055..8a446f618 100644 --- a/test/prelu-nc.cc +++ b/test/prelu-nc.cc @@ -108,6 +108,19 @@ TEST(PRELU_NC_F16, large_batch_with_x_stride_and_y_stride) { } } +TEST(PRELU_NC_F16, fp32_weights) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(3 * xnn_params.f16.prelu.row_tile + 1) + .channels(channels) + .x_stride(123) + .y_stride(117) + .weights_type(PReLUOperatorTester::WeightsType::FP32) + .iterations(1) + .TestF16(); + } +} + TEST(PRELU_NC_F32, unit_batch) { for (size_t channels = 1; channels < xnn_params.f32.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f32.prelu.channel_tile - 1)) { diff --git a/test/prelu-operator-tester.h b/test/prelu-operator-tester.h index 53565ce08..0031c616e 100644 --- a/test/prelu-operator-tester.h +++ b/test/prelu-operator-tester.h @@ -22,6 +22,11 @@ class PReLUOperatorTester { public: + enum class WeightsType { + Default, + FP32, + }; + inline PReLUOperatorTester& batch_size(size_t batch_size) { assert(batch_size != 0); this->batch_size_ = batch_size; @@ -72,6 +77,15 @@ class PReLUOperatorTester { } } + inline PReLUOperatorTester& weights_type(WeightsType weights_type) { + this->weights_type_ = weights_type; + return *this; + } + + inline WeightsType weights_type() const { + return this->weights_type_; + } + inline PReLUOperatorTester& iterations(size_t iterations) { this->iterations_ = iterations; return *this; @@ -82,6 +96,15 @@ class PReLUOperatorTester { } void TestF16() const { + switch (weights_type()) { + case WeightsType::Default: + break; + case WeightsType::FP32: + break; + default: + GTEST_FAIL() << "unexpected weights type"; + } + std::random_device random_device; auto rng = std::mt19937(random_device()); auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng); @@ -91,18 +114,20 @@ class PReLUOperatorTester { std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); std::vector<uint16_t> w(channels()); + std::vector<float> w_as_float(channels()); std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); std::vector<float> y_ref(batch_size() * channels()); for (size_t iteration = 0; iteration < iterations(); iteration++) { std::generate(x.begin(), x.end(), std::ref(f16irng)); std::generate(w.begin(), w.end(), std::ref(f16wrng)); + std::transform(w.cbegin(), w.cend(), w_as_float.begin(), fp16_ieee_to_fp32_value); std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); // Compute reference results, without clamping. for (size_t i = 0; i < batch_size(); i++) { for (size_t c = 0; c < channels(); c++) { const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]); - const float w_value = fp16_ieee_to_fp32_value(w[c]); + const float w_value = w_as_float[c]; y_ref[i * channels() + c] = signbit(x_value) ? x_value * w_value : x_value; } } @@ -111,11 +136,19 @@ class PReLUOperatorTester { ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); xnn_operator_t prelu_op = nullptr; + const void* negative_slope_data = w.data(); + if (weights_type() == WeightsType::FP32) { + negative_slope_data = w_as_float.data(); + } + uint32_t flags = 0; + if (weights_type() == WeightsType::FP32) { + flags |= XNN_FLAG_FP32_STATIC_WEIGHTS; + } ASSERT_EQ(xnn_status_success, xnn_create_prelu_nc_f16( channels(), x_stride(), y_stride(), - w.data(), - 0, &prelu_op)); + negative_slope_data, + flags, &prelu_op)); ASSERT_NE(nullptr, prelu_op); // Smart pointer to automatically delete prelu_op. @@ -145,6 +178,8 @@ class PReLUOperatorTester { } void TestF32() const { + ASSERT_EQ(weights_type(), WeightsType::Default); + std::random_device random_device; auto rng = std::mt19937(random_device()); auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng); @@ -208,5 +243,6 @@ class PReLUOperatorTester { size_t channels_{1}; size_t x_stride_{0}; size_t y_stride_{0}; + WeightsType weights_type_{WeightsType::Default}; size_t iterations_{15}; }; |