diff options
author | Marat Dukhan <maratek@google.com> | 2022-02-04 04:05:35 -0800 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2022-02-04 04:06:55 -0800 |
commit | cb872b09e8e4655e00efa22cc4fdf433ee06acbb (patch) | |
tree | e4899e32dca99a8236e85dbf62131c04e2977c1c | |
parent | 2bd2bd2413a23903b5e34621a2c69ea0fd5b51b2 (diff) | |
download | XNNPACK-cb872b09e8e4655e00efa22cc4fdf433ee06acbb.tar.gz |
Support Static Reshape for QS8/QU8 Tensors and in FP16 graph rewriting
PiperOrigin-RevId: 426366658
-rw-r--r-- | src/subgraph.c | 1 | ||||
-rw-r--r-- | src/subgraph/static-reshape.c | 127 |
2 files changed, 115 insertions, 13 deletions
diff --git a/src/subgraph.c b/src/subgraph.c index 18bdc4b48..342628f32 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -621,6 +621,7 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) case xnn_node_type_max_pooling_2d: case xnn_node_type_prelu: case xnn_node_type_static_constant_pad: + case xnn_node_type_static_reshape: break; default: xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference", diff --git a/src/subgraph/static-reshape.c b/src/subgraph/static-reshape.c index c999f42b9..2437a968a 100644 --- a/src/subgraph/static-reshape.c +++ b/src/subgraph/static-reshape.c @@ -20,8 +20,6 @@ static enum xnn_status create_copy_operator( size_t num_values, struct xnn_operator_data* opdata) { - assert(node->compute_type == xnn_compute_type_fp32); - assert(node->num_inputs == 1); const uint32_t input_id = node->inputs[0]; assert(input_id != XNN_INVALID_VALUE_ID); @@ -32,10 +30,38 @@ static enum xnn_status create_copy_operator( assert(output_id != XNN_INVALID_VALUE_ID); assert(output_id < num_values); - const enum xnn_status status = xnn_create_copy_nc_x32( - 1 /* channels */, 1 /* input stride */, 1 /* output stride */, - node->flags, - &opdata->operator_object); + enum xnn_status status; + switch (node->compute_type) { +#ifndef XNN_NO_F16_OPERATORS + case xnn_compute_type_fp16: + status = xnn_create_copy_nc_x16( + 1 /* channels */, 1 /* input stride */, 1 /* output stride */, + node->flags, + &opdata->operator_object); + break; +#endif // !defined(XNN_NO_F16_OPERATORS) + case xnn_compute_type_fp32: + status = xnn_create_copy_nc_x32( + 1 /* channels */, 1 /* input stride */, 1 /* output stride */, + node->flags, + &opdata->operator_object); + break; +#ifndef XNN_NO_QS8_OPERATORS + case xnn_compute_type_qs8: +#endif // !defined(XNN_NO_QS8_OPERATORS) +#ifndef XNN_NO_QU8_OPERATORS + case xnn_compute_type_qu8: +#endif // !defined(XNN_NO_QU8_OPERATORS) +#if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS) + status = xnn_create_copy_nc_x8( + 1 /* channels */, 1 /* input stride */, 1 /* output stride */, + node->flags, + &opdata->operator_object); + break; +#endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS) + default: + XNN_UNREACHABLE; + } if (status == xnn_status_success) { opdata->batch_size = xnn_shape_multiply_all_dims(&values[input_id].shape); opdata->inputs[0] = input_id; @@ -66,12 +92,38 @@ static enum xnn_status setup_copy_operator( void* output_data = output_blob->data; assert(output_data != NULL); - return xnn_setup_copy_nc_x32( - opdata->operator_object, - opdata->batch_size, - input_data, - output_data, - threadpool); + switch (opdata->operator_object->type) { +#if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS) + case xnn_operator_type_copy_nc_x8: + return xnn_setup_copy_nc_x8( + opdata->operator_object, + opdata->batch_size, + input_data, + output_data, + threadpool); + break; +#endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS) +#ifndef XNN_NO_F16_OPERATORS + case xnn_operator_type_copy_nc_x16: + return xnn_setup_copy_nc_x16( + opdata->operator_object, + opdata->batch_size, + input_data, + output_data, + threadpool); + break; +#endif // !defined(XNN_NO_F16_OPERATORS) + case xnn_operator_type_copy_nc_x32: + return xnn_setup_copy_nc_x32( + opdata->operator_object, + opdata->batch_size, + input_data, + output_data, + threadpool); + break; + default: + XNN_UNREACHABLE; + } } enum xnn_status xnn_define_static_reshape( @@ -105,6 +157,12 @@ enum xnn_status xnn_define_static_reshape( switch (input_value->datatype) { case xnn_datatype_fp32: +#ifndef XNN_NO_QS8_OPERATORS + case xnn_datatype_qint8: +#endif // !defined(XNN_NO_QS8_OPERATORS) +#ifndef XNN_NO_QU8_OPERATORS + case xnn_datatype_quint8: +#endif // !defined(XNN_NO_QU8_OPERATORS) break; default: xnn_log_error( @@ -129,9 +187,21 @@ enum xnn_status xnn_define_static_reshape( return xnn_status_invalid_parameter; } + enum xnn_compute_type compute_type = xnn_compute_type_invalid; switch (output_value->datatype) { case xnn_datatype_fp32: + compute_type = xnn_compute_type_fp32; + break; +#ifndef XNN_NO_QS8_OPERATORS + case xnn_datatype_qint8: + compute_type = xnn_compute_type_qs8; break; +#endif // !defined(XNN_NO_QS8_OPERATORS) +#ifndef XNN_NO_QU8_OPERATORS + case xnn_datatype_quint8: + compute_type = xnn_compute_type_qu8; + break; +#endif // !defined(XNN_NO_QU8_OPERATORS) default: xnn_log_error( "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)", @@ -140,6 +210,37 @@ enum xnn_status xnn_define_static_reshape( return xnn_status_invalid_parameter; } + if (input_value->datatype != output_value->datatype) { + xnn_log_error( + "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32 + ": mismatching datatypes across input (%s) and output (%s)", + xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id, + xnn_datatype_to_string(input_value->datatype), + xnn_datatype_to_string(output_value->datatype)); + return xnn_status_invalid_parameter; + } + +#if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS) + if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) { + if (input_value->quantization.zero_point != output_value->quantization.zero_point) { + xnn_log_error( + "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32 + ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")", + xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id, + input_value->quantization.zero_point, output_value->quantization.zero_point); + return xnn_status_invalid_parameter; + } + if (input_value->quantization.scale != output_value->quantization.scale) { + xnn_log_error( + "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32 + ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)", + xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id, + input_value->quantization.scale, output_value->quantization.scale); + return xnn_status_invalid_parameter; + } + } +#endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS) + if (num_dims > XNN_MAX_TENSOR_DIMS) { xnn_log_error( "failed to define %s operator with %zu-dimensional output shape: at most %zu dimensions are supported", @@ -156,7 +257,7 @@ enum xnn_status xnn_define_static_reshape( memcpy(&node->params.static_reshape.new_shape.dim, new_shape, num_dims * sizeof(size_t)); node->type = xnn_node_type_static_reshape; - node->compute_type = xnn_compute_type_fp32; + node->compute_type = compute_type; node->num_inputs = 1; node->inputs[0] = input_id; node->num_outputs = 1; |