aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/restricted/aws/s2n/tls/s2n_handshake.c
blob: 922cd67c2b678c2328451c358f796b0f3cfef63e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

#include <stdint.h>

#include "error/s2n_errno.h"

#include "tls/s2n_connection.h"
#include "tls/s2n_record.h"
#include "tls/s2n_cipher_suites.h"
#include "tls/s2n_tls.h"

#include "stuffer/s2n_stuffer.h"

#include "utils/s2n_safety.h"
#include "utils/s2n_map.h"

int s2n_handshake_write_header(struct s2n_stuffer *out, uint8_t message_type)
{
    S2N_ERROR_IF(s2n_stuffer_data_available(out), S2N_ERR_HANDSHAKE_STATE);

    /* Write the message header */
    GUARD(s2n_stuffer_write_uint8(out, message_type));

    /* Leave the length blank for now */
    uint16_t length = 0;
    GUARD(s2n_stuffer_write_uint24(out, length));

    return 0;
}

int s2n_handshake_finish_header(struct s2n_stuffer *out)
{
    uint16_t length = s2n_stuffer_data_available(out);
    S2N_ERROR_IF(length < TLS_HANDSHAKE_HEADER_LENGTH, S2N_ERR_SIZE_MISMATCH);

    uint16_t payload = length - TLS_HANDSHAKE_HEADER_LENGTH;

    /* Write the message header */
    GUARD(s2n_stuffer_rewrite(out));
    GUARD(s2n_stuffer_skip_write(out, 1));
    GUARD(s2n_stuffer_write_uint24(out, payload));
    GUARD(s2n_stuffer_skip_write(out, payload));

    return 0;
}

int s2n_handshake_parse_header(struct s2n_connection *conn, uint8_t * message_type, uint32_t * length)
{
    S2N_ERROR_IF(s2n_stuffer_data_available(&conn->handshake.io) < TLS_HANDSHAKE_HEADER_LENGTH, S2N_ERR_SIZE_MISMATCH);

    /* read the message header */
    GUARD(s2n_stuffer_read_uint8(&conn->handshake.io, message_type));
    GUARD(s2n_stuffer_read_uint24(&conn->handshake.io, length));

    return 0;
}

static int s2n_handshake_get_hash_state_ptr(struct s2n_connection *conn, s2n_hash_algorithm hash_alg, struct s2n_hash_state **hash_state)
{
    notnull_check(conn);

    switch (hash_alg) {
    case S2N_HASH_MD5:
        *hash_state = &conn->handshake.md5;
        break;
    case S2N_HASH_SHA1:
        *hash_state = &conn->handshake.sha1;
        break;
    case S2N_HASH_SHA224:
        *hash_state = &conn->handshake.sha224;
        break;
    case S2N_HASH_SHA256:
        *hash_state = &conn->handshake.sha256;
        break;
    case S2N_HASH_SHA384:
        *hash_state = &conn->handshake.sha384;
        break;
    case S2N_HASH_SHA512:
        *hash_state = &conn->handshake.sha512;
        break;
    case S2N_HASH_MD5_SHA1:
        *hash_state = &conn->handshake.md5_sha1;
        break;
    default:
        S2N_ERROR(S2N_ERR_HASH_INVALID_ALGORITHM);
        break;
    }

    return 0;
}

int s2n_handshake_reset_hash_state(struct s2n_connection *conn, s2n_hash_algorithm hash_alg)
{
    struct s2n_hash_state *hash_state_ptr = NULL;
    GUARD(s2n_handshake_get_hash_state_ptr(conn, hash_alg, &hash_state_ptr));

    GUARD(s2n_hash_reset(hash_state_ptr));

    return 0;
}

/* Copy the current hash state into the caller supplied pointer.
 * NOTE: If the underlying digest implementation is using the EVP API
 * then a pointer to the EVP ctx and md is copied. So you are actually
 * taking a reference, not a value.
 * Before using the hash_state returned by this function you must
 * use s2n_hash_copy() to avoid modifying the underlying value.
 */
int s2n_handshake_get_hash_state(struct s2n_connection *conn, s2n_hash_algorithm hash_alg, struct s2n_hash_state *hash_state)
{
    notnull_check(hash_state);

    struct s2n_hash_state *hash_state_ptr = NULL;
    GUARD(s2n_handshake_get_hash_state_ptr(conn, hash_alg, &hash_state_ptr));

    *hash_state = *hash_state_ptr;

    return 0;
}

