diff options
author | Marat Dukhan <maratek@google.com> | 2022-08-17 15:48:43 -0700 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2022-08-17 15:49:36 -0700 |
commit | c6a48f1bac4fd45d8e4daf0fc3e89bc2b5382416 (patch) | |
tree | 69aa47e8853d1e11048374ffe5d2103143c58193 /src | |
parent | f7ffbc7e3f56d6b4f4fd5cf5a09edb4a0a02b714 (diff) | |
download | XNNPACK-c6a48f1bac4fd45d8e4daf0fc3e89bc2b5382416.tar.gz |
Specialize binary elementwise operation task for 1D-4D cases
PiperOrigin-RevId: 468312516
Diffstat (limited to 'src')
-rw-r--r-- | src/operator-run.c | 46 | ||||
-rw-r--r-- | src/operators/binary-elementwise-nd.c | 46 | ||||
-rw-r--r-- | src/xnnpack/compute.h | 12 |
3 files changed, 95 insertions, 9 deletions
diff --git a/src/operator-run.c b/src/operator-run.c index 9cea3625d..ce3b4d362 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -969,6 +969,52 @@ void xnn_compute_pad_5d( } } +void xnn_compute_elementwise_binary_1d( + const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t i) +{ + const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[4]); + const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[4]); + void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[4]); + context->ukernel(context->elements, a, b, y, &context->params); +} + +void xnn_compute_elementwise_binary_2d( + const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j) +{ + const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[3] + j * context->a_stride[4]); + const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[3] + j * context->b_stride[4]); + void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[3] + j * context->y_stride[4]); + context->ukernel(context->elements, a, b, y, &context->params); +} + +void xnn_compute_elementwise_binary_3d( + const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j, size_t k) +{ + const void* a = (const void*) ((uintptr_t) context->a + + i * context->a_stride[2] + j * context->a_stride[3] + k * context->a_stride[4]); + const void* b = (const void*) ((uintptr_t) context->b + + i * context->b_stride[2] + j * context->b_stride[3] + k * context->b_stride[4]); + void* y = (void*) ((uintptr_t) context->y + + i * context->y_stride[2] + j * context->y_stride[3] + k * context->y_stride[4]); + context->ukernel(context->elements, a, b, y, &context->params); +} + +void xnn_compute_elementwise_binary_4d( + const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j, size_t k, size_t l) +{ + const void* a = (const void*) ((uintptr_t) context->a + + i * context->a_stride[1] + j * context->a_stride[2] + k * context->a_stride[3] + l * context->a_stride[4]); + const void* b = (const void*) ((uintptr_t) context->b + + i * context->b_stride[1] + j * context->b_stride[2] + k * context->b_stride[3] + l * context->b_stride[4]); + void* y = (void*) ((uintptr_t) context->y + + i * context->y_stride[1] + j * context->y_stride[2] + k * context->y_stride[3] + l * context->y_stride[4]); + context->ukernel(context->elements, a, b, y, &context->params); +} + void xnn_compute_elementwise_binary_5d( const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, size_t j, size_t k, size_t l, size_t m) diff --git a/src/operators/binary-elementwise-nd.c b/src/operators/binary-elementwise-nd.c index 03ddc8205..014827519 100644 --- a/src/operators/binary-elementwise-nd.c +++ b/src/operators/binary-elementwise-nd.c @@ -1004,15 +1004,43 @@ static enum xnn_status setup_binary_elementwise_nd( y_stride *= compressed_output_shape[i]; } - binary_elementwise_op->compute.type = xnn_parallelization_type_5d; - binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d; - binary_elementwise_op->compute.range[0] = compressed_output_shape[5]; - binary_elementwise_op->compute.range[1] = compressed_output_shape[4]; - binary_elementwise_op->compute.range[2] = compressed_output_shape[3]; - binary_elementwise_op->compute.range[3] = compressed_output_shape[2]; - binary_elementwise_op->compute.range[4] = compressed_output_shape[1]; - binary_elementwise_op->compute.tile[0] = 1; - binary_elementwise_op->compute.tile[1] = 1; + if (compressed_output_shape[5] == 1) { + if (compressed_output_shape[4] == 1) { + if (compressed_output_shape[3] == 1) { + if (compressed_output_shape[2] == 1) { + binary_elementwise_op->compute.type = xnn_parallelization_type_1d; + binary_elementwise_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_elementwise_binary_1d; + binary_elementwise_op->compute.range[0] = compressed_output_shape[1]; + } else { + binary_elementwise_op->compute.type = xnn_parallelization_type_2d; + binary_elementwise_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_elementwise_binary_2d; + binary_elementwise_op->compute.range[0] = compressed_output_shape[2]; + binary_elementwise_op->compute.range[1] = compressed_output_shape[1]; + } + } else { + binary_elementwise_op->compute.type = xnn_parallelization_type_3d; + binary_elementwise_op->compute.task_3d = (pthreadpool_task_3d_t) xnn_compute_elementwise_binary_3d; + binary_elementwise_op->compute.range[0] = compressed_output_shape[3]; + binary_elementwise_op->compute.range[1] = compressed_output_shape[2]; + binary_elementwise_op->compute.range[2] = compressed_output_shape[1]; + } + } else { + binary_elementwise_op->compute.type = xnn_parallelization_type_4d; + binary_elementwise_op->compute.task_4d = (pthreadpool_task_4d_t) xnn_compute_elementwise_binary_4d; + binary_elementwise_op->compute.range[0] = compressed_output_shape[4]; + binary_elementwise_op->compute.range[1] = compressed_output_shape[3]; + binary_elementwise_op->compute.range[2] = compressed_output_shape[2]; + binary_elementwise_op->compute.range[3] = compressed_output_shape[1]; + } + } else { + binary_elementwise_op->compute.type = xnn_parallelization_type_5d; + binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d; + binary_elementwise_op->compute.range[0] = compressed_output_shape[5]; + binary_elementwise_op->compute.range[1] = compressed_output_shape[4]; + binary_elementwise_op->compute.range[2] = compressed_output_shape[3]; + binary_elementwise_op->compute.range[3] = compressed_output_shape[2]; + binary_elementwise_op->compute.range[4] = compressed_output_shape[1]; + } binary_elementwise_op->state = xnn_run_state_ready; return xnn_status_success; diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index db6ff72ce..7370b143f 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -839,6 +839,18 @@ struct elementwise_binary_context { }; #ifndef __cplusplus + XNN_PRIVATE void xnn_compute_elementwise_binary_1d( + const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t i); + XNN_PRIVATE void xnn_compute_elementwise_binary_2d( + const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j); + XNN_PRIVATE void xnn_compute_elementwise_binary_3d( + const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j, size_t k); + XNN_PRIVATE void xnn_compute_elementwise_binary_4d( + const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j, size_t k, size_t l); XNN_PRIVATE void xnn_compute_elementwise_binary_5d( const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, size_t j, size_t k, size_t l, size_t m); |