aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/restricted/aws/aws-c-common/source/arch/intel/encoding_avx2.c
blob: ebae861381070798d86bdc838954b54468b26e92 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
/**
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

#include <emmintrin.h>
#include <immintrin.h>

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <aws/common/common.h>

/***** Decode logic *****/

/*
 * Decodes ranges of bytes in place
 * For each byte of 'in' that is between lo and hi (inclusive), adds offset and _adds_ it to the corresponding offset in
 * out.
 */
static inline __m256i translate_range(__m256i in, uint8_t lo, uint8_t hi, uint8_t offset) {
    __m256i lovec = _mm256_set1_epi8(lo);
    __m256i hivec = _mm256_set1_epi8((char)(hi - lo));
    __m256i offsetvec = _mm256_set1_epi8(offset);

    __m256i tmp = _mm256_sub_epi8(in, lovec);
    /*
     * we'll use the unsigned min operator to do our comparison. Note that
     * there's no unsigned compare as a comparison intrinsic.
     */
    __m256i mask = _mm256_min_epu8(tmp, hivec);
    /* if mask = tmp, then keep that byte */
    mask = _mm256_cmpeq_epi8(mask, tmp);

    tmp = _mm256_add_epi8(tmp, offsetvec);
    tmp = _mm256_and_si256(tmp, mask);
    return tmp;
}

/*
 * For each 8-bit element in in, if the element equals match, add to the corresponding element in out the value decode.
 */
static inline __m256i translate_exact(__m256i in, uint8_t match, uint8_t decode) {
    __m256i mask = _mm256_cmpeq_epi8(in, _mm256_set1_epi8(match));
    return _mm256_and_si256(mask, _mm256_set1_epi8(decode));
}

/*
 * Input: a pointer to a 256-bit vector of base64 characters
 * The pointed-to-vector is replaced by a 256-bit vector of 6-bit decoded parts;
 * on decode failure, returns false, else returns true on success.
 */
static inline bool decode_vec(__m256i *in) {
    __m256i tmp1, tmp2, tmp3;

    /*
     * Base64 decoding table, see RFC4648
     *
     * Note that we use multiple vector registers to try to allow the CPU to
     * paralellize the merging ORs
     */
    tmp1 = translate_range(*in, 'A', 'Z', 0 + 1);
    tmp2 = translate_range(*in, 'a', 'z', 26 + 1);
    tmp3 = translate_range(*in, '0', '9', 52 + 1);
    tmp1 = _mm256_or_si256(tmp1, translate_exact(*in, '+', 62 + 1));
    tmp2 = _mm256_or_si256(tmp2, translate_exact(*in, '/', 63 + 1));
    tmp3 = _mm256_or_si256(tmp3, _mm256_or_si256(tmp1, tmp2));

    /*
     * We use 0 to mark decode failures, so everything is decoded to one higher
     * than normal. We'll shift this down now.
     */
    *in = _mm256_sub_epi8(tmp3, _mm256_set1_epi8(1));

    /* If any byte is now zero, we had a decode failure */
    __m256i mask = _mm256_cmpeq_epi8(tmp3, _mm256_set1_epi8(0));
    return _mm256_testz_si256(mask, mask);
}

AWS_ALIGNED_TYPEDEF(uint8_t, aligned256[32], 32);

/*
 * Input: a 256-bit vector, interpreted as 32 * 6-bit values
 * Output: a 256-bit vector, the lower 24 bytes of which contain the packed version of the input
 */