int s2n_handshake_require_all_hashes(struct s2n_handshake *handshake)
{
    memset(handshake->required_hash_algs, 1, sizeof(handshake->required_hash_algs));
    return 0;
}

static int s2n_handshake_require_hash(struct s2n_handshake *handshake, s2n_hash_algorithm hash_alg)
{
    handshake->required_hash_algs[hash_alg] = 1;
    return 0;
}

uint8_t s2n_handshake_is_hash_required(struct s2n_handshake *handshake, s2n_hash_algorithm hash_alg)
{
    return handshake->required_hash_algs[hash_alg];
}

/* Update the required handshake hash algs depending on current handshake session state.
 * This function must called at the end of a handshake message handler. Additionally it must be called after the
 * ClientHello or ServerHello is processed in client and server mode respectively. The relevant handshake parameters
 * are not available until those messages are processed.
 */
int s2n_conn_update_required_handshake_hashes(struct s2n_connection *conn)
{
    /* Clear all of the required hashes */
    memset(conn->handshake.required_hash_algs, 0, sizeof(conn->handshake.required_hash_algs));

    message_type_t handshake_message = s2n_conn_get_current_message_type(conn);
    const uint8_t client_cert_verify_done = (handshake_message >= CLIENT_CERT_VERIFY) ? 1 : 0;
    s2n_cert_auth_type client_cert_auth_type;
    GUARD(s2n_connection_get_client_auth_type(conn, &client_cert_auth_type));

    /* If client authentication is possible, all hashes are needed until we're past CLIENT_CERT_VERIFY. */
    if ((client_cert_auth_type != S2N_CERT_AUTH_NONE) && !client_cert_verify_done) {
        GUARD(s2n_handshake_require_all_hashes(&conn->handshake));
        return 0;
    }

    /* We don't need all of the hashes. Set the hash alg(s) required for the PRF */
    switch (conn->actual_protocol_version) {
    case S2N_SSLv3:
    case S2N_TLS10:
    case S2N_TLS11:
        GUARD(s2n_handshake_require_hash(&conn->handshake, S2N_HASH_MD5));
        GUARD(s2n_handshake_require_hash(&conn->handshake, S2N_HASH_SHA1));
        break;
    case S2N_TLS12:
        /* fall through */
    case S2N_TLS13:
    {
        /* For TLS 1.2 and TLS 1.3, the cipher suite defines the PRF hash alg */
        s2n_hmac_algorithm prf_alg = conn->secure.cipher_suite->prf_alg;
        s2n_hash_algorithm hash_alg;
        GUARD(s2n_hmac_hash_alg(prf_alg, &hash_alg));
        GUARD(s2n_handshake_require_hash(&conn->handshake, hash_alg));
        break;
    }
    }

    return 0;
}

/*
 * Take a hostname and return a single "simple" wildcard domain name that matches it.
 * The output wildcard representation is meant to be compared directly against a wildcard domain in a certificate.
 * We take a restrictive definition of wildcard here to achieve a single unique wildcard representation
 * given any input hostname.
 * No embedded or trailing wildcards are supported. Additionally, we only support one level of wildcard matching.
 * Thus the output should be a single wildcard character in the first(left-most) DNS label.
 *
 * Example:
 * - my.domain.name -> *.domain.name
 *
 * Not supported:
 * - my.domain.name -> m*.domain.name
 * - my.domain.name -> my.*.name
 * etc.
 *
 * The motivation for using a constrained definition of wildcard:
 * - Support for issuing non-simple wildcard certificates is insignificant.
 * - Certificate selection can be implemented with a constant number of lookups(two).
 */
int s2n_create_wildcard_hostname(struct s2n_stuffer *hostname_stuffer, struct s2n_stuffer *output)
{
    /* Find the end of the first label */
    GUARD(s2n_stuffer_skip_to_char(hostname_stuffer, '.'));

    /* No first label found */
    if (s2n_stuffer_data_available(hostname_stuffer) == 0) {
        return 0;
    }

    /* Slap a single wildcard character to be the first label in output */
    GUARD(s2n_stuffer_write_uint8(output, '*'));

    /* Simply copy the rest of the input to the output. */
    GUARD(s2n_stuffer_copy(hostname_stuffer, output, s2n_stuffer_data_available(hostname_stuffer)));

    return 0;
}

