aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@google.com>2022-02-04 02:18:23 -0800
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-02-04 02:19:34 -0800
commit170f95ad59dff6b4c05a428af12b75517ad82648 (patch)
tree9f74abfe7d540aa77f810fae403acb43b72521c9
parent5756a927fc5044bdcfebe57d4bd84408ca0a0975 (diff)
downloadXNNPACK-170f95ad59dff6b4c05a428af12b75517ad82648.tar.gz
Support PReLU in FP16 graph rewriting
PiperOrigin-RevId: 426349832
-rw-r--r--src/subgraph.c9
-rw-r--r--src/subgraph/prelu.c55
2 files changed, 50 insertions, 14 deletions
diff --git a/src/subgraph.c b/src/subgraph.c
index 903331405..9ceea30ad 100644
--- a/src/subgraph.c
+++ b/src/subgraph.c
@@ -594,6 +594,11 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
// Check that all operators in the subgraph are supported in FP16, bail out on any unsupported one.
for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
struct xnn_node* node = &subgraph->nodes[n];
+ if (node->type == xnn_node_type_invalid) {
+ // Node was fused away, skip.
+ continue;
+ }
+
if (node->compute_type != xnn_compute_type_fp32) {
xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not FP32", n, xnn_node_type_to_string(node->type));
return;
@@ -613,6 +618,7 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
case xnn_node_type_depthwise_convolution_2d:
case xnn_node_type_global_average_pooling_2d:
case xnn_node_type_hardswish:
+ case xnn_node_type_prelu:
case xnn_node_type_static_constant_pad:
break;
default:
@@ -623,13 +629,14 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
}
// Annotate Values to be converted to FP16 as FP16-compatible.
- // Note that static weights in [Depthwise] Convolution & Fully Connected Nodes remain FP32,
+ // Note that static weights in [Depthwise] Convolution, Fully Connected, and PReLU Nodes remain FP32,
// they will be converted to FP16 during weight repacking when the operator is created.
for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
struct xnn_node* node = &subgraph->nodes[n];
switch (node->type) {
case xnn_node_type_convolution_2d:
case xnn_node_type_depthwise_convolution_2d:
+ case xnn_node_type_prelu:
subgraph->values[node->inputs[0]].fp16_compatible = true;
subgraph->values[node->outputs[0]].fp16_compatible = true;
break;
diff --git a/src/subgraph/prelu.c b/src/subgraph/prelu.c
index b6d3b1c0a..d4be75295 100644
--- a/src/subgraph/prelu.c
+++ b/src/subgraph/prelu.c
@@ -19,8 +19,6 @@ static enum xnn_status create_prelu_operator(
size_t num_values,
struct xnn_operator_data* opdata)
{
- assert(node->compute_type == xnn_compute_type_fp32);
-
assert(node->num_inputs == 2);
const uint32_t input_id = node->inputs[0];
assert(input_id != XNN_INVALID_VALUE_ID);
@@ -37,11 +35,27 @@ static enum xnn_status create_prelu_operator(
const size_t num_input_dims = values[input_id].shape.num_dims;
const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
- const enum xnn_status status = xnn_create_prelu_nc_f32(
- channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
- values[slope_id].data /* negative slope */,
- node->flags,
- &opdata->operator_object);
+ enum xnn_status status;
+ switch (node->compute_type) {
+#ifndef XNN_NO_F16_OPERATORS
+ case xnn_compute_type_fp16:
+ status = xnn_create_prelu_nc_f16(
+ channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
+ values[slope_id].data /* negative slope */,
+ node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS,
+ &opdata->operator_object);
+ break;
+#endif // XNN_NO_F16_OPERATORS
+ case xnn_compute_type_fp32:
+ status = xnn_create_prelu_nc_f32(
+ channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
+ values[slope_id].data /* negative slope */,
+ node->flags,
+ &opdata->operator_object);
+ break;
+ default:
+ XNN_UNREACHABLE;
+ }
if (status == xnn_status_success) {
opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
opdata->inputs[0] = input_id;
@@ -72,12 +86,27 @@ static enum xnn_status setup_prelu_operator(
void* output_data = output_blob->data;
assert(output_data != NULL);
- return xnn_setup_prelu_nc_f32(
- opdata->operator_object,
- opdata->batch_size,
- input_data,
- output_data,
- threadpool);
+ switch (opdata->operator_object->type) {
+#ifndef XNN_NO_F16_OPERATORS
+ case xnn_operator_type_prelu_nc_f16:
+ return xnn_setup_prelu_nc_f16(
+ opdata->operator_object,
+ opdata->batch_size,
+ input_data,
+ output_data,
+ threadpool);
+#endif // XNN_NO_F16_OPERATORS
+ case xnn_operator_type_prelu_nc_f32:
+ return xnn_setup_prelu_nc_f32(
+ opdata->operator_object,
+ opdata->batch_size,
+ input_data,
+ output_data,
+ threadpool);
+ default:
+ XNN_UNREACHABLE;
+ }
+
}
enum xnn_status xnn_define_prelu(