aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@google.com>2022-02-04 00:32:09 -0800
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-02-04 00:33:22 -0800
commitaf1671ab277a0742f8c7251b6b4ad1a16aba79bd (patch)
tree1a04932dba77cbd5d8a5e602a2ab58d8332ea08b
parent4b90bee3268e790231431c9b7fd3e92eb0dd9dbd (diff)
downloadXNNPACK-af1671ab277a0742f8c7251b6b4ad1a16aba79bd.tar.gz
Support FP32 weights in F16 PReLU operator
PiperOrigin-RevId: 426333774
-rw-r--r--src/operators/prelu-nc.c17
-rw-r--r--src/packing.c27
-rw-r--r--src/xnnpack/pack.h22
-rw-r--r--test/prelu-nc.cc13
-rw-r--r--test/prelu-operator-tester.h42
5 files changed, 114 insertions, 7 deletions
diff --git a/src/operators/prelu-nc.c b/src/operators/prelu-nc.c
index a2e46c0f9..2fc98fa3c 100644
--- a/src/operators/prelu-nc.c
+++ b/src/operators/prelu-nc.c
@@ -13,6 +13,7 @@
#include <xnnpack/allocator.h>
#include <xnnpack/log.h>
#include <xnnpack/operator.h>
+#include <xnnpack/pack.h>
#include <xnnpack/params-init.h>
#include <xnnpack/params.h>
@@ -23,9 +24,10 @@ static enum xnn_status create_prelu_nc(
size_t output_stride,
const void* negative_slope,
uint32_t flags,
+ uint32_t log2_weights_element_size,
+ xnn_pack_prelu_w_function pack_prelu_w,
uint32_t datatype_init_flags,
enum xnn_operator_type operator_type,
- uint32_t log2_weights_element_size,
xnn_operator_t* prelu_op_out)
{
xnn_operator_t prelu_op = NULL;
@@ -89,7 +91,7 @@ static enum xnn_status create_prelu_nc(
packed_weights_size, xnn_operator_type_to_string(operator_type));
goto error;
}
- memcpy(prelu_op->packed_weights, negative_slope, channels << log2_weights_element_size);
+ pack_prelu_w(channels, negative_slope, prelu_op->packed_weights);
prelu_op->channels = channels;
prelu_op->input_pixel_stride = input_stride;
@@ -117,11 +119,17 @@ enum xnn_status xnn_create_prelu_nc_f16(
uint32_t flags,
xnn_operator_t* prelu_op_out)
{
+ xnn_pack_prelu_w_function pack_prelu_w = (xnn_pack_prelu_w_function) xnn_pack_f16_prelu_w;
+ if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) {
+ pack_prelu_w = (xnn_pack_prelu_w_function) xnn_pack_f32_to_f16_prelu_w;
+ }
+
return create_prelu_nc(
channels, input_stride, output_stride,
negative_slope, flags,
- XNN_INIT_FLAG_F16, xnn_operator_type_prelu_nc_f16,
1 /* log2(sizeof(uint16_t)) */,
+ pack_prelu_w,
+ XNN_INIT_FLAG_F16, xnn_operator_type_prelu_nc_f16,
prelu_op_out);
}
@@ -136,8 +144,9 @@ enum xnn_status xnn_create_prelu_nc_f32(
return create_prelu_nc(
channels, input_stride, output_stride,
negative_slope, flags,
- XNN_INIT_FLAG_F32, xnn_operator_type_prelu_nc_f32,
2 /* log2(sizeof(float)) */,
+ (xnn_pack_prelu_w_function) xnn_pack_f32_prelu_w,
+ XNN_INIT_FLAG_F32, xnn_operator_type_prelu_nc_f32,
prelu_op_out);
}
diff --git a/src/packing.c b/src/packing.c
index 61df3cba6..ca2979924 100644
--- a/src/packing.c
+++ b/src/packing.c
@@ -8,6 +8,7 @@
#include <stdint.h>
#include <stddef.h>
+#include <string.h>
#include <fp16.h>
@@ -2018,3 +2019,29 @@ void xnn_pack_f32_to_f16_vmulcaddc_w(
packed_w += cr - cr_block_size;
}
}
+
+void xnn_pack_f32_prelu_w(
+ size_t c,
+ const float* s,
+ float* packed_w)
+{
+ memcpy(packed_w, s, c * sizeof(float));
+}
+
+void xnn_pack_f16_prelu_w(
+ size_t c,
+ const uint16_t* s,
+ uint16_t* packed_w)
+{
+ memcpy(packed_w, s, c * sizeof(uint16_t));
+}
+
+void xnn_pack_f32_to_f16_prelu_w(
+ size_t c,
+ const float* s,
+ uint16_t* packed_w)
+{
+ do {
+ *packed_w++ = fp16_ieee_from_fp32_value(*s++);
+ } while (--c != 0);
+}
diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h
index 8b0125e92..76bf6deff 100644
--- a/src/xnnpack/pack.h
+++ b/src/xnnpack/pack.h
@@ -677,6 +677,28 @@ XNN_INTERNAL void xnn_pack_f32_to_f16_vmulcaddc_w(
uint16_t* packed_w,
const void* params);
+
+typedef void (*xnn_pack_prelu_w_function)(
+ size_t c,
+ const void* s,
+ void* packed_w);
+
+XNN_INTERNAL void xnn_pack_f32_prelu_w(
+ size_t c,
+ const float* s,
+ float* packed_w);
+
+XNN_INTERNAL void xnn_pack_f16_prelu_w(
+ size_t c,
+ const uint16_t* s,
+ uint16_t* packed_w);
+
+XNN_INTERNAL void xnn_pack_f32_to_f16_prelu_w(
+ size_t c,
+ const float* s,
+ uint16_t* packed_w);
+
+
#ifdef __cplusplus
} // extern "C"
#endif
diff --git a/test/prelu-nc.cc b/test/prelu-nc.cc
index 351e45055..8a446f618 100644
--- a/test/prelu-nc.cc
+++ b/test/prelu-nc.cc
@@ -108,6 +108,19 @@ TEST(PRELU_NC_F16, large_batch_with_x_stride_and_y_stride) {
}
}
+TEST(PRELU_NC_F16, fp32_weights) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+ .channels(channels)
+ .x_stride(123)
+ .y_stride(117)
+ .weights_type(PReLUOperatorTester::WeightsType::FP32)
+ .iterations(1)
+ .TestF16();
+ }
+}
+
TEST(PRELU_NC_F32, unit_batch) {
for (size_t channels = 1; channels < xnn_params.f32.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f32.prelu.channel_tile - 1)) {
diff --git a/test/prelu-operator-tester.h b/test/prelu-operator-tester.h
index 53565ce08..0031c616e 100644
--- a/test/prelu-operator-tester.h
+++ b/test/prelu-operator-tester.h
@@ -22,6 +22,11 @@
class PReLUOperatorTester {
public:
+ enum class WeightsType {
+ Default,
+ FP32,
+ };
+
inline PReLUOperatorTester& batch_size(size_t batch_size) {
assert(batch_size != 0);
this->batch_size_ = batch_size;
@@ -72,6 +77,15 @@ class PReLUOperatorTester {
}
}
+ inline PReLUOperatorTester& weights_type(WeightsType weights_type) {
+ this->weights_type_ = weights_type;
+ return *this;
+ }
+
+ inline WeightsType weights_type() const {
+ return this->weights_type_;
+ }
+
inline PReLUOperatorTester& iterations(size_t iterations) {
this->iterations_ = iterations;
return *this;
@@ -82,6 +96,15 @@ class PReLUOperatorTester {
}
void TestF16() const {
+ switch (weights_type()) {
+ case WeightsType::Default:
+ break;
+ case WeightsType::FP32:
+ break;
+ default:
+ GTEST_FAIL() << "unexpected weights type";
+ }
+
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
@@ -91,18 +114,20 @@ class PReLUOperatorTester {
std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
std::vector<uint16_t> w(channels());
+ std::vector<float> w_as_float(channels());
std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
std::vector<float> y_ref(batch_size() * channels());
for (size_t iteration = 0; iteration < iterations(); iteration++) {
std::generate(x.begin(), x.end(), std::ref(f16irng));
std::generate(w.begin(), w.end(), std::ref(f16wrng));
+ std::transform(w.cbegin(), w.cend(), w_as_float.begin(), fp16_ieee_to_fp32_value);
std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
// Compute reference results, without clamping.
for (size_t i = 0; i < batch_size(); i++) {
for (size_t c = 0; c < channels(); c++) {
const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]);
- const float w_value = fp16_ieee_to_fp32_value(w[c]);
+ const float w_value = w_as_float[c];
y_ref[i * channels() + c] = signbit(x_value) ? x_value * w_value : x_value;
}
}
@@ -111,11 +136,19 @@ class PReLUOperatorTester {
ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
xnn_operator_t prelu_op = nullptr;
+ const void* negative_slope_data = w.data();
+ if (weights_type() == WeightsType::FP32) {
+ negative_slope_data = w_as_float.data();
+ }
+ uint32_t flags = 0;
+ if (weights_type() == WeightsType::FP32) {
+ flags |= XNN_FLAG_FP32_STATIC_WEIGHTS;
+ }
ASSERT_EQ(xnn_status_success,
xnn_create_prelu_nc_f16(
channels(), x_stride(), y_stride(),
- w.data(),
- 0, &prelu_op));
+ negative_slope_data,
+ flags, &prelu_op));
ASSERT_NE(nullptr, prelu_op);
// Smart pointer to automatically delete prelu_op.
@@ -145,6 +178,8 @@ class PReLUOperatorTester {
}
void TestF32() const {
+ ASSERT_EQ(weights_type(), WeightsType::Default);
+
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
@@ -208,5 +243,6 @@ class PReLUOperatorTester {
size_t channels_{1};
size_t x_stride_{0};
size_t y_stride_{0};
+ WeightsType weights_type_{WeightsType::Default};
size_t iterations_{15};
};