static int s2n_find_cert_matches(struct s2n_map *domain_name_to_cert_map,
        struct s2n_blob *dns_name,
        struct s2n_cert_chain_and_key *matches[S2N_CERT_TYPE_COUNT],
        uint8_t *match_exists)
{
    struct s2n_blob map_value;
    bool key_found = false;
    GUARD_AS_POSIX(s2n_map_lookup(domain_name_to_cert_map, dns_name, &map_value, &key_found));
    if (key_found) {
        struct certs_by_type *value = (void *) map_value.data;
        for (int i = 0; i < S2N_CERT_TYPE_COUNT; i++) {
            matches[i] = value->certs[i];
        }
        *match_exists = 1;
    }

    return 0;
}

/* Find certificates that match the ServerName TLS extension sent by the client.
 * For a given ServerName there can be multiple matching certificates based on the
 * type of key in the certificate.
 *
 * A match is determined using s2n_map lookup by DNS name.
 * Wildcards that have a single * in the left most label are supported.
 */
int s2n_conn_find_name_matching_certs(struct s2n_connection *conn)
{
    if (!s2n_server_received_server_name(conn)) {
        return 0;
    }
    const char *name = conn->server_name;
    struct s2n_blob hostname_blob = { .data = (uint8_t *) (uintptr_t) name, .size = strlen(name) };
    lte_check(hostname_blob.size, S2N_MAX_SERVER_NAME);
    char normalized_hostname[S2N_MAX_SERVER_NAME + 1] = { 0 };
    memcpy_check(normalized_hostname, hostname_blob.data, hostname_blob.size);
    struct s2n_blob normalized_name = { .data = (uint8_t *) normalized_hostname, .size = hostname_blob.size };
    GUARD(s2n_blob_char_to_lower(&normalized_name));
    struct s2n_stuffer normalized_hostname_stuffer;
    GUARD(s2n_stuffer_init(&normalized_hostname_stuffer, &normalized_name));
    GUARD(s2n_stuffer_skip_write(&normalized_hostname_stuffer, normalized_name.size));

    /* Find the exact matches for the ServerName */
    GUARD(s2n_find_cert_matches(conn->config->domain_name_to_cert_map,
                &normalized_name,
                conn->handshake_params.exact_sni_matches,
                &(conn->handshake_params.exact_sni_match_exists)));

    if (!conn->handshake_params.exact_sni_match_exists) {
        /* We have not yet found an exact domain match. Try to find wildcard matches. */
        char wildcard_hostname[S2N_MAX_SERVER_NAME + 1] = { 0 };
        struct s2n_blob wildcard_blob = { .data = (uint8_t *) wildcard_hostname, .size = sizeof(wildcard_hostname) };
        struct s2n_stuffer wildcard_stuffer;
        GUARD(s2n_stuffer_init(&wildcard_stuffer, &wildcard_blob));
        GUARD(s2n_create_wildcard_hostname(&normalized_hostname_stuffer, &wildcard_stuffer));
        const uint32_t wildcard_len = s2n_stuffer_data_available(&wildcard_stuffer);

        /* Couldn't create a valid wildcard from the input */
        if (wildcard_len == 0) {
            return 0;
        }

        /* The client's SNI is wildcardified, do an exact match against the set of server certs. */
        wildcard_blob.size = wildcard_len;
        GUARD(s2n_find_cert_matches(conn->config->domain_name_to_cert_map,
                    &wildcard_blob,
                    conn->handshake_params.wc_sni_matches,
                    &(conn->handshake_params.wc_sni_match_exists)));
    }

    /* If we found a suitable cert, we should send back the ServerName extension.
     * Note that this may have already been set by the client hello callback, so we won't override its value
     */
    conn->server_name_used = conn->server_name_used
        || conn->handshake_params.exact_sni_match_exists
        || conn->handshake_params.wc_sni_match_exists;

    return 0;
}

/* Find the optimal certificate of a specific type.
 * The priority of set of certificates to choose from:
 * 1. Certificates that match the client's ServerName extension.
 * 2. Default certificates
 */
struct s2n_cert_chain_and_key *s2n_get_compatible_cert_chain_and_key(struct s2n_connection *conn, const s2n_pkey_type cert_type)
{
    if (conn->handshake_params.exact_sni_match_exists) {
        /* This may return NULL if there was an SNI match, but not a match the cipher_suite's authentication type. */
        return conn->handshake_params.exact_sni_matches[cert_type];
    } if (conn->handshake_params.wc_sni_match_exists) {
        return conn->handshake_params.wc_sni_matches[cert_type];
    } else {
        /* We don't have any name matches. Use the default certificate that works with the key type. */
        return conn->config->default_certs_by_type.certs[cert_type];
    }
}