aboutsummaryrefslogblamecommitdiffstats
path: root/contrib/restricted/aws/s2n/tls/s2n_handshake.c
blob: 922cd67c2b678c2328451c358f796b0f3cfef63e (plain) (tree)









































































































































































































































































































































                                                                                                                                         
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

#include <stdint.h>

#include "error/s2n_errno.h"

#include "tls/s2n_connection.h"
#include "tls/s2n_record.h"
#include "tls/s2n_cipher_suites.h"
#include "tls/s2n_tls.h"

#include "stuffer/s2n_stuffer.h"

#include "utils/s2n_safety.h"
#include "utils/s2n_map.h"

int s2n_handshake_write_header(struct s2n_stuffer *out, uint8_t message_type)
{
    S2N_ERROR_IF(s2n_stuffer_data_available(out), S2N_ERR_HANDSHAKE_STATE);

    /* Write the message header */
    GUARD(s2n_stuffer_write_uint8(out, message_type));

    /* Leave the length blank for now */
    uint16_t length = 0;
    GUARD(s2n_stuffer_write_uint24(out, length));

    return 0;
}

int s2n_handshake_finish_header(struct s2n_stuffer *out)
{
    uint16_t length = s2n_stuffer_data_available(out);
    S2N_ERROR_IF(length < TLS_HANDSHAKE_HEADER_LENGTH, S2N_ERR_SIZE_MISMATCH);

    uint16_t payload = length - TLS_HANDSHAKE_HEADER_LENGTH;

    /* Write the message header */
    GUARD(s2n_stuffer_rewrite(out));
    GUARD(s2n_stuffer_skip_write(out, 1));
    GUARD(s2n_stuffer_write_uint24(out, payload));
    GUARD(s2n_stuffer_skip_write(out, payload));

    return 0;
}

int s2n_handshake_parse_header(struct s2n_connection *conn, uint8_t * message_type, uint32_t * length)
{
    S2N_ERROR_IF(s2n_stuffer_data_available(&conn->handshake.io) < TLS_HANDSHAKE_HEADER_LENGTH, S2N_ERR_SIZE_MISMATCH);

    /* read the message header */
    GUARD(s2n_stuffer_read_uint8(&conn->handshake.io, message_type));
    GUARD(s2n_stuffer_read_uint24(&conn->handshake.io, length));

    return 0;
}

static int s2n_handshake_get_hash_state_ptr(struct s2n_connection *conn, s2n_hash_algorithm hash_alg, struct s2n_hash_state **hash_state)
{
    notnull_check(conn);

    switch (hash_alg) {
    case S2N_HASH_MD5:
        *hash_state = &conn->handshake.md5;
        break;
    case S2N_HASH_SHA1:
        *hash_state = &conn->handshake.sha1;
        break;
    case S2N_HASH_SHA224:
        *hash_state = &conn->handshake.sha224;
        break;
    case S2N_HASH_SHA256:
        *hash_state = &conn->handshake.sha256;
        break;
    case S2N_HASH_SHA384:
        *hash_state = &conn->handshake.sha384;
        break;
    case S2N_HASH_SHA512:
        *hash_state = &conn->handshake.sha512;
        break;
    case S2N_HASH_MD5_SHA1:
        *hash_state = &conn->handshake.md5_sha1;
        break;
    default:
        S2N_ERROR(S2N_ERR_HASH_INVALID_ALGORITHM);
        break;
    }

    return 0;
}

int s2n_handshake_reset_hash_state(struct s2n_connection *conn, s2n_hash_algorithm hash_alg)
{
    struct s2n_hash_state *hash_state_ptr = NULL;
    GUARD(s2n_handshake_get_hash_state_ptr(conn, hash_alg, &hash_state_ptr));

    GUARD(s2n_hash_reset(hash_state_ptr));

    return 0;
}

/* Copy the current hash state into the caller supplied pointer.
 * NOTE: If the underlying digest implementation is using the EVP API
 * then a pointer to the EVP ctx and md is copied. So you are actually
 * taking a reference, not a value.
 * Before using the hash_state returned by this function you must
 * use s2n_hash_copy() to avoid modifying the underlying value.
 */
int s2n_handshake_get_hash_state(struct s2n_connection *conn, s2n_hash_algorithm hash_alg, struct s2n_hash_state *hash_state)
{
    notnull_check(hash_state);

    struct s2n_hash_state *hash_state_ptr = NULL;
    GUARD(s2n_handshake_get_hash_state_ptr(conn, hash_alg, &hash_state_ptr));

    *hash_state = *hash_state_ptr;

    return 0;
}

