aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/restricted/aws/s2n/tls/s2n_protocol_preferences.c
blob: 2a4ea614a53557a1e7ced072eb2046193b4b9e63 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

#include "error/s2n_errno.h"
#include "tls/s2n_connection.h"
#include "utils/s2n_safety.h"

S2N_RESULT s2n_protocol_preferences_read(struct s2n_stuffer *protocol_preferences, struct s2n_blob *protocol)
{
    RESULT_ENSURE_REF(protocol_preferences);
    RESULT_ENSURE_REF(protocol);

    uint8_t length = 0;
    RESULT_GUARD_POSIX(s2n_stuffer_read_uint8(protocol_preferences, &length));
    RESULT_ENSURE_GT(length, 0);

    uint8_t *data = s2n_stuffer_raw_read(protocol_preferences, length);
    RESULT_ENSURE_REF(data);

    RESULT_GUARD_POSIX(s2n_blob_init(protocol, data, length));
    return S2N_RESULT_OK;
}

S2N_RESULT s2n_protocol_preferences_contain(struct s2n_blob *protocol_preferences, struct s2n_blob *protocol, bool *contains)
{
    RESULT_ENSURE_REF(contains);
    *contains = false;
    RESULT_ENSURE_REF(protocol_preferences);
    RESULT_ENSURE_REF(protocol);

    struct s2n_stuffer app_protocols_stuffer = { 0 };
    RESULT_GUARD_POSIX(s2n_stuffer_init(&app_protocols_stuffer, protocol_preferences));
    RESULT_GUARD_POSIX(s2n_stuffer_skip_write(&app_protocols_stuffer, protocol_preferences->size));

    while (s2n_stuffer_data_available(&app_protocols_stuffer) > 0) {
        struct s2n_blob match_against = { 0 };
        RESULT_GUARD(s2n_protocol_preferences_read(&app_protocols_stuffer, &match_against));

        if (match_against.size == protocol->size && memcmp(match_against.data, protocol->data, protocol->size) == 0) {
            *contains = true;
            return S2N_RESULT_OK;
        }
    }
    return S2N_RESULT_OK;
}

S2N_RESULT s2n_protocol_preferences_append(struct s2n_blob *application_protocols, const uint8_t *protocol, uint8_t protocol_len)
{
    RESULT_ENSURE_MUT(application_protocols);
    RESULT_ENSURE_REF(protocol);

    /**
     *= https://tools.ietf.org/rfc/rfc7301#section-3.1
     *# Empty strings
     *# MUST NOT be included and byte strings MUST NOT be truncated.
     */
    RESULT_ENSURE(protocol_len != 0, S2N_ERR_INVALID_APPLICATION_PROTOCOL);

    uint32_t prev_len = application_protocols->size;
    uint32_t new_len = prev_len + /* len prefix */ 1 + protocol_len;
    RESULT_ENSURE(new_len <= UINT16_MAX, S2N_ERR_INVALID_APPLICATION_PROTOCOL);

    RESULT_GUARD_POSIX(s2n_realloc(application_protocols, new_len));

    struct s2n_stuffer protocol_stuffer = { 0 };
    RESULT_GUARD_POSIX(s2n_stuffer_init(&protocol_stuffer, application_protocols));
    RESULT_GUARD_POSIX(s2n_stuffer_skip_write(&protocol_stuffer, prev_len));
    RESULT_GUARD_POSIX(s2n_stuffer_write_uint8(&protocol_stuffer, protocol_len));
    RESULT_GUARD_POSIX(s2n_stuffer_write_bytes(&protocol_stuffer, protocol, protocol_len));

    return S2N_RESULT_OK;
}

