diff options
author | Marat Dukhan <maratek@google.com> | 2022-02-04 00:00:18 -0800 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2022-02-04 00:01:19 -0800 |
commit | 4b90bee3268e790231431c9b7fd3e92eb0dd9dbd (patch) | |
tree | ec5fc59fe981fb3eceeef89967173fdb7f77d590 | |
parent | 10f2bf860d21f3c5c523b5a82d8f69f4f1b7fc9e (diff) | |
download | XNNPACK-4b90bee3268e790231431c9b7fd3e92eb0dd9dbd.tar.gz |
Support Static Constant Pad in FP16 graph rewriting
PiperOrigin-RevId: 426329126
-rw-r--r-- | src/subgraph.c | 7 | ||||
-rw-r--r-- | src/subgraph/static-constant-pad.c | 21 |
2 files changed, 28 insertions, 0 deletions
diff --git a/src/subgraph.c b/src/subgraph.c index 9ed9d271f..903331405 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -8,6 +8,8 @@ #include <stdint.h> #include <stdlib.h> +#include <fp16.h> + #include <xnnpack.h> #include <xnnpack/allocator.h> #include <xnnpack/log.h> @@ -611,6 +613,7 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) case xnn_node_type_depthwise_convolution_2d: case xnn_node_type_global_average_pooling_2d: case xnn_node_type_hardswish: + case xnn_node_type_static_constant_pad: break; default: xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference", @@ -684,6 +687,10 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) struct xnn_node* node = &subgraph->nodes[n]; assert(node->compute_type == xnn_compute_type_fp32); node->compute_type = xnn_compute_type_fp16; + if (node->type == xnn_node_type_static_constant_pad) { + node->params.static_pad.padding_value = + fp16_ieee_from_fp32_value(fp32_from_bits(node->params.static_pad.padding_value)); + } for (uint32_t i = 0; i < node->num_inputs; i++) { const uint32_t fp16_id = subgraph->values[node->inputs[i]].fp16_id; if (fp16_id != XNN_INVALID_VALUE_ID) { diff --git a/src/subgraph/static-constant-pad.c b/src/subgraph/static-constant-pad.c index fe1669000..b9ff70362 100644 --- a/src/subgraph/static-constant-pad.c +++ b/src/subgraph/static-constant-pad.c @@ -34,6 +34,14 @@ static enum xnn_status create_constant_pad_operator( enum xnn_status status; switch (node->compute_type) { +#ifndef XNN_NO_F16_OPERATORS + case xnn_compute_type_fp16: + status = xnn_create_constant_pad_nd_x16( + &node->params.static_pad.padding_value, + node->flags, + &opdata->operator_object); + break; +#endif // !defined(XNN_NO_F16_OPERATORS) case xnn_compute_type_fp32: status = xnn_create_constant_pad_nd_x32( &node->params.static_pad.padding_value, @@ -102,6 +110,19 @@ static enum xnn_status setup_constant_pad_operator( threadpool); break; #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS) +#ifndef XNN_NO_F16_OPERATORS + case xnn_operator_type_constant_pad_nd_x16: + return xnn_setup_constant_pad_nd_x16( + opdata->operator_object, + opdata->shape1.num_dims, + opdata->shape1.dim, + opdata->pre_paddings, + opdata->post_paddings, + input_data, + output_data, + threadpool); + break; +#endif // !defined(XNN_NO_F16_OPERATORS) case xnn_operator_type_constant_pad_nd_x32: return xnn_setup_constant_pad_nd_x32( opdata->operator_object, |