int s2n_handshake_require_all_hashes(struct s2n_handshake *handshake)
{
    memset(handshake->required_hash_algs, 1, sizeof(handshake->required_hash_algs));
    return 0;
}

static int s2n_handshake_require_hash(struct s2n_handshake *handshake, s2n_hash_algorithm hash_alg)
{
    handshake->required_hash_algs[hash_alg] = 1;
    return 0;
}

uint8_t s2n_handshake_is_hash_required(struct s2n_handshake *handshake, s2n_hash_algorithm hash_alg)
{
    return handshake->required_hash_algs[hash_alg];
}

/* Update the required handshake hash algs depending on current handshake session state.
 * This function must called at the end of a handshake message handler. Additionally it must be called after the
 * ClientHello or ServerHello is processed in client and server mode respectively. The relevant handshake parameters
 * are not available until those messages are processed.
 */
int s2n_conn_update_required_handshake_hashes(struct s2n_connection *conn)
{
    /* Clear all of the required hashes */
    memset(conn->handshake.required_hash_algs, 0, sizeof(conn->handshake.required_hash_algs));

    message_type_t handshake_message = s2n_conn_get_current_message_type(conn);
    const uint8_t client_cert_verify_done = (handshake_message >= CLIENT_CERT_VERIFY) ? 1 : 0;
    s2n_cert_auth_type client_cert_auth_type;
    GUARD(s2n_connection_get_client_auth_type(conn, &client_cert_auth_type));

    /* If client authentication is possible, all hashes are needed until we're past CLIENT_CERT_VERIFY. */
    if ((client_cert_auth_type != S2N_CERT_AUTH_NONE) && !client_cert_verify_done) {
        GUARD(s2n_handshake_require_all_hashes(&conn->handshake));
        return 0;
    }

    /* We don't need all of the hashes. Set the hash alg(s) required for the PRF */
    switch (conn->actual_protocol_version) {
    case S2N_SSLv3:
    case S2N_TLS10:
    case S2N_TLS11:
        GUARD(s2n_handshake_require_hash(&conn->handshake, S2N_HASH_MD5));
        GUARD(s2n_handshake_require_hash(&conn->handshake, S2N_HASH_SHA1));
        break;
    case S2N_TLS12:
        /* fall through */
    case S2N_TLS13:
    {
        /* For TLS 1.2 and TLS 1.3, the cipher suite defines the PRF hash alg */
        s2n_hmac_algorithm prf_alg = conn->secure.cipher_suite->prf_alg;
        s2n_hash_algorithm hash_alg;
        GUARD(s2n_hmac_hash_alg(prf_alg, &hash_alg));
        GUARD(s2n_handshake_require_hash(&conn->handshake, hash_alg));
        break;
    }
    }

    return 0;
}

/*
 * Take a hostname and return a single "simple" wildcard domain name that matches it.
 * The output wildcard representation is meant to be compared directly against a wildcard domain in a certificate.
 * We take a restrictive definition of wildcard here to achieve a single unique wildcard representation
 * given any input hostname.
 * No embedded or trailing wildcards are supported. Additionally, we only support one level of wildcard matching.
 * Thus the output should be a single wildcard character in the first(left-most) DNS label.
 *
 * Example:
 * - my.domain.name -> *.domain.name
 *
 * Not supported:
 * - my.domain.name -> m*.domain.name
 * - my.domain.name -> my.*.name
 * etc.
 *
 * The motivation for using a constrained definition of wildcard:
 * - Support for issuing non-simple wildcard certificates is insignificant.
 * - Certificate selection can be implemented with a constant number of lookups(two).
 */
int s2n_create_wildcard_hostname(struct s2n_stuffer *hostname_stuffer, struct s2n_stuffer *output)
{
    /* Find the end of the first label */
    GUARD(s2n_stuffer_skip_to_char(hostname_stuffer, '.'));

    /* No first label found */
    if (s2n_stuffer_data_available(hostname_stuffer) == 0) {
        return 0;
    }

    /* Slap a single wildcard character to be the first label in output */
    GUARD(s2n_stuffer_write_uint8(output, '*'));

    /* Simply copy the rest of the input to the output. */
    GUARD(s2n_stuffer_copy(hostname_stuffer, output, s2n_stuffer_data_available(hostname_stuffer)));

    return 0;
}