static inline __m256i pack_vec(__m256i in) {
    /*
     * Our basic strategy is to split the input vector into three vectors, for each 6-bit component
     * of each 24-bit group, shift the groups into place, then OR the vectors together. Conveniently,
     * we can do this on a (32 bit) dword-by-dword basis.
     *
     * It's important to note that we're interpreting the vector as being little-endian. That is,
     * on entry, we have dwords that look like this:
     *
     * MSB                                 LSB
     * 00DD DDDD 00CC CCCC 00BB BBBB 00AA AAAA
     *
     * And we want to translate to:
     *
     * MSB                                 LSB
     * 0000 0000 AAAA AABB BBBB CCCC CCDD DDDD
     *
     * After which point we can pack these dwords together to produce our final output.
     */
    __m256i maskA = _mm256_set1_epi32(0xFF); // low bits
    __m256i maskB = _mm256_set1_epi32(0xFF00);
    __m256i maskC = _mm256_set1_epi32(0xFF0000);
    __m256i maskD = _mm256_set1_epi32((int)0xFF000000);

    __m256i bitsA = _mm256_slli_epi32(_mm256_and_si256(in, maskA), 18);
    __m256i bitsB = _mm256_slli_epi32(_mm256_and_si256(in, maskB), 4);
    __m256i bitsC = _mm256_srli_epi32(_mm256_and_si256(in, maskC), 10);
    __m256i bitsD = _mm256_srli_epi32(_mm256_and_si256(in, maskD), 24);

    __m256i dwords = _mm256_or_si256(_mm256_or_si256(bitsA, bitsB), _mm256_or_si256(bitsC, bitsD));
    /*
     * Now we have a series of dwords with empty MSBs.
     * We need to pack them together (and shift down) with a shuffle operation.
     * Unfortunately the shuffle operation operates independently within each 128-bit lane,
     * so we'll need to do this in two steps: First we compact dwords within each lane, then
     * we do a dword shuffle to compact the two lanes together.

     * 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 <- byte index (little endian)
     * -- 09 0a 0b -- 06 07 08 -- 03 04 05 -- 00 01 02 <- data index
     *
     * We also reverse the order of 3-byte fragments within each lane; we've constructed
     * those fragments in little endian but the order of fragments within the overall
     * vector is in memory order (big endian)
     */
    const aligned256 shufvec_buf = {
        /* clang-format off */
        /* MSB */
        0xFF, 0xFF, 0xFF, 0xFF, /* Zero out the top 4 bytes of the lane */
        2,  1,  0,
        6,  5,  4,
        10,  9,  8,
        14, 13, 12,

        0xFF, 0xFF, 0xFF, 0xFF, /* Zero out the top 4 bytes of the lane */
        2,  1,  0,
        6,  5,  4,
        10,  9,  8,
        14, 13, 12
        /* LSB */
        /* clang-format on */
    };
    __m256i shufvec = _mm256_load_si256((__m256i const *)&shufvec_buf);

    dwords = _mm256_shuffle_epi8(dwords, shufvec);
    /*
     * Now shuffle the 32-bit words:
     * A B C 0 D E F 0 -> 0 0 A B C D E F
     */
    __m256i shuf32 = _mm256_set_epi32(0, 0, 7, 6, 5, 3, 2, 1);

    dwords = _mm256_permutevar8x32_epi32(dwords, shuf32);

    return dwords;
}

static inline bool decode(const unsigned char *in, unsigned char *out) {
    __m256i vec = _mm256_loadu_si256((__m256i const *)in);
    if (!decode_vec(&vec)) {
        return false;
    }
    vec = pack_vec(vec);

    /*
     * We'll do overlapping writes to get both the low 128 bits and the high 64-bits written.
     * Input (memory order): 0 1 2 3 4 5 - - (dwords)
     * Input (little endian) - - 5 4 3 2 1 0
     * Output in memory:
     * [0 1 2 3] [4 5]
     */
    __m128i lo = _mm256_extracti128_si256(vec, 0);
    /*
     * Unfortunately some compilers don't support _mm256_extract_epi64,
     * so we'll just copy right out of the vector as a fallback
     */

#ifdef HAVE_MM256_EXTRACT_EPI64
    uint64_t hi = _mm256_extract_epi64(vec, 2);
    const uint64_t *p_hi = &hi;
#else
    const uint64_t *p_hi = (uint64_t *)&vec + 2;
#endif

    _mm_storeu_si128((__m128i *)out, lo);
    memcpy(out + 16, p_hi, sizeof(*p_hi));

    return true;
}

size_t aws_common_private_base64_decode_sse41(const unsigned char *in, unsigned char *out, size_t len) {
    if (len % 4) {
        return (size_t)-1;
    }

    size_t outlen = 0;
    while (len > 32) {
        if (!decode(in, out)) {
            return (size_t)-1;
        }
        len -= 32;
        in += 32;
        out += 24;
        outlen += 24;
    }

    if (len > 0) {
        unsigned char tmp_in[32];
        unsigned char tmp_out[24];

        memset(tmp_out, 0xEE, sizeof(tmp_out));

        /* We need to ensure the vector contains valid b64 characters */
        memset(tmp_in, 'A', sizeof(tmp_in));
        memcpy(tmp_in, in, len);

        size_t final_out = (3 * len) / 4;

        /* Check for end-of-string padding (up to 2 characters) */
        for (int i = 0; i < 2; i++) {
            if (tmp_in[len - 1] == '=') {
                tmp_in[len - 1] = 'A'; /* make sure the inner loop doesn't bail out */
                len--;
                final_out--;
            }
        }

        if (!decode(tmp_in, tmp_out)) {
            return (size_t)-1;
        }

        /* Check that there are no trailing ones bits */
        for (size_t i = final_out; i < sizeof(tmp_out); i++) {
            if (tmp_out[i]) {
                return (size_t)-1;
            }
        }

        memcpy(out, tmp_out, final_out);
        outlen += final_out;
    }
    return outlen;
}

