aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@google.com>2022-08-17 15:48:43 -0700
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-08-17 15:49:36 -0700
commitc6a48f1bac4fd45d8e4daf0fc3e89bc2b5382416 (patch)
tree69aa47e8853d1e11048374ffe5d2103143c58193 /src
parentf7ffbc7e3f56d6b4f4fd5cf5a09edb4a0a02b714 (diff)
downloadXNNPACK-c6a48f1bac4fd45d8e4daf0fc3e89bc2b5382416.tar.gz
Specialize binary elementwise operation task for 1D-4D cases
PiperOrigin-RevId: 468312516
Diffstat (limited to 'src')
-rw-r--r--src/operator-run.c46
-rw-r--r--src/operators/binary-elementwise-nd.c46
-rw-r--r--src/xnnpack/compute.h12
3 files changed, 95 insertions, 9 deletions
diff --git a/src/operator-run.c b/src/operator-run.c
index 9cea3625d..ce3b4d362 100644
--- a/src/operator-run.c
+++ b/src/operator-run.c
@@ -969,6 +969,52 @@ void xnn_compute_pad_5d(
}
}
+void xnn_compute_elementwise_binary_1d(
+ const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
+ size_t i)
+{
+ const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[4]);
+ const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[4]);
+ void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[4]);
+ context->ukernel(context->elements, a, b, y, &context->params);
+}
+
+void xnn_compute_elementwise_binary_2d(
+ const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
+ size_t i, size_t j)
+{
+ const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[3] + j * context->a_stride[4]);
+ const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[3] + j * context->b_stride[4]);
+ void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[3] + j * context->y_stride[4]);
+ context->ukernel(context->elements, a, b, y, &context->params);
+}
+
+void xnn_compute_elementwise_binary_3d(
+ const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
+ size_t i, size_t j, size_t k)
+{
+ const void* a = (const void*) ((uintptr_t) context->a +
+ i * context->a_stride[2] + j * context->a_stride[3] + k * context->a_stride[4]);
+ const void* b = (const void*) ((uintptr_t) context->b +
+ i * context->b_stride[2] + j * context->b_stride[3] + k * context->b_stride[4]);
+ void* y = (void*) ((uintptr_t) context->y +
+ i * context->y_stride[2] + j * context->y_stride[3] + k * context->y_stride[4]);
+ context->ukernel(context->elements, a, b, y, &context->params);
+}
+
+void xnn_compute_elementwise_binary_4d(
+ const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
+ size_t i, size_t j, size_t k, size_t l)
+{
+ const void* a = (const void*) ((uintptr_t) context->a +
+ i * context->a_stride[1] + j * context->a_stride[2] + k * context->a_stride[3] + l * context->a_stride[4]);
+ const void* b = (const void*) ((uintptr_t) context->b +
+ i * context->b_stride[1] + j * context->b_stride[2] + k * context->b_stride[3] + l * context->b_stride[4]);
+ void* y = (void*) ((uintptr_t) context->y +
+ i * context->y_stride[1] + j * context->y_stride[2] + k * context->y_stride[3] + l * context->y_stride[4]);
+ context->ukernel(context->elements, a, b, y, &context->params);
+}
+
void xnn_compute_elementwise_binary_5d(
const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k, size_t l, size_t m)
diff --git a/src/operators/binary-elementwise-nd.c b/src/operators/binary-elementwise-nd.c
index 03ddc8205..014827519 100644
--- a/src/operators/binary-elementwise-nd.c
+++ b/src/operators/binary-elementwise-nd.c
@@ -1004,15 +1004,43 @@ static enum xnn_status setup_binary_elementwise_nd(
y_stride *= compressed_output_shape[i];
}
- binary_elementwise_op->compute.type = xnn_parallelization_type_5d;
- binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d;
- binary_elementwise_op->compute.range[0] = compressed_output_shape[5];
- binary_elementwise_op->compute.range[1] = compressed_output_shape[4];
- binary_elementwise_op->compute.range[2] = compressed_output_shape[3];
- binary_elementwise_op->compute.range[3] = compressed_output_shape[2];
- binary_elementwise_op->compute.range[4] = compressed_output_shape[1];
- binary_elementwise_op->compute.tile[0] = 1;
- binary_elementwise_op->compute.tile[1] = 1;
+ if (compressed_output_shape[5] == 1) {
+ if (compressed_output_shape[4] == 1) {
+ if (compressed_output_shape[3] == 1) {
+ if (compressed_output_shape[2] == 1) {
+ binary_elementwise_op->compute.type = xnn_parallelization_type_1d;
+ binary_elementwise_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_elementwise_binary_1d;
+ binary_elementwise_op->compute.range[0] = compressed_output_shape[1];
+ } else {
+ binary_elementwise_op->compute.type = xnn_parallelization_type_2d;
+ binary_elementwise_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_elementwise_binary_2d;
+ binary_elementwise_op->compute.range[0] = compressed_output_shape[2];
+ binary_elementwise_op->compute.range[1] = compressed_output_shape[1];
+ }
+ } else {
+ binary_elementwise_op->compute.type = xnn_parallelization_type_3d;
+ binary_elementwise_op->compute.task_3d = (pthreadpool_task_3d_t) xnn_compute_elementwise_binary_3d;
+ binary_elementwise_op->compute.range[0] = compressed_output_shape[3];
+ binary_elementwise_op->compute.range[1] = compressed_output_shape[2];
+ binary_elementwise_op->compute.range[2] = compressed_output_shape[1];
+ }
+ } else {
+ binary_elementwise_op->compute.type = xnn_parallelization_type_4d;
+ binary_elementwise_op->compute.task_4d = (pthreadpool_task_4d_t) xnn_compute_elementwise_binary_4d;
+ binary_elementwise_op->compute.range[0] = compressed_output_shape[4];
+ binary_elementwise_op->compute.range[1] = compressed_output_shape[3];
+ binary_elementwise_op->compute.range[2] = compressed_output_shape[2];
+ binary_elementwise_op->compute.range[3] = compressed_output_shape[1];
+ }
+ } else {
+ binary_elementwise_op->compute.type = xnn_parallelization_type_5d;
+ binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d;
+ binary_elementwise_op->compute.range[0] = compressed_output_shape[5];
+ binary_elementwise_op->compute.range[1] = compressed_output_shape[4];
+ binary_elementwise_op->compute.range[2] = compressed_output_shape[3];
+ binary_elementwise_op->compute.range[3] = compressed_output_shape[2];
+ binary_elementwise_op->compute.range[4] = compressed_output_shape[1];
+ }
binary_elementwise_op->state = xnn_run_state_ready;
return xnn_status_success;
diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h
index db6ff72ce..7370b143f 100644
--- a/src/xnnpack/compute.h
+++ b/src/xnnpack/compute.h
@@ -839,6 +839,18 @@ struct elementwise_binary_context {
};
#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_elementwise_binary_1d(
+ const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
+ size_t i);
+ XNN_PRIVATE void xnn_compute_elementwise_binary_2d(
+ const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
+ size_t i, size_t j);
+ XNN_PRIVATE void xnn_compute_elementwise_binary_3d(
+ const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
+ size_t i, size_t j, size_t k);
+ XNN_PRIVATE void xnn_compute_elementwise_binary_4d(
+ const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
+ size_t i, size_t j, size_t k, size_t l);
XNN_PRIVATE void xnn_compute_elementwise_binary_5d(
const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k, size_t l, size_t m);