aboutsummaryrefslogtreecommitdiff
path: root/networking/aarch64/chksum_simd.c
blob: 90c00eb7cabe5a0f3e28b6e8f94c17e9f5750334 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*
 * AArch64-specific checksum implementation using NEON
 *
 * Copyright (c) 2020, Arm Limited.
 * SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
 */

#include "networking.h"
#include "../chksum_common.h"

#ifndef __ARM_NEON
#pragma GCC target("+simd")
#endif

#include <arm_neon.h>

always_inline
static inline uint64_t
slurp_head64(const void **pptr, uint32_t *nbytes)
{
    Assert(*nbytes >= 8);
    uint64_t sum = 0;
    uint32_t off = (uintptr_t) *pptr % 8;
    if (likely(off != 0))
    {
	/* Get rid of bytes 0..off-1 */
	const unsigned char *ptr64 = align_ptr(*pptr, 8);
	uint64_t mask = ALL_ONES << (CHAR_BIT * off);
	uint64_t val = load64(ptr64) & mask;
	/* Fold 64-bit sum to 33 bits */
	sum = val >> 32;
	sum += (uint32_t) val;
	*pptr = ptr64 + 8;
	*nbytes -= 8 - off;
    }
    return sum;
}

always_inline
static inline uint64_t
slurp_tail64(uint64_t sum, const void *ptr, uint32_t nbytes)
{
    Assert(nbytes < 8);
    if (likely(nbytes != 0))
    {
	/* Get rid of bytes 7..nbytes */
	uint64_t mask = ALL_ONES >> (CHAR_BIT * (8 - nbytes));
	Assert(__builtin_popcountl(mask) / CHAR_BIT == nbytes);
	uint64_t val = load64(ptr) & mask;
	sum += val >> 32;
	sum += (uint32_t) val;
	nbytes = 0;
    }
    Assert(nbytes == 0);
    return sum;
}

unsigned short
__chksum_aarch64_simd(const void *ptr, unsigned int nbytes)
{
    bool swap = (uintptr_t) ptr & 1;
    uint64_t sum;

    if (unlikely(nbytes < 50))
    {
	sum = slurp_small(ptr, nbytes);
	swap = false;
	goto fold;
    }

    /* 8-byte align pointer */
    Assert(nbytes >= 8);
    sum = slurp_head64(&ptr, &nbytes);
    Assert(((uintptr_t) ptr & 7) == 0);

    const uint32_t *may_alias ptr32 = ptr;

    uint64x2_t vsum0 = { 0, 0 };
    uint64x2_t vsum1 = { 0, 0 };
    uint64x2_t vsum2 = { 0, 0 };
    uint64x2_t vsum3 = { 0, 0 };

    /* Sum groups of 64 bytes */
    for (uint32_t i = 0; i < nbytes / 64; i++)
    {
	uint32x4_t vtmp0 = vld1q_u32(ptr32);
	uint32x4_t vtmp1 = vld1q_u32(ptr32 + 4);
	uint32x4_t vtmp2 = vld1q_u32(ptr32 + 8);
	uint32x4_t vtmp3 = vld1q_u32(ptr32 + 12);
	vsum0 = vpadalq_u32(vsum0, vtmp0);
	vsum1 = vpadalq_u32(vsum1, vtmp1);
	vsum2 = vpadalq_u32(vsum2, vtmp2);
	vsum3 = vpadalq_u32(vsum3, vtmp3);
	ptr32 += 16;
    }
    nbytes %= 64;

    /* Fold vsum2 and vsum3 into vsum0 and vsum1 */
    vsum0 = vpadalq_u32(vsum0, vreinterpretq_u32_u64(vsum2));
    vsum1 = vpadalq_u32(vsum1, vreinterpretq_u32_u64(vsum3));

    /* Add any trailing group of 32 bytes */
    if (nbytes & 32)
    {
	uint32x4_t vtmp0 = vld1q_u32(ptr32);
	uint32x4_t vtmp1 = vld1q_u32(ptr32 + 4);
	vsum0 = vpadalq_u32(vsum0, vtmp0);
	vsum1 = vpadalq_u32(vsum1, vtmp1);
	ptr32 += 8;
	nbytes -= 32;
    }
    Assert(nbytes < 32);

    /* Fold vsum1 into vsum0 */
    vsum0 = vpadalq_u32(vsum0, vreinterpretq_u32_u64(vsum1));

    /* Add any trailing group of 16 bytes */
    if (nbytes & 16)
    {
	uint32x4_t vtmp = vld1q_u32(ptr32);
	vsum0 = vpadalq_u32(vsum0, vtmp);
	ptr32 += 4;
	nbytes -= 16;
    }
    Assert(nbytes < 16);

    /* Add any trailing group of 8 bytes */
    if (nbytes & 8)
    {
	uint32x2_t vtmp = vld1_u32(ptr32);
	vsum0 = vaddw_u32(vsum0, vtmp);
	ptr32 += 2;
	nbytes -= 8;
    }
    Assert(nbytes < 8);

    uint64_t val = vaddlvq_u32(vreinterpretq_u32_u64(vsum0));
    sum += val >> 32;
    sum += (uint32_t) val;

    /* Handle any trailing 0..7 bytes */
    sum = slurp_tail64(sum, ptr32, nbytes);

fold:
    return fold_and_swap(sum, swap);
}