S2N_RESULT s2n_protocol_preferences_set(struct s2n_blob *application_protocols, const char *const *protocols, int protocol_count)
{
    RESULT_ENSURE_MUT(application_protocols);

    /* NULL value indicates no preference so free the previous blob */
    if (protocols == NULL || protocol_count == 0) {
        RESULT_GUARD_POSIX(s2n_free(application_protocols));
        return S2N_RESULT_OK;
    }

    DEFER_CLEANUP(struct s2n_blob new_protocols = { 0 }, s2n_free);

    /* Allocate enough space to avoid a reallocation for every entry
     *
     * We assume that each protocol is most likely 8 bytes or less.
     * If it ends up being larger, we will expand the blob automatically
     * in the append method.
     */
    RESULT_GUARD_POSIX(s2n_realloc(&new_protocols, protocol_count * 8));

    /* set the size back to 0 so we start at the beginning.
     * s2n_realloc will just update the size field here
     */
    RESULT_GUARD_POSIX(s2n_realloc(&new_protocols, 0));

    for (size_t i = 0; i < protocol_count; i++) {
        const uint8_t *protocol = (const uint8_t *) protocols[i];
        size_t length = strlen(protocols[i]);

        /**
         *= https://tools.ietf.org/rfc/rfc7301#section-3.1
         *# Empty strings
         *# MUST NOT be included and byte strings MUST NOT be truncated.
         */
        RESULT_ENSURE(length < 256, S2N_ERR_INVALID_APPLICATION_PROTOCOL);

        RESULT_GUARD(s2n_protocol_preferences_append(&new_protocols, protocol, (uint8_t) length));
    }

    /* now we can free the previous list since we've validated all new input */
    RESULT_GUARD_POSIX(s2n_free(application_protocols));

    /* update the connection/config application_protocols with the newly allocated blob */
    *application_protocols = new_protocols;

    /* zero out new_protocols so the DEFER_CLEANUP from above doesn't free
     * the blob that we created and assigned to application_protocols
     */
    /* cppcheck-suppress unreadVariable */
    new_protocols = (struct s2n_blob){ 0 };

    return S2N_RESULT_OK;
}

S2N_RESULT s2n_select_server_preference_protocol(struct s2n_connection *conn, struct s2n_stuffer *server_list,
        struct s2n_blob *client_list)
{
    RESULT_ENSURE_REF(conn);
    RESULT_ENSURE_REF(server_list);
    RESULT_ENSURE_REF(client_list);

    while (s2n_stuffer_data_available(server_list) > 0) {
        struct s2n_blob protocol = { 0 };
        RESULT_ENSURE_OK(s2n_protocol_preferences_read(server_list, &protocol), S2N_ERR_BAD_MESSAGE);

        bool match_found = false;
        RESULT_ENSURE_OK(s2n_protocol_preferences_contain(client_list, &protocol, &match_found), S2N_ERR_BAD_MESSAGE);

        if (match_found) {
            RESULT_ENSURE_LT(protocol.size, sizeof(conn->application_protocol));
            RESULT_CHECKED_MEMCPY(conn->application_protocol, protocol.data, protocol.size);
            conn->application_protocol[protocol.size] = '\0';
            return S2N_RESULT_OK;
        }
    }

    return S2N_RESULT_OK;
}

int s2n_config_set_protocol_preferences(struct s2n_config *config, const char *const *protocols, int protocol_count)
{
    POSIX_GUARD_RESULT(s2n_protocol_preferences_set(&config->application_protocols, protocols, protocol_count));
    return S2N_SUCCESS;
}

int s2n_config_append_protocol_preference(struct s2n_config *config, const uint8_t *protocol, uint8_t protocol_len)
{
    POSIX_GUARD_RESULT(s2n_protocol_preferences_append(&config->application_protocols, protocol, protocol_len));
    return S2N_SUCCESS;
}

int s2n_connection_set_protocol_preferences(struct s2n_connection *conn, const char *const *protocols, int protocol_count)
{
    POSIX_GUARD_RESULT(s2n_protocol_preferences_set(&conn->application_protocols_overridden, protocols, protocol_count));
    return S2N_SUCCESS;
}

int s2n_connection_append_protocol_preference(struct s2n_connection *conn, const uint8_t *protocol, uint8_t protocol_len)
{
    POSIX_GUARD_RESULT(s2n_protocol_preferences_append(&conn->application_protocols_overridden, protocol, protocol_len));
    return S2N_SUCCESS;
}