diff options
author | Marat Dukhan <maratek@google.com> | 2022-02-04 02:18:23 -0800 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2022-02-04 02:19:34 -0800 |
commit | 170f95ad59dff6b4c05a428af12b75517ad82648 (patch) | |
tree | 9f74abfe7d540aa77f810fae403acb43b72521c9 | |
parent | 5756a927fc5044bdcfebe57d4bd84408ca0a0975 (diff) | |
download | XNNPACK-170f95ad59dff6b4c05a428af12b75517ad82648.tar.gz |
Support PReLU in FP16 graph rewriting
PiperOrigin-RevId: 426349832
-rw-r--r-- | src/subgraph.c | 9 | ||||
-rw-r--r-- | src/subgraph/prelu.c | 55 |
2 files changed, 50 insertions, 14 deletions
diff --git a/src/subgraph.c b/src/subgraph.c index 903331405..9ceea30ad 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -594,6 +594,11 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) // Check that all operators in the subgraph are supported in FP16, bail out on any unsupported one. for (uint32_t n = 0; n < subgraph->num_nodes; n++) { struct xnn_node* node = &subgraph->nodes[n]; + if (node->type == xnn_node_type_invalid) { + // Node was fused away, skip. + continue; + } + if (node->compute_type != xnn_compute_type_fp32) { xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not FP32", n, xnn_node_type_to_string(node->type)); return; @@ -613,6 +618,7 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) case xnn_node_type_depthwise_convolution_2d: case xnn_node_type_global_average_pooling_2d: case xnn_node_type_hardswish: + case xnn_node_type_prelu: case xnn_node_type_static_constant_pad: break; default: @@ -623,13 +629,14 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) } // Annotate Values to be converted to FP16 as FP16-compatible. - // Note that static weights in [Depthwise] Convolution & Fully Connected Nodes remain FP32, + // Note that static weights in [Depthwise] Convolution, Fully Connected, and PReLU Nodes remain FP32, // they will be converted to FP16 during weight repacking when the operator is created. for (uint32_t n = 0; n < subgraph->num_nodes; n++) { struct xnn_node* node = &subgraph->nodes[n]; switch (node->type) { case xnn_node_type_convolution_2d: case xnn_node_type_depthwise_convolution_2d: + case xnn_node_type_prelu: subgraph->values[node->inputs[0]].fp16_compatible = true; subgraph->values[node->outputs[0]].fp16_compatible = true; break; diff --git a/src/subgraph/prelu.c b/src/subgraph/prelu.c index b6d3b1c0a..d4be75295 100644 --- a/src/subgraph/prelu.c +++ b/src/subgraph/prelu.c @@ -19,8 +19,6 @@ static enum xnn_status create_prelu_operator( size_t num_values, struct xnn_operator_data* opdata) { - assert(node->compute_type == xnn_compute_type_fp32); - assert(node->num_inputs == 2); const uint32_t input_id = node->inputs[0]; assert(input_id != XNN_INVALID_VALUE_ID); @@ -37,11 +35,27 @@ static enum xnn_status create_prelu_operator( const size_t num_input_dims = values[input_id].shape.num_dims; const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1]; - const enum xnn_status status = xnn_create_prelu_nc_f32( - channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */, - values[slope_id].data /* negative slope */, - node->flags, - &opdata->operator_object); + enum xnn_status status; + switch (node->compute_type) { +#ifndef XNN_NO_F16_OPERATORS + case xnn_compute_type_fp16: + status = xnn_create_prelu_nc_f16( + channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */, + values[slope_id].data /* negative slope */, + node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS, + &opdata->operator_object); + break; +#endif // XNN_NO_F16_OPERATORS + case xnn_compute_type_fp32: + status = xnn_create_prelu_nc_f32( + channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */, + values[slope_id].data /* negative slope */, + node->flags, + &opdata->operator_object); + break; + default: + XNN_UNREACHABLE; + } if (status == xnn_status_success) { opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape); opdata->inputs[0] = input_id; @@ -72,12 +86,27 @@ static enum xnn_status setup_prelu_operator( void* output_data = output_blob->data; assert(output_data != NULL); - return xnn_setup_prelu_nc_f32( - opdata->operator_object, - opdata->batch_size, - input_data, - output_data, - threadpool); + switch (opdata->operator_object->type) { +#ifndef XNN_NO_F16_OPERATORS + case xnn_operator_type_prelu_nc_f16: + return xnn_setup_prelu_nc_f16( + opdata->operator_object, + opdata->batch_size, + input_data, + output_data, + threadpool); +#endif // XNN_NO_F16_OPERATORS + case xnn_operator_type_prelu_nc_f32: + return xnn_setup_prelu_nc_f32( + opdata->operator_object, + opdata->batch_size, + input_data, + output_data, + threadpool); + default: + XNN_UNREACHABLE; + } + } enum xnn_status xnn_define_prelu( |