aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@google.com>2022-02-04 00:00:18 -0800
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-02-04 00:01:19 -0800
commit4b90bee3268e790231431c9b7fd3e92eb0dd9dbd (patch)
treeec5fc59fe981fb3eceeef89967173fdb7f77d590
parent10f2bf860d21f3c5c523b5a82d8f69f4f1b7fc9e (diff)
downloadXNNPACK-4b90bee3268e790231431c9b7fd3e92eb0dd9dbd.tar.gz
Support Static Constant Pad in FP16 graph rewriting
PiperOrigin-RevId: 426329126
-rw-r--r--src/subgraph.c7
-rw-r--r--src/subgraph/static-constant-pad.c21
2 files changed, 28 insertions, 0 deletions
diff --git a/src/subgraph.c b/src/subgraph.c
index 9ed9d271f..903331405 100644
--- a/src/subgraph.c
+++ b/src/subgraph.c
@@ -8,6 +8,8 @@
#include <stdint.h>
#include <stdlib.h>
+#include <fp16.h>
+
#include <xnnpack.h>
#include <xnnpack/allocator.h>
#include <xnnpack/log.h>
@@ -611,6 +613,7 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
case xnn_node_type_depthwise_convolution_2d:
case xnn_node_type_global_average_pooling_2d:
case xnn_node_type_hardswish:
+ case xnn_node_type_static_constant_pad:
break;
default:
xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference",
@@ -684,6 +687,10 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
struct xnn_node* node = &subgraph->nodes[n];
assert(node->compute_type == xnn_compute_type_fp32);
node->compute_type = xnn_compute_type_fp16;
+ if (node->type == xnn_node_type_static_constant_pad) {
+ node->params.static_pad.padding_value =
+ fp16_ieee_from_fp32_value(fp32_from_bits(node->params.static_pad.padding_value));
+ }
for (uint32_t i = 0; i < node->num_inputs; i++) {
const uint32_t fp16_id = subgraph->values[node->inputs[i]].fp16_id;
if (fp16_id != XNN_INVALID_VALUE_ID) {
diff --git a/src/subgraph/static-constant-pad.c b/src/subgraph/static-constant-pad.c
index fe1669000..b9ff70362 100644
--- a/src/subgraph/static-constant-pad.c
+++ b/src/subgraph/static-constant-pad.c
@@ -34,6 +34,14 @@ static enum xnn_status create_constant_pad_operator(
enum xnn_status status;
switch (node->compute_type) {
+#ifndef XNN_NO_F16_OPERATORS
+ case xnn_compute_type_fp16:
+ status = xnn_create_constant_pad_nd_x16(
+ &node->params.static_pad.padding_value,
+ node->flags,
+ &opdata->operator_object);
+ break;
+#endif // !defined(XNN_NO_F16_OPERATORS)
case xnn_compute_type_fp32:
status = xnn_create_constant_pad_nd_x32(
&node->params.static_pad.padding_value,
@@ -102,6 +110,19 @@ static enum xnn_status setup_constant_pad_operator(
threadpool);
break;
#endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+#ifndef XNN_NO_F16_OPERATORS
+ case xnn_operator_type_constant_pad_nd_x16:
+ return xnn_setup_constant_pad_nd_x16(
+ opdata->operator_object,
+ opdata->shape1.num_dims,
+ opdata->shape1.dim,
+ opdata->pre_paddings,
+ opdata->post_paddings,
+ input_data,
+ output_data,
+ threadpool);
+ break;
+#endif // !defined(XNN_NO_F16_OPERATORS)
case xnn_operator_type_constant_pad_nd_x32:
return xnn_setup_constant_pad_nd_x32(
opdata->operator_object,