aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@google.com>2022-02-04 04:05:35 -0800
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-02-04 04:06:55 -0800
commitcb872b09e8e4655e00efa22cc4fdf433ee06acbb (patch)
treee4899e32dca99a8236e85dbf62131c04e2977c1c
parent2bd2bd2413a23903b5e34621a2c69ea0fd5b51b2 (diff)
downloadXNNPACK-cb872b09e8e4655e00efa22cc4fdf433ee06acbb.tar.gz
Support Static Reshape for QS8/QU8 Tensors and in FP16 graph rewriting
PiperOrigin-RevId: 426366658
-rw-r--r--src/subgraph.c1
-rw-r--r--src/subgraph/static-reshape.c127
2 files changed, 115 insertions, 13 deletions
diff --git a/src/subgraph.c b/src/subgraph.c
index 18bdc4b48..342628f32 100644
--- a/src/subgraph.c
+++ b/src/subgraph.c
@@ -621,6 +621,7 @@ void xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
case xnn_node_type_max_pooling_2d:
case xnn_node_type_prelu:
case xnn_node_type_static_constant_pad:
+ case xnn_node_type_static_reshape:
break;
default:
xnn_log_info("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference",
diff --git a/src/subgraph/static-reshape.c b/src/subgraph/static-reshape.c
index c999f42b9..2437a968a 100644
--- a/src/subgraph/static-reshape.c
+++ b/src/subgraph/static-reshape.c
@@ -20,8 +20,6 @@ static enum xnn_status create_copy_operator(
size_t num_values,
struct xnn_operator_data* opdata)
{
- assert(node->compute_type == xnn_compute_type_fp32);
-
assert(node->num_inputs == 1);
const uint32_t input_id = node->inputs[0];
assert(input_id != XNN_INVALID_VALUE_ID);
@@ -32,10 +30,38 @@ static enum xnn_status create_copy_operator(
assert(output_id != XNN_INVALID_VALUE_ID);
assert(output_id < num_values);
- const enum xnn_status status = xnn_create_copy_nc_x32(
- 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
- node->flags,
- &opdata->operator_object);
+ enum xnn_status status;
+ switch (node->compute_type) {
+#ifndef XNN_NO_F16_OPERATORS
+ case xnn_compute_type_fp16:
+ status = xnn_create_copy_nc_x16(
+ 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
+ node->flags,
+ &opdata->operator_object);
+ break;
+#endif // !defined(XNN_NO_F16_OPERATORS)
+ case xnn_compute_type_fp32:
+ status = xnn_create_copy_nc_x32(
+ 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
+ node->flags,
+ &opdata->operator_object);
+ break;
+#ifndef XNN_NO_QS8_OPERATORS
+ case xnn_compute_type_qs8:
+#endif // !defined(XNN_NO_QS8_OPERATORS)
+#ifndef XNN_NO_QU8_OPERATORS
+ case xnn_compute_type_qu8:
+#endif // !defined(XNN_NO_QU8_OPERATORS)
+#if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+ status = xnn_create_copy_nc_x8(
+ 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
+ node->flags,
+ &opdata->operator_object);
+ break;
+#endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+ default:
+ XNN_UNREACHABLE;
+ }
if (status == xnn_status_success) {
opdata->batch_size = xnn_shape_multiply_all_dims(&values[input_id].shape);
opdata->inputs[0] = input_id;
@@ -66,12 +92,38 @@ static enum xnn_status setup_copy_operator(
void* output_data = output_blob->data;
assert(output_data != NULL);
- return xnn_setup_copy_nc_x32(
- opdata->operator_object,
- opdata->batch_size,
- input_data,
- output_data,
- threadpool);
+ switch (opdata->operator_object->type) {
+#if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+ case xnn_operator_type_copy_nc_x8:
+ return xnn_setup_copy_nc_x8(
+ opdata->operator_object,
+ opdata->batch_size,
+ input_data,
+ output_data,
+ threadpool);
+ break;
+#endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+#ifndef XNN_NO_F16_OPERATORS
+ case xnn_operator_type_copy_nc_x16:
+ return xnn_setup_copy_nc_x16(
+ opdata->operator_object,
+ opdata->batch_size,
+ input_data,
+ output_data,
+ threadpool);
+ break;
+#endif // !defined(XNN_NO_F16_OPERATORS)
+ case xnn_operator_type_copy_nc_x32:
+ return xnn_setup_copy_nc_x32(
+ opdata->operator_object,
+ opdata->batch_size,
+ input_data,
+ output_data,
+ threadpool);
+ break;
+ default:
+ XNN_UNREACHABLE;
+ }
}
enum xnn_status xnn_define_static_reshape(
@@ -105,6 +157,12 @@ enum xnn_status xnn_define_static_reshape(
switch (input_value->datatype) {
case xnn_datatype_fp32:
+#ifndef XNN_NO_QS8_OPERATORS
+ case xnn_datatype_qint8:
+#endif // !defined(XNN_NO_QS8_OPERATORS)
+#ifndef XNN_NO_QU8_OPERATORS
+ case xnn_datatype_quint8:
+#endif // !defined(XNN_NO_QU8_OPERATORS)
break;
default:
xnn_log_error(
@@ -129,9 +187,21 @@ enum xnn_status xnn_define_static_reshape(
return xnn_status_invalid_parameter;
}
+ enum xnn_compute_type compute_type = xnn_compute_type_invalid;
switch (output_value->datatype) {
case xnn_datatype_fp32:
+ compute_type = xnn_compute_type_fp32;
+ break;
+#ifndef XNN_NO_QS8_OPERATORS
+ case xnn_datatype_qint8:
+ compute_type = xnn_compute_type_qs8;
break;
+#endif // !defined(XNN_NO_QS8_OPERATORS)
+#ifndef XNN_NO_QU8_OPERATORS
+ case xnn_datatype_quint8:
+ compute_type = xnn_compute_type_qu8;
+ break;
+#endif // !defined(XNN_NO_QU8_OPERATORS)
default:
xnn_log_error(
"failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
@@ -140,6 +210,37 @@ enum xnn_status xnn_define_static_reshape(
return xnn_status_invalid_parameter;
}
+ if (input_value->datatype != output_value->datatype) {
+ xnn_log_error(
+ "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
+ ": mismatching datatypes across input (%s) and output (%s)",
+ xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
+ xnn_datatype_to_string(input_value->datatype),
+ xnn_datatype_to_string(output_value->datatype));
+ return xnn_status_invalid_parameter;
+ }
+
+#if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
+ if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
+ if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
+ xnn_log_error(
+ "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
+ ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
+ xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
+ input_value->quantization.zero_point, output_value->quantization.zero_point);
+ return xnn_status_invalid_parameter;
+ }
+ if (input_value->quantization.scale != output_value->quantization.scale) {
+ xnn_log_error(
+ "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
+ ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
+ xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
+ input_value->quantization.scale, output_value->quantization.scale);
+ return xnn_status_invalid_parameter;
+ }
+ }
+#endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
+
if (num_dims > XNN_MAX_TENSOR_DIMS) {
xnn_log_error(
"failed to define %s operator with %zu-dimensional output shape: at most %zu dimensions are supported",
@@ -156,7 +257,7 @@ enum xnn_status xnn_define_static_reshape(
memcpy(&node->params.static_reshape.new_shape.dim, new_shape, num_dims * sizeof(size_t));
node->type = xnn_node_type_static_reshape;
- node->compute_type = xnn_compute_type_fp32;
+ node->compute_type = compute_type;
node->num_inputs = 1;
node->inputs[0] = input_id;
node->num_outputs = 1;