static int s2n_find_cert_matches(struct s2n_map *domain_name_to_cert_map,
        struct s2n_blob *dns_name,
        struct s2n_cert_chain_and_key *matches[S2N_CERT_TYPE_COUNT],
        uint8_t *match_exists)
{
    struct s2n_blob map_value;
    bool key_found = false;
    GUARD_AS_POSIX(s2n_map_lookup(domain_name_to_cert_map, dns_name, &map_value, &key_found));
    if (key_found) {
        struct certs_by_type *value = (void *) map_value.data;
        for (int i = 0; i < S2N_CERT_TYPE_COUNT; i++) {
            matches[i] = value->certs[i];
        }
        *match_exists = 1;
    }

    return 0;
}

/* Find certificates that match the ServerName TLS extension sent by the client.
 * For a given ServerName there can be multiple matching certificates based on the
 * type of key in the certificate.
 *
 * A match is determined using s2n_map lookup by DNS name.
 * Wildcards that have a single * in the left most label are supported.
 */
int s2n_conn_find_name_matching_certs(struct s2n_connection *conn)
{
    if (!s2n_server_received_server_name(conn)) {
        return 0;
    }
    const char *name = conn->server_name;
    struct s2n_blob hostname_blob = { .data = (uint8_t *) (uintptr_t) name, .size = strlen(name) };
    lte_check(hostname_blob.size, S2N_MAX_SERVER_NAME);
    char normalized_hostname[S2N_MAX_SERVER_NAME + 1] = { 0 };
    memcpy_check(normalized_hostname, hostname_blob.data, hostname_blob.size);
    struct s2n_blob normalized_name = { .data = (uint8_t *) normalized_hostname, .size = hostname_blob.size };
    GUARD(s2n_blob_char_to_lower(&normalized_name));
    struct s2n_stuffer normalized_hostname_stuffer;
    GUARD(s2n_stuffer_init(&normalized_hostname_stuffer, &normalized_name));
    GUARD(s2n_stuffer_skip_write(&normalized_hostname_stuffer, normalized_name.size));

    /* Find the exact matches for the ServerName */
    GUARD(s2n_find_cert_matches(conn->config->domain_name_to_cert_map,
                &normalized_name,
                conn->handshake_params.exact_sni_matches,
                &(conn->handshake_params.exact_sni_match_exists)));

    if (!conn->handshake_params.exact_sni_match_exists) {
        /* We have not yet found an exact domain match. Try to find wildcard matches. */
        char wildcard_hostname[S2N_MAX_SERVER_NAME + 1] = { 0 };
        struct s2n_blob wildcard_blob = { .data = (uint8_t *) wildcard_hostname, .size = sizeof(wildcard_hostname) };
        struct s2n_stuffer wildcard_stuffer;
        GUARD(s2n_stuffer_init(&wildcard_stuffer, &wildcard_blob));
        GUARD(s2n_create_wildcard_hostname(&normalized_hostname_stuffer, &wildcard_stuffer));
        const uint32_t wildcard_len = s2n_stuffer_data_available(&wildcard_stuffer);

        /* Couldn't create a valid wildcard from the input */
        if (wildcard_len == 0) {
            return 0;
        }

        /* The client's SNI is wildcardified, do an exact match against the set of server certs. */
        wildcard_blob.size = wildcard_len;
        GUARD(s2n_find_cert_matches(conn->config->domain_name_to_cert_map,
                    &wildcard_blob,
                    conn->handshake_params.wc_sni_matches,
                    &(conn->handshake_params.wc_sni_match_exists)));
    }

    /* If we found a suitable cert, we should send back the ServerName extension.
     * Note that this may have already been set by the client hello callback, so we won't override its value
     */
    conn->server_name_used = conn->server_name_used
        || conn->handshake_params.exact_sni_match_exists
        || conn->handshake_params.wc_sni_match_exists;

    return 0;
}

/* Find the optimal certificate of a specific type.
 * The priority of set of certificates to choose from:
 * 1. Certificates that match the client's ServerName extension.
 * 2. Default certificates
 */
struct s2n_cert_chain_and_key *s2n_get_compatible_cert_chain_and_key(struct s2n_connection *conn, const s2n_pkey_type cert_type)
{
    if (conn->handshake_params.exact_sni_match_exists) {
        /* This may return NULL if there was an SNI match, but not a match the cipher_suite's authentication type. */
        return conn->handshake_params.exact_sni_matches[cert_type];
    } if (conn->handshake_params.wc_sni_match_exists) {
        return conn->handshake_params.wc_sni_matches[cert_type];
    } else {
        /* We don't have any name matches. Use the default certificate that works with the key type. */
        return conn->config->default_certs_by_type.certs[cert_type];
    }
}