/***** Encode logic *****/
static inline __m256i encode_chars(__m256i in) {
    __m256i tmp1, tmp2, tmp3;

    /*
     * Base64 encoding table, see RFC4648
     *
     * We again use fan-in for the ORs here.
     */
    tmp1 = translate_range(in, 0, 25, 'A');
    tmp2 = translate_range(in, 26, 26 + 25, 'a');
    tmp3 = translate_range(in, 52, 61, '0');
    tmp1 = _mm256_or_si256(tmp1, translate_exact(in, 62, '+'));
    tmp2 = _mm256_or_si256(tmp2, translate_exact(in, 63, '/'));

    return _mm256_or_si256(tmp3, _mm256_or_si256(tmp1, tmp2));
}

/*
 * Input: A 256-bit vector, interpreted as 24 bytes (LSB) plus 8 bytes of high-byte padding
 * Output: A 256-bit vector of base64 characters
 */
static inline __m256i encode_stride(__m256i vec) {
    /*
     * First, since byte-shuffle operations operate within 128-bit subvectors, swap around the dwords
     * to balance the amount of actual data between 128-bit subvectors.
     * After this we want the LE representation to look like: -- XX XX XX -- XX XX XX
     */
    __m256i shuf32 = _mm256_set_epi32(7, 5, 4, 3, 6, 2, 1, 0);
    vec = _mm256_permutevar8x32_epi32(vec, shuf32);

    /*
     * Next, within each group of 3 bytes, we need to byteswap into little endian form so our bitshifts
     * will work properly. We also shuffle around so that each dword has one 3-byte group, plus one byte
     * (MSB) of zero-padding.
     * Because this is a byte-shuffle, indexes are within each 128-bit subvector.
     *
     * -- -- -- -- 11 10 09 08 07 06 05 04 03 02 01 00
     */

    const aligned256 shufvec_buf = {
        /* clang-format off */
        /* MSB */
        2, 1, 0, 0xFF,
        5, 4, 3, 0xFF,
        8, 7, 6, 0xFF,
        11, 10, 9, 0xFF,

        2, 1, 0, 0xFF,
        5, 4, 3, 0xFF,
        8, 7, 6, 0xFF,
        11, 10, 9, 0xFF
        /* LSB */
        /* clang-format on */
    };
    vec = _mm256_shuffle_epi8(vec, _mm256_load_si256((__m256i const *)&shufvec_buf));

    /*
     * Now shift and mask to split out 6-bit groups.
     * We'll also do a second byteswap to get back into big-endian
     */
    __m256i mask0 = _mm256_set1_epi32(0x3F);
    __m256i mask1 = _mm256_set1_epi32(0x3F << 6);
    __m256i mask2 = _mm256_set1_epi32(0x3F << 12);
    __m256i mask3 = _mm256_set1_epi32(0x3F << 18);

    __m256i digit0 = _mm256_and_si256(mask0, vec);
    __m256i digit1 = _mm256_and_si256(mask1, vec);
    __m256i digit2 = _mm256_and_si256(mask2, vec);
    __m256i digit3 = _mm256_and_si256(mask3, vec);

    /*
     * Because we want to byteswap, the low-order digit0 goes into the
     * high-order byte
     */
    digit0 = _mm256_slli_epi32(digit0, 24);
    digit1 = _mm256_slli_epi32(digit1, 10);
    digit2 = _mm256_srli_epi32(digit2, 4);
    digit3 = _mm256_srli_epi32(digit3, 18);

    vec = _mm256_or_si256(_mm256_or_si256(digit0, digit1), _mm256_or_si256(digit2, digit3));

    /* Finally translate to the base64 character set */
    return encode_chars(vec);
}

void aws_common_private_base64_encode_sse41(const uint8_t *input, uint8_t *output, size_t inlen) {
    __m256i instride, outstride;

    while (inlen >= 32) {
        /*
         * Where possible, we'll load a full vector at a time and ignore the over-read.
         * However, if we have < 32 bytes left, this would result in a potential read
         * of unreadable pages, so we use bounce buffers below.
         */
        instride = _mm256_loadu_si256((__m256i const *)input);
        outstride = encode_stride(instride);
        _mm256_storeu_si256((__m256i *)output, outstride);

        input += 24;
        output += 32;
        inlen -= 24;
    }

    while (inlen) {
        /*
         * We need to go through a bounce buffer for anything remaining, as we
         * don't want to over-read or over-write the ends of the buffers.
         */
        size_t stridelen = inlen > 24 ? 24 : inlen;
        size_t outlen = ((stridelen + 2) / 3) * 4;

        memset(&instride, 0, sizeof(instride));
        memcpy(&instride, input, stridelen);

        outstride = encode_stride(instride);
        memcpy(output, &outstride, outlen);

        if (inlen < 24) {
            if (inlen % 3 >= 1) {
                /* AA== or AAA= */
                output[outlen - 1] = '=';
            }
            if (inlen % 3 == 1) {
                /* AA== */
                output[outlen - 2] = '=';
            }

            return;
        }

        input += stridelen;
        output += outlen;
        inlen -= stridelen;
    }
}