diff options
Diffstat (limited to 'av1/encoder/x86/av1_fwd_txfm2d_avx2.c')
-rw-r--r-- | av1/encoder/x86/av1_fwd_txfm2d_avx2.c | 328 |
1 files changed, 327 insertions, 1 deletions
diff --git a/av1/encoder/x86/av1_fwd_txfm2d_avx2.c b/av1/encoder/x86/av1_fwd_txfm2d_avx2.c index fa5c66abf..b898fc60d 100644 --- a/av1/encoder/x86/av1_fwd_txfm2d_avx2.c +++ b/av1/encoder/x86/av1_fwd_txfm2d_avx2.c @@ -1574,6 +1574,332 @@ static const transform_1d_avx2 row_txfm16x16_arr[TX_TYPES] = { fadst16x16_new_avx2 // H_FLIPADST }; +static const transform_1d_sse2 col_txfm8x8_arr[TX_TYPES] = { + fdct8x8_new_sse2, // DCT_DCT + fadst8x8_new_sse2, // ADST_DCT + fdct8x8_new_sse2, // DCT_ADST + fadst8x8_new_sse2, // ADST_ADST + fadst8x8_new_sse2, // FLIPADST_DCT + fdct8x8_new_sse2, // DCT_FLIPADST + fadst8x8_new_sse2, // FLIPADST_FLIPADST + fadst8x8_new_sse2, // ADST_FLIPADST + fadst8x8_new_sse2, // FLIPADST_ADST + fidentity8x8_new_sse2, // IDTX + fdct8x8_new_sse2, // V_DCT + fidentity8x8_new_sse2, // H_DCT + fadst8x8_new_sse2, // V_ADST + fidentity8x8_new_sse2, // H_ADST + fadst8x8_new_sse2, // V_FLIPADST + fidentity8x8_new_sse2, // H_FLIPADST +}; + +static const transform_1d_sse2 row_txfm8x8_arr[TX_TYPES] = { + fdct8x8_new_sse2, // DCT_DCT + fdct8x8_new_sse2, // ADST_DCT + fadst8x8_new_sse2, // DCT_ADST + fadst8x8_new_sse2, // ADST_ADST + fdct8x8_new_sse2, // FLIPADST_DCT + fadst8x8_new_sse2, // DCT_FLIPADST + fadst8x8_new_sse2, // FLIPADST_FLIPADST + fadst8x8_new_sse2, // ADST_FLIPADST + fadst8x8_new_sse2, // FLIPADST_ADST + fidentity8x8_new_sse2, // IDTX + fidentity8x8_new_sse2, // V_DCT + fdct8x8_new_sse2, // H_DCT + fidentity8x8_new_sse2, // V_ADST + fadst8x8_new_sse2, // H_ADST + fidentity8x8_new_sse2, // V_FLIPADST + fadst8x8_new_sse2 // H_FLIPADST +}; + +static INLINE void load_buffer_and_round_shift(const int16_t *in, int stride, + __m128i *out, int bit) { + out[0] = _mm_load_si128((const __m128i *)(in + 0 * stride)); + out[1] = _mm_load_si128((const __m128i *)(in + 1 * stride)); + out[2] = _mm_load_si128((const __m128i *)(in + 2 * stride)); + out[3] = _mm_load_si128((const __m128i *)(in + 3 * stride)); + out[4] = _mm_load_si128((const __m128i *)(in + 4 * stride)); + out[5] = _mm_load_si128((const __m128i *)(in + 5 * stride)); + out[6] = _mm_load_si128((const __m128i *)(in + 6 * stride)); + out[7] = _mm_load_si128((const __m128i *)(in + 7 * stride)); + out[0] = _mm_slli_epi16(out[0], bit); + out[1] = _mm_slli_epi16(out[1], bit); + out[2] = _mm_slli_epi16(out[2], bit); + out[3] = _mm_slli_epi16(out[3], bit); + out[4] = _mm_slli_epi16(out[4], bit); + out[5] = _mm_slli_epi16(out[5], bit); + out[6] = _mm_slli_epi16(out[6], bit); + out[7] = _mm_slli_epi16(out[7], bit); +} + +static INLINE void load_buffer_and_flip_round_shift(const int16_t *in, + int stride, __m128i *out, + int bit) { + out[7] = load_16bit_to_16bit(in + 0 * stride); + out[6] = load_16bit_to_16bit(in + 1 * stride); + out[5] = load_16bit_to_16bit(in + 2 * stride); + out[4] = load_16bit_to_16bit(in + 3 * stride); + out[3] = load_16bit_to_16bit(in + 4 * stride); + out[2] = load_16bit_to_16bit(in + 5 * stride); + out[1] = load_16bit_to_16bit(in + 6 * stride); + out[0] = load_16bit_to_16bit(in + 7 * stride); + out[7] = _mm_slli_epi16(out[7], bit); + out[6] = _mm_slli_epi16(out[6], bit); + out[5] = _mm_slli_epi16(out[5], bit); + out[4] = _mm_slli_epi16(out[4], bit); + out[3] = _mm_slli_epi16(out[3], bit); + out[2] = _mm_slli_epi16(out[2], bit); + out[1] = _mm_slli_epi16(out[1], bit); + out[0] = _mm_slli_epi16(out[0], bit); +} + +#define TRANSPOSE_8X8_AVX2() \ + { \ + /* aa0: 00 10 01 11 02 12 03 13 | 40 50 41 51 42 52 43 53*/ \ + /* aa1: 04 14 05 15 06 16 07 17 | 44 54 45 55 46 56 47 57*/ \ + /* aa2: 20 30 21 31 22 32 23 33 | 60 70 61 71 62 72 63 73*/ \ + /* aa3: 24 34 25 35 26 36 27 37 | 64 74 65 75 66 76 67 77*/ \ + const __m256i aa0 = _mm256_unpacklo_epi16(b0, b1); \ + const __m256i aa1 = _mm256_unpackhi_epi16(b0, b1); \ + const __m256i aa2 = _mm256_unpacklo_epi16(b2, b3); \ + const __m256i aa3 = _mm256_unpackhi_epi16(b2, b3); \ + /* Unpack 32 bit elements resulting in: */ \ + /* bb0: 00 10 20 30 01 11 21 31 | 40 50 60 70 41 51 61 71*/ \ + /* bb1: 02 12 22 32 03 13 23 33 | 42 52 62 72 43 53 63 73*/ \ + /* bb2: 04 14 24 34 05 15 25 35 | 44 54 64 74 45 55 65 75*/ \ + /* bb2: 06 16 26 36 07 17 27 37 | 46 56 66 76 47 57 67 77*/ \ + const __m256i bb0 = _mm256_unpacklo_epi32(aa0, aa2); \ + const __m256i bb1 = _mm256_unpackhi_epi32(aa0, aa2); \ + const __m256i bb2 = _mm256_unpacklo_epi32(aa1, aa3); \ + const __m256i bb3 = _mm256_unpackhi_epi32(aa1, aa3); \ + /* bb0: 00 10 20 30 40 50 60 70| 01 11 21 31 41 51 61 71*/ \ + /* bb1: 02 12 22 32 42 52 62 72| 03 13 23 33 43 53 63 73*/ \ + /* bb2: 04 14 24 34 44 54 64 74| 05 15 25 35 45 55 65 75*/ \ + /* bb2: 06 16 26 36 46 56 66 76| 07 17 27 37 47 57 67 77*/ \ + c0 = _mm256_permute4x64_epi64(bb0, 0xd8); \ + c1 = _mm256_permute4x64_epi64(bb1, 0xd8); \ + c2 = _mm256_permute4x64_epi64(bb2, 0xd8); \ + c3 = _mm256_permute4x64_epi64(bb3, 0xd8); \ + } + +static INLINE void transpose_round_shift_flip_8x8(__m128i *const in, + __m128i *const out, int bit) { + __m256i c0, c1, c2, c3; + bit = -bit; + const __m256i rounding = _mm256_set1_epi16(1 << (bit - 1)); + const __m256i s04 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[0]), in[4], 0x1); + const __m256i s15 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[1]), in[5], 0x1); + const __m256i s26 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[2]), in[6], 0x1); + const __m256i s37 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[3]), in[7], 0x1); + + const __m256i a0 = _mm256_adds_epi16(s04, rounding); + const __m256i a1 = _mm256_adds_epi16(s15, rounding); + const __m256i a2 = _mm256_adds_epi16(s26, rounding); + const __m256i a3 = _mm256_adds_epi16(s37, rounding); + + // b0: 00 01 02 03 04 05 06 07 | 40 41 42 43 44 45 46 47 + // b1: 10 11 12 13 14 15 16 17 | 50 51 52 53 54 55 56 57 + // b2: 20 21 22 23 24 25 26 27 | 60 61 62 63 64 65 66 67 + // b3: 30 31 32 33 34 35 36 37 | 70 71 72 73 74 75 76 77 + const __m256i b0 = _mm256_srai_epi16(a0, bit); + const __m256i b1 = _mm256_srai_epi16(a1, bit); + const __m256i b2 = _mm256_srai_epi16(a2, bit); + const __m256i b3 = _mm256_srai_epi16(a3, bit); + + TRANSPOSE_8X8_AVX2() + + // Unpack 64 bit elements resulting in: + // out[7]: 00 10 20 30 40 50 60 70 + // out[6]: 01 11 21 31 41 51 61 71 + // out[5]: 02 12 22 32 42 52 62 72 + // out[4]: 03 13 23 33 43 53 63 73 + // out[3]: 04 14 24 34 44 54 64 74 + // out[2]: 05 15 25 35 45 55 65 75 + // out[1]: 06 16 26 36 46 56 66 76 + // out[0]: 07 17 27 37 47 57 67 77 + out[7] = _mm256_castsi256_si128(c0); + out[6] = _mm256_extractf128_si256(c0, 1); + out[5] = _mm256_castsi256_si128(c1); + out[4] = _mm256_extractf128_si256(c1, 1); + out[3] = _mm256_castsi256_si128(c2); + out[2] = _mm256_extractf128_si256(c2, 1); + out[1] = _mm256_castsi256_si128(c3); + out[0] = _mm256_extractf128_si256(c3, 1); +} + +static INLINE void transpose_round_shift_8x8(__m128i *const in, + __m128i *const out, int bit) { + __m256i c0, c1, c2, c3; + bit = -bit; + const __m256i rounding = _mm256_set1_epi16(1 << (bit - 1)); + const __m256i s04 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[0]), in[4], 0x1); + const __m256i s15 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[1]), in[5], 0x1); + const __m256i s26 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[2]), in[6], 0x1); + const __m256i s37 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[3]), in[7], 0x1); + + const __m256i a0 = _mm256_adds_epi16(s04, rounding); + const __m256i a1 = _mm256_adds_epi16(s15, rounding); + const __m256i a2 = _mm256_adds_epi16(s26, rounding); + const __m256i a3 = _mm256_adds_epi16(s37, rounding); + + // b0: 00 01 02 03 04 05 06 07 | 40 41 42 43 44 45 46 47 + // b1: 10 11 12 13 14 15 16 17 | 50 51 52 53 54 55 56 57 + // b2: 20 21 22 23 24 25 26 27 | 60 61 62 63 64 65 66 67 + // b3: 30 31 32 33 34 35 36 37 | 70 71 72 73 74 75 76 77 + const __m256i b0 = _mm256_srai_epi16(a0, bit); + const __m256i b1 = _mm256_srai_epi16(a1, bit); + const __m256i b2 = _mm256_srai_epi16(a2, bit); + const __m256i b3 = _mm256_srai_epi16(a3, bit); + + TRANSPOSE_8X8_AVX2() + // Unpack 64 bit elements resulting in: + // out[7]: 00 10 20 30 40 50 60 70 + // out[6]: 01 11 21 31 41 51 61 71 + // out[5]: 02 12 22 32 42 52 62 72 + // out[4]: 03 13 23 33 43 53 63 73 + // out[3]: 04 14 24 34 44 54 64 74 + // out[2]: 05 15 25 35 45 55 65 75 + // out[1]: 06 16 26 36 46 56 66 76 + // out[0]: 07 17 27 37 47 57 67 77 + out[0] = _mm256_castsi256_si128(c0); + out[1] = _mm256_extractf128_si256(c0, 1); + out[2] = _mm256_castsi256_si128(c1); + out[3] = _mm256_extractf128_si256(c1, 1); + out[4] = _mm256_castsi256_si128(c2); + out[5] = _mm256_extractf128_si256(c2, 1); + out[6] = _mm256_castsi256_si128(c3); + out[7] = _mm256_extractf128_si256(c3, 1); +} + +static INLINE void transpose_16bit_and_store_8x8(const __m128i *const in, + int32_t *output) { + // in[0]: 00 01 02 03 04 05 06 07 + // in[1]: 10 11 12 13 14 15 16 17 + // in[2]: 20 21 22 23 24 25 26 27 + // in[3]: 30 31 32 33 34 35 36 37 + // in[4]: 40 41 42 43 44 45 46 47 + // in[5]: 50 51 52 53 54 55 56 57 + // in[6]: 60 61 62 63 64 65 66 67 + // in[7]: 70 71 72 73 74 75 76 77 + // to: + // s04: 00 01 02 03 04 05 06 07 | 40 41 42 43 44 45 46 47 + // s15: 10 11 12 13 14 15 16 17 | 50 51 52 53 54 55 56 57 + // s26: 20 21 22 23 24 25 26 27 | 60 61 62 63 64 65 66 67 + // s37: 30 31 32 33 34 35 36 37 | 70 71 72 73 74 75 76 77 + const __m256i s04 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[0]), in[4], 0x1); + const __m256i s15 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[1]), in[5], 0x1); + const __m256i s26 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[2]), in[6], 0x1); + const __m256i s37 = + _mm256_insertf128_si256(_mm256_castsi128_si256(in[3]), in[7], 0x1); + + // a0: 00 10 01 11 02 12 03 13 | 40 50 41 51 42 52 43 53 + // a1: 04 14 05 15 06 16 07 17 | 44 54 45 55 46 56 47 57 + // a2: 20 30 21 31 22 32 23 33 | 60 70 61 71 62 72 63 73 + // a3: 24 34 25 35 26 36 27 37 | 64 74 65 75 66 76 67 77 + const __m256i a0 = _mm256_unpacklo_epi16(s04, s15); + const __m256i a1 = _mm256_unpackhi_epi16(s04, s15); + const __m256i a2 = _mm256_unpacklo_epi16(s26, s37); + const __m256i a3 = _mm256_unpackhi_epi16(s26, s37); + + // Unpack 32 bit elements resulting in: + // b0: 00 10 20 30 01 11 21 31 | 40 50 60 70 41 51 61 71 + // b1: 02 12 22 32 03 13 23 33 | 42 52 62 72 43 53 63 73 + // b2: 04 14 24 34 05 15 25 35 | 44 54 64 74 45 55 65 75 + // b2: 06 16 26 36 07 17 27 37 | 46 56 66 76 47 57 67 77 + const __m256i b0 = _mm256_unpacklo_epi32(a0, a2); + const __m256i b1 = _mm256_unpackhi_epi32(a0, a2); + const __m256i b2 = _mm256_unpacklo_epi32(a1, a3); + const __m256i b3 = _mm256_unpackhi_epi32(a1, a3); + + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + // 04 14 24 34 44 54 64 74 + // 05 15 25 35 45 55 65 75 + // 06 16 26 36 46 56 66 76 + // 07 17 27 37 47 57 67 77 + const __m256i a_lo = _mm256_unpacklo_epi16(b0, b0); + const __m256i a_hi = _mm256_unpackhi_epi16(b0, b0); + const __m256i b_lo = _mm256_unpacklo_epi16(b1, b1); + const __m256i b_hi = _mm256_unpackhi_epi16(b1, b1); + const __m256i c_lo = _mm256_unpacklo_epi16(b2, b2); + const __m256i c_hi = _mm256_unpackhi_epi16(b2, b2); + const __m256i d_lo = _mm256_unpacklo_epi16(b3, b3); + const __m256i d_hi = _mm256_unpackhi_epi16(b3, b3); + + const __m256i a_1 = _mm256_srai_epi32(a_lo, 16); + const __m256i a_2 = _mm256_srai_epi32(a_hi, 16); + const __m256i a_3 = _mm256_srai_epi32(b_lo, 16); + const __m256i a_4 = _mm256_srai_epi32(b_hi, 16); + const __m256i a_5 = _mm256_srai_epi32(c_lo, 16); + const __m256i a_6 = _mm256_srai_epi32(c_hi, 16); + const __m256i a_7 = _mm256_srai_epi32(d_lo, 16); + const __m256i a_8 = _mm256_srai_epi32(d_hi, 16); + + _mm256_store_si256((__m256i *)output, a_1); + _mm256_store_si256((__m256i *)(output + 8), a_2); + _mm256_store_si256((__m256i *)(output + 16), a_3); + _mm256_store_si256((__m256i *)(output + 24), a_4); + _mm256_store_si256((__m256i *)(output + 32), a_5); + _mm256_store_si256((__m256i *)(output + 40), a_6); + _mm256_store_si256((__m256i *)(output + 48), a_7); + _mm256_store_si256((__m256i *)(output + 56), a_8); +} + +static void av1_lowbd_fwd_txfm2d_8x8_avx2(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + __m128i buf0[8], buf1[8], *buf; + const int8_t *shift = av1_fwd_txfm_shift_ls[TX_8X8]; + const int txw_idx = get_txw_idx(TX_8X8); + const int txh_idx = get_txh_idx(TX_8X8); + const int cos_bit_col = av1_fwd_cos_bit_col[txw_idx][txh_idx]; + const int cos_bit_row = av1_fwd_cos_bit_row[txw_idx][txh_idx]; + const transform_1d_sse2 col_txfm = col_txfm8x8_arr[tx_type]; + const transform_1d_sse2 row_txfm = row_txfm8x8_arr[tx_type]; + int ud_flip, lr_flip; + + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + // Condition to check shift bit is avoided while round shifting, by assuming + // that shift[0] will always be positive. + assert(shift[0] > 0); + if (ud_flip) + load_buffer_and_flip_round_shift(input, stride, buf0, shift[0]); + else + load_buffer_and_round_shift(input, stride, buf0, shift[0]); + + col_txfm(buf0, buf0, cos_bit_col); + // Condition to check shift bit is avoided while round shifting, by assuming + // that shift[1] will always be negative. + assert(shift[1] < 0); + + if (lr_flip) { + transpose_round_shift_flip_8x8(buf0, buf1, shift[1]); + } else { + transpose_round_shift_8x8(buf0, buf1, shift[1]); + } + + buf = buf1; + row_txfm(buf, buf, cos_bit_row); + + // Round and shift operation is avoided here as the shift bit is assumed to be + // zero always. + assert(shift[2] == 0); + transpose_16bit_and_store_8x8(buf, output); +} + static void lowbd_fwd_txfm2d_16x16_avx2(const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd) { (void)bd; @@ -2781,7 +3107,7 @@ static void lowbd_fwd_txfm2d_16x8_avx2(const int16_t *input, int32_t *output, static FwdTxfm2dFunc fwd_txfm2d_func_ls[TX_SIZES_ALL] = { av1_lowbd_fwd_txfm2d_4x4_sse2, // 4x4 transform - av1_lowbd_fwd_txfm2d_8x8_sse2, // 8x8 transform + av1_lowbd_fwd_txfm2d_8x8_avx2, // 8x8 transform lowbd_fwd_txfm2d_16x16_avx2, // 16x16 transform lowbd_fwd_txfm2d_32x32_avx2, // 32x32 transform lowbd_fwd_txfm2d_64x64_avx2, // 64x64 transform |