aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFrank Barchard <fbarchard@google.com>2022-08-20 09:28:36 -0700
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-08-20 09:29:17 -0700
commit3fd2d48a6bb392fbe033c03ae82f190ba7d186a9 (patch)
treef136a47681df521f797f78440d3cf4a76e19f1de /src
parenta141684389c9561ecf9e8cf063d59da167495726 (diff)
downloadXNNPACK-3fd2d48a6bb392fbe033c03ae82f190ba7d186a9.tar.gz
Specialized M1 variant of bfly4 scalar
PiperOrigin-RevId: 468906177
Diffstat (limited to 'src')
-rw-r--r--src/cs16-bfly4/gen/scalar-m1-x1.c97
-rw-r--r--src/cs16-bfly4/gen/scalar-x1.c3
-rw-r--r--src/cs16-bfly4/gen/scalar-x2.c3
-rw-r--r--src/cs16-bfly4/gen/scalar-x3.c3
-rw-r--r--src/cs16-bfly4/gen/scalar-x4.c3
-rw-r--r--src/cs16-bfly4/scalar.c.in57
-rw-r--r--src/xnnpack/fft.h1
7 files changed, 142 insertions, 25 deletions
diff --git a/src/cs16-bfly4/gen/scalar-m1-x1.c b/src/cs16-bfly4/gen/scalar-m1-x1.c
new file mode 100644
index 000000000..26361abd7
--- /dev/null
+++ b/src/cs16-bfly4/gen/scalar-m1-x1.c
@@ -0,0 +1,97 @@
+// Auto-generated file. Do not edit!
+// Template: src/cs16-bfly4/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2022 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/fft.h>
+
+
+void xnn_cs16_bfly4m1_ukernel__scalar_x1(
+ size_t samples,
+ int16_t* data,
+ const size_t stride,
+ const int16_t* twiddle) {
+
+ int16_t* out0 = data;
+ int16_t* out1 = data + samples * 2;
+ int16_t* out2 = data + samples * 4;
+ int16_t* out3 = data + samples * 6;
+
+ assert(samples == 1);
+ assert(data != NULL);
+ assert(stride != 0);
+ assert(twiddle != NULL);
+
+
+ if XNN_UNLIKELY(samples != 0) {
+ do {
+ int32_t vout0r = (int32_t) out0[0];
+ int32_t vout0i = (int32_t) out0[1];
+ int32_t vout1r = (int32_t) out1[0];
+ int32_t vout1i = (int32_t) out1[1];
+ int32_t vout2r = (int32_t) out2[0];
+ int32_t vout2i = (int32_t) out2[1];
+ int32_t vout3r = (int32_t) out3[0];
+ int32_t vout3i = (int32_t) out3[1];
+
+
+ // Note 32767 / 4 = 8191. Should be 8192.
+ vout0r = math_asr_s32(vout0r * 8191 + 16384, 15);
+ vout0i = math_asr_s32(vout0i * 8191 + 16384, 15);
+ vout1r = math_asr_s32(vout1r * 8191 + 16384, 15);
+ vout1i = math_asr_s32(vout1i * 8191 + 16384, 15);
+ vout2r = math_asr_s32(vout2r * 8191 + 16384, 15);
+ vout2i = math_asr_s32(vout2i * 8191 + 16384, 15);
+ vout3r = math_asr_s32(vout3r * 8191 + 16384, 15);
+ vout3i = math_asr_s32(vout3i * 8191 + 16384, 15);
+
+ const int32_t vtmp0r = math_asr_s32(vout1r * 32767 + 16384, 15);
+ const int32_t vtmp0i = math_asr_s32(vout1i * 32767 + 16384, 15);
+ const int32_t vtmp1r = math_asr_s32(vout2r * 32767 + 16384, 15);
+ const int32_t vtmp1i = math_asr_s32(vout2i * 32767 + 16384, 15);
+ const int32_t vtmp2r = math_asr_s32(vout3r * 32767 + 16384, 15);
+ const int32_t vtmp2i = math_asr_s32(vout3i * 32767 + 16384, 15);
+
+ const int32_t vtmp5r = vout0r - vtmp1r;
+ const int32_t vtmp5i = vout0i - vtmp1i;
+ vout0r += vtmp1r;
+ vout0i += vtmp1i;
+ const int32_t vtmp3r = vtmp0r + vtmp2r;
+ const int32_t vtmp3i = vtmp0i + vtmp2i;
+ const int32_t vtmp4r = vtmp0r - vtmp2r;
+ const int32_t vtmp4i = vtmp0i - vtmp2i;
+ vout2r = vout0r - vtmp3r;
+ vout2i = vout0i - vtmp3i;
+
+ vout0r += vtmp3r;
+ vout0i += vtmp3i;
+
+ vout1r = vtmp5r + vtmp4i;
+ vout1i = vtmp5i - vtmp4r;
+ vout3r = vtmp5r - vtmp4i;
+ vout3i = vtmp5i + vtmp4r;
+
+ out0[0] = (int16_t) vout0r;
+ out0[1] = (int16_t) vout0i;
+ out1[0] = (int16_t) vout1r;
+ out1[1] = (int16_t) vout1i;
+ out2[0] = (int16_t) vout2r;
+ out2[1] = (int16_t) vout2i;
+ out3[0] = (int16_t) vout3r;
+ out3[1] = (int16_t) vout3i;
+ out0 += 2;
+ out1 += 2;
+ out2 += 2;
+ out3 += 2;
+ } while(--samples != 0);
+ }
+}
diff --git a/src/cs16-bfly4/gen/scalar-x1.c b/src/cs16-bfly4/gen/scalar-x1.c
index b15829038..89dd4b8eb 100644
--- a/src/cs16-bfly4/gen/scalar-x1.c
+++ b/src/cs16-bfly4/gen/scalar-x1.c
@@ -14,6 +14,7 @@
#include <xnnpack/math.h>
#include <xnnpack/fft.h>
+
void xnn_cs16_bfly4_ukernel__scalar_x1(
size_t samples,
int16_t* data,
@@ -29,9 +30,9 @@ void xnn_cs16_bfly4_ukernel__scalar_x1(
int16_t* out3 = data + samples * 6;
assert(samples != 0);
+ assert(data != NULL);
assert(stride != 0);
assert(twiddle != NULL);
- assert(data != NULL);
if XNN_UNLIKELY(samples != 0) {
diff --git a/src/cs16-bfly4/gen/scalar-x2.c b/src/cs16-bfly4/gen/scalar-x2.c
index aba94b9f3..99bd8653e 100644
--- a/src/cs16-bfly4/gen/scalar-x2.c
+++ b/src/cs16-bfly4/gen/scalar-x2.c
@@ -14,6 +14,7 @@
#include <xnnpack/math.h>
#include <xnnpack/fft.h>
+
void xnn_cs16_bfly4_ukernel__scalar_x2(
size_t samples,
int16_t* data,
@@ -29,9 +30,9 @@ void xnn_cs16_bfly4_ukernel__scalar_x2(
int16_t* out3 = data + samples * 6;
assert(samples != 0);
+ assert(data != NULL);
assert(stride != 0);
assert(twiddle != NULL);
- assert(data != NULL);
for (; samples >= 2; samples -= 2) {
int32_t vout0r0 = (int32_t) out0[0];
diff --git a/src/cs16-bfly4/gen/scalar-x3.c b/src/cs16-bfly4/gen/scalar-x3.c
index eeacc671c..fcb487110 100644
--- a/src/cs16-bfly4/gen/scalar-x3.c
+++ b/src/cs16-bfly4/gen/scalar-x3.c
@@ -14,6 +14,7 @@
#include <xnnpack/math.h>
#include <xnnpack/fft.h>
+
void xnn_cs16_bfly4_ukernel__scalar_x3(
size_t samples,
int16_t* data,
@@ -29,9 +30,9 @@ void xnn_cs16_bfly4_ukernel__scalar_x3(
int16_t* out3 = data + samples * 6;
assert(samples != 0);
+ assert(data != NULL);
assert(stride != 0);
assert(twiddle != NULL);
- assert(data != NULL);
for (; samples >= 3; samples -= 3) {
int32_t vout0r0 = (int32_t) out0[0];
diff --git a/src/cs16-bfly4/gen/scalar-x4.c b/src/cs16-bfly4/gen/scalar-x4.c
index 39f1850f7..ac7117c1e 100644
--- a/src/cs16-bfly4/gen/scalar-x4.c
+++ b/src/cs16-bfly4/gen/scalar-x4.c
@@ -14,6 +14,7 @@
#include <xnnpack/math.h>
#include <xnnpack/fft.h>
+
void xnn_cs16_bfly4_ukernel__scalar_x4(
size_t samples,
int16_t* data,
@@ -29,9 +30,9 @@ void xnn_cs16_bfly4_ukernel__scalar_x4(
int16_t* out3 = data + samples * 6;
assert(samples != 0);
+ assert(data != NULL);
assert(stride != 0);
assert(twiddle != NULL);
- assert(data != NULL);
for (; samples >= 4; samples -= 4) {
int32_t vout0r0 = (int32_t) out0[0];
diff --git a/src/cs16-bfly4/scalar.c.in b/src/cs16-bfly4/scalar.c.in
index 486d96587..4f31c277f 100644
--- a/src/cs16-bfly4/scalar.c.in
+++ b/src/cs16-bfly4/scalar.c.in
@@ -11,24 +11,30 @@ $assert SAMPLE_TILE >= 1
#include <xnnpack/math.h>
#include <xnnpack/fft.h>
-void xnn_cs16_bfly4_ukernel__scalar_x${SAMPLE_TILE}(
+
+$VARIANT = "m%s" % M if M else ""
+void xnn_cs16_bfly4${VARIANT}_ukernel__scalar_x${SAMPLE_TILE}(
size_t samples,
int16_t* data,
const size_t stride,
const int16_t* twiddle) {
- const int16_t* tw1 = twiddle;
- const int16_t* tw2 = tw1;
- const int16_t* tw3 = tw1;
+ $if M != 1:
+ const int16_t* tw1 = twiddle;
+ const int16_t* tw2 = tw1;
+ const int16_t* tw3 = tw1;
int16_t* out0 = data;
int16_t* out1 = data + samples * 2;
int16_t* out2 = data + samples * 4;
int16_t* out3 = data + samples * 6;
- assert(samples != 0);
+ $if M != 0:
+ assert(samples == ${M});
+ $else:
+ assert(samples != 0);
+ assert(data != NULL);
assert(stride != 0);
assert(twiddle != NULL);
- assert(data != NULL);
$if SAMPLE_TILE > 1:
for (; samples >= ${SAMPLE_TILE}; samples -= ${SAMPLE_TILE}) {
@@ -151,15 +157,16 @@ void xnn_cs16_bfly4_ukernel__scalar_x${SAMPLE_TILE}(
int32_t vout3r = (int32_t) out3[0];
int32_t vout3i = (int32_t) out3[1];
- const int32_t vtw1r = (const int32_t) tw1[0];
- const int32_t vtw1i = (const int32_t) tw1[1];
- const int32_t vtw2r = (const int32_t) tw2[0];
- const int32_t vtw2i = (const int32_t) tw2[1];
- const int32_t vtw3r = (const int32_t) tw3[0];
- const int32_t vtw3i = (const int32_t) tw3[1];
- tw1 += stride * 2;
- tw2 += stride * 4;
- tw3 += stride * 6;
+ $if M != 1:
+ const int32_t vtw1r = (const int32_t) tw1[0];
+ const int32_t vtw1i = (const int32_t) tw1[1];
+ const int32_t vtw2r = (const int32_t) tw2[0];
+ const int32_t vtw2i = (const int32_t) tw2[1];
+ const int32_t vtw3r = (const int32_t) tw3[0];
+ const int32_t vtw3i = (const int32_t) tw3[1];
+ tw1 += stride * 2;
+ tw2 += stride * 4;
+ tw3 += stride * 6;
// Note 32767 / 4 = 8191. Should be 8192.
vout0r = math_asr_s32(vout0r * 8191 + 16384, 15);
@@ -171,12 +178,20 @@ void xnn_cs16_bfly4_ukernel__scalar_x${SAMPLE_TILE}(
vout3r = math_asr_s32(vout3r * 8191 + 16384, 15);
vout3i = math_asr_s32(vout3i * 8191 + 16384, 15);
- const int32_t vtmp0r = math_asr_s32(vout1r * vtw1r - vout1i * vtw1i + 16384, 15);
- const int32_t vtmp0i = math_asr_s32(vout1r * vtw1i + vout1i * vtw1r + 16384, 15);
- const int32_t vtmp1r = math_asr_s32(vout2r * vtw2r - vout2i * vtw2i + 16384, 15);
- const int32_t vtmp1i = math_asr_s32(vout2r * vtw2i + vout2i * vtw2r + 16384, 15);
- const int32_t vtmp2r = math_asr_s32(vout3r * vtw3r - vout3i * vtw3i + 16384, 15);
- const int32_t vtmp2i = math_asr_s32(vout3r * vtw3i + vout3i * vtw3r + 16384, 15);
+ $if M == 1:
+ const int32_t vtmp0r = math_asr_s32(vout1r * 32767 + 16384, 15);
+ const int32_t vtmp0i = math_asr_s32(vout1i * 32767 + 16384, 15);
+ const int32_t vtmp1r = math_asr_s32(vout2r * 32767 + 16384, 15);
+ const int32_t vtmp1i = math_asr_s32(vout2i * 32767 + 16384, 15);
+ const int32_t vtmp2r = math_asr_s32(vout3r * 32767 + 16384, 15);
+ const int32_t vtmp2i = math_asr_s32(vout3i * 32767 + 16384, 15);
+ $else:
+ const int32_t vtmp0r = math_asr_s32(vout1r * vtw1r - vout1i * vtw1i + 16384, 15);
+ const int32_t vtmp0i = math_asr_s32(vout1r * vtw1i + vout1i * vtw1r + 16384, 15);
+ const int32_t vtmp1r = math_asr_s32(vout2r * vtw2r - vout2i * vtw2i + 16384, 15);
+ const int32_t vtmp1i = math_asr_s32(vout2r * vtw2i + vout2i * vtw2r + 16384, 15);
+ const int32_t vtmp2r = math_asr_s32(vout3r * vtw3r - vout3i * vtw3i + 16384, 15);
+ const int32_t vtmp2i = math_asr_s32(vout3r * vtw3i + vout3i * vtw3r + 16384, 15);
const int32_t vtmp5r = vout0r - vtmp1r;
const int32_t vtmp5i = vout0i - vtmp1i;
diff --git a/src/xnnpack/fft.h b/src/xnnpack/fft.h
index 1b673bafd..50bb7a2e9 100644
--- a/src/xnnpack/fft.h
+++ b/src/xnnpack/fft.h
@@ -26,6 +26,7 @@ DECLARE_CS16_BFLY4_UKERNEL_FUNCTION(xnn_cs16_bfly4_ukernel__scalar_x1)
DECLARE_CS16_BFLY4_UKERNEL_FUNCTION(xnn_cs16_bfly4_ukernel__scalar_x2)
DECLARE_CS16_BFLY4_UKERNEL_FUNCTION(xnn_cs16_bfly4_ukernel__scalar_x3)
DECLARE_CS16_BFLY4_UKERNEL_FUNCTION(xnn_cs16_bfly4_ukernel__scalar_x4)
+DECLARE_CS16_BFLY4_UKERNEL_FUNCTION(xnn_cs16_bfly4m1_ukernel__scalar_x1)
#define DECLARE_CS16_FFTR_UKERNEL_FUNCTION(fn_name) \
XNN_INTERNAL void fn_name( \