diff options
author | orivej <orivej@yandex-team.ru> | 2022-02-10 16:44:49 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:44:49 +0300 |
commit | 718c552901d703c502ccbefdfc3c9028d608b947 (patch) | |
tree | 46534a98bbefcd7b1f3faa5b52c138ab27db75b7 /contrib/restricted/aws/aws-c-io/source | |
parent | e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (diff) | |
download | ydb-718c552901d703c502ccbefdfc3c9028d608b947.tar.gz |
Restoring authorship annotation for <orivej@yandex-team.ru>. Commit 1 of 2.
Diffstat (limited to 'contrib/restricted/aws/aws-c-io/source')
25 files changed, 13192 insertions, 13192 deletions
diff --git a/contrib/restricted/aws/aws-c-io/source/alpn_handler.c b/contrib/restricted/aws/aws-c-io/source/alpn_handler.c index 5ad2882602..fe49568332 100644 --- a/contrib/restricted/aws/aws-c-io/source/alpn_handler.c +++ b/contrib/restricted/aws/aws-c-io/source/alpn_handler.c @@ -1,110 +1,110 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/channel.h> -#include <aws/io/tls_channel_handler.h> - -struct alpn_handler { - aws_tls_on_protocol_negotiated on_protocol_negotiated; - void *user_data; -}; - -static int s_alpn_process_read_message( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_io_message *message) { - - if (message->message_tag != AWS_TLS_NEGOTIATED_PROTOCOL_MESSAGE) { - return aws_raise_error(AWS_IO_MISSING_ALPN_MESSAGE); - } - - struct aws_tls_negotiated_protocol_message *protocol_message = - (struct aws_tls_negotiated_protocol_message *)message->message_data.buffer; - - struct aws_channel_slot *new_slot = aws_channel_slot_new(slot->channel); - - struct alpn_handler *alpn_handler = (struct alpn_handler *)handler->impl; - - if (!new_slot) { - return AWS_OP_ERR; - } - - struct aws_channel_handler *new_handler = - alpn_handler->on_protocol_negotiated(new_slot, &protocol_message->protocol, alpn_handler->user_data); - - if (!new_handler) { - aws_mem_release(handler->alloc, (void *)new_slot); - return aws_raise_error(AWS_IO_UNHANDLED_ALPN_PROTOCOL_MESSAGE); - } - - aws_channel_slot_replace(slot, new_slot); - aws_channel_slot_set_handler(new_slot, new_handler); - return AWS_OP_SUCCESS; -} - -static int s_alpn_shutdown( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - enum aws_channel_direction dir, - int error_code, - bool abort_immediately) { - (void)handler; - return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, abort_immediately); -} - -static size_t s_alpn_get_initial_window_size(struct aws_channel_handler *handler) { - (void)handler; - return sizeof(struct aws_tls_negotiated_protocol_message); -} - -static void s_alpn_destroy(struct aws_channel_handler *handler) { - struct alpn_handler *alpn_handler = (struct alpn_handler *)handler->impl; - aws_mem_release(handler->alloc, alpn_handler); - aws_mem_release(handler->alloc, handler); -} - -static size_t s_alpn_message_overhead(struct aws_channel_handler *handler) { - (void)handler; - return 0; -} - -static struct aws_channel_handler_vtable s_alpn_handler_vtable = { - .initial_window_size = s_alpn_get_initial_window_size, - .increment_read_window = NULL, - .shutdown = s_alpn_shutdown, - .process_write_message = NULL, - .process_read_message = s_alpn_process_read_message, - .destroy = s_alpn_destroy, - .message_overhead = s_alpn_message_overhead, -}; - -struct aws_channel_handler *aws_tls_alpn_handler_new( - struct aws_allocator *allocator, - aws_tls_on_protocol_negotiated on_protocol_negotiated, - void *user_data) { - struct aws_channel_handler *channel_handler = - (struct aws_channel_handler *)aws_mem_calloc(allocator, 1, sizeof(struct aws_channel_handler)); - - if (!channel_handler) { - return NULL; - } - - struct alpn_handler *alpn_handler = - (struct alpn_handler *)aws_mem_calloc(allocator, 1, sizeof(struct alpn_handler)); - - if (!alpn_handler) { - aws_mem_release(allocator, (void *)channel_handler); - return NULL; - } - - alpn_handler->on_protocol_negotiated = on_protocol_negotiated; - alpn_handler->user_data = user_data; - channel_handler->impl = alpn_handler; - channel_handler->alloc = allocator; - - channel_handler->vtable = &s_alpn_handler_vtable; - - return channel_handler; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/channel.h> +#include <aws/io/tls_channel_handler.h> + +struct alpn_handler { + aws_tls_on_protocol_negotiated on_protocol_negotiated; + void *user_data; +}; + +static int s_alpn_process_read_message( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_io_message *message) { + + if (message->message_tag != AWS_TLS_NEGOTIATED_PROTOCOL_MESSAGE) { + return aws_raise_error(AWS_IO_MISSING_ALPN_MESSAGE); + } + + struct aws_tls_negotiated_protocol_message *protocol_message = + (struct aws_tls_negotiated_protocol_message *)message->message_data.buffer; + + struct aws_channel_slot *new_slot = aws_channel_slot_new(slot->channel); + + struct alpn_handler *alpn_handler = (struct alpn_handler *)handler->impl; + + if (!new_slot) { + return AWS_OP_ERR; + } + + struct aws_channel_handler *new_handler = + alpn_handler->on_protocol_negotiated(new_slot, &protocol_message->protocol, alpn_handler->user_data); + + if (!new_handler) { + aws_mem_release(handler->alloc, (void *)new_slot); + return aws_raise_error(AWS_IO_UNHANDLED_ALPN_PROTOCOL_MESSAGE); + } + + aws_channel_slot_replace(slot, new_slot); + aws_channel_slot_set_handler(new_slot, new_handler); + return AWS_OP_SUCCESS; +} + +static int s_alpn_shutdown( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + enum aws_channel_direction dir, + int error_code, + bool abort_immediately) { + (void)handler; + return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, abort_immediately); +} + +static size_t s_alpn_get_initial_window_size(struct aws_channel_handler *handler) { + (void)handler; + return sizeof(struct aws_tls_negotiated_protocol_message); +} + +static void s_alpn_destroy(struct aws_channel_handler *handler) { + struct alpn_handler *alpn_handler = (struct alpn_handler *)handler->impl; + aws_mem_release(handler->alloc, alpn_handler); + aws_mem_release(handler->alloc, handler); +} + +static size_t s_alpn_message_overhead(struct aws_channel_handler *handler) { + (void)handler; + return 0; +} + +static struct aws_channel_handler_vtable s_alpn_handler_vtable = { + .initial_window_size = s_alpn_get_initial_window_size, + .increment_read_window = NULL, + .shutdown = s_alpn_shutdown, + .process_write_message = NULL, + .process_read_message = s_alpn_process_read_message, + .destroy = s_alpn_destroy, + .message_overhead = s_alpn_message_overhead, +}; + +struct aws_channel_handler *aws_tls_alpn_handler_new( + struct aws_allocator *allocator, + aws_tls_on_protocol_negotiated on_protocol_negotiated, + void *user_data) { + struct aws_channel_handler *channel_handler = + (struct aws_channel_handler *)aws_mem_calloc(allocator, 1, sizeof(struct aws_channel_handler)); + + if (!channel_handler) { + return NULL; + } + + struct alpn_handler *alpn_handler = + (struct alpn_handler *)aws_mem_calloc(allocator, 1, sizeof(struct alpn_handler)); + + if (!alpn_handler) { + aws_mem_release(allocator, (void *)channel_handler); + return NULL; + } + + alpn_handler->on_protocol_negotiated = on_protocol_negotiated; + alpn_handler->user_data = user_data; + channel_handler->impl = alpn_handler; + channel_handler->alloc = allocator; + + channel_handler->vtable = &s_alpn_handler_vtable; + + return channel_handler; +} diff --git a/contrib/restricted/aws/aws-c-io/source/bsd/kqueue_event_loop.c b/contrib/restricted/aws/aws-c-io/source/bsd/kqueue_event_loop.c index 3de882d045..949549b941 100644 --- a/contrib/restricted/aws/aws-c-io/source/bsd/kqueue_event_loop.c +++ b/contrib/restricted/aws/aws-c-io/source/bsd/kqueue_event_loop.c @@ -1,960 +1,960 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/event_loop.h> - -#include <aws/io/logging.h> - -#include <aws/common/atomics.h> -#include <aws/common/clock.h> -#include <aws/common/mutex.h> -#include <aws/common/task_scheduler.h> -#include <aws/common/thread.h> - -#if defined(__FreeBSD__) || defined(__NetBSD__) -# define __BSD_VISIBLE 1 -# include <sys/types.h> -#endif - -#include <sys/event.h> - -#include <aws/io/io.h> -#include <limits.h> -#include <unistd.h> - -static void s_destroy(struct aws_event_loop *event_loop); -static int s_run(struct aws_event_loop *event_loop); -static int s_stop(struct aws_event_loop *event_loop); -static int s_wait_for_stop_completion(struct aws_event_loop *event_loop); -static void s_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task); -static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos); -static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task); -static int s_subscribe_to_io_events( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - aws_event_loop_on_event_fn *on_event, - void *user_data); -static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle); -static void s_free_io_event_resources(void *user_data); -static bool s_is_event_thread(struct aws_event_loop *event_loop); - -static void s_event_thread_main(void *user_data); - -int aws_open_nonblocking_posix_pipe(int pipe_fds[2]); - -enum event_thread_state { - EVENT_THREAD_STATE_READY_TO_RUN, - EVENT_THREAD_STATE_RUNNING, - EVENT_THREAD_STATE_STOPPING, -}; - -enum pipe_fd_index { - READ_FD, - WRITE_FD, -}; - -struct kqueue_loop { - /* thread_created_on is the handle to the event loop thread. */ - struct aws_thread thread_created_on; - /* thread_joined_to is used by the thread destroying the event loop. */ - aws_thread_id_t thread_joined_to; - /* running_thread_id is NULL if the event loop thread is stopped or points-to the thread_id of the thread running - * the event loop (either thread_created_on or thread_joined_to). Atomic because of concurrent writes (e.g., - * run/stop) and reads (e.g., is_event_loop_thread). - * An aws_thread_id_t variable itself cannot be atomic because it is an opaque type that is platform-dependent. */ - struct aws_atomic_var running_thread_id; - int kq_fd; /* kqueue file descriptor */ - - /* Pipe for signaling to event-thread that cross_thread_data has changed. */ - int cross_thread_signal_pipe[2]; - - /* cross_thread_data holds things that must be communicated across threads. - * When the event-thread is running, the mutex must be locked while anyone touches anything in cross_thread_data. - * If this data is modified outside the thread, the thread is signaled via activity on a pipe. */ - struct { - struct aws_mutex mutex; - bool thread_signaled; /* whether thread has been signaled about changes to cross_thread_data */ - struct aws_linked_list tasks_to_schedule; - enum event_thread_state state; - } cross_thread_data; - - /* thread_data holds things which, when the event-thread is running, may only be touched by the thread */ - struct { - struct aws_task_scheduler scheduler; - - int connected_handle_count; - - /* These variables duplicate ones in cross_thread_data. We move values out while holding the mutex and operate - * on them later */ - enum event_thread_state state; - } thread_data; -}; - -/* Data attached to aws_io_handle while the handle is subscribed to io events */ -struct handle_data { - struct aws_io_handle *owner; - struct aws_event_loop *event_loop; - aws_event_loop_on_event_fn *on_event; - void *on_event_user_data; - - int events_subscribed; /* aws_io_event_types this handle should be subscribed to */ - int events_this_loop; /* aws_io_event_types received during current loop of the event-thread */ - - enum { HANDLE_STATE_SUBSCRIBING, HANDLE_STATE_SUBSCRIBED, HANDLE_STATE_UNSUBSCRIBED } state; - - struct aws_task subscribe_task; - struct aws_task cleanup_task; -}; - -enum { - DEFAULT_TIMEOUT_SEC = 100, /* Max kevent() timeout per loop of the event-thread */ - MAX_EVENTS = 100, /* Max kevents to process per loop of the event-thread */ -}; - -struct aws_event_loop_vtable s_kqueue_vtable = { - .destroy = s_destroy, - .run = s_run, - .stop = s_stop, - .wait_for_stop_completion = s_wait_for_stop_completion, - .schedule_task_now = s_schedule_task_now, - .schedule_task_future = s_schedule_task_future, - .subscribe_to_io_events = s_subscribe_to_io_events, - .cancel_task = s_cancel_task, - .unsubscribe_from_io_events = s_unsubscribe_from_io_events, - .free_io_event_resources = s_free_io_event_resources, - .is_on_callers_thread = s_is_event_thread, -}; - -struct aws_event_loop *aws_event_loop_new_default(struct aws_allocator *alloc, aws_io_clock_fn *clock) { - AWS_ASSERT(alloc); - AWS_ASSERT(clock); - - bool clean_up_event_loop_mem = false; - bool clean_up_event_loop_base = false; - bool clean_up_impl_mem = false; - bool clean_up_thread = false; - bool clean_up_kqueue = false; - bool clean_up_signal_pipe = false; - bool clean_up_signal_kevent = false; - bool clean_up_mutex = false; - - struct aws_event_loop *event_loop = aws_mem_acquire(alloc, sizeof(struct aws_event_loop)); - if (!event_loop) { - return NULL; - } - - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Initializing edge-triggered kqueue", (void *)event_loop); - clean_up_event_loop_mem = true; - - int err = aws_event_loop_init_base(event_loop, alloc, clock); - if (err) { - goto clean_up; - } - clean_up_event_loop_base = true; - - struct kqueue_loop *impl = aws_mem_calloc(alloc, 1, sizeof(struct kqueue_loop)); - if (!impl) { - goto clean_up; - } - /* intialize thread id to NULL. It will be set when the event loop thread starts. */ - aws_atomic_init_ptr(&impl->running_thread_id, NULL); - clean_up_impl_mem = true; - - err = aws_thread_init(&impl->thread_created_on, alloc); - if (err) { - goto clean_up; - } - clean_up_thread = true; - - impl->kq_fd = kqueue(); - if (impl->kq_fd == -1) { - AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: Failed to open kqueue handle.", (void *)event_loop); - aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); - goto clean_up; - } - clean_up_kqueue = true; - - err = aws_open_nonblocking_posix_pipe(impl->cross_thread_signal_pipe); - if (err) { - AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: failed to open pipe handle.", (void *)event_loop); - goto clean_up; - } - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: pipe descriptors read %d, write %d.", - (void *)event_loop, - impl->cross_thread_signal_pipe[READ_FD], - impl->cross_thread_signal_pipe[WRITE_FD]); - clean_up_signal_pipe = true; - - /* Set up kevent to handle activity on the cross_thread_signal_pipe */ - struct kevent thread_signal_kevent; - EV_SET( - &thread_signal_kevent, - impl->cross_thread_signal_pipe[READ_FD], - EVFILT_READ /*filter*/, - EV_ADD | EV_CLEAR /*flags*/, - 0 /*fflags*/, - 0 /*data*/, - NULL /*udata*/); - - int res = kevent( - impl->kq_fd, - &thread_signal_kevent /*changelist*/, - 1 /*nchanges*/, - NULL /*eventlist*/, - 0 /*nevents*/, - NULL /*timeout*/); - - if (res == -1) { - AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: failed to create cross-thread signal kevent.", (void *)event_loop); - aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); - goto clean_up; - } - clean_up_signal_kevent = true; - - err = aws_mutex_init(&impl->cross_thread_data.mutex); - if (err) { - goto clean_up; - } - clean_up_mutex = true; - - impl->cross_thread_data.thread_signaled = false; - - aws_linked_list_init(&impl->cross_thread_data.tasks_to_schedule); - - impl->cross_thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; - - err = aws_task_scheduler_init(&impl->thread_data.scheduler, alloc); - if (err) { - goto clean_up; - } - - impl->thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; - - event_loop->impl_data = impl; - - event_loop->vtable = &s_kqueue_vtable; - - /* success */ - return event_loop; - -clean_up: - if (clean_up_mutex) { - aws_mutex_clean_up(&impl->cross_thread_data.mutex); - } - if (clean_up_signal_kevent) { - thread_signal_kevent.flags = EV_DELETE; - kevent( - impl->kq_fd, - &thread_signal_kevent /*changelist*/, - 1 /*nchanges*/, - NULL /*eventlist*/, - 0 /*nevents*/, - NULL /*timeout*/); - } - if (clean_up_signal_pipe) { - close(impl->cross_thread_signal_pipe[READ_FD]); - close(impl->cross_thread_signal_pipe[WRITE_FD]); - } - if (clean_up_kqueue) { - close(impl->kq_fd); - } - if (clean_up_thread) { - aws_thread_clean_up(&impl->thread_created_on); - } - if (clean_up_impl_mem) { - aws_mem_release(alloc, impl); - } - if (clean_up_event_loop_base) { - aws_event_loop_clean_up_base(event_loop); - } - if (clean_up_event_loop_mem) { - aws_mem_release(alloc, event_loop); - } - return NULL; -} - -static void s_destroy(struct aws_event_loop *event_loop) { - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: destroying event_loop", (void *)event_loop); - struct kqueue_loop *impl = event_loop->impl_data; - - /* Stop the event-thread. This might have already happened. It's safe to call multiple times. */ - s_stop(event_loop); - int err = s_wait_for_stop_completion(event_loop); - if (err) { - AWS_LOGF_WARN( - AWS_LS_IO_EVENT_LOOP, - "id=%p: failed to destroy event-thread, resources have been leaked", - (void *)event_loop); - AWS_ASSERT("Failed to destroy event-thread, resources have been leaked." == NULL); - return; - } - /* setting this so that canceled tasks don't blow up when asking if they're on the event-loop thread. */ - impl->thread_joined_to = aws_thread_current_thread_id(); - aws_atomic_store_ptr(&impl->running_thread_id, &impl->thread_joined_to); - - /* Clean up task-related stuff first. It's possible the a cancelled task adds further tasks to this event_loop. - * Tasks added in this way will be in cross_thread_data.tasks_to_schedule, so we clean that up last */ - - aws_task_scheduler_clean_up(&impl->thread_data.scheduler); /* Tasks in scheduler get cancelled*/ - - while (!aws_linked_list_empty(&impl->cross_thread_data.tasks_to_schedule)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&impl->cross_thread_data.tasks_to_schedule); - struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); - task->fn(task, task->arg, AWS_TASK_STATUS_CANCELED); - } - - /* Warn user if aws_io_handle was subscribed, but never unsubscribed. This would cause memory leaks. */ - AWS_ASSERT(impl->thread_data.connected_handle_count == 0); - - /* Clean up everything else */ - aws_mutex_clean_up(&impl->cross_thread_data.mutex); - - struct kevent thread_signal_kevent; - EV_SET( - &thread_signal_kevent, - impl->cross_thread_signal_pipe[READ_FD], - EVFILT_READ /*filter*/, - EV_DELETE /*flags*/, - 0 /*fflags*/, - 0 /*data*/, - NULL /*udata*/); - - kevent( - impl->kq_fd, - &thread_signal_kevent /*changelist*/, - 1 /*nchanges*/, - NULL /*eventlist*/, - 0 /*nevents*/, - NULL /*timeout*/); - - close(impl->cross_thread_signal_pipe[READ_FD]); - close(impl->cross_thread_signal_pipe[WRITE_FD]); - close(impl->kq_fd); - aws_thread_clean_up(&impl->thread_created_on); - aws_mem_release(event_loop->alloc, impl); - aws_event_loop_clean_up_base(event_loop); - aws_mem_release(event_loop->alloc, event_loop); -} - -static int s_run(struct aws_event_loop *event_loop) { - struct kqueue_loop *impl = event_loop->impl_data; - - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: starting event-loop thread.", (void *)event_loop); - /* to re-run, call stop() and wait_for_stop_completion() */ - AWS_ASSERT(impl->cross_thread_data.state == EVENT_THREAD_STATE_READY_TO_RUN); - AWS_ASSERT(impl->thread_data.state == EVENT_THREAD_STATE_READY_TO_RUN); - - /* Since thread isn't running it's ok to touch thread_data, - * and it's ok to touch cross_thread_data without locking the mutex */ - impl->cross_thread_data.state = EVENT_THREAD_STATE_RUNNING; - - int err = aws_thread_launch(&impl->thread_created_on, s_event_thread_main, (void *)event_loop, NULL); - if (err) { - AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: thread creation failed.", (void *)event_loop); - goto clean_up; - } - - return AWS_OP_SUCCESS; - -clean_up: - impl->cross_thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; - return AWS_OP_ERR; -} - -/* This function can't fail, we're relying on the thread responding to critical messages (ex: stop thread) */ -void signal_cross_thread_data_changed(struct aws_event_loop *event_loop) { - struct kqueue_loop *impl = event_loop->impl_data; - - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: signaling event-loop that cross-thread tasks need to be scheduled.", - (void *)event_loop); - /* Doesn't actually matter what we write, any activity on pipe signals that cross_thread_data has changed, - * If the pipe is full and the write fails, that's fine, the event-thread will get the signal from some previous - * write */ - uint32_t write_whatever = 0xC0FFEE; - write(impl->cross_thread_signal_pipe[WRITE_FD], &write_whatever, sizeof(write_whatever)); -} - -static int s_stop(struct aws_event_loop *event_loop) { - struct kqueue_loop *impl = event_loop->impl_data; - - bool signal_thread = false; - - { /* Begin critical section */ - aws_mutex_lock(&impl->cross_thread_data.mutex); - if (impl->cross_thread_data.state == EVENT_THREAD_STATE_RUNNING) { - impl->cross_thread_data.state = EVENT_THREAD_STATE_STOPPING; - signal_thread = !impl->cross_thread_data.thread_signaled; - impl->cross_thread_data.thread_signaled = true; - } - aws_mutex_unlock(&impl->cross_thread_data.mutex); - } /* End critical section */ - - if (signal_thread) { - signal_cross_thread_data_changed(event_loop); - } - - return AWS_OP_SUCCESS; -} - -static int s_wait_for_stop_completion(struct aws_event_loop *event_loop) { - struct kqueue_loop *impl = event_loop->impl_data; - -#ifdef DEBUG_BUILD - aws_mutex_lock(&impl->cross_thread_data.mutex); - /* call stop() before wait_for_stop_completion() or you'll wait forever */ - AWS_ASSERT(impl->cross_thread_data.state != EVENT_THREAD_STATE_RUNNING); - aws_mutex_unlock(&impl->cross_thread_data.mutex); -#endif - - int err = aws_thread_join(&impl->thread_created_on); - if (err) { - return AWS_OP_ERR; - } - - /* Since thread is no longer running it's ok to touch thread_data, - * and it's ok to touch cross_thread_data without locking the mutex */ - impl->cross_thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; - impl->thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; - - return AWS_OP_SUCCESS; -} - -/* Common functionality for "now" and "future" task scheduling. - * If `run_at_nanos` is zero then the task is scheduled as a "now" task. */ -static void s_schedule_task_common(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { - AWS_ASSERT(task); - struct kqueue_loop *impl = event_loop->impl_data; - - /* If we're on the event-thread, just schedule it directly */ - if (s_is_event_thread(event_loop)) { - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: scheduling task %p in-thread for timestamp %llu", - (void *)event_loop, - (void *)task, - (unsigned long long)run_at_nanos); - if (run_at_nanos == 0) { - aws_task_scheduler_schedule_now(&impl->thread_data.scheduler, task); - } else { - aws_task_scheduler_schedule_future(&impl->thread_data.scheduler, task, run_at_nanos); - } - return; - } - - /* Otherwise, add it to cross_thread_data.tasks_to_schedule and signal the event-thread to process it */ - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: scheduling task %p cross-thread for timestamp %llu", - (void *)event_loop, - (void *)task, - (unsigned long long)run_at_nanos); - task->timestamp = run_at_nanos; - bool should_signal_thread = false; - - /* Begin critical section */ - aws_mutex_lock(&impl->cross_thread_data.mutex); - aws_linked_list_push_back(&impl->cross_thread_data.tasks_to_schedule, &task->node); - - /* Signal thread that cross_thread_data has changed (unless it's been signaled already) */ - if (!impl->cross_thread_data.thread_signaled) { - should_signal_thread = true; - impl->cross_thread_data.thread_signaled = true; - } - - aws_mutex_unlock(&impl->cross_thread_data.mutex); - /* End critical section */ - - if (should_signal_thread) { - signal_cross_thread_data_changed(event_loop); - } -} - -static void s_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task) { - s_schedule_task_common(event_loop, task, 0); /* Zero is used to denote "now" tasks */ -} - -static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { - s_schedule_task_common(event_loop, task, run_at_nanos); -} - -static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task) { - struct kqueue_loop *kqueue_loop = event_loop->impl_data; - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: cancelling task %p", (void *)event_loop, (void *)task); - aws_task_scheduler_cancel_task(&kqueue_loop->thread_data.scheduler, task); -} - -/* Scheduled task that connects aws_io_handle with the kqueue */ -static void s_subscribe_task(struct aws_task *task, void *user_data, enum aws_task_status status) { - (void)task; - struct handle_data *handle_data = user_data; - struct aws_event_loop *event_loop = handle_data->event_loop; - struct kqueue_loop *impl = handle_data->event_loop->impl_data; - - impl->thread_data.connected_handle_count++; - - /* if task was cancelled, nothing to do */ - if (status == AWS_TASK_STATUS_CANCELED) { - return; - } - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, "id=%p: subscribing to events on fd %d", (void *)event_loop, handle_data->owner->data.fd); - - /* If handle was unsubscribed before this task could execute, nothing to do */ - if (handle_data->state == HANDLE_STATE_UNSUBSCRIBED) { - return; - } - - AWS_ASSERT(handle_data->state == HANDLE_STATE_SUBSCRIBING); - - /* In order to monitor both reads and writes, kqueue requires you to add two separate kevents. - * If we're adding two separate kevents, but one of those fails, we need to remove the other kevent. - * Therefore we use the EV_RECEIPT flag. This causes kevent() to tell whether each EV_ADD succeeded, - * rather than the usual behavior of telling us about recent events. */ - struct kevent changelist[2]; - AWS_ZERO_ARRAY(changelist); - - int changelist_size = 0; - - if (handle_data->events_subscribed & AWS_IO_EVENT_TYPE_READABLE) { - EV_SET( - &changelist[changelist_size++], - handle_data->owner->data.fd, - EVFILT_READ /*filter*/, - EV_ADD | EV_RECEIPT | EV_CLEAR /*flags*/, - 0 /*fflags*/, - 0 /*data*/, - handle_data /*udata*/); - } - if (handle_data->events_subscribed & AWS_IO_EVENT_TYPE_WRITABLE) { - EV_SET( - &changelist[changelist_size++], - handle_data->owner->data.fd, - EVFILT_WRITE /*filter*/, - EV_ADD | EV_RECEIPT | EV_CLEAR /*flags*/, - 0 /*fflags*/, - 0 /*data*/, - handle_data /*udata*/); - } - - int num_events = kevent( - impl->kq_fd, - changelist /*changelist*/, - changelist_size /*nchanges*/, - changelist /*eventlist. It's OK to re-use the same memory for changelist input and eventlist output*/, - changelist_size /*nevents*/, - NULL /*timeout*/); - if (num_events == -1) { - goto subscribe_failed; - } - - /* Look through results to see if any failed */ - for (int i = 0; i < num_events; ++i) { - /* Every result should be flagged as error, that's just how EV_RECEIPT works */ - AWS_ASSERT(changelist[i].flags & EV_ERROR); - - /* If a real error occurred, .data contains the error code */ - if (changelist[i].data != 0) { - goto subscribe_failed; - } - } - - /* Success */ - handle_data->state = HANDLE_STATE_SUBSCRIBED; - return; - -subscribe_failed: - AWS_LOGF_ERROR( - AWS_LS_IO_EVENT_LOOP, - "id=%p: failed to subscribe to events on fd %d", - (void *)event_loop, - handle_data->owner->data.fd); - /* Remove any related kevents that succeeded */ - for (int i = 0; i < num_events; ++i) { - if (changelist[i].data == 0) { - changelist[i].flags = EV_DELETE; - kevent( - impl->kq_fd, - &changelist[i] /*changelist*/, - 1 /*nchanges*/, - NULL /*eventlist*/, - 0 /*nevents*/, - NULL /*timeout*/); - } - } - - /* We can't return an error code because this was a scheduled task. - * Notify the user of the failed subscription by passing AWS_IO_EVENT_TYPE_ERROR to the callback. */ - handle_data->on_event(event_loop, handle_data->owner, AWS_IO_EVENT_TYPE_ERROR, handle_data->on_event_user_data); -} - -static int s_subscribe_to_io_events( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - aws_event_loop_on_event_fn *on_event, - void *user_data) { - - AWS_ASSERT(event_loop); - AWS_ASSERT(handle->data.fd != -1); - AWS_ASSERT(handle->additional_data == NULL); - AWS_ASSERT(on_event); - /* Must subscribe for read, write, or both */ - AWS_ASSERT(events & (AWS_IO_EVENT_TYPE_READABLE | AWS_IO_EVENT_TYPE_WRITABLE)); - - struct handle_data *handle_data = aws_mem_calloc(event_loop->alloc, 1, sizeof(struct handle_data)); - if (!handle_data) { - return AWS_OP_ERR; - } - - handle_data->owner = handle; - handle_data->event_loop = event_loop; - handle_data->on_event = on_event; - handle_data->on_event_user_data = user_data; - handle_data->events_subscribed = events; - handle_data->state = HANDLE_STATE_SUBSCRIBING; - - handle->additional_data = handle_data; - - /* We schedule a task to perform the actual changes to the kqueue, read on for an explanation why... - * - * kqueue requires separate registrations for read and write events. - * If the user wants to know about both read and write, we need register once for read and once for write. - * If the first registration succeeds, but the second registration fails, we need to delete the first registration. - * If this all happened outside the event-thread, the successful registration's events could begin processing - * in the brief window of time before the registration is deleted. */ - - aws_task_init(&handle_data->subscribe_task, s_subscribe_task, handle_data, "kqueue_event_loop_subscribe"); - s_schedule_task_now(event_loop, &handle_data->subscribe_task); - - return AWS_OP_SUCCESS; -} - -static void s_free_io_event_resources(void *user_data) { - struct handle_data *handle_data = user_data; - struct kqueue_loop *impl = handle_data->event_loop->impl_data; - - impl->thread_data.connected_handle_count--; - - aws_mem_release(handle_data->event_loop->alloc, handle_data); -} - -static void s_clean_up_handle_data_task(struct aws_task *task, void *user_data, enum aws_task_status status) { - (void)task; - (void)status; - - struct handle_data *handle_data = user_data; - s_free_io_event_resources(handle_data); -} - -static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, "id=%p: un-subscribing from events on fd %d", (void *)event_loop, handle->data.fd); - AWS_ASSERT(handle->additional_data); - struct handle_data *handle_data = handle->additional_data; - struct kqueue_loop *impl = event_loop->impl_data; - - AWS_ASSERT(event_loop == handle_data->event_loop); - - /* If the handle was successfully subscribed to kqueue, then remove it. */ - if (handle_data->state == HANDLE_STATE_SUBSCRIBED) { - struct kevent changelist[2]; - int changelist_size = 0; - - if (handle_data->events_subscribed & AWS_IO_EVENT_TYPE_READABLE) { - EV_SET( - &changelist[changelist_size++], - handle_data->owner->data.fd, - EVFILT_READ /*filter*/, - EV_DELETE /*flags*/, - 0 /*fflags*/, - 0 /*data*/, - handle_data /*udata*/); - } - if (handle_data->events_subscribed & AWS_IO_EVENT_TYPE_WRITABLE) { - EV_SET( - &changelist[changelist_size++], - handle_data->owner->data.fd, - EVFILT_WRITE /*filter*/, - EV_DELETE /*flags*/, - 0 /*fflags*/, - 0 /*data*/, - handle_data /*udata*/); - } - - kevent(impl->kq_fd, changelist, changelist_size, NULL /*eventlist*/, 0 /*nevents*/, NULL /*timeout*/); - } - - /* Schedule a task to clean up the memory. This is done in a task to prevent the following scenario: - * - While processing a batch of events, some callback unsubscribes another aws_io_handle. - * - One of the other events in this batch belongs to that other aws_io_handle. - * - If the handle_data were already deleted, there would be an access invalid memory. */ - - aws_task_init( - &handle_data->cleanup_task, s_clean_up_handle_data_task, handle_data, "kqueue_event_loop_clean_up_handle_data"); - aws_event_loop_schedule_task_now(event_loop, &handle_data->cleanup_task); - - handle_data->state = HANDLE_STATE_UNSUBSCRIBED; - handle->additional_data = NULL; - - return AWS_OP_SUCCESS; -} - -static bool s_is_event_thread(struct aws_event_loop *event_loop) { - struct kqueue_loop *impl = event_loop->impl_data; - - aws_thread_id_t *thread_id = aws_atomic_load_ptr(&impl->running_thread_id); - return thread_id && aws_thread_thread_id_equal(*thread_id, aws_thread_current_thread_id()); -} - -/* Called from thread. - * Takes tasks from tasks_to_schedule and adds them to the scheduler. */ -static void s_process_tasks_to_schedule(struct aws_event_loop *event_loop, struct aws_linked_list *tasks_to_schedule) { - struct kqueue_loop *impl = event_loop->impl_data; - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: processing cross-thread tasks", (void *)event_loop); - - while (!aws_linked_list_empty(tasks_to_schedule)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(tasks_to_schedule); - struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); - - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: task %p pulled to event-loop, scheduling now.", - (void *)event_loop, - (void *)task); - /* Timestamp 0 is used to denote "now" tasks */ - if (task->timestamp == 0) { - aws_task_scheduler_schedule_now(&impl->thread_data.scheduler, task); - } else { - aws_task_scheduler_schedule_future(&impl->thread_data.scheduler, task, task->timestamp); - } - } -} - -static void s_process_cross_thread_data(struct aws_event_loop *event_loop) { - struct kqueue_loop *impl = event_loop->impl_data; - - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: notified of cross-thread data to process", (void *)event_loop); - /* If there are tasks to schedule, grab them all out of synced_data.tasks_to_schedule. - * We'll process them later, so that we minimize time spent holding the mutex. */ - struct aws_linked_list tasks_to_schedule; - aws_linked_list_init(&tasks_to_schedule); - - { /* Begin critical section */ - aws_mutex_lock(&impl->cross_thread_data.mutex); - impl->cross_thread_data.thread_signaled = false; - - bool initiate_stop = (impl->cross_thread_data.state == EVENT_THREAD_STATE_STOPPING) && - (impl->thread_data.state == EVENT_THREAD_STATE_RUNNING); - if (AWS_UNLIKELY(initiate_stop)) { - impl->thread_data.state = EVENT_THREAD_STATE_STOPPING; - } - - aws_linked_list_swap_contents(&impl->cross_thread_data.tasks_to_schedule, &tasks_to_schedule); - - aws_mutex_unlock(&impl->cross_thread_data.mutex); - } /* End critical section */ - - s_process_tasks_to_schedule(event_loop, &tasks_to_schedule); -} - -static int s_aws_event_flags_from_kevent(struct kevent *kevent) { - int event_flags = 0; - - if (kevent->flags & EV_ERROR) { - event_flags |= AWS_IO_EVENT_TYPE_ERROR; - } else if (kevent->filter == EVFILT_READ) { - if (kevent->data != 0) { - event_flags |= AWS_IO_EVENT_TYPE_READABLE; - } - - if (kevent->flags & EV_EOF) { - event_flags |= AWS_IO_EVENT_TYPE_CLOSED; - } - } else if (kevent->filter == EVFILT_WRITE) { - if (kevent->data != 0) { - event_flags |= AWS_IO_EVENT_TYPE_WRITABLE; - } - - if (kevent->flags & EV_EOF) { - event_flags |= AWS_IO_EVENT_TYPE_CLOSED; - } - } - - return event_flags; -} - -static void s_event_thread_main(void *user_data) { - struct aws_event_loop *event_loop = user_data; - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: main loop started", (void *)event_loop); - struct kqueue_loop *impl = event_loop->impl_data; - - /* set thread id to the event-loop's thread. */ - aws_atomic_store_ptr(&impl->running_thread_id, &impl->thread_created_on.thread_id); - - AWS_ASSERT(impl->thread_data.state == EVENT_THREAD_STATE_READY_TO_RUN); - impl->thread_data.state = EVENT_THREAD_STATE_RUNNING; - - struct kevent kevents[MAX_EVENTS]; - - /* A single aws_io_handle could have two separate kevents if subscribed for both read and write. - * If both the read and write kevents fire in the same loop of the event-thread, - * combine the event-flags and deliver them in a single callback. - * This makes the kqueue_event_loop behave more like the other platform implementations. */ - struct handle_data *io_handle_events[MAX_EVENTS]; - - struct timespec timeout = { - .tv_sec = DEFAULT_TIMEOUT_SEC, - .tv_nsec = 0, - }; - - AWS_LOGF_INFO( - AWS_LS_IO_EVENT_LOOP, - "id=%p: default timeout %ds, and max events to process per tick %d", - (void *)event_loop, - DEFAULT_TIMEOUT_SEC, - MAX_EVENTS); - - while (impl->thread_data.state == EVENT_THREAD_STATE_RUNNING) { - int num_io_handle_events = 0; - bool should_process_cross_thread_data = false; - - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: waiting for a maximum of %ds %lluns", - (void *)event_loop, - (int)timeout.tv_sec, - (unsigned long long)timeout.tv_nsec); - - /* Process kqueue events */ - int num_kevents = kevent( - impl->kq_fd, NULL /*changelist*/, 0 /*nchanges*/, kevents /*eventlist*/, MAX_EVENTS /*nevents*/, &timeout); - - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, "id=%p: wake up with %d events to process.", (void *)event_loop, num_kevents); - if (num_kevents == -1) { - /* Raise an error, in case this is interesting to anyone monitoring, - * and continue on with this loop. We can't process events, - * but we can still process scheduled tasks */ - aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); - - /* Force the cross_thread_data to be processed. - * There might be valuable info in there, like the message to stop the thread. - * It's fine to do this even if nothing has changed, it just costs a mutex lock/unlock. */ - should_process_cross_thread_data = true; - } - - for (int i = 0; i < num_kevents; ++i) { - struct kevent *kevent = &kevents[i]; - - /* Was this event to signal that cross_thread_data has changed? */ - if ((int)kevent->ident == impl->cross_thread_signal_pipe[READ_FD]) { - should_process_cross_thread_data = true; - - /* Drain whatever data was written to the signaling pipe */ - uint32_t read_whatever; - while (read((int)kevent->ident, &read_whatever, sizeof(read_whatever)) > 0) { - } - - continue; - } - - /* Otherwise this was a normal event on a subscribed handle. Figure out which flags to report. */ - int event_flags = s_aws_event_flags_from_kevent(kevent); - if (event_flags == 0) { - continue; - } - - /* Combine flags, in case multiple kevents correspond to one handle. (see notes at top of function) */ - struct handle_data *handle_data = kevent->udata; - if (handle_data->events_this_loop == 0) { - io_handle_events[num_io_handle_events++] = handle_data; - } - handle_data->events_this_loop |= event_flags; - } - - /* Invoke each handle's event callback (unless the handle has been unsubscribed) */ - for (int i = 0; i < num_io_handle_events; ++i) { - struct handle_data *handle_data = io_handle_events[i]; - - if (handle_data->state == HANDLE_STATE_SUBSCRIBED) { - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: activity on fd %d, invoking handler.", - (void *)event_loop, - handle_data->owner->data.fd); - handle_data->on_event( - event_loop, handle_data->owner, handle_data->events_this_loop, handle_data->on_event_user_data); - } - - handle_data->events_this_loop = 0; - } - - /* Process cross_thread_data */ - if (should_process_cross_thread_data) { - s_process_cross_thread_data(event_loop); - } - - /* Run scheduled tasks */ - uint64_t now_ns = 0; - event_loop->clock(&now_ns); /* If clock fails, now_ns will be 0 and tasks scheduled for a specific time - will not be run. That's ok, we'll handle them next time around. */ - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: running scheduled tasks.", (void *)event_loop); - aws_task_scheduler_run_all(&impl->thread_data.scheduler, now_ns); - - /* Set timeout for next kevent() call. - * If clock fails, or scheduler has no tasks, use default timeout */ - bool use_default_timeout = false; - - int err = event_loop->clock(&now_ns); - if (err) { - use_default_timeout = true; - } - - uint64_t next_run_time_ns; - if (!aws_task_scheduler_has_tasks(&impl->thread_data.scheduler, &next_run_time_ns)) { - - use_default_timeout = true; - } - - if (use_default_timeout) { - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, "id=%p: no more scheduled tasks using default timeout.", (void *)event_loop); - timeout.tv_sec = DEFAULT_TIMEOUT_SEC; - timeout.tv_nsec = 0; - } else { - /* Convert from timestamp in nanoseconds, to timeout in seconds with nanosecond remainder */ - uint64_t timeout_ns = next_run_time_ns > now_ns ? next_run_time_ns - now_ns : 0; - - uint64_t timeout_remainder_ns = 0; - uint64_t timeout_sec = - aws_timestamp_convert(timeout_ns, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_SECS, &timeout_remainder_ns); - - if (timeout_sec > LONG_MAX) { /* Check for overflow. On Darwin, these values are stored as longs */ - timeout_sec = LONG_MAX; - timeout_remainder_ns = 0; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: detected more scheduled tasks with the next occurring at " - "%llu using timeout of %ds %lluns.", - (void *)event_loop, - (unsigned long long)timeout_ns, - (int)timeout_sec, - (unsigned long long)timeout_remainder_ns); - timeout.tv_sec = (time_t)(timeout_sec); - timeout.tv_nsec = (long)(timeout_remainder_ns); - } - } - - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: exiting main loop", (void *)event_loop); - /* reset to NULL. This should be updated again during destroy before tasks are canceled. */ - aws_atomic_store_ptr(&impl->running_thread_id, NULL); -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/event_loop.h> + +#include <aws/io/logging.h> + +#include <aws/common/atomics.h> +#include <aws/common/clock.h> +#include <aws/common/mutex.h> +#include <aws/common/task_scheduler.h> +#include <aws/common/thread.h> + +#if defined(__FreeBSD__) || defined(__NetBSD__) +# define __BSD_VISIBLE 1 +# include <sys/types.h> +#endif + +#include <sys/event.h> + +#include <aws/io/io.h> +#include <limits.h> +#include <unistd.h> + +static void s_destroy(struct aws_event_loop *event_loop); +static int s_run(struct aws_event_loop *event_loop); +static int s_stop(struct aws_event_loop *event_loop); +static int s_wait_for_stop_completion(struct aws_event_loop *event_loop); +static void s_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task); +static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos); +static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task); +static int s_subscribe_to_io_events( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + aws_event_loop_on_event_fn *on_event, + void *user_data); +static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle); +static void s_free_io_event_resources(void *user_data); +static bool s_is_event_thread(struct aws_event_loop *event_loop); + +static void s_event_thread_main(void *user_data); + +int aws_open_nonblocking_posix_pipe(int pipe_fds[2]); + +enum event_thread_state { + EVENT_THREAD_STATE_READY_TO_RUN, + EVENT_THREAD_STATE_RUNNING, + EVENT_THREAD_STATE_STOPPING, +}; + +enum pipe_fd_index { + READ_FD, + WRITE_FD, +}; + +struct kqueue_loop { + /* thread_created_on is the handle to the event loop thread. */ + struct aws_thread thread_created_on; + /* thread_joined_to is used by the thread destroying the event loop. */ + aws_thread_id_t thread_joined_to; + /* running_thread_id is NULL if the event loop thread is stopped or points-to the thread_id of the thread running + * the event loop (either thread_created_on or thread_joined_to). Atomic because of concurrent writes (e.g., + * run/stop) and reads (e.g., is_event_loop_thread). + * An aws_thread_id_t variable itself cannot be atomic because it is an opaque type that is platform-dependent. */ + struct aws_atomic_var running_thread_id; + int kq_fd; /* kqueue file descriptor */ + + /* Pipe for signaling to event-thread that cross_thread_data has changed. */ + int cross_thread_signal_pipe[2]; + + /* cross_thread_data holds things that must be communicated across threads. + * When the event-thread is running, the mutex must be locked while anyone touches anything in cross_thread_data. + * If this data is modified outside the thread, the thread is signaled via activity on a pipe. */ + struct { + struct aws_mutex mutex; + bool thread_signaled; /* whether thread has been signaled about changes to cross_thread_data */ + struct aws_linked_list tasks_to_schedule; + enum event_thread_state state; + } cross_thread_data; + + /* thread_data holds things which, when the event-thread is running, may only be touched by the thread */ + struct { + struct aws_task_scheduler scheduler; + + int connected_handle_count; + + /* These variables duplicate ones in cross_thread_data. We move values out while holding the mutex and operate + * on them later */ + enum event_thread_state state; + } thread_data; +}; + +/* Data attached to aws_io_handle while the handle is subscribed to io events */ +struct handle_data { + struct aws_io_handle *owner; + struct aws_event_loop *event_loop; + aws_event_loop_on_event_fn *on_event; + void *on_event_user_data; + + int events_subscribed; /* aws_io_event_types this handle should be subscribed to */ + int events_this_loop; /* aws_io_event_types received during current loop of the event-thread */ + + enum { HANDLE_STATE_SUBSCRIBING, HANDLE_STATE_SUBSCRIBED, HANDLE_STATE_UNSUBSCRIBED } state; + + struct aws_task subscribe_task; + struct aws_task cleanup_task; +}; + +enum { + DEFAULT_TIMEOUT_SEC = 100, /* Max kevent() timeout per loop of the event-thread */ + MAX_EVENTS = 100, /* Max kevents to process per loop of the event-thread */ +}; + +struct aws_event_loop_vtable s_kqueue_vtable = { + .destroy = s_destroy, + .run = s_run, + .stop = s_stop, + .wait_for_stop_completion = s_wait_for_stop_completion, + .schedule_task_now = s_schedule_task_now, + .schedule_task_future = s_schedule_task_future, + .subscribe_to_io_events = s_subscribe_to_io_events, + .cancel_task = s_cancel_task, + .unsubscribe_from_io_events = s_unsubscribe_from_io_events, + .free_io_event_resources = s_free_io_event_resources, + .is_on_callers_thread = s_is_event_thread, +}; + +struct aws_event_loop *aws_event_loop_new_default(struct aws_allocator *alloc, aws_io_clock_fn *clock) { + AWS_ASSERT(alloc); + AWS_ASSERT(clock); + + bool clean_up_event_loop_mem = false; + bool clean_up_event_loop_base = false; + bool clean_up_impl_mem = false; + bool clean_up_thread = false; + bool clean_up_kqueue = false; + bool clean_up_signal_pipe = false; + bool clean_up_signal_kevent = false; + bool clean_up_mutex = false; + + struct aws_event_loop *event_loop = aws_mem_acquire(alloc, sizeof(struct aws_event_loop)); + if (!event_loop) { + return NULL; + } + + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Initializing edge-triggered kqueue", (void *)event_loop); + clean_up_event_loop_mem = true; + + int err = aws_event_loop_init_base(event_loop, alloc, clock); + if (err) { + goto clean_up; + } + clean_up_event_loop_base = true; + + struct kqueue_loop *impl = aws_mem_calloc(alloc, 1, sizeof(struct kqueue_loop)); + if (!impl) { + goto clean_up; + } + /* intialize thread id to NULL. It will be set when the event loop thread starts. */ + aws_atomic_init_ptr(&impl->running_thread_id, NULL); + clean_up_impl_mem = true; + + err = aws_thread_init(&impl->thread_created_on, alloc); + if (err) { + goto clean_up; + } + clean_up_thread = true; + + impl->kq_fd = kqueue(); + if (impl->kq_fd == -1) { + AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: Failed to open kqueue handle.", (void *)event_loop); + aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); + goto clean_up; + } + clean_up_kqueue = true; + + err = aws_open_nonblocking_posix_pipe(impl->cross_thread_signal_pipe); + if (err) { + AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: failed to open pipe handle.", (void *)event_loop); + goto clean_up; + } + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: pipe descriptors read %d, write %d.", + (void *)event_loop, + impl->cross_thread_signal_pipe[READ_FD], + impl->cross_thread_signal_pipe[WRITE_FD]); + clean_up_signal_pipe = true; + + /* Set up kevent to handle activity on the cross_thread_signal_pipe */ + struct kevent thread_signal_kevent; + EV_SET( + &thread_signal_kevent, + impl->cross_thread_signal_pipe[READ_FD], + EVFILT_READ /*filter*/, + EV_ADD | EV_CLEAR /*flags*/, + 0 /*fflags*/, + 0 /*data*/, + NULL /*udata*/); + + int res = kevent( + impl->kq_fd, + &thread_signal_kevent /*changelist*/, + 1 /*nchanges*/, + NULL /*eventlist*/, + 0 /*nevents*/, + NULL /*timeout*/); + + if (res == -1) { + AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: failed to create cross-thread signal kevent.", (void *)event_loop); + aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); + goto clean_up; + } + clean_up_signal_kevent = true; + + err = aws_mutex_init(&impl->cross_thread_data.mutex); + if (err) { + goto clean_up; + } + clean_up_mutex = true; + + impl->cross_thread_data.thread_signaled = false; + + aws_linked_list_init(&impl->cross_thread_data.tasks_to_schedule); + + impl->cross_thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; + + err = aws_task_scheduler_init(&impl->thread_data.scheduler, alloc); + if (err) { + goto clean_up; + } + + impl->thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; + + event_loop->impl_data = impl; + + event_loop->vtable = &s_kqueue_vtable; + + /* success */ + return event_loop; + +clean_up: + if (clean_up_mutex) { + aws_mutex_clean_up(&impl->cross_thread_data.mutex); + } + if (clean_up_signal_kevent) { + thread_signal_kevent.flags = EV_DELETE; + kevent( + impl->kq_fd, + &thread_signal_kevent /*changelist*/, + 1 /*nchanges*/, + NULL /*eventlist*/, + 0 /*nevents*/, + NULL /*timeout*/); + } + if (clean_up_signal_pipe) { + close(impl->cross_thread_signal_pipe[READ_FD]); + close(impl->cross_thread_signal_pipe[WRITE_FD]); + } + if (clean_up_kqueue) { + close(impl->kq_fd); + } + if (clean_up_thread) { + aws_thread_clean_up(&impl->thread_created_on); + } + if (clean_up_impl_mem) { + aws_mem_release(alloc, impl); + } + if (clean_up_event_loop_base) { + aws_event_loop_clean_up_base(event_loop); + } + if (clean_up_event_loop_mem) { + aws_mem_release(alloc, event_loop); + } + return NULL; +} + +static void s_destroy(struct aws_event_loop *event_loop) { + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: destroying event_loop", (void *)event_loop); + struct kqueue_loop *impl = event_loop->impl_data; + + /* Stop the event-thread. This might have already happened. It's safe to call multiple times. */ + s_stop(event_loop); + int err = s_wait_for_stop_completion(event_loop); + if (err) { + AWS_LOGF_WARN( + AWS_LS_IO_EVENT_LOOP, + "id=%p: failed to destroy event-thread, resources have been leaked", + (void *)event_loop); + AWS_ASSERT("Failed to destroy event-thread, resources have been leaked." == NULL); + return; + } + /* setting this so that canceled tasks don't blow up when asking if they're on the event-loop thread. */ + impl->thread_joined_to = aws_thread_current_thread_id(); + aws_atomic_store_ptr(&impl->running_thread_id, &impl->thread_joined_to); + + /* Clean up task-related stuff first. It's possible the a cancelled task adds further tasks to this event_loop. + * Tasks added in this way will be in cross_thread_data.tasks_to_schedule, so we clean that up last */ + + aws_task_scheduler_clean_up(&impl->thread_data.scheduler); /* Tasks in scheduler get cancelled*/ + + while (!aws_linked_list_empty(&impl->cross_thread_data.tasks_to_schedule)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&impl->cross_thread_data.tasks_to_schedule); + struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); + task->fn(task, task->arg, AWS_TASK_STATUS_CANCELED); + } + + /* Warn user if aws_io_handle was subscribed, but never unsubscribed. This would cause memory leaks. */ + AWS_ASSERT(impl->thread_data.connected_handle_count == 0); + + /* Clean up everything else */ + aws_mutex_clean_up(&impl->cross_thread_data.mutex); + + struct kevent thread_signal_kevent; + EV_SET( + &thread_signal_kevent, + impl->cross_thread_signal_pipe[READ_FD], + EVFILT_READ /*filter*/, + EV_DELETE /*flags*/, + 0 /*fflags*/, + 0 /*data*/, + NULL /*udata*/); + + kevent( + impl->kq_fd, + &thread_signal_kevent /*changelist*/, + 1 /*nchanges*/, + NULL /*eventlist*/, + 0 /*nevents*/, + NULL /*timeout*/); + + close(impl->cross_thread_signal_pipe[READ_FD]); + close(impl->cross_thread_signal_pipe[WRITE_FD]); + close(impl->kq_fd); + aws_thread_clean_up(&impl->thread_created_on); + aws_mem_release(event_loop->alloc, impl); + aws_event_loop_clean_up_base(event_loop); + aws_mem_release(event_loop->alloc, event_loop); +} + +static int s_run(struct aws_event_loop *event_loop) { + struct kqueue_loop *impl = event_loop->impl_data; + + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: starting event-loop thread.", (void *)event_loop); + /* to re-run, call stop() and wait_for_stop_completion() */ + AWS_ASSERT(impl->cross_thread_data.state == EVENT_THREAD_STATE_READY_TO_RUN); + AWS_ASSERT(impl->thread_data.state == EVENT_THREAD_STATE_READY_TO_RUN); + + /* Since thread isn't running it's ok to touch thread_data, + * and it's ok to touch cross_thread_data without locking the mutex */ + impl->cross_thread_data.state = EVENT_THREAD_STATE_RUNNING; + + int err = aws_thread_launch(&impl->thread_created_on, s_event_thread_main, (void *)event_loop, NULL); + if (err) { + AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: thread creation failed.", (void *)event_loop); + goto clean_up; + } + + return AWS_OP_SUCCESS; + +clean_up: + impl->cross_thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; + return AWS_OP_ERR; +} + +/* This function can't fail, we're relying on the thread responding to critical messages (ex: stop thread) */ +void signal_cross_thread_data_changed(struct aws_event_loop *event_loop) { + struct kqueue_loop *impl = event_loop->impl_data; + + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: signaling event-loop that cross-thread tasks need to be scheduled.", + (void *)event_loop); + /* Doesn't actually matter what we write, any activity on pipe signals that cross_thread_data has changed, + * If the pipe is full and the write fails, that's fine, the event-thread will get the signal from some previous + * write */ + uint32_t write_whatever = 0xC0FFEE; + write(impl->cross_thread_signal_pipe[WRITE_FD], &write_whatever, sizeof(write_whatever)); +} + +static int s_stop(struct aws_event_loop *event_loop) { + struct kqueue_loop *impl = event_loop->impl_data; + + bool signal_thread = false; + + { /* Begin critical section */ + aws_mutex_lock(&impl->cross_thread_data.mutex); + if (impl->cross_thread_data.state == EVENT_THREAD_STATE_RUNNING) { + impl->cross_thread_data.state = EVENT_THREAD_STATE_STOPPING; + signal_thread = !impl->cross_thread_data.thread_signaled; + impl->cross_thread_data.thread_signaled = true; + } + aws_mutex_unlock(&impl->cross_thread_data.mutex); + } /* End critical section */ + + if (signal_thread) { + signal_cross_thread_data_changed(event_loop); + } + + return AWS_OP_SUCCESS; +} + +static int s_wait_for_stop_completion(struct aws_event_loop *event_loop) { + struct kqueue_loop *impl = event_loop->impl_data; + +#ifdef DEBUG_BUILD + aws_mutex_lock(&impl->cross_thread_data.mutex); + /* call stop() before wait_for_stop_completion() or you'll wait forever */ + AWS_ASSERT(impl->cross_thread_data.state != EVENT_THREAD_STATE_RUNNING); + aws_mutex_unlock(&impl->cross_thread_data.mutex); +#endif + + int err = aws_thread_join(&impl->thread_created_on); + if (err) { + return AWS_OP_ERR; + } + + /* Since thread is no longer running it's ok to touch thread_data, + * and it's ok to touch cross_thread_data without locking the mutex */ + impl->cross_thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; + impl->thread_data.state = EVENT_THREAD_STATE_READY_TO_RUN; + + return AWS_OP_SUCCESS; +} + +/* Common functionality for "now" and "future" task scheduling. + * If `run_at_nanos` is zero then the task is scheduled as a "now" task. */ +static void s_schedule_task_common(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { + AWS_ASSERT(task); + struct kqueue_loop *impl = event_loop->impl_data; + + /* If we're on the event-thread, just schedule it directly */ + if (s_is_event_thread(event_loop)) { + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: scheduling task %p in-thread for timestamp %llu", + (void *)event_loop, + (void *)task, + (unsigned long long)run_at_nanos); + if (run_at_nanos == 0) { + aws_task_scheduler_schedule_now(&impl->thread_data.scheduler, task); + } else { + aws_task_scheduler_schedule_future(&impl->thread_data.scheduler, task, run_at_nanos); + } + return; + } + + /* Otherwise, add it to cross_thread_data.tasks_to_schedule and signal the event-thread to process it */ + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: scheduling task %p cross-thread for timestamp %llu", + (void *)event_loop, + (void *)task, + (unsigned long long)run_at_nanos); + task->timestamp = run_at_nanos; + bool should_signal_thread = false; + + /* Begin critical section */ + aws_mutex_lock(&impl->cross_thread_data.mutex); + aws_linked_list_push_back(&impl->cross_thread_data.tasks_to_schedule, &task->node); + + /* Signal thread that cross_thread_data has changed (unless it's been signaled already) */ + if (!impl->cross_thread_data.thread_signaled) { + should_signal_thread = true; + impl->cross_thread_data.thread_signaled = true; + } + + aws_mutex_unlock(&impl->cross_thread_data.mutex); + /* End critical section */ + + if (should_signal_thread) { + signal_cross_thread_data_changed(event_loop); + } +} + +static void s_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task) { + s_schedule_task_common(event_loop, task, 0); /* Zero is used to denote "now" tasks */ +} + +static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { + s_schedule_task_common(event_loop, task, run_at_nanos); +} + +static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task) { + struct kqueue_loop *kqueue_loop = event_loop->impl_data; + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: cancelling task %p", (void *)event_loop, (void *)task); + aws_task_scheduler_cancel_task(&kqueue_loop->thread_data.scheduler, task); +} + +/* Scheduled task that connects aws_io_handle with the kqueue */ +static void s_subscribe_task(struct aws_task *task, void *user_data, enum aws_task_status status) { + (void)task; + struct handle_data *handle_data = user_data; + struct aws_event_loop *event_loop = handle_data->event_loop; + struct kqueue_loop *impl = handle_data->event_loop->impl_data; + + impl->thread_data.connected_handle_count++; + + /* if task was cancelled, nothing to do */ + if (status == AWS_TASK_STATUS_CANCELED) { + return; + } + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, "id=%p: subscribing to events on fd %d", (void *)event_loop, handle_data->owner->data.fd); + + /* If handle was unsubscribed before this task could execute, nothing to do */ + if (handle_data->state == HANDLE_STATE_UNSUBSCRIBED) { + return; + } + + AWS_ASSERT(handle_data->state == HANDLE_STATE_SUBSCRIBING); + + /* In order to monitor both reads and writes, kqueue requires you to add two separate kevents. + * If we're adding two separate kevents, but one of those fails, we need to remove the other kevent. + * Therefore we use the EV_RECEIPT flag. This causes kevent() to tell whether each EV_ADD succeeded, + * rather than the usual behavior of telling us about recent events. */ + struct kevent changelist[2]; + AWS_ZERO_ARRAY(changelist); + + int changelist_size = 0; + + if (handle_data->events_subscribed & AWS_IO_EVENT_TYPE_READABLE) { + EV_SET( + &changelist[changelist_size++], + handle_data->owner->data.fd, + EVFILT_READ /*filter*/, + EV_ADD | EV_RECEIPT | EV_CLEAR /*flags*/, + 0 /*fflags*/, + 0 /*data*/, + handle_data /*udata*/); + } + if (handle_data->events_subscribed & AWS_IO_EVENT_TYPE_WRITABLE) { + EV_SET( + &changelist[changelist_size++], + handle_data->owner->data.fd, + EVFILT_WRITE /*filter*/, + EV_ADD | EV_RECEIPT | EV_CLEAR /*flags*/, + 0 /*fflags*/, + 0 /*data*/, + handle_data /*udata*/); + } + + int num_events = kevent( + impl->kq_fd, + changelist /*changelist*/, + changelist_size /*nchanges*/, + changelist /*eventlist. It's OK to re-use the same memory for changelist input and eventlist output*/, + changelist_size /*nevents*/, + NULL /*timeout*/); + if (num_events == -1) { + goto subscribe_failed; + } + + /* Look through results to see if any failed */ + for (int i = 0; i < num_events; ++i) { + /* Every result should be flagged as error, that's just how EV_RECEIPT works */ + AWS_ASSERT(changelist[i].flags & EV_ERROR); + + /* If a real error occurred, .data contains the error code */ + if (changelist[i].data != 0) { + goto subscribe_failed; + } + } + + /* Success */ + handle_data->state = HANDLE_STATE_SUBSCRIBED; + return; + +subscribe_failed: + AWS_LOGF_ERROR( + AWS_LS_IO_EVENT_LOOP, + "id=%p: failed to subscribe to events on fd %d", + (void *)event_loop, + handle_data->owner->data.fd); + /* Remove any related kevents that succeeded */ + for (int i = 0; i < num_events; ++i) { + if (changelist[i].data == 0) { + changelist[i].flags = EV_DELETE; + kevent( + impl->kq_fd, + &changelist[i] /*changelist*/, + 1 /*nchanges*/, + NULL /*eventlist*/, + 0 /*nevents*/, + NULL /*timeout*/); + } + } + + /* We can't return an error code because this was a scheduled task. + * Notify the user of the failed subscription by passing AWS_IO_EVENT_TYPE_ERROR to the callback. */ + handle_data->on_event(event_loop, handle_data->owner, AWS_IO_EVENT_TYPE_ERROR, handle_data->on_event_user_data); +} + +static int s_subscribe_to_io_events( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + aws_event_loop_on_event_fn *on_event, + void *user_data) { + + AWS_ASSERT(event_loop); + AWS_ASSERT(handle->data.fd != -1); + AWS_ASSERT(handle->additional_data == NULL); + AWS_ASSERT(on_event); + /* Must subscribe for read, write, or both */ + AWS_ASSERT(events & (AWS_IO_EVENT_TYPE_READABLE | AWS_IO_EVENT_TYPE_WRITABLE)); + + struct handle_data *handle_data = aws_mem_calloc(event_loop->alloc, 1, sizeof(struct handle_data)); + if (!handle_data) { + return AWS_OP_ERR; + } + + handle_data->owner = handle; + handle_data->event_loop = event_loop; + handle_data->on_event = on_event; + handle_data->on_event_user_data = user_data; + handle_data->events_subscribed = events; + handle_data->state = HANDLE_STATE_SUBSCRIBING; + + handle->additional_data = handle_data; + + /* We schedule a task to perform the actual changes to the kqueue, read on for an explanation why... + * + * kqueue requires separate registrations for read and write events. + * If the user wants to know about both read and write, we need register once for read and once for write. + * If the first registration succeeds, but the second registration fails, we need to delete the first registration. + * If this all happened outside the event-thread, the successful registration's events could begin processing + * in the brief window of time before the registration is deleted. */ + + aws_task_init(&handle_data->subscribe_task, s_subscribe_task, handle_data, "kqueue_event_loop_subscribe"); + s_schedule_task_now(event_loop, &handle_data->subscribe_task); + + return AWS_OP_SUCCESS; +} + +static void s_free_io_event_resources(void *user_data) { + struct handle_data *handle_data = user_data; + struct kqueue_loop *impl = handle_data->event_loop->impl_data; + + impl->thread_data.connected_handle_count--; + + aws_mem_release(handle_data->event_loop->alloc, handle_data); +} + +static void s_clean_up_handle_data_task(struct aws_task *task, void *user_data, enum aws_task_status status) { + (void)task; + (void)status; + + struct handle_data *handle_data = user_data; + s_free_io_event_resources(handle_data); +} + +static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, "id=%p: un-subscribing from events on fd %d", (void *)event_loop, handle->data.fd); + AWS_ASSERT(handle->additional_data); + struct handle_data *handle_data = handle->additional_data; + struct kqueue_loop *impl = event_loop->impl_data; + + AWS_ASSERT(event_loop == handle_data->event_loop); + + /* If the handle was successfully subscribed to kqueue, then remove it. */ + if (handle_data->state == HANDLE_STATE_SUBSCRIBED) { + struct kevent changelist[2]; + int changelist_size = 0; + + if (handle_data->events_subscribed & AWS_IO_EVENT_TYPE_READABLE) { + EV_SET( + &changelist[changelist_size++], + handle_data->owner->data.fd, + EVFILT_READ /*filter*/, + EV_DELETE /*flags*/, + 0 /*fflags*/, + 0 /*data*/, + handle_data /*udata*/); + } + if (handle_data->events_subscribed & AWS_IO_EVENT_TYPE_WRITABLE) { + EV_SET( + &changelist[changelist_size++], + handle_data->owner->data.fd, + EVFILT_WRITE /*filter*/, + EV_DELETE /*flags*/, + 0 /*fflags*/, + 0 /*data*/, + handle_data /*udata*/); + } + + kevent(impl->kq_fd, changelist, changelist_size, NULL /*eventlist*/, 0 /*nevents*/, NULL /*timeout*/); + } + + /* Schedule a task to clean up the memory. This is done in a task to prevent the following scenario: + * - While processing a batch of events, some callback unsubscribes another aws_io_handle. + * - One of the other events in this batch belongs to that other aws_io_handle. + * - If the handle_data were already deleted, there would be an access invalid memory. */ + + aws_task_init( + &handle_data->cleanup_task, s_clean_up_handle_data_task, handle_data, "kqueue_event_loop_clean_up_handle_data"); + aws_event_loop_schedule_task_now(event_loop, &handle_data->cleanup_task); + + handle_data->state = HANDLE_STATE_UNSUBSCRIBED; + handle->additional_data = NULL; + + return AWS_OP_SUCCESS; +} + +static bool s_is_event_thread(struct aws_event_loop *event_loop) { + struct kqueue_loop *impl = event_loop->impl_data; + + aws_thread_id_t *thread_id = aws_atomic_load_ptr(&impl->running_thread_id); + return thread_id && aws_thread_thread_id_equal(*thread_id, aws_thread_current_thread_id()); +} + +/* Called from thread. + * Takes tasks from tasks_to_schedule and adds them to the scheduler. */ +static void s_process_tasks_to_schedule(struct aws_event_loop *event_loop, struct aws_linked_list *tasks_to_schedule) { + struct kqueue_loop *impl = event_loop->impl_data; + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: processing cross-thread tasks", (void *)event_loop); + + while (!aws_linked_list_empty(tasks_to_schedule)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(tasks_to_schedule); + struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); + + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: task %p pulled to event-loop, scheduling now.", + (void *)event_loop, + (void *)task); + /* Timestamp 0 is used to denote "now" tasks */ + if (task->timestamp == 0) { + aws_task_scheduler_schedule_now(&impl->thread_data.scheduler, task); + } else { + aws_task_scheduler_schedule_future(&impl->thread_data.scheduler, task, task->timestamp); + } + } +} + +static void s_process_cross_thread_data(struct aws_event_loop *event_loop) { + struct kqueue_loop *impl = event_loop->impl_data; + + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: notified of cross-thread data to process", (void *)event_loop); + /* If there are tasks to schedule, grab them all out of synced_data.tasks_to_schedule. + * We'll process them later, so that we minimize time spent holding the mutex. */ + struct aws_linked_list tasks_to_schedule; + aws_linked_list_init(&tasks_to_schedule); + + { /* Begin critical section */ + aws_mutex_lock(&impl->cross_thread_data.mutex); + impl->cross_thread_data.thread_signaled = false; + + bool initiate_stop = (impl->cross_thread_data.state == EVENT_THREAD_STATE_STOPPING) && + (impl->thread_data.state == EVENT_THREAD_STATE_RUNNING); + if (AWS_UNLIKELY(initiate_stop)) { + impl->thread_data.state = EVENT_THREAD_STATE_STOPPING; + } + + aws_linked_list_swap_contents(&impl->cross_thread_data.tasks_to_schedule, &tasks_to_schedule); + + aws_mutex_unlock(&impl->cross_thread_data.mutex); + } /* End critical section */ + + s_process_tasks_to_schedule(event_loop, &tasks_to_schedule); +} + +static int s_aws_event_flags_from_kevent(struct kevent *kevent) { + int event_flags = 0; + + if (kevent->flags & EV_ERROR) { + event_flags |= AWS_IO_EVENT_TYPE_ERROR; + } else if (kevent->filter == EVFILT_READ) { + if (kevent->data != 0) { + event_flags |= AWS_IO_EVENT_TYPE_READABLE; + } + + if (kevent->flags & EV_EOF) { + event_flags |= AWS_IO_EVENT_TYPE_CLOSED; + } + } else if (kevent->filter == EVFILT_WRITE) { + if (kevent->data != 0) { + event_flags |= AWS_IO_EVENT_TYPE_WRITABLE; + } + + if (kevent->flags & EV_EOF) { + event_flags |= AWS_IO_EVENT_TYPE_CLOSED; + } + } + + return event_flags; +} + +static void s_event_thread_main(void *user_data) { + struct aws_event_loop *event_loop = user_data; + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: main loop started", (void *)event_loop); + struct kqueue_loop *impl = event_loop->impl_data; + + /* set thread id to the event-loop's thread. */ + aws_atomic_store_ptr(&impl->running_thread_id, &impl->thread_created_on.thread_id); + + AWS_ASSERT(impl->thread_data.state == EVENT_THREAD_STATE_READY_TO_RUN); + impl->thread_data.state = EVENT_THREAD_STATE_RUNNING; + + struct kevent kevents[MAX_EVENTS]; + + /* A single aws_io_handle could have two separate kevents if subscribed for both read and write. + * If both the read and write kevents fire in the same loop of the event-thread, + * combine the event-flags and deliver them in a single callback. + * This makes the kqueue_event_loop behave more like the other platform implementations. */ + struct handle_data *io_handle_events[MAX_EVENTS]; + + struct timespec timeout = { + .tv_sec = DEFAULT_TIMEOUT_SEC, + .tv_nsec = 0, + }; + + AWS_LOGF_INFO( + AWS_LS_IO_EVENT_LOOP, + "id=%p: default timeout %ds, and max events to process per tick %d", + (void *)event_loop, + DEFAULT_TIMEOUT_SEC, + MAX_EVENTS); + + while (impl->thread_data.state == EVENT_THREAD_STATE_RUNNING) { + int num_io_handle_events = 0; + bool should_process_cross_thread_data = false; + + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: waiting for a maximum of %ds %lluns", + (void *)event_loop, + (int)timeout.tv_sec, + (unsigned long long)timeout.tv_nsec); + + /* Process kqueue events */ + int num_kevents = kevent( + impl->kq_fd, NULL /*changelist*/, 0 /*nchanges*/, kevents /*eventlist*/, MAX_EVENTS /*nevents*/, &timeout); + + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, "id=%p: wake up with %d events to process.", (void *)event_loop, num_kevents); + if (num_kevents == -1) { + /* Raise an error, in case this is interesting to anyone monitoring, + * and continue on with this loop. We can't process events, + * but we can still process scheduled tasks */ + aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); + + /* Force the cross_thread_data to be processed. + * There might be valuable info in there, like the message to stop the thread. + * It's fine to do this even if nothing has changed, it just costs a mutex lock/unlock. */ + should_process_cross_thread_data = true; + } + + for (int i = 0; i < num_kevents; ++i) { + struct kevent *kevent = &kevents[i]; + + /* Was this event to signal that cross_thread_data has changed? */ + if ((int)kevent->ident == impl->cross_thread_signal_pipe[READ_FD]) { + should_process_cross_thread_data = true; + + /* Drain whatever data was written to the signaling pipe */ + uint32_t read_whatever; + while (read((int)kevent->ident, &read_whatever, sizeof(read_whatever)) > 0) { + } + + continue; + } + + /* Otherwise this was a normal event on a subscribed handle. Figure out which flags to report. */ + int event_flags = s_aws_event_flags_from_kevent(kevent); + if (event_flags == 0) { + continue; + } + + /* Combine flags, in case multiple kevents correspond to one handle. (see notes at top of function) */ + struct handle_data *handle_data = kevent->udata; + if (handle_data->events_this_loop == 0) { + io_handle_events[num_io_handle_events++] = handle_data; + } + handle_data->events_this_loop |= event_flags; + } + + /* Invoke each handle's event callback (unless the handle has been unsubscribed) */ + for (int i = 0; i < num_io_handle_events; ++i) { + struct handle_data *handle_data = io_handle_events[i]; + + if (handle_data->state == HANDLE_STATE_SUBSCRIBED) { + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: activity on fd %d, invoking handler.", + (void *)event_loop, + handle_data->owner->data.fd); + handle_data->on_event( + event_loop, handle_data->owner, handle_data->events_this_loop, handle_data->on_event_user_data); + } + + handle_data->events_this_loop = 0; + } + + /* Process cross_thread_data */ + if (should_process_cross_thread_data) { + s_process_cross_thread_data(event_loop); + } + + /* Run scheduled tasks */ + uint64_t now_ns = 0; + event_loop->clock(&now_ns); /* If clock fails, now_ns will be 0 and tasks scheduled for a specific time + will not be run. That's ok, we'll handle them next time around. */ + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: running scheduled tasks.", (void *)event_loop); + aws_task_scheduler_run_all(&impl->thread_data.scheduler, now_ns); + + /* Set timeout for next kevent() call. + * If clock fails, or scheduler has no tasks, use default timeout */ + bool use_default_timeout = false; + + int err = event_loop->clock(&now_ns); + if (err) { + use_default_timeout = true; + } + + uint64_t next_run_time_ns; + if (!aws_task_scheduler_has_tasks(&impl->thread_data.scheduler, &next_run_time_ns)) { + + use_default_timeout = true; + } + + if (use_default_timeout) { + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, "id=%p: no more scheduled tasks using default timeout.", (void *)event_loop); + timeout.tv_sec = DEFAULT_TIMEOUT_SEC; + timeout.tv_nsec = 0; + } else { + /* Convert from timestamp in nanoseconds, to timeout in seconds with nanosecond remainder */ + uint64_t timeout_ns = next_run_time_ns > now_ns ? next_run_time_ns - now_ns : 0; + + uint64_t timeout_remainder_ns = 0; + uint64_t timeout_sec = + aws_timestamp_convert(timeout_ns, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_SECS, &timeout_remainder_ns); + + if (timeout_sec > LONG_MAX) { /* Check for overflow. On Darwin, these values are stored as longs */ + timeout_sec = LONG_MAX; + timeout_remainder_ns = 0; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: detected more scheduled tasks with the next occurring at " + "%llu using timeout of %ds %lluns.", + (void *)event_loop, + (unsigned long long)timeout_ns, + (int)timeout_sec, + (unsigned long long)timeout_remainder_ns); + timeout.tv_sec = (time_t)(timeout_sec); + timeout.tv_nsec = (long)(timeout_remainder_ns); + } + } + + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: exiting main loop", (void *)event_loop); + /* reset to NULL. This should be updated again during destroy before tasks are canceled. */ + aws_atomic_store_ptr(&impl->running_thread_id, NULL); +} diff --git a/contrib/restricted/aws/aws-c-io/source/channel.c b/contrib/restricted/aws/aws-c-io/source/channel.c index a36fa269d7..55ec1636cb 100644 --- a/contrib/restricted/aws/aws-c-io/source/channel.c +++ b/contrib/restricted/aws/aws-c-io/source/channel.c @@ -1,1144 +1,1144 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/channel.h> - -#include <aws/common/atomics.h> -#include <aws/common/clock.h> -#include <aws/common/mutex.h> - -#include <aws/io/event_loop.h> -#include <aws/io/logging.h> -#include <aws/io/message_pool.h> -#include <aws/io/statistics.h> - -#if _MSC_VER -# pragma warning(disable : 4204) /* non-constant aggregate initializer */ -#endif - -static size_t s_message_pool_key = 0; /* Address of variable serves as key in hash table */ - -enum { - KB_16 = 16 * 1024, -}; - -size_t g_aws_channel_max_fragment_size = KB_16; - -#define INITIAL_STATISTIC_LIST_SIZE 5 - -enum aws_channel_state { - AWS_CHANNEL_SETTING_UP, - AWS_CHANNEL_ACTIVE, - AWS_CHANNEL_SHUTTING_DOWN, - AWS_CHANNEL_SHUT_DOWN, -}; - -struct aws_shutdown_notification_task { - struct aws_task task; - int error_code; - struct aws_channel_slot *slot; - bool shutdown_immediately; -}; - -struct shutdown_task { - struct aws_channel_task task; - struct aws_channel *channel; - int error_code; - bool shutdown_immediately; -}; - -struct aws_channel { - struct aws_allocator *alloc; - struct aws_event_loop *loop; - struct aws_channel_slot *first; - struct aws_message_pool *msg_pool; - enum aws_channel_state channel_state; - struct aws_shutdown_notification_task shutdown_notify_task; - aws_channel_on_shutdown_completed_fn *on_shutdown_completed; - void *shutdown_user_data; - struct aws_atomic_var refcount; - struct aws_task deletion_task; - - struct aws_task statistics_task; - struct aws_crt_statistics_handler *statistics_handler; - uint64_t statistics_interval_start_time_ms; - struct aws_array_list statistic_list; - - struct { - struct aws_linked_list list; - } channel_thread_tasks; - struct { - struct aws_mutex lock; - struct aws_linked_list list; - struct aws_task scheduling_task; - struct shutdown_task shutdown_task; - bool is_channel_shut_down; - } cross_thread_tasks; - - size_t window_update_batch_emit_threshold; - struct aws_channel_task window_update_task; - bool read_back_pressure_enabled; - bool window_update_in_progress; -}; - -struct channel_setup_args { - struct aws_allocator *alloc; - struct aws_channel *channel; - aws_channel_on_setup_completed_fn *on_setup_completed; - void *user_data; - struct aws_task task; -}; - -static void s_on_msg_pool_removed(struct aws_event_loop_local_object *object) { - struct aws_message_pool *msg_pool = object->object; - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "static: message pool %p has been purged " - "from the event-loop: likely because of shutdown", - (void *)msg_pool); - struct aws_allocator *alloc = msg_pool->alloc; - aws_message_pool_clean_up(msg_pool); - aws_mem_release(alloc, msg_pool); - aws_mem_release(alloc, object); -} - -static void s_on_channel_setup_complete(struct aws_task *task, void *arg, enum aws_task_status task_status) { - - (void)task; - struct channel_setup_args *setup_args = arg; - struct aws_message_pool *message_pool = NULL; - struct aws_event_loop_local_object *local_object = NULL; - - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL, "id=%p: setup complete, notifying caller.", (void *)setup_args->channel); - if (task_status == AWS_TASK_STATUS_RUN_READY) { - struct aws_event_loop_local_object stack_obj; - AWS_ZERO_STRUCT(stack_obj); - local_object = &stack_obj; - - if (aws_event_loop_fetch_local_object(setup_args->channel->loop, &s_message_pool_key, local_object)) { - - local_object = aws_mem_calloc(setup_args->alloc, 1, sizeof(struct aws_event_loop_local_object)); - if (!local_object) { - goto cleanup_setup_args; - } - - message_pool = aws_mem_acquire(setup_args->alloc, sizeof(struct aws_message_pool)); - if (!message_pool) { - goto cleanup_local_obj; - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL, - "id=%p: no message pool is currently stored in the event-loop " - "local storage, adding %p with max message size %zu, " - "message count 4, with 4 small blocks of 128 bytes.", - (void *)setup_args->channel, - (void *)message_pool, - g_aws_channel_max_fragment_size); - - struct aws_message_pool_creation_args creation_args = { - .application_data_msg_data_size = g_aws_channel_max_fragment_size, - .application_data_msg_count = 4, - .small_block_msg_count = 4, - .small_block_msg_data_size = 128, - }; - - if (aws_message_pool_init(message_pool, setup_args->alloc, &creation_args)) { - goto cleanup_msg_pool_mem; - } - - local_object->key = &s_message_pool_key; - local_object->object = message_pool; - local_object->on_object_removed = s_on_msg_pool_removed; - - if (aws_event_loop_put_local_object(setup_args->channel->loop, local_object)) { - goto cleanup_msg_pool; - } - } else { - message_pool = local_object->object; - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL, - "id=%p: message pool %p found in event-loop local storage: using it.", - (void *)setup_args->channel, - (void *)message_pool) - } - - setup_args->channel->msg_pool = message_pool; - setup_args->channel->channel_state = AWS_CHANNEL_ACTIVE; - setup_args->on_setup_completed(setup_args->channel, AWS_OP_SUCCESS, setup_args->user_data); - aws_channel_release_hold(setup_args->channel); - aws_mem_release(setup_args->alloc, setup_args); - return; - } - - goto cleanup_setup_args; - -cleanup_msg_pool: - aws_message_pool_clean_up(message_pool); - -cleanup_msg_pool_mem: - aws_mem_release(setup_args->alloc, message_pool); - -cleanup_local_obj: - aws_mem_release(setup_args->alloc, local_object); - -cleanup_setup_args: - setup_args->on_setup_completed(setup_args->channel, AWS_OP_ERR, setup_args->user_data); - aws_channel_release_hold(setup_args->channel); - aws_mem_release(setup_args->alloc, setup_args); -} - -static void s_schedule_cross_thread_tasks(struct aws_task *task, void *arg, enum aws_task_status status); - -static void s_destroy_partially_constructed_channel(struct aws_channel *channel) { - if (channel == NULL) { - return; - } - - aws_array_list_clean_up(&channel->statistic_list); - - aws_mem_release(channel->alloc, channel); -} - -struct aws_channel *aws_channel_new(struct aws_allocator *alloc, const struct aws_channel_options *creation_args) { - AWS_PRECONDITION(creation_args); - AWS_PRECONDITION(creation_args->event_loop); - AWS_PRECONDITION(creation_args->on_setup_completed); - - struct aws_channel *channel = aws_mem_calloc(alloc, 1, sizeof(struct aws_channel)); - if (!channel) { - return NULL; - } - - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL, "id=%p: Beginning creation and setup of new channel.", (void *)channel); - channel->alloc = alloc; - channel->loop = creation_args->event_loop; - channel->on_shutdown_completed = creation_args->on_shutdown_completed; - channel->shutdown_user_data = creation_args->shutdown_user_data; - - if (aws_array_list_init_dynamic( - &channel->statistic_list, alloc, INITIAL_STATISTIC_LIST_SIZE, sizeof(struct aws_crt_statistics_base *))) { - goto on_error; - } - - /* Start refcount at 2: - * 1 for self-reference, released from aws_channel_destroy() - * 1 for the setup task, released when task executes */ - aws_atomic_init_int(&channel->refcount, 2); - - struct channel_setup_args *setup_args = aws_mem_calloc(alloc, 1, sizeof(struct channel_setup_args)); - if (!setup_args) { - goto on_error; - } - - channel->channel_state = AWS_CHANNEL_SETTING_UP; - aws_linked_list_init(&channel->channel_thread_tasks.list); - aws_linked_list_init(&channel->cross_thread_tasks.list); - channel->cross_thread_tasks.lock = (struct aws_mutex)AWS_MUTEX_INIT; - - if (creation_args->enable_read_back_pressure) { - channel->read_back_pressure_enabled = true; - /* we probably only need room for one fragment, but let's avoid potential deadlocks - * on things like tls that need extra head-room. */ - channel->window_update_batch_emit_threshold = g_aws_channel_max_fragment_size * 2; - } - - aws_task_init( - &channel->cross_thread_tasks.scheduling_task, - s_schedule_cross_thread_tasks, - channel, - "schedule_cross_thread_tasks"); - - setup_args->alloc = alloc; - setup_args->channel = channel; - setup_args->on_setup_completed = creation_args->on_setup_completed; - setup_args->user_data = creation_args->setup_user_data; - - aws_task_init(&setup_args->task, s_on_channel_setup_complete, setup_args, "on_channel_setup_complete"); - aws_event_loop_schedule_task_now(creation_args->event_loop, &setup_args->task); - - return channel; - -on_error: - - s_destroy_partially_constructed_channel(channel); - - return NULL; -} - -static void s_cleanup_slot(struct aws_channel_slot *slot) { - if (slot) { - if (slot->handler) { - aws_channel_handler_destroy(slot->handler); - } - aws_mem_release(slot->alloc, slot); - } -} - -void aws_channel_destroy(struct aws_channel *channel) { - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL, "id=%p: destroying channel.", (void *)channel); - - aws_channel_release_hold(channel); -} - -static void s_final_channel_deletion_task(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)task; - (void)status; - struct aws_channel *channel = arg; - - struct aws_channel_slot *current = channel->first; - - if (!current || !current->handler) { - /* Allow channels with no valid slots to skip shutdown process */ - channel->channel_state = AWS_CHANNEL_SHUT_DOWN; - } - - AWS_ASSERT(channel->channel_state == AWS_CHANNEL_SHUT_DOWN); - - while (current) { - struct aws_channel_slot *tmp = current->adj_right; - s_cleanup_slot(current); - current = tmp; - } - - aws_array_list_clean_up(&channel->statistic_list); - - aws_channel_set_statistics_handler(channel, NULL); - - aws_mem_release(channel->alloc, channel); -} - -void aws_channel_acquire_hold(struct aws_channel *channel) { - size_t prev_refcount = aws_atomic_fetch_add(&channel->refcount, 1); - AWS_ASSERT(prev_refcount != 0); - (void)prev_refcount; -} - -void aws_channel_release_hold(struct aws_channel *channel) { - size_t prev_refcount = aws_atomic_fetch_sub(&channel->refcount, 1); - AWS_ASSERT(prev_refcount != 0); - - if (prev_refcount == 1) { - /* Refcount is now 0, finish cleaning up channel memory. */ - if (aws_channel_thread_is_callers_thread(channel)) { - s_final_channel_deletion_task(NULL, channel, AWS_TASK_STATUS_RUN_READY); - } else { - aws_task_init(&channel->deletion_task, s_final_channel_deletion_task, channel, "final_channel_deletion"); - aws_event_loop_schedule_task_now(channel->loop, &channel->deletion_task); - } - } -} - -struct channel_shutdown_task_args { - struct aws_channel *channel; - struct aws_allocator *alloc; - int error_code; - struct aws_task task; -}; - -static int s_channel_shutdown(struct aws_channel *channel, int error_code, bool shutdown_immediately); - -static void s_on_shutdown_completion_task(struct aws_task *task, void *arg, enum aws_task_status status); - -static void s_shutdown_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) { - - (void)task; - (void)status; - struct shutdown_task *shutdown_task = arg; - struct aws_channel *channel = shutdown_task->channel; - int error_code = shutdown_task->error_code; - bool shutdown_immediately = shutdown_task->shutdown_immediately; - if (channel->channel_state < AWS_CHANNEL_SHUTTING_DOWN) { - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL, "id=%p: beginning shutdown process", (void *)channel); - - struct aws_channel_slot *slot = channel->first; - channel->channel_state = AWS_CHANNEL_SHUTTING_DOWN; - - if (slot) { - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "id=%p: shutting down slot %p (the first one) in the read direction", - (void *)channel, - (void *)slot); - - aws_channel_slot_shutdown(slot, AWS_CHANNEL_DIR_READ, error_code, shutdown_immediately); - return; - } - - channel->channel_state = AWS_CHANNEL_SHUT_DOWN; - AWS_LOGF_TRACE(AWS_LS_IO_CHANNEL, "id=%p: shutdown completed", (void *)channel); - - aws_mutex_lock(&channel->cross_thread_tasks.lock); - channel->cross_thread_tasks.is_channel_shut_down = true; - aws_mutex_unlock(&channel->cross_thread_tasks.lock); - - if (channel->on_shutdown_completed) { - channel->shutdown_notify_task.task.fn = s_on_shutdown_completion_task; - channel->shutdown_notify_task.task.arg = channel; - channel->shutdown_notify_task.error_code = error_code; - aws_event_loop_schedule_task_now(channel->loop, &channel->shutdown_notify_task.task); - } - } -} - -static int s_channel_shutdown(struct aws_channel *channel, int error_code, bool shutdown_immediately) { - bool need_to_schedule = true; - aws_mutex_lock(&channel->cross_thread_tasks.lock); - if (channel->cross_thread_tasks.shutdown_task.task.task_fn) { - need_to_schedule = false; - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL, "id=%p: Channel shutdown is already pending, not scheduling another.", (void *)channel); - - } else { - aws_channel_task_init( - &channel->cross_thread_tasks.shutdown_task.task, - s_shutdown_task, - &channel->cross_thread_tasks.shutdown_task, - "channel_shutdown"); - channel->cross_thread_tasks.shutdown_task.shutdown_immediately = shutdown_immediately; - channel->cross_thread_tasks.shutdown_task.channel = channel; - channel->cross_thread_tasks.shutdown_task.error_code = error_code; - } - - aws_mutex_unlock(&channel->cross_thread_tasks.lock); - - if (need_to_schedule) { - AWS_LOGF_TRACE(AWS_LS_IO_CHANNEL, "id=%p: channel shutdown task is scheduled", (void *)channel); - aws_channel_schedule_task_now(channel, &channel->cross_thread_tasks.shutdown_task.task); - } - - return AWS_OP_SUCCESS; -} - -int aws_channel_shutdown(struct aws_channel *channel, int error_code) { - return s_channel_shutdown(channel, error_code, false); -} - -struct aws_io_message *aws_channel_acquire_message_from_pool( - struct aws_channel *channel, - enum aws_io_message_type message_type, - size_t size_hint) { - - struct aws_io_message *message = aws_message_pool_acquire(channel->msg_pool, message_type, size_hint); - - if (AWS_LIKELY(message)) { - message->owning_channel = channel; - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "id=%p: acquired message %p of capacity %zu from pool %p. Requested size was %zu", - (void *)channel, - (void *)message, - message->message_data.capacity, - (void *)channel->msg_pool, - size_hint); - } - - return message; -} - -struct aws_channel_slot *aws_channel_slot_new(struct aws_channel *channel) { - struct aws_channel_slot *new_slot = aws_mem_calloc(channel->alloc, 1, sizeof(struct aws_channel_slot)); - if (!new_slot) { - return NULL; - } - - AWS_LOGF_TRACE(AWS_LS_IO_CHANNEL, "id=%p: creating new slot %p.", (void *)channel, (void *)new_slot); - new_slot->alloc = channel->alloc; - new_slot->channel = channel; - - if (!channel->first) { - channel->first = new_slot; - } - - return new_slot; -} - -int aws_channel_current_clock_time(struct aws_channel *channel, uint64_t *time_nanos) { - return aws_event_loop_current_clock_time(channel->loop, time_nanos); -} - -int aws_channel_fetch_local_object( - struct aws_channel *channel, - const void *key, - struct aws_event_loop_local_object *obj) { - - return aws_event_loop_fetch_local_object(channel->loop, (void *)key, obj); -} -int aws_channel_put_local_object( - struct aws_channel *channel, - const void *key, - const struct aws_event_loop_local_object *obj) { - - (void)key; - return aws_event_loop_put_local_object(channel->loop, (struct aws_event_loop_local_object *)obj); -} - -int aws_channel_remove_local_object( - struct aws_channel *channel, - const void *key, - struct aws_event_loop_local_object *removed_obj) { - - return aws_event_loop_remove_local_object(channel->loop, (void *)key, removed_obj); -} - -static void s_channel_task_run(struct aws_task *task, void *arg, enum aws_task_status status) { - struct aws_channel_task *channel_task = AWS_CONTAINER_OF(task, struct aws_channel_task, wrapper_task); - struct aws_channel *channel = arg; - - /* Any task that runs after shutdown completes is considered canceled */ - if (channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { - status = AWS_TASK_STATUS_CANCELED; - } - - aws_linked_list_remove(&channel_task->node); - channel_task->task_fn(channel_task, channel_task->arg, status); -} - -static void s_schedule_cross_thread_tasks(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)task; - struct aws_channel *channel = arg; - - struct aws_linked_list cross_thread_task_list; - aws_linked_list_init(&cross_thread_task_list); - - /* Grab contents of cross-thread task list while we have the lock */ - aws_mutex_lock(&channel->cross_thread_tasks.lock); - aws_linked_list_swap_contents(&channel->cross_thread_tasks.list, &cross_thread_task_list); - aws_mutex_unlock(&channel->cross_thread_tasks.lock); - - /* If the channel has shut down since the cross-thread tasks were scheduled, run tasks immediately as canceled */ - if (channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { - status = AWS_TASK_STATUS_CANCELED; - } - - while (!aws_linked_list_empty(&cross_thread_task_list)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&cross_thread_task_list); - struct aws_channel_task *channel_task = AWS_CONTAINER_OF(node, struct aws_channel_task, node); - - if ((channel_task->wrapper_task.timestamp == 0) || (status == AWS_TASK_STATUS_CANCELED)) { - /* Run "now" tasks, and canceled tasks, immediately */ - channel_task->task_fn(channel_task, channel_task->arg, status); - } else { - /* "Future" tasks are scheduled with the event-loop. */ - aws_linked_list_push_back(&channel->channel_thread_tasks.list, &channel_task->node); - aws_event_loop_schedule_task_future( - channel->loop, &channel_task->wrapper_task, channel_task->wrapper_task.timestamp); - } - } -} - -void aws_channel_task_init( - struct aws_channel_task *channel_task, - aws_channel_task_fn *task_fn, - void *arg, - const char *type_tag) { - AWS_ZERO_STRUCT(*channel_task); - channel_task->task_fn = task_fn; - channel_task->arg = arg; - channel_task->type_tag = type_tag; -} - -/* Common functionality for scheduling "now" and "future" tasks. - * For "now" tasks, pass 0 for `run_at_nanos` */ -static void s_register_pending_task( - struct aws_channel *channel, - struct aws_channel_task *channel_task, - uint64_t run_at_nanos) { - - /* Reset every property on channel task other than user's fn & arg.*/ - aws_task_init(&channel_task->wrapper_task, s_channel_task_run, channel, channel_task->type_tag); - channel_task->wrapper_task.timestamp = run_at_nanos; - aws_linked_list_node_reset(&channel_task->node); - - if (aws_channel_thread_is_callers_thread(channel)) { - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "id=%p: scheduling task with wrapper task id %p.", - (void *)channel, - (void *)&channel_task->wrapper_task); - - /* If channel is shut down, run task immediately as canceled */ - if (channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL, - "id=%p: Running %s channel task immediately as canceled due to shut down channel", - (void *)channel, - channel_task->type_tag); - channel_task->task_fn(channel_task, channel_task->arg, AWS_TASK_STATUS_CANCELED); - return; - } - - aws_linked_list_push_back(&channel->channel_thread_tasks.list, &channel_task->node); - if (run_at_nanos == 0) { - aws_event_loop_schedule_task_now(channel->loop, &channel_task->wrapper_task); - } else { - aws_event_loop_schedule_task_future( - channel->loop, &channel_task->wrapper_task, channel_task->wrapper_task.timestamp); - } - return; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "id=%p: scheduling task with wrapper task id %p from " - "outside the event-loop thread.", - (void *)channel, - (void *)&channel_task->wrapper_task); - /* Outside event-loop thread... */ - bool should_cancel_task = false; - - /* Begin Critical Section */ - aws_mutex_lock(&channel->cross_thread_tasks.lock); - if (channel->cross_thread_tasks.is_channel_shut_down) { - should_cancel_task = true; /* run task outside critical section to avoid deadlock */ - } else { - bool list_was_empty = aws_linked_list_empty(&channel->cross_thread_tasks.list); - aws_linked_list_push_back(&channel->cross_thread_tasks.list, &channel_task->node); - - if (list_was_empty) { - aws_event_loop_schedule_task_now(channel->loop, &channel->cross_thread_tasks.scheduling_task); - } - } - aws_mutex_unlock(&channel->cross_thread_tasks.lock); - /* End Critical Section */ - - if (should_cancel_task) { - channel_task->task_fn(channel_task, channel_task->arg, AWS_TASK_STATUS_CANCELED); - } -} - -void aws_channel_schedule_task_now(struct aws_channel *channel, struct aws_channel_task *task) { - s_register_pending_task(channel, task, 0); -} - -void aws_channel_schedule_task_future( - struct aws_channel *channel, - struct aws_channel_task *task, - uint64_t run_at_nanos) { - - s_register_pending_task(channel, task, run_at_nanos); -} - -bool aws_channel_thread_is_callers_thread(struct aws_channel *channel) { - return aws_event_loop_thread_is_callers_thread(channel->loop); -} - -static void s_update_channel_slot_message_overheads(struct aws_channel *channel) { - size_t overhead = 0; - struct aws_channel_slot *slot_iter = channel->first; - while (slot_iter) { - slot_iter->upstream_message_overhead = overhead; - - if (slot_iter->handler) { - overhead += slot_iter->handler->vtable->message_overhead(slot_iter->handler); - } - slot_iter = slot_iter->adj_right; - } -} - -int aws_channel_slot_set_handler(struct aws_channel_slot *slot, struct aws_channel_handler *handler) { - slot->handler = handler; - slot->handler->slot = slot; - s_update_channel_slot_message_overheads(slot->channel); - - return aws_channel_slot_increment_read_window(slot, slot->handler->vtable->initial_window_size(handler)); -} - -int aws_channel_slot_remove(struct aws_channel_slot *slot) { - if (slot->adj_right) { - slot->adj_right->adj_left = slot->adj_left; - - if (slot == slot->channel->first) { - slot->channel->first = slot->adj_right; - } - } - - if (slot->adj_left) { - slot->adj_left->adj_right = slot->adj_right; - } - - if (slot == slot->channel->first) { - slot->channel->first = NULL; - } - - s_update_channel_slot_message_overheads(slot->channel); - s_cleanup_slot(slot); - return AWS_OP_SUCCESS; -} - -int aws_channel_slot_replace(struct aws_channel_slot *remove, struct aws_channel_slot *new_slot) { - new_slot->adj_left = remove->adj_left; - - if (remove->adj_left) { - remove->adj_left->adj_right = new_slot; - } - - new_slot->adj_right = remove->adj_right; - - if (remove->adj_right) { - remove->adj_right->adj_left = new_slot; - } - - if (remove == remove->channel->first) { - remove->channel->first = new_slot; - } - - s_update_channel_slot_message_overheads(remove->channel); - s_cleanup_slot(remove); - return AWS_OP_SUCCESS; -} - -int aws_channel_slot_insert_right(struct aws_channel_slot *slot, struct aws_channel_slot *to_add) { - to_add->adj_right = slot->adj_right; - - if (slot->adj_right) { - slot->adj_right->adj_left = to_add; - } - - slot->adj_right = to_add; - to_add->adj_left = slot; - - return AWS_OP_SUCCESS; -} - -int aws_channel_slot_insert_end(struct aws_channel *channel, struct aws_channel_slot *to_add) { - /* It's actually impossible there's not a first if the user went through the aws_channel_slot_new() function. - * But also check that a user didn't call insert_end if it's the first slot in the channel since first would already - * have been set. */ - if (AWS_LIKELY(channel->first && channel->first != to_add)) { - struct aws_channel_slot *cur = channel->first; - while (cur->adj_right) { - cur = cur->adj_right; - } - - return aws_channel_slot_insert_right(cur, to_add); - } - - AWS_ASSERT(0); - return AWS_OP_ERR; -} - -int aws_channel_slot_insert_left(struct aws_channel_slot *slot, struct aws_channel_slot *to_add) { - to_add->adj_left = slot->adj_left; - - if (slot->adj_left) { - slot->adj_left->adj_right = to_add; - } - - slot->adj_left = to_add; - to_add->adj_right = slot; - - if (slot == slot->channel->first) { - slot->channel->first = to_add; - } - - return AWS_OP_SUCCESS; -} - -int aws_channel_slot_send_message( - struct aws_channel_slot *slot, - struct aws_io_message *message, - enum aws_channel_direction dir) { - - if (dir == AWS_CHANNEL_DIR_READ) { - AWS_ASSERT(slot->adj_right); - AWS_ASSERT(slot->adj_right->handler); - - if (!slot->channel->read_back_pressure_enabled || slot->adj_right->window_size >= message->message_data.len) { - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "id=%p: sending read message of size %zu, " - "from slot %p to slot %p with handler %p.", - (void *)slot->channel, - message->message_data.len, - (void *)slot, - (void *)slot->adj_right, - (void *)slot->adj_right->handler); - slot->adj_right->window_size -= message->message_data.len; - return aws_channel_handler_process_read_message(slot->adj_right->handler, slot->adj_right, message); - } - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL, - "id=%p: sending message of size %zu, " - "from slot %p to slot %p with handler %p, but this would exceed the channel's " - "read window, this is always a programming error.", - (void *)slot->channel, - message->message_data.len, - (void *)slot, - (void *)slot->adj_right, - (void *)slot->adj_right->handler); - return aws_raise_error(AWS_IO_CHANNEL_READ_WOULD_EXCEED_WINDOW); - } - - AWS_ASSERT(slot->adj_left); - AWS_ASSERT(slot->adj_left->handler); - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "id=%p: sending write message of size %zu, " - "from slot %p to slot %p with handler %p.", - (void *)slot->channel, - message->message_data.len, - (void *)slot, - (void *)slot->adj_left, - (void *)slot->adj_left->handler); - return aws_channel_handler_process_write_message(slot->adj_left->handler, slot->adj_left, message); -} - -struct aws_io_message *aws_channel_slot_acquire_max_message_for_write(struct aws_channel_slot *slot) { - AWS_PRECONDITION(slot); - AWS_PRECONDITION(slot->channel); - AWS_PRECONDITION(aws_channel_thread_is_callers_thread(slot->channel)); - - const size_t overhead = aws_channel_slot_upstream_message_overhead(slot); - if (overhead >= g_aws_channel_max_fragment_size) { - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL, "id=%p: Upstream overhead exceeds channel's max message size.", (void *)slot->channel); - aws_raise_error(AWS_ERROR_INVALID_STATE); - return NULL; - } - - const size_t size_hint = g_aws_channel_max_fragment_size - overhead; - return aws_channel_acquire_message_from_pool(slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, size_hint); -} - -static void s_window_update_task(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) { - (void)channel_task; - struct aws_channel *channel = arg; - - if (status == AWS_TASK_STATUS_RUN_READY && channel->channel_state < AWS_CHANNEL_SHUTTING_DOWN) { - /* get the right-most slot to start the updates. */ - struct aws_channel_slot *slot = channel->first; - while (slot->adj_right) { - slot = slot->adj_right; - } - - while (slot->adj_left) { - struct aws_channel_slot *upstream_slot = slot->adj_left; - if (upstream_slot->handler) { - slot->window_size = aws_add_size_saturating(slot->window_size, slot->current_window_update_batch_size); - size_t update_size = slot->current_window_update_batch_size; - slot->current_window_update_batch_size = 0; - if (aws_channel_handler_increment_read_window(upstream_slot->handler, upstream_slot, update_size)) { - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL, - "channel %p: channel update task failed with status %d", - (void *)slot->channel, - aws_last_error()); - slot->channel->window_update_in_progress = false; - aws_channel_shutdown(channel, aws_last_error()); - return; - } - } - slot = slot->adj_left; - } - } - channel->window_update_in_progress = false; -} - -int aws_channel_slot_increment_read_window(struct aws_channel_slot *slot, size_t window) { - - if (slot->channel->read_back_pressure_enabled && slot->channel->channel_state < AWS_CHANNEL_SHUTTING_DOWN) { - slot->current_window_update_batch_size = - aws_add_size_saturating(slot->current_window_update_batch_size, window); - - if (!slot->channel->window_update_in_progress && - slot->window_size <= slot->channel->window_update_batch_emit_threshold) { - slot->channel->window_update_in_progress = true; - aws_channel_task_init( - &slot->channel->window_update_task, s_window_update_task, slot->channel, "window update task"); - aws_channel_schedule_task_now(slot->channel, &slot->channel->window_update_task); - } - } - - return AWS_OP_SUCCESS; -} - -int aws_channel_slot_shutdown( - struct aws_channel_slot *slot, - enum aws_channel_direction dir, - int err_code, - bool free_scarce_resources_immediately) { - AWS_ASSERT(slot->handler); - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "id=%p: shutting down slot %p, with handler %p " - "in %s direction with error code %d", - (void *)slot->channel, - (void *)slot, - (void *)slot->handler, - (dir == AWS_CHANNEL_DIR_READ) ? "read" : "write", - err_code); - return aws_channel_handler_shutdown(slot->handler, slot, dir, err_code, free_scarce_resources_immediately); -} - -static void s_on_shutdown_completion_task(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)status; - - struct aws_shutdown_notification_task *shutdown_notify = (struct aws_shutdown_notification_task *)task; - struct aws_channel *channel = arg; - AWS_ASSERT(channel->channel_state == AWS_CHANNEL_SHUT_DOWN); - - /* Cancel tasks that have been scheduled with the event loop */ - while (!aws_linked_list_empty(&channel->channel_thread_tasks.list)) { - struct aws_linked_list_node *node = aws_linked_list_front(&channel->channel_thread_tasks.list); - struct aws_channel_task *channel_task = AWS_CONTAINER_OF(node, struct aws_channel_task, node); - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL, - "id=%p: during shutdown, canceling task %p", - (void *)channel, - (void *)&channel_task->wrapper_task); - /* The task will remove itself from the list when it's canceled */ - aws_event_loop_cancel_task(channel->loop, &channel_task->wrapper_task); - } - - /* Cancel off-thread tasks, which haven't made it to the event-loop thread yet */ - aws_mutex_lock(&channel->cross_thread_tasks.lock); - bool cancel_cross_thread_tasks = !aws_linked_list_empty(&channel->cross_thread_tasks.list); - aws_mutex_unlock(&channel->cross_thread_tasks.lock); - - if (cancel_cross_thread_tasks) { - aws_event_loop_cancel_task(channel->loop, &channel->cross_thread_tasks.scheduling_task); - } - - AWS_ASSERT(aws_linked_list_empty(&channel->channel_thread_tasks.list)); - AWS_ASSERT(aws_linked_list_empty(&channel->cross_thread_tasks.list)); - - channel->on_shutdown_completed(channel, shutdown_notify->error_code, channel->shutdown_user_data); -} - -static void s_run_shutdown_write_direction(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)arg; - (void)status; - - struct aws_shutdown_notification_task *shutdown_notify = (struct aws_shutdown_notification_task *)task; - task->fn = NULL; - task->arg = NULL; - struct aws_channel_slot *slot = shutdown_notify->slot; - aws_channel_handler_shutdown( - slot->handler, slot, AWS_CHANNEL_DIR_WRITE, shutdown_notify->error_code, shutdown_notify->shutdown_immediately); -} - -int aws_channel_slot_on_handler_shutdown_complete( - struct aws_channel_slot *slot, - enum aws_channel_direction dir, - int err_code, - bool free_scarce_resources_immediately) { - - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL, - "id=%p: handler %p shutdown in %s dir completed.", - (void *)slot->channel, - (void *)slot->handler, - (dir == AWS_CHANNEL_DIR_READ) ? "read" : "write"); - - if (slot->channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { - return AWS_OP_SUCCESS; - } - - if (dir == AWS_CHANNEL_DIR_READ) { - if (slot->adj_right && slot->adj_right->handler) { - return aws_channel_handler_shutdown( - slot->adj_right->handler, slot->adj_right, dir, err_code, free_scarce_resources_immediately); - } - - /* break the shutdown sequence so we don't have handlers having to deal with their memory disappearing out from - * under them during a shutdown process. */ - slot->channel->shutdown_notify_task.slot = slot; - slot->channel->shutdown_notify_task.shutdown_immediately = free_scarce_resources_immediately; - slot->channel->shutdown_notify_task.error_code = err_code; - slot->channel->shutdown_notify_task.task.fn = s_run_shutdown_write_direction; - slot->channel->shutdown_notify_task.task.arg = NULL; - - aws_event_loop_schedule_task_now(slot->channel->loop, &slot->channel->shutdown_notify_task.task); - return AWS_OP_SUCCESS; - } - - if (slot->adj_left && slot->adj_left->handler) { - return aws_channel_handler_shutdown( - slot->adj_left->handler, slot->adj_left, dir, err_code, free_scarce_resources_immediately); - } - - if (slot->channel->first == slot) { - slot->channel->channel_state = AWS_CHANNEL_SHUT_DOWN; - aws_mutex_lock(&slot->channel->cross_thread_tasks.lock); - slot->channel->cross_thread_tasks.is_channel_shut_down = true; - aws_mutex_unlock(&slot->channel->cross_thread_tasks.lock); - - if (slot->channel->on_shutdown_completed) { - slot->channel->shutdown_notify_task.task.fn = s_on_shutdown_completion_task; - slot->channel->shutdown_notify_task.task.arg = slot->channel; - slot->channel->shutdown_notify_task.error_code = err_code; - aws_event_loop_schedule_task_now(slot->channel->loop, &slot->channel->shutdown_notify_task.task); - } - } - - return AWS_OP_SUCCESS; -} - -size_t aws_channel_slot_downstream_read_window(struct aws_channel_slot *slot) { - AWS_ASSERT(slot->adj_right); - return slot->channel->read_back_pressure_enabled ? slot->adj_right->window_size : SIZE_MAX; -} - -size_t aws_channel_slot_upstream_message_overhead(struct aws_channel_slot *slot) { - return slot->upstream_message_overhead; -} - -void aws_channel_handler_destroy(struct aws_channel_handler *handler) { - AWS_ASSERT(handler->vtable && handler->vtable->destroy); - handler->vtable->destroy(handler); -} - -int aws_channel_handler_process_read_message( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_io_message *message) { - - AWS_ASSERT(handler->vtable && handler->vtable->process_read_message); - return handler->vtable->process_read_message(handler, slot, message); -} - -int aws_channel_handler_process_write_message( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_io_message *message) { - - AWS_ASSERT(handler->vtable && handler->vtable->process_write_message); - return handler->vtable->process_write_message(handler, slot, message); -} - -int aws_channel_handler_increment_read_window( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - size_t size) { - - AWS_ASSERT(handler->vtable && handler->vtable->increment_read_window); - - return handler->vtable->increment_read_window(handler, slot, size); -} - -int aws_channel_handler_shutdown( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - enum aws_channel_direction dir, - int error_code, - bool free_scarce_resources_immediately) { - - AWS_ASSERT(handler->vtable && handler->vtable->shutdown); - return handler->vtable->shutdown(handler, slot, dir, error_code, free_scarce_resources_immediately); -} - -size_t aws_channel_handler_initial_window_size(struct aws_channel_handler *handler) { - AWS_ASSERT(handler->vtable && handler->vtable->initial_window_size); - return handler->vtable->initial_window_size(handler); -} - -struct aws_channel_slot *aws_channel_get_first_slot(struct aws_channel *channel) { - return channel->first; -} - -static void s_reset_statistics(struct aws_channel *channel) { - AWS_FATAL_ASSERT(aws_channel_thread_is_callers_thread(channel)); - - struct aws_channel_slot *current_slot = channel->first; - while (current_slot) { - struct aws_channel_handler *handler = current_slot->handler; - if (handler != NULL && handler->vtable->reset_statistics != NULL) { - handler->vtable->reset_statistics(handler); - } - current_slot = current_slot->adj_right; - } -} - -static void s_channel_gather_statistics_task(struct aws_task *task, void *arg, enum aws_task_status status) { - if (status != AWS_TASK_STATUS_RUN_READY) { - return; - } - - struct aws_channel *channel = arg; - if (channel->statistics_handler == NULL) { - return; - } - - if (channel->channel_state == AWS_CHANNEL_SHUTTING_DOWN || channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { - return; - } - - uint64_t now_ns = 0; - if (aws_channel_current_clock_time(channel, &now_ns)) { - return; - } - - uint64_t now_ms = aws_timestamp_convert(now_ns, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_MILLIS, NULL); - - struct aws_array_list *statistics_list = &channel->statistic_list; - aws_array_list_clear(statistics_list); - - struct aws_channel_slot *current_slot = channel->first; - while (current_slot) { - struct aws_channel_handler *handler = current_slot->handler; - if (handler != NULL && handler->vtable->gather_statistics != NULL) { - handler->vtable->gather_statistics(handler, statistics_list); - } - current_slot = current_slot->adj_right; - } - - struct aws_crt_statistics_sample_interval sample_interval = { - .begin_time_ms = channel->statistics_interval_start_time_ms, .end_time_ms = now_ms}; - - aws_crt_statistics_handler_process_statistics( - channel->statistics_handler, &sample_interval, statistics_list, channel); - - s_reset_statistics(channel); - - uint64_t reschedule_interval_ns = aws_timestamp_convert( - aws_crt_statistics_handler_get_report_interval_ms(channel->statistics_handler), - AWS_TIMESTAMP_MILLIS, - AWS_TIMESTAMP_NANOS, - NULL); - - aws_event_loop_schedule_task_future(channel->loop, task, now_ns + reschedule_interval_ns); - - channel->statistics_interval_start_time_ms = now_ms; -} - -int aws_channel_set_statistics_handler(struct aws_channel *channel, struct aws_crt_statistics_handler *handler) { - AWS_FATAL_ASSERT(aws_channel_thread_is_callers_thread(channel)); - - if (channel->statistics_handler) { - aws_crt_statistics_handler_destroy(channel->statistics_handler); - aws_event_loop_cancel_task(channel->loop, &channel->statistics_task); - channel->statistics_handler = NULL; - } - - if (handler != NULL) { - aws_task_init(&channel->statistics_task, s_channel_gather_statistics_task, channel, "gather_statistics"); - - uint64_t now_ns = 0; - if (aws_channel_current_clock_time(channel, &now_ns)) { - return AWS_OP_ERR; - } - - uint64_t report_time_ns = now_ns + aws_timestamp_convert( - aws_crt_statistics_handler_get_report_interval_ms(handler), - AWS_TIMESTAMP_MILLIS, - AWS_TIMESTAMP_NANOS, - NULL); - - channel->statistics_interval_start_time_ms = - aws_timestamp_convert(now_ns, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_MILLIS, NULL); - s_reset_statistics(channel); - - aws_event_loop_schedule_task_future(channel->loop, &channel->statistics_task, report_time_ns); - } - - channel->statistics_handler = handler; - - return AWS_OP_SUCCESS; -} - -struct aws_event_loop *aws_channel_get_event_loop(struct aws_channel *channel) { - return channel->loop; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/channel.h> + +#include <aws/common/atomics.h> +#include <aws/common/clock.h> +#include <aws/common/mutex.h> + +#include <aws/io/event_loop.h> +#include <aws/io/logging.h> +#include <aws/io/message_pool.h> +#include <aws/io/statistics.h> + +#if _MSC_VER +# pragma warning(disable : 4204) /* non-constant aggregate initializer */ +#endif + +static size_t s_message_pool_key = 0; /* Address of variable serves as key in hash table */ + +enum { + KB_16 = 16 * 1024, +}; + +size_t g_aws_channel_max_fragment_size = KB_16; + +#define INITIAL_STATISTIC_LIST_SIZE 5 + +enum aws_channel_state { + AWS_CHANNEL_SETTING_UP, + AWS_CHANNEL_ACTIVE, + AWS_CHANNEL_SHUTTING_DOWN, + AWS_CHANNEL_SHUT_DOWN, +}; + +struct aws_shutdown_notification_task { + struct aws_task task; + int error_code; + struct aws_channel_slot *slot; + bool shutdown_immediately; +}; + +struct shutdown_task { + struct aws_channel_task task; + struct aws_channel *channel; + int error_code; + bool shutdown_immediately; +}; + +struct aws_channel { + struct aws_allocator *alloc; + struct aws_event_loop *loop; + struct aws_channel_slot *first; + struct aws_message_pool *msg_pool; + enum aws_channel_state channel_state; + struct aws_shutdown_notification_task shutdown_notify_task; + aws_channel_on_shutdown_completed_fn *on_shutdown_completed; + void *shutdown_user_data; + struct aws_atomic_var refcount; + struct aws_task deletion_task; + + struct aws_task statistics_task; + struct aws_crt_statistics_handler *statistics_handler; + uint64_t statistics_interval_start_time_ms; + struct aws_array_list statistic_list; + + struct { + struct aws_linked_list list; + } channel_thread_tasks; + struct { + struct aws_mutex lock; + struct aws_linked_list list; + struct aws_task scheduling_task; + struct shutdown_task shutdown_task; + bool is_channel_shut_down; + } cross_thread_tasks; + + size_t window_update_batch_emit_threshold; + struct aws_channel_task window_update_task; + bool read_back_pressure_enabled; + bool window_update_in_progress; +}; + +struct channel_setup_args { + struct aws_allocator *alloc; + struct aws_channel *channel; + aws_channel_on_setup_completed_fn *on_setup_completed; + void *user_data; + struct aws_task task; +}; + +static void s_on_msg_pool_removed(struct aws_event_loop_local_object *object) { + struct aws_message_pool *msg_pool = object->object; + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "static: message pool %p has been purged " + "from the event-loop: likely because of shutdown", + (void *)msg_pool); + struct aws_allocator *alloc = msg_pool->alloc; + aws_message_pool_clean_up(msg_pool); + aws_mem_release(alloc, msg_pool); + aws_mem_release(alloc, object); +} + +static void s_on_channel_setup_complete(struct aws_task *task, void *arg, enum aws_task_status task_status) { + + (void)task; + struct channel_setup_args *setup_args = arg; + struct aws_message_pool *message_pool = NULL; + struct aws_event_loop_local_object *local_object = NULL; + + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL, "id=%p: setup complete, notifying caller.", (void *)setup_args->channel); + if (task_status == AWS_TASK_STATUS_RUN_READY) { + struct aws_event_loop_local_object stack_obj; + AWS_ZERO_STRUCT(stack_obj); + local_object = &stack_obj; + + if (aws_event_loop_fetch_local_object(setup_args->channel->loop, &s_message_pool_key, local_object)) { + + local_object = aws_mem_calloc(setup_args->alloc, 1, sizeof(struct aws_event_loop_local_object)); + if (!local_object) { + goto cleanup_setup_args; + } + + message_pool = aws_mem_acquire(setup_args->alloc, sizeof(struct aws_message_pool)); + if (!message_pool) { + goto cleanup_local_obj; + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL, + "id=%p: no message pool is currently stored in the event-loop " + "local storage, adding %p with max message size %zu, " + "message count 4, with 4 small blocks of 128 bytes.", + (void *)setup_args->channel, + (void *)message_pool, + g_aws_channel_max_fragment_size); + + struct aws_message_pool_creation_args creation_args = { + .application_data_msg_data_size = g_aws_channel_max_fragment_size, + .application_data_msg_count = 4, + .small_block_msg_count = 4, + .small_block_msg_data_size = 128, + }; + + if (aws_message_pool_init(message_pool, setup_args->alloc, &creation_args)) { + goto cleanup_msg_pool_mem; + } + + local_object->key = &s_message_pool_key; + local_object->object = message_pool; + local_object->on_object_removed = s_on_msg_pool_removed; + + if (aws_event_loop_put_local_object(setup_args->channel->loop, local_object)) { + goto cleanup_msg_pool; + } + } else { + message_pool = local_object->object; + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL, + "id=%p: message pool %p found in event-loop local storage: using it.", + (void *)setup_args->channel, + (void *)message_pool) + } + + setup_args->channel->msg_pool = message_pool; + setup_args->channel->channel_state = AWS_CHANNEL_ACTIVE; + setup_args->on_setup_completed(setup_args->channel, AWS_OP_SUCCESS, setup_args->user_data); + aws_channel_release_hold(setup_args->channel); + aws_mem_release(setup_args->alloc, setup_args); + return; + } + + goto cleanup_setup_args; + +cleanup_msg_pool: + aws_message_pool_clean_up(message_pool); + +cleanup_msg_pool_mem: + aws_mem_release(setup_args->alloc, message_pool); + +cleanup_local_obj: + aws_mem_release(setup_args->alloc, local_object); + +cleanup_setup_args: + setup_args->on_setup_completed(setup_args->channel, AWS_OP_ERR, setup_args->user_data); + aws_channel_release_hold(setup_args->channel); + aws_mem_release(setup_args->alloc, setup_args); +} + +static void s_schedule_cross_thread_tasks(struct aws_task *task, void *arg, enum aws_task_status status); + +static void s_destroy_partially_constructed_channel(struct aws_channel *channel) { + if (channel == NULL) { + return; + } + + aws_array_list_clean_up(&channel->statistic_list); + + aws_mem_release(channel->alloc, channel); +} + +struct aws_channel *aws_channel_new(struct aws_allocator *alloc, const struct aws_channel_options *creation_args) { + AWS_PRECONDITION(creation_args); + AWS_PRECONDITION(creation_args->event_loop); + AWS_PRECONDITION(creation_args->on_setup_completed); + + struct aws_channel *channel = aws_mem_calloc(alloc, 1, sizeof(struct aws_channel)); + if (!channel) { + return NULL; + } + + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL, "id=%p: Beginning creation and setup of new channel.", (void *)channel); + channel->alloc = alloc; + channel->loop = creation_args->event_loop; + channel->on_shutdown_completed = creation_args->on_shutdown_completed; + channel->shutdown_user_data = creation_args->shutdown_user_data; + + if (aws_array_list_init_dynamic( + &channel->statistic_list, alloc, INITIAL_STATISTIC_LIST_SIZE, sizeof(struct aws_crt_statistics_base *))) { + goto on_error; + } + + /* Start refcount at 2: + * 1 for self-reference, released from aws_channel_destroy() + * 1 for the setup task, released when task executes */ + aws_atomic_init_int(&channel->refcount, 2); + + struct channel_setup_args *setup_args = aws_mem_calloc(alloc, 1, sizeof(struct channel_setup_args)); + if (!setup_args) { + goto on_error; + } + + channel->channel_state = AWS_CHANNEL_SETTING_UP; + aws_linked_list_init(&channel->channel_thread_tasks.list); + aws_linked_list_init(&channel->cross_thread_tasks.list); + channel->cross_thread_tasks.lock = (struct aws_mutex)AWS_MUTEX_INIT; + + if (creation_args->enable_read_back_pressure) { + channel->read_back_pressure_enabled = true; + /* we probably only need room for one fragment, but let's avoid potential deadlocks + * on things like tls that need extra head-room. */ + channel->window_update_batch_emit_threshold = g_aws_channel_max_fragment_size * 2; + } + + aws_task_init( + &channel->cross_thread_tasks.scheduling_task, + s_schedule_cross_thread_tasks, + channel, + "schedule_cross_thread_tasks"); + + setup_args->alloc = alloc; + setup_args->channel = channel; + setup_args->on_setup_completed = creation_args->on_setup_completed; + setup_args->user_data = creation_args->setup_user_data; + + aws_task_init(&setup_args->task, s_on_channel_setup_complete, setup_args, "on_channel_setup_complete"); + aws_event_loop_schedule_task_now(creation_args->event_loop, &setup_args->task); + + return channel; + +on_error: + + s_destroy_partially_constructed_channel(channel); + + return NULL; +} + +static void s_cleanup_slot(struct aws_channel_slot *slot) { + if (slot) { + if (slot->handler) { + aws_channel_handler_destroy(slot->handler); + } + aws_mem_release(slot->alloc, slot); + } +} + +void aws_channel_destroy(struct aws_channel *channel) { + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL, "id=%p: destroying channel.", (void *)channel); + + aws_channel_release_hold(channel); +} + +static void s_final_channel_deletion_task(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + (void)status; + struct aws_channel *channel = arg; + + struct aws_channel_slot *current = channel->first; + + if (!current || !current->handler) { + /* Allow channels with no valid slots to skip shutdown process */ + channel->channel_state = AWS_CHANNEL_SHUT_DOWN; + } + + AWS_ASSERT(channel->channel_state == AWS_CHANNEL_SHUT_DOWN); + + while (current) { + struct aws_channel_slot *tmp = current->adj_right; + s_cleanup_slot(current); + current = tmp; + } + + aws_array_list_clean_up(&channel->statistic_list); + + aws_channel_set_statistics_handler(channel, NULL); + + aws_mem_release(channel->alloc, channel); +} + +void aws_channel_acquire_hold(struct aws_channel *channel) { + size_t prev_refcount = aws_atomic_fetch_add(&channel->refcount, 1); + AWS_ASSERT(prev_refcount != 0); + (void)prev_refcount; +} + +void aws_channel_release_hold(struct aws_channel *channel) { + size_t prev_refcount = aws_atomic_fetch_sub(&channel->refcount, 1); + AWS_ASSERT(prev_refcount != 0); + + if (prev_refcount == 1) { + /* Refcount is now 0, finish cleaning up channel memory. */ + if (aws_channel_thread_is_callers_thread(channel)) { + s_final_channel_deletion_task(NULL, channel, AWS_TASK_STATUS_RUN_READY); + } else { + aws_task_init(&channel->deletion_task, s_final_channel_deletion_task, channel, "final_channel_deletion"); + aws_event_loop_schedule_task_now(channel->loop, &channel->deletion_task); + } + } +} + +struct channel_shutdown_task_args { + struct aws_channel *channel; + struct aws_allocator *alloc; + int error_code; + struct aws_task task; +}; + +static int s_channel_shutdown(struct aws_channel *channel, int error_code, bool shutdown_immediately); + +static void s_on_shutdown_completion_task(struct aws_task *task, void *arg, enum aws_task_status status); + +static void s_shutdown_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) { + + (void)task; + (void)status; + struct shutdown_task *shutdown_task = arg; + struct aws_channel *channel = shutdown_task->channel; + int error_code = shutdown_task->error_code; + bool shutdown_immediately = shutdown_task->shutdown_immediately; + if (channel->channel_state < AWS_CHANNEL_SHUTTING_DOWN) { + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL, "id=%p: beginning shutdown process", (void *)channel); + + struct aws_channel_slot *slot = channel->first; + channel->channel_state = AWS_CHANNEL_SHUTTING_DOWN; + + if (slot) { + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "id=%p: shutting down slot %p (the first one) in the read direction", + (void *)channel, + (void *)slot); + + aws_channel_slot_shutdown(slot, AWS_CHANNEL_DIR_READ, error_code, shutdown_immediately); + return; + } + + channel->channel_state = AWS_CHANNEL_SHUT_DOWN; + AWS_LOGF_TRACE(AWS_LS_IO_CHANNEL, "id=%p: shutdown completed", (void *)channel); + + aws_mutex_lock(&channel->cross_thread_tasks.lock); + channel->cross_thread_tasks.is_channel_shut_down = true; + aws_mutex_unlock(&channel->cross_thread_tasks.lock); + + if (channel->on_shutdown_completed) { + channel->shutdown_notify_task.task.fn = s_on_shutdown_completion_task; + channel->shutdown_notify_task.task.arg = channel; + channel->shutdown_notify_task.error_code = error_code; + aws_event_loop_schedule_task_now(channel->loop, &channel->shutdown_notify_task.task); + } + } +} + +static int s_channel_shutdown(struct aws_channel *channel, int error_code, bool shutdown_immediately) { + bool need_to_schedule = true; + aws_mutex_lock(&channel->cross_thread_tasks.lock); + if (channel->cross_thread_tasks.shutdown_task.task.task_fn) { + need_to_schedule = false; + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL, "id=%p: Channel shutdown is already pending, not scheduling another.", (void *)channel); + + } else { + aws_channel_task_init( + &channel->cross_thread_tasks.shutdown_task.task, + s_shutdown_task, + &channel->cross_thread_tasks.shutdown_task, + "channel_shutdown"); + channel->cross_thread_tasks.shutdown_task.shutdown_immediately = shutdown_immediately; + channel->cross_thread_tasks.shutdown_task.channel = channel; + channel->cross_thread_tasks.shutdown_task.error_code = error_code; + } + + aws_mutex_unlock(&channel->cross_thread_tasks.lock); + + if (need_to_schedule) { + AWS_LOGF_TRACE(AWS_LS_IO_CHANNEL, "id=%p: channel shutdown task is scheduled", (void *)channel); + aws_channel_schedule_task_now(channel, &channel->cross_thread_tasks.shutdown_task.task); + } + + return AWS_OP_SUCCESS; +} + +int aws_channel_shutdown(struct aws_channel *channel, int error_code) { + return s_channel_shutdown(channel, error_code, false); +} + +struct aws_io_message *aws_channel_acquire_message_from_pool( + struct aws_channel *channel, + enum aws_io_message_type message_type, + size_t size_hint) { + + struct aws_io_message *message = aws_message_pool_acquire(channel->msg_pool, message_type, size_hint); + + if (AWS_LIKELY(message)) { + message->owning_channel = channel; + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "id=%p: acquired message %p of capacity %zu from pool %p. Requested size was %zu", + (void *)channel, + (void *)message, + message->message_data.capacity, + (void *)channel->msg_pool, + size_hint); + } + + return message; +} + +struct aws_channel_slot *aws_channel_slot_new(struct aws_channel *channel) { + struct aws_channel_slot *new_slot = aws_mem_calloc(channel->alloc, 1, sizeof(struct aws_channel_slot)); + if (!new_slot) { + return NULL; + } + + AWS_LOGF_TRACE(AWS_LS_IO_CHANNEL, "id=%p: creating new slot %p.", (void *)channel, (void *)new_slot); + new_slot->alloc = channel->alloc; + new_slot->channel = channel; + + if (!channel->first) { + channel->first = new_slot; + } + + return new_slot; +} + +int aws_channel_current_clock_time(struct aws_channel *channel, uint64_t *time_nanos) { + return aws_event_loop_current_clock_time(channel->loop, time_nanos); +} + +int aws_channel_fetch_local_object( + struct aws_channel *channel, + const void *key, + struct aws_event_loop_local_object *obj) { + + return aws_event_loop_fetch_local_object(channel->loop, (void *)key, obj); +} +int aws_channel_put_local_object( + struct aws_channel *channel, + const void *key, + const struct aws_event_loop_local_object *obj) { + + (void)key; + return aws_event_loop_put_local_object(channel->loop, (struct aws_event_loop_local_object *)obj); +} + +int aws_channel_remove_local_object( + struct aws_channel *channel, + const void *key, + struct aws_event_loop_local_object *removed_obj) { + + return aws_event_loop_remove_local_object(channel->loop, (void *)key, removed_obj); +} + +static void s_channel_task_run(struct aws_task *task, void *arg, enum aws_task_status status) { + struct aws_channel_task *channel_task = AWS_CONTAINER_OF(task, struct aws_channel_task, wrapper_task); + struct aws_channel *channel = arg; + + /* Any task that runs after shutdown completes is considered canceled */ + if (channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { + status = AWS_TASK_STATUS_CANCELED; + } + + aws_linked_list_remove(&channel_task->node); + channel_task->task_fn(channel_task, channel_task->arg, status); +} + +static void s_schedule_cross_thread_tasks(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + struct aws_channel *channel = arg; + + struct aws_linked_list cross_thread_task_list; + aws_linked_list_init(&cross_thread_task_list); + + /* Grab contents of cross-thread task list while we have the lock */ + aws_mutex_lock(&channel->cross_thread_tasks.lock); + aws_linked_list_swap_contents(&channel->cross_thread_tasks.list, &cross_thread_task_list); + aws_mutex_unlock(&channel->cross_thread_tasks.lock); + + /* If the channel has shut down since the cross-thread tasks were scheduled, run tasks immediately as canceled */ + if (channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { + status = AWS_TASK_STATUS_CANCELED; + } + + while (!aws_linked_list_empty(&cross_thread_task_list)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&cross_thread_task_list); + struct aws_channel_task *channel_task = AWS_CONTAINER_OF(node, struct aws_channel_task, node); + + if ((channel_task->wrapper_task.timestamp == 0) || (status == AWS_TASK_STATUS_CANCELED)) { + /* Run "now" tasks, and canceled tasks, immediately */ + channel_task->task_fn(channel_task, channel_task->arg, status); + } else { + /* "Future" tasks are scheduled with the event-loop. */ + aws_linked_list_push_back(&channel->channel_thread_tasks.list, &channel_task->node); + aws_event_loop_schedule_task_future( + channel->loop, &channel_task->wrapper_task, channel_task->wrapper_task.timestamp); + } + } +} + +void aws_channel_task_init( + struct aws_channel_task *channel_task, + aws_channel_task_fn *task_fn, + void *arg, + const char *type_tag) { + AWS_ZERO_STRUCT(*channel_task); + channel_task->task_fn = task_fn; + channel_task->arg = arg; + channel_task->type_tag = type_tag; +} + +/* Common functionality for scheduling "now" and "future" tasks. + * For "now" tasks, pass 0 for `run_at_nanos` */ +static void s_register_pending_task( + struct aws_channel *channel, + struct aws_channel_task *channel_task, + uint64_t run_at_nanos) { + + /* Reset every property on channel task other than user's fn & arg.*/ + aws_task_init(&channel_task->wrapper_task, s_channel_task_run, channel, channel_task->type_tag); + channel_task->wrapper_task.timestamp = run_at_nanos; + aws_linked_list_node_reset(&channel_task->node); + + if (aws_channel_thread_is_callers_thread(channel)) { + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "id=%p: scheduling task with wrapper task id %p.", + (void *)channel, + (void *)&channel_task->wrapper_task); + + /* If channel is shut down, run task immediately as canceled */ + if (channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL, + "id=%p: Running %s channel task immediately as canceled due to shut down channel", + (void *)channel, + channel_task->type_tag); + channel_task->task_fn(channel_task, channel_task->arg, AWS_TASK_STATUS_CANCELED); + return; + } + + aws_linked_list_push_back(&channel->channel_thread_tasks.list, &channel_task->node); + if (run_at_nanos == 0) { + aws_event_loop_schedule_task_now(channel->loop, &channel_task->wrapper_task); + } else { + aws_event_loop_schedule_task_future( + channel->loop, &channel_task->wrapper_task, channel_task->wrapper_task.timestamp); + } + return; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "id=%p: scheduling task with wrapper task id %p from " + "outside the event-loop thread.", + (void *)channel, + (void *)&channel_task->wrapper_task); + /* Outside event-loop thread... */ + bool should_cancel_task = false; + + /* Begin Critical Section */ + aws_mutex_lock(&channel->cross_thread_tasks.lock); + if (channel->cross_thread_tasks.is_channel_shut_down) { + should_cancel_task = true; /* run task outside critical section to avoid deadlock */ + } else { + bool list_was_empty = aws_linked_list_empty(&channel->cross_thread_tasks.list); + aws_linked_list_push_back(&channel->cross_thread_tasks.list, &channel_task->node); + + if (list_was_empty) { + aws_event_loop_schedule_task_now(channel->loop, &channel->cross_thread_tasks.scheduling_task); + } + } + aws_mutex_unlock(&channel->cross_thread_tasks.lock); + /* End Critical Section */ + + if (should_cancel_task) { + channel_task->task_fn(channel_task, channel_task->arg, AWS_TASK_STATUS_CANCELED); + } +} + +void aws_channel_schedule_task_now(struct aws_channel *channel, struct aws_channel_task *task) { + s_register_pending_task(channel, task, 0); +} + +void aws_channel_schedule_task_future( + struct aws_channel *channel, + struct aws_channel_task *task, + uint64_t run_at_nanos) { + + s_register_pending_task(channel, task, run_at_nanos); +} + +bool aws_channel_thread_is_callers_thread(struct aws_channel *channel) { + return aws_event_loop_thread_is_callers_thread(channel->loop); +} + +static void s_update_channel_slot_message_overheads(struct aws_channel *channel) { + size_t overhead = 0; + struct aws_channel_slot *slot_iter = channel->first; + while (slot_iter) { + slot_iter->upstream_message_overhead = overhead; + + if (slot_iter->handler) { + overhead += slot_iter->handler->vtable->message_overhead(slot_iter->handler); + } + slot_iter = slot_iter->adj_right; + } +} + +int aws_channel_slot_set_handler(struct aws_channel_slot *slot, struct aws_channel_handler *handler) { + slot->handler = handler; + slot->handler->slot = slot; + s_update_channel_slot_message_overheads(slot->channel); + + return aws_channel_slot_increment_read_window(slot, slot->handler->vtable->initial_window_size(handler)); +} + +int aws_channel_slot_remove(struct aws_channel_slot *slot) { + if (slot->adj_right) { + slot->adj_right->adj_left = slot->adj_left; + + if (slot == slot->channel->first) { + slot->channel->first = slot->adj_right; + } + } + + if (slot->adj_left) { + slot->adj_left->adj_right = slot->adj_right; + } + + if (slot == slot->channel->first) { + slot->channel->first = NULL; + } + + s_update_channel_slot_message_overheads(slot->channel); + s_cleanup_slot(slot); + return AWS_OP_SUCCESS; +} + +int aws_channel_slot_replace(struct aws_channel_slot *remove, struct aws_channel_slot *new_slot) { + new_slot->adj_left = remove->adj_left; + + if (remove->adj_left) { + remove->adj_left->adj_right = new_slot; + } + + new_slot->adj_right = remove->adj_right; + + if (remove->adj_right) { + remove->adj_right->adj_left = new_slot; + } + + if (remove == remove->channel->first) { + remove->channel->first = new_slot; + } + + s_update_channel_slot_message_overheads(remove->channel); + s_cleanup_slot(remove); + return AWS_OP_SUCCESS; +} + +int aws_channel_slot_insert_right(struct aws_channel_slot *slot, struct aws_channel_slot *to_add) { + to_add->adj_right = slot->adj_right; + + if (slot->adj_right) { + slot->adj_right->adj_left = to_add; + } + + slot->adj_right = to_add; + to_add->adj_left = slot; + + return AWS_OP_SUCCESS; +} + +int aws_channel_slot_insert_end(struct aws_channel *channel, struct aws_channel_slot *to_add) { + /* It's actually impossible there's not a first if the user went through the aws_channel_slot_new() function. + * But also check that a user didn't call insert_end if it's the first slot in the channel since first would already + * have been set. */ + if (AWS_LIKELY(channel->first && channel->first != to_add)) { + struct aws_channel_slot *cur = channel->first; + while (cur->adj_right) { + cur = cur->adj_right; + } + + return aws_channel_slot_insert_right(cur, to_add); + } + + AWS_ASSERT(0); + return AWS_OP_ERR; +} + +int aws_channel_slot_insert_left(struct aws_channel_slot *slot, struct aws_channel_slot *to_add) { + to_add->adj_left = slot->adj_left; + + if (slot->adj_left) { + slot->adj_left->adj_right = to_add; + } + + slot->adj_left = to_add; + to_add->adj_right = slot; + + if (slot == slot->channel->first) { + slot->channel->first = to_add; + } + + return AWS_OP_SUCCESS; +} + +int aws_channel_slot_send_message( + struct aws_channel_slot *slot, + struct aws_io_message *message, + enum aws_channel_direction dir) { + + if (dir == AWS_CHANNEL_DIR_READ) { + AWS_ASSERT(slot->adj_right); + AWS_ASSERT(slot->adj_right->handler); + + if (!slot->channel->read_back_pressure_enabled || slot->adj_right->window_size >= message->message_data.len) { + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "id=%p: sending read message of size %zu, " + "from slot %p to slot %p with handler %p.", + (void *)slot->channel, + message->message_data.len, + (void *)slot, + (void *)slot->adj_right, + (void *)slot->adj_right->handler); + slot->adj_right->window_size -= message->message_data.len; + return aws_channel_handler_process_read_message(slot->adj_right->handler, slot->adj_right, message); + } + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL, + "id=%p: sending message of size %zu, " + "from slot %p to slot %p with handler %p, but this would exceed the channel's " + "read window, this is always a programming error.", + (void *)slot->channel, + message->message_data.len, + (void *)slot, + (void *)slot->adj_right, + (void *)slot->adj_right->handler); + return aws_raise_error(AWS_IO_CHANNEL_READ_WOULD_EXCEED_WINDOW); + } + + AWS_ASSERT(slot->adj_left); + AWS_ASSERT(slot->adj_left->handler); + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "id=%p: sending write message of size %zu, " + "from slot %p to slot %p with handler %p.", + (void *)slot->channel, + message->message_data.len, + (void *)slot, + (void *)slot->adj_left, + (void *)slot->adj_left->handler); + return aws_channel_handler_process_write_message(slot->adj_left->handler, slot->adj_left, message); +} + +struct aws_io_message *aws_channel_slot_acquire_max_message_for_write(struct aws_channel_slot *slot) { + AWS_PRECONDITION(slot); + AWS_PRECONDITION(slot->channel); + AWS_PRECONDITION(aws_channel_thread_is_callers_thread(slot->channel)); + + const size_t overhead = aws_channel_slot_upstream_message_overhead(slot); + if (overhead >= g_aws_channel_max_fragment_size) { + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL, "id=%p: Upstream overhead exceeds channel's max message size.", (void *)slot->channel); + aws_raise_error(AWS_ERROR_INVALID_STATE); + return NULL; + } + + const size_t size_hint = g_aws_channel_max_fragment_size - overhead; + return aws_channel_acquire_message_from_pool(slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, size_hint); +} + +static void s_window_update_task(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) { + (void)channel_task; + struct aws_channel *channel = arg; + + if (status == AWS_TASK_STATUS_RUN_READY && channel->channel_state < AWS_CHANNEL_SHUTTING_DOWN) { + /* get the right-most slot to start the updates. */ + struct aws_channel_slot *slot = channel->first; + while (slot->adj_right) { + slot = slot->adj_right; + } + + while (slot->adj_left) { + struct aws_channel_slot *upstream_slot = slot->adj_left; + if (upstream_slot->handler) { + slot->window_size = aws_add_size_saturating(slot->window_size, slot->current_window_update_batch_size); + size_t update_size = slot->current_window_update_batch_size; + slot->current_window_update_batch_size = 0; + if (aws_channel_handler_increment_read_window(upstream_slot->handler, upstream_slot, update_size)) { + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL, + "channel %p: channel update task failed with status %d", + (void *)slot->channel, + aws_last_error()); + slot->channel->window_update_in_progress = false; + aws_channel_shutdown(channel, aws_last_error()); + return; + } + } + slot = slot->adj_left; + } + } + channel->window_update_in_progress = false; +} + +int aws_channel_slot_increment_read_window(struct aws_channel_slot *slot, size_t window) { + + if (slot->channel->read_back_pressure_enabled && slot->channel->channel_state < AWS_CHANNEL_SHUTTING_DOWN) { + slot->current_window_update_batch_size = + aws_add_size_saturating(slot->current_window_update_batch_size, window); + + if (!slot->channel->window_update_in_progress && + slot->window_size <= slot->channel->window_update_batch_emit_threshold) { + slot->channel->window_update_in_progress = true; + aws_channel_task_init( + &slot->channel->window_update_task, s_window_update_task, slot->channel, "window update task"); + aws_channel_schedule_task_now(slot->channel, &slot->channel->window_update_task); + } + } + + return AWS_OP_SUCCESS; +} + +int aws_channel_slot_shutdown( + struct aws_channel_slot *slot, + enum aws_channel_direction dir, + int err_code, + bool free_scarce_resources_immediately) { + AWS_ASSERT(slot->handler); + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "id=%p: shutting down slot %p, with handler %p " + "in %s direction with error code %d", + (void *)slot->channel, + (void *)slot, + (void *)slot->handler, + (dir == AWS_CHANNEL_DIR_READ) ? "read" : "write", + err_code); + return aws_channel_handler_shutdown(slot->handler, slot, dir, err_code, free_scarce_resources_immediately); +} + +static void s_on_shutdown_completion_task(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)status; + + struct aws_shutdown_notification_task *shutdown_notify = (struct aws_shutdown_notification_task *)task; + struct aws_channel *channel = arg; + AWS_ASSERT(channel->channel_state == AWS_CHANNEL_SHUT_DOWN); + + /* Cancel tasks that have been scheduled with the event loop */ + while (!aws_linked_list_empty(&channel->channel_thread_tasks.list)) { + struct aws_linked_list_node *node = aws_linked_list_front(&channel->channel_thread_tasks.list); + struct aws_channel_task *channel_task = AWS_CONTAINER_OF(node, struct aws_channel_task, node); + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL, + "id=%p: during shutdown, canceling task %p", + (void *)channel, + (void *)&channel_task->wrapper_task); + /* The task will remove itself from the list when it's canceled */ + aws_event_loop_cancel_task(channel->loop, &channel_task->wrapper_task); + } + + /* Cancel off-thread tasks, which haven't made it to the event-loop thread yet */ + aws_mutex_lock(&channel->cross_thread_tasks.lock); + bool cancel_cross_thread_tasks = !aws_linked_list_empty(&channel->cross_thread_tasks.list); + aws_mutex_unlock(&channel->cross_thread_tasks.lock); + + if (cancel_cross_thread_tasks) { + aws_event_loop_cancel_task(channel->loop, &channel->cross_thread_tasks.scheduling_task); + } + + AWS_ASSERT(aws_linked_list_empty(&channel->channel_thread_tasks.list)); + AWS_ASSERT(aws_linked_list_empty(&channel->cross_thread_tasks.list)); + + channel->on_shutdown_completed(channel, shutdown_notify->error_code, channel->shutdown_user_data); +} + +static void s_run_shutdown_write_direction(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)arg; + (void)status; + + struct aws_shutdown_notification_task *shutdown_notify = (struct aws_shutdown_notification_task *)task; + task->fn = NULL; + task->arg = NULL; + struct aws_channel_slot *slot = shutdown_notify->slot; + aws_channel_handler_shutdown( + slot->handler, slot, AWS_CHANNEL_DIR_WRITE, shutdown_notify->error_code, shutdown_notify->shutdown_immediately); +} + +int aws_channel_slot_on_handler_shutdown_complete( + struct aws_channel_slot *slot, + enum aws_channel_direction dir, + int err_code, + bool free_scarce_resources_immediately) { + + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL, + "id=%p: handler %p shutdown in %s dir completed.", + (void *)slot->channel, + (void *)slot->handler, + (dir == AWS_CHANNEL_DIR_READ) ? "read" : "write"); + + if (slot->channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { + return AWS_OP_SUCCESS; + } + + if (dir == AWS_CHANNEL_DIR_READ) { + if (slot->adj_right && slot->adj_right->handler) { + return aws_channel_handler_shutdown( + slot->adj_right->handler, slot->adj_right, dir, err_code, free_scarce_resources_immediately); + } + + /* break the shutdown sequence so we don't have handlers having to deal with their memory disappearing out from + * under them during a shutdown process. */ + slot->channel->shutdown_notify_task.slot = slot; + slot->channel->shutdown_notify_task.shutdown_immediately = free_scarce_resources_immediately; + slot->channel->shutdown_notify_task.error_code = err_code; + slot->channel->shutdown_notify_task.task.fn = s_run_shutdown_write_direction; + slot->channel->shutdown_notify_task.task.arg = NULL; + + aws_event_loop_schedule_task_now(slot->channel->loop, &slot->channel->shutdown_notify_task.task); + return AWS_OP_SUCCESS; + } + + if (slot->adj_left && slot->adj_left->handler) { + return aws_channel_handler_shutdown( + slot->adj_left->handler, slot->adj_left, dir, err_code, free_scarce_resources_immediately); + } + + if (slot->channel->first == slot) { + slot->channel->channel_state = AWS_CHANNEL_SHUT_DOWN; + aws_mutex_lock(&slot->channel->cross_thread_tasks.lock); + slot->channel->cross_thread_tasks.is_channel_shut_down = true; + aws_mutex_unlock(&slot->channel->cross_thread_tasks.lock); + + if (slot->channel->on_shutdown_completed) { + slot->channel->shutdown_notify_task.task.fn = s_on_shutdown_completion_task; + slot->channel->shutdown_notify_task.task.arg = slot->channel; + slot->channel->shutdown_notify_task.error_code = err_code; + aws_event_loop_schedule_task_now(slot->channel->loop, &slot->channel->shutdown_notify_task.task); + } + } + + return AWS_OP_SUCCESS; +} + +size_t aws_channel_slot_downstream_read_window(struct aws_channel_slot *slot) { + AWS_ASSERT(slot->adj_right); + return slot->channel->read_back_pressure_enabled ? slot->adj_right->window_size : SIZE_MAX; +} + +size_t aws_channel_slot_upstream_message_overhead(struct aws_channel_slot *slot) { + return slot->upstream_message_overhead; +} + +void aws_channel_handler_destroy(struct aws_channel_handler *handler) { + AWS_ASSERT(handler->vtable && handler->vtable->destroy); + handler->vtable->destroy(handler); +} + +int aws_channel_handler_process_read_message( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_io_message *message) { + + AWS_ASSERT(handler->vtable && handler->vtable->process_read_message); + return handler->vtable->process_read_message(handler, slot, message); +} + +int aws_channel_handler_process_write_message( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_io_message *message) { + + AWS_ASSERT(handler->vtable && handler->vtable->process_write_message); + return handler->vtable->process_write_message(handler, slot, message); +} + +int aws_channel_handler_increment_read_window( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + size_t size) { + + AWS_ASSERT(handler->vtable && handler->vtable->increment_read_window); + + return handler->vtable->increment_read_window(handler, slot, size); +} + +int aws_channel_handler_shutdown( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + enum aws_channel_direction dir, + int error_code, + bool free_scarce_resources_immediately) { + + AWS_ASSERT(handler->vtable && handler->vtable->shutdown); + return handler->vtable->shutdown(handler, slot, dir, error_code, free_scarce_resources_immediately); +} + +size_t aws_channel_handler_initial_window_size(struct aws_channel_handler *handler) { + AWS_ASSERT(handler->vtable && handler->vtable->initial_window_size); + return handler->vtable->initial_window_size(handler); +} + +struct aws_channel_slot *aws_channel_get_first_slot(struct aws_channel *channel) { + return channel->first; +} + +static void s_reset_statistics(struct aws_channel *channel) { + AWS_FATAL_ASSERT(aws_channel_thread_is_callers_thread(channel)); + + struct aws_channel_slot *current_slot = channel->first; + while (current_slot) { + struct aws_channel_handler *handler = current_slot->handler; + if (handler != NULL && handler->vtable->reset_statistics != NULL) { + handler->vtable->reset_statistics(handler); + } + current_slot = current_slot->adj_right; + } +} + +static void s_channel_gather_statistics_task(struct aws_task *task, void *arg, enum aws_task_status status) { + if (status != AWS_TASK_STATUS_RUN_READY) { + return; + } + + struct aws_channel *channel = arg; + if (channel->statistics_handler == NULL) { + return; + } + + if (channel->channel_state == AWS_CHANNEL_SHUTTING_DOWN || channel->channel_state == AWS_CHANNEL_SHUT_DOWN) { + return; + } + + uint64_t now_ns = 0; + if (aws_channel_current_clock_time(channel, &now_ns)) { + return; + } + + uint64_t now_ms = aws_timestamp_convert(now_ns, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_MILLIS, NULL); + + struct aws_array_list *statistics_list = &channel->statistic_list; + aws_array_list_clear(statistics_list); + + struct aws_channel_slot *current_slot = channel->first; + while (current_slot) { + struct aws_channel_handler *handler = current_slot->handler; + if (handler != NULL && handler->vtable->gather_statistics != NULL) { + handler->vtable->gather_statistics(handler, statistics_list); + } + current_slot = current_slot->adj_right; + } + + struct aws_crt_statistics_sample_interval sample_interval = { + .begin_time_ms = channel->statistics_interval_start_time_ms, .end_time_ms = now_ms}; + + aws_crt_statistics_handler_process_statistics( + channel->statistics_handler, &sample_interval, statistics_list, channel); + + s_reset_statistics(channel); + + uint64_t reschedule_interval_ns = aws_timestamp_convert( + aws_crt_statistics_handler_get_report_interval_ms(channel->statistics_handler), + AWS_TIMESTAMP_MILLIS, + AWS_TIMESTAMP_NANOS, + NULL); + + aws_event_loop_schedule_task_future(channel->loop, task, now_ns + reschedule_interval_ns); + + channel->statistics_interval_start_time_ms = now_ms; +} + +int aws_channel_set_statistics_handler(struct aws_channel *channel, struct aws_crt_statistics_handler *handler) { + AWS_FATAL_ASSERT(aws_channel_thread_is_callers_thread(channel)); + + if (channel->statistics_handler) { + aws_crt_statistics_handler_destroy(channel->statistics_handler); + aws_event_loop_cancel_task(channel->loop, &channel->statistics_task); + channel->statistics_handler = NULL; + } + + if (handler != NULL) { + aws_task_init(&channel->statistics_task, s_channel_gather_statistics_task, channel, "gather_statistics"); + + uint64_t now_ns = 0; + if (aws_channel_current_clock_time(channel, &now_ns)) { + return AWS_OP_ERR; + } + + uint64_t report_time_ns = now_ns + aws_timestamp_convert( + aws_crt_statistics_handler_get_report_interval_ms(handler), + AWS_TIMESTAMP_MILLIS, + AWS_TIMESTAMP_NANOS, + NULL); + + channel->statistics_interval_start_time_ms = + aws_timestamp_convert(now_ns, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_MILLIS, NULL); + s_reset_statistics(channel); + + aws_event_loop_schedule_task_future(channel->loop, &channel->statistics_task, report_time_ns); + } + + channel->statistics_handler = handler; + + return AWS_OP_SUCCESS; +} + +struct aws_event_loop *aws_channel_get_event_loop(struct aws_channel *channel) { + return channel->loop; +} diff --git a/contrib/restricted/aws/aws-c-io/source/channel_bootstrap.c b/contrib/restricted/aws/aws-c-io/source/channel_bootstrap.c index f5e0ad7aff..581e72572f 100644 --- a/contrib/restricted/aws/aws-c-io/source/channel_bootstrap.c +++ b/contrib/restricted/aws/aws-c-io/source/channel_bootstrap.c @@ -1,1421 +1,1421 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/channel_bootstrap.h> - -#include <aws/common/ref_count.h> -#include <aws/common/string.h> -#include <aws/io/event_loop.h> -#include <aws/io/logging.h> -#include <aws/io/socket.h> -#include <aws/io/socket_channel_handler.h> -#include <aws/io/tls_channel_handler.h> - -#if _MSC_VER -/* non-constant aggregate initializer */ -# pragma warning(disable : 4204) -/* allow automatic variable to escape scope - (it's intentional and we make sure it doesn't actually return - before the task is finished).*/ -# pragma warning(disable : 4221) -#endif - -#define DEFAULT_DNS_TTL 30 - -static void s_client_bootstrap_destroy_impl(struct aws_client_bootstrap *bootstrap) { - AWS_ASSERT(bootstrap); - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: destroying", (void *)bootstrap); - aws_client_bootstrap_shutdown_complete_fn *on_shutdown_complete = bootstrap->on_shutdown_complete; - void *user_data = bootstrap->user_data; - - aws_event_loop_group_release(bootstrap->event_loop_group); - aws_host_resolver_release(bootstrap->host_resolver); - - aws_mem_release(bootstrap->allocator, bootstrap); - - if (on_shutdown_complete) { - on_shutdown_complete(user_data); - } -} - -struct aws_client_bootstrap *aws_client_bootstrap_acquire(struct aws_client_bootstrap *bootstrap) { - if (bootstrap != NULL) { - aws_ref_count_acquire(&bootstrap->ref_count); - } - - return bootstrap; -} - -void aws_client_bootstrap_release(struct aws_client_bootstrap *bootstrap) { - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: releasing bootstrap reference", (void *)bootstrap); - if (bootstrap != NULL) { - aws_ref_count_release(&bootstrap->ref_count); - } -} - -struct aws_client_bootstrap *aws_client_bootstrap_new( - struct aws_allocator *allocator, - const struct aws_client_bootstrap_options *options) { - AWS_ASSERT(allocator); - AWS_ASSERT(options); - AWS_ASSERT(options->event_loop_group); - - struct aws_client_bootstrap *bootstrap = aws_mem_calloc(allocator, 1, sizeof(struct aws_client_bootstrap)); - if (!bootstrap) { - return NULL; - } - - AWS_LOGF_INFO( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Initializing client bootstrap with event-loop group %p", - (void *)bootstrap, - (void *)options->event_loop_group); - - bootstrap->allocator = allocator; - bootstrap->event_loop_group = aws_event_loop_group_acquire(options->event_loop_group); - bootstrap->on_protocol_negotiated = NULL; - aws_ref_count_init( - &bootstrap->ref_count, bootstrap, (aws_simple_completion_callback *)s_client_bootstrap_destroy_impl); - bootstrap->host_resolver = aws_host_resolver_acquire(options->host_resolver); - bootstrap->on_shutdown_complete = options->on_shutdown_complete; - bootstrap->user_data = options->user_data; - - if (options->host_resolution_config) { - bootstrap->host_resolver_config = *options->host_resolution_config; - } else { - bootstrap->host_resolver_config = (struct aws_host_resolution_config){ - .impl = aws_default_dns_resolve, - .max_ttl = DEFAULT_DNS_TTL, - .impl_data = NULL, - }; - } - - return bootstrap; -} - -int aws_client_bootstrap_set_alpn_callback( - struct aws_client_bootstrap *bootstrap, - aws_channel_on_protocol_negotiated_fn *on_protocol_negotiated) { - AWS_ASSERT(on_protocol_negotiated); - - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: Setting ALPN callback", (void *)bootstrap); - bootstrap->on_protocol_negotiated = on_protocol_negotiated; - return AWS_OP_SUCCESS; -} - -struct client_channel_data { - struct aws_channel *channel; - struct aws_socket *socket; - struct aws_tls_connection_options tls_options; - aws_channel_on_protocol_negotiated_fn *on_protocol_negotiated; - aws_tls_on_data_read_fn *user_on_data_read; - aws_tls_on_negotiation_result_fn *user_on_negotiation_result; - aws_tls_on_error_fn *user_on_error; - void *tls_user_data; - bool use_tls; -}; - -struct client_connection_args { - struct aws_client_bootstrap *bootstrap; - aws_client_bootstrap_on_channel_event_fn *creation_callback; - aws_client_bootstrap_on_channel_event_fn *setup_callback; - aws_client_bootstrap_on_channel_event_fn *shutdown_callback; - struct client_channel_data channel_data; - struct aws_socket_options outgoing_options; - uint16_t outgoing_port; - struct aws_string *host_name; - void *user_data; - uint8_t addresses_count; - uint8_t failed_count; - bool connection_chosen; - bool setup_called; - bool enable_read_back_pressure; - - /* - * It is likely that all reference adjustments to the connection args take place in a single event loop - * thread and are thus thread-safe. I can imagine some complex future scenarios where that might not hold true - * and so it seems reasonable to switch now to a safe pattern. - * - */ - struct aws_ref_count ref_count; -}; - -static struct client_connection_args *s_client_connection_args_acquire(struct client_connection_args *args) { - if (args != NULL) { - aws_ref_count_acquire(&args->ref_count); - } - - return args; -} - -static void s_client_connection_args_destroy(struct client_connection_args *args) { - AWS_ASSERT(args); - - struct aws_allocator *allocator = args->bootstrap->allocator; - aws_client_bootstrap_release(args->bootstrap); - if (args->host_name) { - aws_string_destroy(args->host_name); - } - - if (args->channel_data.use_tls) { - aws_tls_connection_options_clean_up(&args->channel_data.tls_options); - } - - aws_mem_release(allocator, args); -} - -static void s_client_connection_args_release(struct client_connection_args *args) { - if (args != NULL) { - aws_ref_count_release(&args->ref_count); - } -} - -static void s_connection_args_setup_callback( - struct client_connection_args *args, - int error_code, - struct aws_channel *channel) { - /* setup_callback is always called exactly once */ - AWS_ASSERT(!args->setup_called); - if (!args->setup_called) { - AWS_ASSERT((error_code == AWS_OP_SUCCESS) == (channel != NULL)); - aws_client_bootstrap_on_channel_event_fn *setup_callback = args->setup_callback; - setup_callback(args->bootstrap, error_code, channel, args->user_data); - args->setup_called = true; - /* if setup_callback is called with an error, we will not call shutdown_callback */ - if (error_code) { - args->shutdown_callback = NULL; - } - s_client_connection_args_release(args); - } -} - -static void s_connection_args_creation_callback(struct client_connection_args *args, struct aws_channel *channel) { - - AWS_FATAL_ASSERT(channel != NULL); - - if (args->creation_callback) { - args->creation_callback(args->bootstrap, AWS_ERROR_SUCCESS, channel, args->user_data); - } -} - -static void s_connection_args_shutdown_callback( - struct client_connection_args *args, - int error_code, - struct aws_channel *channel) { - - if (!args->setup_called) { - /* if setup_callback was not called yet, an error occurred, ensure we tell the user *SOMETHING* */ - error_code = (error_code) ? error_code : AWS_ERROR_UNKNOWN; - s_connection_args_setup_callback(args, error_code, NULL); - return; - } - - aws_client_bootstrap_on_channel_event_fn *shutdown_callback = args->shutdown_callback; - if (shutdown_callback) { - shutdown_callback(args->bootstrap, error_code, channel, args->user_data); - } -} - -static void s_tls_client_on_negotiation_result( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - int err_code, - void *user_data) { - struct client_connection_args *connection_args = user_data; - - if (connection_args->channel_data.user_on_negotiation_result) { - connection_args->channel_data.user_on_negotiation_result( - handler, slot, err_code, connection_args->channel_data.tls_user_data); - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: tls negotiation result %d on channel %p", - (void *)connection_args->bootstrap, - err_code, - (void *)slot->channel); - - /* if an error occurred, the user callback will be delivered in shutdown */ - if (err_code) { - aws_channel_shutdown(slot->channel, err_code); - return; - } - - struct aws_channel *channel = connection_args->channel_data.channel; - s_connection_args_setup_callback(connection_args, AWS_ERROR_SUCCESS, channel); -} - -/* in the context of a channel bootstrap, we don't care about these, but since we're hooking into these APIs we have to - * provide a proxy for the user actually receiving their callbacks. */ -static void s_tls_client_on_data_read( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_byte_buf *buffer, - void *user_data) { - struct client_connection_args *connection_args = user_data; - - if (connection_args->channel_data.user_on_data_read) { - connection_args->channel_data.user_on_data_read( - handler, slot, buffer, connection_args->channel_data.tls_user_data); - } -} - -/* in the context of a channel bootstrap, we don't care about these, but since we're hooking into these APIs we have to - * provide a proxy for the user actually receiving their callbacks. */ -static void s_tls_client_on_error( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - int err, - const char *message, - void *user_data) { - struct client_connection_args *connection_args = user_data; - - if (connection_args->channel_data.user_on_error) { - connection_args->channel_data.user_on_error( - handler, slot, err, message, connection_args->channel_data.tls_user_data); - } -} - -static inline int s_setup_client_tls(struct client_connection_args *connection_args, struct aws_channel *channel) { - struct aws_channel_slot *tls_slot = aws_channel_slot_new(channel); - - /* as far as cleanup goes, since this stuff is being added to a channel, the caller will free this memory - when they clean up the channel. */ - if (!tls_slot) { - return AWS_OP_ERR; - } - - struct aws_channel_handler *tls_handler = aws_tls_client_handler_new( - connection_args->bootstrap->allocator, &connection_args->channel_data.tls_options, tls_slot); - - if (!tls_handler) { - aws_mem_release(connection_args->bootstrap->allocator, (void *)tls_slot); - return AWS_OP_ERR; - } - - aws_channel_slot_insert_end(channel, tls_slot); - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Setting up client TLS on channel %p with handler %p on slot %p", - (void *)connection_args->bootstrap, - (void *)channel, - (void *)tls_handler, - (void *)tls_slot); - - if (aws_channel_slot_set_handler(tls_slot, tls_handler) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - - if (connection_args->channel_data.on_protocol_negotiated) { - struct aws_channel_slot *alpn_slot = aws_channel_slot_new(channel); - - if (!alpn_slot) { - return AWS_OP_ERR; - } - - struct aws_channel_handler *alpn_handler = aws_tls_alpn_handler_new( - connection_args->bootstrap->allocator, - connection_args->channel_data.on_protocol_negotiated, - connection_args->user_data); - - if (!alpn_handler) { - aws_mem_release(connection_args->bootstrap->allocator, (void *)alpn_slot); - return AWS_OP_ERR; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Setting up ALPN handler on channel " - "%p with handler %p on slot %p", - (void *)connection_args->bootstrap, - (void *)channel, - (void *)alpn_handler, - (void *)alpn_slot); - - aws_channel_slot_insert_right(tls_slot, alpn_slot); - if (aws_channel_slot_set_handler(alpn_slot, alpn_handler) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - } - - if (aws_tls_client_handler_start_negotiation(tls_handler) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -static void s_on_client_channel_on_setup_completed(struct aws_channel *channel, int error_code, void *user_data) { - struct client_connection_args *connection_args = user_data; - int err_code = error_code; - - if (!err_code) { - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: channel %p setup succeeded: bootstrapping.", - (void *)connection_args->bootstrap, - (void *)channel); - - struct aws_channel_slot *socket_slot = aws_channel_slot_new(channel); - - if (!socket_slot) { - err_code = aws_last_error(); - goto error; - } - - struct aws_channel_handler *socket_channel_handler = aws_socket_handler_new( - connection_args->bootstrap->allocator, - connection_args->channel_data.socket, - socket_slot, - g_aws_channel_max_fragment_size); - - if (!socket_channel_handler) { - err_code = aws_last_error(); - aws_channel_slot_remove(socket_slot); - socket_slot = NULL; - goto error; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Setting up socket handler on channel " - "%p with handler %p on slot %p.", - (void *)connection_args->bootstrap, - (void *)channel, - (void *)socket_channel_handler, - (void *)socket_slot); - - if (aws_channel_slot_set_handler(socket_slot, socket_channel_handler)) { - err_code = aws_last_error(); - goto error; - } - - if (connection_args->channel_data.use_tls) { - /* we don't want to notify the user that the channel is ready yet, since tls is still negotiating, wait - * for the negotiation callback and handle it then.*/ - if (s_setup_client_tls(connection_args, channel)) { - err_code = aws_last_error(); - goto error; - } - } else { - s_connection_args_setup_callback(connection_args, AWS_OP_SUCCESS, channel); - } - - return; - } - -error: - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: channel %p setup failed with error %d.", - (void *)connection_args->bootstrap, - (void *)channel, - err_code); - aws_channel_shutdown(channel, err_code); - /* the channel shutdown callback will clean the channel up */ -} - -static void s_on_client_channel_on_shutdown(struct aws_channel *channel, int error_code, void *user_data) { - struct client_connection_args *connection_args = user_data; - - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: channel %p shutdown with error %d.", - (void *)connection_args->bootstrap, - (void *)channel, - error_code); - - /* note it's not safe to reference the bootstrap after the callback. */ - struct aws_allocator *allocator = connection_args->bootstrap->allocator; - s_connection_args_shutdown_callback(connection_args, error_code, channel); - - aws_channel_destroy(channel); - aws_socket_clean_up(connection_args->channel_data.socket); - aws_mem_release(allocator, connection_args->channel_data.socket); - s_client_connection_args_release(connection_args); -} - -static bool s_aws_socket_domain_uses_dns(enum aws_socket_domain domain) { - return domain == AWS_SOCKET_IPV4 || domain == AWS_SOCKET_IPV6; -} - -static void s_on_client_connection_established(struct aws_socket *socket, int error_code, void *user_data) { - struct client_connection_args *connection_args = user_data; - - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: client connection on socket %p completed with error %d.", - (void *)connection_args->bootstrap, - (void *)socket, - error_code); - - if (error_code) { - connection_args->failed_count++; - } - - if (error_code || connection_args->connection_chosen) { - if (s_aws_socket_domain_uses_dns(connection_args->outgoing_options.domain) && error_code) { - struct aws_host_address host_address; - host_address.host = connection_args->host_name; - host_address.address = - aws_string_new_from_c_str(connection_args->bootstrap->allocator, socket->remote_endpoint.address); - host_address.record_type = connection_args->outgoing_options.domain == AWS_SOCKET_IPV6 - ? AWS_ADDRESS_RECORD_TYPE_AAAA - : AWS_ADDRESS_RECORD_TYPE_A; - - if (host_address.address) { - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: recording bad address %s.", - (void *)connection_args->bootstrap, - socket->remote_endpoint.address); - aws_host_resolver_record_connection_failure(connection_args->bootstrap->host_resolver, &host_address); - aws_string_destroy((void *)host_address.address); - } - } - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: releasing socket %p either because we already have a " - "successful connection or because it errored out.", - (void *)connection_args->bootstrap, - (void *)socket); - aws_socket_close(socket); - - aws_socket_clean_up(socket); - aws_mem_release(connection_args->bootstrap->allocator, socket); - - /* if this is the last attempted connection and it failed, notify the user */ - if (connection_args->failed_count == connection_args->addresses_count) { - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Connection failed with error_code %d.", - (void *)connection_args->bootstrap, - error_code); - /* connection_args will be released after setup_callback */ - s_connection_args_setup_callback(connection_args, error_code, NULL); - } - - /* every connection task adds a ref, so every failure or cancel needs to dec one */ - s_client_connection_args_release(connection_args); - return; - } - - connection_args->connection_chosen = true; - connection_args->channel_data.socket = socket; - - struct aws_channel_options args = { - .on_setup_completed = s_on_client_channel_on_setup_completed, - .setup_user_data = connection_args, - .shutdown_user_data = connection_args, - .on_shutdown_completed = s_on_client_channel_on_shutdown, - }; - - args.enable_read_back_pressure = connection_args->enable_read_back_pressure; - args.event_loop = aws_socket_get_event_loop(socket); - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Successful connection, creating a new channel using socket %p.", - (void *)connection_args->bootstrap, - (void *)socket); - - connection_args->channel_data.channel = aws_channel_new(connection_args->bootstrap->allocator, &args); - - if (!connection_args->channel_data.channel) { - aws_socket_clean_up(socket); - aws_mem_release(connection_args->bootstrap->allocator, connection_args->channel_data.socket); - connection_args->failed_count++; - - /* if this is the last attempted connection and it failed, notify the user */ - if (connection_args->failed_count == connection_args->addresses_count) { - s_connection_args_setup_callback(connection_args, error_code, NULL); - } - } else { - s_connection_args_creation_callback(connection_args, connection_args->channel_data.channel); - } -} - -struct connection_task_data { - struct aws_task task; - struct aws_socket_endpoint endpoint; - struct aws_socket_options options; - struct aws_host_address host_address; - struct client_connection_args *args; - struct aws_event_loop *connect_loop; -}; - -static void s_attempt_connection(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)task; - struct connection_task_data *task_data = arg; - struct aws_allocator *allocator = task_data->args->bootstrap->allocator; - int err_code = 0; - - if (status != AWS_TASK_STATUS_RUN_READY) { - goto task_cancelled; - } - - struct aws_socket *outgoing_socket = aws_mem_acquire(allocator, sizeof(struct aws_socket)); - if (!outgoing_socket) { - goto socket_alloc_failed; - } - - if (aws_socket_init(outgoing_socket, allocator, &task_data->options)) { - goto socket_init_failed; - } - - if (aws_socket_connect( - outgoing_socket, - &task_data->endpoint, - task_data->connect_loop, - s_on_client_connection_established, - task_data->args)) { - - goto socket_connect_failed; - } - - goto cleanup_task; - -socket_connect_failed: - aws_host_resolver_record_connection_failure(task_data->args->bootstrap->host_resolver, &task_data->host_address); - aws_socket_clean_up(outgoing_socket); -socket_init_failed: - aws_mem_release(allocator, outgoing_socket); -socket_alloc_failed: - err_code = aws_last_error(); - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: failed to create socket with error %d", - (void *)task_data->args->bootstrap, - err_code); -task_cancelled: - task_data->args->failed_count++; - /* if this is the last attempted connection and it failed, notify the user */ - if (task_data->args->failed_count == task_data->args->addresses_count) { - s_connection_args_setup_callback(task_data->args, err_code, NULL); - } - s_client_connection_args_release(task_data->args); - -cleanup_task: - aws_host_address_clean_up(&task_data->host_address); - aws_mem_release(allocator, task_data); -} - -static void s_on_host_resolved( - struct aws_host_resolver *resolver, - const struct aws_string *host_name, - int err_code, - const struct aws_array_list *host_addresses, - void *user_data) { - (void)resolver; - (void)host_name; - - struct client_connection_args *client_connection_args = user_data; - struct aws_allocator *allocator = client_connection_args->bootstrap->allocator; - - if (err_code) { - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: dns resolution failed, or all socket connections to the endpoint failed.", - (void *)client_connection_args->bootstrap); - s_connection_args_setup_callback(client_connection_args, err_code, NULL); - return; - } - - size_t host_addresses_len = aws_array_list_length(host_addresses); - AWS_FATAL_ASSERT(host_addresses_len > 0); - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: dns resolution completed. Kicking off connections" - " on %llu addresses. First one back wins.", - (void *)client_connection_args->bootstrap, - (unsigned long long)host_addresses_len); - /* use this event loop for all outgoing connection attempts (only one will ultimately win). */ - struct aws_event_loop *connect_loop = - aws_event_loop_group_get_next_loop(client_connection_args->bootstrap->event_loop_group); - client_connection_args->addresses_count = (uint8_t)host_addresses_len; - - /* allocate all the task data first, in case it fails... */ - AWS_VARIABLE_LENGTH_ARRAY(struct connection_task_data *, tasks, host_addresses_len); - for (size_t i = 0; i < host_addresses_len; ++i) { - struct connection_task_data *task_data = tasks[i] = - aws_mem_calloc(allocator, 1, sizeof(struct connection_task_data)); - bool failed = task_data == NULL; - if (!failed) { - struct aws_host_address *host_address_ptr = NULL; - aws_array_list_get_at_ptr(host_addresses, (void **)&host_address_ptr, i); - - task_data->endpoint.port = client_connection_args->outgoing_port; - AWS_ASSERT(sizeof(task_data->endpoint.address) >= host_address_ptr->address->len + 1); - memcpy( - task_data->endpoint.address, - aws_string_bytes(host_address_ptr->address), - host_address_ptr->address->len); - task_data->endpoint.address[host_address_ptr->address->len] = 0; - - task_data->options = client_connection_args->outgoing_options; - task_data->options.domain = - host_address_ptr->record_type == AWS_ADDRESS_RECORD_TYPE_AAAA ? AWS_SOCKET_IPV6 : AWS_SOCKET_IPV4; - - failed = aws_host_address_copy(host_address_ptr, &task_data->host_address) != AWS_OP_SUCCESS; - task_data->args = client_connection_args; - task_data->connect_loop = connect_loop; - } - - if (failed) { - for (size_t j = 0; j <= i; ++j) { - if (tasks[j]) { - aws_host_address_clean_up(&tasks[j]->host_address); - aws_mem_release(allocator, tasks[j]); - } - } - int alloc_err_code = aws_last_error(); - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: failed to allocate connection task data: err=%d", - (void *)client_connection_args->bootstrap, - alloc_err_code); - s_connection_args_setup_callback(client_connection_args, alloc_err_code, NULL); - return; - } - } - - /* ...then schedule all the tasks, which cannot fail */ - for (size_t i = 0; i < host_addresses_len; ++i) { - struct connection_task_data *task_data = tasks[i]; - /* each task needs to hold a ref to the args until completed */ - s_client_connection_args_acquire(task_data->args); - - aws_task_init(&task_data->task, s_attempt_connection, task_data, "attempt_connection"); - aws_event_loop_schedule_task_now(connect_loop, &task_data->task); - } -} - -int aws_client_bootstrap_new_socket_channel(struct aws_socket_channel_bootstrap_options *options) { - - struct aws_client_bootstrap *bootstrap = options->bootstrap; - AWS_FATAL_ASSERT(options->setup_callback); - AWS_FATAL_ASSERT(options->shutdown_callback); - AWS_FATAL_ASSERT(bootstrap); - - const struct aws_socket_options *socket_options = options->socket_options; - AWS_FATAL_ASSERT(socket_options != NULL); - - const struct aws_tls_connection_options *tls_options = options->tls_options; - - AWS_FATAL_ASSERT(tls_options == NULL || socket_options->type == AWS_SOCKET_STREAM); - aws_io_fatal_assert_library_initialized(); - - struct client_connection_args *client_connection_args = - aws_mem_calloc(bootstrap->allocator, 1, sizeof(struct client_connection_args)); - - if (!client_connection_args) { - return AWS_OP_ERR; - } - - const char *host_name = options->host_name; - uint16_t port = options->port; - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: attempting to initialize a new client channel to %s:%d", - (void *)bootstrap, - host_name, - (int)port); - - aws_ref_count_init( - &client_connection_args->ref_count, - client_connection_args, - (aws_simple_completion_callback *)s_client_connection_args_destroy); - client_connection_args->user_data = options->user_data; - client_connection_args->bootstrap = aws_client_bootstrap_acquire(bootstrap); - client_connection_args->creation_callback = options->creation_callback; - client_connection_args->setup_callback = options->setup_callback; - client_connection_args->shutdown_callback = options->shutdown_callback; - client_connection_args->outgoing_options = *socket_options; - client_connection_args->outgoing_port = port; - client_connection_args->enable_read_back_pressure = options->enable_read_back_pressure; - - if (tls_options) { - if (aws_tls_connection_options_copy(&client_connection_args->channel_data.tls_options, tls_options)) { - goto error; - } - client_connection_args->channel_data.use_tls = true; - - client_connection_args->channel_data.on_protocol_negotiated = bootstrap->on_protocol_negotiated; - client_connection_args->channel_data.tls_user_data = tls_options->user_data; - - /* in order to honor any callbacks a user may have installed on their tls_connection_options, - * we need to wrap them if they were set.*/ - if (bootstrap->on_protocol_negotiated) { - client_connection_args->channel_data.tls_options.advertise_alpn_message = true; - } - - if (tls_options->on_data_read) { - client_connection_args->channel_data.user_on_data_read = tls_options->on_data_read; - client_connection_args->channel_data.tls_options.on_data_read = s_tls_client_on_data_read; - } - - if (tls_options->on_error) { - client_connection_args->channel_data.user_on_error = tls_options->on_error; - client_connection_args->channel_data.tls_options.on_error = s_tls_client_on_error; - } - - if (tls_options->on_negotiation_result) { - client_connection_args->channel_data.user_on_negotiation_result = tls_options->on_negotiation_result; - } - - client_connection_args->channel_data.tls_options.on_negotiation_result = s_tls_client_on_negotiation_result; - client_connection_args->channel_data.tls_options.user_data = client_connection_args; - } - - if (s_aws_socket_domain_uses_dns(socket_options->domain)) { - client_connection_args->host_name = aws_string_new_from_c_str(bootstrap->allocator, host_name); - - if (!client_connection_args->host_name) { - goto error; - } - - if (aws_host_resolver_resolve_host( - bootstrap->host_resolver, - client_connection_args->host_name, - s_on_host_resolved, - &bootstrap->host_resolver_config, - client_connection_args)) { - goto error; - } - } else { - /* ensure that the pipe/domain socket name will fit in the endpoint address */ - const size_t host_name_len = strlen(host_name); - if (host_name_len >= AWS_ADDRESS_MAX_LEN) { - aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS); - goto error; - } - - struct aws_socket_endpoint endpoint; - AWS_ZERO_STRUCT(endpoint); - memcpy(endpoint.address, host_name, host_name_len); - if (socket_options->domain == AWS_SOCKET_VSOCK) { - endpoint.port = port; - } else { - endpoint.port = 0; - } - - struct aws_socket *outgoing_socket = aws_mem_acquire(bootstrap->allocator, sizeof(struct aws_socket)); - - if (!outgoing_socket) { - goto error; - } - - if (aws_socket_init(outgoing_socket, bootstrap->allocator, socket_options)) { - aws_mem_release(bootstrap->allocator, outgoing_socket); - goto error; - } - - client_connection_args->addresses_count = 1; - - struct aws_event_loop *connect_loop = aws_event_loop_group_get_next_loop(bootstrap->event_loop_group); - - s_client_connection_args_acquire(client_connection_args); - if (aws_socket_connect( - outgoing_socket, &endpoint, connect_loop, s_on_client_connection_established, client_connection_args)) { - aws_socket_clean_up(outgoing_socket); - aws_mem_release(client_connection_args->bootstrap->allocator, outgoing_socket); - s_client_connection_args_release(client_connection_args); - goto error; - } - } - - return AWS_OP_SUCCESS; - -error: - if (client_connection_args) { - /* tls opt will also be freed when we clean up the connection arg */ - s_client_connection_args_release(client_connection_args); - } - return AWS_OP_ERR; -} - -void s_server_bootstrap_destroy_impl(struct aws_server_bootstrap *bootstrap) { - AWS_ASSERT(bootstrap); - aws_event_loop_group_release(bootstrap->event_loop_group); - aws_mem_release(bootstrap->allocator, bootstrap); -} - -struct aws_server_bootstrap *aws_server_bootstrap_acquire(struct aws_server_bootstrap *bootstrap) { - if (bootstrap != NULL) { - aws_ref_count_acquire(&bootstrap->ref_count); - } - - return bootstrap; -} - -void aws_server_bootstrap_release(struct aws_server_bootstrap *bootstrap) { - /* if destroy is being called, the user intends to not use the bootstrap anymore - * so we clean up the thread local state while the event loop thread is - * still alive */ - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: releasing server bootstrap reference", (void *)bootstrap); - if (bootstrap != NULL) { - aws_ref_count_release(&bootstrap->ref_count); - } -} - -struct aws_server_bootstrap *aws_server_bootstrap_new( - struct aws_allocator *allocator, - struct aws_event_loop_group *el_group) { - AWS_ASSERT(allocator); - AWS_ASSERT(el_group); - - struct aws_server_bootstrap *bootstrap = aws_mem_calloc(allocator, 1, sizeof(struct aws_server_bootstrap)); - if (!bootstrap) { - return NULL; - } - - AWS_LOGF_INFO( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Initializing server bootstrap with event-loop group %p", - (void *)bootstrap, - (void *)el_group); - - bootstrap->allocator = allocator; - bootstrap->event_loop_group = aws_event_loop_group_acquire(el_group); - bootstrap->on_protocol_negotiated = NULL; - aws_ref_count_init( - &bootstrap->ref_count, bootstrap, (aws_simple_completion_callback *)s_server_bootstrap_destroy_impl); - - return bootstrap; -} - -struct server_connection_args { - struct aws_server_bootstrap *bootstrap; - struct aws_socket listener; - aws_server_bootstrap_on_accept_channel_setup_fn *incoming_callback; - aws_server_bootstrap_on_accept_channel_shutdown_fn *shutdown_callback; - aws_server_bootstrap_on_server_listener_destroy_fn *destroy_callback; - struct aws_tls_connection_options tls_options; - aws_channel_on_protocol_negotiated_fn *on_protocol_negotiated; - aws_tls_on_data_read_fn *user_on_data_read; - aws_tls_on_negotiation_result_fn *user_on_negotiation_result; - aws_tls_on_error_fn *user_on_error; - struct aws_task listener_destroy_task; - void *tls_user_data; - void *user_data; - bool use_tls; - bool enable_read_back_pressure; - struct aws_ref_count ref_count; -}; - -struct server_channel_data { - struct aws_channel *channel; - struct aws_socket *socket; - struct server_connection_args *server_connection_args; - bool incoming_called; -}; - -static struct server_connection_args *s_server_connection_args_acquire(struct server_connection_args *args) { - if (args != NULL) { - aws_ref_count_acquire(&args->ref_count); - } - - return args; -} - -static void s_server_connection_args_destroy(struct server_connection_args *args) { - if (args == NULL) { - return; - } - - /* fire the destroy callback */ - if (args->destroy_callback) { - args->destroy_callback(args->bootstrap, args->user_data); - } - - struct aws_allocator *allocator = args->bootstrap->allocator; - aws_server_bootstrap_release(args->bootstrap); - if (args->use_tls) { - aws_tls_connection_options_clean_up(&args->tls_options); - } - - aws_mem_release(allocator, args); -} - -static void s_server_connection_args_release(struct server_connection_args *args) { - if (args != NULL) { - aws_ref_count_release(&args->ref_count); - } -} - -static void s_server_incoming_callback( - struct server_channel_data *channel_data, - int error_code, - struct aws_channel *channel) { - /* incoming_callback is always called exactly once for each channel */ - AWS_ASSERT(!channel_data->incoming_called); - struct server_connection_args *args = channel_data->server_connection_args; - args->incoming_callback(args->bootstrap, error_code, channel, args->user_data); - channel_data->incoming_called = true; -} - -static void s_tls_server_on_negotiation_result( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - int err_code, - void *user_data) { - struct server_channel_data *channel_data = user_data; - struct server_connection_args *connection_args = channel_data->server_connection_args; - - if (connection_args->user_on_negotiation_result) { - connection_args->user_on_negotiation_result(handler, slot, err_code, connection_args->tls_user_data); - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: tls negotiation result %d on channel %p", - (void *)connection_args->bootstrap, - err_code, - (void *)slot->channel); - - struct aws_channel *channel = slot->channel; - if (err_code) { - /* shut down the channel */ - aws_channel_shutdown(channel, err_code); - } else { - s_server_incoming_callback(channel_data, err_code, channel); - } -} - -/* in the context of a channel bootstrap, we don't care about these, but since we're hooking into these APIs we have to - * provide a proxy for the user actually receiving their callbacks. */ -static void s_tls_server_on_data_read( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_byte_buf *buffer, - void *user_data) { - struct server_connection_args *connection_args = user_data; - - if (connection_args->user_on_data_read) { - connection_args->user_on_data_read(handler, slot, buffer, connection_args->tls_user_data); - } -} - -/* in the context of a channel bootstrap, we don't care about these, but since we're hooking into these APIs we have to - * provide a proxy for the user actually receiving their callbacks. */ -static void s_tls_server_on_error( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - int err, - const char *message, - void *user_data) { - struct server_connection_args *connection_args = user_data; - - if (connection_args->user_on_error) { - connection_args->user_on_error(handler, slot, err, message, connection_args->tls_user_data); - } -} - -static inline int s_setup_server_tls(struct server_channel_data *channel_data, struct aws_channel *channel) { - struct aws_channel_slot *tls_slot = NULL; - struct aws_channel_handler *tls_handler = NULL; - struct server_connection_args *connection_args = channel_data->server_connection_args; - - /* as far as cleanup goes here, since we're adding things to a channel, if a slot is ever successfully - added to the channel, we leave it there. The caller will clean up the channel and it will clean this memory - up as well. */ - tls_slot = aws_channel_slot_new(channel); - - if (!tls_slot) { - return AWS_OP_ERR; - } - - /* Shallow-copy tls_options so we can override the user_data, making it specific to this channel */ - struct aws_tls_connection_options tls_options = connection_args->tls_options; - tls_options.user_data = channel_data; - tls_handler = aws_tls_server_handler_new(connection_args->bootstrap->allocator, &tls_options, tls_slot); - - if (!tls_handler) { - aws_mem_release(connection_args->bootstrap->allocator, tls_slot); - return AWS_OP_ERR; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Setting up server TLS on channel %p with handler %p on slot %p", - (void *)connection_args->bootstrap, - (void *)channel, - (void *)tls_handler, - (void *)tls_slot); - - aws_channel_slot_insert_end(channel, tls_slot); - - if (aws_channel_slot_set_handler(tls_slot, tls_handler)) { - return AWS_OP_ERR; - } - - if (connection_args->on_protocol_negotiated) { - struct aws_channel_slot *alpn_slot = NULL; - struct aws_channel_handler *alpn_handler = NULL; - alpn_slot = aws_channel_slot_new(channel); - - if (!alpn_slot) { - return AWS_OP_ERR; - } - - alpn_handler = aws_tls_alpn_handler_new( - connection_args->bootstrap->allocator, connection_args->on_protocol_negotiated, connection_args->user_data); - - if (!alpn_handler) { - aws_channel_slot_remove(alpn_slot); - return AWS_OP_ERR; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Setting up ALPN handler on channel " - "%p with handler %p on slot %p", - (void *)connection_args->bootstrap, - (void *)channel, - (void *)alpn_handler, - (void *)alpn_slot); - - aws_channel_slot_insert_right(tls_slot, alpn_slot); - - if (aws_channel_slot_set_handler(alpn_slot, alpn_handler)) { - return AWS_OP_ERR; - } - } - - return AWS_OP_SUCCESS; -} - -static void s_on_server_channel_on_setup_completed(struct aws_channel *channel, int error_code, void *user_data) { - struct server_channel_data *channel_data = user_data; - - int err_code = error_code; - if (err_code) { - /* channel fail to set up no destroy callback will fire */ - AWS_LOGF_ERROR( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: channel %p setup failed with error %d.", - (void *)channel_data->server_connection_args->bootstrap, - (void *)channel, - err_code); - - aws_channel_destroy(channel); - struct aws_allocator *allocator = channel_data->socket->allocator; - aws_socket_clean_up(channel_data->socket); - aws_mem_release(allocator, (void *)channel_data->socket); - s_server_incoming_callback(channel_data, err_code, NULL); - aws_mem_release(channel_data->server_connection_args->bootstrap->allocator, channel_data); - /* no shutdown call back will be fired, we release the ref_count of connection arg here */ - s_server_connection_args_release(channel_data->server_connection_args); - return; - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: channel %p setup succeeded: bootstrapping.", - (void *)channel_data->server_connection_args->bootstrap, - (void *)channel); - - struct aws_channel_slot *socket_slot = aws_channel_slot_new(channel); - - if (!socket_slot) { - err_code = aws_last_error(); - goto error; - } - - struct aws_channel_handler *socket_channel_handler = aws_socket_handler_new( - channel_data->server_connection_args->bootstrap->allocator, - channel_data->socket, - socket_slot, - g_aws_channel_max_fragment_size); - - if (!socket_channel_handler) { - err_code = aws_last_error(); - aws_channel_slot_remove(socket_slot); - socket_slot = NULL; - goto error; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: Setting up socket handler on channel " - "%p with handler %p on slot %p.", - (void *)channel_data->server_connection_args->bootstrap, - (void *)channel, - (void *)socket_channel_handler, - (void *)socket_slot); - - if (aws_channel_slot_set_handler(socket_slot, socket_channel_handler)) { - err_code = aws_last_error(); - goto error; - } - - if (channel_data->server_connection_args->use_tls) { - /* incoming callback will be invoked upon the negotiation completion so don't do it - * here. */ - if (s_setup_server_tls(channel_data, channel)) { - err_code = aws_last_error(); - goto error; - } - } else { - s_server_incoming_callback(channel_data, AWS_OP_SUCCESS, channel); - } - return; - -error: - /* shut down the channel */ - aws_channel_shutdown(channel, err_code); -} - -static void s_on_server_channel_on_shutdown(struct aws_channel *channel, int error_code, void *user_data) { - struct server_channel_data *channel_data = user_data; - struct server_connection_args *args = channel_data->server_connection_args; - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: channel %p shutdown with error %d.", - (void *)args->bootstrap, - (void *)channel, - error_code); - - void *server_shutdown_user_data = args->user_data; - struct aws_server_bootstrap *server_bootstrap = args->bootstrap; - struct aws_allocator *allocator = server_bootstrap->allocator; - - if (!channel_data->incoming_called) { - error_code = (error_code) ? error_code : AWS_ERROR_UNKNOWN; - s_server_incoming_callback(channel_data, error_code, NULL); - } else { - args->shutdown_callback(server_bootstrap, error_code, channel, server_shutdown_user_data); - } - - aws_channel_destroy(channel); - aws_socket_clean_up(channel_data->socket); - aws_mem_release(allocator, channel_data->socket); - s_server_connection_args_release(channel_data->server_connection_args); - - aws_mem_release(allocator, channel_data); -} - -void s_on_server_connection_result( - struct aws_socket *socket, - int error_code, - struct aws_socket *new_socket, - void *user_data) { - (void)socket; - struct server_connection_args *connection_args = user_data; - - s_server_connection_args_acquire(connection_args); - AWS_LOGF_DEBUG( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: server connection on socket %p completed with error %d.", - (void *)connection_args->bootstrap, - (void *)socket, - error_code); - - if (!error_code) { - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: creating a new channel for incoming " - "connection using socket %p.", - (void *)connection_args->bootstrap, - (void *)socket); - struct server_channel_data *channel_data = - aws_mem_calloc(connection_args->bootstrap->allocator, 1, sizeof(struct server_channel_data)); - if (!channel_data) { - goto error_cleanup; - } - channel_data->incoming_called = false; - channel_data->socket = new_socket; - channel_data->server_connection_args = connection_args; - - struct aws_event_loop *event_loop = - aws_event_loop_group_get_next_loop(connection_args->bootstrap->event_loop_group); - - struct aws_channel_options channel_args = { - .on_setup_completed = s_on_server_channel_on_setup_completed, - .setup_user_data = channel_data, - .shutdown_user_data = channel_data, - .on_shutdown_completed = s_on_server_channel_on_shutdown, - }; - - channel_args.event_loop = event_loop; - channel_args.enable_read_back_pressure = channel_data->server_connection_args->enable_read_back_pressure; - - if (aws_socket_assign_to_event_loop(new_socket, event_loop)) { - aws_mem_release(connection_args->bootstrap->allocator, (void *)channel_data); - goto error_cleanup; - } - - channel_data->channel = aws_channel_new(connection_args->bootstrap->allocator, &channel_args); - - if (!channel_data->channel) { - aws_mem_release(connection_args->bootstrap->allocator, (void *)channel_data); - goto error_cleanup; - } - } else { - /* no channel is created */ - connection_args->incoming_callback(connection_args->bootstrap, error_code, NULL, connection_args->user_data); - s_server_connection_args_release(connection_args); - } - - return; - -error_cleanup: - /* no channel is created */ - connection_args->incoming_callback(connection_args->bootstrap, aws_last_error(), NULL, connection_args->user_data); - - struct aws_allocator *allocator = new_socket->allocator; - aws_socket_clean_up(new_socket); - aws_mem_release(allocator, (void *)new_socket); - s_server_connection_args_release(connection_args); -} - -static void s_listener_destroy_task(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)status; - (void)task; - struct server_connection_args *server_connection_args = arg; - - aws_socket_stop_accept(&server_connection_args->listener); - aws_socket_clean_up(&server_connection_args->listener); - s_server_connection_args_release(server_connection_args); -} - -struct aws_socket *aws_server_bootstrap_new_socket_listener( - const struct aws_server_socket_channel_bootstrap_options *bootstrap_options) { - AWS_PRECONDITION(bootstrap_options); - AWS_PRECONDITION(bootstrap_options->bootstrap); - AWS_PRECONDITION(bootstrap_options->incoming_callback) - AWS_PRECONDITION(bootstrap_options->shutdown_callback) - - struct server_connection_args *server_connection_args = - aws_mem_calloc(bootstrap_options->bootstrap->allocator, 1, sizeof(struct server_connection_args)); - if (!server_connection_args) { - return NULL; - } - - AWS_LOGF_INFO( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: attempting to initialize a new " - "server socket listener for %s:%d", - (void *)server_connection_args->bootstrap, - bootstrap_options->host_name, - (int)bootstrap_options->port); - - aws_ref_count_init( - &server_connection_args->ref_count, - server_connection_args, - (aws_simple_completion_callback *)s_server_connection_args_destroy); - server_connection_args->user_data = bootstrap_options->user_data; - server_connection_args->bootstrap = aws_server_bootstrap_acquire(bootstrap_options->bootstrap); - server_connection_args->shutdown_callback = bootstrap_options->shutdown_callback; - server_connection_args->incoming_callback = bootstrap_options->incoming_callback; - server_connection_args->destroy_callback = bootstrap_options->destroy_callback; - server_connection_args->on_protocol_negotiated = bootstrap_options->bootstrap->on_protocol_negotiated; - server_connection_args->enable_read_back_pressure = bootstrap_options->enable_read_back_pressure; - - aws_task_init( - &server_connection_args->listener_destroy_task, - s_listener_destroy_task, - server_connection_args, - "listener socket destroy"); - - if (bootstrap_options->tls_options) { - AWS_LOGF_INFO( - AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: using tls on listener", (void *)bootstrap_options->tls_options); - if (aws_tls_connection_options_copy(&server_connection_args->tls_options, bootstrap_options->tls_options)) { - goto cleanup_server_connection_args; - } - - server_connection_args->use_tls = true; - - server_connection_args->tls_user_data = bootstrap_options->tls_options->user_data; - - /* in order to honor any callbacks a user may have installed on their tls_connection_options, - * we need to wrap them if they were set.*/ - if (bootstrap_options->bootstrap->on_protocol_negotiated) { - server_connection_args->tls_options.advertise_alpn_message = true; - } - - if (bootstrap_options->tls_options->on_data_read) { - server_connection_args->user_on_data_read = bootstrap_options->tls_options->on_data_read; - server_connection_args->tls_options.on_data_read = s_tls_server_on_data_read; - } - - if (bootstrap_options->tls_options->on_error) { - server_connection_args->user_on_error = bootstrap_options->tls_options->on_error; - server_connection_args->tls_options.on_error = s_tls_server_on_error; - } - - if (bootstrap_options->tls_options->on_negotiation_result) { - server_connection_args->user_on_negotiation_result = bootstrap_options->tls_options->on_negotiation_result; - } - - server_connection_args->tls_options.on_negotiation_result = s_tls_server_on_negotiation_result; - server_connection_args->tls_options.user_data = server_connection_args; - } - - struct aws_event_loop *connection_loop = - aws_event_loop_group_get_next_loop(bootstrap_options->bootstrap->event_loop_group); - - if (aws_socket_init( - &server_connection_args->listener, - bootstrap_options->bootstrap->allocator, - bootstrap_options->socket_options)) { - goto cleanup_server_connection_args; - } - - struct aws_socket_endpoint endpoint; - AWS_ZERO_STRUCT(endpoint); - size_t host_name_len = 0; - if (aws_secure_strlen(bootstrap_options->host_name, sizeof(endpoint.address), &host_name_len)) { - goto cleanup_server_connection_args; - } - - memcpy(endpoint.address, bootstrap_options->host_name, host_name_len); - endpoint.port = bootstrap_options->port; - - if (aws_socket_bind(&server_connection_args->listener, &endpoint)) { - goto cleanup_listener; - } - - if (aws_socket_listen(&server_connection_args->listener, 1024)) { - goto cleanup_listener; - } - - if (aws_socket_start_accept( - &server_connection_args->listener, - connection_loop, - s_on_server_connection_result, - server_connection_args)) { - goto cleanup_listener; - } - - return &server_connection_args->listener; - -cleanup_listener: - aws_socket_clean_up(&server_connection_args->listener); - -cleanup_server_connection_args: - s_server_connection_args_release(server_connection_args); - - return NULL; -} - -void aws_server_bootstrap_destroy_socket_listener(struct aws_server_bootstrap *bootstrap, struct aws_socket *listener) { - struct server_connection_args *server_connection_args = - AWS_CONTAINER_OF(listener, struct server_connection_args, listener); - - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: releasing bootstrap reference", (void *)bootstrap); - aws_event_loop_schedule_task_now(listener->event_loop, &server_connection_args->listener_destroy_task); -} - -int aws_server_bootstrap_set_alpn_callback( - struct aws_server_bootstrap *bootstrap, - aws_channel_on_protocol_negotiated_fn *on_protocol_negotiated) { - AWS_ASSERT(on_protocol_negotiated); - AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: Setting ALPN callback", (void *)bootstrap); - bootstrap->on_protocol_negotiated = on_protocol_negotiated; - return AWS_OP_SUCCESS; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/channel_bootstrap.h> + +#include <aws/common/ref_count.h> +#include <aws/common/string.h> +#include <aws/io/event_loop.h> +#include <aws/io/logging.h> +#include <aws/io/socket.h> +#include <aws/io/socket_channel_handler.h> +#include <aws/io/tls_channel_handler.h> + +#if _MSC_VER +/* non-constant aggregate initializer */ +# pragma warning(disable : 4204) +/* allow automatic variable to escape scope + (it's intentional and we make sure it doesn't actually return + before the task is finished).*/ +# pragma warning(disable : 4221) +#endif + +#define DEFAULT_DNS_TTL 30 + +static void s_client_bootstrap_destroy_impl(struct aws_client_bootstrap *bootstrap) { + AWS_ASSERT(bootstrap); + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: destroying", (void *)bootstrap); + aws_client_bootstrap_shutdown_complete_fn *on_shutdown_complete = bootstrap->on_shutdown_complete; + void *user_data = bootstrap->user_data; + + aws_event_loop_group_release(bootstrap->event_loop_group); + aws_host_resolver_release(bootstrap->host_resolver); + + aws_mem_release(bootstrap->allocator, bootstrap); + + if (on_shutdown_complete) { + on_shutdown_complete(user_data); + } +} + +struct aws_client_bootstrap *aws_client_bootstrap_acquire(struct aws_client_bootstrap *bootstrap) { + if (bootstrap != NULL) { + aws_ref_count_acquire(&bootstrap->ref_count); + } + + return bootstrap; +} + +void aws_client_bootstrap_release(struct aws_client_bootstrap *bootstrap) { + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: releasing bootstrap reference", (void *)bootstrap); + if (bootstrap != NULL) { + aws_ref_count_release(&bootstrap->ref_count); + } +} + +struct aws_client_bootstrap *aws_client_bootstrap_new( + struct aws_allocator *allocator, + const struct aws_client_bootstrap_options *options) { + AWS_ASSERT(allocator); + AWS_ASSERT(options); + AWS_ASSERT(options->event_loop_group); + + struct aws_client_bootstrap *bootstrap = aws_mem_calloc(allocator, 1, sizeof(struct aws_client_bootstrap)); + if (!bootstrap) { + return NULL; + } + + AWS_LOGF_INFO( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Initializing client bootstrap with event-loop group %p", + (void *)bootstrap, + (void *)options->event_loop_group); + + bootstrap->allocator = allocator; + bootstrap->event_loop_group = aws_event_loop_group_acquire(options->event_loop_group); + bootstrap->on_protocol_negotiated = NULL; + aws_ref_count_init( + &bootstrap->ref_count, bootstrap, (aws_simple_completion_callback *)s_client_bootstrap_destroy_impl); + bootstrap->host_resolver = aws_host_resolver_acquire(options->host_resolver); + bootstrap->on_shutdown_complete = options->on_shutdown_complete; + bootstrap->user_data = options->user_data; + + if (options->host_resolution_config) { + bootstrap->host_resolver_config = *options->host_resolution_config; + } else { + bootstrap->host_resolver_config = (struct aws_host_resolution_config){ + .impl = aws_default_dns_resolve, + .max_ttl = DEFAULT_DNS_TTL, + .impl_data = NULL, + }; + } + + return bootstrap; +} + +int aws_client_bootstrap_set_alpn_callback( + struct aws_client_bootstrap *bootstrap, + aws_channel_on_protocol_negotiated_fn *on_protocol_negotiated) { + AWS_ASSERT(on_protocol_negotiated); + + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: Setting ALPN callback", (void *)bootstrap); + bootstrap->on_protocol_negotiated = on_protocol_negotiated; + return AWS_OP_SUCCESS; +} + +struct client_channel_data { + struct aws_channel *channel; + struct aws_socket *socket; + struct aws_tls_connection_options tls_options; + aws_channel_on_protocol_negotiated_fn *on_protocol_negotiated; + aws_tls_on_data_read_fn *user_on_data_read; + aws_tls_on_negotiation_result_fn *user_on_negotiation_result; + aws_tls_on_error_fn *user_on_error; + void *tls_user_data; + bool use_tls; +}; + +struct client_connection_args { + struct aws_client_bootstrap *bootstrap; + aws_client_bootstrap_on_channel_event_fn *creation_callback; + aws_client_bootstrap_on_channel_event_fn *setup_callback; + aws_client_bootstrap_on_channel_event_fn *shutdown_callback; + struct client_channel_data channel_data; + struct aws_socket_options outgoing_options; + uint16_t outgoing_port; + struct aws_string *host_name; + void *user_data; + uint8_t addresses_count; + uint8_t failed_count; + bool connection_chosen; + bool setup_called; + bool enable_read_back_pressure; + + /* + * It is likely that all reference adjustments to the connection args take place in a single event loop + * thread and are thus thread-safe. I can imagine some complex future scenarios where that might not hold true + * and so it seems reasonable to switch now to a safe pattern. + * + */ + struct aws_ref_count ref_count; +}; + +static struct client_connection_args *s_client_connection_args_acquire(struct client_connection_args *args) { + if (args != NULL) { + aws_ref_count_acquire(&args->ref_count); + } + + return args; +} + +static void s_client_connection_args_destroy(struct client_connection_args *args) { + AWS_ASSERT(args); + + struct aws_allocator *allocator = args->bootstrap->allocator; + aws_client_bootstrap_release(args->bootstrap); + if (args->host_name) { + aws_string_destroy(args->host_name); + } + + if (args->channel_data.use_tls) { + aws_tls_connection_options_clean_up(&args->channel_data.tls_options); + } + + aws_mem_release(allocator, args); +} + +static void s_client_connection_args_release(struct client_connection_args *args) { + if (args != NULL) { + aws_ref_count_release(&args->ref_count); + } +} + +static void s_connection_args_setup_callback( + struct client_connection_args *args, + int error_code, + struct aws_channel *channel) { + /* setup_callback is always called exactly once */ + AWS_ASSERT(!args->setup_called); + if (!args->setup_called) { + AWS_ASSERT((error_code == AWS_OP_SUCCESS) == (channel != NULL)); + aws_client_bootstrap_on_channel_event_fn *setup_callback = args->setup_callback; + setup_callback(args->bootstrap, error_code, channel, args->user_data); + args->setup_called = true; + /* if setup_callback is called with an error, we will not call shutdown_callback */ + if (error_code) { + args->shutdown_callback = NULL; + } + s_client_connection_args_release(args); + } +} + +static void s_connection_args_creation_callback(struct client_connection_args *args, struct aws_channel *channel) { + + AWS_FATAL_ASSERT(channel != NULL); + + if (args->creation_callback) { + args->creation_callback(args->bootstrap, AWS_ERROR_SUCCESS, channel, args->user_data); + } +} + +static void s_connection_args_shutdown_callback( + struct client_connection_args *args, + int error_code, + struct aws_channel *channel) { + + if (!args->setup_called) { + /* if setup_callback was not called yet, an error occurred, ensure we tell the user *SOMETHING* */ + error_code = (error_code) ? error_code : AWS_ERROR_UNKNOWN; + s_connection_args_setup_callback(args, error_code, NULL); + return; + } + + aws_client_bootstrap_on_channel_event_fn *shutdown_callback = args->shutdown_callback; + if (shutdown_callback) { + shutdown_callback(args->bootstrap, error_code, channel, args->user_data); + } +} + +static void s_tls_client_on_negotiation_result( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + int err_code, + void *user_data) { + struct client_connection_args *connection_args = user_data; + + if (connection_args->channel_data.user_on_negotiation_result) { + connection_args->channel_data.user_on_negotiation_result( + handler, slot, err_code, connection_args->channel_data.tls_user_data); + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: tls negotiation result %d on channel %p", + (void *)connection_args->bootstrap, + err_code, + (void *)slot->channel); + + /* if an error occurred, the user callback will be delivered in shutdown */ + if (err_code) { + aws_channel_shutdown(slot->channel, err_code); + return; + } + + struct aws_channel *channel = connection_args->channel_data.channel; + s_connection_args_setup_callback(connection_args, AWS_ERROR_SUCCESS, channel); +} + +/* in the context of a channel bootstrap, we don't care about these, but since we're hooking into these APIs we have to + * provide a proxy for the user actually receiving their callbacks. */ +static void s_tls_client_on_data_read( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_byte_buf *buffer, + void *user_data) { + struct client_connection_args *connection_args = user_data; + + if (connection_args->channel_data.user_on_data_read) { + connection_args->channel_data.user_on_data_read( + handler, slot, buffer, connection_args->channel_data.tls_user_data); + } +} + +/* in the context of a channel bootstrap, we don't care about these, but since we're hooking into these APIs we have to + * provide a proxy for the user actually receiving their callbacks. */ +static void s_tls_client_on_error( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + int err, + const char *message, + void *user_data) { + struct client_connection_args *connection_args = user_data; + + if (connection_args->channel_data.user_on_error) { + connection_args->channel_data.user_on_error( + handler, slot, err, message, connection_args->channel_data.tls_user_data); + } +} + +static inline int s_setup_client_tls(struct client_connection_args *connection_args, struct aws_channel *channel) { + struct aws_channel_slot *tls_slot = aws_channel_slot_new(channel); + + /* as far as cleanup goes, since this stuff is being added to a channel, the caller will free this memory + when they clean up the channel. */ + if (!tls_slot) { + return AWS_OP_ERR; + } + + struct aws_channel_handler *tls_handler = aws_tls_client_handler_new( + connection_args->bootstrap->allocator, &connection_args->channel_data.tls_options, tls_slot); + + if (!tls_handler) { + aws_mem_release(connection_args->bootstrap->allocator, (void *)tls_slot); + return AWS_OP_ERR; + } + + aws_channel_slot_insert_end(channel, tls_slot); + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Setting up client TLS on channel %p with handler %p on slot %p", + (void *)connection_args->bootstrap, + (void *)channel, + (void *)tls_handler, + (void *)tls_slot); + + if (aws_channel_slot_set_handler(tls_slot, tls_handler) != AWS_OP_SUCCESS) { + return AWS_OP_ERR; + } + + if (connection_args->channel_data.on_protocol_negotiated) { + struct aws_channel_slot *alpn_slot = aws_channel_slot_new(channel); + + if (!alpn_slot) { + return AWS_OP_ERR; + } + + struct aws_channel_handler *alpn_handler = aws_tls_alpn_handler_new( + connection_args->bootstrap->allocator, + connection_args->channel_data.on_protocol_negotiated, + connection_args->user_data); + + if (!alpn_handler) { + aws_mem_release(connection_args->bootstrap->allocator, (void *)alpn_slot); + return AWS_OP_ERR; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Setting up ALPN handler on channel " + "%p with handler %p on slot %p", + (void *)connection_args->bootstrap, + (void *)channel, + (void *)alpn_handler, + (void *)alpn_slot); + + aws_channel_slot_insert_right(tls_slot, alpn_slot); + if (aws_channel_slot_set_handler(alpn_slot, alpn_handler) != AWS_OP_SUCCESS) { + return AWS_OP_ERR; + } + } + + if (aws_tls_client_handler_start_negotiation(tls_handler) != AWS_OP_SUCCESS) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +static void s_on_client_channel_on_setup_completed(struct aws_channel *channel, int error_code, void *user_data) { + struct client_connection_args *connection_args = user_data; + int err_code = error_code; + + if (!err_code) { + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: channel %p setup succeeded: bootstrapping.", + (void *)connection_args->bootstrap, + (void *)channel); + + struct aws_channel_slot *socket_slot = aws_channel_slot_new(channel); + + if (!socket_slot) { + err_code = aws_last_error(); + goto error; + } + + struct aws_channel_handler *socket_channel_handler = aws_socket_handler_new( + connection_args->bootstrap->allocator, + connection_args->channel_data.socket, + socket_slot, + g_aws_channel_max_fragment_size); + + if (!socket_channel_handler) { + err_code = aws_last_error(); + aws_channel_slot_remove(socket_slot); + socket_slot = NULL; + goto error; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Setting up socket handler on channel " + "%p with handler %p on slot %p.", + (void *)connection_args->bootstrap, + (void *)channel, + (void *)socket_channel_handler, + (void *)socket_slot); + + if (aws_channel_slot_set_handler(socket_slot, socket_channel_handler)) { + err_code = aws_last_error(); + goto error; + } + + if (connection_args->channel_data.use_tls) { + /* we don't want to notify the user that the channel is ready yet, since tls is still negotiating, wait + * for the negotiation callback and handle it then.*/ + if (s_setup_client_tls(connection_args, channel)) { + err_code = aws_last_error(); + goto error; + } + } else { + s_connection_args_setup_callback(connection_args, AWS_OP_SUCCESS, channel); + } + + return; + } + +error: + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: channel %p setup failed with error %d.", + (void *)connection_args->bootstrap, + (void *)channel, + err_code); + aws_channel_shutdown(channel, err_code); + /* the channel shutdown callback will clean the channel up */ +} + +static void s_on_client_channel_on_shutdown(struct aws_channel *channel, int error_code, void *user_data) { + struct client_connection_args *connection_args = user_data; + + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: channel %p shutdown with error %d.", + (void *)connection_args->bootstrap, + (void *)channel, + error_code); + + /* note it's not safe to reference the bootstrap after the callback. */ + struct aws_allocator *allocator = connection_args->bootstrap->allocator; + s_connection_args_shutdown_callback(connection_args, error_code, channel); + + aws_channel_destroy(channel); + aws_socket_clean_up(connection_args->channel_data.socket); + aws_mem_release(allocator, connection_args->channel_data.socket); + s_client_connection_args_release(connection_args); +} + +static bool s_aws_socket_domain_uses_dns(enum aws_socket_domain domain) { + return domain == AWS_SOCKET_IPV4 || domain == AWS_SOCKET_IPV6; +} + +static void s_on_client_connection_established(struct aws_socket *socket, int error_code, void *user_data) { + struct client_connection_args *connection_args = user_data; + + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: client connection on socket %p completed with error %d.", + (void *)connection_args->bootstrap, + (void *)socket, + error_code); + + if (error_code) { + connection_args->failed_count++; + } + + if (error_code || connection_args->connection_chosen) { + if (s_aws_socket_domain_uses_dns(connection_args->outgoing_options.domain) && error_code) { + struct aws_host_address host_address; + host_address.host = connection_args->host_name; + host_address.address = + aws_string_new_from_c_str(connection_args->bootstrap->allocator, socket->remote_endpoint.address); + host_address.record_type = connection_args->outgoing_options.domain == AWS_SOCKET_IPV6 + ? AWS_ADDRESS_RECORD_TYPE_AAAA + : AWS_ADDRESS_RECORD_TYPE_A; + + if (host_address.address) { + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: recording bad address %s.", + (void *)connection_args->bootstrap, + socket->remote_endpoint.address); + aws_host_resolver_record_connection_failure(connection_args->bootstrap->host_resolver, &host_address); + aws_string_destroy((void *)host_address.address); + } + } + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: releasing socket %p either because we already have a " + "successful connection or because it errored out.", + (void *)connection_args->bootstrap, + (void *)socket); + aws_socket_close(socket); + + aws_socket_clean_up(socket); + aws_mem_release(connection_args->bootstrap->allocator, socket); + + /* if this is the last attempted connection and it failed, notify the user */ + if (connection_args->failed_count == connection_args->addresses_count) { + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Connection failed with error_code %d.", + (void *)connection_args->bootstrap, + error_code); + /* connection_args will be released after setup_callback */ + s_connection_args_setup_callback(connection_args, error_code, NULL); + } + + /* every connection task adds a ref, so every failure or cancel needs to dec one */ + s_client_connection_args_release(connection_args); + return; + } + + connection_args->connection_chosen = true; + connection_args->channel_data.socket = socket; + + struct aws_channel_options args = { + .on_setup_completed = s_on_client_channel_on_setup_completed, + .setup_user_data = connection_args, + .shutdown_user_data = connection_args, + .on_shutdown_completed = s_on_client_channel_on_shutdown, + }; + + args.enable_read_back_pressure = connection_args->enable_read_back_pressure; + args.event_loop = aws_socket_get_event_loop(socket); + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Successful connection, creating a new channel using socket %p.", + (void *)connection_args->bootstrap, + (void *)socket); + + connection_args->channel_data.channel = aws_channel_new(connection_args->bootstrap->allocator, &args); + + if (!connection_args->channel_data.channel) { + aws_socket_clean_up(socket); + aws_mem_release(connection_args->bootstrap->allocator, connection_args->channel_data.socket); + connection_args->failed_count++; + + /* if this is the last attempted connection and it failed, notify the user */ + if (connection_args->failed_count == connection_args->addresses_count) { + s_connection_args_setup_callback(connection_args, error_code, NULL); + } + } else { + s_connection_args_creation_callback(connection_args, connection_args->channel_data.channel); + } +} + +struct connection_task_data { + struct aws_task task; + struct aws_socket_endpoint endpoint; + struct aws_socket_options options; + struct aws_host_address host_address; + struct client_connection_args *args; + struct aws_event_loop *connect_loop; +}; + +static void s_attempt_connection(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + struct connection_task_data *task_data = arg; + struct aws_allocator *allocator = task_data->args->bootstrap->allocator; + int err_code = 0; + + if (status != AWS_TASK_STATUS_RUN_READY) { + goto task_cancelled; + } + + struct aws_socket *outgoing_socket = aws_mem_acquire(allocator, sizeof(struct aws_socket)); + if (!outgoing_socket) { + goto socket_alloc_failed; + } + + if (aws_socket_init(outgoing_socket, allocator, &task_data->options)) { + goto socket_init_failed; + } + + if (aws_socket_connect( + outgoing_socket, + &task_data->endpoint, + task_data->connect_loop, + s_on_client_connection_established, + task_data->args)) { + + goto socket_connect_failed; + } + + goto cleanup_task; + +socket_connect_failed: + aws_host_resolver_record_connection_failure(task_data->args->bootstrap->host_resolver, &task_data->host_address); + aws_socket_clean_up(outgoing_socket); +socket_init_failed: + aws_mem_release(allocator, outgoing_socket); +socket_alloc_failed: + err_code = aws_last_error(); + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: failed to create socket with error %d", + (void *)task_data->args->bootstrap, + err_code); +task_cancelled: + task_data->args->failed_count++; + /* if this is the last attempted connection and it failed, notify the user */ + if (task_data->args->failed_count == task_data->args->addresses_count) { + s_connection_args_setup_callback(task_data->args, err_code, NULL); + } + s_client_connection_args_release(task_data->args); + +cleanup_task: + aws_host_address_clean_up(&task_data->host_address); + aws_mem_release(allocator, task_data); +} + +static void s_on_host_resolved( + struct aws_host_resolver *resolver, + const struct aws_string *host_name, + int err_code, + const struct aws_array_list *host_addresses, + void *user_data) { + (void)resolver; + (void)host_name; + + struct client_connection_args *client_connection_args = user_data; + struct aws_allocator *allocator = client_connection_args->bootstrap->allocator; + + if (err_code) { + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: dns resolution failed, or all socket connections to the endpoint failed.", + (void *)client_connection_args->bootstrap); + s_connection_args_setup_callback(client_connection_args, err_code, NULL); + return; + } + + size_t host_addresses_len = aws_array_list_length(host_addresses); + AWS_FATAL_ASSERT(host_addresses_len > 0); + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: dns resolution completed. Kicking off connections" + " on %llu addresses. First one back wins.", + (void *)client_connection_args->bootstrap, + (unsigned long long)host_addresses_len); + /* use this event loop for all outgoing connection attempts (only one will ultimately win). */ + struct aws_event_loop *connect_loop = + aws_event_loop_group_get_next_loop(client_connection_args->bootstrap->event_loop_group); + client_connection_args->addresses_count = (uint8_t)host_addresses_len; + + /* allocate all the task data first, in case it fails... */ + AWS_VARIABLE_LENGTH_ARRAY(struct connection_task_data *, tasks, host_addresses_len); + for (size_t i = 0; i < host_addresses_len; ++i) { + struct connection_task_data *task_data = tasks[i] = + aws_mem_calloc(allocator, 1, sizeof(struct connection_task_data)); + bool failed = task_data == NULL; + if (!failed) { + struct aws_host_address *host_address_ptr = NULL; + aws_array_list_get_at_ptr(host_addresses, (void **)&host_address_ptr, i); + + task_data->endpoint.port = client_connection_args->outgoing_port; + AWS_ASSERT(sizeof(task_data->endpoint.address) >= host_address_ptr->address->len + 1); + memcpy( + task_data->endpoint.address, + aws_string_bytes(host_address_ptr->address), + host_address_ptr->address->len); + task_data->endpoint.address[host_address_ptr->address->len] = 0; + + task_data->options = client_connection_args->outgoing_options; + task_data->options.domain = + host_address_ptr->record_type == AWS_ADDRESS_RECORD_TYPE_AAAA ? AWS_SOCKET_IPV6 : AWS_SOCKET_IPV4; + + failed = aws_host_address_copy(host_address_ptr, &task_data->host_address) != AWS_OP_SUCCESS; + task_data->args = client_connection_args; + task_data->connect_loop = connect_loop; + } + + if (failed) { + for (size_t j = 0; j <= i; ++j) { + if (tasks[j]) { + aws_host_address_clean_up(&tasks[j]->host_address); + aws_mem_release(allocator, tasks[j]); + } + } + int alloc_err_code = aws_last_error(); + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: failed to allocate connection task data: err=%d", + (void *)client_connection_args->bootstrap, + alloc_err_code); + s_connection_args_setup_callback(client_connection_args, alloc_err_code, NULL); + return; + } + } + + /* ...then schedule all the tasks, which cannot fail */ + for (size_t i = 0; i < host_addresses_len; ++i) { + struct connection_task_data *task_data = tasks[i]; + /* each task needs to hold a ref to the args until completed */ + s_client_connection_args_acquire(task_data->args); + + aws_task_init(&task_data->task, s_attempt_connection, task_data, "attempt_connection"); + aws_event_loop_schedule_task_now(connect_loop, &task_data->task); + } +} + +int aws_client_bootstrap_new_socket_channel(struct aws_socket_channel_bootstrap_options *options) { + + struct aws_client_bootstrap *bootstrap = options->bootstrap; + AWS_FATAL_ASSERT(options->setup_callback); + AWS_FATAL_ASSERT(options->shutdown_callback); + AWS_FATAL_ASSERT(bootstrap); + + const struct aws_socket_options *socket_options = options->socket_options; + AWS_FATAL_ASSERT(socket_options != NULL); + + const struct aws_tls_connection_options *tls_options = options->tls_options; + + AWS_FATAL_ASSERT(tls_options == NULL || socket_options->type == AWS_SOCKET_STREAM); + aws_io_fatal_assert_library_initialized(); + + struct client_connection_args *client_connection_args = + aws_mem_calloc(bootstrap->allocator, 1, sizeof(struct client_connection_args)); + + if (!client_connection_args) { + return AWS_OP_ERR; + } + + const char *host_name = options->host_name; + uint16_t port = options->port; + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: attempting to initialize a new client channel to %s:%d", + (void *)bootstrap, + host_name, + (int)port); + + aws_ref_count_init( + &client_connection_args->ref_count, + client_connection_args, + (aws_simple_completion_callback *)s_client_connection_args_destroy); + client_connection_args->user_data = options->user_data; + client_connection_args->bootstrap = aws_client_bootstrap_acquire(bootstrap); + client_connection_args->creation_callback = options->creation_callback; + client_connection_args->setup_callback = options->setup_callback; + client_connection_args->shutdown_callback = options->shutdown_callback; + client_connection_args->outgoing_options = *socket_options; + client_connection_args->outgoing_port = port; + client_connection_args->enable_read_back_pressure = options->enable_read_back_pressure; + + if (tls_options) { + if (aws_tls_connection_options_copy(&client_connection_args->channel_data.tls_options, tls_options)) { + goto error; + } + client_connection_args->channel_data.use_tls = true; + + client_connection_args->channel_data.on_protocol_negotiated = bootstrap->on_protocol_negotiated; + client_connection_args->channel_data.tls_user_data = tls_options->user_data; + + /* in order to honor any callbacks a user may have installed on their tls_connection_options, + * we need to wrap them if they were set.*/ + if (bootstrap->on_protocol_negotiated) { + client_connection_args->channel_data.tls_options.advertise_alpn_message = true; + } + + if (tls_options->on_data_read) { + client_connection_args->channel_data.user_on_data_read = tls_options->on_data_read; + client_connection_args->channel_data.tls_options.on_data_read = s_tls_client_on_data_read; + } + + if (tls_options->on_error) { + client_connection_args->channel_data.user_on_error = tls_options->on_error; + client_connection_args->channel_data.tls_options.on_error = s_tls_client_on_error; + } + + if (tls_options->on_negotiation_result) { + client_connection_args->channel_data.user_on_negotiation_result = tls_options->on_negotiation_result; + } + + client_connection_args->channel_data.tls_options.on_negotiation_result = s_tls_client_on_negotiation_result; + client_connection_args->channel_data.tls_options.user_data = client_connection_args; + } + + if (s_aws_socket_domain_uses_dns(socket_options->domain)) { + client_connection_args->host_name = aws_string_new_from_c_str(bootstrap->allocator, host_name); + + if (!client_connection_args->host_name) { + goto error; + } + + if (aws_host_resolver_resolve_host( + bootstrap->host_resolver, + client_connection_args->host_name, + s_on_host_resolved, + &bootstrap->host_resolver_config, + client_connection_args)) { + goto error; + } + } else { + /* ensure that the pipe/domain socket name will fit in the endpoint address */ + const size_t host_name_len = strlen(host_name); + if (host_name_len >= AWS_ADDRESS_MAX_LEN) { + aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS); + goto error; + } + + struct aws_socket_endpoint endpoint; + AWS_ZERO_STRUCT(endpoint); + memcpy(endpoint.address, host_name, host_name_len); + if (socket_options->domain == AWS_SOCKET_VSOCK) { + endpoint.port = port; + } else { + endpoint.port = 0; + } + + struct aws_socket *outgoing_socket = aws_mem_acquire(bootstrap->allocator, sizeof(struct aws_socket)); + + if (!outgoing_socket) { + goto error; + } + + if (aws_socket_init(outgoing_socket, bootstrap->allocator, socket_options)) { + aws_mem_release(bootstrap->allocator, outgoing_socket); + goto error; + } + + client_connection_args->addresses_count = 1; + + struct aws_event_loop *connect_loop = aws_event_loop_group_get_next_loop(bootstrap->event_loop_group); + + s_client_connection_args_acquire(client_connection_args); + if (aws_socket_connect( + outgoing_socket, &endpoint, connect_loop, s_on_client_connection_established, client_connection_args)) { + aws_socket_clean_up(outgoing_socket); + aws_mem_release(client_connection_args->bootstrap->allocator, outgoing_socket); + s_client_connection_args_release(client_connection_args); + goto error; + } + } + + return AWS_OP_SUCCESS; + +error: + if (client_connection_args) { + /* tls opt will also be freed when we clean up the connection arg */ + s_client_connection_args_release(client_connection_args); + } + return AWS_OP_ERR; +} + +void s_server_bootstrap_destroy_impl(struct aws_server_bootstrap *bootstrap) { + AWS_ASSERT(bootstrap); + aws_event_loop_group_release(bootstrap->event_loop_group); + aws_mem_release(bootstrap->allocator, bootstrap); +} + +struct aws_server_bootstrap *aws_server_bootstrap_acquire(struct aws_server_bootstrap *bootstrap) { + if (bootstrap != NULL) { + aws_ref_count_acquire(&bootstrap->ref_count); + } + + return bootstrap; +} + +void aws_server_bootstrap_release(struct aws_server_bootstrap *bootstrap) { + /* if destroy is being called, the user intends to not use the bootstrap anymore + * so we clean up the thread local state while the event loop thread is + * still alive */ + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: releasing server bootstrap reference", (void *)bootstrap); + if (bootstrap != NULL) { + aws_ref_count_release(&bootstrap->ref_count); + } +} + +struct aws_server_bootstrap *aws_server_bootstrap_new( + struct aws_allocator *allocator, + struct aws_event_loop_group *el_group) { + AWS_ASSERT(allocator); + AWS_ASSERT(el_group); + + struct aws_server_bootstrap *bootstrap = aws_mem_calloc(allocator, 1, sizeof(struct aws_server_bootstrap)); + if (!bootstrap) { + return NULL; + } + + AWS_LOGF_INFO( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Initializing server bootstrap with event-loop group %p", + (void *)bootstrap, + (void *)el_group); + + bootstrap->allocator = allocator; + bootstrap->event_loop_group = aws_event_loop_group_acquire(el_group); + bootstrap->on_protocol_negotiated = NULL; + aws_ref_count_init( + &bootstrap->ref_count, bootstrap, (aws_simple_completion_callback *)s_server_bootstrap_destroy_impl); + + return bootstrap; +} + +struct server_connection_args { + struct aws_server_bootstrap *bootstrap; + struct aws_socket listener; + aws_server_bootstrap_on_accept_channel_setup_fn *incoming_callback; + aws_server_bootstrap_on_accept_channel_shutdown_fn *shutdown_callback; + aws_server_bootstrap_on_server_listener_destroy_fn *destroy_callback; + struct aws_tls_connection_options tls_options; + aws_channel_on_protocol_negotiated_fn *on_protocol_negotiated; + aws_tls_on_data_read_fn *user_on_data_read; + aws_tls_on_negotiation_result_fn *user_on_negotiation_result; + aws_tls_on_error_fn *user_on_error; + struct aws_task listener_destroy_task; + void *tls_user_data; + void *user_data; + bool use_tls; + bool enable_read_back_pressure; + struct aws_ref_count ref_count; +}; + +struct server_channel_data { + struct aws_channel *channel; + struct aws_socket *socket; + struct server_connection_args *server_connection_args; + bool incoming_called; +}; + +static struct server_connection_args *s_server_connection_args_acquire(struct server_connection_args *args) { + if (args != NULL) { + aws_ref_count_acquire(&args->ref_count); + } + + return args; +} + +static void s_server_connection_args_destroy(struct server_connection_args *args) { + if (args == NULL) { + return; + } + + /* fire the destroy callback */ + if (args->destroy_callback) { + args->destroy_callback(args->bootstrap, args->user_data); + } + + struct aws_allocator *allocator = args->bootstrap->allocator; + aws_server_bootstrap_release(args->bootstrap); + if (args->use_tls) { + aws_tls_connection_options_clean_up(&args->tls_options); + } + + aws_mem_release(allocator, args); +} + +static void s_server_connection_args_release(struct server_connection_args *args) { + if (args != NULL) { + aws_ref_count_release(&args->ref_count); + } +} + +static void s_server_incoming_callback( + struct server_channel_data *channel_data, + int error_code, + struct aws_channel *channel) { + /* incoming_callback is always called exactly once for each channel */ + AWS_ASSERT(!channel_data->incoming_called); + struct server_connection_args *args = channel_data->server_connection_args; + args->incoming_callback(args->bootstrap, error_code, channel, args->user_data); + channel_data->incoming_called = true; +} + +static void s_tls_server_on_negotiation_result( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + int err_code, + void *user_data) { + struct server_channel_data *channel_data = user_data; + struct server_connection_args *connection_args = channel_data->server_connection_args; + + if (connection_args->user_on_negotiation_result) { + connection_args->user_on_negotiation_result(handler, slot, err_code, connection_args->tls_user_data); + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: tls negotiation result %d on channel %p", + (void *)connection_args->bootstrap, + err_code, + (void *)slot->channel); + + struct aws_channel *channel = slot->channel; + if (err_code) { + /* shut down the channel */ + aws_channel_shutdown(channel, err_code); + } else { + s_server_incoming_callback(channel_data, err_code, channel); + } +} + +/* in the context of a channel bootstrap, we don't care about these, but since we're hooking into these APIs we have to + * provide a proxy for the user actually receiving their callbacks. */ +static void s_tls_server_on_data_read( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_byte_buf *buffer, + void *user_data) { + struct server_connection_args *connection_args = user_data; + + if (connection_args->user_on_data_read) { + connection_args->user_on_data_read(handler, slot, buffer, connection_args->tls_user_data); + } +} + +/* in the context of a channel bootstrap, we don't care about these, but since we're hooking into these APIs we have to + * provide a proxy for the user actually receiving their callbacks. */ +static void s_tls_server_on_error( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + int err, + const char *message, + void *user_data) { + struct server_connection_args *connection_args = user_data; + + if (connection_args->user_on_error) { + connection_args->user_on_error(handler, slot, err, message, connection_args->tls_user_data); + } +} + +static inline int s_setup_server_tls(struct server_channel_data *channel_data, struct aws_channel *channel) { + struct aws_channel_slot *tls_slot = NULL; + struct aws_channel_handler *tls_handler = NULL; + struct server_connection_args *connection_args = channel_data->server_connection_args; + + /* as far as cleanup goes here, since we're adding things to a channel, if a slot is ever successfully + added to the channel, we leave it there. The caller will clean up the channel and it will clean this memory + up as well. */ + tls_slot = aws_channel_slot_new(channel); + + if (!tls_slot) { + return AWS_OP_ERR; + } + + /* Shallow-copy tls_options so we can override the user_data, making it specific to this channel */ + struct aws_tls_connection_options tls_options = connection_args->tls_options; + tls_options.user_data = channel_data; + tls_handler = aws_tls_server_handler_new(connection_args->bootstrap->allocator, &tls_options, tls_slot); + + if (!tls_handler) { + aws_mem_release(connection_args->bootstrap->allocator, tls_slot); + return AWS_OP_ERR; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Setting up server TLS on channel %p with handler %p on slot %p", + (void *)connection_args->bootstrap, + (void *)channel, + (void *)tls_handler, + (void *)tls_slot); + + aws_channel_slot_insert_end(channel, tls_slot); + + if (aws_channel_slot_set_handler(tls_slot, tls_handler)) { + return AWS_OP_ERR; + } + + if (connection_args->on_protocol_negotiated) { + struct aws_channel_slot *alpn_slot = NULL; + struct aws_channel_handler *alpn_handler = NULL; + alpn_slot = aws_channel_slot_new(channel); + + if (!alpn_slot) { + return AWS_OP_ERR; + } + + alpn_handler = aws_tls_alpn_handler_new( + connection_args->bootstrap->allocator, connection_args->on_protocol_negotiated, connection_args->user_data); + + if (!alpn_handler) { + aws_channel_slot_remove(alpn_slot); + return AWS_OP_ERR; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Setting up ALPN handler on channel " + "%p with handler %p on slot %p", + (void *)connection_args->bootstrap, + (void *)channel, + (void *)alpn_handler, + (void *)alpn_slot); + + aws_channel_slot_insert_right(tls_slot, alpn_slot); + + if (aws_channel_slot_set_handler(alpn_slot, alpn_handler)) { + return AWS_OP_ERR; + } + } + + return AWS_OP_SUCCESS; +} + +static void s_on_server_channel_on_setup_completed(struct aws_channel *channel, int error_code, void *user_data) { + struct server_channel_data *channel_data = user_data; + + int err_code = error_code; + if (err_code) { + /* channel fail to set up no destroy callback will fire */ + AWS_LOGF_ERROR( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: channel %p setup failed with error %d.", + (void *)channel_data->server_connection_args->bootstrap, + (void *)channel, + err_code); + + aws_channel_destroy(channel); + struct aws_allocator *allocator = channel_data->socket->allocator; + aws_socket_clean_up(channel_data->socket); + aws_mem_release(allocator, (void *)channel_data->socket); + s_server_incoming_callback(channel_data, err_code, NULL); + aws_mem_release(channel_data->server_connection_args->bootstrap->allocator, channel_data); + /* no shutdown call back will be fired, we release the ref_count of connection arg here */ + s_server_connection_args_release(channel_data->server_connection_args); + return; + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: channel %p setup succeeded: bootstrapping.", + (void *)channel_data->server_connection_args->bootstrap, + (void *)channel); + + struct aws_channel_slot *socket_slot = aws_channel_slot_new(channel); + + if (!socket_slot) { + err_code = aws_last_error(); + goto error; + } + + struct aws_channel_handler *socket_channel_handler = aws_socket_handler_new( + channel_data->server_connection_args->bootstrap->allocator, + channel_data->socket, + socket_slot, + g_aws_channel_max_fragment_size); + + if (!socket_channel_handler) { + err_code = aws_last_error(); + aws_channel_slot_remove(socket_slot); + socket_slot = NULL; + goto error; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: Setting up socket handler on channel " + "%p with handler %p on slot %p.", + (void *)channel_data->server_connection_args->bootstrap, + (void *)channel, + (void *)socket_channel_handler, + (void *)socket_slot); + + if (aws_channel_slot_set_handler(socket_slot, socket_channel_handler)) { + err_code = aws_last_error(); + goto error; + } + + if (channel_data->server_connection_args->use_tls) { + /* incoming callback will be invoked upon the negotiation completion so don't do it + * here. */ + if (s_setup_server_tls(channel_data, channel)) { + err_code = aws_last_error(); + goto error; + } + } else { + s_server_incoming_callback(channel_data, AWS_OP_SUCCESS, channel); + } + return; + +error: + /* shut down the channel */ + aws_channel_shutdown(channel, err_code); +} + +static void s_on_server_channel_on_shutdown(struct aws_channel *channel, int error_code, void *user_data) { + struct server_channel_data *channel_data = user_data; + struct server_connection_args *args = channel_data->server_connection_args; + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: channel %p shutdown with error %d.", + (void *)args->bootstrap, + (void *)channel, + error_code); + + void *server_shutdown_user_data = args->user_data; + struct aws_server_bootstrap *server_bootstrap = args->bootstrap; + struct aws_allocator *allocator = server_bootstrap->allocator; + + if (!channel_data->incoming_called) { + error_code = (error_code) ? error_code : AWS_ERROR_UNKNOWN; + s_server_incoming_callback(channel_data, error_code, NULL); + } else { + args->shutdown_callback(server_bootstrap, error_code, channel, server_shutdown_user_data); + } + + aws_channel_destroy(channel); + aws_socket_clean_up(channel_data->socket); + aws_mem_release(allocator, channel_data->socket); + s_server_connection_args_release(channel_data->server_connection_args); + + aws_mem_release(allocator, channel_data); +} + +void s_on_server_connection_result( + struct aws_socket *socket, + int error_code, + struct aws_socket *new_socket, + void *user_data) { + (void)socket; + struct server_connection_args *connection_args = user_data; + + s_server_connection_args_acquire(connection_args); + AWS_LOGF_DEBUG( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: server connection on socket %p completed with error %d.", + (void *)connection_args->bootstrap, + (void *)socket, + error_code); + + if (!error_code) { + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: creating a new channel for incoming " + "connection using socket %p.", + (void *)connection_args->bootstrap, + (void *)socket); + struct server_channel_data *channel_data = + aws_mem_calloc(connection_args->bootstrap->allocator, 1, sizeof(struct server_channel_data)); + if (!channel_data) { + goto error_cleanup; + } + channel_data->incoming_called = false; + channel_data->socket = new_socket; + channel_data->server_connection_args = connection_args; + + struct aws_event_loop *event_loop = + aws_event_loop_group_get_next_loop(connection_args->bootstrap->event_loop_group); + + struct aws_channel_options channel_args = { + .on_setup_completed = s_on_server_channel_on_setup_completed, + .setup_user_data = channel_data, + .shutdown_user_data = channel_data, + .on_shutdown_completed = s_on_server_channel_on_shutdown, + }; + + channel_args.event_loop = event_loop; + channel_args.enable_read_back_pressure = channel_data->server_connection_args->enable_read_back_pressure; + + if (aws_socket_assign_to_event_loop(new_socket, event_loop)) { + aws_mem_release(connection_args->bootstrap->allocator, (void *)channel_data); + goto error_cleanup; + } + + channel_data->channel = aws_channel_new(connection_args->bootstrap->allocator, &channel_args); + + if (!channel_data->channel) { + aws_mem_release(connection_args->bootstrap->allocator, (void *)channel_data); + goto error_cleanup; + } + } else { + /* no channel is created */ + connection_args->incoming_callback(connection_args->bootstrap, error_code, NULL, connection_args->user_data); + s_server_connection_args_release(connection_args); + } + + return; + +error_cleanup: + /* no channel is created */ + connection_args->incoming_callback(connection_args->bootstrap, aws_last_error(), NULL, connection_args->user_data); + + struct aws_allocator *allocator = new_socket->allocator; + aws_socket_clean_up(new_socket); + aws_mem_release(allocator, (void *)new_socket); + s_server_connection_args_release(connection_args); +} + +static void s_listener_destroy_task(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)status; + (void)task; + struct server_connection_args *server_connection_args = arg; + + aws_socket_stop_accept(&server_connection_args->listener); + aws_socket_clean_up(&server_connection_args->listener); + s_server_connection_args_release(server_connection_args); +} + +struct aws_socket *aws_server_bootstrap_new_socket_listener( + const struct aws_server_socket_channel_bootstrap_options *bootstrap_options) { + AWS_PRECONDITION(bootstrap_options); + AWS_PRECONDITION(bootstrap_options->bootstrap); + AWS_PRECONDITION(bootstrap_options->incoming_callback) + AWS_PRECONDITION(bootstrap_options->shutdown_callback) + + struct server_connection_args *server_connection_args = + aws_mem_calloc(bootstrap_options->bootstrap->allocator, 1, sizeof(struct server_connection_args)); + if (!server_connection_args) { + return NULL; + } + + AWS_LOGF_INFO( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "id=%p: attempting to initialize a new " + "server socket listener for %s:%d", + (void *)server_connection_args->bootstrap, + bootstrap_options->host_name, + (int)bootstrap_options->port); + + aws_ref_count_init( + &server_connection_args->ref_count, + server_connection_args, + (aws_simple_completion_callback *)s_server_connection_args_destroy); + server_connection_args->user_data = bootstrap_options->user_data; + server_connection_args->bootstrap = aws_server_bootstrap_acquire(bootstrap_options->bootstrap); + server_connection_args->shutdown_callback = bootstrap_options->shutdown_callback; + server_connection_args->incoming_callback = bootstrap_options->incoming_callback; + server_connection_args->destroy_callback = bootstrap_options->destroy_callback; + server_connection_args->on_protocol_negotiated = bootstrap_options->bootstrap->on_protocol_negotiated; + server_connection_args->enable_read_back_pressure = bootstrap_options->enable_read_back_pressure; + + aws_task_init( + &server_connection_args->listener_destroy_task, + s_listener_destroy_task, + server_connection_args, + "listener socket destroy"); + + if (bootstrap_options->tls_options) { + AWS_LOGF_INFO( + AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: using tls on listener", (void *)bootstrap_options->tls_options); + if (aws_tls_connection_options_copy(&server_connection_args->tls_options, bootstrap_options->tls_options)) { + goto cleanup_server_connection_args; + } + + server_connection_args->use_tls = true; + + server_connection_args->tls_user_data = bootstrap_options->tls_options->user_data; + + /* in order to honor any callbacks a user may have installed on their tls_connection_options, + * we need to wrap them if they were set.*/ + if (bootstrap_options->bootstrap->on_protocol_negotiated) { + server_connection_args->tls_options.advertise_alpn_message = true; + } + + if (bootstrap_options->tls_options->on_data_read) { + server_connection_args->user_on_data_read = bootstrap_options->tls_options->on_data_read; + server_connection_args->tls_options.on_data_read = s_tls_server_on_data_read; + } + + if (bootstrap_options->tls_options->on_error) { + server_connection_args->user_on_error = bootstrap_options->tls_options->on_error; + server_connection_args->tls_options.on_error = s_tls_server_on_error; + } + + if (bootstrap_options->tls_options->on_negotiation_result) { + server_connection_args->user_on_negotiation_result = bootstrap_options->tls_options->on_negotiation_result; + } + + server_connection_args->tls_options.on_negotiation_result = s_tls_server_on_negotiation_result; + server_connection_args->tls_options.user_data = server_connection_args; + } + + struct aws_event_loop *connection_loop = + aws_event_loop_group_get_next_loop(bootstrap_options->bootstrap->event_loop_group); + + if (aws_socket_init( + &server_connection_args->listener, + bootstrap_options->bootstrap->allocator, + bootstrap_options->socket_options)) { + goto cleanup_server_connection_args; + } + + struct aws_socket_endpoint endpoint; + AWS_ZERO_STRUCT(endpoint); + size_t host_name_len = 0; + if (aws_secure_strlen(bootstrap_options->host_name, sizeof(endpoint.address), &host_name_len)) { + goto cleanup_server_connection_args; + } + + memcpy(endpoint.address, bootstrap_options->host_name, host_name_len); + endpoint.port = bootstrap_options->port; + + if (aws_socket_bind(&server_connection_args->listener, &endpoint)) { + goto cleanup_listener; + } + + if (aws_socket_listen(&server_connection_args->listener, 1024)) { + goto cleanup_listener; + } + + if (aws_socket_start_accept( + &server_connection_args->listener, + connection_loop, + s_on_server_connection_result, + server_connection_args)) { + goto cleanup_listener; + } + + return &server_connection_args->listener; + +cleanup_listener: + aws_socket_clean_up(&server_connection_args->listener); + +cleanup_server_connection_args: + s_server_connection_args_release(server_connection_args); + + return NULL; +} + +void aws_server_bootstrap_destroy_socket_listener(struct aws_server_bootstrap *bootstrap, struct aws_socket *listener) { + struct server_connection_args *server_connection_args = + AWS_CONTAINER_OF(listener, struct server_connection_args, listener); + + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: releasing bootstrap reference", (void *)bootstrap); + aws_event_loop_schedule_task_now(listener->event_loop, &server_connection_args->listener_destroy_task); +} + +int aws_server_bootstrap_set_alpn_callback( + struct aws_server_bootstrap *bootstrap, + aws_channel_on_protocol_negotiated_fn *on_protocol_negotiated) { + AWS_ASSERT(on_protocol_negotiated); + AWS_LOGF_DEBUG(AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: Setting ALPN callback", (void *)bootstrap); + bootstrap->on_protocol_negotiated = on_protocol_negotiated; + return AWS_OP_SUCCESS; +} diff --git a/contrib/restricted/aws/aws-c-io/source/event_loop.c b/contrib/restricted/aws/aws-c-io/source/event_loop.c index 0b4a41b374..18af6e9fac 100644 --- a/contrib/restricted/aws/aws-c-io/source/event_loop.c +++ b/contrib/restricted/aws/aws-c-io/source/event_loop.c @@ -1,367 +1,367 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/event_loop.h> - -#include <aws/common/clock.h> -#include <aws/common/system_info.h> -#include <aws/common/thread.h> - -static void s_event_loop_group_thread_exit(void *user_data) { - struct aws_event_loop_group *el_group = user_data; - - aws_simple_completion_callback *completion_callback = el_group->shutdown_options.shutdown_callback_fn; - void *completion_user_data = el_group->shutdown_options.shutdown_callback_user_data; - - aws_mem_release(el_group->allocator, el_group); - - if (completion_callback != NULL) { - completion_callback(completion_user_data); - } - - aws_global_thread_creator_decrement(); -} - -static void s_aws_event_loop_group_shutdown_sync(struct aws_event_loop_group *el_group) { - while (aws_array_list_length(&el_group->event_loops) > 0) { - struct aws_event_loop *loop = NULL; - - if (!aws_array_list_back(&el_group->event_loops, &loop)) { - aws_event_loop_destroy(loop); - } - - aws_array_list_pop_back(&el_group->event_loops); - } - - aws_array_list_clean_up(&el_group->event_loops); -} - -static void s_event_loop_destroy_async_thread_fn(void *thread_data) { - struct aws_event_loop_group *el_group = thread_data; - - s_aws_event_loop_group_shutdown_sync(el_group); - - aws_thread_current_at_exit(s_event_loop_group_thread_exit, el_group); -} - -static void s_aws_event_loop_group_shutdown_async(struct aws_event_loop_group *el_group) { - - /* It's possible that the last refcount was released on an event-loop thread, - * so we would deadlock if we waited here for all the event-loop threads to shut down. - * Therefore, we spawn a NEW thread and have it wait for all the event-loop threads to shut down - */ - struct aws_thread cleanup_thread; - AWS_ZERO_STRUCT(cleanup_thread); - - AWS_FATAL_ASSERT(aws_thread_init(&cleanup_thread, el_group->allocator) == AWS_OP_SUCCESS); - - struct aws_thread_options thread_options; - AWS_ZERO_STRUCT(thread_options); - - AWS_FATAL_ASSERT( - aws_thread_launch(&cleanup_thread, s_event_loop_destroy_async_thread_fn, el_group, &thread_options) == - AWS_OP_SUCCESS); - - aws_thread_clean_up(&cleanup_thread); -} - -struct aws_event_loop_group *aws_event_loop_group_new( - struct aws_allocator *alloc, - aws_io_clock_fn *clock, - uint16_t el_count, - aws_new_event_loop_fn *new_loop_fn, - void *new_loop_user_data, - const struct aws_shutdown_callback_options *shutdown_options) { - - AWS_ASSERT(new_loop_fn); - - struct aws_event_loop_group *el_group = aws_mem_calloc(alloc, 1, sizeof(struct aws_event_loop_group)); - if (el_group == NULL) { - return NULL; - } - - el_group->allocator = alloc; - aws_ref_count_init( - &el_group->ref_count, el_group, (aws_simple_completion_callback *)s_aws_event_loop_group_shutdown_async); - aws_atomic_init_int(&el_group->current_index, 0); - - if (aws_array_list_init_dynamic(&el_group->event_loops, alloc, el_count, sizeof(struct aws_event_loop *))) { - goto on_error; - } - - for (uint16_t i = 0; i < el_count; ++i) { - struct aws_event_loop *loop = new_loop_fn(alloc, clock, new_loop_user_data); - - if (!loop) { - goto on_error; - } - - if (aws_array_list_push_back(&el_group->event_loops, (const void *)&loop)) { - aws_event_loop_destroy(loop); - goto on_error; - } - - if (aws_event_loop_run(loop)) { - goto on_error; - } - } - - if (shutdown_options != NULL) { - el_group->shutdown_options = *shutdown_options; - } - - aws_global_thread_creator_increment(); - - return el_group; - -on_error: - - s_aws_event_loop_group_shutdown_sync(el_group); - s_event_loop_group_thread_exit(el_group); - - return NULL; -} - -static struct aws_event_loop *default_new_event_loop( - struct aws_allocator *allocator, - aws_io_clock_fn *clock, - void *user_data) { - - (void)user_data; - return aws_event_loop_new_default(allocator, clock); -} - -struct aws_event_loop_group *aws_event_loop_group_new_default( - struct aws_allocator *alloc, - uint16_t max_threads, - const struct aws_shutdown_callback_options *shutdown_options) { - if (!max_threads) { - max_threads = (uint16_t)aws_system_info_processor_count(); - } - - return aws_event_loop_group_new( - alloc, aws_high_res_clock_get_ticks, max_threads, default_new_event_loop, NULL, shutdown_options); -} - -struct aws_event_loop_group *aws_event_loop_group_acquire(struct aws_event_loop_group *el_group) { - if (el_group != NULL) { - aws_ref_count_acquire(&el_group->ref_count); - } - - return el_group; -} - -void aws_event_loop_group_release(struct aws_event_loop_group *el_group) { - if (el_group != NULL) { - aws_ref_count_release(&el_group->ref_count); - } -} - -size_t aws_event_loop_group_get_loop_count(struct aws_event_loop_group *el_group) { - return aws_array_list_length(&el_group->event_loops); -} - -struct aws_event_loop *aws_event_loop_group_get_loop_at(struct aws_event_loop_group *el_group, size_t index) { - struct aws_event_loop *el = NULL; - aws_array_list_get_at(&el_group->event_loops, &el, index); - return el; -} - -struct aws_event_loop *aws_event_loop_group_get_next_loop(struct aws_event_loop_group *el_group) { - size_t loop_count = aws_array_list_length(&el_group->event_loops); - AWS_ASSERT(loop_count > 0); - if (loop_count == 0) { - return NULL; - } - - /* thread safety: atomic CAS to ensure we got the best loop, and that the index is within bounds */ - size_t old_index = 0; - size_t new_index = 0; - do { - old_index = aws_atomic_load_int(&el_group->current_index); - new_index = (old_index + 1) % loop_count; - } while (!aws_atomic_compare_exchange_int(&el_group->current_index, &old_index, new_index)); - - struct aws_event_loop *loop = NULL; - - /* if the fetch fails, we don't really care since loop will be NULL and error code will already be set. */ - aws_array_list_get_at(&el_group->event_loops, &loop, old_index); - return loop; -} - -static void s_object_removed(void *value) { - struct aws_event_loop_local_object *object = (struct aws_event_loop_local_object *)value; - if (object->on_object_removed) { - object->on_object_removed(object); - } -} - -int aws_event_loop_init_base(struct aws_event_loop *event_loop, struct aws_allocator *alloc, aws_io_clock_fn *clock) { - AWS_ZERO_STRUCT(*event_loop); - - event_loop->alloc = alloc; - event_loop->clock = clock; - - if (aws_hash_table_init(&event_loop->local_data, alloc, 20, aws_hash_ptr, aws_ptr_eq, NULL, s_object_removed)) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -void aws_event_loop_clean_up_base(struct aws_event_loop *event_loop) { - aws_hash_table_clean_up(&event_loop->local_data); -} - -void aws_event_loop_destroy(struct aws_event_loop *event_loop) { - if (!event_loop) { - return; - } - - AWS_ASSERT(event_loop->vtable && event_loop->vtable->destroy); - AWS_ASSERT(!aws_event_loop_thread_is_callers_thread(event_loop)); - - event_loop->vtable->destroy(event_loop); -} - -int aws_event_loop_fetch_local_object( - struct aws_event_loop *event_loop, - void *key, - struct aws_event_loop_local_object *obj) { - - AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); - - struct aws_hash_element *object = NULL; - if (!aws_hash_table_find(&event_loop->local_data, key, &object) && object) { - *obj = *(struct aws_event_loop_local_object *)object->value; - return AWS_OP_SUCCESS; - } - - return AWS_OP_ERR; -} - -int aws_event_loop_put_local_object(struct aws_event_loop *event_loop, struct aws_event_loop_local_object *obj) { - AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); - - struct aws_hash_element *object = NULL; - int was_created = 0; - - if (!aws_hash_table_create(&event_loop->local_data, obj->key, &object, &was_created)) { - object->key = obj->key; - object->value = obj; - return AWS_OP_SUCCESS; - } - - return AWS_OP_ERR; -} - -int aws_event_loop_remove_local_object( - struct aws_event_loop *event_loop, - void *key, - struct aws_event_loop_local_object *removed_obj) { - - AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); - - struct aws_hash_element existing_object; - AWS_ZERO_STRUCT(existing_object); - - int was_present = 0; - - struct aws_hash_element *remove_candidate = removed_obj ? &existing_object : NULL; - - if (!aws_hash_table_remove(&event_loop->local_data, key, remove_candidate, &was_present)) { - if (remove_candidate && was_present) { - *removed_obj = *(struct aws_event_loop_local_object *)existing_object.value; - } - - return AWS_OP_SUCCESS; - } - - return AWS_OP_ERR; -} - -int aws_event_loop_run(struct aws_event_loop *event_loop) { - AWS_ASSERT(event_loop->vtable && event_loop->vtable->run); - return event_loop->vtable->run(event_loop); -} - -int aws_event_loop_stop(struct aws_event_loop *event_loop) { - AWS_ASSERT(event_loop->vtable && event_loop->vtable->stop); - return event_loop->vtable->stop(event_loop); -} - -int aws_event_loop_wait_for_stop_completion(struct aws_event_loop *event_loop) { - AWS_ASSERT(!aws_event_loop_thread_is_callers_thread(event_loop)); - AWS_ASSERT(event_loop->vtable && event_loop->vtable->wait_for_stop_completion); - return event_loop->vtable->wait_for_stop_completion(event_loop); -} - -void aws_event_loop_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task) { - AWS_ASSERT(event_loop->vtable && event_loop->vtable->schedule_task_now); - AWS_ASSERT(task); - event_loop->vtable->schedule_task_now(event_loop, task); -} - -void aws_event_loop_schedule_task_future( - struct aws_event_loop *event_loop, - struct aws_task *task, - uint64_t run_at_nanos) { - - AWS_ASSERT(event_loop->vtable && event_loop->vtable->schedule_task_future); - AWS_ASSERT(task); - event_loop->vtable->schedule_task_future(event_loop, task, run_at_nanos); -} - -void aws_event_loop_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task) { - AWS_ASSERT(event_loop->vtable && event_loop->vtable->cancel_task); - AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); - AWS_ASSERT(task); - event_loop->vtable->cancel_task(event_loop, task); -} - -#if AWS_USE_IO_COMPLETION_PORTS - -int aws_event_loop_connect_handle_to_io_completion_port( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle) { - - AWS_ASSERT(event_loop->vtable && event_loop->vtable->connect_to_io_completion_port); - return event_loop->vtable->connect_to_io_completion_port(event_loop, handle); -} - -#else /* !AWS_USE_IO_COMPLETION_PORTS */ - -int aws_event_loop_subscribe_to_io_events( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - aws_event_loop_on_event_fn *on_event, - void *user_data) { - - AWS_ASSERT(event_loop->vtable && event_loop->vtable->subscribe_to_io_events); - return event_loop->vtable->subscribe_to_io_events(event_loop, handle, events, on_event, user_data); -} -#endif /* AWS_USE_IO_COMPLETION_PORTS */ - -int aws_event_loop_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { - AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); - AWS_ASSERT(event_loop->vtable && event_loop->vtable->unsubscribe_from_io_events); - return event_loop->vtable->unsubscribe_from_io_events(event_loop, handle); -} - -void aws_event_loop_free_io_event_resources(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { - AWS_ASSERT(event_loop && event_loop->vtable->free_io_event_resources); - event_loop->vtable->free_io_event_resources(handle->additional_data); -} - -bool aws_event_loop_thread_is_callers_thread(struct aws_event_loop *event_loop) { - AWS_ASSERT(event_loop->vtable && event_loop->vtable->is_on_callers_thread); - return event_loop->vtable->is_on_callers_thread(event_loop); -} - -int aws_event_loop_current_clock_time(struct aws_event_loop *event_loop, uint64_t *time_nanos) { - AWS_ASSERT(event_loop->clock); - return event_loop->clock(time_nanos); -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/event_loop.h> + +#include <aws/common/clock.h> +#include <aws/common/system_info.h> +#include <aws/common/thread.h> + +static void s_event_loop_group_thread_exit(void *user_data) { + struct aws_event_loop_group *el_group = user_data; + + aws_simple_completion_callback *completion_callback = el_group->shutdown_options.shutdown_callback_fn; + void *completion_user_data = el_group->shutdown_options.shutdown_callback_user_data; + + aws_mem_release(el_group->allocator, el_group); + + if (completion_callback != NULL) { + completion_callback(completion_user_data); + } + + aws_global_thread_creator_decrement(); +} + +static void s_aws_event_loop_group_shutdown_sync(struct aws_event_loop_group *el_group) { + while (aws_array_list_length(&el_group->event_loops) > 0) { + struct aws_event_loop *loop = NULL; + + if (!aws_array_list_back(&el_group->event_loops, &loop)) { + aws_event_loop_destroy(loop); + } + + aws_array_list_pop_back(&el_group->event_loops); + } + + aws_array_list_clean_up(&el_group->event_loops); +} + +static void s_event_loop_destroy_async_thread_fn(void *thread_data) { + struct aws_event_loop_group *el_group = thread_data; + + s_aws_event_loop_group_shutdown_sync(el_group); + + aws_thread_current_at_exit(s_event_loop_group_thread_exit, el_group); +} + +static void s_aws_event_loop_group_shutdown_async(struct aws_event_loop_group *el_group) { + + /* It's possible that the last refcount was released on an event-loop thread, + * so we would deadlock if we waited here for all the event-loop threads to shut down. + * Therefore, we spawn a NEW thread and have it wait for all the event-loop threads to shut down + */ + struct aws_thread cleanup_thread; + AWS_ZERO_STRUCT(cleanup_thread); + + AWS_FATAL_ASSERT(aws_thread_init(&cleanup_thread, el_group->allocator) == AWS_OP_SUCCESS); + + struct aws_thread_options thread_options; + AWS_ZERO_STRUCT(thread_options); + + AWS_FATAL_ASSERT( + aws_thread_launch(&cleanup_thread, s_event_loop_destroy_async_thread_fn, el_group, &thread_options) == + AWS_OP_SUCCESS); + + aws_thread_clean_up(&cleanup_thread); +} + +struct aws_event_loop_group *aws_event_loop_group_new( + struct aws_allocator *alloc, + aws_io_clock_fn *clock, + uint16_t el_count, + aws_new_event_loop_fn *new_loop_fn, + void *new_loop_user_data, + const struct aws_shutdown_callback_options *shutdown_options) { + + AWS_ASSERT(new_loop_fn); + + struct aws_event_loop_group *el_group = aws_mem_calloc(alloc, 1, sizeof(struct aws_event_loop_group)); + if (el_group == NULL) { + return NULL; + } + + el_group->allocator = alloc; + aws_ref_count_init( + &el_group->ref_count, el_group, (aws_simple_completion_callback *)s_aws_event_loop_group_shutdown_async); + aws_atomic_init_int(&el_group->current_index, 0); + + if (aws_array_list_init_dynamic(&el_group->event_loops, alloc, el_count, sizeof(struct aws_event_loop *))) { + goto on_error; + } + + for (uint16_t i = 0; i < el_count; ++i) { + struct aws_event_loop *loop = new_loop_fn(alloc, clock, new_loop_user_data); + + if (!loop) { + goto on_error; + } + + if (aws_array_list_push_back(&el_group->event_loops, (const void *)&loop)) { + aws_event_loop_destroy(loop); + goto on_error; + } + + if (aws_event_loop_run(loop)) { + goto on_error; + } + } + + if (shutdown_options != NULL) { + el_group->shutdown_options = *shutdown_options; + } + + aws_global_thread_creator_increment(); + + return el_group; + +on_error: + + s_aws_event_loop_group_shutdown_sync(el_group); + s_event_loop_group_thread_exit(el_group); + + return NULL; +} + +static struct aws_event_loop *default_new_event_loop( + struct aws_allocator *allocator, + aws_io_clock_fn *clock, + void *user_data) { + + (void)user_data; + return aws_event_loop_new_default(allocator, clock); +} + +struct aws_event_loop_group *aws_event_loop_group_new_default( + struct aws_allocator *alloc, + uint16_t max_threads, + const struct aws_shutdown_callback_options *shutdown_options) { + if (!max_threads) { + max_threads = (uint16_t)aws_system_info_processor_count(); + } + + return aws_event_loop_group_new( + alloc, aws_high_res_clock_get_ticks, max_threads, default_new_event_loop, NULL, shutdown_options); +} + +struct aws_event_loop_group *aws_event_loop_group_acquire(struct aws_event_loop_group *el_group) { + if (el_group != NULL) { + aws_ref_count_acquire(&el_group->ref_count); + } + + return el_group; +} + +void aws_event_loop_group_release(struct aws_event_loop_group *el_group) { + if (el_group != NULL) { + aws_ref_count_release(&el_group->ref_count); + } +} + +size_t aws_event_loop_group_get_loop_count(struct aws_event_loop_group *el_group) { + return aws_array_list_length(&el_group->event_loops); +} + +struct aws_event_loop *aws_event_loop_group_get_loop_at(struct aws_event_loop_group *el_group, size_t index) { + struct aws_event_loop *el = NULL; + aws_array_list_get_at(&el_group->event_loops, &el, index); + return el; +} + +struct aws_event_loop *aws_event_loop_group_get_next_loop(struct aws_event_loop_group *el_group) { + size_t loop_count = aws_array_list_length(&el_group->event_loops); + AWS_ASSERT(loop_count > 0); + if (loop_count == 0) { + return NULL; + } + + /* thread safety: atomic CAS to ensure we got the best loop, and that the index is within bounds */ + size_t old_index = 0; + size_t new_index = 0; + do { + old_index = aws_atomic_load_int(&el_group->current_index); + new_index = (old_index + 1) % loop_count; + } while (!aws_atomic_compare_exchange_int(&el_group->current_index, &old_index, new_index)); + + struct aws_event_loop *loop = NULL; + + /* if the fetch fails, we don't really care since loop will be NULL and error code will already be set. */ + aws_array_list_get_at(&el_group->event_loops, &loop, old_index); + return loop; +} + +static void s_object_removed(void *value) { + struct aws_event_loop_local_object *object = (struct aws_event_loop_local_object *)value; + if (object->on_object_removed) { + object->on_object_removed(object); + } +} + +int aws_event_loop_init_base(struct aws_event_loop *event_loop, struct aws_allocator *alloc, aws_io_clock_fn *clock) { + AWS_ZERO_STRUCT(*event_loop); + + event_loop->alloc = alloc; + event_loop->clock = clock; + + if (aws_hash_table_init(&event_loop->local_data, alloc, 20, aws_hash_ptr, aws_ptr_eq, NULL, s_object_removed)) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +void aws_event_loop_clean_up_base(struct aws_event_loop *event_loop) { + aws_hash_table_clean_up(&event_loop->local_data); +} + +void aws_event_loop_destroy(struct aws_event_loop *event_loop) { + if (!event_loop) { + return; + } + + AWS_ASSERT(event_loop->vtable && event_loop->vtable->destroy); + AWS_ASSERT(!aws_event_loop_thread_is_callers_thread(event_loop)); + + event_loop->vtable->destroy(event_loop); +} + +int aws_event_loop_fetch_local_object( + struct aws_event_loop *event_loop, + void *key, + struct aws_event_loop_local_object *obj) { + + AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); + + struct aws_hash_element *object = NULL; + if (!aws_hash_table_find(&event_loop->local_data, key, &object) && object) { + *obj = *(struct aws_event_loop_local_object *)object->value; + return AWS_OP_SUCCESS; + } + + return AWS_OP_ERR; +} + +int aws_event_loop_put_local_object(struct aws_event_loop *event_loop, struct aws_event_loop_local_object *obj) { + AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); + + struct aws_hash_element *object = NULL; + int was_created = 0; + + if (!aws_hash_table_create(&event_loop->local_data, obj->key, &object, &was_created)) { + object->key = obj->key; + object->value = obj; + return AWS_OP_SUCCESS; + } + + return AWS_OP_ERR; +} + +int aws_event_loop_remove_local_object( + struct aws_event_loop *event_loop, + void *key, + struct aws_event_loop_local_object *removed_obj) { + + AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); + + struct aws_hash_element existing_object; + AWS_ZERO_STRUCT(existing_object); + + int was_present = 0; + + struct aws_hash_element *remove_candidate = removed_obj ? &existing_object : NULL; + + if (!aws_hash_table_remove(&event_loop->local_data, key, remove_candidate, &was_present)) { + if (remove_candidate && was_present) { + *removed_obj = *(struct aws_event_loop_local_object *)existing_object.value; + } + + return AWS_OP_SUCCESS; + } + + return AWS_OP_ERR; +} + +int aws_event_loop_run(struct aws_event_loop *event_loop) { + AWS_ASSERT(event_loop->vtable && event_loop->vtable->run); + return event_loop->vtable->run(event_loop); +} + +int aws_event_loop_stop(struct aws_event_loop *event_loop) { + AWS_ASSERT(event_loop->vtable && event_loop->vtable->stop); + return event_loop->vtable->stop(event_loop); +} + +int aws_event_loop_wait_for_stop_completion(struct aws_event_loop *event_loop) { + AWS_ASSERT(!aws_event_loop_thread_is_callers_thread(event_loop)); + AWS_ASSERT(event_loop->vtable && event_loop->vtable->wait_for_stop_completion); + return event_loop->vtable->wait_for_stop_completion(event_loop); +} + +void aws_event_loop_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task) { + AWS_ASSERT(event_loop->vtable && event_loop->vtable->schedule_task_now); + AWS_ASSERT(task); + event_loop->vtable->schedule_task_now(event_loop, task); +} + +void aws_event_loop_schedule_task_future( + struct aws_event_loop *event_loop, + struct aws_task *task, + uint64_t run_at_nanos) { + + AWS_ASSERT(event_loop->vtable && event_loop->vtable->schedule_task_future); + AWS_ASSERT(task); + event_loop->vtable->schedule_task_future(event_loop, task, run_at_nanos); +} + +void aws_event_loop_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task) { + AWS_ASSERT(event_loop->vtable && event_loop->vtable->cancel_task); + AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); + AWS_ASSERT(task); + event_loop->vtable->cancel_task(event_loop, task); +} + +#if AWS_USE_IO_COMPLETION_PORTS + +int aws_event_loop_connect_handle_to_io_completion_port( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle) { + + AWS_ASSERT(event_loop->vtable && event_loop->vtable->connect_to_io_completion_port); + return event_loop->vtable->connect_to_io_completion_port(event_loop, handle); +} + +#else /* !AWS_USE_IO_COMPLETION_PORTS */ + +int aws_event_loop_subscribe_to_io_events( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + aws_event_loop_on_event_fn *on_event, + void *user_data) { + + AWS_ASSERT(event_loop->vtable && event_loop->vtable->subscribe_to_io_events); + return event_loop->vtable->subscribe_to_io_events(event_loop, handle, events, on_event, user_data); +} +#endif /* AWS_USE_IO_COMPLETION_PORTS */ + +int aws_event_loop_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { + AWS_ASSERT(aws_event_loop_thread_is_callers_thread(event_loop)); + AWS_ASSERT(event_loop->vtable && event_loop->vtable->unsubscribe_from_io_events); + return event_loop->vtable->unsubscribe_from_io_events(event_loop, handle); +} + +void aws_event_loop_free_io_event_resources(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { + AWS_ASSERT(event_loop && event_loop->vtable->free_io_event_resources); + event_loop->vtable->free_io_event_resources(handle->additional_data); +} + +bool aws_event_loop_thread_is_callers_thread(struct aws_event_loop *event_loop) { + AWS_ASSERT(event_loop->vtable && event_loop->vtable->is_on_callers_thread); + return event_loop->vtable->is_on_callers_thread(event_loop); +} + +int aws_event_loop_current_clock_time(struct aws_event_loop *event_loop, uint64_t *time_nanos) { + AWS_ASSERT(event_loop->clock); + return event_loop->clock(time_nanos); +} diff --git a/contrib/restricted/aws/aws-c-io/source/exponential_backoff_retry_strategy.c b/contrib/restricted/aws/aws-c-io/source/exponential_backoff_retry_strategy.c index f064f2118c..f45a101ca7 100644 --- a/contrib/restricted/aws/aws-c-io/source/exponential_backoff_retry_strategy.c +++ b/contrib/restricted/aws/aws-c-io/source/exponential_backoff_retry_strategy.c @@ -1,346 +1,346 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/retry_strategy.h> - -#include <aws/io/event_loop.h> -#include <aws/io/logging.h> - -#include <aws/common/clock.h> -#include <aws/common/device_random.h> -#include <aws/common/logging.h> -#include <aws/common/mutex.h> -#include <aws/common/task_scheduler.h> - -#include <inttypes.h> - -struct exponential_backoff_strategy { - struct aws_retry_strategy base; - struct aws_exponential_backoff_retry_options config; -}; - -struct exponential_backoff_retry_token { - struct aws_retry_token base; - struct aws_atomic_var current_retry_count; - struct aws_atomic_var last_backoff; - size_t max_retries; - uint64_t backoff_scale_factor_ns; - enum aws_exponential_backoff_jitter_mode jitter_mode; - /* Let's not make this worst by constantly moving across threads if we can help it */ - struct aws_event_loop *bound_loop; - uint64_t (*generate_random)(void); - struct aws_task retry_task; - - struct { - struct aws_mutex mutex; - aws_retry_strategy_on_retry_token_acquired_fn *acquired_fn; - aws_retry_strategy_on_retry_ready_fn *retry_ready_fn; - void *user_data; - } thread_data; -}; - -static void s_exponential_retry_destroy(struct aws_retry_strategy *retry_strategy) { - if (retry_strategy) { - aws_mem_release(retry_strategy->allocator, retry_strategy); - } -} - -static void s_exponential_retry_task(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)task; - - int error_code = AWS_ERROR_IO_OPERATION_CANCELLED; - if (status == AWS_TASK_STATUS_RUN_READY) { - error_code = AWS_OP_SUCCESS; - } - - struct exponential_backoff_retry_token *backoff_retry_token = arg; - aws_retry_strategy_on_retry_token_acquired_fn *acquired_fn = NULL; - aws_retry_strategy_on_retry_ready_fn *retry_ready_fn = NULL; - void *user_data = NULL; - - { /***** BEGIN CRITICAL SECTION *********/ - AWS_FATAL_ASSERT( - !aws_mutex_lock(&backoff_retry_token->thread_data.mutex) && "Retry token mutex acquisition failed"); - acquired_fn = backoff_retry_token->thread_data.acquired_fn; - retry_ready_fn = backoff_retry_token->thread_data.retry_ready_fn; - user_data = backoff_retry_token->thread_data.user_data; - backoff_retry_token->thread_data.user_data = NULL; - backoff_retry_token->thread_data.retry_ready_fn = NULL; - backoff_retry_token->thread_data.acquired_fn = NULL; - AWS_FATAL_ASSERT( - !aws_mutex_unlock(&backoff_retry_token->thread_data.mutex) && "Retry token mutex release failed"); - } /**** END CRITICAL SECTION ***********/ - - if (acquired_fn) { - AWS_LOGF_DEBUG( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "id=%p: Vending retry_token %p", - (void *)backoff_retry_token->base.retry_strategy, - (void *)&backoff_retry_token->base); - acquired_fn(backoff_retry_token->base.retry_strategy, error_code, &backoff_retry_token->base, user_data); - } else if (retry_ready_fn) { - AWS_LOGF_DEBUG( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "id=%p: Invoking retry_ready for token %p", - (void *)backoff_retry_token->base.retry_strategy, - (void *)&backoff_retry_token->base); - retry_ready_fn(&backoff_retry_token->base, error_code, user_data); - } -} - -static int s_exponential_retry_acquire_token( - struct aws_retry_strategy *retry_strategy, - const struct aws_byte_cursor *partition_id, - aws_retry_strategy_on_retry_token_acquired_fn *on_acquired, - void *user_data, - uint64_t timeout_ms) { - (void)partition_id; - /* no resource contention here so no timeouts. */ - (void)timeout_ms; - - struct exponential_backoff_retry_token *backoff_retry_token = - aws_mem_calloc(retry_strategy->allocator, 1, sizeof(struct exponential_backoff_retry_token)); - - if (!backoff_retry_token) { - return AWS_OP_ERR; - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "id=%p: Initializing retry token %p", - (void *)retry_strategy, - (void *)&backoff_retry_token->base); - - backoff_retry_token->base.allocator = retry_strategy->allocator; - backoff_retry_token->base.retry_strategy = retry_strategy; - aws_retry_strategy_acquire(retry_strategy); - backoff_retry_token->base.impl = backoff_retry_token; - - struct exponential_backoff_strategy *exponential_backoff_strategy = retry_strategy->impl; - backoff_retry_token->bound_loop = aws_event_loop_group_get_next_loop(exponential_backoff_strategy->config.el_group); - backoff_retry_token->max_retries = exponential_backoff_strategy->config.max_retries; - backoff_retry_token->backoff_scale_factor_ns = aws_timestamp_convert( - exponential_backoff_strategy->config.backoff_scale_factor_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL); - backoff_retry_token->jitter_mode = exponential_backoff_strategy->config.jitter_mode; - backoff_retry_token->generate_random = exponential_backoff_strategy->config.generate_random; - aws_atomic_init_int(&backoff_retry_token->current_retry_count, 0); - aws_atomic_init_int(&backoff_retry_token->last_backoff, 0); - - backoff_retry_token->thread_data.acquired_fn = on_acquired; - backoff_retry_token->thread_data.user_data = user_data; - AWS_FATAL_ASSERT( - !aws_mutex_init(&backoff_retry_token->thread_data.mutex) && "Retry strategy mutex initialization failed"); - - aws_task_init( - &backoff_retry_token->retry_task, - s_exponential_retry_task, - backoff_retry_token, - "aws_exponential_backoff_retry_task"); - aws_event_loop_schedule_task_now(backoff_retry_token->bound_loop, &backoff_retry_token->retry_task); - - return AWS_OP_SUCCESS; -} - -static inline uint64_t s_random_in_range(uint64_t from, uint64_t to, struct exponential_backoff_retry_token *token) { - uint64_t max = aws_max_u64(from, to); - uint64_t min = aws_min_u64(from, to); - - uint64_t diff = max - min; - - if (!diff) { - return 0; - } - - uint64_t random = token->generate_random(); - return min + random % (diff); -} - -typedef uint64_t(compute_backoff_fn)(struct exponential_backoff_retry_token *token); - -static uint64_t s_compute_no_jitter(struct exponential_backoff_retry_token *token) { - uint64_t retry_count = aws_min_u64(aws_atomic_load_int(&token->current_retry_count), 63); - return aws_mul_u64_saturating((uint64_t)1 << retry_count, token->backoff_scale_factor_ns); -} - -static uint64_t s_compute_full_jitter(struct exponential_backoff_retry_token *token) { - uint64_t non_jittered = s_compute_no_jitter(token); - return s_random_in_range(0, non_jittered, token); -} - -static uint64_t s_compute_deccorelated_jitter(struct exponential_backoff_retry_token *token) { - size_t last_backoff_val = aws_atomic_load_int(&token->last_backoff); - - if (!last_backoff_val) { - return s_compute_full_jitter(token); - } - - return s_random_in_range(token->backoff_scale_factor_ns, aws_mul_u64_saturating(last_backoff_val, 3), token); -} - -static compute_backoff_fn *s_backoff_compute_table[] = { - [AWS_EXPONENTIAL_BACKOFF_JITTER_DEFAULT] = s_compute_full_jitter, - [AWS_EXPONENTIAL_BACKOFF_JITTER_NONE] = s_compute_no_jitter, - [AWS_EXPONENTIAL_BACKOFF_JITTER_FULL] = s_compute_full_jitter, - [AWS_EXPONENTIAL_BACKOFF_JITTER_DECORRELATED] = s_compute_deccorelated_jitter, -}; - -static int s_exponential_retry_schedule_retry( - struct aws_retry_token *token, - enum aws_retry_error_type error_type, - aws_retry_strategy_on_retry_ready_fn *retry_ready, - void *user_data) { - struct exponential_backoff_retry_token *backoff_retry_token = token->impl; - - AWS_LOGF_DEBUG( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "id=%p: Attempting retry on token %p with error type %d", - (void *)backoff_retry_token->base.retry_strategy, - (void *)token, - error_type); - uint64_t schedule_at = 0; - - /* AWS_RETRY_ERROR_TYPE_CLIENT_ERROR does not count against your retry budget since you were responding to an - * improperly crafted request. */ - if (error_type != AWS_RETRY_ERROR_TYPE_CLIENT_ERROR) { - size_t retry_count = aws_atomic_load_int(&backoff_retry_token->current_retry_count); - - if (retry_count >= backoff_retry_token->max_retries) { - AWS_LOGF_WARN( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "id=%p: token %p has exhausted allowed retries. Retry count %zu max retries %zu", - (void *)backoff_retry_token->base.retry_strategy, - (void *)token, - backoff_retry_token->max_retries, - retry_count); - return aws_raise_error(AWS_IO_MAX_RETRIES_EXCEEDED); - } - - uint64_t backoff = s_backoff_compute_table[backoff_retry_token->jitter_mode](backoff_retry_token); - uint64_t current_time = 0; - - aws_event_loop_current_clock_time(backoff_retry_token->bound_loop, ¤t_time); - schedule_at = backoff + current_time; - aws_atomic_init_int(&backoff_retry_token->last_backoff, (size_t)backoff); - aws_atomic_fetch_add(&backoff_retry_token->current_retry_count, 1); - AWS_LOGF_DEBUG( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "id=%p: Computed backoff value of %" PRIu64 "ns on token %p", - (void *)backoff_retry_token->base.retry_strategy, - backoff, - (void *)token); - } - - bool already_scheduled = false; - - { /***** BEGIN CRITICAL SECTION *********/ - AWS_FATAL_ASSERT( - !aws_mutex_lock(&backoff_retry_token->thread_data.mutex) && "Retry token mutex acquisition failed"); - - if (backoff_retry_token->thread_data.user_data) { - already_scheduled = true; - } else { - backoff_retry_token->thread_data.retry_ready_fn = retry_ready; - backoff_retry_token->thread_data.user_data = user_data; - aws_task_init( - &backoff_retry_token->retry_task, - s_exponential_retry_task, - backoff_retry_token, - "aws_exponential_backoff_retry_task"); - } - AWS_FATAL_ASSERT( - !aws_mutex_unlock(&backoff_retry_token->thread_data.mutex) && "Retry token mutex release failed"); - } /**** END CRITICAL SECTION ***********/ - - if (already_scheduled) { - AWS_LOGF_ERROR( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "id=%p: retry token %p is already scheduled.", - (void *)backoff_retry_token->base.retry_strategy, - (void *)token) - return aws_raise_error(AWS_ERROR_INVALID_STATE); - } - - aws_event_loop_schedule_task_future(backoff_retry_token->bound_loop, &backoff_retry_token->retry_task, schedule_at); - return AWS_OP_SUCCESS; -} - -static int s_exponential_backoff_record_success(struct aws_retry_token *token) { - /* we don't do book keeping in this mode. */ - (void)token; - return AWS_OP_SUCCESS; -} - -static void s_exponential_backoff_release_token(struct aws_retry_token *token) { - if (token) { - aws_retry_strategy_release(token->retry_strategy); - struct exponential_backoff_retry_token *backoff_retry_token = token->impl; - aws_mutex_clean_up(&backoff_retry_token->thread_data.mutex); - aws_mem_release(token->allocator, backoff_retry_token); - } -} - -static struct aws_retry_strategy_vtable s_exponential_retry_vtable = { - .destroy = s_exponential_retry_destroy, - .acquire_token = s_exponential_retry_acquire_token, - .schedule_retry = s_exponential_retry_schedule_retry, - .record_success = s_exponential_backoff_record_success, - .release_token = s_exponential_backoff_release_token, -}; - -static uint64_t s_default_gen_rand(void) { - uint64_t res = 0; - aws_device_random_u64(&res); - return res; -} - -struct aws_retry_strategy *aws_retry_strategy_new_exponential_backoff( - struct aws_allocator *allocator, - const struct aws_exponential_backoff_retry_options *config) { - AWS_PRECONDITION(config); - AWS_PRECONDITION(config->el_group) - AWS_PRECONDITION(config->jitter_mode <= AWS_EXPONENTIAL_BACKOFF_JITTER_DECORRELATED) - AWS_PRECONDITION(config->max_retries) - - if (config->max_retries > 63 || !config->el_group || - config->jitter_mode > AWS_EXPONENTIAL_BACKOFF_JITTER_DECORRELATED) { - aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - return NULL; - } - - struct exponential_backoff_strategy *exponential_backoff_strategy = - aws_mem_calloc(allocator, 1, sizeof(struct exponential_backoff_strategy)); - - if (!exponential_backoff_strategy) { - return NULL; - } - - AWS_LOGF_INFO( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "id=%p: Initializing exponential backoff retry strategy with scale factor: %" PRIu32 - " jitter mode: %d and max retries %zu", - (void *)&exponential_backoff_strategy->base, - config->backoff_scale_factor_ms, - config->jitter_mode, - config->max_retries); - - exponential_backoff_strategy->base.allocator = allocator; - exponential_backoff_strategy->base.impl = exponential_backoff_strategy; - exponential_backoff_strategy->base.vtable = &s_exponential_retry_vtable; - aws_atomic_init_int(&exponential_backoff_strategy->base.ref_count, 1); - exponential_backoff_strategy->config = *config; - - if (!exponential_backoff_strategy->config.generate_random) { - exponential_backoff_strategy->config.generate_random = s_default_gen_rand; - } - - if (!exponential_backoff_strategy->config.max_retries) { - exponential_backoff_strategy->config.max_retries = 10; - } - - if (!exponential_backoff_strategy->config.backoff_scale_factor_ms) { - exponential_backoff_strategy->config.backoff_scale_factor_ms = 25; - } - - return &exponential_backoff_strategy->base; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/retry_strategy.h> + +#include <aws/io/event_loop.h> +#include <aws/io/logging.h> + +#include <aws/common/clock.h> +#include <aws/common/device_random.h> +#include <aws/common/logging.h> +#include <aws/common/mutex.h> +#include <aws/common/task_scheduler.h> + +#include <inttypes.h> + +struct exponential_backoff_strategy { + struct aws_retry_strategy base; + struct aws_exponential_backoff_retry_options config; +}; + +struct exponential_backoff_retry_token { + struct aws_retry_token base; + struct aws_atomic_var current_retry_count; + struct aws_atomic_var last_backoff; + size_t max_retries; + uint64_t backoff_scale_factor_ns; + enum aws_exponential_backoff_jitter_mode jitter_mode; + /* Let's not make this worst by constantly moving across threads if we can help it */ + struct aws_event_loop *bound_loop; + uint64_t (*generate_random)(void); + struct aws_task retry_task; + + struct { + struct aws_mutex mutex; + aws_retry_strategy_on_retry_token_acquired_fn *acquired_fn; + aws_retry_strategy_on_retry_ready_fn *retry_ready_fn; + void *user_data; + } thread_data; +}; + +static void s_exponential_retry_destroy(struct aws_retry_strategy *retry_strategy) { + if (retry_strategy) { + aws_mem_release(retry_strategy->allocator, retry_strategy); + } +} + +static void s_exponential_retry_task(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + + int error_code = AWS_ERROR_IO_OPERATION_CANCELLED; + if (status == AWS_TASK_STATUS_RUN_READY) { + error_code = AWS_OP_SUCCESS; + } + + struct exponential_backoff_retry_token *backoff_retry_token = arg; + aws_retry_strategy_on_retry_token_acquired_fn *acquired_fn = NULL; + aws_retry_strategy_on_retry_ready_fn *retry_ready_fn = NULL; + void *user_data = NULL; + + { /***** BEGIN CRITICAL SECTION *********/ + AWS_FATAL_ASSERT( + !aws_mutex_lock(&backoff_retry_token->thread_data.mutex) && "Retry token mutex acquisition failed"); + acquired_fn = backoff_retry_token->thread_data.acquired_fn; + retry_ready_fn = backoff_retry_token->thread_data.retry_ready_fn; + user_data = backoff_retry_token->thread_data.user_data; + backoff_retry_token->thread_data.user_data = NULL; + backoff_retry_token->thread_data.retry_ready_fn = NULL; + backoff_retry_token->thread_data.acquired_fn = NULL; + AWS_FATAL_ASSERT( + !aws_mutex_unlock(&backoff_retry_token->thread_data.mutex) && "Retry token mutex release failed"); + } /**** END CRITICAL SECTION ***********/ + + if (acquired_fn) { + AWS_LOGF_DEBUG( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "id=%p: Vending retry_token %p", + (void *)backoff_retry_token->base.retry_strategy, + (void *)&backoff_retry_token->base); + acquired_fn(backoff_retry_token->base.retry_strategy, error_code, &backoff_retry_token->base, user_data); + } else if (retry_ready_fn) { + AWS_LOGF_DEBUG( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "id=%p: Invoking retry_ready for token %p", + (void *)backoff_retry_token->base.retry_strategy, + (void *)&backoff_retry_token->base); + retry_ready_fn(&backoff_retry_token->base, error_code, user_data); + } +} + +static int s_exponential_retry_acquire_token( + struct aws_retry_strategy *retry_strategy, + const struct aws_byte_cursor *partition_id, + aws_retry_strategy_on_retry_token_acquired_fn *on_acquired, + void *user_data, + uint64_t timeout_ms) { + (void)partition_id; + /* no resource contention here so no timeouts. */ + (void)timeout_ms; + + struct exponential_backoff_retry_token *backoff_retry_token = + aws_mem_calloc(retry_strategy->allocator, 1, sizeof(struct exponential_backoff_retry_token)); + + if (!backoff_retry_token) { + return AWS_OP_ERR; + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "id=%p: Initializing retry token %p", + (void *)retry_strategy, + (void *)&backoff_retry_token->base); + + backoff_retry_token->base.allocator = retry_strategy->allocator; + backoff_retry_token->base.retry_strategy = retry_strategy; + aws_retry_strategy_acquire(retry_strategy); + backoff_retry_token->base.impl = backoff_retry_token; + + struct exponential_backoff_strategy *exponential_backoff_strategy = retry_strategy->impl; + backoff_retry_token->bound_loop = aws_event_loop_group_get_next_loop(exponential_backoff_strategy->config.el_group); + backoff_retry_token->max_retries = exponential_backoff_strategy->config.max_retries; + backoff_retry_token->backoff_scale_factor_ns = aws_timestamp_convert( + exponential_backoff_strategy->config.backoff_scale_factor_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL); + backoff_retry_token->jitter_mode = exponential_backoff_strategy->config.jitter_mode; + backoff_retry_token->generate_random = exponential_backoff_strategy->config.generate_random; + aws_atomic_init_int(&backoff_retry_token->current_retry_count, 0); + aws_atomic_init_int(&backoff_retry_token->last_backoff, 0); + + backoff_retry_token->thread_data.acquired_fn = on_acquired; + backoff_retry_token->thread_data.user_data = user_data; + AWS_FATAL_ASSERT( + !aws_mutex_init(&backoff_retry_token->thread_data.mutex) && "Retry strategy mutex initialization failed"); + + aws_task_init( + &backoff_retry_token->retry_task, + s_exponential_retry_task, + backoff_retry_token, + "aws_exponential_backoff_retry_task"); + aws_event_loop_schedule_task_now(backoff_retry_token->bound_loop, &backoff_retry_token->retry_task); + + return AWS_OP_SUCCESS; +} + +static inline uint64_t s_random_in_range(uint64_t from, uint64_t to, struct exponential_backoff_retry_token *token) { + uint64_t max = aws_max_u64(from, to); + uint64_t min = aws_min_u64(from, to); + + uint64_t diff = max - min; + + if (!diff) { + return 0; + } + + uint64_t random = token->generate_random(); + return min + random % (diff); +} + +typedef uint64_t(compute_backoff_fn)(struct exponential_backoff_retry_token *token); + +static uint64_t s_compute_no_jitter(struct exponential_backoff_retry_token *token) { + uint64_t retry_count = aws_min_u64(aws_atomic_load_int(&token->current_retry_count), 63); + return aws_mul_u64_saturating((uint64_t)1 << retry_count, token->backoff_scale_factor_ns); +} + +static uint64_t s_compute_full_jitter(struct exponential_backoff_retry_token *token) { + uint64_t non_jittered = s_compute_no_jitter(token); + return s_random_in_range(0, non_jittered, token); +} + +static uint64_t s_compute_deccorelated_jitter(struct exponential_backoff_retry_token *token) { + size_t last_backoff_val = aws_atomic_load_int(&token->last_backoff); + + if (!last_backoff_val) { + return s_compute_full_jitter(token); + } + + return s_random_in_range(token->backoff_scale_factor_ns, aws_mul_u64_saturating(last_backoff_val, 3), token); +} + +static compute_backoff_fn *s_backoff_compute_table[] = { + [AWS_EXPONENTIAL_BACKOFF_JITTER_DEFAULT] = s_compute_full_jitter, + [AWS_EXPONENTIAL_BACKOFF_JITTER_NONE] = s_compute_no_jitter, + [AWS_EXPONENTIAL_BACKOFF_JITTER_FULL] = s_compute_full_jitter, + [AWS_EXPONENTIAL_BACKOFF_JITTER_DECORRELATED] = s_compute_deccorelated_jitter, +}; + +static int s_exponential_retry_schedule_retry( + struct aws_retry_token *token, + enum aws_retry_error_type error_type, + aws_retry_strategy_on_retry_ready_fn *retry_ready, + void *user_data) { + struct exponential_backoff_retry_token *backoff_retry_token = token->impl; + + AWS_LOGF_DEBUG( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "id=%p: Attempting retry on token %p with error type %d", + (void *)backoff_retry_token->base.retry_strategy, + (void *)token, + error_type); + uint64_t schedule_at = 0; + + /* AWS_RETRY_ERROR_TYPE_CLIENT_ERROR does not count against your retry budget since you were responding to an + * improperly crafted request. */ + if (error_type != AWS_RETRY_ERROR_TYPE_CLIENT_ERROR) { + size_t retry_count = aws_atomic_load_int(&backoff_retry_token->current_retry_count); + + if (retry_count >= backoff_retry_token->max_retries) { + AWS_LOGF_WARN( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "id=%p: token %p has exhausted allowed retries. Retry count %zu max retries %zu", + (void *)backoff_retry_token->base.retry_strategy, + (void *)token, + backoff_retry_token->max_retries, + retry_count); + return aws_raise_error(AWS_IO_MAX_RETRIES_EXCEEDED); + } + + uint64_t backoff = s_backoff_compute_table[backoff_retry_token->jitter_mode](backoff_retry_token); + uint64_t current_time = 0; + + aws_event_loop_current_clock_time(backoff_retry_token->bound_loop, ¤t_time); + schedule_at = backoff + current_time; + aws_atomic_init_int(&backoff_retry_token->last_backoff, (size_t)backoff); + aws_atomic_fetch_add(&backoff_retry_token->current_retry_count, 1); + AWS_LOGF_DEBUG( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "id=%p: Computed backoff value of %" PRIu64 "ns on token %p", + (void *)backoff_retry_token->base.retry_strategy, + backoff, + (void *)token); + } + + bool already_scheduled = false; + + { /***** BEGIN CRITICAL SECTION *********/ + AWS_FATAL_ASSERT( + !aws_mutex_lock(&backoff_retry_token->thread_data.mutex) && "Retry token mutex acquisition failed"); + + if (backoff_retry_token->thread_data.user_data) { + already_scheduled = true; + } else { + backoff_retry_token->thread_data.retry_ready_fn = retry_ready; + backoff_retry_token->thread_data.user_data = user_data; + aws_task_init( + &backoff_retry_token->retry_task, + s_exponential_retry_task, + backoff_retry_token, + "aws_exponential_backoff_retry_task"); + } + AWS_FATAL_ASSERT( + !aws_mutex_unlock(&backoff_retry_token->thread_data.mutex) && "Retry token mutex release failed"); + } /**** END CRITICAL SECTION ***********/ + + if (already_scheduled) { + AWS_LOGF_ERROR( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "id=%p: retry token %p is already scheduled.", + (void *)backoff_retry_token->base.retry_strategy, + (void *)token) + return aws_raise_error(AWS_ERROR_INVALID_STATE); + } + + aws_event_loop_schedule_task_future(backoff_retry_token->bound_loop, &backoff_retry_token->retry_task, schedule_at); + return AWS_OP_SUCCESS; +} + +static int s_exponential_backoff_record_success(struct aws_retry_token *token) { + /* we don't do book keeping in this mode. */ + (void)token; + return AWS_OP_SUCCESS; +} + +static void s_exponential_backoff_release_token(struct aws_retry_token *token) { + if (token) { + aws_retry_strategy_release(token->retry_strategy); + struct exponential_backoff_retry_token *backoff_retry_token = token->impl; + aws_mutex_clean_up(&backoff_retry_token->thread_data.mutex); + aws_mem_release(token->allocator, backoff_retry_token); + } +} + +static struct aws_retry_strategy_vtable s_exponential_retry_vtable = { + .destroy = s_exponential_retry_destroy, + .acquire_token = s_exponential_retry_acquire_token, + .schedule_retry = s_exponential_retry_schedule_retry, + .record_success = s_exponential_backoff_record_success, + .release_token = s_exponential_backoff_release_token, +}; + +static uint64_t s_default_gen_rand(void) { + uint64_t res = 0; + aws_device_random_u64(&res); + return res; +} + +struct aws_retry_strategy *aws_retry_strategy_new_exponential_backoff( + struct aws_allocator *allocator, + const struct aws_exponential_backoff_retry_options *config) { + AWS_PRECONDITION(config); + AWS_PRECONDITION(config->el_group) + AWS_PRECONDITION(config->jitter_mode <= AWS_EXPONENTIAL_BACKOFF_JITTER_DECORRELATED) + AWS_PRECONDITION(config->max_retries) + + if (config->max_retries > 63 || !config->el_group || + config->jitter_mode > AWS_EXPONENTIAL_BACKOFF_JITTER_DECORRELATED) { + aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + return NULL; + } + + struct exponential_backoff_strategy *exponential_backoff_strategy = + aws_mem_calloc(allocator, 1, sizeof(struct exponential_backoff_strategy)); + + if (!exponential_backoff_strategy) { + return NULL; + } + + AWS_LOGF_INFO( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "id=%p: Initializing exponential backoff retry strategy with scale factor: %" PRIu32 + " jitter mode: %d and max retries %zu", + (void *)&exponential_backoff_strategy->base, + config->backoff_scale_factor_ms, + config->jitter_mode, + config->max_retries); + + exponential_backoff_strategy->base.allocator = allocator; + exponential_backoff_strategy->base.impl = exponential_backoff_strategy; + exponential_backoff_strategy->base.vtable = &s_exponential_retry_vtable; + aws_atomic_init_int(&exponential_backoff_strategy->base.ref_count, 1); + exponential_backoff_strategy->config = *config; + + if (!exponential_backoff_strategy->config.generate_random) { + exponential_backoff_strategy->config.generate_random = s_default_gen_rand; + } + + if (!exponential_backoff_strategy->config.max_retries) { + exponential_backoff_strategy->config.max_retries = 10; + } + + if (!exponential_backoff_strategy->config.backoff_scale_factor_ms) { + exponential_backoff_strategy->config.backoff_scale_factor_ms = 25; + } + + return &exponential_backoff_strategy->base; +} diff --git a/contrib/restricted/aws/aws-c-io/source/file_utils_shared.c b/contrib/restricted/aws/aws-c-io/source/file_utils_shared.c index 00a5f38800..091b27a66e 100644 --- a/contrib/restricted/aws/aws-c-io/source/file_utils_shared.c +++ b/contrib/restricted/aws/aws-c-io/source/file_utils_shared.c @@ -1,68 +1,68 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/file_utils.h> - -#include <aws/common/environment.h> -#include <aws/common/string.h> -#include <aws/io/logging.h> - -#include <errno.h> -#include <stdio.h> - -#ifdef _MSC_VER -# pragma warning(disable : 4996) /* Disable warnings about fopen() being insecure */ -#endif /* _MSC_VER */ - -int aws_byte_buf_init_from_file(struct aws_byte_buf *out_buf, struct aws_allocator *alloc, const char *filename) { - AWS_ZERO_STRUCT(*out_buf); - FILE *fp = fopen(filename, "rb"); - - if (fp) { - if (fseek(fp, 0L, SEEK_END)) { - AWS_LOGF_ERROR(AWS_LS_IO_FILE_UTILS, "static: Failed to seek file %s with errno %d", filename, errno); - fclose(fp); - return aws_translate_and_raise_io_error(errno); - } - - size_t allocation_size = (size_t)ftell(fp) + 1; - /* Tell the user that we allocate here and if success they're responsible for the free. */ - if (aws_byte_buf_init(out_buf, alloc, allocation_size)) { - fclose(fp); - return AWS_OP_ERR; - } - - /* Ensure compatibility with null-terminated APIs, but don't consider - * the null terminator part of the length of the payload */ - out_buf->len = out_buf->capacity - 1; - out_buf->buffer[out_buf->len] = 0; - - if (fseek(fp, 0L, SEEK_SET)) { - AWS_LOGF_ERROR(AWS_LS_IO_FILE_UTILS, "static: Failed to seek file %s with errno %d", filename, errno); - aws_byte_buf_clean_up(out_buf); - fclose(fp); - return aws_translate_and_raise_io_error(errno); - } - - size_t read = fread(out_buf->buffer, 1, out_buf->len, fp); - fclose(fp); - if (read < out_buf->len) { - AWS_LOGF_ERROR(AWS_LS_IO_FILE_UTILS, "static: Failed to read file %s with errno %d", filename, errno); - aws_secure_zero(out_buf->buffer, out_buf->len); - aws_byte_buf_clean_up(out_buf); - return aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); - } - - return AWS_OP_SUCCESS; - } - - AWS_LOGF_ERROR(AWS_LS_IO_FILE_UTILS, "static: Failed to open file %s with errno %d", filename, errno); - - return aws_translate_and_raise_io_error(errno); -} - -bool aws_is_any_directory_separator(char value) { - return value == '\\' || value == '/'; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/file_utils.h> + +#include <aws/common/environment.h> +#include <aws/common/string.h> +#include <aws/io/logging.h> + +#include <errno.h> +#include <stdio.h> + +#ifdef _MSC_VER +# pragma warning(disable : 4996) /* Disable warnings about fopen() being insecure */ +#endif /* _MSC_VER */ + +int aws_byte_buf_init_from_file(struct aws_byte_buf *out_buf, struct aws_allocator *alloc, const char *filename) { + AWS_ZERO_STRUCT(*out_buf); + FILE *fp = fopen(filename, "rb"); + + if (fp) { + if (fseek(fp, 0L, SEEK_END)) { + AWS_LOGF_ERROR(AWS_LS_IO_FILE_UTILS, "static: Failed to seek file %s with errno %d", filename, errno); + fclose(fp); + return aws_translate_and_raise_io_error(errno); + } + + size_t allocation_size = (size_t)ftell(fp) + 1; + /* Tell the user that we allocate here and if success they're responsible for the free. */ + if (aws_byte_buf_init(out_buf, alloc, allocation_size)) { + fclose(fp); + return AWS_OP_ERR; + } + + /* Ensure compatibility with null-terminated APIs, but don't consider + * the null terminator part of the length of the payload */ + out_buf->len = out_buf->capacity - 1; + out_buf->buffer[out_buf->len] = 0; + + if (fseek(fp, 0L, SEEK_SET)) { + AWS_LOGF_ERROR(AWS_LS_IO_FILE_UTILS, "static: Failed to seek file %s with errno %d", filename, errno); + aws_byte_buf_clean_up(out_buf); + fclose(fp); + return aws_translate_and_raise_io_error(errno); + } + + size_t read = fread(out_buf->buffer, 1, out_buf->len, fp); + fclose(fp); + if (read < out_buf->len) { + AWS_LOGF_ERROR(AWS_LS_IO_FILE_UTILS, "static: Failed to read file %s with errno %d", filename, errno); + aws_secure_zero(out_buf->buffer, out_buf->len); + aws_byte_buf_clean_up(out_buf); + return aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); + } + + return AWS_OP_SUCCESS; + } + + AWS_LOGF_ERROR(AWS_LS_IO_FILE_UTILS, "static: Failed to open file %s with errno %d", filename, errno); + + return aws_translate_and_raise_io_error(errno); +} + +bool aws_is_any_directory_separator(char value) { + return value == '\\' || value == '/'; +} diff --git a/contrib/restricted/aws/aws-c-io/source/host_resolver.c b/contrib/restricted/aws/aws-c-io/source/host_resolver.c index 2df732a904..4e6eeb40a3 100644 --- a/contrib/restricted/aws/aws-c-io/source/host_resolver.c +++ b/contrib/restricted/aws/aws-c-io/source/host_resolver.c @@ -1,1772 +1,1772 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/host_resolver.h> - -#include <aws/common/atomics.h> -#include <aws/common/clock.h> -#include <aws/common/condition_variable.h> -#include <aws/common/hash_table.h> -#include <aws/common/lru_cache.h> -#include <aws/common/mutex.h> -#include <aws/common/string.h> -#include <aws/common/thread.h> - -#include <aws/io/logging.h> - -#include <inttypes.h> - -const uint64_t NS_PER_SEC = 1000000000; - -int aws_host_address_copy(const struct aws_host_address *from, struct aws_host_address *to) { - to->allocator = from->allocator; - to->address = aws_string_new_from_string(to->allocator, from->address); - - if (!to->address) { - return AWS_OP_ERR; - } - - to->host = aws_string_new_from_string(to->allocator, from->host); - - if (!to->host) { - aws_string_destroy((void *)to->address); - return AWS_OP_ERR; - } - - to->record_type = from->record_type; - to->use_count = from->use_count; - to->connection_failure_count = from->connection_failure_count; - to->expiry = from->expiry; - to->weight = from->weight; - - return AWS_OP_SUCCESS; -} - -void aws_host_address_move(struct aws_host_address *from, struct aws_host_address *to) { - to->allocator = from->allocator; - to->address = from->address; - to->host = from->host; - to->record_type = from->record_type; - to->use_count = from->use_count; - to->connection_failure_count = from->connection_failure_count; - to->expiry = from->expiry; - to->weight = from->weight; - AWS_ZERO_STRUCT(*from); -} - -void aws_host_address_clean_up(struct aws_host_address *address) { - if (address->address) { - aws_string_destroy((void *)address->address); - } - if (address->host) { - aws_string_destroy((void *)address->host); - } - AWS_ZERO_STRUCT(*address); -} - -int aws_host_resolver_resolve_host( - struct aws_host_resolver *resolver, - const struct aws_string *host_name, - aws_on_host_resolved_result_fn *res, - struct aws_host_resolution_config *config, - void *user_data) { - AWS_ASSERT(resolver->vtable && resolver->vtable->resolve_host); - return resolver->vtable->resolve_host(resolver, host_name, res, config, user_data); -} - -int aws_host_resolver_purge_cache(struct aws_host_resolver *resolver) { - AWS_ASSERT(resolver->vtable && resolver->vtable->purge_cache); - return resolver->vtable->purge_cache(resolver); -} - -int aws_host_resolver_record_connection_failure(struct aws_host_resolver *resolver, struct aws_host_address *address) { - AWS_ASSERT(resolver->vtable && resolver->vtable->record_connection_failure); - return resolver->vtable->record_connection_failure(resolver, address); -} - -struct aws_host_listener *aws_host_resolver_add_host_listener( - struct aws_host_resolver *resolver, - const struct aws_host_listener_options *options) { - AWS_PRECONDITION(resolver); - AWS_PRECONDITION(resolver->vtable); - - if (resolver->vtable->add_host_listener) { - return resolver->vtable->add_host_listener(resolver, options); - } - - aws_raise_error(AWS_ERROR_UNSUPPORTED_OPERATION); - return NULL; -} - -int aws_host_resolver_remove_host_listener(struct aws_host_resolver *resolver, struct aws_host_listener *listener) { - AWS_PRECONDITION(resolver); - AWS_PRECONDITION(resolver->vtable); - - if (resolver->vtable->remove_host_listener) { - return resolver->vtable->remove_host_listener(resolver, listener); - } - - aws_raise_error(AWS_ERROR_UNSUPPORTED_OPERATION); - return AWS_OP_ERR; -} - -/* - * Used by both the resolver for its lifetime state as well as individual host entries for theirs. - */ -enum default_resolver_state { - DRS_ACTIVE, - DRS_SHUTTING_DOWN, -}; - -struct default_host_resolver { - struct aws_allocator *allocator; - - /* - * Mutually exclusion for the whole resolver, includes all member data and all host_entry_table operations. Once - * an entry is retrieved, this lock MAY be dropped but certain logic may hold both the resolver and the entry lock. - * The two locks must be taken in that order. - */ - struct aws_mutex resolver_lock; - - /* host_name (aws_string*) -> host_entry* */ - struct aws_hash_table host_entry_table; - - /* Hash table of listener entries per host name. We keep this decoupled from the host entry table to allow for - * listeners to be added/removed regardless of whether or not a corresponding host entry exists. - * - * Any time the listener list in the listener entry becomes empty, we remove the entry from the table. This - * includes when a resolver thread moves all of the available listeners to its local list. - */ - /* host_name (aws_string*) -> host_listener_entry* */ - struct aws_hash_table listener_entry_table; - - enum default_resolver_state state; - - /* - * Tracks the number of launched resolution threads that have not yet invoked their shutdown completion - * callback. - */ - uint32_t pending_host_entry_shutdown_completion_callbacks; -}; - -/* Default host resolver implementation for listener. */ -struct host_listener { - - /* Reference to the host resolver that owns this listener */ - struct aws_host_resolver *resolver; - - /* String copy of the host name */ - struct aws_string *host_name; - - /* User-supplied callbacks/user_data */ - aws_host_listener_resolved_address_fn *resolved_address_callback; - aws_host_listener_shutdown_fn *shutdown_callback; - void *user_data; - - /* Synchronous data, requires host resolver lock to read/modify*/ - /* TODO Add a lock-synced-data function for the host resolver, replacing all current places where the host resolver - * mutex is locked. */ - struct host_listener_synced_data { - /* It's important that the node structure is always first, so that the HOST_LISTENER_FROM_SYNCED_NODE macro - * works properly.*/ - struct aws_linked_list_node node; - uint32_t owned_by_resolver_thread : 1; - uint32_t pending_destroy : 1; - } synced_data; - - /* Threaded data that can only be used in the resolver thread. */ - struct host_listener_threaded_data { - /* It's important that the node structure is always first, so that the HOST_LISTENER_FROM_THREADED_NODE macro - * works properly.*/ - struct aws_linked_list_node node; - } threaded_data; -}; - -/* AWS_CONTAINER_OF does not compile under Clang when using a member in a nested structure, ie, synced_data.node or - * threaded_data.node. To get around this, we define two local macros that rely on the node being the first member of - * the synced_data/threaded_data structures.*/ -#define HOST_LISTENER_FROM_SYNCED_NODE(listener_node) \ - AWS_CONTAINER_OF((listener_node), struct host_listener, synced_data) -#define HOST_LISTENER_FROM_THREADED_NODE(listener_node) \ - AWS_CONTAINER_OF((listener_node), struct host_listener, threaded_data) - -/* Structure for holding all listeners for a particular host name. */ -struct host_listener_entry { - struct default_host_resolver *resolver; - - /* Linked list of struct host_listener */ - struct aws_linked_list listeners; -}; - -struct host_entry { - /* immutable post-creation */ - struct aws_allocator *allocator; - struct aws_host_resolver *resolver; - struct aws_thread resolver_thread; - const struct aws_string *host_name; - int64_t resolve_frequency_ns; - struct aws_host_resolution_config resolution_config; - - /* synchronized data and its lock */ - struct aws_mutex entry_lock; - struct aws_condition_variable entry_signal; - struct aws_cache *aaaa_records; - struct aws_cache *a_records; - struct aws_cache *failed_connection_aaaa_records; - struct aws_cache *failed_connection_a_records; - struct aws_linked_list pending_resolution_callbacks; - uint32_t resolves_since_last_request; - uint64_t last_resolve_request_timestamp_ns; - enum default_resolver_state state; -}; - -static void s_shutdown_host_entry(struct host_entry *entry) { - aws_mutex_lock(&entry->entry_lock); - entry->state = DRS_SHUTTING_DOWN; - aws_mutex_unlock(&entry->entry_lock); -} - -static struct aws_host_listener *default_add_host_listener( - struct aws_host_resolver *host_resolver, - const struct aws_host_listener_options *options); - -static int default_remove_host_listener( - struct aws_host_resolver *host_resolver, - struct aws_host_listener *listener_opaque); - -static void s_host_listener_entry_destroy(void *listener_entry_void); - -static struct host_listener *s_pop_host_listener_from_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - struct host_listener_entry **in_out_listener_entry); - -static int s_add_host_listener_to_listener_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - struct host_listener *listener); - -static void s_remove_host_listener_from_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - struct host_listener *listener); - -static void s_host_listener_destroy(struct host_listener *listener); - -/* - * resolver lock must be held before calling this function - */ -static void s_clear_default_resolver_entry_table(struct default_host_resolver *resolver) { - struct aws_hash_table *table = &resolver->host_entry_table; - for (struct aws_hash_iter iter = aws_hash_iter_begin(table); !aws_hash_iter_done(&iter); - aws_hash_iter_next(&iter)) { - struct host_entry *entry = iter.element.value; - s_shutdown_host_entry(entry); - } - - aws_hash_table_clear(table); -} - -static int resolver_purge_cache(struct aws_host_resolver *resolver) { - struct default_host_resolver *default_host_resolver = resolver->impl; - aws_mutex_lock(&default_host_resolver->resolver_lock); - s_clear_default_resolver_entry_table(default_host_resolver); - aws_mutex_unlock(&default_host_resolver->resolver_lock); - - return AWS_OP_SUCCESS; -} - -static void s_cleanup_default_resolver(struct aws_host_resolver *resolver) { - struct default_host_resolver *default_host_resolver = resolver->impl; - - aws_hash_table_clean_up(&default_host_resolver->host_entry_table); - aws_hash_table_clean_up(&default_host_resolver->listener_entry_table); - - aws_mutex_clean_up(&default_host_resolver->resolver_lock); - - aws_simple_completion_callback *shutdown_callback = resolver->shutdown_options.shutdown_callback_fn; - void *shutdown_completion_user_data = resolver->shutdown_options.shutdown_callback_user_data; - - aws_mem_release(resolver->allocator, resolver); - - /* invoke shutdown completion finally */ - if (shutdown_callback != NULL) { - shutdown_callback(shutdown_completion_user_data); - } - - aws_global_thread_creator_decrement(); -} - -static void resolver_destroy(struct aws_host_resolver *resolver) { - struct default_host_resolver *default_host_resolver = resolver->impl; - - bool cleanup_resolver = false; - - aws_mutex_lock(&default_host_resolver->resolver_lock); - - AWS_FATAL_ASSERT(default_host_resolver->state == DRS_ACTIVE); - - s_clear_default_resolver_entry_table(default_host_resolver); - default_host_resolver->state = DRS_SHUTTING_DOWN; - if (default_host_resolver->pending_host_entry_shutdown_completion_callbacks == 0) { - cleanup_resolver = true; - } - aws_mutex_unlock(&default_host_resolver->resolver_lock); - - if (cleanup_resolver) { - s_cleanup_default_resolver(resolver); - } -} - -struct pending_callback { - aws_on_host_resolved_result_fn *callback; - void *user_data; - struct aws_linked_list_node node; -}; - -static void s_clean_up_host_entry(struct host_entry *entry) { - if (entry == NULL) { - return; - } - - /* - * This can happen if the resolver's final reference drops while an unanswered query is pending on an entry. - * - * You could add an assertion that the resolver is in the shut down state if this condition hits but that - * requires additional locking just to make the assert. - */ - if (!aws_linked_list_empty(&entry->pending_resolution_callbacks)) { - aws_raise_error(AWS_IO_DNS_HOST_REMOVED_FROM_CACHE); - } - - while (!aws_linked_list_empty(&entry->pending_resolution_callbacks)) { - struct aws_linked_list_node *resolution_callback_node = - aws_linked_list_pop_front(&entry->pending_resolution_callbacks); - struct pending_callback *pending_callback = - AWS_CONTAINER_OF(resolution_callback_node, struct pending_callback, node); - - pending_callback->callback( - entry->resolver, entry->host_name, AWS_IO_DNS_HOST_REMOVED_FROM_CACHE, NULL, pending_callback->user_data); - - aws_mem_release(entry->allocator, pending_callback); - } - - aws_cache_destroy(entry->aaaa_records); - aws_cache_destroy(entry->a_records); - aws_cache_destroy(entry->failed_connection_a_records); - aws_cache_destroy(entry->failed_connection_aaaa_records); - aws_string_destroy((void *)entry->host_name); - aws_mem_release(entry->allocator, entry); -} - -static void s_on_host_entry_shutdown_completion(void *user_data) { - struct host_entry *entry = user_data; - struct aws_host_resolver *resolver = entry->resolver; - struct default_host_resolver *default_host_resolver = resolver->impl; - - s_clean_up_host_entry(entry); - - bool cleanup_resolver = false; - - aws_mutex_lock(&default_host_resolver->resolver_lock); - --default_host_resolver->pending_host_entry_shutdown_completion_callbacks; - if (default_host_resolver->state == DRS_SHUTTING_DOWN && - default_host_resolver->pending_host_entry_shutdown_completion_callbacks == 0) { - cleanup_resolver = true; - } - aws_mutex_unlock(&default_host_resolver->resolver_lock); - - if (cleanup_resolver) { - s_cleanup_default_resolver(resolver); - } -} - -/* this only ever gets called after resolution has already run. We expect that the entry's lock - has been acquired for writing before this function is called and released afterwards. */ -static inline void process_records( - struct aws_allocator *allocator, - struct aws_cache *records, - struct aws_cache *failed_records) { - uint64_t timestamp = 0; - aws_sys_clock_get_ticks(×tamp); - - size_t record_count = aws_cache_get_element_count(records); - size_t expired_records = 0; - - /* since this only ever gets called after resolution has already run, we're in a dns outage - * if everything is expired. Leave an element so we can keep trying. */ - for (size_t index = 0; index < record_count && expired_records < record_count - 1; ++index) { - struct aws_host_address *lru_element = aws_lru_cache_use_lru_element(records); - - if (lru_element->expiry < timestamp) { - AWS_LOGF_DEBUG( - AWS_LS_IO_DNS, - "static: purging expired record %s for %s", - lru_element->address->bytes, - lru_element->host->bytes); - expired_records++; - aws_cache_remove(records, lru_element->address); - } - } - - record_count = aws_cache_get_element_count(records); - AWS_LOGF_TRACE(AWS_LS_IO_DNS, "static: remaining record count for host %d", (int)record_count); - - /* if we don't have any known good addresses, take the least recently used, but not expired address with a history - * of spotty behavior and upgrade it for reuse. If it's expired, leave it and let the resolve fail. Better to fail - * than accidentally give a kids' app an IP address to somebody's adult website when the IP address gets rebound to - * a different endpoint. The moral of the story here is to not disable SSL verification! */ - if (!record_count) { - size_t failed_count = aws_cache_get_element_count(failed_records); - for (size_t index = 0; index < failed_count; ++index) { - struct aws_host_address *lru_element = aws_lru_cache_use_lru_element(failed_records); - - if (timestamp < lru_element->expiry) { - struct aws_host_address *to_add = aws_mem_acquire(allocator, sizeof(struct aws_host_address)); - - if (to_add && !aws_host_address_copy(lru_element, to_add)) { - AWS_LOGF_INFO( - AWS_LS_IO_DNS, - "static: promoting spotty record %s for %s back to good list", - lru_element->address->bytes, - lru_element->host->bytes); - if (aws_cache_put(records, to_add->address, to_add)) { - aws_mem_release(allocator, to_add); - continue; - } - /* we only want to promote one per process run.*/ - aws_cache_remove(failed_records, lru_element->address); - break; - } - - if (to_add) { - aws_mem_release(allocator, to_add); - } - } - } - } -} - -static int resolver_record_connection_failure(struct aws_host_resolver *resolver, struct aws_host_address *address) { - struct default_host_resolver *default_host_resolver = resolver->impl; - - AWS_LOGF_INFO( - AWS_LS_IO_DNS, - "id=%p: recording failure for record %s for %s, moving to bad list", - (void *)resolver, - address->address->bytes, - address->host->bytes); - - aws_mutex_lock(&default_host_resolver->resolver_lock); - - struct aws_hash_element *element = NULL; - if (aws_hash_table_find(&default_host_resolver->host_entry_table, address->host, &element)) { - aws_mutex_unlock(&default_host_resolver->resolver_lock); - return AWS_OP_ERR; - } - - struct host_entry *host_entry = NULL; - if (element != NULL) { - host_entry = element->value; - AWS_FATAL_ASSERT(host_entry); - } - - if (host_entry) { - struct aws_host_address *cached_address = NULL; - - aws_mutex_lock(&host_entry->entry_lock); - aws_mutex_unlock(&default_host_resolver->resolver_lock); - struct aws_cache *address_table = - address->record_type == AWS_ADDRESS_RECORD_TYPE_AAAA ? host_entry->aaaa_records : host_entry->a_records; - - struct aws_cache *failed_table = address->record_type == AWS_ADDRESS_RECORD_TYPE_AAAA - ? host_entry->failed_connection_aaaa_records - : host_entry->failed_connection_a_records; - - aws_cache_find(address_table, address->address, (void **)&cached_address); - - struct aws_host_address *address_copy = NULL; - if (cached_address) { - address_copy = aws_mem_acquire(resolver->allocator, sizeof(struct aws_host_address)); - - if (!address_copy || aws_host_address_copy(cached_address, address_copy)) { - goto error_host_entry_cleanup; - } - - if (aws_cache_remove(address_table, cached_address->address)) { - goto error_host_entry_cleanup; - } - - address_copy->connection_failure_count += 1; - - if (aws_cache_put(failed_table, address_copy->address, address_copy)) { - goto error_host_entry_cleanup; - } - } else { - if (aws_cache_find(failed_table, address->address, (void **)&cached_address)) { - goto error_host_entry_cleanup; - } - - if (cached_address) { - cached_address->connection_failure_count += 1; - } - } - aws_mutex_unlock(&host_entry->entry_lock); - return AWS_OP_SUCCESS; - - error_host_entry_cleanup: - if (address_copy) { - aws_host_address_clean_up(address_copy); - aws_mem_release(resolver->allocator, address_copy); - } - aws_mutex_unlock(&host_entry->entry_lock); - return AWS_OP_ERR; - } - - aws_mutex_unlock(&default_host_resolver->resolver_lock); - - return AWS_OP_SUCCESS; -} - -/* - * A bunch of convenience functions for the host resolver background thread function - */ - -static struct aws_host_address *s_find_cached_address_aux( - struct aws_cache *primary_records, - struct aws_cache *fallback_records, - const struct aws_string *address) { - - struct aws_host_address *found = NULL; - aws_cache_find(primary_records, address, (void **)&found); - if (found == NULL) { - aws_cache_find(fallback_records, address, (void **)&found); - } - - return found; -} - -/* - * Looks in both the good and failed connection record sets for a given host record - */ -static struct aws_host_address *s_find_cached_address( - struct host_entry *entry, - const struct aws_string *address, - enum aws_address_record_type record_type) { - - switch (record_type) { - case AWS_ADDRESS_RECORD_TYPE_AAAA: - return s_find_cached_address_aux(entry->aaaa_records, entry->failed_connection_aaaa_records, address); - - case AWS_ADDRESS_RECORD_TYPE_A: - return s_find_cached_address_aux(entry->a_records, entry->failed_connection_a_records, address); - - default: - return NULL; - } -} - -static struct aws_host_address *s_get_lru_address_aux( - struct aws_cache *primary_records, - struct aws_cache *fallback_records) { - - struct aws_host_address *address = aws_lru_cache_use_lru_element(primary_records); - if (address == NULL) { - aws_lru_cache_use_lru_element(fallback_records); - } - - return address; -} - -/* - * Looks in both the good and failed connection record sets for the LRU host record - */ -static struct aws_host_address *s_get_lru_address(struct host_entry *entry, enum aws_address_record_type record_type) { - switch (record_type) { - case AWS_ADDRESS_RECORD_TYPE_AAAA: - return s_get_lru_address_aux(entry->aaaa_records, entry->failed_connection_aaaa_records); - - case AWS_ADDRESS_RECORD_TYPE_A: - return s_get_lru_address_aux(entry->a_records, entry->failed_connection_a_records); - - default: - return NULL; - } -} - -static void s_clear_address_list(struct aws_array_list *address_list) { - for (size_t i = 0; i < aws_array_list_length(address_list); ++i) { - struct aws_host_address *address = NULL; - aws_array_list_get_at_ptr(address_list, (void **)&address, i); - aws_host_address_clean_up(address); - } - - aws_array_list_clear(address_list); -} - -static void s_update_address_cache( - struct host_entry *host_entry, - struct aws_array_list *address_list, - uint64_t new_expiration, - struct aws_array_list *out_new_address_list) { - - AWS_PRECONDITION(host_entry); - AWS_PRECONDITION(address_list); - AWS_PRECONDITION(out_new_address_list); - - for (size_t i = 0; i < aws_array_list_length(address_list); ++i) { - struct aws_host_address *fresh_resolved_address = NULL; - aws_array_list_get_at_ptr(address_list, (void **)&fresh_resolved_address, i); - - struct aws_host_address *address_to_cache = - s_find_cached_address(host_entry, fresh_resolved_address->address, fresh_resolved_address->record_type); - - if (address_to_cache) { - address_to_cache->expiry = new_expiration; - AWS_LOGF_TRACE( - AWS_LS_IO_DNS, - "static: updating expiry for %s for host %s to %llu", - address_to_cache->address->bytes, - host_entry->host_name->bytes, - (unsigned long long)new_expiration); - } else { - address_to_cache = aws_mem_acquire(host_entry->allocator, sizeof(struct aws_host_address)); - - aws_host_address_move(fresh_resolved_address, address_to_cache); - address_to_cache->expiry = new_expiration; - - struct aws_cache *address_table = address_to_cache->record_type == AWS_ADDRESS_RECORD_TYPE_AAAA - ? host_entry->aaaa_records - : host_entry->a_records; - - if (aws_cache_put(address_table, address_to_cache->address, address_to_cache)) { - AWS_LOGF_ERROR( - AWS_LS_IO_DNS, - "static: could not add new address to host entry cache for host '%s' in " - "s_update_address_cache.", - host_entry->host_name->bytes); - - continue; - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_DNS, - "static: new address resolved %s for host %s caching", - address_to_cache->address->bytes, - host_entry->host_name->bytes); - - struct aws_host_address new_address_copy; - - if (aws_host_address_copy(address_to_cache, &new_address_copy)) { - AWS_LOGF_ERROR( - AWS_LS_IO_DNS, - "static: could not copy address for new-address list for host '%s' in s_update_address_cache.", - host_entry->host_name->bytes); - - continue; - } - - if (aws_array_list_push_back(out_new_address_list, &new_address_copy)) { - aws_host_address_clean_up(&new_address_copy); - - AWS_LOGF_ERROR( - AWS_LS_IO_DNS, - "static: could not push address to new-address list for host '%s' in s_update_address_cache.", - host_entry->host_name->bytes); - - continue; - } - } - } -} - -static void s_copy_address_into_callback_set( - struct aws_host_address *address, - struct aws_array_list *callback_addresses, - const struct aws_string *host_name) { - - if (address) { - address->use_count += 1; - - /* - * This is the worst. - * - * We have to copy the cache address while we still have a write lock. Otherwise, connection failures - * can sneak in and destroy our address by moving the address to/from the various lru caches. - * - * But there's no nice copy construction into an array list, so we get to - * (1) Push a zeroed dummy element onto the array list - * (2) Get its pointer - * (3) Call aws_host_address_copy onto it. If that fails, pop the dummy element. - */ - struct aws_host_address dummy; - AWS_ZERO_STRUCT(dummy); - - if (aws_array_list_push_back(callback_addresses, &dummy)) { - return; - } - - struct aws_host_address *dest_copy = NULL; - aws_array_list_get_at_ptr( - callback_addresses, (void **)&dest_copy, aws_array_list_length(callback_addresses) - 1); - AWS_FATAL_ASSERT(dest_copy != NULL); - - if (aws_host_address_copy(address, dest_copy)) { - aws_array_list_pop_back(callback_addresses); - return; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_DNS, - "static: vending address %s for host %s to caller", - address->address->bytes, - host_name->bytes); - } -} - -static bool s_host_entry_finished_pred(void *user_data) { - struct host_entry *entry = user_data; - - return entry->state == DRS_SHUTTING_DOWN; -} - -/* Move all of the listeners in the host-resolver-owned listener entry to the resolver thread owned list. */ -/* Assumes resolver_lock is held so that we can pop from the listener entry and access the listener's synced_data. */ -static void s_resolver_thread_move_listeners_from_listener_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - struct aws_linked_list *listener_list) { - - AWS_PRECONDITION(resolver); - AWS_PRECONDITION(host_name); - AWS_PRECONDITION(listener_list); - - struct host_listener_entry *listener_entry = NULL; - struct host_listener *listener = s_pop_host_listener_from_entry(resolver, host_name, &listener_entry); - - while (listener != NULL) { - /* Flag this listener as in-use by the resolver thread so that it can't be destroyed from outside of that - * thread. */ - listener->synced_data.owned_by_resolver_thread = true; - - aws_linked_list_push_back(listener_list, &listener->threaded_data.node); - - listener = s_pop_host_listener_from_entry(resolver, host_name, &listener_entry); - } -} - -/* When the thread is ready to exit, we move all of the listeners back to the host-resolver-owned listener entry.*/ -/* Assumes that we have already removed all pending_destroy listeners via - * s_resolver_thread_cull_pending_destroy_listeners. */ -/* Assumes resolver_lock is held so that we can write to the listener entry and read/write from the listener's - * synced_data. */ -static int s_resolver_thread_move_listeners_to_listener_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - struct aws_linked_list *listener_list) { - - AWS_PRECONDITION(resolver); - AWS_PRECONDITION(host_name); - AWS_PRECONDITION(listener_list); - - int result = 0; - size_t num_listeners_not_moved = 0; - - while (!aws_linked_list_empty(listener_list)) { - struct aws_linked_list_node *listener_node = aws_linked_list_pop_back(listener_list); - struct host_listener *listener = HOST_LISTENER_FROM_THREADED_NODE(listener_node); - - /* Flag this listener as no longer in-use by the resolver thread. */ - listener->synced_data.owned_by_resolver_thread = false; - - AWS_ASSERT(!listener->synced_data.pending_destroy); - - if (s_add_host_listener_to_listener_entry(resolver, host_name, listener)) { - result = AWS_OP_ERR; - ++num_listeners_not_moved; - } - } - - if (result == AWS_OP_ERR) { - AWS_LOGF_ERROR( - AWS_LS_IO_DNS, - "static: could not move %" PRIu64 " listeners back to listener entry", - (uint64_t)num_listeners_not_moved); - } - - return result; -} - -/* Remove the listeners from the resolver-thread-owned listener_list that are marked pending destroy, and move them into - * the destroy list. */ -/* Assumes resolver_lock is held. (This lock is necessary for reading from the listener's synced_data.) */ -static void s_resolver_thread_cull_pending_destroy_listeners( - struct aws_linked_list *listener_list, - struct aws_linked_list *listener_destroy_list) { - - AWS_PRECONDITION(listener_list); - AWS_PRECONDITION(listener_destroy_list); - - struct aws_linked_list_node *listener_node = aws_linked_list_begin(listener_list); - - /* Find all listeners in our current list that are marked for destroy. */ - while (listener_node != aws_linked_list_end(listener_list)) { - struct host_listener *listener = HOST_LISTENER_FROM_THREADED_NODE(listener_node); - - /* Advance our node pointer early to allow for a removal. */ - listener_node = aws_linked_list_next(listener_node); - - /* If listener is pending destroy, remove it from the local list, and push it into the destroy list. */ - if (listener->synced_data.pending_destroy) { - aws_linked_list_remove(&listener->threaded_data.node); - aws_linked_list_push_back(listener_destroy_list, &listener->threaded_data.node); - } - } -} - -/* Destroys all of the listeners in the resolver thread's destroy list. */ -/* Assumes no lock is held. (We don't want any lock held so that any shutdown callbacks happen outside of a lock.) */ -static void s_resolver_thread_destroy_listeners(struct aws_linked_list *listener_destroy_list) { - - AWS_PRECONDITION(listener_destroy_list); - - while (!aws_linked_list_empty(listener_destroy_list)) { - struct aws_linked_list_node *listener_node = aws_linked_list_pop_back(listener_destroy_list); - struct host_listener *listener = HOST_LISTENER_FROM_THREADED_NODE(listener_node); - s_host_listener_destroy(listener); - } -} - -/* Notify all listeners with resolve address callbacks, and also clean up any that are waiting to be cleaned up. */ -/* Assumes no lock is held. The listener_list is owned by the resolver thread, so no lock is necessary. We also don't - * want a lock held when calling the resolver-address callback.*/ -static void s_resolver_thread_notify_listeners( - const struct aws_array_list *new_address_list, - struct aws_linked_list *listener_list) { - - AWS_PRECONDITION(new_address_list); - AWS_PRECONDITION(listener_list); - - /* Go through each listener in our list. */ - for (struct aws_linked_list_node *listener_node = aws_linked_list_begin(listener_list); - listener_node != aws_linked_list_end(listener_list); - listener_node = aws_linked_list_next(listener_node)) { - struct host_listener *listener = HOST_LISTENER_FROM_THREADED_NODE(listener_node); - - /* If we have new adddresses, notify the resolved-address callback if one exists */ - if (aws_array_list_length(new_address_list) > 0 && listener->resolved_address_callback != NULL) { - listener->resolved_address_callback( - (struct aws_host_listener *)listener, new_address_list, listener->user_data); - } - } -} - -static void resolver_thread_fn(void *arg) { - struct host_entry *host_entry = arg; - - size_t unsolicited_resolve_max = host_entry->resolution_config.max_ttl; - if (unsolicited_resolve_max == 0) { - unsolicited_resolve_max = 1; - } - - uint64_t max_no_solicitation_interval = - aws_timestamp_convert(unsolicited_resolve_max, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL); - - struct aws_array_list address_list; - if (aws_array_list_init_dynamic(&address_list, host_entry->allocator, 4, sizeof(struct aws_host_address))) { - return; - } - - struct aws_array_list new_address_list; - if (aws_array_list_init_dynamic(&new_address_list, host_entry->allocator, 4, sizeof(struct aws_host_address))) { - aws_array_list_clean_up(&address_list); - return; - } - - struct aws_linked_list listener_list; - aws_linked_list_init(&listener_list); - - struct aws_linked_list listener_destroy_list; - aws_linked_list_init(&listener_destroy_list); - - bool keep_going = true; - while (keep_going) { - - AWS_LOGF_TRACE(AWS_LS_IO_DNS, "static, resolving %s", aws_string_c_str(host_entry->host_name)); - - /* resolve and then process each record */ - int err_code = AWS_ERROR_SUCCESS; - if (host_entry->resolution_config.impl( - host_entry->allocator, host_entry->host_name, &address_list, host_entry->resolution_config.impl_data)) { - - err_code = aws_last_error(); - } - uint64_t timestamp = 0; - aws_sys_clock_get_ticks(×tamp); - uint64_t new_expiry = timestamp + (host_entry->resolution_config.max_ttl * NS_PER_SEC); - - struct aws_linked_list pending_resolve_copy; - aws_linked_list_init(&pending_resolve_copy); - - /* - * Within the lock we - * (1) Update the cache with the newly resolved addresses - * (2) Process all held addresses looking for expired or promotable ones - * (3) Prep for callback invocations - */ - aws_mutex_lock(&host_entry->entry_lock); - - if (!err_code) { - s_update_address_cache(host_entry, &address_list, new_expiry, &new_address_list); - } - - /* - * process and clean_up records in the entry. occasionally, failed connect records will be upgraded - * for retry. - */ - process_records(host_entry->allocator, host_entry->aaaa_records, host_entry->failed_connection_aaaa_records); - process_records(host_entry->allocator, host_entry->a_records, host_entry->failed_connection_a_records); - - aws_linked_list_swap_contents(&pending_resolve_copy, &host_entry->pending_resolution_callbacks); - - aws_mutex_unlock(&host_entry->entry_lock); - - /* - * Clean up resolved addressed outside of the lock - */ - s_clear_address_list(&address_list); - - struct aws_host_address address_array[2]; - AWS_ZERO_ARRAY(address_array); - - /* - * Perform the actual subscriber notifications - */ - while (!aws_linked_list_empty(&pending_resolve_copy)) { - struct aws_linked_list_node *resolution_callback_node = aws_linked_list_pop_front(&pending_resolve_copy); - struct pending_callback *pending_callback = - AWS_CONTAINER_OF(resolution_callback_node, struct pending_callback, node); - - struct aws_array_list callback_address_list; - aws_array_list_init_static(&callback_address_list, address_array, 2, sizeof(struct aws_host_address)); - - aws_mutex_lock(&host_entry->entry_lock); - s_copy_address_into_callback_set( - s_get_lru_address(host_entry, AWS_ADDRESS_RECORD_TYPE_AAAA), - &callback_address_list, - host_entry->host_name); - s_copy_address_into_callback_set( - s_get_lru_address(host_entry, AWS_ADDRESS_RECORD_TYPE_A), - &callback_address_list, - host_entry->host_name); - aws_mutex_unlock(&host_entry->entry_lock); - - AWS_ASSERT(err_code != AWS_ERROR_SUCCESS || aws_array_list_length(&callback_address_list) > 0); - - if (aws_array_list_length(&callback_address_list) > 0) { - pending_callback->callback( - host_entry->resolver, - host_entry->host_name, - AWS_OP_SUCCESS, - &callback_address_list, - pending_callback->user_data); - - } else { - pending_callback->callback( - host_entry->resolver, host_entry->host_name, err_code, NULL, pending_callback->user_data); - } - - s_clear_address_list(&callback_address_list); - - aws_mem_release(host_entry->allocator, pending_callback); - } - - aws_mutex_lock(&host_entry->entry_lock); - - ++host_entry->resolves_since_last_request; - - /* wait for a quit notification or the base resolve frequency time interval */ - aws_condition_variable_wait_for_pred( - &host_entry->entry_signal, - &host_entry->entry_lock, - host_entry->resolve_frequency_ns, - s_host_entry_finished_pred, - host_entry); - - aws_mutex_unlock(&host_entry->entry_lock); - - /* - * This is a bit awkward that we unlock the entry and then relock both the resolver and the entry, but it - * is mandatory that -- in order to maintain the consistent view of the resolver table (entry exist => entry - * is alive and can be queried) -- we have the resolver lock as well before making the decision to remove - * the entry from the table and terminate the thread. - */ - struct default_host_resolver *resolver = host_entry->resolver->impl; - aws_mutex_lock(&resolver->resolver_lock); - - /* Remove any listeners from our listener list that have been marked pending destroy, moving them into the - * destroy list. */ - s_resolver_thread_cull_pending_destroy_listeners(&listener_list, &listener_destroy_list); - - /* Grab any listeners on the listener entry, moving them into the local list. */ - s_resolver_thread_move_listeners_from_listener_entry(resolver, host_entry->host_name, &listener_list); - - aws_mutex_lock(&host_entry->entry_lock); - - uint64_t now = 0; - aws_sys_clock_get_ticks(&now); - - /* - * Ideally this should just be time-based, but given the non-determinism of waits (and spurious wake ups) and - * clock time, I feel much more comfortable keeping an additional constraint in terms of iterations. - * - * Note that we have the entry lock now and if any queries have arrived since our last resolution, - * resolves_since_last_request will be 0 or 1 (depending on timing) and so, regardless of wait and wake up - * timings, this check will always fail in that case leading to another iteration to satisfy the pending - * query(ies). - * - * The only way we terminate the loop with pending queries is if the resolver itself has no more references - * to it and is going away. In that case, the pending queries will be completed (with failure) by the - * final clean up of this entry. - */ - if (host_entry->resolves_since_last_request > unsolicited_resolve_max && - host_entry->last_resolve_request_timestamp_ns + max_no_solicitation_interval < now) { - host_entry->state = DRS_SHUTTING_DOWN; - } - - keep_going = host_entry->state == DRS_ACTIVE; - if (!keep_going) { - aws_hash_table_remove(&resolver->host_entry_table, host_entry->host_name, NULL, NULL); - - /* Move any local listeners we have back to the listener entry */ - if (s_resolver_thread_move_listeners_to_listener_entry(resolver, host_entry->host_name, &listener_list)) { - AWS_LOGF_ERROR(AWS_LS_IO_DNS, "static: could not clean up all listeners from resolver thread."); - } - } - - aws_mutex_unlock(&host_entry->entry_lock); - aws_mutex_unlock(&resolver->resolver_lock); - - /* Destroy any listeners in our destroy list. */ - s_resolver_thread_destroy_listeners(&listener_destroy_list); - - /* Notify our local listeners of new addresses. */ - s_resolver_thread_notify_listeners(&new_address_list, &listener_list); - - s_clear_address_list(&new_address_list); - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_DNS, - "static: Either no requests have been made for an address for %s for the duration " - "of the ttl, or this thread is being forcibly shutdown. Killing thread.", - host_entry->host_name->bytes) - - aws_array_list_clean_up(&address_list); - aws_array_list_clean_up(&new_address_list); - - /* please don't fail */ - aws_thread_current_at_exit(s_on_host_entry_shutdown_completion, host_entry); -} - -static void on_address_value_removed(void *value) { - struct aws_host_address *host_address = value; - - AWS_LOGF_DEBUG( - AWS_LS_IO_DNS, - "static: purging address %s for host %s from " - "the cache due to cache eviction or shutdown", - host_address->address->bytes, - host_address->host->bytes); - - struct aws_allocator *allocator = host_address->allocator; - aws_host_address_clean_up(host_address); - aws_mem_release(allocator, host_address); -} - -/* - * The resolver lock must be held before calling this function - */ -static inline int create_and_init_host_entry( - struct aws_host_resolver *resolver, - const struct aws_string *host_name, - aws_on_host_resolved_result_fn *res, - struct aws_host_resolution_config *config, - uint64_t timestamp, - void *user_data) { - struct host_entry *new_host_entry = aws_mem_calloc(resolver->allocator, 1, sizeof(struct host_entry)); - if (!new_host_entry) { - return AWS_OP_ERR; - } - - new_host_entry->resolver = resolver; - new_host_entry->allocator = resolver->allocator; - new_host_entry->last_resolve_request_timestamp_ns = timestamp; - new_host_entry->resolves_since_last_request = 0; - new_host_entry->resolve_frequency_ns = NS_PER_SEC; - new_host_entry->state = DRS_ACTIVE; - - bool thread_init = false; - struct pending_callback *pending_callback = NULL; - const struct aws_string *host_string_copy = aws_string_new_from_string(resolver->allocator, host_name); - if (AWS_UNLIKELY(!host_string_copy)) { - goto setup_host_entry_error; - } - - new_host_entry->host_name = host_string_copy; - new_host_entry->a_records = aws_cache_new_lru( - new_host_entry->allocator, - aws_hash_string, - aws_hash_callback_string_eq, - NULL, - on_address_value_removed, - config->max_ttl); - if (AWS_UNLIKELY(!new_host_entry->a_records)) { - goto setup_host_entry_error; - } - - new_host_entry->aaaa_records = aws_cache_new_lru( - new_host_entry->allocator, - aws_hash_string, - aws_hash_callback_string_eq, - NULL, - on_address_value_removed, - config->max_ttl); - if (AWS_UNLIKELY(!new_host_entry->aaaa_records)) { - goto setup_host_entry_error; - } - - new_host_entry->failed_connection_a_records = aws_cache_new_lru( - new_host_entry->allocator, - aws_hash_string, - aws_hash_callback_string_eq, - NULL, - on_address_value_removed, - config->max_ttl); - if (AWS_UNLIKELY(!new_host_entry->failed_connection_a_records)) { - goto setup_host_entry_error; - } - - new_host_entry->failed_connection_aaaa_records = aws_cache_new_lru( - new_host_entry->allocator, - aws_hash_string, - aws_hash_callback_string_eq, - NULL, - on_address_value_removed, - config->max_ttl); - if (AWS_UNLIKELY(!new_host_entry->failed_connection_aaaa_records)) { - goto setup_host_entry_error; - } - - aws_linked_list_init(&new_host_entry->pending_resolution_callbacks); - - pending_callback = aws_mem_acquire(resolver->allocator, sizeof(struct pending_callback)); - - if (AWS_UNLIKELY(!pending_callback)) { - goto setup_host_entry_error; - } - - /*add the current callback here */ - pending_callback->user_data = user_data; - pending_callback->callback = res; - aws_linked_list_push_back(&new_host_entry->pending_resolution_callbacks, &pending_callback->node); - - aws_mutex_init(&new_host_entry->entry_lock); - new_host_entry->resolution_config = *config; - aws_condition_variable_init(&new_host_entry->entry_signal); - - if (aws_thread_init(&new_host_entry->resolver_thread, resolver->allocator)) { - goto setup_host_entry_error; - } - - thread_init = true; - struct default_host_resolver *default_host_resolver = resolver->impl; - if (AWS_UNLIKELY( - aws_hash_table_put(&default_host_resolver->host_entry_table, host_string_copy, new_host_entry, NULL))) { - goto setup_host_entry_error; - } - - aws_thread_launch(&new_host_entry->resolver_thread, resolver_thread_fn, new_host_entry, NULL); - ++default_host_resolver->pending_host_entry_shutdown_completion_callbacks; - - return AWS_OP_SUCCESS; - -setup_host_entry_error: - - if (thread_init) { - aws_thread_clean_up(&new_host_entry->resolver_thread); - } - - s_clean_up_host_entry(new_host_entry); - - return AWS_OP_ERR; -} - -static int default_resolve_host( - struct aws_host_resolver *resolver, - const struct aws_string *host_name, - aws_on_host_resolved_result_fn *res, - struct aws_host_resolution_config *config, - void *user_data) { - int result = AWS_OP_SUCCESS; - - AWS_LOGF_DEBUG(AWS_LS_IO_DNS, "id=%p: Host resolution requested for %s", (void *)resolver, host_name->bytes); - - uint64_t timestamp = 0; - aws_sys_clock_get_ticks(×tamp); - - struct default_host_resolver *default_host_resolver = resolver->impl; - aws_mutex_lock(&default_host_resolver->resolver_lock); - - struct aws_hash_element *element = NULL; - /* we don't care about the error code here, only that the host_entry was found or not. */ - aws_hash_table_find(&default_host_resolver->host_entry_table, host_name, &element); - - struct host_entry *host_entry = NULL; - if (element != NULL) { - host_entry = element->value; - AWS_FATAL_ASSERT(host_entry != NULL); - } - - if (!host_entry) { - AWS_LOGF_DEBUG( - AWS_LS_IO_DNS, - "id=%p: No cached entries found for %s starting new resolver thread.", - (void *)resolver, - host_name->bytes); - - result = create_and_init_host_entry(resolver, host_name, res, config, timestamp, user_data); - aws_mutex_unlock(&default_host_resolver->resolver_lock); - - return result; - } - - aws_mutex_lock(&host_entry->entry_lock); - - /* - * We don't need to make any resolver side-affects in the remaining logic and it's impossible for the entry - * to disappear underneath us while holding its lock, so its safe to release the resolver lock and let other - * things query other entries. - */ - aws_mutex_unlock(&default_host_resolver->resolver_lock); - host_entry->last_resolve_request_timestamp_ns = timestamp; - host_entry->resolves_since_last_request = 0; - - struct aws_host_address *aaaa_record = aws_lru_cache_use_lru_element(host_entry->aaaa_records); - struct aws_host_address *a_record = aws_lru_cache_use_lru_element(host_entry->a_records); - struct aws_host_address address_array[2]; - AWS_ZERO_ARRAY(address_array); - struct aws_array_list callback_address_list; - aws_array_list_init_static(&callback_address_list, address_array, 2, sizeof(struct aws_host_address)); - - if ((aaaa_record || a_record)) { - AWS_LOGF_DEBUG( - AWS_LS_IO_DNS, - "id=%p: cached entries found for %s returning to caller.", - (void *)resolver, - host_name->bytes); - - /* these will all need to be copied so that we don't hold the lock during the callback. */ - if (aaaa_record) { - struct aws_host_address aaaa_record_cpy; - if (!aws_host_address_copy(aaaa_record, &aaaa_record_cpy)) { - aws_array_list_push_back(&callback_address_list, &aaaa_record_cpy); - AWS_LOGF_TRACE( - AWS_LS_IO_DNS, - "id=%p: vending address %s for host %s to caller", - (void *)resolver, - aaaa_record->address->bytes, - host_entry->host_name->bytes); - } - } - if (a_record) { - struct aws_host_address a_record_cpy; - if (!aws_host_address_copy(a_record, &a_record_cpy)) { - aws_array_list_push_back(&callback_address_list, &a_record_cpy); - AWS_LOGF_TRACE( - AWS_LS_IO_DNS, - "id=%p: vending address %s for host %s to caller", - (void *)resolver, - a_record->address->bytes, - host_entry->host_name->bytes); - } - } - aws_mutex_unlock(&host_entry->entry_lock); - - /* we don't want to do the callback WHILE we hold the lock someone may reentrantly call us. */ - if (aws_array_list_length(&callback_address_list)) { - res(resolver, host_name, AWS_OP_SUCCESS, &callback_address_list, user_data); - } else { - res(resolver, host_name, aws_last_error(), NULL, user_data); - result = AWS_OP_ERR; - } - - for (size_t i = 0; i < aws_array_list_length(&callback_address_list); ++i) { - struct aws_host_address *address_ptr = NULL; - aws_array_list_get_at_ptr(&callback_address_list, (void **)&address_ptr, i); - aws_host_address_clean_up(address_ptr); - } - - aws_array_list_clean_up(&callback_address_list); - - return result; - } - - struct pending_callback *pending_callback = - aws_mem_acquire(default_host_resolver->allocator, sizeof(struct pending_callback)); - if (pending_callback != NULL) { - pending_callback->user_data = user_data; - pending_callback->callback = res; - aws_linked_list_push_back(&host_entry->pending_resolution_callbacks, &pending_callback->node); - } else { - result = AWS_OP_ERR; - } - - aws_mutex_unlock(&host_entry->entry_lock); - - return result; -} - -static size_t default_get_host_address_count( - struct aws_host_resolver *host_resolver, - const struct aws_string *host_name, - uint32_t flags) { - struct default_host_resolver *default_host_resolver = host_resolver->impl; - size_t address_count = 0; - - aws_mutex_lock(&default_host_resolver->resolver_lock); - - struct aws_hash_element *element = NULL; - aws_hash_table_find(&default_host_resolver->host_entry_table, host_name, &element); - if (element != NULL) { - struct host_entry *host_entry = element->value; - if (host_entry != NULL) { - aws_mutex_lock(&host_entry->entry_lock); - - if ((flags & AWS_GET_HOST_ADDRESS_COUNT_RECORD_TYPE_A) != 0) { - address_count += aws_cache_get_element_count(host_entry->a_records); - } - - if ((flags & AWS_GET_HOST_ADDRESS_COUNT_RECORD_TYPE_AAAA) != 0) { - address_count += aws_cache_get_element_count(host_entry->aaaa_records); - } - - aws_mutex_unlock(&host_entry->entry_lock); - } - } - - aws_mutex_unlock(&default_host_resolver->resolver_lock); - - return address_count; -} - -static struct aws_host_resolver_vtable s_vtable = { - .purge_cache = resolver_purge_cache, - .resolve_host = default_resolve_host, - .record_connection_failure = resolver_record_connection_failure, - .get_host_address_count = default_get_host_address_count, - .add_host_listener = default_add_host_listener, - .remove_host_listener = default_remove_host_listener, - .destroy = resolver_destroy, -}; - -static void s_aws_host_resolver_destroy(struct aws_host_resolver *resolver) { - AWS_ASSERT(resolver->vtable && resolver->vtable->destroy); - resolver->vtable->destroy(resolver); -} - -struct aws_host_resolver *aws_host_resolver_new_default( - struct aws_allocator *allocator, - size_t max_entries, - struct aws_event_loop_group *el_group, - const struct aws_shutdown_callback_options *shutdown_options) { - /* NOTE: we don't use el_group yet, but we will in the future. Also, we - don't want host resolvers getting cleaned up after el_groups; this will force that - in bindings, and encourage it in C land. */ - (void)el_group; - AWS_ASSERT(el_group); - - struct aws_host_resolver *resolver = NULL; - struct default_host_resolver *default_host_resolver = NULL; - if (!aws_mem_acquire_many( - allocator, - 2, - &resolver, - sizeof(struct aws_host_resolver), - &default_host_resolver, - sizeof(struct default_host_resolver))) { - return NULL; - } - - AWS_ZERO_STRUCT(*resolver); - AWS_ZERO_STRUCT(*default_host_resolver); - - AWS_LOGF_INFO( - AWS_LS_IO_DNS, - "id=%p: Initializing default host resolver with %llu max host entries.", - (void *)resolver, - (unsigned long long)max_entries); - - resolver->vtable = &s_vtable; - resolver->allocator = allocator; - resolver->impl = default_host_resolver; - - default_host_resolver->allocator = allocator; - default_host_resolver->pending_host_entry_shutdown_completion_callbacks = 0; - default_host_resolver->state = DRS_ACTIVE; - aws_mutex_init(&default_host_resolver->resolver_lock); - - aws_global_thread_creator_increment(); - - if (aws_hash_table_init( - &default_host_resolver->host_entry_table, - allocator, - max_entries, - aws_hash_string, - aws_hash_callback_string_eq, - NULL, - NULL)) { - goto on_error; - } - - if (aws_hash_table_init( - &default_host_resolver->listener_entry_table, - allocator, - max_entries, - aws_hash_string, - aws_hash_callback_string_eq, - aws_hash_callback_string_destroy, - s_host_listener_entry_destroy)) { - goto on_error; - } - - aws_ref_count_init(&resolver->ref_count, resolver, (aws_simple_completion_callback *)s_aws_host_resolver_destroy); - - if (shutdown_options != NULL) { - resolver->shutdown_options = *shutdown_options; - } - - return resolver; - -on_error: - - s_cleanup_default_resolver(resolver); - - return NULL; -} - -struct aws_host_resolver *aws_host_resolver_acquire(struct aws_host_resolver *resolver) { - if (resolver != NULL) { - aws_ref_count_acquire(&resolver->ref_count); - } - - return resolver; -} - -void aws_host_resolver_release(struct aws_host_resolver *resolver) { - if (resolver != NULL) { - aws_ref_count_release(&resolver->ref_count); - } -} - -size_t aws_host_resolver_get_host_address_count( - struct aws_host_resolver *resolver, - const struct aws_string *host_name, - uint32_t flags) { - return resolver->vtable->get_host_address_count(resolver, host_name, flags); -} - -enum find_listener_entry_flags { - FIND_LISTENER_ENTRY_FLAGS_CREATE_IF_NOT_FOUND = 0x00000001, -}; - -static struct host_listener_entry *s_find_host_listener_entry( - struct default_host_resolver *default_resolver, - const struct aws_string *host_name, - uint32_t flags); - -static struct aws_host_listener *default_add_host_listener( - struct aws_host_resolver *resolver, - const struct aws_host_listener_options *options) { - AWS_PRECONDITION(resolver); - - if (options == NULL) { - AWS_LOGF_ERROR(AWS_LS_IO_DNS, "Cannot create host resolver listener; options structure is NULL."); - aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - return NULL; - } - - if (options->host_name.len == 0) { - AWS_LOGF_ERROR(AWS_LS_IO_DNS, "Cannot create host resolver listener; invalid host name specified."); - aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - return NULL; - } - - /* Allocate and set up the listener. */ - struct host_listener *listener = aws_mem_calloc(resolver->allocator, 1, sizeof(struct host_listener)); - - AWS_LOGF_TRACE( - AWS_LS_IO_DNS, - "id=%p Adding listener %p for host name %s", - (void *)resolver, - (void *)listener, - (const char *)options->host_name.ptr); - - aws_host_resolver_acquire(resolver); - listener->resolver = resolver; - listener->host_name = aws_string_new_from_cursor(resolver->allocator, &options->host_name); - listener->resolved_address_callback = options->resolved_address_callback; - listener->shutdown_callback = options->shutdown_callback; - listener->user_data = options->user_data; - - struct default_host_resolver *default_host_resolver = resolver->impl; - - /* Add the listener to a host listener entry in the host listener entry table. */ - aws_mutex_lock(&default_host_resolver->resolver_lock); - - if (s_add_host_listener_to_listener_entry(default_host_resolver, listener->host_name, listener)) { - aws_mem_release(resolver->allocator, listener); - listener = NULL; - } - - aws_mutex_unlock(&default_host_resolver->resolver_lock); - - return (struct aws_host_listener *)listener; -} - -static int default_remove_host_listener( - struct aws_host_resolver *host_resolver, - struct aws_host_listener *listener_opaque) { - AWS_PRECONDITION(host_resolver); - AWS_PRECONDITION(listener_opaque); - - struct host_listener *listener = (struct host_listener *)listener_opaque; - struct default_host_resolver *default_host_resolver = host_resolver->impl; - - if (listener->resolver != host_resolver) { - AWS_LOGF_ERROR( - AWS_LS_IO_DNS, - "id=%p Trying to remove listener from incorrect host resolver. Listener belongs to host resolver %p", - (void *)host_resolver, - (void *)listener->resolver); - aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - return AWS_OP_ERR; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_DNS, - "id=%p Removing listener %p for host name %s", - (void *)host_resolver, - (void *)listener, - (const char *)listener->host_name->bytes); - - bool destroy_listener_immediate = false; - - aws_mutex_lock(&default_host_resolver->resolver_lock); - - /* If owned by the resolver thread, flag the listener as pending destroy, so that resolver thread knows to destroy - * it. */ - if (listener->synced_data.owned_by_resolver_thread) { - listener->synced_data.pending_destroy = true; - } else { - /* Else, remove the listener from the listener entry and clean it up once outside of the mutex. */ - s_remove_host_listener_from_entry(default_host_resolver, listener->host_name, listener); - destroy_listener_immediate = true; - } - - aws_mutex_unlock(&default_host_resolver->resolver_lock); - - if (destroy_listener_immediate) { - s_host_listener_destroy(listener); - } - - return AWS_OP_SUCCESS; -} - -/* Find listener entry on the host resolver, optionally creating it if it doesn't exist. */ -/* Assumes host resolver lock is held. */ -static struct host_listener_entry *s_find_host_listener_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - uint32_t flags) { - AWS_PRECONDITION(resolver); - AWS_PRECONDITION(host_name); - - struct host_listener_entry *listener_entry = NULL; - struct aws_string *host_string_copy = NULL; - - struct aws_hash_element *listener_entry_hash_element = NULL; - bool create_if_not_found = (flags & FIND_LISTENER_ENTRY_FLAGS_CREATE_IF_NOT_FOUND) != 0; - - if (aws_hash_table_find(&resolver->listener_entry_table, host_name, &listener_entry_hash_element)) { - AWS_LOGF_ERROR( - AWS_LS_IO_DNS, "static: error when trying to find a listener entry in the listener entry table."); - goto error_clean_up; - } - - if (listener_entry_hash_element != NULL) { - listener_entry = listener_entry_hash_element->value; - AWS_FATAL_ASSERT(listener_entry); - } else if (create_if_not_found) { - - listener_entry = aws_mem_calloc(resolver->allocator, 1, sizeof(struct host_listener_entry)); - listener_entry->resolver = resolver; - aws_linked_list_init(&listener_entry->listeners); - - host_string_copy = aws_string_new_from_string(resolver->allocator, host_name); - - if (aws_hash_table_put(&resolver->listener_entry_table, host_string_copy, listener_entry, NULL)) { - AWS_LOGF_ERROR(AWS_LS_IO_DNS, "static: could not put new listener entry into listener entry table."); - goto error_clean_up; - } - } - - return listener_entry; - -error_clean_up: - - s_host_listener_entry_destroy(listener_entry); - - aws_string_destroy(host_string_copy); - - return NULL; -} - -/* Destroy function for listener entries. Takes a void* so that it can be used by the listener entry hash table. */ -static void s_host_listener_entry_destroy(void *listener_entry_void) { - if (listener_entry_void == NULL) { - return; - } - - struct host_listener_entry *listener_entry = listener_entry_void; - struct default_host_resolver *resolver = listener_entry->resolver; - - aws_mem_release(resolver->allocator, listener_entry); -} - -/* Add a listener to the relevant host listener entry. */ -/* Assumes host resolver lock is held. */ -static int s_add_host_listener_to_listener_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - struct host_listener *listener) { - AWS_PRECONDITION(resolver); - AWS_PRECONDITION(host_name); - AWS_PRECONDITION(listener); - - struct host_listener_entry *listener_entry = - s_find_host_listener_entry(resolver, host_name, FIND_LISTENER_ENTRY_FLAGS_CREATE_IF_NOT_FOUND); - - if (listener_entry == NULL) { - return AWS_OP_ERR; - } - - aws_linked_list_push_back(&listener_entry->listeners, &listener->synced_data.node); - return AWS_OP_SUCCESS; -} - -/* Assumes host resolver lock is held. */ -static struct host_listener *s_pop_host_listener_from_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - struct host_listener_entry **in_out_listener_entry) { - AWS_PRECONDITION(resolver); - AWS_PRECONDITION(host_name); - - struct host_listener_entry *listener_entry = NULL; - - if (in_out_listener_entry) { - listener_entry = *in_out_listener_entry; - } - - if (listener_entry == NULL) { - listener_entry = s_find_host_listener_entry(resolver, host_name, 0); - - if (listener_entry == NULL) { - return NULL; - } - } - - /* We should never have a listener entry without any listeners. Whenever a listener entry has no listeners, it - * should be cleaned up immediately. */ - AWS_ASSERT(!aws_linked_list_empty(&listener_entry->listeners)); - - struct aws_linked_list_node *node = aws_linked_list_pop_back(&listener_entry->listeners); - - struct host_listener *listener = HOST_LISTENER_FROM_SYNCED_NODE(node); - AWS_FATAL_ASSERT(listener); - - /* If the listener list on the listener entry is now empty, remove it. */ - if (aws_linked_list_empty(&listener_entry->listeners)) { - aws_hash_table_remove(&resolver->listener_entry_table, host_name, NULL, NULL); - listener_entry = NULL; - } - - if (in_out_listener_entry) { - *in_out_listener_entry = listener_entry; - } - - return listener; -} - -/* Assumes host resolver lock is held. */ -static void s_remove_host_listener_from_entry( - struct default_host_resolver *resolver, - const struct aws_string *host_name, - struct host_listener *listener) { - AWS_PRECONDITION(resolver); - AWS_PRECONDITION(host_name); - AWS_PRECONDITION(listener); - - struct host_listener_entry *listener_entry = s_find_host_listener_entry(resolver, host_name, 0); - - if (listener_entry == NULL) { - AWS_LOGF_WARN(AWS_LS_IO_DNS, "id=%p: Could not find listener entry for listener.", (void *)listener); - return; - } - - /* We should never have a listener entry without any listeners. Whenever a listener entry has no listeners, it - * should be cleaned up immediately. */ - AWS_ASSERT(!aws_linked_list_empty(&listener_entry->listeners)); - - aws_linked_list_remove(&listener->synced_data.node); - - /* If the listener list on the listener entry is now empty, remove it. */ - if (aws_linked_list_empty(&listener_entry->listeners)) { - aws_hash_table_remove(&resolver->listener_entry_table, host_name, NULL, NULL); - } -} - -/* Finish destroying a default resolver listener, releasing any remaining memory for it and triggering its shutdown - * callack. Since a shutdown callback is triggered, no lock should be held when calling this function. */ -static void s_host_listener_destroy(struct host_listener *listener) { - if (listener == NULL) { - return; - } - - AWS_LOGF_TRACE(AWS_LS_IO_DNS, "id=%p: Finishing clean up of host listener.", (void *)listener); - - struct aws_host_resolver *host_resolver = listener->resolver; - - aws_host_listener_shutdown_fn *shutdown_callback = listener->shutdown_callback; - void *shutdown_user_data = listener->user_data; - - aws_string_destroy(listener->host_name); - listener->host_name = NULL; - - aws_mem_release(host_resolver->allocator, listener); - listener = NULL; - - if (shutdown_callback != NULL) { - shutdown_callback(shutdown_user_data); - } - - if (host_resolver != NULL) { - aws_host_resolver_release(host_resolver); - host_resolver = NULL; - } -} - -#undef HOST_LISTENER_FROM_SYNCED_NODE -#undef HOST_LISTENER_FROM_THREADED_NODE +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/host_resolver.h> + +#include <aws/common/atomics.h> +#include <aws/common/clock.h> +#include <aws/common/condition_variable.h> +#include <aws/common/hash_table.h> +#include <aws/common/lru_cache.h> +#include <aws/common/mutex.h> +#include <aws/common/string.h> +#include <aws/common/thread.h> + +#include <aws/io/logging.h> + +#include <inttypes.h> + +const uint64_t NS_PER_SEC = 1000000000; + +int aws_host_address_copy(const struct aws_host_address *from, struct aws_host_address *to) { + to->allocator = from->allocator; + to->address = aws_string_new_from_string(to->allocator, from->address); + + if (!to->address) { + return AWS_OP_ERR; + } + + to->host = aws_string_new_from_string(to->allocator, from->host); + + if (!to->host) { + aws_string_destroy((void *)to->address); + return AWS_OP_ERR; + } + + to->record_type = from->record_type; + to->use_count = from->use_count; + to->connection_failure_count = from->connection_failure_count; + to->expiry = from->expiry; + to->weight = from->weight; + + return AWS_OP_SUCCESS; +} + +void aws_host_address_move(struct aws_host_address *from, struct aws_host_address *to) { + to->allocator = from->allocator; + to->address = from->address; + to->host = from->host; + to->record_type = from->record_type; + to->use_count = from->use_count; + to->connection_failure_count = from->connection_failure_count; + to->expiry = from->expiry; + to->weight = from->weight; + AWS_ZERO_STRUCT(*from); +} + +void aws_host_address_clean_up(struct aws_host_address *address) { + if (address->address) { + aws_string_destroy((void *)address->address); + } + if (address->host) { + aws_string_destroy((void *)address->host); + } + AWS_ZERO_STRUCT(*address); +} + +int aws_host_resolver_resolve_host( + struct aws_host_resolver *resolver, + const struct aws_string *host_name, + aws_on_host_resolved_result_fn *res, + struct aws_host_resolution_config *config, + void *user_data) { + AWS_ASSERT(resolver->vtable && resolver->vtable->resolve_host); + return resolver->vtable->resolve_host(resolver, host_name, res, config, user_data); +} + +int aws_host_resolver_purge_cache(struct aws_host_resolver *resolver) { + AWS_ASSERT(resolver->vtable && resolver->vtable->purge_cache); + return resolver->vtable->purge_cache(resolver); +} + +int aws_host_resolver_record_connection_failure(struct aws_host_resolver *resolver, struct aws_host_address *address) { + AWS_ASSERT(resolver->vtable && resolver->vtable->record_connection_failure); + return resolver->vtable->record_connection_failure(resolver, address); +} + +struct aws_host_listener *aws_host_resolver_add_host_listener( + struct aws_host_resolver *resolver, + const struct aws_host_listener_options *options) { + AWS_PRECONDITION(resolver); + AWS_PRECONDITION(resolver->vtable); + + if (resolver->vtable->add_host_listener) { + return resolver->vtable->add_host_listener(resolver, options); + } + + aws_raise_error(AWS_ERROR_UNSUPPORTED_OPERATION); + return NULL; +} + +int aws_host_resolver_remove_host_listener(struct aws_host_resolver *resolver, struct aws_host_listener *listener) { + AWS_PRECONDITION(resolver); + AWS_PRECONDITION(resolver->vtable); + + if (resolver->vtable->remove_host_listener) { + return resolver->vtable->remove_host_listener(resolver, listener); + } + + aws_raise_error(AWS_ERROR_UNSUPPORTED_OPERATION); + return AWS_OP_ERR; +} + +/* + * Used by both the resolver for its lifetime state as well as individual host entries for theirs. + */ +enum default_resolver_state { + DRS_ACTIVE, + DRS_SHUTTING_DOWN, +}; + +struct default_host_resolver { + struct aws_allocator *allocator; + + /* + * Mutually exclusion for the whole resolver, includes all member data and all host_entry_table operations. Once + * an entry is retrieved, this lock MAY be dropped but certain logic may hold both the resolver and the entry lock. + * The two locks must be taken in that order. + */ + struct aws_mutex resolver_lock; + + /* host_name (aws_string*) -> host_entry* */ + struct aws_hash_table host_entry_table; + + /* Hash table of listener entries per host name. We keep this decoupled from the host entry table to allow for + * listeners to be added/removed regardless of whether or not a corresponding host entry exists. + * + * Any time the listener list in the listener entry becomes empty, we remove the entry from the table. This + * includes when a resolver thread moves all of the available listeners to its local list. + */ + /* host_name (aws_string*) -> host_listener_entry* */ + struct aws_hash_table listener_entry_table; + + enum default_resolver_state state; + + /* + * Tracks the number of launched resolution threads that have not yet invoked their shutdown completion + * callback. + */ + uint32_t pending_host_entry_shutdown_completion_callbacks; +}; + +/* Default host resolver implementation for listener. */ +struct host_listener { + + /* Reference to the host resolver that owns this listener */ + struct aws_host_resolver *resolver; + + /* String copy of the host name */ + struct aws_string *host_name; + + /* User-supplied callbacks/user_data */ + aws_host_listener_resolved_address_fn *resolved_address_callback; + aws_host_listener_shutdown_fn *shutdown_callback; + void *user_data; + + /* Synchronous data, requires host resolver lock to read/modify*/ + /* TODO Add a lock-synced-data function for the host resolver, replacing all current places where the host resolver + * mutex is locked. */ + struct host_listener_synced_data { + /* It's important that the node structure is always first, so that the HOST_LISTENER_FROM_SYNCED_NODE macro + * works properly.*/ + struct aws_linked_list_node node; + uint32_t owned_by_resolver_thread : 1; + uint32_t pending_destroy : 1; + } synced_data; + + /* Threaded data that can only be used in the resolver thread. */ + struct host_listener_threaded_data { + /* It's important that the node structure is always first, so that the HOST_LISTENER_FROM_THREADED_NODE macro + * works properly.*/ + struct aws_linked_list_node node; + } threaded_data; +}; + +/* AWS_CONTAINER_OF does not compile under Clang when using a member in a nested structure, ie, synced_data.node or + * threaded_data.node. To get around this, we define two local macros that rely on the node being the first member of + * the synced_data/threaded_data structures.*/ +#define HOST_LISTENER_FROM_SYNCED_NODE(listener_node) \ + AWS_CONTAINER_OF((listener_node), struct host_listener, synced_data) +#define HOST_LISTENER_FROM_THREADED_NODE(listener_node) \ + AWS_CONTAINER_OF((listener_node), struct host_listener, threaded_data) + +/* Structure for holding all listeners for a particular host name. */ +struct host_listener_entry { + struct default_host_resolver *resolver; + + /* Linked list of struct host_listener */ + struct aws_linked_list listeners; +}; + +struct host_entry { + /* immutable post-creation */ + struct aws_allocator *allocator; + struct aws_host_resolver *resolver; + struct aws_thread resolver_thread; + const struct aws_string *host_name; + int64_t resolve_frequency_ns; + struct aws_host_resolution_config resolution_config; + + /* synchronized data and its lock */ + struct aws_mutex entry_lock; + struct aws_condition_variable entry_signal; + struct aws_cache *aaaa_records; + struct aws_cache *a_records; + struct aws_cache *failed_connection_aaaa_records; + struct aws_cache *failed_connection_a_records; + struct aws_linked_list pending_resolution_callbacks; + uint32_t resolves_since_last_request; + uint64_t last_resolve_request_timestamp_ns; + enum default_resolver_state state; +}; + +static void s_shutdown_host_entry(struct host_entry *entry) { + aws_mutex_lock(&entry->entry_lock); + entry->state = DRS_SHUTTING_DOWN; + aws_mutex_unlock(&entry->entry_lock); +} + +static struct aws_host_listener *default_add_host_listener( + struct aws_host_resolver *host_resolver, + const struct aws_host_listener_options *options); + +static int default_remove_host_listener( + struct aws_host_resolver *host_resolver, + struct aws_host_listener *listener_opaque); + +static void s_host_listener_entry_destroy(void *listener_entry_void); + +static struct host_listener *s_pop_host_listener_from_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + struct host_listener_entry **in_out_listener_entry); + +static int s_add_host_listener_to_listener_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + struct host_listener *listener); + +static void s_remove_host_listener_from_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + struct host_listener *listener); + +static void s_host_listener_destroy(struct host_listener *listener); + +/* + * resolver lock must be held before calling this function + */ +static void s_clear_default_resolver_entry_table(struct default_host_resolver *resolver) { + struct aws_hash_table *table = &resolver->host_entry_table; + for (struct aws_hash_iter iter = aws_hash_iter_begin(table); !aws_hash_iter_done(&iter); + aws_hash_iter_next(&iter)) { + struct host_entry *entry = iter.element.value; + s_shutdown_host_entry(entry); + } + + aws_hash_table_clear(table); +} + +static int resolver_purge_cache(struct aws_host_resolver *resolver) { + struct default_host_resolver *default_host_resolver = resolver->impl; + aws_mutex_lock(&default_host_resolver->resolver_lock); + s_clear_default_resolver_entry_table(default_host_resolver); + aws_mutex_unlock(&default_host_resolver->resolver_lock); + + return AWS_OP_SUCCESS; +} + +static void s_cleanup_default_resolver(struct aws_host_resolver *resolver) { + struct default_host_resolver *default_host_resolver = resolver->impl; + + aws_hash_table_clean_up(&default_host_resolver->host_entry_table); + aws_hash_table_clean_up(&default_host_resolver->listener_entry_table); + + aws_mutex_clean_up(&default_host_resolver->resolver_lock); + + aws_simple_completion_callback *shutdown_callback = resolver->shutdown_options.shutdown_callback_fn; + void *shutdown_completion_user_data = resolver->shutdown_options.shutdown_callback_user_data; + + aws_mem_release(resolver->allocator, resolver); + + /* invoke shutdown completion finally */ + if (shutdown_callback != NULL) { + shutdown_callback(shutdown_completion_user_data); + } + + aws_global_thread_creator_decrement(); +} + +static void resolver_destroy(struct aws_host_resolver *resolver) { + struct default_host_resolver *default_host_resolver = resolver->impl; + + bool cleanup_resolver = false; + + aws_mutex_lock(&default_host_resolver->resolver_lock); + + AWS_FATAL_ASSERT(default_host_resolver->state == DRS_ACTIVE); + + s_clear_default_resolver_entry_table(default_host_resolver); + default_host_resolver->state = DRS_SHUTTING_DOWN; + if (default_host_resolver->pending_host_entry_shutdown_completion_callbacks == 0) { + cleanup_resolver = true; + } + aws_mutex_unlock(&default_host_resolver->resolver_lock); + + if (cleanup_resolver) { + s_cleanup_default_resolver(resolver); + } +} + +struct pending_callback { + aws_on_host_resolved_result_fn *callback; + void *user_data; + struct aws_linked_list_node node; +}; + +static void s_clean_up_host_entry(struct host_entry *entry) { + if (entry == NULL) { + return; + } + + /* + * This can happen if the resolver's final reference drops while an unanswered query is pending on an entry. + * + * You could add an assertion that the resolver is in the shut down state if this condition hits but that + * requires additional locking just to make the assert. + */ + if (!aws_linked_list_empty(&entry->pending_resolution_callbacks)) { + aws_raise_error(AWS_IO_DNS_HOST_REMOVED_FROM_CACHE); + } + + while (!aws_linked_list_empty(&entry->pending_resolution_callbacks)) { + struct aws_linked_list_node *resolution_callback_node = + aws_linked_list_pop_front(&entry->pending_resolution_callbacks); + struct pending_callback *pending_callback = + AWS_CONTAINER_OF(resolution_callback_node, struct pending_callback, node); + + pending_callback->callback( + entry->resolver, entry->host_name, AWS_IO_DNS_HOST_REMOVED_FROM_CACHE, NULL, pending_callback->user_data); + + aws_mem_release(entry->allocator, pending_callback); + } + + aws_cache_destroy(entry->aaaa_records); + aws_cache_destroy(entry->a_records); + aws_cache_destroy(entry->failed_connection_a_records); + aws_cache_destroy(entry->failed_connection_aaaa_records); + aws_string_destroy((void *)entry->host_name); + aws_mem_release(entry->allocator, entry); +} + +static void s_on_host_entry_shutdown_completion(void *user_data) { + struct host_entry *entry = user_data; + struct aws_host_resolver *resolver = entry->resolver; + struct default_host_resolver *default_host_resolver = resolver->impl; + + s_clean_up_host_entry(entry); + + bool cleanup_resolver = false; + + aws_mutex_lock(&default_host_resolver->resolver_lock); + --default_host_resolver->pending_host_entry_shutdown_completion_callbacks; + if (default_host_resolver->state == DRS_SHUTTING_DOWN && + default_host_resolver->pending_host_entry_shutdown_completion_callbacks == 0) { + cleanup_resolver = true; + } + aws_mutex_unlock(&default_host_resolver->resolver_lock); + + if (cleanup_resolver) { + s_cleanup_default_resolver(resolver); + } +} + +/* this only ever gets called after resolution has already run. We expect that the entry's lock + has been acquired for writing before this function is called and released afterwards. */ +static inline void process_records( + struct aws_allocator *allocator, + struct aws_cache *records, + struct aws_cache *failed_records) { + uint64_t timestamp = 0; + aws_sys_clock_get_ticks(×tamp); + + size_t record_count = aws_cache_get_element_count(records); + size_t expired_records = 0; + + /* since this only ever gets called after resolution has already run, we're in a dns outage + * if everything is expired. Leave an element so we can keep trying. */ + for (size_t index = 0; index < record_count && expired_records < record_count - 1; ++index) { + struct aws_host_address *lru_element = aws_lru_cache_use_lru_element(records); + + if (lru_element->expiry < timestamp) { + AWS_LOGF_DEBUG( + AWS_LS_IO_DNS, + "static: purging expired record %s for %s", + lru_element->address->bytes, + lru_element->host->bytes); + expired_records++; + aws_cache_remove(records, lru_element->address); + } + } + + record_count = aws_cache_get_element_count(records); + AWS_LOGF_TRACE(AWS_LS_IO_DNS, "static: remaining record count for host %d", (int)record_count); + + /* if we don't have any known good addresses, take the least recently used, but not expired address with a history + * of spotty behavior and upgrade it for reuse. If it's expired, leave it and let the resolve fail. Better to fail + * than accidentally give a kids' app an IP address to somebody's adult website when the IP address gets rebound to + * a different endpoint. The moral of the story here is to not disable SSL verification! */ + if (!record_count) { + size_t failed_count = aws_cache_get_element_count(failed_records); + for (size_t index = 0; index < failed_count; ++index) { + struct aws_host_address *lru_element = aws_lru_cache_use_lru_element(failed_records); + + if (timestamp < lru_element->expiry) { + struct aws_host_address *to_add = aws_mem_acquire(allocator, sizeof(struct aws_host_address)); + + if (to_add && !aws_host_address_copy(lru_element, to_add)) { + AWS_LOGF_INFO( + AWS_LS_IO_DNS, + "static: promoting spotty record %s for %s back to good list", + lru_element->address->bytes, + lru_element->host->bytes); + if (aws_cache_put(records, to_add->address, to_add)) { + aws_mem_release(allocator, to_add); + continue; + } + /* we only want to promote one per process run.*/ + aws_cache_remove(failed_records, lru_element->address); + break; + } + + if (to_add) { + aws_mem_release(allocator, to_add); + } + } + } + } +} + +static int resolver_record_connection_failure(struct aws_host_resolver *resolver, struct aws_host_address *address) { + struct default_host_resolver *default_host_resolver = resolver->impl; + + AWS_LOGF_INFO( + AWS_LS_IO_DNS, + "id=%p: recording failure for record %s for %s, moving to bad list", + (void *)resolver, + address->address->bytes, + address->host->bytes); + + aws_mutex_lock(&default_host_resolver->resolver_lock); + + struct aws_hash_element *element = NULL; + if (aws_hash_table_find(&default_host_resolver->host_entry_table, address->host, &element)) { + aws_mutex_unlock(&default_host_resolver->resolver_lock); + return AWS_OP_ERR; + } + + struct host_entry *host_entry = NULL; + if (element != NULL) { + host_entry = element->value; + AWS_FATAL_ASSERT(host_entry); + } + + if (host_entry) { + struct aws_host_address *cached_address = NULL; + + aws_mutex_lock(&host_entry->entry_lock); + aws_mutex_unlock(&default_host_resolver->resolver_lock); + struct aws_cache *address_table = + address->record_type == AWS_ADDRESS_RECORD_TYPE_AAAA ? host_entry->aaaa_records : host_entry->a_records; + + struct aws_cache *failed_table = address->record_type == AWS_ADDRESS_RECORD_TYPE_AAAA + ? host_entry->failed_connection_aaaa_records + : host_entry->failed_connection_a_records; + + aws_cache_find(address_table, address->address, (void **)&cached_address); + + struct aws_host_address *address_copy = NULL; + if (cached_address) { + address_copy = aws_mem_acquire(resolver->allocator, sizeof(struct aws_host_address)); + + if (!address_copy || aws_host_address_copy(cached_address, address_copy)) { + goto error_host_entry_cleanup; + } + + if (aws_cache_remove(address_table, cached_address->address)) { + goto error_host_entry_cleanup; + } + + address_copy->connection_failure_count += 1; + + if (aws_cache_put(failed_table, address_copy->address, address_copy)) { + goto error_host_entry_cleanup; + } + } else { + if (aws_cache_find(failed_table, address->address, (void **)&cached_address)) { + goto error_host_entry_cleanup; + } + + if (cached_address) { + cached_address->connection_failure_count += 1; + } + } + aws_mutex_unlock(&host_entry->entry_lock); + return AWS_OP_SUCCESS; + + error_host_entry_cleanup: + if (address_copy) { + aws_host_address_clean_up(address_copy); + aws_mem_release(resolver->allocator, address_copy); + } + aws_mutex_unlock(&host_entry->entry_lock); + return AWS_OP_ERR; + } + + aws_mutex_unlock(&default_host_resolver->resolver_lock); + + return AWS_OP_SUCCESS; +} + +/* + * A bunch of convenience functions for the host resolver background thread function + */ + +static struct aws_host_address *s_find_cached_address_aux( + struct aws_cache *primary_records, + struct aws_cache *fallback_records, + const struct aws_string *address) { + + struct aws_host_address *found = NULL; + aws_cache_find(primary_records, address, (void **)&found); + if (found == NULL) { + aws_cache_find(fallback_records, address, (void **)&found); + } + + return found; +} + +/* + * Looks in both the good and failed connection record sets for a given host record + */ +static struct aws_host_address *s_find_cached_address( + struct host_entry *entry, + const struct aws_string *address, + enum aws_address_record_type record_type) { + + switch (record_type) { + case AWS_ADDRESS_RECORD_TYPE_AAAA: + return s_find_cached_address_aux(entry->aaaa_records, entry->failed_connection_aaaa_records, address); + + case AWS_ADDRESS_RECORD_TYPE_A: + return s_find_cached_address_aux(entry->a_records, entry->failed_connection_a_records, address); + + default: + return NULL; + } +} + +static struct aws_host_address *s_get_lru_address_aux( + struct aws_cache *primary_records, + struct aws_cache *fallback_records) { + + struct aws_host_address *address = aws_lru_cache_use_lru_element(primary_records); + if (address == NULL) { + aws_lru_cache_use_lru_element(fallback_records); + } + + return address; +} + +/* + * Looks in both the good and failed connection record sets for the LRU host record + */ +static struct aws_host_address *s_get_lru_address(struct host_entry *entry, enum aws_address_record_type record_type) { + switch (record_type) { + case AWS_ADDRESS_RECORD_TYPE_AAAA: + return s_get_lru_address_aux(entry->aaaa_records, entry->failed_connection_aaaa_records); + + case AWS_ADDRESS_RECORD_TYPE_A: + return s_get_lru_address_aux(entry->a_records, entry->failed_connection_a_records); + + default: + return NULL; + } +} + +static void s_clear_address_list(struct aws_array_list *address_list) { + for (size_t i = 0; i < aws_array_list_length(address_list); ++i) { + struct aws_host_address *address = NULL; + aws_array_list_get_at_ptr(address_list, (void **)&address, i); + aws_host_address_clean_up(address); + } + + aws_array_list_clear(address_list); +} + +static void s_update_address_cache( + struct host_entry *host_entry, + struct aws_array_list *address_list, + uint64_t new_expiration, + struct aws_array_list *out_new_address_list) { + + AWS_PRECONDITION(host_entry); + AWS_PRECONDITION(address_list); + AWS_PRECONDITION(out_new_address_list); + + for (size_t i = 0; i < aws_array_list_length(address_list); ++i) { + struct aws_host_address *fresh_resolved_address = NULL; + aws_array_list_get_at_ptr(address_list, (void **)&fresh_resolved_address, i); + + struct aws_host_address *address_to_cache = + s_find_cached_address(host_entry, fresh_resolved_address->address, fresh_resolved_address->record_type); + + if (address_to_cache) { + address_to_cache->expiry = new_expiration; + AWS_LOGF_TRACE( + AWS_LS_IO_DNS, + "static: updating expiry for %s for host %s to %llu", + address_to_cache->address->bytes, + host_entry->host_name->bytes, + (unsigned long long)new_expiration); + } else { + address_to_cache = aws_mem_acquire(host_entry->allocator, sizeof(struct aws_host_address)); + + aws_host_address_move(fresh_resolved_address, address_to_cache); + address_to_cache->expiry = new_expiration; + + struct aws_cache *address_table = address_to_cache->record_type == AWS_ADDRESS_RECORD_TYPE_AAAA + ? host_entry->aaaa_records + : host_entry->a_records; + + if (aws_cache_put(address_table, address_to_cache->address, address_to_cache)) { + AWS_LOGF_ERROR( + AWS_LS_IO_DNS, + "static: could not add new address to host entry cache for host '%s' in " + "s_update_address_cache.", + host_entry->host_name->bytes); + + continue; + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_DNS, + "static: new address resolved %s for host %s caching", + address_to_cache->address->bytes, + host_entry->host_name->bytes); + + struct aws_host_address new_address_copy; + + if (aws_host_address_copy(address_to_cache, &new_address_copy)) { + AWS_LOGF_ERROR( + AWS_LS_IO_DNS, + "static: could not copy address for new-address list for host '%s' in s_update_address_cache.", + host_entry->host_name->bytes); + + continue; + } + + if (aws_array_list_push_back(out_new_address_list, &new_address_copy)) { + aws_host_address_clean_up(&new_address_copy); + + AWS_LOGF_ERROR( + AWS_LS_IO_DNS, + "static: could not push address to new-address list for host '%s' in s_update_address_cache.", + host_entry->host_name->bytes); + + continue; + } + } + } +} + +static void s_copy_address_into_callback_set( + struct aws_host_address *address, + struct aws_array_list *callback_addresses, + const struct aws_string *host_name) { + + if (address) { + address->use_count += 1; + + /* + * This is the worst. + * + * We have to copy the cache address while we still have a write lock. Otherwise, connection failures + * can sneak in and destroy our address by moving the address to/from the various lru caches. + * + * But there's no nice copy construction into an array list, so we get to + * (1) Push a zeroed dummy element onto the array list + * (2) Get its pointer + * (3) Call aws_host_address_copy onto it. If that fails, pop the dummy element. + */ + struct aws_host_address dummy; + AWS_ZERO_STRUCT(dummy); + + if (aws_array_list_push_back(callback_addresses, &dummy)) { + return; + } + + struct aws_host_address *dest_copy = NULL; + aws_array_list_get_at_ptr( + callback_addresses, (void **)&dest_copy, aws_array_list_length(callback_addresses) - 1); + AWS_FATAL_ASSERT(dest_copy != NULL); + + if (aws_host_address_copy(address, dest_copy)) { + aws_array_list_pop_back(callback_addresses); + return; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_DNS, + "static: vending address %s for host %s to caller", + address->address->bytes, + host_name->bytes); + } +} + +static bool s_host_entry_finished_pred(void *user_data) { + struct host_entry *entry = user_data; + + return entry->state == DRS_SHUTTING_DOWN; +} + +/* Move all of the listeners in the host-resolver-owned listener entry to the resolver thread owned list. */ +/* Assumes resolver_lock is held so that we can pop from the listener entry and access the listener's synced_data. */ +static void s_resolver_thread_move_listeners_from_listener_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + struct aws_linked_list *listener_list) { + + AWS_PRECONDITION(resolver); + AWS_PRECONDITION(host_name); + AWS_PRECONDITION(listener_list); + + struct host_listener_entry *listener_entry = NULL; + struct host_listener *listener = s_pop_host_listener_from_entry(resolver, host_name, &listener_entry); + + while (listener != NULL) { + /* Flag this listener as in-use by the resolver thread so that it can't be destroyed from outside of that + * thread. */ + listener->synced_data.owned_by_resolver_thread = true; + + aws_linked_list_push_back(listener_list, &listener->threaded_data.node); + + listener = s_pop_host_listener_from_entry(resolver, host_name, &listener_entry); + } +} + +/* When the thread is ready to exit, we move all of the listeners back to the host-resolver-owned listener entry.*/ +/* Assumes that we have already removed all pending_destroy listeners via + * s_resolver_thread_cull_pending_destroy_listeners. */ +/* Assumes resolver_lock is held so that we can write to the listener entry and read/write from the listener's + * synced_data. */ +static int s_resolver_thread_move_listeners_to_listener_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + struct aws_linked_list *listener_list) { + + AWS_PRECONDITION(resolver); + AWS_PRECONDITION(host_name); + AWS_PRECONDITION(listener_list); + + int result = 0; + size_t num_listeners_not_moved = 0; + + while (!aws_linked_list_empty(listener_list)) { + struct aws_linked_list_node *listener_node = aws_linked_list_pop_back(listener_list); + struct host_listener *listener = HOST_LISTENER_FROM_THREADED_NODE(listener_node); + + /* Flag this listener as no longer in-use by the resolver thread. */ + listener->synced_data.owned_by_resolver_thread = false; + + AWS_ASSERT(!listener->synced_data.pending_destroy); + + if (s_add_host_listener_to_listener_entry(resolver, host_name, listener)) { + result = AWS_OP_ERR; + ++num_listeners_not_moved; + } + } + + if (result == AWS_OP_ERR) { + AWS_LOGF_ERROR( + AWS_LS_IO_DNS, + "static: could not move %" PRIu64 " listeners back to listener entry", + (uint64_t)num_listeners_not_moved); + } + + return result; +} + +/* Remove the listeners from the resolver-thread-owned listener_list that are marked pending destroy, and move them into + * the destroy list. */ +/* Assumes resolver_lock is held. (This lock is necessary for reading from the listener's synced_data.) */ +static void s_resolver_thread_cull_pending_destroy_listeners( + struct aws_linked_list *listener_list, + struct aws_linked_list *listener_destroy_list) { + + AWS_PRECONDITION(listener_list); + AWS_PRECONDITION(listener_destroy_list); + + struct aws_linked_list_node *listener_node = aws_linked_list_begin(listener_list); + + /* Find all listeners in our current list that are marked for destroy. */ + while (listener_node != aws_linked_list_end(listener_list)) { + struct host_listener *listener = HOST_LISTENER_FROM_THREADED_NODE(listener_node); + + /* Advance our node pointer early to allow for a removal. */ + listener_node = aws_linked_list_next(listener_node); + + /* If listener is pending destroy, remove it from the local list, and push it into the destroy list. */ + if (listener->synced_data.pending_destroy) { + aws_linked_list_remove(&listener->threaded_data.node); + aws_linked_list_push_back(listener_destroy_list, &listener->threaded_data.node); + } + } +} + +/* Destroys all of the listeners in the resolver thread's destroy list. */ +/* Assumes no lock is held. (We don't want any lock held so that any shutdown callbacks happen outside of a lock.) */ +static void s_resolver_thread_destroy_listeners(struct aws_linked_list *listener_destroy_list) { + + AWS_PRECONDITION(listener_destroy_list); + + while (!aws_linked_list_empty(listener_destroy_list)) { + struct aws_linked_list_node *listener_node = aws_linked_list_pop_back(listener_destroy_list); + struct host_listener *listener = HOST_LISTENER_FROM_THREADED_NODE(listener_node); + s_host_listener_destroy(listener); + } +} + +/* Notify all listeners with resolve address callbacks, and also clean up any that are waiting to be cleaned up. */ +/* Assumes no lock is held. The listener_list is owned by the resolver thread, so no lock is necessary. We also don't + * want a lock held when calling the resolver-address callback.*/ +static void s_resolver_thread_notify_listeners( + const struct aws_array_list *new_address_list, + struct aws_linked_list *listener_list) { + + AWS_PRECONDITION(new_address_list); + AWS_PRECONDITION(listener_list); + + /* Go through each listener in our list. */ + for (struct aws_linked_list_node *listener_node = aws_linked_list_begin(listener_list); + listener_node != aws_linked_list_end(listener_list); + listener_node = aws_linked_list_next(listener_node)) { + struct host_listener *listener = HOST_LISTENER_FROM_THREADED_NODE(listener_node); + + /* If we have new adddresses, notify the resolved-address callback if one exists */ + if (aws_array_list_length(new_address_list) > 0 && listener->resolved_address_callback != NULL) { + listener->resolved_address_callback( + (struct aws_host_listener *)listener, new_address_list, listener->user_data); + } + } +} + +static void resolver_thread_fn(void *arg) { + struct host_entry *host_entry = arg; + + size_t unsolicited_resolve_max = host_entry->resolution_config.max_ttl; + if (unsolicited_resolve_max == 0) { + unsolicited_resolve_max = 1; + } + + uint64_t max_no_solicitation_interval = + aws_timestamp_convert(unsolicited_resolve_max, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL); + + struct aws_array_list address_list; + if (aws_array_list_init_dynamic(&address_list, host_entry->allocator, 4, sizeof(struct aws_host_address))) { + return; + } + + struct aws_array_list new_address_list; + if (aws_array_list_init_dynamic(&new_address_list, host_entry->allocator, 4, sizeof(struct aws_host_address))) { + aws_array_list_clean_up(&address_list); + return; + } + + struct aws_linked_list listener_list; + aws_linked_list_init(&listener_list); + + struct aws_linked_list listener_destroy_list; + aws_linked_list_init(&listener_destroy_list); + + bool keep_going = true; + while (keep_going) { + + AWS_LOGF_TRACE(AWS_LS_IO_DNS, "static, resolving %s", aws_string_c_str(host_entry->host_name)); + + /* resolve and then process each record */ + int err_code = AWS_ERROR_SUCCESS; + if (host_entry->resolution_config.impl( + host_entry->allocator, host_entry->host_name, &address_list, host_entry->resolution_config.impl_data)) { + + err_code = aws_last_error(); + } + uint64_t timestamp = 0; + aws_sys_clock_get_ticks(×tamp); + uint64_t new_expiry = timestamp + (host_entry->resolution_config.max_ttl * NS_PER_SEC); + + struct aws_linked_list pending_resolve_copy; + aws_linked_list_init(&pending_resolve_copy); + + /* + * Within the lock we + * (1) Update the cache with the newly resolved addresses + * (2) Process all held addresses looking for expired or promotable ones + * (3) Prep for callback invocations + */ + aws_mutex_lock(&host_entry->entry_lock); + + if (!err_code) { + s_update_address_cache(host_entry, &address_list, new_expiry, &new_address_list); + } + + /* + * process and clean_up records in the entry. occasionally, failed connect records will be upgraded + * for retry. + */ + process_records(host_entry->allocator, host_entry->aaaa_records, host_entry->failed_connection_aaaa_records); + process_records(host_entry->allocator, host_entry->a_records, host_entry->failed_connection_a_records); + + aws_linked_list_swap_contents(&pending_resolve_copy, &host_entry->pending_resolution_callbacks); + + aws_mutex_unlock(&host_entry->entry_lock); + + /* + * Clean up resolved addressed outside of the lock + */ + s_clear_address_list(&address_list); + + struct aws_host_address address_array[2]; + AWS_ZERO_ARRAY(address_array); + + /* + * Perform the actual subscriber notifications + */ + while (!aws_linked_list_empty(&pending_resolve_copy)) { + struct aws_linked_list_node *resolution_callback_node = aws_linked_list_pop_front(&pending_resolve_copy); + struct pending_callback *pending_callback = + AWS_CONTAINER_OF(resolution_callback_node, struct pending_callback, node); + + struct aws_array_list callback_address_list; + aws_array_list_init_static(&callback_address_list, address_array, 2, sizeof(struct aws_host_address)); + + aws_mutex_lock(&host_entry->entry_lock); + s_copy_address_into_callback_set( + s_get_lru_address(host_entry, AWS_ADDRESS_RECORD_TYPE_AAAA), + &callback_address_list, + host_entry->host_name); + s_copy_address_into_callback_set( + s_get_lru_address(host_entry, AWS_ADDRESS_RECORD_TYPE_A), + &callback_address_list, + host_entry->host_name); + aws_mutex_unlock(&host_entry->entry_lock); + + AWS_ASSERT(err_code != AWS_ERROR_SUCCESS || aws_array_list_length(&callback_address_list) > 0); + + if (aws_array_list_length(&callback_address_list) > 0) { + pending_callback->callback( + host_entry->resolver, + host_entry->host_name, + AWS_OP_SUCCESS, + &callback_address_list, + pending_callback->user_data); + + } else { + pending_callback->callback( + host_entry->resolver, host_entry->host_name, err_code, NULL, pending_callback->user_data); + } + + s_clear_address_list(&callback_address_list); + + aws_mem_release(host_entry->allocator, pending_callback); + } + + aws_mutex_lock(&host_entry->entry_lock); + + ++host_entry->resolves_since_last_request; + + /* wait for a quit notification or the base resolve frequency time interval */ + aws_condition_variable_wait_for_pred( + &host_entry->entry_signal, + &host_entry->entry_lock, + host_entry->resolve_frequency_ns, + s_host_entry_finished_pred, + host_entry); + + aws_mutex_unlock(&host_entry->entry_lock); + + /* + * This is a bit awkward that we unlock the entry and then relock both the resolver and the entry, but it + * is mandatory that -- in order to maintain the consistent view of the resolver table (entry exist => entry + * is alive and can be queried) -- we have the resolver lock as well before making the decision to remove + * the entry from the table and terminate the thread. + */ + struct default_host_resolver *resolver = host_entry->resolver->impl; + aws_mutex_lock(&resolver->resolver_lock); + + /* Remove any listeners from our listener list that have been marked pending destroy, moving them into the + * destroy list. */ + s_resolver_thread_cull_pending_destroy_listeners(&listener_list, &listener_destroy_list); + + /* Grab any listeners on the listener entry, moving them into the local list. */ + s_resolver_thread_move_listeners_from_listener_entry(resolver, host_entry->host_name, &listener_list); + + aws_mutex_lock(&host_entry->entry_lock); + + uint64_t now = 0; + aws_sys_clock_get_ticks(&now); + + /* + * Ideally this should just be time-based, but given the non-determinism of waits (and spurious wake ups) and + * clock time, I feel much more comfortable keeping an additional constraint in terms of iterations. + * + * Note that we have the entry lock now and if any queries have arrived since our last resolution, + * resolves_since_last_request will be 0 or 1 (depending on timing) and so, regardless of wait and wake up + * timings, this check will always fail in that case leading to another iteration to satisfy the pending + * query(ies). + * + * The only way we terminate the loop with pending queries is if the resolver itself has no more references + * to it and is going away. In that case, the pending queries will be completed (with failure) by the + * final clean up of this entry. + */ + if (host_entry->resolves_since_last_request > unsolicited_resolve_max && + host_entry->last_resolve_request_timestamp_ns + max_no_solicitation_interval < now) { + host_entry->state = DRS_SHUTTING_DOWN; + } + + keep_going = host_entry->state == DRS_ACTIVE; + if (!keep_going) { + aws_hash_table_remove(&resolver->host_entry_table, host_entry->host_name, NULL, NULL); + + /* Move any local listeners we have back to the listener entry */ + if (s_resolver_thread_move_listeners_to_listener_entry(resolver, host_entry->host_name, &listener_list)) { + AWS_LOGF_ERROR(AWS_LS_IO_DNS, "static: could not clean up all listeners from resolver thread."); + } + } + + aws_mutex_unlock(&host_entry->entry_lock); + aws_mutex_unlock(&resolver->resolver_lock); + + /* Destroy any listeners in our destroy list. */ + s_resolver_thread_destroy_listeners(&listener_destroy_list); + + /* Notify our local listeners of new addresses. */ + s_resolver_thread_notify_listeners(&new_address_list, &listener_list); + + s_clear_address_list(&new_address_list); + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_DNS, + "static: Either no requests have been made for an address for %s for the duration " + "of the ttl, or this thread is being forcibly shutdown. Killing thread.", + host_entry->host_name->bytes) + + aws_array_list_clean_up(&address_list); + aws_array_list_clean_up(&new_address_list); + + /* please don't fail */ + aws_thread_current_at_exit(s_on_host_entry_shutdown_completion, host_entry); +} + +static void on_address_value_removed(void *value) { + struct aws_host_address *host_address = value; + + AWS_LOGF_DEBUG( + AWS_LS_IO_DNS, + "static: purging address %s for host %s from " + "the cache due to cache eviction or shutdown", + host_address->address->bytes, + host_address->host->bytes); + + struct aws_allocator *allocator = host_address->allocator; + aws_host_address_clean_up(host_address); + aws_mem_release(allocator, host_address); +} + +/* + * The resolver lock must be held before calling this function + */ +static inline int create_and_init_host_entry( + struct aws_host_resolver *resolver, + const struct aws_string *host_name, + aws_on_host_resolved_result_fn *res, + struct aws_host_resolution_config *config, + uint64_t timestamp, + void *user_data) { + struct host_entry *new_host_entry = aws_mem_calloc(resolver->allocator, 1, sizeof(struct host_entry)); + if (!new_host_entry) { + return AWS_OP_ERR; + } + + new_host_entry->resolver = resolver; + new_host_entry->allocator = resolver->allocator; + new_host_entry->last_resolve_request_timestamp_ns = timestamp; + new_host_entry->resolves_since_last_request = 0; + new_host_entry->resolve_frequency_ns = NS_PER_SEC; + new_host_entry->state = DRS_ACTIVE; + + bool thread_init = false; + struct pending_callback *pending_callback = NULL; + const struct aws_string *host_string_copy = aws_string_new_from_string(resolver->allocator, host_name); + if (AWS_UNLIKELY(!host_string_copy)) { + goto setup_host_entry_error; + } + + new_host_entry->host_name = host_string_copy; + new_host_entry->a_records = aws_cache_new_lru( + new_host_entry->allocator, + aws_hash_string, + aws_hash_callback_string_eq, + NULL, + on_address_value_removed, + config->max_ttl); + if (AWS_UNLIKELY(!new_host_entry->a_records)) { + goto setup_host_entry_error; + } + + new_host_entry->aaaa_records = aws_cache_new_lru( + new_host_entry->allocator, + aws_hash_string, + aws_hash_callback_string_eq, + NULL, + on_address_value_removed, + config->max_ttl); + if (AWS_UNLIKELY(!new_host_entry->aaaa_records)) { + goto setup_host_entry_error; + } + + new_host_entry->failed_connection_a_records = aws_cache_new_lru( + new_host_entry->allocator, + aws_hash_string, + aws_hash_callback_string_eq, + NULL, + on_address_value_removed, + config->max_ttl); + if (AWS_UNLIKELY(!new_host_entry->failed_connection_a_records)) { + goto setup_host_entry_error; + } + + new_host_entry->failed_connection_aaaa_records = aws_cache_new_lru( + new_host_entry->allocator, + aws_hash_string, + aws_hash_callback_string_eq, + NULL, + on_address_value_removed, + config->max_ttl); + if (AWS_UNLIKELY(!new_host_entry->failed_connection_aaaa_records)) { + goto setup_host_entry_error; + } + + aws_linked_list_init(&new_host_entry->pending_resolution_callbacks); + + pending_callback = aws_mem_acquire(resolver->allocator, sizeof(struct pending_callback)); + + if (AWS_UNLIKELY(!pending_callback)) { + goto setup_host_entry_error; + } + + /*add the current callback here */ + pending_callback->user_data = user_data; + pending_callback->callback = res; + aws_linked_list_push_back(&new_host_entry->pending_resolution_callbacks, &pending_callback->node); + + aws_mutex_init(&new_host_entry->entry_lock); + new_host_entry->resolution_config = *config; + aws_condition_variable_init(&new_host_entry->entry_signal); + + if (aws_thread_init(&new_host_entry->resolver_thread, resolver->allocator)) { + goto setup_host_entry_error; + } + + thread_init = true; + struct default_host_resolver *default_host_resolver = resolver->impl; + if (AWS_UNLIKELY( + aws_hash_table_put(&default_host_resolver->host_entry_table, host_string_copy, new_host_entry, NULL))) { + goto setup_host_entry_error; + } + + aws_thread_launch(&new_host_entry->resolver_thread, resolver_thread_fn, new_host_entry, NULL); + ++default_host_resolver->pending_host_entry_shutdown_completion_callbacks; + + return AWS_OP_SUCCESS; + +setup_host_entry_error: + + if (thread_init) { + aws_thread_clean_up(&new_host_entry->resolver_thread); + } + + s_clean_up_host_entry(new_host_entry); + + return AWS_OP_ERR; +} + +static int default_resolve_host( + struct aws_host_resolver *resolver, + const struct aws_string *host_name, + aws_on_host_resolved_result_fn *res, + struct aws_host_resolution_config *config, + void *user_data) { + int result = AWS_OP_SUCCESS; + + AWS_LOGF_DEBUG(AWS_LS_IO_DNS, "id=%p: Host resolution requested for %s", (void *)resolver, host_name->bytes); + + uint64_t timestamp = 0; + aws_sys_clock_get_ticks(×tamp); + + struct default_host_resolver *default_host_resolver = resolver->impl; + aws_mutex_lock(&default_host_resolver->resolver_lock); + + struct aws_hash_element *element = NULL; + /* we don't care about the error code here, only that the host_entry was found or not. */ + aws_hash_table_find(&default_host_resolver->host_entry_table, host_name, &element); + + struct host_entry *host_entry = NULL; + if (element != NULL) { + host_entry = element->value; + AWS_FATAL_ASSERT(host_entry != NULL); + } + + if (!host_entry) { + AWS_LOGF_DEBUG( + AWS_LS_IO_DNS, + "id=%p: No cached entries found for %s starting new resolver thread.", + (void *)resolver, + host_name->bytes); + + result = create_and_init_host_entry(resolver, host_name, res, config, timestamp, user_data); + aws_mutex_unlock(&default_host_resolver->resolver_lock); + + return result; + } + + aws_mutex_lock(&host_entry->entry_lock); + + /* + * We don't need to make any resolver side-affects in the remaining logic and it's impossible for the entry + * to disappear underneath us while holding its lock, so its safe to release the resolver lock and let other + * things query other entries. + */ + aws_mutex_unlock(&default_host_resolver->resolver_lock); + host_entry->last_resolve_request_timestamp_ns = timestamp; + host_entry->resolves_since_last_request = 0; + + struct aws_host_address *aaaa_record = aws_lru_cache_use_lru_element(host_entry->aaaa_records); + struct aws_host_address *a_record = aws_lru_cache_use_lru_element(host_entry->a_records); + struct aws_host_address address_array[2]; + AWS_ZERO_ARRAY(address_array); + struct aws_array_list callback_address_list; + aws_array_list_init_static(&callback_address_list, address_array, 2, sizeof(struct aws_host_address)); + + if ((aaaa_record || a_record)) { + AWS_LOGF_DEBUG( + AWS_LS_IO_DNS, + "id=%p: cached entries found for %s returning to caller.", + (void *)resolver, + host_name->bytes); + + /* these will all need to be copied so that we don't hold the lock during the callback. */ + if (aaaa_record) { + struct aws_host_address aaaa_record_cpy; + if (!aws_host_address_copy(aaaa_record, &aaaa_record_cpy)) { + aws_array_list_push_back(&callback_address_list, &aaaa_record_cpy); + AWS_LOGF_TRACE( + AWS_LS_IO_DNS, + "id=%p: vending address %s for host %s to caller", + (void *)resolver, + aaaa_record->address->bytes, + host_entry->host_name->bytes); + } + } + if (a_record) { + struct aws_host_address a_record_cpy; + if (!aws_host_address_copy(a_record, &a_record_cpy)) { + aws_array_list_push_back(&callback_address_list, &a_record_cpy); + AWS_LOGF_TRACE( + AWS_LS_IO_DNS, + "id=%p: vending address %s for host %s to caller", + (void *)resolver, + a_record->address->bytes, + host_entry->host_name->bytes); + } + } + aws_mutex_unlock(&host_entry->entry_lock); + + /* we don't want to do the callback WHILE we hold the lock someone may reentrantly call us. */ + if (aws_array_list_length(&callback_address_list)) { + res(resolver, host_name, AWS_OP_SUCCESS, &callback_address_list, user_data); + } else { + res(resolver, host_name, aws_last_error(), NULL, user_data); + result = AWS_OP_ERR; + } + + for (size_t i = 0; i < aws_array_list_length(&callback_address_list); ++i) { + struct aws_host_address *address_ptr = NULL; + aws_array_list_get_at_ptr(&callback_address_list, (void **)&address_ptr, i); + aws_host_address_clean_up(address_ptr); + } + + aws_array_list_clean_up(&callback_address_list); + + return result; + } + + struct pending_callback *pending_callback = + aws_mem_acquire(default_host_resolver->allocator, sizeof(struct pending_callback)); + if (pending_callback != NULL) { + pending_callback->user_data = user_data; + pending_callback->callback = res; + aws_linked_list_push_back(&host_entry->pending_resolution_callbacks, &pending_callback->node); + } else { + result = AWS_OP_ERR; + } + + aws_mutex_unlock(&host_entry->entry_lock); + + return result; +} + +static size_t default_get_host_address_count( + struct aws_host_resolver *host_resolver, + const struct aws_string *host_name, + uint32_t flags) { + struct default_host_resolver *default_host_resolver = host_resolver->impl; + size_t address_count = 0; + + aws_mutex_lock(&default_host_resolver->resolver_lock); + + struct aws_hash_element *element = NULL; + aws_hash_table_find(&default_host_resolver->host_entry_table, host_name, &element); + if (element != NULL) { + struct host_entry *host_entry = element->value; + if (host_entry != NULL) { + aws_mutex_lock(&host_entry->entry_lock); + + if ((flags & AWS_GET_HOST_ADDRESS_COUNT_RECORD_TYPE_A) != 0) { + address_count += aws_cache_get_element_count(host_entry->a_records); + } + + if ((flags & AWS_GET_HOST_ADDRESS_COUNT_RECORD_TYPE_AAAA) != 0) { + address_count += aws_cache_get_element_count(host_entry->aaaa_records); + } + + aws_mutex_unlock(&host_entry->entry_lock); + } + } + + aws_mutex_unlock(&default_host_resolver->resolver_lock); + + return address_count; +} + +static struct aws_host_resolver_vtable s_vtable = { + .purge_cache = resolver_purge_cache, + .resolve_host = default_resolve_host, + .record_connection_failure = resolver_record_connection_failure, + .get_host_address_count = default_get_host_address_count, + .add_host_listener = default_add_host_listener, + .remove_host_listener = default_remove_host_listener, + .destroy = resolver_destroy, +}; + +static void s_aws_host_resolver_destroy(struct aws_host_resolver *resolver) { + AWS_ASSERT(resolver->vtable && resolver->vtable->destroy); + resolver->vtable->destroy(resolver); +} + +struct aws_host_resolver *aws_host_resolver_new_default( + struct aws_allocator *allocator, + size_t max_entries, + struct aws_event_loop_group *el_group, + const struct aws_shutdown_callback_options *shutdown_options) { + /* NOTE: we don't use el_group yet, but we will in the future. Also, we + don't want host resolvers getting cleaned up after el_groups; this will force that + in bindings, and encourage it in C land. */ + (void)el_group; + AWS_ASSERT(el_group); + + struct aws_host_resolver *resolver = NULL; + struct default_host_resolver *default_host_resolver = NULL; + if (!aws_mem_acquire_many( + allocator, + 2, + &resolver, + sizeof(struct aws_host_resolver), + &default_host_resolver, + sizeof(struct default_host_resolver))) { + return NULL; + } + + AWS_ZERO_STRUCT(*resolver); + AWS_ZERO_STRUCT(*default_host_resolver); + + AWS_LOGF_INFO( + AWS_LS_IO_DNS, + "id=%p: Initializing default host resolver with %llu max host entries.", + (void *)resolver, + (unsigned long long)max_entries); + + resolver->vtable = &s_vtable; + resolver->allocator = allocator; + resolver->impl = default_host_resolver; + + default_host_resolver->allocator = allocator; + default_host_resolver->pending_host_entry_shutdown_completion_callbacks = 0; + default_host_resolver->state = DRS_ACTIVE; + aws_mutex_init(&default_host_resolver->resolver_lock); + + aws_global_thread_creator_increment(); + + if (aws_hash_table_init( + &default_host_resolver->host_entry_table, + allocator, + max_entries, + aws_hash_string, + aws_hash_callback_string_eq, + NULL, + NULL)) { + goto on_error; + } + + if (aws_hash_table_init( + &default_host_resolver->listener_entry_table, + allocator, + max_entries, + aws_hash_string, + aws_hash_callback_string_eq, + aws_hash_callback_string_destroy, + s_host_listener_entry_destroy)) { + goto on_error; + } + + aws_ref_count_init(&resolver->ref_count, resolver, (aws_simple_completion_callback *)s_aws_host_resolver_destroy); + + if (shutdown_options != NULL) { + resolver->shutdown_options = *shutdown_options; + } + + return resolver; + +on_error: + + s_cleanup_default_resolver(resolver); + + return NULL; +} + +struct aws_host_resolver *aws_host_resolver_acquire(struct aws_host_resolver *resolver) { + if (resolver != NULL) { + aws_ref_count_acquire(&resolver->ref_count); + } + + return resolver; +} + +void aws_host_resolver_release(struct aws_host_resolver *resolver) { + if (resolver != NULL) { + aws_ref_count_release(&resolver->ref_count); + } +} + +size_t aws_host_resolver_get_host_address_count( + struct aws_host_resolver *resolver, + const struct aws_string *host_name, + uint32_t flags) { + return resolver->vtable->get_host_address_count(resolver, host_name, flags); +} + +enum find_listener_entry_flags { + FIND_LISTENER_ENTRY_FLAGS_CREATE_IF_NOT_FOUND = 0x00000001, +}; + +static struct host_listener_entry *s_find_host_listener_entry( + struct default_host_resolver *default_resolver, + const struct aws_string *host_name, + uint32_t flags); + +static struct aws_host_listener *default_add_host_listener( + struct aws_host_resolver *resolver, + const struct aws_host_listener_options *options) { + AWS_PRECONDITION(resolver); + + if (options == NULL) { + AWS_LOGF_ERROR(AWS_LS_IO_DNS, "Cannot create host resolver listener; options structure is NULL."); + aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + return NULL; + } + + if (options->host_name.len == 0) { + AWS_LOGF_ERROR(AWS_LS_IO_DNS, "Cannot create host resolver listener; invalid host name specified."); + aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + return NULL; + } + + /* Allocate and set up the listener. */ + struct host_listener *listener = aws_mem_calloc(resolver->allocator, 1, sizeof(struct host_listener)); + + AWS_LOGF_TRACE( + AWS_LS_IO_DNS, + "id=%p Adding listener %p for host name %s", + (void *)resolver, + (void *)listener, + (const char *)options->host_name.ptr); + + aws_host_resolver_acquire(resolver); + listener->resolver = resolver; + listener->host_name = aws_string_new_from_cursor(resolver->allocator, &options->host_name); + listener->resolved_address_callback = options->resolved_address_callback; + listener->shutdown_callback = options->shutdown_callback; + listener->user_data = options->user_data; + + struct default_host_resolver *default_host_resolver = resolver->impl; + + /* Add the listener to a host listener entry in the host listener entry table. */ + aws_mutex_lock(&default_host_resolver->resolver_lock); + + if (s_add_host_listener_to_listener_entry(default_host_resolver, listener->host_name, listener)) { + aws_mem_release(resolver->allocator, listener); + listener = NULL; + } + + aws_mutex_unlock(&default_host_resolver->resolver_lock); + + return (struct aws_host_listener *)listener; +} + +static int default_remove_host_listener( + struct aws_host_resolver *host_resolver, + struct aws_host_listener *listener_opaque) { + AWS_PRECONDITION(host_resolver); + AWS_PRECONDITION(listener_opaque); + + struct host_listener *listener = (struct host_listener *)listener_opaque; + struct default_host_resolver *default_host_resolver = host_resolver->impl; + + if (listener->resolver != host_resolver) { + AWS_LOGF_ERROR( + AWS_LS_IO_DNS, + "id=%p Trying to remove listener from incorrect host resolver. Listener belongs to host resolver %p", + (void *)host_resolver, + (void *)listener->resolver); + aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + return AWS_OP_ERR; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_DNS, + "id=%p Removing listener %p for host name %s", + (void *)host_resolver, + (void *)listener, + (const char *)listener->host_name->bytes); + + bool destroy_listener_immediate = false; + + aws_mutex_lock(&default_host_resolver->resolver_lock); + + /* If owned by the resolver thread, flag the listener as pending destroy, so that resolver thread knows to destroy + * it. */ + if (listener->synced_data.owned_by_resolver_thread) { + listener->synced_data.pending_destroy = true; + } else { + /* Else, remove the listener from the listener entry and clean it up once outside of the mutex. */ + s_remove_host_listener_from_entry(default_host_resolver, listener->host_name, listener); + destroy_listener_immediate = true; + } + + aws_mutex_unlock(&default_host_resolver->resolver_lock); + + if (destroy_listener_immediate) { + s_host_listener_destroy(listener); + } + + return AWS_OP_SUCCESS; +} + +/* Find listener entry on the host resolver, optionally creating it if it doesn't exist. */ +/* Assumes host resolver lock is held. */ +static struct host_listener_entry *s_find_host_listener_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + uint32_t flags) { + AWS_PRECONDITION(resolver); + AWS_PRECONDITION(host_name); + + struct host_listener_entry *listener_entry = NULL; + struct aws_string *host_string_copy = NULL; + + struct aws_hash_element *listener_entry_hash_element = NULL; + bool create_if_not_found = (flags & FIND_LISTENER_ENTRY_FLAGS_CREATE_IF_NOT_FOUND) != 0; + + if (aws_hash_table_find(&resolver->listener_entry_table, host_name, &listener_entry_hash_element)) { + AWS_LOGF_ERROR( + AWS_LS_IO_DNS, "static: error when trying to find a listener entry in the listener entry table."); + goto error_clean_up; + } + + if (listener_entry_hash_element != NULL) { + listener_entry = listener_entry_hash_element->value; + AWS_FATAL_ASSERT(listener_entry); + } else if (create_if_not_found) { + + listener_entry = aws_mem_calloc(resolver->allocator, 1, sizeof(struct host_listener_entry)); + listener_entry->resolver = resolver; + aws_linked_list_init(&listener_entry->listeners); + + host_string_copy = aws_string_new_from_string(resolver->allocator, host_name); + + if (aws_hash_table_put(&resolver->listener_entry_table, host_string_copy, listener_entry, NULL)) { + AWS_LOGF_ERROR(AWS_LS_IO_DNS, "static: could not put new listener entry into listener entry table."); + goto error_clean_up; + } + } + + return listener_entry; + +error_clean_up: + + s_host_listener_entry_destroy(listener_entry); + + aws_string_destroy(host_string_copy); + + return NULL; +} + +/* Destroy function for listener entries. Takes a void* so that it can be used by the listener entry hash table. */ +static void s_host_listener_entry_destroy(void *listener_entry_void) { + if (listener_entry_void == NULL) { + return; + } + + struct host_listener_entry *listener_entry = listener_entry_void; + struct default_host_resolver *resolver = listener_entry->resolver; + + aws_mem_release(resolver->allocator, listener_entry); +} + +/* Add a listener to the relevant host listener entry. */ +/* Assumes host resolver lock is held. */ +static int s_add_host_listener_to_listener_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + struct host_listener *listener) { + AWS_PRECONDITION(resolver); + AWS_PRECONDITION(host_name); + AWS_PRECONDITION(listener); + + struct host_listener_entry *listener_entry = + s_find_host_listener_entry(resolver, host_name, FIND_LISTENER_ENTRY_FLAGS_CREATE_IF_NOT_FOUND); + + if (listener_entry == NULL) { + return AWS_OP_ERR; + } + + aws_linked_list_push_back(&listener_entry->listeners, &listener->synced_data.node); + return AWS_OP_SUCCESS; +} + +/* Assumes host resolver lock is held. */ +static struct host_listener *s_pop_host_listener_from_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + struct host_listener_entry **in_out_listener_entry) { + AWS_PRECONDITION(resolver); + AWS_PRECONDITION(host_name); + + struct host_listener_entry *listener_entry = NULL; + + if (in_out_listener_entry) { + listener_entry = *in_out_listener_entry; + } + + if (listener_entry == NULL) { + listener_entry = s_find_host_listener_entry(resolver, host_name, 0); + + if (listener_entry == NULL) { + return NULL; + } + } + + /* We should never have a listener entry without any listeners. Whenever a listener entry has no listeners, it + * should be cleaned up immediately. */ + AWS_ASSERT(!aws_linked_list_empty(&listener_entry->listeners)); + + struct aws_linked_list_node *node = aws_linked_list_pop_back(&listener_entry->listeners); + + struct host_listener *listener = HOST_LISTENER_FROM_SYNCED_NODE(node); + AWS_FATAL_ASSERT(listener); + + /* If the listener list on the listener entry is now empty, remove it. */ + if (aws_linked_list_empty(&listener_entry->listeners)) { + aws_hash_table_remove(&resolver->listener_entry_table, host_name, NULL, NULL); + listener_entry = NULL; + } + + if (in_out_listener_entry) { + *in_out_listener_entry = listener_entry; + } + + return listener; +} + +/* Assumes host resolver lock is held. */ +static void s_remove_host_listener_from_entry( + struct default_host_resolver *resolver, + const struct aws_string *host_name, + struct host_listener *listener) { + AWS_PRECONDITION(resolver); + AWS_PRECONDITION(host_name); + AWS_PRECONDITION(listener); + + struct host_listener_entry *listener_entry = s_find_host_listener_entry(resolver, host_name, 0); + + if (listener_entry == NULL) { + AWS_LOGF_WARN(AWS_LS_IO_DNS, "id=%p: Could not find listener entry for listener.", (void *)listener); + return; + } + + /* We should never have a listener entry without any listeners. Whenever a listener entry has no listeners, it + * should be cleaned up immediately. */ + AWS_ASSERT(!aws_linked_list_empty(&listener_entry->listeners)); + + aws_linked_list_remove(&listener->synced_data.node); + + /* If the listener list on the listener entry is now empty, remove it. */ + if (aws_linked_list_empty(&listener_entry->listeners)) { + aws_hash_table_remove(&resolver->listener_entry_table, host_name, NULL, NULL); + } +} + +/* Finish destroying a default resolver listener, releasing any remaining memory for it and triggering its shutdown + * callack. Since a shutdown callback is triggered, no lock should be held when calling this function. */ +static void s_host_listener_destroy(struct host_listener *listener) { + if (listener == NULL) { + return; + } + + AWS_LOGF_TRACE(AWS_LS_IO_DNS, "id=%p: Finishing clean up of host listener.", (void *)listener); + + struct aws_host_resolver *host_resolver = listener->resolver; + + aws_host_listener_shutdown_fn *shutdown_callback = listener->shutdown_callback; + void *shutdown_user_data = listener->user_data; + + aws_string_destroy(listener->host_name); + listener->host_name = NULL; + + aws_mem_release(host_resolver->allocator, listener); + listener = NULL; + + if (shutdown_callback != NULL) { + shutdown_callback(shutdown_user_data); + } + + if (host_resolver != NULL) { + aws_host_resolver_release(host_resolver); + host_resolver = NULL; + } +} + +#undef HOST_LISTENER_FROM_SYNCED_NODE +#undef HOST_LISTENER_FROM_THREADED_NODE diff --git a/contrib/restricted/aws/aws-c-io/source/io.c b/contrib/restricted/aws/aws-c-io/source/io.c index dc0092a76a..4660d19739 100644 --- a/contrib/restricted/aws/aws-c-io/source/io.c +++ b/contrib/restricted/aws/aws-c-io/source/io.c @@ -1,230 +1,230 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/io.h> - -#include <aws/io/logging.h> - -#include <aws/cal/cal.h> - -#define AWS_DEFINE_ERROR_INFO_IO(CODE, STR) [(CODE)-0x0400] = AWS_DEFINE_ERROR_INFO(CODE, STR, "aws-c-io") - -/* clang-format off */ -static struct aws_error_info s_errors[] = { - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_CHANNEL_ERROR_ERROR_CANT_ACCEPT_INPUT, - "Channel cannot accept input"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_CHANNEL_UNKNOWN_MESSAGE_TYPE, - "Channel unknown message type"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_CHANNEL_READ_WOULD_EXCEED_WINDOW, - "A channel handler attempted to propagate a read larger than the upstream window"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_EVENT_LOOP_ALREADY_ASSIGNED, - "An attempt was made to assign an io handle to an event loop, but the handle was already assigned."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_EVENT_LOOP_SHUTDOWN, - "Event loop has shutdown and a resource was still using it, the resource has been removed from the loop."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE, - "TLS (SSL) negotiation failed"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_ERROR_NOT_NEGOTIATED, - "Attempt to read/write, but TLS (SSL) hasn't been negotiated"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_ERROR_WRITE_FAILURE, - "Failed to write to TLS handler"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_ERROR_ALERT_RECEIVED, - "Fatal TLS Alert was received"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_CTX_ERROR, - "Failed to create tls context"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_VERSION_UNSUPPORTED, - "A TLS version was specified that is currently not supported. Consider using AWS_IO_TLS_VER_SYS_DEFAULTS, " - " and when this lib or the operating system is updated, it will automatically be used."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_CIPHER_PREF_UNSUPPORTED, - "A TLS Cipher Preference was specified that is currently not supported by the current platform. Consider " - " using AWS_IO_TLS_CIPHER_SYSTEM_DEFAULT, and when this lib or the operating system is updated, it will " - "automatically be used."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_MISSING_ALPN_MESSAGE, - "An ALPN message was expected but not received"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_UNHANDLED_ALPN_PROTOCOL_MESSAGE, - "An ALPN message was received but a handler was not created by the user"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_FILE_VALIDATION_FAILURE, - "A file was read and the input did not match the expected value"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY, - "Attempt to perform operation that must be run inside the event loop thread"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_ERROR_IO_ALREADY_SUBSCRIBED, - "Already subscribed to receive events"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_ERROR_IO_NOT_SUBSCRIBED, - "Not subscribed to receive events"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_ERROR_IO_OPERATION_CANCELLED, - "Operation cancelled before it could complete"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_READ_WOULD_BLOCK, - "Read operation would block, try again later"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_BROKEN_PIPE, - "Attempt to read or write to io handle that has already been closed."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY, - "Socket, unsupported address family."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_INVALID_OPERATION_FOR_TYPE, - "Invalid socket operation for socket type."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_CONNECTION_REFUSED, - "socket connection refused."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_TIMEOUT, - "socket operation timed out."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_NO_ROUTE_TO_HOST, - "socket connect failure, no route to host."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_NETWORK_DOWN, - "network is down."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_CLOSED, - "socket is closed."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_NOT_CONNECTED, - "socket not connected."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_INVALID_OPTIONS, - "Invalid socket options."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_ADDRESS_IN_USE, - "Socket address already in use."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_INVALID_ADDRESS, - "Invalid socket address."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE, - "Illegal operation for socket state."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SOCKET_CONNECT_ABORTED, - "Incoming connection was aborted."), - AWS_DEFINE_ERROR_INFO_IO ( - AWS_IO_DNS_QUERY_FAILED, - "A query to dns failed to resolve."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_DNS_INVALID_NAME, - "Host name was invalid for dns resolution."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_DNS_NO_ADDRESS_FOR_HOST, - "No address was found for the supplied host name."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_DNS_HOST_REMOVED_FROM_CACHE, - "The entries for host name were removed from the local dns cache."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_STREAM_INVALID_SEEK_POSITION, - "The seek position was outside of a stream's bounds"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_STREAM_READ_FAILED, - "Stream failed to read from the underlying io source"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_INVALID_FILE_HANDLE, - "Operation failed because the file handle was invalid"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SHARED_LIBRARY_LOAD_FAILURE, - "System call error during attempt to load shared library"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_SHARED_LIBRARY_FIND_SYMBOL_FAILURE, - "System call error during attempt to find shared library symbol"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_NEGOTIATION_TIMEOUT, - "Channel shutdown due to tls negotiation timeout"), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_TLS_ALERT_NOT_GRACEFUL, - "Channel shutdown due to tls alert. The alert was not for a graceful shutdown."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_MAX_RETRIES_EXCEEDED, - "Retry cannot be attempted because the maximum number of retries has been exceeded."), - AWS_DEFINE_ERROR_INFO_IO( - AWS_IO_RETRY_PERMISSION_DENIED, - "Retry cannot be attempted because the retry strategy has prevented the operation."), -}; -/* clang-format on */ - -static struct aws_error_info_list s_list = { - .error_list = s_errors, - .count = sizeof(s_errors) / sizeof(struct aws_error_info), -}; - -static struct aws_log_subject_info s_io_log_subject_infos[] = { - DEFINE_LOG_SUBJECT_INFO( - AWS_LS_IO_GENERAL, - "aws-c-io", - "Subject for IO logging that doesn't belong to any particular category"), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_EVENT_LOOP, "event-loop", "Subject for Event-loop specific logging."), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_SOCKET, "socket", "Subject for Socket specific logging."), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_SOCKET_HANDLER, "socket-handler", "Subject for a socket channel handler."), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_TLS, "tls-handler", "Subject for TLS-related logging"), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_ALPN, "alpn", "Subject for ALPN-related logging"), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_DNS, "dns", "Subject for DNS-related logging"), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_PKI, "pki-utils", "Subject for Pki utilities."), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_CHANNEL, "channel", "Subject for Channels"), - DEFINE_LOG_SUBJECT_INFO( - AWS_LS_IO_CHANNEL_BOOTSTRAP, - "channel-bootstrap", - "Subject for channel bootstrap (client and server modes)"), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_FILE_UTILS, "file-utils", "Subject for file operations"), - DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_SHARED_LIBRARY, "shared-library", "Subject for shared library operations"), - DEFINE_LOG_SUBJECT_INFO( - AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, - "exp-backoff-strategy", - "Subject for exponential backoff retry strategy")}; - -static struct aws_log_subject_info_list s_io_log_subject_list = { - .subject_list = s_io_log_subject_infos, - .count = AWS_ARRAY_SIZE(s_io_log_subject_infos), -}; - -static bool s_io_library_initialized = false; - -void aws_tls_init_static_state(struct aws_allocator *alloc); -void aws_tls_clean_up_static_state(void); - -void aws_io_library_init(struct aws_allocator *allocator) { - if (!s_io_library_initialized) { - s_io_library_initialized = true; - aws_common_library_init(allocator); - aws_cal_library_init(allocator); - aws_register_error_info(&s_list); - aws_register_log_subject_info_list(&s_io_log_subject_list); - aws_tls_init_static_state(allocator); - } -} - -void aws_io_library_clean_up(void) { - if (s_io_library_initialized) { - s_io_library_initialized = false; - aws_tls_clean_up_static_state(); - aws_unregister_error_info(&s_list); - aws_unregister_log_subject_info_list(&s_io_log_subject_list); - aws_cal_library_clean_up(); - aws_common_library_clean_up(); - } -} - -void aws_io_fatal_assert_library_initialized(void) { - if (!s_io_library_initialized) { - AWS_LOGF_FATAL( - AWS_LS_IO_GENERAL, "aws_io_library_init() must be called before using any functionality in aws-c-io."); - - AWS_FATAL_ASSERT(s_io_library_initialized); - } -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/io.h> + +#include <aws/io/logging.h> + +#include <aws/cal/cal.h> + +#define AWS_DEFINE_ERROR_INFO_IO(CODE, STR) [(CODE)-0x0400] = AWS_DEFINE_ERROR_INFO(CODE, STR, "aws-c-io") + +/* clang-format off */ +static struct aws_error_info s_errors[] = { + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_CHANNEL_ERROR_ERROR_CANT_ACCEPT_INPUT, + "Channel cannot accept input"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_CHANNEL_UNKNOWN_MESSAGE_TYPE, + "Channel unknown message type"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_CHANNEL_READ_WOULD_EXCEED_WINDOW, + "A channel handler attempted to propagate a read larger than the upstream window"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_EVENT_LOOP_ALREADY_ASSIGNED, + "An attempt was made to assign an io handle to an event loop, but the handle was already assigned."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_EVENT_LOOP_SHUTDOWN, + "Event loop has shutdown and a resource was still using it, the resource has been removed from the loop."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE, + "TLS (SSL) negotiation failed"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_ERROR_NOT_NEGOTIATED, + "Attempt to read/write, but TLS (SSL) hasn't been negotiated"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_ERROR_WRITE_FAILURE, + "Failed to write to TLS handler"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_ERROR_ALERT_RECEIVED, + "Fatal TLS Alert was received"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_CTX_ERROR, + "Failed to create tls context"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_VERSION_UNSUPPORTED, + "A TLS version was specified that is currently not supported. Consider using AWS_IO_TLS_VER_SYS_DEFAULTS, " + " and when this lib or the operating system is updated, it will automatically be used."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_CIPHER_PREF_UNSUPPORTED, + "A TLS Cipher Preference was specified that is currently not supported by the current platform. Consider " + " using AWS_IO_TLS_CIPHER_SYSTEM_DEFAULT, and when this lib or the operating system is updated, it will " + "automatically be used."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_MISSING_ALPN_MESSAGE, + "An ALPN message was expected but not received"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_UNHANDLED_ALPN_PROTOCOL_MESSAGE, + "An ALPN message was received but a handler was not created by the user"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_FILE_VALIDATION_FAILURE, + "A file was read and the input did not match the expected value"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY, + "Attempt to perform operation that must be run inside the event loop thread"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_ERROR_IO_ALREADY_SUBSCRIBED, + "Already subscribed to receive events"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_ERROR_IO_NOT_SUBSCRIBED, + "Not subscribed to receive events"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_ERROR_IO_OPERATION_CANCELLED, + "Operation cancelled before it could complete"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_READ_WOULD_BLOCK, + "Read operation would block, try again later"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_BROKEN_PIPE, + "Attempt to read or write to io handle that has already been closed."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY, + "Socket, unsupported address family."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_INVALID_OPERATION_FOR_TYPE, + "Invalid socket operation for socket type."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_CONNECTION_REFUSED, + "socket connection refused."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_TIMEOUT, + "socket operation timed out."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_NO_ROUTE_TO_HOST, + "socket connect failure, no route to host."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_NETWORK_DOWN, + "network is down."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_CLOSED, + "socket is closed."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_NOT_CONNECTED, + "socket not connected."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_INVALID_OPTIONS, + "Invalid socket options."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_ADDRESS_IN_USE, + "Socket address already in use."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_INVALID_ADDRESS, + "Invalid socket address."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE, + "Illegal operation for socket state."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SOCKET_CONNECT_ABORTED, + "Incoming connection was aborted."), + AWS_DEFINE_ERROR_INFO_IO ( + AWS_IO_DNS_QUERY_FAILED, + "A query to dns failed to resolve."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_DNS_INVALID_NAME, + "Host name was invalid for dns resolution."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_DNS_NO_ADDRESS_FOR_HOST, + "No address was found for the supplied host name."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_DNS_HOST_REMOVED_FROM_CACHE, + "The entries for host name were removed from the local dns cache."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_STREAM_INVALID_SEEK_POSITION, + "The seek position was outside of a stream's bounds"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_STREAM_READ_FAILED, + "Stream failed to read from the underlying io source"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_INVALID_FILE_HANDLE, + "Operation failed because the file handle was invalid"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SHARED_LIBRARY_LOAD_FAILURE, + "System call error during attempt to load shared library"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_SHARED_LIBRARY_FIND_SYMBOL_FAILURE, + "System call error during attempt to find shared library symbol"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_NEGOTIATION_TIMEOUT, + "Channel shutdown due to tls negotiation timeout"), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_TLS_ALERT_NOT_GRACEFUL, + "Channel shutdown due to tls alert. The alert was not for a graceful shutdown."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_MAX_RETRIES_EXCEEDED, + "Retry cannot be attempted because the maximum number of retries has been exceeded."), + AWS_DEFINE_ERROR_INFO_IO( + AWS_IO_RETRY_PERMISSION_DENIED, + "Retry cannot be attempted because the retry strategy has prevented the operation."), +}; +/* clang-format on */ + +static struct aws_error_info_list s_list = { + .error_list = s_errors, + .count = sizeof(s_errors) / sizeof(struct aws_error_info), +}; + +static struct aws_log_subject_info s_io_log_subject_infos[] = { + DEFINE_LOG_SUBJECT_INFO( + AWS_LS_IO_GENERAL, + "aws-c-io", + "Subject for IO logging that doesn't belong to any particular category"), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_EVENT_LOOP, "event-loop", "Subject for Event-loop specific logging."), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_SOCKET, "socket", "Subject for Socket specific logging."), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_SOCKET_HANDLER, "socket-handler", "Subject for a socket channel handler."), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_TLS, "tls-handler", "Subject for TLS-related logging"), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_ALPN, "alpn", "Subject for ALPN-related logging"), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_DNS, "dns", "Subject for DNS-related logging"), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_PKI, "pki-utils", "Subject for Pki utilities."), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_CHANNEL, "channel", "Subject for Channels"), + DEFINE_LOG_SUBJECT_INFO( + AWS_LS_IO_CHANNEL_BOOTSTRAP, + "channel-bootstrap", + "Subject for channel bootstrap (client and server modes)"), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_FILE_UTILS, "file-utils", "Subject for file operations"), + DEFINE_LOG_SUBJECT_INFO(AWS_LS_IO_SHARED_LIBRARY, "shared-library", "Subject for shared library operations"), + DEFINE_LOG_SUBJECT_INFO( + AWS_LS_IO_EXPONENTIAL_BACKOFF_RETRY_STRATEGY, + "exp-backoff-strategy", + "Subject for exponential backoff retry strategy")}; + +static struct aws_log_subject_info_list s_io_log_subject_list = { + .subject_list = s_io_log_subject_infos, + .count = AWS_ARRAY_SIZE(s_io_log_subject_infos), +}; + +static bool s_io_library_initialized = false; + +void aws_tls_init_static_state(struct aws_allocator *alloc); +void aws_tls_clean_up_static_state(void); + +void aws_io_library_init(struct aws_allocator *allocator) { + if (!s_io_library_initialized) { + s_io_library_initialized = true; + aws_common_library_init(allocator); + aws_cal_library_init(allocator); + aws_register_error_info(&s_list); + aws_register_log_subject_info_list(&s_io_log_subject_list); + aws_tls_init_static_state(allocator); + } +} + +void aws_io_library_clean_up(void) { + if (s_io_library_initialized) { + s_io_library_initialized = false; + aws_tls_clean_up_static_state(); + aws_unregister_error_info(&s_list); + aws_unregister_log_subject_info_list(&s_io_log_subject_list); + aws_cal_library_clean_up(); + aws_common_library_clean_up(); + } +} + +void aws_io_fatal_assert_library_initialized(void) { + if (!s_io_library_initialized) { + AWS_LOGF_FATAL( + AWS_LS_IO_GENERAL, "aws_io_library_init() must be called before using any functionality in aws-c-io."); + + AWS_FATAL_ASSERT(s_io_library_initialized); + } +} diff --git a/contrib/restricted/aws/aws-c-io/source/linux/epoll_event_loop.c b/contrib/restricted/aws/aws-c-io/source/linux/epoll_event_loop.c index 8957e6c2b6..e8fd3e87d7 100644 --- a/contrib/restricted/aws/aws-c-io/source/linux/epoll_event_loop.c +++ b/contrib/restricted/aws/aws-c-io/source/linux/epoll_event_loop.c @@ -1,655 +1,655 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/event_loop.h> - -#include <aws/common/atomics.h> -#include <aws/common/clock.h> -#include <aws/common/mutex.h> -#include <aws/common/task_scheduler.h> -#include <aws/common/thread.h> - -#include <aws/io/logging.h> - -#include <sys/epoll.h> - -#include <errno.h> -#include <limits.h> -#include <unistd.h> - -#if !defined(COMPAT_MODE) && defined(__GLIBC__) && __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 8 -# define USE_EFD 1 -#else -# define USE_EFD 0 -#endif - -#if USE_EFD -# include <aws/io/io.h> -# include <sys/eventfd.h> - -#else -# include <aws/io/pipe.h> -#endif - -/* This isn't defined on ancient linux distros (breaking the builds). - * However, if this is a prebuild, we purposely build on an ancient system, but - * we want the kernel calls to still be the same as a modern build since that's likely the target of the application - * calling this code. Just define this if it isn't there already. GlibC and the kernel don't really care how the flag - * gets passed as long as it does. - */ -#ifndef EPOLLRDHUP -# define EPOLLRDHUP 0x2000 -#endif - -static void s_destroy(struct aws_event_loop *event_loop); -static int s_run(struct aws_event_loop *event_loop); -static int s_stop(struct aws_event_loop *event_loop); -static int s_wait_for_stop_completion(struct aws_event_loop *event_loop); -static void s_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task); -static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos); -static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task); -static int s_subscribe_to_io_events( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - aws_event_loop_on_event_fn *on_event, - void *user_data); -static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle); -static void s_free_io_event_resources(void *user_data); -static bool s_is_on_callers_thread(struct aws_event_loop *event_loop); - -static void s_main_loop(void *args); - -static struct aws_event_loop_vtable s_vtable = { - .destroy = s_destroy, - .run = s_run, - .stop = s_stop, - .wait_for_stop_completion = s_wait_for_stop_completion, - .schedule_task_now = s_schedule_task_now, - .schedule_task_future = s_schedule_task_future, - .cancel_task = s_cancel_task, - .subscribe_to_io_events = s_subscribe_to_io_events, - .unsubscribe_from_io_events = s_unsubscribe_from_io_events, - .free_io_event_resources = s_free_io_event_resources, - .is_on_callers_thread = s_is_on_callers_thread, -}; - -struct epoll_loop { - struct aws_task_scheduler scheduler; - struct aws_thread thread_created_on; - aws_thread_id_t thread_joined_to; - struct aws_atomic_var running_thread_id; - struct aws_io_handle read_task_handle; - struct aws_io_handle write_task_handle; - struct aws_mutex task_pre_queue_mutex; - struct aws_linked_list task_pre_queue; - struct aws_task stop_task; - struct aws_atomic_var stop_task_ptr; - int epoll_fd; - bool should_process_task_pre_queue; - bool should_continue; -}; - -struct epoll_event_data { - struct aws_allocator *alloc; - struct aws_io_handle *handle; - aws_event_loop_on_event_fn *on_event; - void *user_data; - struct aws_task cleanup_task; - bool is_subscribed; /* false when handle is unsubscribed, but this struct hasn't been cleaned up yet */ -}; - -/* default timeout is 100 seconds */ -enum { - DEFAULT_TIMEOUT = 100 * 1000, - MAX_EVENTS = 100, -}; - -int aws_open_nonblocking_posix_pipe(int pipe_fds[2]); - -/* Setup edge triggered epoll with a scheduler. */ -struct aws_event_loop *aws_event_loop_new_default(struct aws_allocator *alloc, aws_io_clock_fn *clock) { - struct aws_event_loop *loop = aws_mem_calloc(alloc, 1, sizeof(struct aws_event_loop)); - if (!loop) { - return NULL; - } - - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Initializing edge-triggered epoll", (void *)loop); - if (aws_event_loop_init_base(loop, alloc, clock)) { - goto clean_up_loop; - } - - struct epoll_loop *epoll_loop = aws_mem_calloc(alloc, 1, sizeof(struct epoll_loop)); - if (!epoll_loop) { - goto cleanup_base_loop; - } - - /* initialize thread id to NULL, it should be updated when the event loop thread starts. */ - aws_atomic_init_ptr(&epoll_loop->running_thread_id, NULL); - - aws_linked_list_init(&epoll_loop->task_pre_queue); - epoll_loop->task_pre_queue_mutex = (struct aws_mutex)AWS_MUTEX_INIT; - aws_atomic_init_ptr(&epoll_loop->stop_task_ptr, NULL); - - epoll_loop->epoll_fd = epoll_create(100); - if (epoll_loop->epoll_fd < 0) { - AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: Failed to open epoll handle.", (void *)loop); - aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); - goto clean_up_epoll; - } - - if (aws_thread_init(&epoll_loop->thread_created_on, alloc)) { - goto clean_up_epoll; - } - -#if USE_EFD - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Using eventfd for cross-thread notifications.", (void *)loop); - int fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); - - if (fd < 0) { - AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: Failed to open eventfd handle.", (void *)loop); - aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); - goto clean_up_thread; - } - - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: eventfd descriptor %d.", (void *)loop, fd); - epoll_loop->write_task_handle = (struct aws_io_handle){.data.fd = fd, .additional_data = NULL}; - epoll_loop->read_task_handle = (struct aws_io_handle){.data.fd = fd, .additional_data = NULL}; -#else - AWS_LOGF_DEBUG( - AWS_LS_IO_EVENT_LOOP, - "id=%p: Eventfd not available, falling back to pipe for cross-thread notification.", - (void *)loop); - - int pipe_fds[2] = {0}; - /* this pipe is for task scheduling. */ - if (aws_open_nonblocking_posix_pipe(pipe_fds)) { - AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: failed to open pipe handle.", (void *)loop); - goto clean_up_thread; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, "id=%p: pipe descriptors read %d, write %d.", (void *)loop, pipe_fds[0], pipe_fds[1]); - epoll_loop->write_task_handle.data.fd = pipe_fds[1]; - epoll_loop->read_task_handle.data.fd = pipe_fds[0]; -#endif - - if (aws_task_scheduler_init(&epoll_loop->scheduler, alloc)) { - goto clean_up_pipe; - } - - epoll_loop->should_continue = false; - - loop->impl_data = epoll_loop; - loop->vtable = &s_vtable; - - return loop; - -clean_up_pipe: -#if USE_EFD - close(epoll_loop->write_task_handle.data.fd); - epoll_loop->write_task_handle.data.fd = -1; - epoll_loop->read_task_handle.data.fd = -1; -#else - close(epoll_loop->read_task_handle.data.fd); - close(epoll_loop->write_task_handle.data.fd); -#endif - -clean_up_thread: - aws_thread_clean_up(&epoll_loop->thread_created_on); - -clean_up_epoll: - if (epoll_loop->epoll_fd >= 0) { - close(epoll_loop->epoll_fd); - } - - aws_mem_release(alloc, epoll_loop); - -cleanup_base_loop: - aws_event_loop_clean_up_base(loop); - -clean_up_loop: - aws_mem_release(alloc, loop); - - return NULL; -} - -static void s_destroy(struct aws_event_loop *event_loop) { - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Destroying event_loop", (void *)event_loop); - - struct epoll_loop *epoll_loop = event_loop->impl_data; - - /* we don't know if stop() has been called by someone else, - * just call stop() again and wait for event-loop to finish. */ - aws_event_loop_stop(event_loop); - s_wait_for_stop_completion(event_loop); - - /* setting this so that canceled tasks don't blow up when asking if they're on the event-loop thread. */ - epoll_loop->thread_joined_to = aws_thread_current_thread_id(); - aws_atomic_store_ptr(&epoll_loop->running_thread_id, &epoll_loop->thread_joined_to); - aws_task_scheduler_clean_up(&epoll_loop->scheduler); - - while (!aws_linked_list_empty(&epoll_loop->task_pre_queue)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&epoll_loop->task_pre_queue); - struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); - task->fn(task, task->arg, AWS_TASK_STATUS_CANCELED); - } - - aws_thread_clean_up(&epoll_loop->thread_created_on); -#if USE_EFD - close(epoll_loop->write_task_handle.data.fd); - epoll_loop->write_task_handle.data.fd = -1; - epoll_loop->read_task_handle.data.fd = -1; -#else - close(epoll_loop->read_task_handle.data.fd); - close(epoll_loop->write_task_handle.data.fd); -#endif - - close(epoll_loop->epoll_fd); - aws_mem_release(event_loop->alloc, epoll_loop); - aws_event_loop_clean_up_base(event_loop); - aws_mem_release(event_loop->alloc, event_loop); -} - -static int s_run(struct aws_event_loop *event_loop) { - struct epoll_loop *epoll_loop = event_loop->impl_data; - - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Starting event-loop thread.", (void *)event_loop); - - epoll_loop->should_continue = true; - if (aws_thread_launch(&epoll_loop->thread_created_on, &s_main_loop, event_loop, NULL)) { - AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: thread creation failed.", (void *)event_loop); - epoll_loop->should_continue = false; - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -static void s_stop_task(struct aws_task *task, void *args, enum aws_task_status status) { - - (void)task; - struct aws_event_loop *event_loop = args; - struct epoll_loop *epoll_loop = event_loop->impl_data; - - /* now okay to reschedule stop tasks. */ - aws_atomic_store_ptr(&epoll_loop->stop_task_ptr, NULL); - if (status == AWS_TASK_STATUS_RUN_READY) { - /* - * this allows the event loop to invoke the callback once the event loop has completed. - */ - epoll_loop->should_continue = false; - } -} - -static int s_stop(struct aws_event_loop *event_loop) { - struct epoll_loop *epoll_loop = event_loop->impl_data; - - void *expected_ptr = NULL; - bool update_succeeded = - aws_atomic_compare_exchange_ptr(&epoll_loop->stop_task_ptr, &expected_ptr, &epoll_loop->stop_task); - if (!update_succeeded) { - /* the stop task is already scheduled. */ - return AWS_OP_SUCCESS; - } - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Stopping event-loop thread.", (void *)event_loop); - aws_task_init(&epoll_loop->stop_task, s_stop_task, event_loop, "epoll_event_loop_stop"); - s_schedule_task_now(event_loop, &epoll_loop->stop_task); - - return AWS_OP_SUCCESS; -} - -static int s_wait_for_stop_completion(struct aws_event_loop *event_loop) { - struct epoll_loop *epoll_loop = event_loop->impl_data; - return aws_thread_join(&epoll_loop->thread_created_on); -} - -static void s_schedule_task_common(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { - struct epoll_loop *epoll_loop = event_loop->impl_data; - - /* if event loop and the caller are the same thread, just schedule and be done with it. */ - if (s_is_on_callers_thread(event_loop)) { - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: scheduling task %p in-thread for timestamp %llu", - (void *)event_loop, - (void *)task, - (unsigned long long)run_at_nanos); - if (run_at_nanos == 0) { - /* zero denotes "now" task */ - aws_task_scheduler_schedule_now(&epoll_loop->scheduler, task); - } else { - aws_task_scheduler_schedule_future(&epoll_loop->scheduler, task, run_at_nanos); - } - return; - } - - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: Scheduling task %p cross-thread for timestamp %llu", - (void *)event_loop, - (void *)task, - (unsigned long long)run_at_nanos); - task->timestamp = run_at_nanos; - aws_mutex_lock(&epoll_loop->task_pre_queue_mutex); - - uint64_t counter = 1; - - bool is_first_task = aws_linked_list_empty(&epoll_loop->task_pre_queue); - - aws_linked_list_push_back(&epoll_loop->task_pre_queue, &task->node); - - /* if the list was not empty, we already have a pending read on the pipe/eventfd, no need to write again. */ - if (is_first_task) { - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: Waking up event-loop thread", (void *)event_loop); - - /* If the write fails because the buffer is full, we don't actually care because that means there's a pending - * read on the pipe/eventfd and thus the event loop will end up checking to see if something has been queued.*/ - ssize_t do_not_care = write(epoll_loop->write_task_handle.data.fd, (void *)&counter, sizeof(counter)); - (void)do_not_care; - } - - aws_mutex_unlock(&epoll_loop->task_pre_queue_mutex); -} - -static void s_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task) { - s_schedule_task_common(event_loop, task, 0 /* zero denotes "now" task */); -} - -static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { - s_schedule_task_common(event_loop, task, run_at_nanos); -} - -static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task) { - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: cancelling task %p", (void *)event_loop, (void *)task); - struct epoll_loop *epoll_loop = event_loop->impl_data; - aws_task_scheduler_cancel_task(&epoll_loop->scheduler, task); -} - -static int s_subscribe_to_io_events( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - aws_event_loop_on_event_fn *on_event, - void *user_data) { - - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: subscribing to events on fd %d", (void *)event_loop, handle->data.fd); - struct epoll_event_data *epoll_event_data = aws_mem_calloc(event_loop->alloc, 1, sizeof(struct epoll_event_data)); - handle->additional_data = epoll_event_data; - if (!epoll_event_data) { - return AWS_OP_ERR; - } - - struct epoll_loop *epoll_loop = event_loop->impl_data; - epoll_event_data->alloc = event_loop->alloc; - epoll_event_data->user_data = user_data; - epoll_event_data->handle = handle; - epoll_event_data->on_event = on_event; - epoll_event_data->is_subscribed = true; - - /*everyone is always registered for edge-triggered, hang up, remote hang up, errors. */ - uint32_t event_mask = EPOLLET | EPOLLHUP | EPOLLRDHUP | EPOLLERR; - - if (events & AWS_IO_EVENT_TYPE_READABLE) { - event_mask |= EPOLLIN; - } - - if (events & AWS_IO_EVENT_TYPE_WRITABLE) { - event_mask |= EPOLLOUT; - } - - /* this guy is copied by epoll_ctl */ - struct epoll_event epoll_event = { - .data = {.ptr = epoll_event_data}, - .events = event_mask, - }; - - if (epoll_ctl(epoll_loop->epoll_fd, EPOLL_CTL_ADD, handle->data.fd, &epoll_event)) { - AWS_LOGF_ERROR( - AWS_LS_IO_EVENT_LOOP, "id=%p: failed to subscribe to events on fd %d", (void *)event_loop, handle->data.fd); - handle->additional_data = NULL; - aws_mem_release(event_loop->alloc, epoll_event_data); - return aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); - } - - return AWS_OP_SUCCESS; -} - -static void s_free_io_event_resources(void *user_data) { - struct epoll_event_data *event_data = user_data; - aws_mem_release(event_data->alloc, (void *)event_data); -} - -static void s_unsubscribe_cleanup_task(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)task; - (void)status; - struct epoll_event_data *event_data = (struct epoll_event_data *)arg; - s_free_io_event_resources(event_data); -} - -static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, "id=%p: un-subscribing from events on fd %d", (void *)event_loop, handle->data.fd); - struct epoll_loop *epoll_loop = event_loop->impl_data; - - AWS_ASSERT(handle->additional_data); - struct epoll_event_data *additional_handle_data = handle->additional_data; - - struct epoll_event dummy_event; - - if (AWS_UNLIKELY(epoll_ctl(epoll_loop->epoll_fd, EPOLL_CTL_DEL, handle->data.fd, &dummy_event /*ignored*/))) { - AWS_LOGF_ERROR( - AWS_LS_IO_EVENT_LOOP, - "id=%p: failed to un-subscribe from events on fd %d", - (void *)event_loop, - handle->data.fd); - return aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); - } - - /* We can't clean up yet, because we have schedule tasks and more events to process, - * mark it as unsubscribed and schedule a cleanup task. */ - additional_handle_data->is_subscribed = false; - - aws_task_init( - &additional_handle_data->cleanup_task, - s_unsubscribe_cleanup_task, - additional_handle_data, - "epoll_event_loop_unsubscribe_cleanup"); - s_schedule_task_now(event_loop, &additional_handle_data->cleanup_task); - - handle->additional_data = NULL; - return AWS_OP_SUCCESS; -} - -static bool s_is_on_callers_thread(struct aws_event_loop *event_loop) { - struct epoll_loop *epoll_loop = event_loop->impl_data; - - aws_thread_id_t *thread_id = aws_atomic_load_ptr(&epoll_loop->running_thread_id); - return thread_id && aws_thread_thread_id_equal(*thread_id, aws_thread_current_thread_id()); -} - -/* We treat the pipe fd with a subscription to io events just like any other managed file descriptor. - * This is the event handler for events on that pipe.*/ -static void s_on_tasks_to_schedule( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - void *user_data) { - - (void)handle; - (void)user_data; - - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: notified of cross-thread tasks to schedule", (void *)event_loop); - struct epoll_loop *epoll_loop = event_loop->impl_data; - if (events & AWS_IO_EVENT_TYPE_READABLE) { - epoll_loop->should_process_task_pre_queue = true; - } -} - -static void s_process_task_pre_queue(struct aws_event_loop *event_loop) { - struct epoll_loop *epoll_loop = event_loop->impl_data; - - if (!epoll_loop->should_process_task_pre_queue) { - return; - } - - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: processing cross-thread tasks", (void *)event_loop); - epoll_loop->should_process_task_pre_queue = false; - - struct aws_linked_list task_pre_queue; - aws_linked_list_init(&task_pre_queue); - - uint64_t count_ignore = 0; - - aws_mutex_lock(&epoll_loop->task_pre_queue_mutex); - - /* several tasks could theoretically have been written (though this should never happen), make sure we drain the - * eventfd/pipe. */ - while (read(epoll_loop->read_task_handle.data.fd, &count_ignore, sizeof(count_ignore)) > -1) { - } - - aws_linked_list_swap_contents(&epoll_loop->task_pre_queue, &task_pre_queue); - - aws_mutex_unlock(&epoll_loop->task_pre_queue_mutex); - - while (!aws_linked_list_empty(&task_pre_queue)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&task_pre_queue); - struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: task %p pulled to event-loop, scheduling now.", - (void *)event_loop, - (void *)task); - /* Timestamp 0 is used to denote "now" tasks */ - if (task->timestamp == 0) { - aws_task_scheduler_schedule_now(&epoll_loop->scheduler, task); - } else { - aws_task_scheduler_schedule_future(&epoll_loop->scheduler, task, task->timestamp); - } - } -} - -static void s_main_loop(void *args) { - struct aws_event_loop *event_loop = args; - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: main loop started", (void *)event_loop); - struct epoll_loop *epoll_loop = event_loop->impl_data; - - /* set thread id to the thread of the event loop */ - aws_atomic_store_ptr(&epoll_loop->running_thread_id, &epoll_loop->thread_created_on.thread_id); - - int err = s_subscribe_to_io_events( - event_loop, &epoll_loop->read_task_handle, AWS_IO_EVENT_TYPE_READABLE, s_on_tasks_to_schedule, NULL); - if (err) { - return; - } - - int timeout = DEFAULT_TIMEOUT; - - struct epoll_event events[MAX_EVENTS]; - - AWS_LOGF_INFO( - AWS_LS_IO_EVENT_LOOP, - "id=%p: default timeout %d, and max events to process per tick %d", - (void *)event_loop, - timeout, - MAX_EVENTS); - - /* - * until stop is called, - * call epoll_wait, if a task is scheduled, or a file descriptor has activity, it will - * return. - * - * process all events, - * - * run all scheduled tasks. - * - * process queued subscription cleanups. - */ - while (epoll_loop->should_continue) { - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: waiting for a maximum of %d ms", (void *)event_loop, timeout); - int event_count = epoll_wait(epoll_loop->epoll_fd, events, MAX_EVENTS, timeout); - - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, "id=%p: wake up with %d events to process.", (void *)event_loop, event_count); - for (int i = 0; i < event_count; ++i) { - struct epoll_event_data *event_data = (struct epoll_event_data *)events[i].data.ptr; - - int event_mask = 0; - if (events[i].events & EPOLLIN) { - event_mask |= AWS_IO_EVENT_TYPE_READABLE; - } - - if (events[i].events & EPOLLOUT) { - event_mask |= AWS_IO_EVENT_TYPE_WRITABLE; - } - - if (events[i].events & EPOLLRDHUP) { - event_mask |= AWS_IO_EVENT_TYPE_REMOTE_HANG_UP; - } - - if (events[i].events & EPOLLHUP) { - event_mask |= AWS_IO_EVENT_TYPE_CLOSED; - } - - if (events[i].events & EPOLLERR) { - event_mask |= AWS_IO_EVENT_TYPE_ERROR; - } - - if (event_data->is_subscribed) { - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: activity on fd %d, invoking handler.", - (void *)event_loop, - event_data->handle->data.fd); - event_data->on_event(event_loop, event_data->handle, event_mask, event_data->user_data); - } - } - - /* run scheduled tasks */ - s_process_task_pre_queue(event_loop); - - uint64_t now_ns = 0; - event_loop->clock(&now_ns); /* if clock fails, now_ns will be 0 and tasks scheduled for a specific time - will not be run. That's ok, we'll handle them next time around. */ - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: running scheduled tasks.", (void *)event_loop); - aws_task_scheduler_run_all(&epoll_loop->scheduler, now_ns); - - /* set timeout for next epoll_wait() call. - * if clock fails, or scheduler has no tasks, use default timeout */ - bool use_default_timeout = false; - - if (event_loop->clock(&now_ns)) { - use_default_timeout = true; - } - - uint64_t next_run_time_ns; - if (!aws_task_scheduler_has_tasks(&epoll_loop->scheduler, &next_run_time_ns)) { - use_default_timeout = true; - } - - if (use_default_timeout) { - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, "id=%p: no more scheduled tasks using default timeout.", (void *)event_loop); - timeout = DEFAULT_TIMEOUT; - } else { - /* Translate timestamp (in nanoseconds) to timeout (in milliseconds) */ - uint64_t timeout_ns = (next_run_time_ns > now_ns) ? (next_run_time_ns - now_ns) : 0; - uint64_t timeout_ms64 = aws_timestamp_convert(timeout_ns, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_MILLIS, NULL); - timeout = timeout_ms64 > INT_MAX ? INT_MAX : (int)timeout_ms64; - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: detected more scheduled tasks with the next occurring at " - "%llu, using timeout of %d.", - (void *)event_loop, - (unsigned long long)timeout_ns, - timeout); - } - } - - AWS_LOGF_DEBUG(AWS_LS_IO_EVENT_LOOP, "id=%p: exiting main loop", (void *)event_loop); - s_unsubscribe_from_io_events(event_loop, &epoll_loop->read_task_handle); - /* set thread id back to NULL. This should be updated again in destroy, before tasks are canceled. */ - aws_atomic_store_ptr(&epoll_loop->running_thread_id, NULL); -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/event_loop.h> + +#include <aws/common/atomics.h> +#include <aws/common/clock.h> +#include <aws/common/mutex.h> +#include <aws/common/task_scheduler.h> +#include <aws/common/thread.h> + +#include <aws/io/logging.h> + +#include <sys/epoll.h> + +#include <errno.h> +#include <limits.h> +#include <unistd.h> + +#if !defined(COMPAT_MODE) && defined(__GLIBC__) && __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 8 +# define USE_EFD 1 +#else +# define USE_EFD 0 +#endif + +#if USE_EFD +# include <aws/io/io.h> +# include <sys/eventfd.h> + +#else +# include <aws/io/pipe.h> +#endif + +/* This isn't defined on ancient linux distros (breaking the builds). + * However, if this is a prebuild, we purposely build on an ancient system, but + * we want the kernel calls to still be the same as a modern build since that's likely the target of the application + * calling this code. Just define this if it isn't there already. GlibC and the kernel don't really care how the flag + * gets passed as long as it does. + */ +#ifndef EPOLLRDHUP +# define EPOLLRDHUP 0x2000 +#endif + +static void s_destroy(struct aws_event_loop *event_loop); +static int s_run(struct aws_event_loop *event_loop); +static int s_stop(struct aws_event_loop *event_loop); +static int s_wait_for_stop_completion(struct aws_event_loop *event_loop); +static void s_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task); +static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos); +static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task); +static int s_subscribe_to_io_events( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + aws_event_loop_on_event_fn *on_event, + void *user_data); +static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle); +static void s_free_io_event_resources(void *user_data); +static bool s_is_on_callers_thread(struct aws_event_loop *event_loop); + +static void s_main_loop(void *args); + +static struct aws_event_loop_vtable s_vtable = { + .destroy = s_destroy, + .run = s_run, + .stop = s_stop, + .wait_for_stop_completion = s_wait_for_stop_completion, + .schedule_task_now = s_schedule_task_now, + .schedule_task_future = s_schedule_task_future, + .cancel_task = s_cancel_task, + .subscribe_to_io_events = s_subscribe_to_io_events, + .unsubscribe_from_io_events = s_unsubscribe_from_io_events, + .free_io_event_resources = s_free_io_event_resources, + .is_on_callers_thread = s_is_on_callers_thread, +}; + +struct epoll_loop { + struct aws_task_scheduler scheduler; + struct aws_thread thread_created_on; + aws_thread_id_t thread_joined_to; + struct aws_atomic_var running_thread_id; + struct aws_io_handle read_task_handle; + struct aws_io_handle write_task_handle; + struct aws_mutex task_pre_queue_mutex; + struct aws_linked_list task_pre_queue; + struct aws_task stop_task; + struct aws_atomic_var stop_task_ptr; + int epoll_fd; + bool should_process_task_pre_queue; + bool should_continue; +}; + +struct epoll_event_data { + struct aws_allocator *alloc; + struct aws_io_handle *handle; + aws_event_loop_on_event_fn *on_event; + void *user_data; + struct aws_task cleanup_task; + bool is_subscribed; /* false when handle is unsubscribed, but this struct hasn't been cleaned up yet */ +}; + +/* default timeout is 100 seconds */ +enum { + DEFAULT_TIMEOUT = 100 * 1000, + MAX_EVENTS = 100, +}; + +int aws_open_nonblocking_posix_pipe(int pipe_fds[2]); + +/* Setup edge triggered epoll with a scheduler. */ +struct aws_event_loop *aws_event_loop_new_default(struct aws_allocator *alloc, aws_io_clock_fn *clock) { + struct aws_event_loop *loop = aws_mem_calloc(alloc, 1, sizeof(struct aws_event_loop)); + if (!loop) { + return NULL; + } + + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Initializing edge-triggered epoll", (void *)loop); + if (aws_event_loop_init_base(loop, alloc, clock)) { + goto clean_up_loop; + } + + struct epoll_loop *epoll_loop = aws_mem_calloc(alloc, 1, sizeof(struct epoll_loop)); + if (!epoll_loop) { + goto cleanup_base_loop; + } + + /* initialize thread id to NULL, it should be updated when the event loop thread starts. */ + aws_atomic_init_ptr(&epoll_loop->running_thread_id, NULL); + + aws_linked_list_init(&epoll_loop->task_pre_queue); + epoll_loop->task_pre_queue_mutex = (struct aws_mutex)AWS_MUTEX_INIT; + aws_atomic_init_ptr(&epoll_loop->stop_task_ptr, NULL); + + epoll_loop->epoll_fd = epoll_create(100); + if (epoll_loop->epoll_fd < 0) { + AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: Failed to open epoll handle.", (void *)loop); + aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); + goto clean_up_epoll; + } + + if (aws_thread_init(&epoll_loop->thread_created_on, alloc)) { + goto clean_up_epoll; + } + +#if USE_EFD + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Using eventfd for cross-thread notifications.", (void *)loop); + int fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + + if (fd < 0) { + AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: Failed to open eventfd handle.", (void *)loop); + aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); + goto clean_up_thread; + } + + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: eventfd descriptor %d.", (void *)loop, fd); + epoll_loop->write_task_handle = (struct aws_io_handle){.data.fd = fd, .additional_data = NULL}; + epoll_loop->read_task_handle = (struct aws_io_handle){.data.fd = fd, .additional_data = NULL}; +#else + AWS_LOGF_DEBUG( + AWS_LS_IO_EVENT_LOOP, + "id=%p: Eventfd not available, falling back to pipe for cross-thread notification.", + (void *)loop); + + int pipe_fds[2] = {0}; + /* this pipe is for task scheduling. */ + if (aws_open_nonblocking_posix_pipe(pipe_fds)) { + AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: failed to open pipe handle.", (void *)loop); + goto clean_up_thread; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, "id=%p: pipe descriptors read %d, write %d.", (void *)loop, pipe_fds[0], pipe_fds[1]); + epoll_loop->write_task_handle.data.fd = pipe_fds[1]; + epoll_loop->read_task_handle.data.fd = pipe_fds[0]; +#endif + + if (aws_task_scheduler_init(&epoll_loop->scheduler, alloc)) { + goto clean_up_pipe; + } + + epoll_loop->should_continue = false; + + loop->impl_data = epoll_loop; + loop->vtable = &s_vtable; + + return loop; + +clean_up_pipe: +#if USE_EFD + close(epoll_loop->write_task_handle.data.fd); + epoll_loop->write_task_handle.data.fd = -1; + epoll_loop->read_task_handle.data.fd = -1; +#else + close(epoll_loop->read_task_handle.data.fd); + close(epoll_loop->write_task_handle.data.fd); +#endif + +clean_up_thread: + aws_thread_clean_up(&epoll_loop->thread_created_on); + +clean_up_epoll: + if (epoll_loop->epoll_fd >= 0) { + close(epoll_loop->epoll_fd); + } + + aws_mem_release(alloc, epoll_loop); + +cleanup_base_loop: + aws_event_loop_clean_up_base(loop); + +clean_up_loop: + aws_mem_release(alloc, loop); + + return NULL; +} + +static void s_destroy(struct aws_event_loop *event_loop) { + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Destroying event_loop", (void *)event_loop); + + struct epoll_loop *epoll_loop = event_loop->impl_data; + + /* we don't know if stop() has been called by someone else, + * just call stop() again and wait for event-loop to finish. */ + aws_event_loop_stop(event_loop); + s_wait_for_stop_completion(event_loop); + + /* setting this so that canceled tasks don't blow up when asking if they're on the event-loop thread. */ + epoll_loop->thread_joined_to = aws_thread_current_thread_id(); + aws_atomic_store_ptr(&epoll_loop->running_thread_id, &epoll_loop->thread_joined_to); + aws_task_scheduler_clean_up(&epoll_loop->scheduler); + + while (!aws_linked_list_empty(&epoll_loop->task_pre_queue)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&epoll_loop->task_pre_queue); + struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); + task->fn(task, task->arg, AWS_TASK_STATUS_CANCELED); + } + + aws_thread_clean_up(&epoll_loop->thread_created_on); +#if USE_EFD + close(epoll_loop->write_task_handle.data.fd); + epoll_loop->write_task_handle.data.fd = -1; + epoll_loop->read_task_handle.data.fd = -1; +#else + close(epoll_loop->read_task_handle.data.fd); + close(epoll_loop->write_task_handle.data.fd); +#endif + + close(epoll_loop->epoll_fd); + aws_mem_release(event_loop->alloc, epoll_loop); + aws_event_loop_clean_up_base(event_loop); + aws_mem_release(event_loop->alloc, event_loop); +} + +static int s_run(struct aws_event_loop *event_loop) { + struct epoll_loop *epoll_loop = event_loop->impl_data; + + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Starting event-loop thread.", (void *)event_loop); + + epoll_loop->should_continue = true; + if (aws_thread_launch(&epoll_loop->thread_created_on, &s_main_loop, event_loop, NULL)) { + AWS_LOGF_FATAL(AWS_LS_IO_EVENT_LOOP, "id=%p: thread creation failed.", (void *)event_loop); + epoll_loop->should_continue = false; + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +static void s_stop_task(struct aws_task *task, void *args, enum aws_task_status status) { + + (void)task; + struct aws_event_loop *event_loop = args; + struct epoll_loop *epoll_loop = event_loop->impl_data; + + /* now okay to reschedule stop tasks. */ + aws_atomic_store_ptr(&epoll_loop->stop_task_ptr, NULL); + if (status == AWS_TASK_STATUS_RUN_READY) { + /* + * this allows the event loop to invoke the callback once the event loop has completed. + */ + epoll_loop->should_continue = false; + } +} + +static int s_stop(struct aws_event_loop *event_loop) { + struct epoll_loop *epoll_loop = event_loop->impl_data; + + void *expected_ptr = NULL; + bool update_succeeded = + aws_atomic_compare_exchange_ptr(&epoll_loop->stop_task_ptr, &expected_ptr, &epoll_loop->stop_task); + if (!update_succeeded) { + /* the stop task is already scheduled. */ + return AWS_OP_SUCCESS; + } + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Stopping event-loop thread.", (void *)event_loop); + aws_task_init(&epoll_loop->stop_task, s_stop_task, event_loop, "epoll_event_loop_stop"); + s_schedule_task_now(event_loop, &epoll_loop->stop_task); + + return AWS_OP_SUCCESS; +} + +static int s_wait_for_stop_completion(struct aws_event_loop *event_loop) { + struct epoll_loop *epoll_loop = event_loop->impl_data; + return aws_thread_join(&epoll_loop->thread_created_on); +} + +static void s_schedule_task_common(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { + struct epoll_loop *epoll_loop = event_loop->impl_data; + + /* if event loop and the caller are the same thread, just schedule and be done with it. */ + if (s_is_on_callers_thread(event_loop)) { + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: scheduling task %p in-thread for timestamp %llu", + (void *)event_loop, + (void *)task, + (unsigned long long)run_at_nanos); + if (run_at_nanos == 0) { + /* zero denotes "now" task */ + aws_task_scheduler_schedule_now(&epoll_loop->scheduler, task); + } else { + aws_task_scheduler_schedule_future(&epoll_loop->scheduler, task, run_at_nanos); + } + return; + } + + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: Scheduling task %p cross-thread for timestamp %llu", + (void *)event_loop, + (void *)task, + (unsigned long long)run_at_nanos); + task->timestamp = run_at_nanos; + aws_mutex_lock(&epoll_loop->task_pre_queue_mutex); + + uint64_t counter = 1; + + bool is_first_task = aws_linked_list_empty(&epoll_loop->task_pre_queue); + + aws_linked_list_push_back(&epoll_loop->task_pre_queue, &task->node); + + /* if the list was not empty, we already have a pending read on the pipe/eventfd, no need to write again. */ + if (is_first_task) { + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: Waking up event-loop thread", (void *)event_loop); + + /* If the write fails because the buffer is full, we don't actually care because that means there's a pending + * read on the pipe/eventfd and thus the event loop will end up checking to see if something has been queued.*/ + ssize_t do_not_care = write(epoll_loop->write_task_handle.data.fd, (void *)&counter, sizeof(counter)); + (void)do_not_care; + } + + aws_mutex_unlock(&epoll_loop->task_pre_queue_mutex); +} + +static void s_schedule_task_now(struct aws_event_loop *event_loop, struct aws_task *task) { + s_schedule_task_common(event_loop, task, 0 /* zero denotes "now" task */); +} + +static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { + s_schedule_task_common(event_loop, task, run_at_nanos); +} + +static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task) { + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: cancelling task %p", (void *)event_loop, (void *)task); + struct epoll_loop *epoll_loop = event_loop->impl_data; + aws_task_scheduler_cancel_task(&epoll_loop->scheduler, task); +} + +static int s_subscribe_to_io_events( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + aws_event_loop_on_event_fn *on_event, + void *user_data) { + + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: subscribing to events on fd %d", (void *)event_loop, handle->data.fd); + struct epoll_event_data *epoll_event_data = aws_mem_calloc(event_loop->alloc, 1, sizeof(struct epoll_event_data)); + handle->additional_data = epoll_event_data; + if (!epoll_event_data) { + return AWS_OP_ERR; + } + + struct epoll_loop *epoll_loop = event_loop->impl_data; + epoll_event_data->alloc = event_loop->alloc; + epoll_event_data->user_data = user_data; + epoll_event_data->handle = handle; + epoll_event_data->on_event = on_event; + epoll_event_data->is_subscribed = true; + + /*everyone is always registered for edge-triggered, hang up, remote hang up, errors. */ + uint32_t event_mask = EPOLLET | EPOLLHUP | EPOLLRDHUP | EPOLLERR; + + if (events & AWS_IO_EVENT_TYPE_READABLE) { + event_mask |= EPOLLIN; + } + + if (events & AWS_IO_EVENT_TYPE_WRITABLE) { + event_mask |= EPOLLOUT; + } + + /* this guy is copied by epoll_ctl */ + struct epoll_event epoll_event = { + .data = {.ptr = epoll_event_data}, + .events = event_mask, + }; + + if (epoll_ctl(epoll_loop->epoll_fd, EPOLL_CTL_ADD, handle->data.fd, &epoll_event)) { + AWS_LOGF_ERROR( + AWS_LS_IO_EVENT_LOOP, "id=%p: failed to subscribe to events on fd %d", (void *)event_loop, handle->data.fd); + handle->additional_data = NULL; + aws_mem_release(event_loop->alloc, epoll_event_data); + return aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); + } + + return AWS_OP_SUCCESS; +} + +static void s_free_io_event_resources(void *user_data) { + struct epoll_event_data *event_data = user_data; + aws_mem_release(event_data->alloc, (void *)event_data); +} + +static void s_unsubscribe_cleanup_task(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + (void)status; + struct epoll_event_data *event_data = (struct epoll_event_data *)arg; + s_free_io_event_resources(event_data); +} + +static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, "id=%p: un-subscribing from events on fd %d", (void *)event_loop, handle->data.fd); + struct epoll_loop *epoll_loop = event_loop->impl_data; + + AWS_ASSERT(handle->additional_data); + struct epoll_event_data *additional_handle_data = handle->additional_data; + + struct epoll_event dummy_event; + + if (AWS_UNLIKELY(epoll_ctl(epoll_loop->epoll_fd, EPOLL_CTL_DEL, handle->data.fd, &dummy_event /*ignored*/))) { + AWS_LOGF_ERROR( + AWS_LS_IO_EVENT_LOOP, + "id=%p: failed to un-subscribe from events on fd %d", + (void *)event_loop, + handle->data.fd); + return aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); + } + + /* We can't clean up yet, because we have schedule tasks and more events to process, + * mark it as unsubscribed and schedule a cleanup task. */ + additional_handle_data->is_subscribed = false; + + aws_task_init( + &additional_handle_data->cleanup_task, + s_unsubscribe_cleanup_task, + additional_handle_data, + "epoll_event_loop_unsubscribe_cleanup"); + s_schedule_task_now(event_loop, &additional_handle_data->cleanup_task); + + handle->additional_data = NULL; + return AWS_OP_SUCCESS; +} + +static bool s_is_on_callers_thread(struct aws_event_loop *event_loop) { + struct epoll_loop *epoll_loop = event_loop->impl_data; + + aws_thread_id_t *thread_id = aws_atomic_load_ptr(&epoll_loop->running_thread_id); + return thread_id && aws_thread_thread_id_equal(*thread_id, aws_thread_current_thread_id()); +} + +/* We treat the pipe fd with a subscription to io events just like any other managed file descriptor. + * This is the event handler for events on that pipe.*/ +static void s_on_tasks_to_schedule( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + void *user_data) { + + (void)handle; + (void)user_data; + + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: notified of cross-thread tasks to schedule", (void *)event_loop); + struct epoll_loop *epoll_loop = event_loop->impl_data; + if (events & AWS_IO_EVENT_TYPE_READABLE) { + epoll_loop->should_process_task_pre_queue = true; + } +} + +static void s_process_task_pre_queue(struct aws_event_loop *event_loop) { + struct epoll_loop *epoll_loop = event_loop->impl_data; + + if (!epoll_loop->should_process_task_pre_queue) { + return; + } + + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: processing cross-thread tasks", (void *)event_loop); + epoll_loop->should_process_task_pre_queue = false; + + struct aws_linked_list task_pre_queue; + aws_linked_list_init(&task_pre_queue); + + uint64_t count_ignore = 0; + + aws_mutex_lock(&epoll_loop->task_pre_queue_mutex); + + /* several tasks could theoretically have been written (though this should never happen), make sure we drain the + * eventfd/pipe. */ + while (read(epoll_loop->read_task_handle.data.fd, &count_ignore, sizeof(count_ignore)) > -1) { + } + + aws_linked_list_swap_contents(&epoll_loop->task_pre_queue, &task_pre_queue); + + aws_mutex_unlock(&epoll_loop->task_pre_queue_mutex); + + while (!aws_linked_list_empty(&task_pre_queue)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&task_pre_queue); + struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: task %p pulled to event-loop, scheduling now.", + (void *)event_loop, + (void *)task); + /* Timestamp 0 is used to denote "now" tasks */ + if (task->timestamp == 0) { + aws_task_scheduler_schedule_now(&epoll_loop->scheduler, task); + } else { + aws_task_scheduler_schedule_future(&epoll_loop->scheduler, task, task->timestamp); + } + } +} + +static void s_main_loop(void *args) { + struct aws_event_loop *event_loop = args; + AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: main loop started", (void *)event_loop); + struct epoll_loop *epoll_loop = event_loop->impl_data; + + /* set thread id to the thread of the event loop */ + aws_atomic_store_ptr(&epoll_loop->running_thread_id, &epoll_loop->thread_created_on.thread_id); + + int err = s_subscribe_to_io_events( + event_loop, &epoll_loop->read_task_handle, AWS_IO_EVENT_TYPE_READABLE, s_on_tasks_to_schedule, NULL); + if (err) { + return; + } + + int timeout = DEFAULT_TIMEOUT; + + struct epoll_event events[MAX_EVENTS]; + + AWS_LOGF_INFO( + AWS_LS_IO_EVENT_LOOP, + "id=%p: default timeout %d, and max events to process per tick %d", + (void *)event_loop, + timeout, + MAX_EVENTS); + + /* + * until stop is called, + * call epoll_wait, if a task is scheduled, or a file descriptor has activity, it will + * return. + * + * process all events, + * + * run all scheduled tasks. + * + * process queued subscription cleanups. + */ + while (epoll_loop->should_continue) { + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: waiting for a maximum of %d ms", (void *)event_loop, timeout); + int event_count = epoll_wait(epoll_loop->epoll_fd, events, MAX_EVENTS, timeout); + + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, "id=%p: wake up with %d events to process.", (void *)event_loop, event_count); + for (int i = 0; i < event_count; ++i) { + struct epoll_event_data *event_data = (struct epoll_event_data *)events[i].data.ptr; + + int event_mask = 0; + if (events[i].events & EPOLLIN) { + event_mask |= AWS_IO_EVENT_TYPE_READABLE; + } + + if (events[i].events & EPOLLOUT) { + event_mask |= AWS_IO_EVENT_TYPE_WRITABLE; + } + + if (events[i].events & EPOLLRDHUP) { + event_mask |= AWS_IO_EVENT_TYPE_REMOTE_HANG_UP; + } + + if (events[i].events & EPOLLHUP) { + event_mask |= AWS_IO_EVENT_TYPE_CLOSED; + } + + if (events[i].events & EPOLLERR) { + event_mask |= AWS_IO_EVENT_TYPE_ERROR; + } + + if (event_data->is_subscribed) { + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: activity on fd %d, invoking handler.", + (void *)event_loop, + event_data->handle->data.fd); + event_data->on_event(event_loop, event_data->handle, event_mask, event_data->user_data); + } + } + + /* run scheduled tasks */ + s_process_task_pre_queue(event_loop); + + uint64_t now_ns = 0; + event_loop->clock(&now_ns); /* if clock fails, now_ns will be 0 and tasks scheduled for a specific time + will not be run. That's ok, we'll handle them next time around. */ + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: running scheduled tasks.", (void *)event_loop); + aws_task_scheduler_run_all(&epoll_loop->scheduler, now_ns); + + /* set timeout for next epoll_wait() call. + * if clock fails, or scheduler has no tasks, use default timeout */ + bool use_default_timeout = false; + + if (event_loop->clock(&now_ns)) { + use_default_timeout = true; + } + + uint64_t next_run_time_ns; + if (!aws_task_scheduler_has_tasks(&epoll_loop->scheduler, &next_run_time_ns)) { + use_default_timeout = true; + } + + if (use_default_timeout) { + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, "id=%p: no more scheduled tasks using default timeout.", (void *)event_loop); + timeout = DEFAULT_TIMEOUT; + } else { + /* Translate timestamp (in nanoseconds) to timeout (in milliseconds) */ + uint64_t timeout_ns = (next_run_time_ns > now_ns) ? (next_run_time_ns - now_ns) : 0; + uint64_t timeout_ms64 = aws_timestamp_convert(timeout_ns, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_MILLIS, NULL); + timeout = timeout_ms64 > INT_MAX ? INT_MAX : (int)timeout_ms64; + AWS_LOGF_TRACE( + AWS_LS_IO_EVENT_LOOP, + "id=%p: detected more scheduled tasks with the next occurring at " + "%llu, using timeout of %d.", + (void *)event_loop, + (unsigned long long)timeout_ns, + timeout); + } + } + + AWS_LOGF_DEBUG(AWS_LS_IO_EVENT_LOOP, "id=%p: exiting main loop", (void *)event_loop); + s_unsubscribe_from_io_events(event_loop, &epoll_loop->read_task_handle); + /* set thread id back to NULL. This should be updated again in destroy, before tasks are canceled. */ + aws_atomic_store_ptr(&epoll_loop->running_thread_id, NULL); +} diff --git a/contrib/restricted/aws/aws-c-io/source/message_pool.c b/contrib/restricted/aws/aws-c-io/source/message_pool.c index 3090f57e6b..5de8038315 100644 --- a/contrib/restricted/aws/aws-c-io/source/message_pool.c +++ b/contrib/restricted/aws/aws-c-io/source/message_pool.c @@ -1,205 +1,205 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/message_pool.h> - -#include <aws/common/thread.h> - -int aws_memory_pool_init( - struct aws_memory_pool *mempool, - struct aws_allocator *alloc, - uint16_t ideal_segment_count, - size_t segment_size) { - - mempool->alloc = alloc; - mempool->ideal_segment_count = ideal_segment_count; - mempool->segment_size = segment_size; - mempool->data_ptr = aws_mem_calloc(alloc, ideal_segment_count, sizeof(void *)); - if (!mempool->data_ptr) { - return AWS_OP_ERR; - } - - aws_array_list_init_static(&mempool->stack, mempool->data_ptr, ideal_segment_count, sizeof(void *)); - - for (uint16_t i = 0; i < ideal_segment_count; ++i) { - void *memory = aws_mem_acquire(alloc, segment_size); - if (memory) { - aws_array_list_push_back(&mempool->stack, &memory); - } else { - goto clean_up; - } - } - - return AWS_OP_SUCCESS; - -clean_up: - aws_memory_pool_clean_up(mempool); - return AWS_OP_ERR; -} - -void aws_memory_pool_clean_up(struct aws_memory_pool *mempool) { - void *cur = NULL; - - while (aws_array_list_length(&mempool->stack) > 0) { - /* the only way this fails is not possible since I already checked the length. */ - aws_array_list_back(&mempool->stack, &cur); - aws_array_list_pop_back(&mempool->stack); - aws_mem_release(mempool->alloc, cur); - } - - aws_array_list_clean_up(&mempool->stack); - aws_mem_release(mempool->alloc, mempool->data_ptr); -} - -void *aws_memory_pool_acquire(struct aws_memory_pool *mempool) { - void *back = NULL; - if (aws_array_list_length(&mempool->stack) > 0) { - aws_array_list_back(&mempool->stack, &back); - aws_array_list_pop_back(&mempool->stack); - - return back; - } - - void *mem = aws_mem_acquire(mempool->alloc, mempool->segment_size); - return mem; -} - -void aws_memory_pool_release(struct aws_memory_pool *mempool, void *to_release) { - size_t pool_size = aws_array_list_length(&mempool->stack); - - if (pool_size >= mempool->ideal_segment_count) { - aws_mem_release(mempool->alloc, to_release); - return; - } - - aws_array_list_push_back(&mempool->stack, &to_release); -} - -struct message_pool_allocator { - struct aws_allocator base_allocator; - struct aws_message_pool *msg_pool; -}; - -void *s_message_pool_mem_acquire(struct aws_allocator *allocator, size_t size) { - (void)allocator; - (void)size; - - /* no one should ever call this ever. */ - AWS_ASSERT(0); - return NULL; -} - -void s_message_pool_mem_release(struct aws_allocator *allocator, void *ptr) { - struct message_pool_allocator *msg_pool_alloc = allocator->impl; - - aws_message_pool_release(msg_pool_alloc->msg_pool, (struct aws_io_message *)ptr); -} - -static size_t MSG_OVERHEAD = sizeof(struct aws_io_message) + sizeof(struct message_pool_allocator); - -int aws_message_pool_init( - struct aws_message_pool *msg_pool, - struct aws_allocator *alloc, - struct aws_message_pool_creation_args *args) { - - msg_pool->alloc = alloc; - - size_t msg_data_size = args->application_data_msg_data_size + MSG_OVERHEAD; - - if (aws_memory_pool_init( - &msg_pool->application_data_pool, alloc, args->application_data_msg_count, msg_data_size)) { - return AWS_OP_ERR; - } - - size_t small_blk_data_size = args->small_block_msg_data_size + MSG_OVERHEAD; - - if (aws_memory_pool_init(&msg_pool->small_block_pool, alloc, args->small_block_msg_count, small_blk_data_size)) { - aws_memory_pool_clean_up(&msg_pool->application_data_pool); - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -void aws_message_pool_clean_up(struct aws_message_pool *msg_pool) { - aws_memory_pool_clean_up(&msg_pool->application_data_pool); - aws_memory_pool_clean_up(&msg_pool->small_block_pool); - AWS_ZERO_STRUCT(*msg_pool); -} - -struct message_wrapper { - struct aws_io_message message; - struct message_pool_allocator msg_allocator; - uint8_t buffer_start[1]; -}; - -struct aws_io_message *aws_message_pool_acquire( - struct aws_message_pool *msg_pool, - enum aws_io_message_type message_type, - size_t size_hint) { - - struct message_wrapper *message_wrapper = NULL; - size_t max_size = 0; - switch (message_type) { - case AWS_IO_MESSAGE_APPLICATION_DATA: - if (size_hint > msg_pool->small_block_pool.segment_size - MSG_OVERHEAD) { - message_wrapper = aws_memory_pool_acquire(&msg_pool->application_data_pool); - max_size = msg_pool->application_data_pool.segment_size - MSG_OVERHEAD; - } else { - message_wrapper = aws_memory_pool_acquire(&msg_pool->small_block_pool); - max_size = msg_pool->small_block_pool.segment_size - MSG_OVERHEAD; - } - break; - default: - AWS_ASSERT(0); - aws_raise_error(AWS_IO_CHANNEL_UNKNOWN_MESSAGE_TYPE); - return NULL; - } - - if (!message_wrapper) { - return NULL; - } - - message_wrapper->message.message_type = message_type; - message_wrapper->message.message_tag = 0; - message_wrapper->message.user_data = NULL; - message_wrapper->message.copy_mark = 0; - message_wrapper->message.on_completion = NULL; - /* the buffer shares the allocation with the message. It's the bit at the end. */ - message_wrapper->message.message_data.buffer = message_wrapper->buffer_start; - message_wrapper->message.message_data.len = 0; - message_wrapper->message.message_data.capacity = size_hint <= max_size ? size_hint : max_size; - - /* set the allocator ptr */ - message_wrapper->msg_allocator.base_allocator.impl = &message_wrapper->msg_allocator; - message_wrapper->msg_allocator.base_allocator.mem_acquire = s_message_pool_mem_acquire; - message_wrapper->msg_allocator.base_allocator.mem_realloc = NULL; - message_wrapper->msg_allocator.base_allocator.mem_release = s_message_pool_mem_release; - message_wrapper->msg_allocator.msg_pool = msg_pool; - - message_wrapper->message.allocator = &message_wrapper->msg_allocator.base_allocator; - return &message_wrapper->message; -} - -void aws_message_pool_release(struct aws_message_pool *msg_pool, struct aws_io_message *message) { - - memset(message->message_data.buffer, 0, message->message_data.len); - message->allocator = NULL; - - struct message_wrapper *wrapper = AWS_CONTAINER_OF(message, struct message_wrapper, message); - - switch (message->message_type) { - case AWS_IO_MESSAGE_APPLICATION_DATA: - if (message->message_data.capacity > msg_pool->small_block_pool.segment_size - MSG_OVERHEAD) { - aws_memory_pool_release(&msg_pool->application_data_pool, wrapper); - } else { - aws_memory_pool_release(&msg_pool->small_block_pool, wrapper); - } - break; - default: - AWS_ASSERT(0); - aws_raise_error(AWS_IO_CHANNEL_UNKNOWN_MESSAGE_TYPE); - } -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/message_pool.h> + +#include <aws/common/thread.h> + +int aws_memory_pool_init( + struct aws_memory_pool *mempool, + struct aws_allocator *alloc, + uint16_t ideal_segment_count, + size_t segment_size) { + + mempool->alloc = alloc; + mempool->ideal_segment_count = ideal_segment_count; + mempool->segment_size = segment_size; + mempool->data_ptr = aws_mem_calloc(alloc, ideal_segment_count, sizeof(void *)); + if (!mempool->data_ptr) { + return AWS_OP_ERR; + } + + aws_array_list_init_static(&mempool->stack, mempool->data_ptr, ideal_segment_count, sizeof(void *)); + + for (uint16_t i = 0; i < ideal_segment_count; ++i) { + void *memory = aws_mem_acquire(alloc, segment_size); + if (memory) { + aws_array_list_push_back(&mempool->stack, &memory); + } else { + goto clean_up; + } + } + + return AWS_OP_SUCCESS; + +clean_up: + aws_memory_pool_clean_up(mempool); + return AWS_OP_ERR; +} + +void aws_memory_pool_clean_up(struct aws_memory_pool *mempool) { + void *cur = NULL; + + while (aws_array_list_length(&mempool->stack) > 0) { + /* the only way this fails is not possible since I already checked the length. */ + aws_array_list_back(&mempool->stack, &cur); + aws_array_list_pop_back(&mempool->stack); + aws_mem_release(mempool->alloc, cur); + } + + aws_array_list_clean_up(&mempool->stack); + aws_mem_release(mempool->alloc, mempool->data_ptr); +} + +void *aws_memory_pool_acquire(struct aws_memory_pool *mempool) { + void *back = NULL; + if (aws_array_list_length(&mempool->stack) > 0) { + aws_array_list_back(&mempool->stack, &back); + aws_array_list_pop_back(&mempool->stack); + + return back; + } + + void *mem = aws_mem_acquire(mempool->alloc, mempool->segment_size); + return mem; +} + +void aws_memory_pool_release(struct aws_memory_pool *mempool, void *to_release) { + size_t pool_size = aws_array_list_length(&mempool->stack); + + if (pool_size >= mempool->ideal_segment_count) { + aws_mem_release(mempool->alloc, to_release); + return; + } + + aws_array_list_push_back(&mempool->stack, &to_release); +} + +struct message_pool_allocator { + struct aws_allocator base_allocator; + struct aws_message_pool *msg_pool; +}; + +void *s_message_pool_mem_acquire(struct aws_allocator *allocator, size_t size) { + (void)allocator; + (void)size; + + /* no one should ever call this ever. */ + AWS_ASSERT(0); + return NULL; +} + +void s_message_pool_mem_release(struct aws_allocator *allocator, void *ptr) { + struct message_pool_allocator *msg_pool_alloc = allocator->impl; + + aws_message_pool_release(msg_pool_alloc->msg_pool, (struct aws_io_message *)ptr); +} + +static size_t MSG_OVERHEAD = sizeof(struct aws_io_message) + sizeof(struct message_pool_allocator); + +int aws_message_pool_init( + struct aws_message_pool *msg_pool, + struct aws_allocator *alloc, + struct aws_message_pool_creation_args *args) { + + msg_pool->alloc = alloc; + + size_t msg_data_size = args->application_data_msg_data_size + MSG_OVERHEAD; + + if (aws_memory_pool_init( + &msg_pool->application_data_pool, alloc, args->application_data_msg_count, msg_data_size)) { + return AWS_OP_ERR; + } + + size_t small_blk_data_size = args->small_block_msg_data_size + MSG_OVERHEAD; + + if (aws_memory_pool_init(&msg_pool->small_block_pool, alloc, args->small_block_msg_count, small_blk_data_size)) { + aws_memory_pool_clean_up(&msg_pool->application_data_pool); + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +void aws_message_pool_clean_up(struct aws_message_pool *msg_pool) { + aws_memory_pool_clean_up(&msg_pool->application_data_pool); + aws_memory_pool_clean_up(&msg_pool->small_block_pool); + AWS_ZERO_STRUCT(*msg_pool); +} + +struct message_wrapper { + struct aws_io_message message; + struct message_pool_allocator msg_allocator; + uint8_t buffer_start[1]; +}; + +struct aws_io_message *aws_message_pool_acquire( + struct aws_message_pool *msg_pool, + enum aws_io_message_type message_type, + size_t size_hint) { + + struct message_wrapper *message_wrapper = NULL; + size_t max_size = 0; + switch (message_type) { + case AWS_IO_MESSAGE_APPLICATION_DATA: + if (size_hint > msg_pool->small_block_pool.segment_size - MSG_OVERHEAD) { + message_wrapper = aws_memory_pool_acquire(&msg_pool->application_data_pool); + max_size = msg_pool->application_data_pool.segment_size - MSG_OVERHEAD; + } else { + message_wrapper = aws_memory_pool_acquire(&msg_pool->small_block_pool); + max_size = msg_pool->small_block_pool.segment_size - MSG_OVERHEAD; + } + break; + default: + AWS_ASSERT(0); + aws_raise_error(AWS_IO_CHANNEL_UNKNOWN_MESSAGE_TYPE); + return NULL; + } + + if (!message_wrapper) { + return NULL; + } + + message_wrapper->message.message_type = message_type; + message_wrapper->message.message_tag = 0; + message_wrapper->message.user_data = NULL; + message_wrapper->message.copy_mark = 0; + message_wrapper->message.on_completion = NULL; + /* the buffer shares the allocation with the message. It's the bit at the end. */ + message_wrapper->message.message_data.buffer = message_wrapper->buffer_start; + message_wrapper->message.message_data.len = 0; + message_wrapper->message.message_data.capacity = size_hint <= max_size ? size_hint : max_size; + + /* set the allocator ptr */ + message_wrapper->msg_allocator.base_allocator.impl = &message_wrapper->msg_allocator; + message_wrapper->msg_allocator.base_allocator.mem_acquire = s_message_pool_mem_acquire; + message_wrapper->msg_allocator.base_allocator.mem_realloc = NULL; + message_wrapper->msg_allocator.base_allocator.mem_release = s_message_pool_mem_release; + message_wrapper->msg_allocator.msg_pool = msg_pool; + + message_wrapper->message.allocator = &message_wrapper->msg_allocator.base_allocator; + return &message_wrapper->message; +} + +void aws_message_pool_release(struct aws_message_pool *msg_pool, struct aws_io_message *message) { + + memset(message->message_data.buffer, 0, message->message_data.len); + message->allocator = NULL; + + struct message_wrapper *wrapper = AWS_CONTAINER_OF(message, struct message_wrapper, message); + + switch (message->message_type) { + case AWS_IO_MESSAGE_APPLICATION_DATA: + if (message->message_data.capacity > msg_pool->small_block_pool.segment_size - MSG_OVERHEAD) { + aws_memory_pool_release(&msg_pool->application_data_pool, wrapper); + } else { + aws_memory_pool_release(&msg_pool->small_block_pool, wrapper); + } + break; + default: + AWS_ASSERT(0); + aws_raise_error(AWS_IO_CHANNEL_UNKNOWN_MESSAGE_TYPE); + } +} diff --git a/contrib/restricted/aws/aws-c-io/source/pki_utils.c b/contrib/restricted/aws/aws-c-io/source/pki_utils.c index e8b3719089..8e3dc97602 100644 --- a/contrib/restricted/aws/aws-c-io/source/pki_utils.c +++ b/contrib/restricted/aws/aws-c-io/source/pki_utils.c @@ -1,228 +1,228 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/pki_utils.h> - -#include <aws/common/encoding.h> - -#include <aws/io/file_utils.h> -#include <aws/io/logging.h> - -#include <ctype.h> -#include <errno.h> -#include <string.h> - -enum PEM_PARSE_STATE { - BEGIN, - ON_DATA, -}; - -void aws_cert_chain_clean_up(struct aws_array_list *cert_chain) { - for (size_t i = 0; i < aws_array_list_length(cert_chain); ++i) { - struct aws_byte_buf *decoded_buffer_ptr = NULL; - aws_array_list_get_at_ptr(cert_chain, (void **)&decoded_buffer_ptr, i); - - if (decoded_buffer_ptr) { - aws_secure_zero(decoded_buffer_ptr->buffer, decoded_buffer_ptr->len); - aws_byte_buf_clean_up(decoded_buffer_ptr); - } - } - - /* remember, we don't own it so we don't free it, just undo whatever mutations we've done at this point. */ - aws_array_list_clear(cert_chain); -} - -static int s_convert_pem_to_raw_base64( - struct aws_allocator *allocator, - const struct aws_byte_cursor *pem, - struct aws_array_list *cert_chain_or_key) { - enum PEM_PARSE_STATE state = BEGIN; - - struct aws_byte_buf current_cert; - const char *begin_header = "-----BEGIN"; - const char *end_header = "-----END"; - size_t begin_header_len = strlen(begin_header); - size_t end_header_len = strlen(end_header); - bool on_length_calc = true; - - struct aws_array_list split_buffers; - if (aws_array_list_init_dynamic(&split_buffers, allocator, 16, sizeof(struct aws_byte_cursor))) { - return AWS_OP_ERR; - } - - if (aws_byte_cursor_split_on_char(pem, '\n', &split_buffers)) { - aws_array_list_clean_up(&split_buffers); - AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Invalid PEM buffer: failed to split on newline"); - return AWS_OP_ERR; - } - - size_t split_count = aws_array_list_length(&split_buffers); - size_t i = 0; - size_t index_of_current_cert_start = 0; - size_t current_cert_len = 0; - - while (i < split_count) { - struct aws_byte_cursor *current_cur_ptr = NULL; - aws_array_list_get_at_ptr(&split_buffers, (void **)¤t_cur_ptr, i); - - /* burn off the padding in the buffer first. - * Worst case we'll only have to do this once per line in the buffer. */ - while (current_cur_ptr->len && aws_isspace(*current_cur_ptr->ptr)) { - aws_byte_cursor_advance(current_cur_ptr, 1); - } - - /* handle CRLF on Windows by burning '\r' off the end of the buffer */ - if (current_cur_ptr->len && (current_cur_ptr->ptr[current_cur_ptr->len - 1] == '\r')) { - current_cur_ptr->len--; - } - - switch (state) { - case BEGIN: - if (current_cur_ptr->len > begin_header_len && - !strncmp((const char *)current_cur_ptr->ptr, begin_header, begin_header_len)) { - state = ON_DATA; - index_of_current_cert_start = i + 1; - } - ++i; - break; - /* this loops through the lines containing data twice. First to figure out the length, a second - * time to actually copy the data. */ - case ON_DATA: - /* Found end tag. */ - if (current_cur_ptr->len > end_header_len && - !strncmp((const char *)current_cur_ptr->ptr, end_header, end_header_len)) { - if (on_length_calc) { - on_length_calc = false; - state = ON_DATA; - i = index_of_current_cert_start; - - if (aws_byte_buf_init(¤t_cert, allocator, current_cert_len)) { - goto end_of_loop; - } - - } else { - if (aws_array_list_push_back(cert_chain_or_key, ¤t_cert)) { - aws_secure_zero(¤t_cert.buffer, current_cert.len); - aws_byte_buf_clean_up(¤t_cert); - goto end_of_loop; - } - state = BEGIN; - on_length_calc = true; - current_cert_len = 0; - ++i; - } - /* actually on a line with data in it. */ - } else { - if (!on_length_calc) { - aws_byte_buf_write(¤t_cert, current_cur_ptr->ptr, current_cur_ptr->len); - } else { - current_cert_len += current_cur_ptr->len; - } - ++i; - } - break; - } - } - -end_of_loop: - aws_array_list_clean_up(&split_buffers); - - if (state == BEGIN && aws_array_list_length(cert_chain_or_key) > 0) { - return AWS_OP_SUCCESS; - } - - AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Invalid PEM buffer."); - aws_cert_chain_clean_up(cert_chain_or_key); - return aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); -} - -int aws_decode_pem_to_buffer_list( - struct aws_allocator *alloc, - const struct aws_byte_cursor *pem_cursor, - struct aws_array_list *cert_chain_or_key) { - AWS_ASSERT(aws_array_list_length(cert_chain_or_key) == 0); - struct aws_array_list base_64_buffer_list; - - if (aws_array_list_init_dynamic(&base_64_buffer_list, alloc, 2, sizeof(struct aws_byte_buf))) { - return AWS_OP_ERR; - } - - int err_code = AWS_OP_ERR; - - if (s_convert_pem_to_raw_base64(alloc, pem_cursor, &base_64_buffer_list)) { - goto cleanup_base64_buffer_list; - } - - for (size_t i = 0; i < aws_array_list_length(&base_64_buffer_list); ++i) { - size_t decoded_len = 0; - struct aws_byte_buf *byte_buf_ptr = NULL; - aws_array_list_get_at_ptr(&base_64_buffer_list, (void **)&byte_buf_ptr, i); - struct aws_byte_cursor byte_cur = aws_byte_cursor_from_buf(byte_buf_ptr); - - if (aws_base64_compute_decoded_len(&byte_cur, &decoded_len)) { - aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); - goto cleanup_output_due_to_error; - } - - struct aws_byte_buf decoded_buffer; - - if (aws_byte_buf_init(&decoded_buffer, alloc, decoded_len)) { - goto cleanup_output_due_to_error; - } - - if (aws_base64_decode(&byte_cur, &decoded_buffer)) { - aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); - aws_byte_buf_clean_up(&decoded_buffer); - goto cleanup_output_due_to_error; - } - - if (aws_array_list_push_back(cert_chain_or_key, &decoded_buffer)) { - aws_byte_buf_clean_up(&decoded_buffer); - goto cleanup_output_due_to_error; - } - } - - err_code = AWS_OP_SUCCESS; - -cleanup_base64_buffer_list: - aws_cert_chain_clean_up(&base_64_buffer_list); - aws_array_list_clean_up(&base_64_buffer_list); - - return err_code; - -cleanup_output_due_to_error: - AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Invalid PEM buffer."); - aws_cert_chain_clean_up(&base_64_buffer_list); - aws_array_list_clean_up(&base_64_buffer_list); - - aws_cert_chain_clean_up(cert_chain_or_key); - - return AWS_OP_ERR; -} - -int aws_read_and_decode_pem_file_to_buffer_list( - struct aws_allocator *alloc, - const char *filename, - struct aws_array_list *cert_chain_or_key) { - - struct aws_byte_buf raw_file_buffer; - if (aws_byte_buf_init_from_file(&raw_file_buffer, alloc, filename)) { - AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Failed to read file %s.", filename); - return AWS_OP_ERR; - } - AWS_ASSERT(raw_file_buffer.buffer); - - struct aws_byte_cursor file_cursor = aws_byte_cursor_from_buf(&raw_file_buffer); - if (aws_decode_pem_to_buffer_list(alloc, &file_cursor, cert_chain_or_key)) { - aws_secure_zero(raw_file_buffer.buffer, raw_file_buffer.len); - aws_byte_buf_clean_up(&raw_file_buffer); - AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Failed to decode PEM file %s.", filename); - return AWS_OP_ERR; - } - - aws_secure_zero(raw_file_buffer.buffer, raw_file_buffer.len); - aws_byte_buf_clean_up(&raw_file_buffer); - - return AWS_OP_SUCCESS; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/pki_utils.h> + +#include <aws/common/encoding.h> + +#include <aws/io/file_utils.h> +#include <aws/io/logging.h> + +#include <ctype.h> +#include <errno.h> +#include <string.h> + +enum PEM_PARSE_STATE { + BEGIN, + ON_DATA, +}; + +void aws_cert_chain_clean_up(struct aws_array_list *cert_chain) { + for (size_t i = 0; i < aws_array_list_length(cert_chain); ++i) { + struct aws_byte_buf *decoded_buffer_ptr = NULL; + aws_array_list_get_at_ptr(cert_chain, (void **)&decoded_buffer_ptr, i); + + if (decoded_buffer_ptr) { + aws_secure_zero(decoded_buffer_ptr->buffer, decoded_buffer_ptr->len); + aws_byte_buf_clean_up(decoded_buffer_ptr); + } + } + + /* remember, we don't own it so we don't free it, just undo whatever mutations we've done at this point. */ + aws_array_list_clear(cert_chain); +} + +static int s_convert_pem_to_raw_base64( + struct aws_allocator *allocator, + const struct aws_byte_cursor *pem, + struct aws_array_list *cert_chain_or_key) { + enum PEM_PARSE_STATE state = BEGIN; + + struct aws_byte_buf current_cert; + const char *begin_header = "-----BEGIN"; + const char *end_header = "-----END"; + size_t begin_header_len = strlen(begin_header); + size_t end_header_len = strlen(end_header); + bool on_length_calc = true; + + struct aws_array_list split_buffers; + if (aws_array_list_init_dynamic(&split_buffers, allocator, 16, sizeof(struct aws_byte_cursor))) { + return AWS_OP_ERR; + } + + if (aws_byte_cursor_split_on_char(pem, '\n', &split_buffers)) { + aws_array_list_clean_up(&split_buffers); + AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Invalid PEM buffer: failed to split on newline"); + return AWS_OP_ERR; + } + + size_t split_count = aws_array_list_length(&split_buffers); + size_t i = 0; + size_t index_of_current_cert_start = 0; + size_t current_cert_len = 0; + + while (i < split_count) { + struct aws_byte_cursor *current_cur_ptr = NULL; + aws_array_list_get_at_ptr(&split_buffers, (void **)¤t_cur_ptr, i); + + /* burn off the padding in the buffer first. + * Worst case we'll only have to do this once per line in the buffer. */ + while (current_cur_ptr->len && aws_isspace(*current_cur_ptr->ptr)) { + aws_byte_cursor_advance(current_cur_ptr, 1); + } + + /* handle CRLF on Windows by burning '\r' off the end of the buffer */ + if (current_cur_ptr->len && (current_cur_ptr->ptr[current_cur_ptr->len - 1] == '\r')) { + current_cur_ptr->len--; + } + + switch (state) { + case BEGIN: + if (current_cur_ptr->len > begin_header_len && + !strncmp((const char *)current_cur_ptr->ptr, begin_header, begin_header_len)) { + state = ON_DATA; + index_of_current_cert_start = i + 1; + } + ++i; + break; + /* this loops through the lines containing data twice. First to figure out the length, a second + * time to actually copy the data. */ + case ON_DATA: + /* Found end tag. */ + if (current_cur_ptr->len > end_header_len && + !strncmp((const char *)current_cur_ptr->ptr, end_header, end_header_len)) { + if (on_length_calc) { + on_length_calc = false; + state = ON_DATA; + i = index_of_current_cert_start; + + if (aws_byte_buf_init(¤t_cert, allocator, current_cert_len)) { + goto end_of_loop; + } + + } else { + if (aws_array_list_push_back(cert_chain_or_key, ¤t_cert)) { + aws_secure_zero(¤t_cert.buffer, current_cert.len); + aws_byte_buf_clean_up(¤t_cert); + goto end_of_loop; + } + state = BEGIN; + on_length_calc = true; + current_cert_len = 0; + ++i; + } + /* actually on a line with data in it. */ + } else { + if (!on_length_calc) { + aws_byte_buf_write(¤t_cert, current_cur_ptr->ptr, current_cur_ptr->len); + } else { + current_cert_len += current_cur_ptr->len; + } + ++i; + } + break; + } + } + +end_of_loop: + aws_array_list_clean_up(&split_buffers); + + if (state == BEGIN && aws_array_list_length(cert_chain_or_key) > 0) { + return AWS_OP_SUCCESS; + } + + AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Invalid PEM buffer."); + aws_cert_chain_clean_up(cert_chain_or_key); + return aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); +} + +int aws_decode_pem_to_buffer_list( + struct aws_allocator *alloc, + const struct aws_byte_cursor *pem_cursor, + struct aws_array_list *cert_chain_or_key) { + AWS_ASSERT(aws_array_list_length(cert_chain_or_key) == 0); + struct aws_array_list base_64_buffer_list; + + if (aws_array_list_init_dynamic(&base_64_buffer_list, alloc, 2, sizeof(struct aws_byte_buf))) { + return AWS_OP_ERR; + } + + int err_code = AWS_OP_ERR; + + if (s_convert_pem_to_raw_base64(alloc, pem_cursor, &base_64_buffer_list)) { + goto cleanup_base64_buffer_list; + } + + for (size_t i = 0; i < aws_array_list_length(&base_64_buffer_list); ++i) { + size_t decoded_len = 0; + struct aws_byte_buf *byte_buf_ptr = NULL; + aws_array_list_get_at_ptr(&base_64_buffer_list, (void **)&byte_buf_ptr, i); + struct aws_byte_cursor byte_cur = aws_byte_cursor_from_buf(byte_buf_ptr); + + if (aws_base64_compute_decoded_len(&byte_cur, &decoded_len)) { + aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); + goto cleanup_output_due_to_error; + } + + struct aws_byte_buf decoded_buffer; + + if (aws_byte_buf_init(&decoded_buffer, alloc, decoded_len)) { + goto cleanup_output_due_to_error; + } + + if (aws_base64_decode(&byte_cur, &decoded_buffer)) { + aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); + aws_byte_buf_clean_up(&decoded_buffer); + goto cleanup_output_due_to_error; + } + + if (aws_array_list_push_back(cert_chain_or_key, &decoded_buffer)) { + aws_byte_buf_clean_up(&decoded_buffer); + goto cleanup_output_due_to_error; + } + } + + err_code = AWS_OP_SUCCESS; + +cleanup_base64_buffer_list: + aws_cert_chain_clean_up(&base_64_buffer_list); + aws_array_list_clean_up(&base_64_buffer_list); + + return err_code; + +cleanup_output_due_to_error: + AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Invalid PEM buffer."); + aws_cert_chain_clean_up(&base_64_buffer_list); + aws_array_list_clean_up(&base_64_buffer_list); + + aws_cert_chain_clean_up(cert_chain_or_key); + + return AWS_OP_ERR; +} + +int aws_read_and_decode_pem_file_to_buffer_list( + struct aws_allocator *alloc, + const char *filename, + struct aws_array_list *cert_chain_or_key) { + + struct aws_byte_buf raw_file_buffer; + if (aws_byte_buf_init_from_file(&raw_file_buffer, alloc, filename)) { + AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Failed to read file %s.", filename); + return AWS_OP_ERR; + } + AWS_ASSERT(raw_file_buffer.buffer); + + struct aws_byte_cursor file_cursor = aws_byte_cursor_from_buf(&raw_file_buffer); + if (aws_decode_pem_to_buffer_list(alloc, &file_cursor, cert_chain_or_key)) { + aws_secure_zero(raw_file_buffer.buffer, raw_file_buffer.len); + aws_byte_buf_clean_up(&raw_file_buffer); + AWS_LOGF_ERROR(AWS_LS_IO_PKI, "static: Failed to decode PEM file %s.", filename); + return AWS_OP_ERR; + } + + aws_secure_zero(raw_file_buffer.buffer, raw_file_buffer.len); + aws_byte_buf_clean_up(&raw_file_buffer); + + return AWS_OP_SUCCESS; +} diff --git a/contrib/restricted/aws/aws-c-io/source/posix/file_utils.c b/contrib/restricted/aws/aws-c-io/source/posix/file_utils.c index fcb96260eb..03b5f6c734 100644 --- a/contrib/restricted/aws/aws-c-io/source/posix/file_utils.c +++ b/contrib/restricted/aws/aws-c-io/source/posix/file_utils.c @@ -1,69 +1,69 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/file_utils.h> - -#include <aws/common/environment.h> -#include <aws/common/string.h> - -#include <errno.h> -#include <sys/stat.h> -#include <unistd.h> - -char aws_get_platform_directory_separator(void) { - return '/'; -} - -AWS_STATIC_STRING_FROM_LITERAL(s_home_env_var, "HOME"); - -struct aws_string *aws_get_home_directory(struct aws_allocator *allocator) { - - /* ToDo: check getpwuid_r if environment check fails */ - struct aws_string *home_env_var_value = NULL; - if (aws_get_environment_value(allocator, s_home_env_var, &home_env_var_value) == 0 && home_env_var_value != NULL) { - return home_env_var_value; - } - - return NULL; -} - -bool aws_path_exists(const char *path) { - struct stat buffer; - return stat(path, &buffer) == 0; -} - -int aws_fseek(FILE *file, aws_off_t offset, int whence) { - - int result = -#if _FILE_OFFSET_BITS == 64 || _POSIX_C_SOURCE >= 200112L - fseeko(file, offset, whence); -#else - fseek(file, offset, whence); -#endif - - if (result != 0) { - return aws_translate_and_raise_io_error(errno); - } - - return AWS_OP_SUCCESS; -} - -int aws_file_get_length(FILE *file, int64_t *length) { - - struct stat file_stats; - - int fd = fileno(file); - if (fd == -1) { - return aws_raise_error(AWS_IO_INVALID_FILE_HANDLE); - } - - if (fstat(fd, &file_stats)) { - return aws_translate_and_raise_io_error(errno); - } - - *length = file_stats.st_size; - - return AWS_OP_SUCCESS; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/file_utils.h> + +#include <aws/common/environment.h> +#include <aws/common/string.h> + +#include <errno.h> +#include <sys/stat.h> +#include <unistd.h> + +char aws_get_platform_directory_separator(void) { + return '/'; +} + +AWS_STATIC_STRING_FROM_LITERAL(s_home_env_var, "HOME"); + +struct aws_string *aws_get_home_directory(struct aws_allocator *allocator) { + + /* ToDo: check getpwuid_r if environment check fails */ + struct aws_string *home_env_var_value = NULL; + if (aws_get_environment_value(allocator, s_home_env_var, &home_env_var_value) == 0 && home_env_var_value != NULL) { + return home_env_var_value; + } + + return NULL; +} + +bool aws_path_exists(const char *path) { + struct stat buffer; + return stat(path, &buffer) == 0; +} + +int aws_fseek(FILE *file, aws_off_t offset, int whence) { + + int result = +#if _FILE_OFFSET_BITS == 64 || _POSIX_C_SOURCE >= 200112L + fseeko(file, offset, whence); +#else + fseek(file, offset, whence); +#endif + + if (result != 0) { + return aws_translate_and_raise_io_error(errno); + } + + return AWS_OP_SUCCESS; +} + +int aws_file_get_length(FILE *file, int64_t *length) { + + struct stat file_stats; + + int fd = fileno(file); + if (fd == -1) { + return aws_raise_error(AWS_IO_INVALID_FILE_HANDLE); + } + + if (fstat(fd, &file_stats)) { + return aws_translate_and_raise_io_error(errno); + } + + *length = file_stats.st_size; + + return AWS_OP_SUCCESS; +} diff --git a/contrib/restricted/aws/aws-c-io/source/posix/host_resolver.c b/contrib/restricted/aws/aws-c-io/source/posix/host_resolver.c index 6594723bb8..e9604107d7 100644 --- a/contrib/restricted/aws/aws-c-io/source/posix/host_resolver.c +++ b/contrib/restricted/aws/aws-c-io/source/posix/host_resolver.c @@ -1,118 +1,118 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/host_resolver.h> - -#include <aws/io/logging.h> - -#include <aws/common/string.h> - -#include <arpa/inet.h> -#include <netdb.h> -#include <sys/socket.h> -#include <sys/types.h> - -int aws_default_dns_resolve( - struct aws_allocator *allocator, - const struct aws_string *host_name, - struct aws_array_list *output_addresses, - void *user_data) { - - (void)user_data; - struct addrinfo *result = NULL; - struct addrinfo *iter = NULL; - /* max string length for ipv6. */ - socklen_t max_len = INET6_ADDRSTRLEN; - char address_buffer[max_len]; - - const char *hostname_cstr = aws_string_c_str(host_name); - AWS_LOGF_DEBUG(AWS_LS_IO_DNS, "static: resolving host %s", hostname_cstr); - - /* Android would prefer NO HINTS IF YOU DON'T MIND, SIR */ -#ifdef ANDROID - int err_code = getaddrinfo(hostname_cstr, NULL, NULL, &result); -#else - struct addrinfo hints; - AWS_ZERO_STRUCT(hints); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = AI_ALL | AI_V4MAPPED; - - int err_code = getaddrinfo(hostname_cstr, NULL, &hints, &result); -#endif - - if (err_code) { - AWS_LOGF_ERROR(AWS_LS_IO_DNS, "static: getaddrinfo failed with error_code %d", err_code); - goto clean_up; - } - - for (iter = result; iter != NULL; iter = iter->ai_next) { - struct aws_host_address host_address; - - AWS_ZERO_ARRAY(address_buffer); - - if (iter->ai_family == AF_INET6) { - host_address.record_type = AWS_ADDRESS_RECORD_TYPE_AAAA; - inet_ntop(iter->ai_family, &((struct sockaddr_in6 *)iter->ai_addr)->sin6_addr, address_buffer, max_len); - } else { - host_address.record_type = AWS_ADDRESS_RECORD_TYPE_A; - inet_ntop(iter->ai_family, &((struct sockaddr_in *)iter->ai_addr)->sin_addr, address_buffer, max_len); - } - - size_t address_len = strlen(address_buffer); - const struct aws_string *address = - aws_string_new_from_array(allocator, (const uint8_t *)address_buffer, address_len); - - if (!address) { - goto clean_up; - } - - const struct aws_string *host_cpy = aws_string_new_from_string(allocator, host_name); - - if (!host_cpy) { - aws_string_destroy((void *)address); - goto clean_up; - } - - AWS_LOGF_DEBUG(AWS_LS_IO_DNS, "static: resolved record: %s", address_buffer); - - host_address.address = address; - host_address.weight = 0; - host_address.allocator = allocator; - host_address.use_count = 0; - host_address.connection_failure_count = 0; - host_address.host = host_cpy; - - if (aws_array_list_push_back(output_addresses, &host_address)) { - aws_host_address_clean_up(&host_address); - goto clean_up; - } - } - - freeaddrinfo(result); - return AWS_OP_SUCCESS; - -clean_up: - if (result) { - freeaddrinfo(result); - } - - if (err_code) { - switch (err_code) { - case EAI_FAIL: - case EAI_AGAIN: - return aws_raise_error(AWS_IO_DNS_QUERY_FAILED); - case EAI_MEMORY: - return aws_raise_error(AWS_ERROR_OOM); - case EAI_NONAME: - case EAI_SERVICE: - return aws_raise_error(AWS_IO_DNS_INVALID_NAME); - default: - return aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); - } - } - - return AWS_OP_ERR; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/host_resolver.h> + +#include <aws/io/logging.h> + +#include <aws/common/string.h> + +#include <arpa/inet.h> +#include <netdb.h> +#include <sys/socket.h> +#include <sys/types.h> + +int aws_default_dns_resolve( + struct aws_allocator *allocator, + const struct aws_string *host_name, + struct aws_array_list *output_addresses, + void *user_data) { + + (void)user_data; + struct addrinfo *result = NULL; + struct addrinfo *iter = NULL; + /* max string length for ipv6. */ + socklen_t max_len = INET6_ADDRSTRLEN; + char address_buffer[max_len]; + + const char *hostname_cstr = aws_string_c_str(host_name); + AWS_LOGF_DEBUG(AWS_LS_IO_DNS, "static: resolving host %s", hostname_cstr); + + /* Android would prefer NO HINTS IF YOU DON'T MIND, SIR */ +#ifdef ANDROID + int err_code = getaddrinfo(hostname_cstr, NULL, NULL, &result); +#else + struct addrinfo hints; + AWS_ZERO_STRUCT(hints); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_ALL | AI_V4MAPPED; + + int err_code = getaddrinfo(hostname_cstr, NULL, &hints, &result); +#endif + + if (err_code) { + AWS_LOGF_ERROR(AWS_LS_IO_DNS, "static: getaddrinfo failed with error_code %d", err_code); + goto clean_up; + } + + for (iter = result; iter != NULL; iter = iter->ai_next) { + struct aws_host_address host_address; + + AWS_ZERO_ARRAY(address_buffer); + + if (iter->ai_family == AF_INET6) { + host_address.record_type = AWS_ADDRESS_RECORD_TYPE_AAAA; + inet_ntop(iter->ai_family, &((struct sockaddr_in6 *)iter->ai_addr)->sin6_addr, address_buffer, max_len); + } else { + host_address.record_type = AWS_ADDRESS_RECORD_TYPE_A; + inet_ntop(iter->ai_family, &((struct sockaddr_in *)iter->ai_addr)->sin_addr, address_buffer, max_len); + } + + size_t address_len = strlen(address_buffer); + const struct aws_string *address = + aws_string_new_from_array(allocator, (const uint8_t *)address_buffer, address_len); + + if (!address) { + goto clean_up; + } + + const struct aws_string *host_cpy = aws_string_new_from_string(allocator, host_name); + + if (!host_cpy) { + aws_string_destroy((void *)address); + goto clean_up; + } + + AWS_LOGF_DEBUG(AWS_LS_IO_DNS, "static: resolved record: %s", address_buffer); + + host_address.address = address; + host_address.weight = 0; + host_address.allocator = allocator; + host_address.use_count = 0; + host_address.connection_failure_count = 0; + host_address.host = host_cpy; + + if (aws_array_list_push_back(output_addresses, &host_address)) { + aws_host_address_clean_up(&host_address); + goto clean_up; + } + } + + freeaddrinfo(result); + return AWS_OP_SUCCESS; + +clean_up: + if (result) { + freeaddrinfo(result); + } + + if (err_code) { + switch (err_code) { + case EAI_FAIL: + case EAI_AGAIN: + return aws_raise_error(AWS_IO_DNS_QUERY_FAILED); + case EAI_MEMORY: + return aws_raise_error(AWS_ERROR_OOM); + case EAI_NONAME: + case EAI_SERVICE: + return aws_raise_error(AWS_IO_DNS_INVALID_NAME); + default: + return aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); + } + } + + return AWS_OP_ERR; +} diff --git a/contrib/restricted/aws/aws-c-io/source/posix/pipe.c b/contrib/restricted/aws/aws-c-io/source/posix/pipe.c index 141cd05cbe..049a15d690 100644 --- a/contrib/restricted/aws/aws-c-io/source/posix/pipe.c +++ b/contrib/restricted/aws/aws-c-io/source/posix/pipe.c @@ -1,583 +1,583 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/pipe.h> - -#include <aws/io/event_loop.h> - -#ifdef __GLIBC__ -# define __USE_GNU -#endif - -/* TODO: move this detection to CMAKE and a config header */ -#if !defined(COMPAT_MODE) && defined(__GLIBC__) && __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 9 -# define HAVE_PIPE2 1 -#else -# define HAVE_PIPE2 0 -#endif - -#include <errno.h> -#include <fcntl.h> -#include <unistd.h> - -/* This isn't defined on ancient linux distros (breaking the builds). - * However, if this is a prebuild, we purposely build on an ancient system, but - * we want the kernel calls to still be the same as a modern build since that's likely the target of the application - * calling this code. Just define this if it isn't there already. GlibC and the kernel don't really care how the flag - * gets passed as long as it does. - */ -#ifndef O_CLOEXEC -# define O_CLOEXEC 02000000 -#endif - -struct read_end_impl { - struct aws_allocator *alloc; - struct aws_io_handle handle; - struct aws_event_loop *event_loop; - aws_pipe_on_readable_fn *on_readable_user_callback; - void *on_readable_user_data; - - /* Used in handshake for detecting whether user callback resulted in read-end being cleaned up. - * If clean_up() sees that the pointer is set, the bool it points to will get set true. */ - bool *did_user_callback_clean_up_read_end; - - bool is_subscribed; -}; - -struct write_request { - struct aws_byte_cursor original_cursor; - struct aws_byte_cursor cursor; /* tracks progress of write */ - size_t num_bytes_written; - aws_pipe_on_write_completed_fn *user_callback; - void *user_data; - struct aws_linked_list_node list_node; - - /* True if the write-end is cleaned up while the user callback is being invoked */ - bool did_user_callback_clean_up_write_end; -}; - -struct write_end_impl { - struct aws_allocator *alloc; - struct aws_io_handle handle; - struct aws_event_loop *event_loop; - struct aws_linked_list write_list; - - /* Valid while invoking user callback on a completed write request. */ - struct write_request *currently_invoking_write_callback; - - bool is_writable; - - /* Future optimization idea: avoid an allocation on each write by keeping 1 pre-allocated write_request around - * and re-using it whenever possible */ -}; - -static void s_write_end_on_event( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - void *user_data); - -static int s_translate_posix_error(int err) { - AWS_ASSERT(err); - - switch (err) { - case EPIPE: - return AWS_IO_BROKEN_PIPE; - default: - return AWS_ERROR_SYS_CALL_FAILURE; - } -} - -static int s_raise_posix_error(int err) { - return aws_raise_error(s_translate_posix_error(err)); -} - -AWS_IO_API int aws_open_nonblocking_posix_pipe(int pipe_fds[2]) { - int err; - -#if HAVE_PIPE2 - err = pipe2(pipe_fds, O_NONBLOCK | O_CLOEXEC); - if (err) { - return s_raise_posix_error(err); - } - - return AWS_OP_SUCCESS; -#else - err = pipe(pipe_fds); - if (err) { - return s_raise_posix_error(err); - } - - for (int i = 0; i < 2; ++i) { - int flags = fcntl(pipe_fds[i], F_GETFL); - if (flags == -1) { - s_raise_posix_error(err); - goto error; - } - - flags |= O_NONBLOCK | O_CLOEXEC; - if (fcntl(pipe_fds[i], F_SETFL, flags) == -1) { - s_raise_posix_error(err); - goto error; - } - } - - return AWS_OP_SUCCESS; -error: - close(pipe_fds[0]); - close(pipe_fds[1]); - return AWS_OP_ERR; -#endif -} - -int aws_pipe_init( - struct aws_pipe_read_end *read_end, - struct aws_event_loop *read_end_event_loop, - struct aws_pipe_write_end *write_end, - struct aws_event_loop *write_end_event_loop, - struct aws_allocator *allocator) { - - AWS_ASSERT(read_end); - AWS_ASSERT(read_end_event_loop); - AWS_ASSERT(write_end); - AWS_ASSERT(write_end_event_loop); - AWS_ASSERT(allocator); - - AWS_ZERO_STRUCT(*read_end); - AWS_ZERO_STRUCT(*write_end); - - struct read_end_impl *read_impl = NULL; - struct write_end_impl *write_impl = NULL; - int err; - - /* Open pipe */ - int pipe_fds[2]; - err = aws_open_nonblocking_posix_pipe(pipe_fds); - if (err) { - return AWS_OP_ERR; - } - - /* Init read-end */ - read_impl = aws_mem_calloc(allocator, 1, sizeof(struct read_end_impl)); - if (!read_impl) { - goto error; - } - - read_impl->alloc = allocator; - read_impl->handle.data.fd = pipe_fds[0]; - read_impl->event_loop = read_end_event_loop; - - /* Init write-end */ - write_impl = aws_mem_calloc(allocator, 1, sizeof(struct write_end_impl)); - if (!write_impl) { - goto error; - } - - write_impl->alloc = allocator; - write_impl->handle.data.fd = pipe_fds[1]; - write_impl->event_loop = write_end_event_loop; - write_impl->is_writable = true; /* Assume pipe is writable to start. Even if it's not, things shouldn't break */ - aws_linked_list_init(&write_impl->write_list); - - read_end->impl_data = read_impl; - write_end->impl_data = write_impl; - - err = aws_event_loop_subscribe_to_io_events( - write_end_event_loop, &write_impl->handle, AWS_IO_EVENT_TYPE_WRITABLE, s_write_end_on_event, write_end); - if (err) { - goto error; - } - - return AWS_OP_SUCCESS; - -error: - close(pipe_fds[0]); - close(pipe_fds[1]); - - if (read_impl) { - aws_mem_release(allocator, read_impl); - } - - if (write_impl) { - aws_mem_release(allocator, write_impl); - } - - read_end->impl_data = NULL; - write_end->impl_data = NULL; - - return AWS_OP_ERR; -} - -int aws_pipe_clean_up_read_end(struct aws_pipe_read_end *read_end) { - struct read_end_impl *read_impl = read_end->impl_data; - if (!read_impl) { - return aws_raise_error(AWS_IO_BROKEN_PIPE); - } - - if (!aws_event_loop_thread_is_callers_thread(read_impl->event_loop)) { - return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); - } - - if (read_impl->is_subscribed) { - int err = aws_pipe_unsubscribe_from_readable_events(read_end); - if (err) { - return AWS_OP_ERR; - } - } - - /* If the event-handler is invoking a user callback, let it know that the read-end was cleaned up */ - if (read_impl->did_user_callback_clean_up_read_end) { - *read_impl->did_user_callback_clean_up_read_end = true; - } - - close(read_impl->handle.data.fd); - - aws_mem_release(read_impl->alloc, read_impl); - AWS_ZERO_STRUCT(*read_end); - return AWS_OP_SUCCESS; -} - -struct aws_event_loop *aws_pipe_get_read_end_event_loop(const struct aws_pipe_read_end *read_end) { - const struct read_end_impl *read_impl = read_end->impl_data; - if (!read_impl) { - aws_raise_error(AWS_IO_BROKEN_PIPE); - return NULL; - } - - return read_impl->event_loop; -} - -struct aws_event_loop *aws_pipe_get_write_end_event_loop(const struct aws_pipe_write_end *write_end) { - const struct write_end_impl *write_impl = write_end->impl_data; - if (!write_impl) { - aws_raise_error(AWS_IO_BROKEN_PIPE); - return NULL; - } - - return write_impl->event_loop; -} - -int aws_pipe_read(struct aws_pipe_read_end *read_end, struct aws_byte_buf *dst_buffer, size_t *num_bytes_read) { - AWS_ASSERT(dst_buffer && dst_buffer->buffer); - - struct read_end_impl *read_impl = read_end->impl_data; - if (!read_impl) { - return aws_raise_error(AWS_IO_BROKEN_PIPE); - } - - if (num_bytes_read) { - *num_bytes_read = 0; - } - - size_t num_bytes_to_read = dst_buffer->capacity - dst_buffer->len; - - ssize_t read_val = read(read_impl->handle.data.fd, dst_buffer->buffer + dst_buffer->len, num_bytes_to_read); - - if (read_val < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - return aws_raise_error(AWS_IO_READ_WOULD_BLOCK); - } - return s_raise_posix_error(errno); - } - - /* Success */ - dst_buffer->len += read_val; - - if (num_bytes_read) { - *num_bytes_read = read_val; - } - - return AWS_OP_SUCCESS; -} - -static void s_read_end_on_event( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - void *user_data) { - - (void)event_loop; - (void)handle; - - /* Note that it should be impossible for this to run after read-end has been unsubscribed or cleaned up */ - struct aws_pipe_read_end *read_end = user_data; - struct read_end_impl *read_impl = read_end->impl_data; - AWS_ASSERT(read_impl); - AWS_ASSERT(read_impl->event_loop == event_loop); - AWS_ASSERT(&read_impl->handle == handle); - AWS_ASSERT(read_impl->is_subscribed); - AWS_ASSERT(events != 0); - AWS_ASSERT(read_impl->did_user_callback_clean_up_read_end == NULL); - - /* Set up handshake, so we can be informed if the read-end is cleaned up while invoking a user callback */ - bool did_user_callback_clean_up_read_end = false; - read_impl->did_user_callback_clean_up_read_end = &did_user_callback_clean_up_read_end; - - /* If readable event received, tell user to try and read, even if "error" events have also occurred. */ - if (events & AWS_IO_EVENT_TYPE_READABLE) { - read_impl->on_readable_user_callback(read_end, AWS_ERROR_SUCCESS, read_impl->on_readable_user_data); - - if (did_user_callback_clean_up_read_end) { - return; - } - - events &= ~AWS_IO_EVENT_TYPE_READABLE; - } - - if (events) { - /* Check that user didn't unsubscribe in the previous callback */ - if (read_impl->is_subscribed) { - read_impl->on_readable_user_callback(read_end, AWS_IO_BROKEN_PIPE, read_impl->on_readable_user_data); - - if (did_user_callback_clean_up_read_end) { - return; - } - } - } - - read_impl->did_user_callback_clean_up_read_end = NULL; -} - -int aws_pipe_subscribe_to_readable_events( - struct aws_pipe_read_end *read_end, - aws_pipe_on_readable_fn *on_readable, - void *user_data) { - - AWS_ASSERT(on_readable); - - struct read_end_impl *read_impl = read_end->impl_data; - if (!read_impl) { - return aws_raise_error(AWS_IO_BROKEN_PIPE); - } - - if (!aws_event_loop_thread_is_callers_thread(read_impl->event_loop)) { - return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); - } - - if (read_impl->is_subscribed) { - return aws_raise_error(AWS_ERROR_IO_ALREADY_SUBSCRIBED); - } - - read_impl->is_subscribed = true; - read_impl->on_readable_user_callback = on_readable; - read_impl->on_readable_user_data = user_data; - - int err = aws_event_loop_subscribe_to_io_events( - read_impl->event_loop, &read_impl->handle, AWS_IO_EVENT_TYPE_READABLE, s_read_end_on_event, read_end); - if (err) { - read_impl->is_subscribed = false; - read_impl->on_readable_user_callback = NULL; - read_impl->on_readable_user_data = NULL; - - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -int aws_pipe_unsubscribe_from_readable_events(struct aws_pipe_read_end *read_end) { - struct read_end_impl *read_impl = read_end->impl_data; - if (!read_impl) { - return aws_raise_error(AWS_IO_BROKEN_PIPE); - } - - if (!aws_event_loop_thread_is_callers_thread(read_impl->event_loop)) { - return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); - } - - if (!read_impl->is_subscribed) { - return aws_raise_error(AWS_ERROR_IO_NOT_SUBSCRIBED); - } - - int err = aws_event_loop_unsubscribe_from_io_events(read_impl->event_loop, &read_impl->handle); - if (err) { - return AWS_OP_ERR; - } - - read_impl->is_subscribed = false; - read_impl->on_readable_user_callback = NULL; - read_impl->on_readable_user_data = NULL; - - return AWS_OP_SUCCESS; -} - -/* Pop front write request, invoke its callback, and delete it. - * Returns whether the callback resulted in the write-end getting cleaned up */ -static bool s_write_end_complete_front_write_request(struct aws_pipe_write_end *write_end, int error_code) { - struct write_end_impl *write_impl = write_end->impl_data; - - AWS_ASSERT(!aws_linked_list_empty(&write_impl->write_list)); - struct aws_linked_list_node *node = aws_linked_list_pop_front(&write_impl->write_list); - struct write_request *request = AWS_CONTAINER_OF(node, struct write_request, list_node); - - struct aws_allocator *alloc = write_impl->alloc; - - /* Let the write-end know that a callback is in process, so the write-end can inform the callback - * whether it resulted in clean_up() being called. */ - bool write_end_cleaned_up_during_callback = false; - struct write_request *prev_invoking_request = write_impl->currently_invoking_write_callback; - write_impl->currently_invoking_write_callback = request; - - if (request->user_callback) { - request->user_callback(write_end, error_code, request->original_cursor, request->user_data); - write_end_cleaned_up_during_callback = request->did_user_callback_clean_up_write_end; - } - - if (!write_end_cleaned_up_during_callback) { - write_impl->currently_invoking_write_callback = prev_invoking_request; - } - - aws_mem_release(alloc, request); - - return write_end_cleaned_up_during_callback; -} - -/* Process write requests as long as the pipe remains writable */ -static void s_write_end_process_requests(struct aws_pipe_write_end *write_end) { - struct write_end_impl *write_impl = write_end->impl_data; - AWS_ASSERT(write_impl); - - while (!aws_linked_list_empty(&write_impl->write_list)) { - struct aws_linked_list_node *node = aws_linked_list_front(&write_impl->write_list); - struct write_request *request = AWS_CONTAINER_OF(node, struct write_request, list_node); - - int completed_error_code = AWS_ERROR_SUCCESS; - - if (request->cursor.len > 0) { - ssize_t write_val = write(write_impl->handle.data.fd, request->cursor.ptr, request->cursor.len); - - if (write_val < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - /* The pipe is no longer writable. Bail out */ - write_impl->is_writable = false; - return; - } - - /* A non-recoverable error occurred during this write */ - completed_error_code = s_translate_posix_error(errno); - - } else { - aws_byte_cursor_advance(&request->cursor, write_val); - - if (request->cursor.len > 0) { - /* There was a partial write, loop again to try and write the rest. */ - continue; - } - } - } - - /* If we got this far in the loop, then the write request is complete. - * Note that the callback may result in the pipe being cleaned up. */ - bool write_end_cleaned_up = s_write_end_complete_front_write_request(write_end, completed_error_code); - if (write_end_cleaned_up) { - /* Bail out! Any remaining requests were canceled during clean_up() */ - return; - } - } -} - -/* Handle events on the write-end's file handle */ -static void s_write_end_on_event( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - void *user_data) { - - (void)event_loop; - (void)handle; - - /* Note that it should be impossible for this to run after write-end has been unsubscribed or cleaned up */ - struct aws_pipe_write_end *write_end = user_data; - struct write_end_impl *write_impl = write_end->impl_data; - AWS_ASSERT(write_impl); - AWS_ASSERT(write_impl->event_loop == event_loop); - AWS_ASSERT(&write_impl->handle == handle); - - /* Only care about the writable event. */ - if ((events & AWS_IO_EVENT_TYPE_WRITABLE) == 0) { - return; - } - - write_impl->is_writable = true; - - s_write_end_process_requests(write_end); -} - -int aws_pipe_write( - struct aws_pipe_write_end *write_end, - struct aws_byte_cursor src_buffer, - aws_pipe_on_write_completed_fn *on_completed, - void *user_data) { - - AWS_ASSERT(src_buffer.ptr); - - struct write_end_impl *write_impl = write_end->impl_data; - if (!write_impl) { - return aws_raise_error(AWS_IO_BROKEN_PIPE); - } - - if (!aws_event_loop_thread_is_callers_thread(write_impl->event_loop)) { - return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); - } - - struct write_request *request = aws_mem_calloc(write_impl->alloc, 1, sizeof(struct write_request)); - if (!request) { - return AWS_OP_ERR; - } - - request->original_cursor = src_buffer; - request->cursor = src_buffer; - request->user_callback = on_completed; - request->user_data = user_data; - - aws_linked_list_push_back(&write_impl->write_list, &request->list_node); - - /* If the pipe is writable, process the request (unless pipe is already in the middle of processing, which could - * happen if a this aws_pipe_write() call was made by another write's completion callback */ - if (write_impl->is_writable && !write_impl->currently_invoking_write_callback) { - s_write_end_process_requests(write_end); - } - - return AWS_OP_SUCCESS; -} - -int aws_pipe_clean_up_write_end(struct aws_pipe_write_end *write_end) { - struct write_end_impl *write_impl = write_end->impl_data; - if (!write_impl) { - return aws_raise_error(AWS_IO_BROKEN_PIPE); - } - - if (!aws_event_loop_thread_is_callers_thread(write_impl->event_loop)) { - return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); - } - - int err = aws_event_loop_unsubscribe_from_io_events(write_impl->event_loop, &write_impl->handle); - if (err) { - return AWS_OP_ERR; - } - - close(write_impl->handle.data.fd); - - /* Zero out write-end before invoking user callbacks so that it won't work anymore with public functions. */ - AWS_ZERO_STRUCT(*write_end); - - /* If a request callback is currently being invoked, let it know that the write-end was cleaned up */ - if (write_impl->currently_invoking_write_callback) { - write_impl->currently_invoking_write_callback->did_user_callback_clean_up_write_end = true; - } - - /* Force any outstanding write requests to complete with an error status. */ - while (!aws_linked_list_empty(&write_impl->write_list)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&write_impl->write_list); - struct write_request *request = AWS_CONTAINER_OF(node, struct write_request, list_node); - if (request->user_callback) { - request->user_callback(NULL, AWS_IO_BROKEN_PIPE, request->original_cursor, request->user_data); - } - aws_mem_release(write_impl->alloc, request); - } - - aws_mem_release(write_impl->alloc, write_impl); - return AWS_OP_SUCCESS; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/pipe.h> + +#include <aws/io/event_loop.h> + +#ifdef __GLIBC__ +# define __USE_GNU +#endif + +/* TODO: move this detection to CMAKE and a config header */ +#if !defined(COMPAT_MODE) && defined(__GLIBC__) && __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 9 +# define HAVE_PIPE2 1 +#else +# define HAVE_PIPE2 0 +#endif + +#include <errno.h> +#include <fcntl.h> +#include <unistd.h> + +/* This isn't defined on ancient linux distros (breaking the builds). + * However, if this is a prebuild, we purposely build on an ancient system, but + * we want the kernel calls to still be the same as a modern build since that's likely the target of the application + * calling this code. Just define this if it isn't there already. GlibC and the kernel don't really care how the flag + * gets passed as long as it does. + */ +#ifndef O_CLOEXEC +# define O_CLOEXEC 02000000 +#endif + +struct read_end_impl { + struct aws_allocator *alloc; + struct aws_io_handle handle; + struct aws_event_loop *event_loop; + aws_pipe_on_readable_fn *on_readable_user_callback; + void *on_readable_user_data; + + /* Used in handshake for detecting whether user callback resulted in read-end being cleaned up. + * If clean_up() sees that the pointer is set, the bool it points to will get set true. */ + bool *did_user_callback_clean_up_read_end; + + bool is_subscribed; +}; + +struct write_request { + struct aws_byte_cursor original_cursor; + struct aws_byte_cursor cursor; /* tracks progress of write */ + size_t num_bytes_written; + aws_pipe_on_write_completed_fn *user_callback; + void *user_data; + struct aws_linked_list_node list_node; + + /* True if the write-end is cleaned up while the user callback is being invoked */ + bool did_user_callback_clean_up_write_end; +}; + +struct write_end_impl { + struct aws_allocator *alloc; + struct aws_io_handle handle; + struct aws_event_loop *event_loop; + struct aws_linked_list write_list; + + /* Valid while invoking user callback on a completed write request. */ + struct write_request *currently_invoking_write_callback; + + bool is_writable; + + /* Future optimization idea: avoid an allocation on each write by keeping 1 pre-allocated write_request around + * and re-using it whenever possible */ +}; + +static void s_write_end_on_event( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + void *user_data); + +static int s_translate_posix_error(int err) { + AWS_ASSERT(err); + + switch (err) { + case EPIPE: + return AWS_IO_BROKEN_PIPE; + default: + return AWS_ERROR_SYS_CALL_FAILURE; + } +} + +static int s_raise_posix_error(int err) { + return aws_raise_error(s_translate_posix_error(err)); +} + +AWS_IO_API int aws_open_nonblocking_posix_pipe(int pipe_fds[2]) { + int err; + +#if HAVE_PIPE2 + err = pipe2(pipe_fds, O_NONBLOCK | O_CLOEXEC); + if (err) { + return s_raise_posix_error(err); + } + + return AWS_OP_SUCCESS; +#else + err = pipe(pipe_fds); + if (err) { + return s_raise_posix_error(err); + } + + for (int i = 0; i < 2; ++i) { + int flags = fcntl(pipe_fds[i], F_GETFL); + if (flags == -1) { + s_raise_posix_error(err); + goto error; + } + + flags |= O_NONBLOCK | O_CLOEXEC; + if (fcntl(pipe_fds[i], F_SETFL, flags) == -1) { + s_raise_posix_error(err); + goto error; + } + } + + return AWS_OP_SUCCESS; +error: + close(pipe_fds[0]); + close(pipe_fds[1]); + return AWS_OP_ERR; +#endif +} + +int aws_pipe_init( + struct aws_pipe_read_end *read_end, + struct aws_event_loop *read_end_event_loop, + struct aws_pipe_write_end *write_end, + struct aws_event_loop *write_end_event_loop, + struct aws_allocator *allocator) { + + AWS_ASSERT(read_end); + AWS_ASSERT(read_end_event_loop); + AWS_ASSERT(write_end); + AWS_ASSERT(write_end_event_loop); + AWS_ASSERT(allocator); + + AWS_ZERO_STRUCT(*read_end); + AWS_ZERO_STRUCT(*write_end); + + struct read_end_impl *read_impl = NULL; + struct write_end_impl *write_impl = NULL; + int err; + + /* Open pipe */ + int pipe_fds[2]; + err = aws_open_nonblocking_posix_pipe(pipe_fds); + if (err) { + return AWS_OP_ERR; + } + + /* Init read-end */ + read_impl = aws_mem_calloc(allocator, 1, sizeof(struct read_end_impl)); + if (!read_impl) { + goto error; + } + + read_impl->alloc = allocator; + read_impl->handle.data.fd = pipe_fds[0]; + read_impl->event_loop = read_end_event_loop; + + /* Init write-end */ + write_impl = aws_mem_calloc(allocator, 1, sizeof(struct write_end_impl)); + if (!write_impl) { + goto error; + } + + write_impl->alloc = allocator; + write_impl->handle.data.fd = pipe_fds[1]; + write_impl->event_loop = write_end_event_loop; + write_impl->is_writable = true; /* Assume pipe is writable to start. Even if it's not, things shouldn't break */ + aws_linked_list_init(&write_impl->write_list); + + read_end->impl_data = read_impl; + write_end->impl_data = write_impl; + + err = aws_event_loop_subscribe_to_io_events( + write_end_event_loop, &write_impl->handle, AWS_IO_EVENT_TYPE_WRITABLE, s_write_end_on_event, write_end); + if (err) { + goto error; + } + + return AWS_OP_SUCCESS; + +error: + close(pipe_fds[0]); + close(pipe_fds[1]); + + if (read_impl) { + aws_mem_release(allocator, read_impl); + } + + if (write_impl) { + aws_mem_release(allocator, write_impl); + } + + read_end->impl_data = NULL; + write_end->impl_data = NULL; + + return AWS_OP_ERR; +} + +int aws_pipe_clean_up_read_end(struct aws_pipe_read_end *read_end) { + struct read_end_impl *read_impl = read_end->impl_data; + if (!read_impl) { + return aws_raise_error(AWS_IO_BROKEN_PIPE); + } + + if (!aws_event_loop_thread_is_callers_thread(read_impl->event_loop)) { + return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); + } + + if (read_impl->is_subscribed) { + int err = aws_pipe_unsubscribe_from_readable_events(read_end); + if (err) { + return AWS_OP_ERR; + } + } + + /* If the event-handler is invoking a user callback, let it know that the read-end was cleaned up */ + if (read_impl->did_user_callback_clean_up_read_end) { + *read_impl->did_user_callback_clean_up_read_end = true; + } + + close(read_impl->handle.data.fd); + + aws_mem_release(read_impl->alloc, read_impl); + AWS_ZERO_STRUCT(*read_end); + return AWS_OP_SUCCESS; +} + +struct aws_event_loop *aws_pipe_get_read_end_event_loop(const struct aws_pipe_read_end *read_end) { + const struct read_end_impl *read_impl = read_end->impl_data; + if (!read_impl) { + aws_raise_error(AWS_IO_BROKEN_PIPE); + return NULL; + } + + return read_impl->event_loop; +} + +struct aws_event_loop *aws_pipe_get_write_end_event_loop(const struct aws_pipe_write_end *write_end) { + const struct write_end_impl *write_impl = write_end->impl_data; + if (!write_impl) { + aws_raise_error(AWS_IO_BROKEN_PIPE); + return NULL; + } + + return write_impl->event_loop; +} + +int aws_pipe_read(struct aws_pipe_read_end *read_end, struct aws_byte_buf *dst_buffer, size_t *num_bytes_read) { + AWS_ASSERT(dst_buffer && dst_buffer->buffer); + + struct read_end_impl *read_impl = read_end->impl_data; + if (!read_impl) { + return aws_raise_error(AWS_IO_BROKEN_PIPE); + } + + if (num_bytes_read) { + *num_bytes_read = 0; + } + + size_t num_bytes_to_read = dst_buffer->capacity - dst_buffer->len; + + ssize_t read_val = read(read_impl->handle.data.fd, dst_buffer->buffer + dst_buffer->len, num_bytes_to_read); + + if (read_val < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return aws_raise_error(AWS_IO_READ_WOULD_BLOCK); + } + return s_raise_posix_error(errno); + } + + /* Success */ + dst_buffer->len += read_val; + + if (num_bytes_read) { + *num_bytes_read = read_val; + } + + return AWS_OP_SUCCESS; +} + +static void s_read_end_on_event( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + void *user_data) { + + (void)event_loop; + (void)handle; + + /* Note that it should be impossible for this to run after read-end has been unsubscribed or cleaned up */ + struct aws_pipe_read_end *read_end = user_data; + struct read_end_impl *read_impl = read_end->impl_data; + AWS_ASSERT(read_impl); + AWS_ASSERT(read_impl->event_loop == event_loop); + AWS_ASSERT(&read_impl->handle == handle); + AWS_ASSERT(read_impl->is_subscribed); + AWS_ASSERT(events != 0); + AWS_ASSERT(read_impl->did_user_callback_clean_up_read_end == NULL); + + /* Set up handshake, so we can be informed if the read-end is cleaned up while invoking a user callback */ + bool did_user_callback_clean_up_read_end = false; + read_impl->did_user_callback_clean_up_read_end = &did_user_callback_clean_up_read_end; + + /* If readable event received, tell user to try and read, even if "error" events have also occurred. */ + if (events & AWS_IO_EVENT_TYPE_READABLE) { + read_impl->on_readable_user_callback(read_end, AWS_ERROR_SUCCESS, read_impl->on_readable_user_data); + + if (did_user_callback_clean_up_read_end) { + return; + } + + events &= ~AWS_IO_EVENT_TYPE_READABLE; + } + + if (events) { + /* Check that user didn't unsubscribe in the previous callback */ + if (read_impl->is_subscribed) { + read_impl->on_readable_user_callback(read_end, AWS_IO_BROKEN_PIPE, read_impl->on_readable_user_data); + + if (did_user_callback_clean_up_read_end) { + return; + } + } + } + + read_impl->did_user_callback_clean_up_read_end = NULL; +} + +int aws_pipe_subscribe_to_readable_events( + struct aws_pipe_read_end *read_end, + aws_pipe_on_readable_fn *on_readable, + void *user_data) { + + AWS_ASSERT(on_readable); + + struct read_end_impl *read_impl = read_end->impl_data; + if (!read_impl) { + return aws_raise_error(AWS_IO_BROKEN_PIPE); + } + + if (!aws_event_loop_thread_is_callers_thread(read_impl->event_loop)) { + return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); + } + + if (read_impl->is_subscribed) { + return aws_raise_error(AWS_ERROR_IO_ALREADY_SUBSCRIBED); + } + + read_impl->is_subscribed = true; + read_impl->on_readable_user_callback = on_readable; + read_impl->on_readable_user_data = user_data; + + int err = aws_event_loop_subscribe_to_io_events( + read_impl->event_loop, &read_impl->handle, AWS_IO_EVENT_TYPE_READABLE, s_read_end_on_event, read_end); + if (err) { + read_impl->is_subscribed = false; + read_impl->on_readable_user_callback = NULL; + read_impl->on_readable_user_data = NULL; + + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +int aws_pipe_unsubscribe_from_readable_events(struct aws_pipe_read_end *read_end) { + struct read_end_impl *read_impl = read_end->impl_data; + if (!read_impl) { + return aws_raise_error(AWS_IO_BROKEN_PIPE); + } + + if (!aws_event_loop_thread_is_callers_thread(read_impl->event_loop)) { + return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); + } + + if (!read_impl->is_subscribed) { + return aws_raise_error(AWS_ERROR_IO_NOT_SUBSCRIBED); + } + + int err = aws_event_loop_unsubscribe_from_io_events(read_impl->event_loop, &read_impl->handle); + if (err) { + return AWS_OP_ERR; + } + + read_impl->is_subscribed = false; + read_impl->on_readable_user_callback = NULL; + read_impl->on_readable_user_data = NULL; + + return AWS_OP_SUCCESS; +} + +/* Pop front write request, invoke its callback, and delete it. + * Returns whether the callback resulted in the write-end getting cleaned up */ +static bool s_write_end_complete_front_write_request(struct aws_pipe_write_end *write_end, int error_code) { + struct write_end_impl *write_impl = write_end->impl_data; + + AWS_ASSERT(!aws_linked_list_empty(&write_impl->write_list)); + struct aws_linked_list_node *node = aws_linked_list_pop_front(&write_impl->write_list); + struct write_request *request = AWS_CONTAINER_OF(node, struct write_request, list_node); + + struct aws_allocator *alloc = write_impl->alloc; + + /* Let the write-end know that a callback is in process, so the write-end can inform the callback + * whether it resulted in clean_up() being called. */ + bool write_end_cleaned_up_during_callback = false; + struct write_request *prev_invoking_request = write_impl->currently_invoking_write_callback; + write_impl->currently_invoking_write_callback = request; + + if (request->user_callback) { + request->user_callback(write_end, error_code, request->original_cursor, request->user_data); + write_end_cleaned_up_during_callback = request->did_user_callback_clean_up_write_end; + } + + if (!write_end_cleaned_up_during_callback) { + write_impl->currently_invoking_write_callback = prev_invoking_request; + } + + aws_mem_release(alloc, request); + + return write_end_cleaned_up_during_callback; +} + +/* Process write requests as long as the pipe remains writable */ +static void s_write_end_process_requests(struct aws_pipe_write_end *write_end) { + struct write_end_impl *write_impl = write_end->impl_data; + AWS_ASSERT(write_impl); + + while (!aws_linked_list_empty(&write_impl->write_list)) { + struct aws_linked_list_node *node = aws_linked_list_front(&write_impl->write_list); + struct write_request *request = AWS_CONTAINER_OF(node, struct write_request, list_node); + + int completed_error_code = AWS_ERROR_SUCCESS; + + if (request->cursor.len > 0) { + ssize_t write_val = write(write_impl->handle.data.fd, request->cursor.ptr, request->cursor.len); + + if (write_val < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + /* The pipe is no longer writable. Bail out */ + write_impl->is_writable = false; + return; + } + + /* A non-recoverable error occurred during this write */ + completed_error_code = s_translate_posix_error(errno); + + } else { + aws_byte_cursor_advance(&request->cursor, write_val); + + if (request->cursor.len > 0) { + /* There was a partial write, loop again to try and write the rest. */ + continue; + } + } + } + + /* If we got this far in the loop, then the write request is complete. + * Note that the callback may result in the pipe being cleaned up. */ + bool write_end_cleaned_up = s_write_end_complete_front_write_request(write_end, completed_error_code); + if (write_end_cleaned_up) { + /* Bail out! Any remaining requests were canceled during clean_up() */ + return; + } + } +} + +/* Handle events on the write-end's file handle */ +static void s_write_end_on_event( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + void *user_data) { + + (void)event_loop; + (void)handle; + + /* Note that it should be impossible for this to run after write-end has been unsubscribed or cleaned up */ + struct aws_pipe_write_end *write_end = user_data; + struct write_end_impl *write_impl = write_end->impl_data; + AWS_ASSERT(write_impl); + AWS_ASSERT(write_impl->event_loop == event_loop); + AWS_ASSERT(&write_impl->handle == handle); + + /* Only care about the writable event. */ + if ((events & AWS_IO_EVENT_TYPE_WRITABLE) == 0) { + return; + } + + write_impl->is_writable = true; + + s_write_end_process_requests(write_end); +} + +int aws_pipe_write( + struct aws_pipe_write_end *write_end, + struct aws_byte_cursor src_buffer, + aws_pipe_on_write_completed_fn *on_completed, + void *user_data) { + + AWS_ASSERT(src_buffer.ptr); + + struct write_end_impl *write_impl = write_end->impl_data; + if (!write_impl) { + return aws_raise_error(AWS_IO_BROKEN_PIPE); + } + + if (!aws_event_loop_thread_is_callers_thread(write_impl->event_loop)) { + return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); + } + + struct write_request *request = aws_mem_calloc(write_impl->alloc, 1, sizeof(struct write_request)); + if (!request) { + return AWS_OP_ERR; + } + + request->original_cursor = src_buffer; + request->cursor = src_buffer; + request->user_callback = on_completed; + request->user_data = user_data; + + aws_linked_list_push_back(&write_impl->write_list, &request->list_node); + + /* If the pipe is writable, process the request (unless pipe is already in the middle of processing, which could + * happen if a this aws_pipe_write() call was made by another write's completion callback */ + if (write_impl->is_writable && !write_impl->currently_invoking_write_callback) { + s_write_end_process_requests(write_end); + } + + return AWS_OP_SUCCESS; +} + +int aws_pipe_clean_up_write_end(struct aws_pipe_write_end *write_end) { + struct write_end_impl *write_impl = write_end->impl_data; + if (!write_impl) { + return aws_raise_error(AWS_IO_BROKEN_PIPE); + } + + if (!aws_event_loop_thread_is_callers_thread(write_impl->event_loop)) { + return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); + } + + int err = aws_event_loop_unsubscribe_from_io_events(write_impl->event_loop, &write_impl->handle); + if (err) { + return AWS_OP_ERR; + } + + close(write_impl->handle.data.fd); + + /* Zero out write-end before invoking user callbacks so that it won't work anymore with public functions. */ + AWS_ZERO_STRUCT(*write_end); + + /* If a request callback is currently being invoked, let it know that the write-end was cleaned up */ + if (write_impl->currently_invoking_write_callback) { + write_impl->currently_invoking_write_callback->did_user_callback_clean_up_write_end = true; + } + + /* Force any outstanding write requests to complete with an error status. */ + while (!aws_linked_list_empty(&write_impl->write_list)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&write_impl->write_list); + struct write_request *request = AWS_CONTAINER_OF(node, struct write_request, list_node); + if (request->user_callback) { + request->user_callback(NULL, AWS_IO_BROKEN_PIPE, request->original_cursor, request->user_data); + } + aws_mem_release(write_impl->alloc, request); + } + + aws_mem_release(write_impl->alloc, write_impl); + return AWS_OP_SUCCESS; +} diff --git a/contrib/restricted/aws/aws-c-io/source/posix/shared_library.c b/contrib/restricted/aws/aws-c-io/source/posix/shared_library.c index 751c99bc23..6261ea9ea8 100644 --- a/contrib/restricted/aws/aws-c-io/source/posix/shared_library.c +++ b/contrib/restricted/aws/aws-c-io/source/posix/shared_library.c @@ -1,66 +1,66 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/shared_library.h> - -#include <aws/io/logging.h> - -#include <dlfcn.h> - -static const char *s_null = "<NULL>"; -static const char *s_unknown_error = "<Unknown>"; - -int aws_shared_library_init(struct aws_shared_library *library, const char *library_path) { - AWS_ZERO_STRUCT(*library); - - library->library_handle = dlopen(library_path, RTLD_LAZY); - if (library->library_handle == NULL) { - const char *error = dlerror(); - AWS_LOGF_ERROR( - AWS_LS_IO_SHARED_LIBRARY, - "id=%p: Failed to load shared library at path \"%s\" with error: %s", - (void *)library, - library_path ? library_path : s_null, - error ? error : s_unknown_error); - return aws_raise_error(AWS_IO_SHARED_LIBRARY_LOAD_FAILURE); - } - - return AWS_OP_SUCCESS; -} - -void aws_shared_library_clean_up(struct aws_shared_library *library) { - if (library && library->library_handle) { - dlclose(library->library_handle); - library->library_handle = NULL; - } -} - -int aws_shared_library_find_function( - struct aws_shared_library *library, - const char *symbol_name, - aws_generic_function *function_address) { - if (library == NULL || library->library_handle == NULL) { - return aws_raise_error(AWS_IO_SHARED_LIBRARY_FIND_SYMBOL_FAILURE); - } - - /* - * Suggested work around for (undefined behavior) cast from void * to function pointer - * in POSIX.1-2003 standard, at least according to dlsym man page code sample. - */ - *(void **)(function_address) = dlsym(library->library_handle, symbol_name); - - if (*function_address == NULL) { - const char *error = dlerror(); - AWS_LOGF_ERROR( - AWS_LS_IO_SHARED_LIBRARY, - "id=%p: Failed to find shared library symbol \"%s\" with error: %s", - (void *)library, - symbol_name ? symbol_name : s_null, - error ? error : s_unknown_error); - return aws_raise_error(AWS_IO_SHARED_LIBRARY_FIND_SYMBOL_FAILURE); - } - - return AWS_OP_SUCCESS; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/shared_library.h> + +#include <aws/io/logging.h> + +#include <dlfcn.h> + +static const char *s_null = "<NULL>"; +static const char *s_unknown_error = "<Unknown>"; + +int aws_shared_library_init(struct aws_shared_library *library, const char *library_path) { + AWS_ZERO_STRUCT(*library); + + library->library_handle = dlopen(library_path, RTLD_LAZY); + if (library->library_handle == NULL) { + const char *error = dlerror(); + AWS_LOGF_ERROR( + AWS_LS_IO_SHARED_LIBRARY, + "id=%p: Failed to load shared library at path \"%s\" with error: %s", + (void *)library, + library_path ? library_path : s_null, + error ? error : s_unknown_error); + return aws_raise_error(AWS_IO_SHARED_LIBRARY_LOAD_FAILURE); + } + + return AWS_OP_SUCCESS; +} + +void aws_shared_library_clean_up(struct aws_shared_library *library) { + if (library && library->library_handle) { + dlclose(library->library_handle); + library->library_handle = NULL; + } +} + +int aws_shared_library_find_function( + struct aws_shared_library *library, + const char *symbol_name, + aws_generic_function *function_address) { + if (library == NULL || library->library_handle == NULL) { + return aws_raise_error(AWS_IO_SHARED_LIBRARY_FIND_SYMBOL_FAILURE); + } + + /* + * Suggested work around for (undefined behavior) cast from void * to function pointer + * in POSIX.1-2003 standard, at least according to dlsym man page code sample. + */ + *(void **)(function_address) = dlsym(library->library_handle, symbol_name); + + if (*function_address == NULL) { + const char *error = dlerror(); + AWS_LOGF_ERROR( + AWS_LS_IO_SHARED_LIBRARY, + "id=%p: Failed to find shared library symbol \"%s\" with error: %s", + (void *)library, + symbol_name ? symbol_name : s_null, + error ? error : s_unknown_error); + return aws_raise_error(AWS_IO_SHARED_LIBRARY_FIND_SYMBOL_FAILURE); + } + + return AWS_OP_SUCCESS; +} diff --git a/contrib/restricted/aws/aws-c-io/source/posix/socket.c b/contrib/restricted/aws/aws-c-io/source/posix/socket.c index 5f11cdff52..7ac30b39c2 100644 --- a/contrib/restricted/aws/aws-c-io/source/posix/socket.c +++ b/contrib/restricted/aws/aws-c-io/source/posix/socket.c @@ -1,1777 +1,1777 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/socket.h> - -#include <aws/common/clock.h> -#include <aws/common/condition_variable.h> -#include <aws/common/mutex.h> -#include <aws/common/string.h> - -#include <aws/io/event_loop.h> -#include <aws/io/logging.h> - -#include <arpa/inet.h> -#include <aws/io/io.h> -#include <errno.h> -#include <fcntl.h> -#include <netinet/tcp.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#if defined(__MACH__) -# define NO_SIGNAL SO_NOSIGPIPE -# define TCP_KEEPIDLE TCP_KEEPALIVE -#else -# define NO_SIGNAL MSG_NOSIGNAL -#endif - -/* This isn't defined on ancient linux distros (breaking the builds). - * However, if this is a prebuild, we purposely build on an ancient system, but - * we want the kernel calls to still be the same as a modern build since that's likely the target of the application - * calling this code. Just define this if it isn't there already. GlibC and the kernel don't really care how the flag - * gets passed as long as it does. - */ -#ifndef O_CLOEXEC -# define O_CLOEXEC 02000000 -#endif - -#ifdef USE_VSOCK -# if defined(__linux__) && defined(AF_VSOCK) -# include <linux/vm_sockets.h> -# else -# error "USE_VSOCK not supported on current platform" -# endif -#endif - -/* other than CONNECTED_READ | CONNECTED_WRITE - * a socket is only in one of these states at a time. */ -enum socket_state { - INIT = 0x01, - CONNECTING = 0x02, - CONNECTED_READ = 0x04, - CONNECTED_WRITE = 0x08, - BOUND = 0x10, - LISTENING = 0x20, - TIMEDOUT = 0x40, - ERROR = 0x80, - CLOSED, -}; - -static int s_convert_domain(enum aws_socket_domain domain) { - switch (domain) { - case AWS_SOCKET_IPV4: - return AF_INET; - case AWS_SOCKET_IPV6: - return AF_INET6; - case AWS_SOCKET_LOCAL: - return AF_UNIX; -#ifdef USE_VSOCK - case AWS_SOCKET_VSOCK: - return AF_VSOCK; -#endif - default: - AWS_ASSERT(0); - return AF_INET; - } -} - -static int s_convert_type(enum aws_socket_type type) { - switch (type) { - case AWS_SOCKET_STREAM: - return SOCK_STREAM; - case AWS_SOCKET_DGRAM: - return SOCK_DGRAM; - default: - AWS_ASSERT(0); - return SOCK_STREAM; - } -} - -static int s_determine_socket_error(int error) { - switch (error) { - case ECONNREFUSED: - return AWS_IO_SOCKET_CONNECTION_REFUSED; - case ETIMEDOUT: - return AWS_IO_SOCKET_TIMEOUT; - case EHOSTUNREACH: - case ENETUNREACH: - return AWS_IO_SOCKET_NO_ROUTE_TO_HOST; - case EADDRNOTAVAIL: - return AWS_IO_SOCKET_INVALID_ADDRESS; - case ENETDOWN: - return AWS_IO_SOCKET_NETWORK_DOWN; - case ECONNABORTED: - return AWS_IO_SOCKET_CONNECT_ABORTED; - case EADDRINUSE: - return AWS_IO_SOCKET_ADDRESS_IN_USE; - case ENOBUFS: - case ENOMEM: - return AWS_ERROR_OOM; - case EAGAIN: - return AWS_IO_READ_WOULD_BLOCK; - case EMFILE: - case ENFILE: - return AWS_ERROR_MAX_FDS_EXCEEDED; - case ENOENT: - case EINVAL: - return AWS_ERROR_FILE_INVALID_PATH; - case EAFNOSUPPORT: - return AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY; - case EACCES: - return AWS_ERROR_NO_PERMISSION; - default: - return AWS_IO_SOCKET_NOT_CONNECTED; - } -} - -static int s_create_socket(struct aws_socket *sock, const struct aws_socket_options *options) { - - int fd = socket(s_convert_domain(options->domain), s_convert_type(options->type), 0); - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: initializing with domain %d and type %d", - (void *)sock, - fd, - options->domain, - options->type); - if (fd != -1) { - int flags = fcntl(fd, F_GETFL, 0); - flags |= O_NONBLOCK | O_CLOEXEC; - int success = fcntl(fd, F_SETFL, flags); - (void)success; - sock->io_handle.data.fd = fd; - sock->io_handle.additional_data = NULL; - return aws_socket_set_options(sock, options); - } - - int aws_error = s_determine_socket_error(errno); - return aws_raise_error(aws_error); -} - -struct posix_socket_connect_args { - struct aws_task task; - struct aws_allocator *allocator; - struct aws_socket *socket; -}; - -struct posix_socket { - struct aws_linked_list write_queue; - struct posix_socket_connect_args *connect_args; - bool write_in_progress; - bool currently_subscribed; - bool continue_accept; - bool currently_in_event; - bool clean_yourself_up; - bool *close_happened; -}; - -static int s_socket_init( - struct aws_socket *socket, - struct aws_allocator *alloc, - const struct aws_socket_options *options, - int existing_socket_fd) { - AWS_ASSERT(options); - AWS_ZERO_STRUCT(*socket); - - struct posix_socket *posix_socket = aws_mem_calloc(alloc, 1, sizeof(struct posix_socket)); - if (!posix_socket) { - socket->impl = NULL; - return AWS_OP_ERR; - } - - socket->allocator = alloc; - socket->io_handle.data.fd = -1; - socket->state = INIT; - socket->options = *options; - - if (existing_socket_fd < 0) { - int err = s_create_socket(socket, options); - if (err) { - aws_mem_release(alloc, posix_socket); - socket->impl = NULL; - return AWS_OP_ERR; - } - } else { - socket->io_handle = (struct aws_io_handle){ - .data = {.fd = existing_socket_fd}, - .additional_data = NULL, - }; - aws_socket_set_options(socket, options); - } - - aws_linked_list_init(&posix_socket->write_queue); - posix_socket->write_in_progress = false; - posix_socket->currently_subscribed = false; - posix_socket->continue_accept = false; - posix_socket->currently_in_event = false; - posix_socket->clean_yourself_up = false; - posix_socket->connect_args = NULL; - posix_socket->close_happened = NULL; - socket->impl = posix_socket; - return AWS_OP_SUCCESS; -} - -int aws_socket_init(struct aws_socket *socket, struct aws_allocator *alloc, const struct aws_socket_options *options) { - AWS_ASSERT(options); - return s_socket_init(socket, alloc, options, -1); -} - -void aws_socket_clean_up(struct aws_socket *socket) { - if (!socket->impl) { - /* protect from double clean */ - return; - } - if (aws_socket_is_open(socket)) { - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, "id=%p fd=%d: is still open, closing...", (void *)socket, socket->io_handle.data.fd); - aws_socket_close(socket); - } - struct posix_socket *socket_impl = socket->impl; - - if (!socket_impl->currently_in_event) { - aws_mem_release(socket->allocator, socket->impl); - } else { - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: is still pending io letting it dangle and cleaning up later.", - (void *)socket, - socket->io_handle.data.fd); - socket_impl->clean_yourself_up = true; - } - - AWS_ZERO_STRUCT(*socket); - socket->io_handle.data.fd = -1; -} - -static void s_on_connection_error(struct aws_socket *socket, int error); - -static int s_on_connection_success(struct aws_socket *socket) { - - struct aws_event_loop *event_loop = socket->event_loop; - struct posix_socket *socket_impl = socket->impl; - - if (socket_impl->currently_subscribed) { - aws_event_loop_unsubscribe_from_io_events(socket->event_loop, &socket->io_handle); - socket_impl->currently_subscribed = false; - } - - socket->event_loop = NULL; - - int connect_result; - socklen_t result_length = sizeof(connect_result); - - if (getsockopt(socket->io_handle.data.fd, SOL_SOCKET, SO_ERROR, &connect_result, &result_length) < 0) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: failed to determine connection error %d", - (void *)socket, - socket->io_handle.data.fd, - errno); - int aws_error = s_determine_socket_error(errno); - aws_raise_error(aws_error); - s_on_connection_error(socket, aws_error); - return AWS_OP_ERR; - } - - if (connect_result) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: connection error %d", - (void *)socket, - socket->io_handle.data.fd, - connect_result); - int aws_error = s_determine_socket_error(connect_result); - aws_raise_error(aws_error); - s_on_connection_error(socket, aws_error); - return AWS_OP_ERR; - } - - AWS_LOGF_INFO(AWS_LS_IO_SOCKET, "id=%p fd=%d: connection success", (void *)socket, socket->io_handle.data.fd); - - struct sockaddr_storage address; - AWS_ZERO_STRUCT(address); - socklen_t address_size = sizeof(address); - if (!getsockname(socket->io_handle.data.fd, (struct sockaddr *)&address, &address_size)) { - uint16_t port = 0; - - if (address.ss_family == AF_INET) { - struct sockaddr_in *s = (struct sockaddr_in *)&address; - port = ntohs(s->sin_port); - /* this comes straight from the kernal. a.) they won't fail. b.) even if they do, it's not fatal - * once we add logging, we can log this if it fails. */ - if (inet_ntop( - AF_INET, &s->sin_addr, socket->local_endpoint.address, sizeof(socket->local_endpoint.address))) { - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: local endpoint %s:%d", - (void *)socket, - socket->io_handle.data.fd, - socket->local_endpoint.address, - port); - } else { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: determining local endpoint failed", - (void *)socket, - socket->io_handle.data.fd); - } - } else if (address.ss_family == AF_INET6) { - struct sockaddr_in6 *s = (struct sockaddr_in6 *)&address; - port = ntohs(s->sin6_port); - /* this comes straight from the kernal. a.) they won't fail. b.) even if they do, it's not fatal - * once we add logging, we can log this if it fails. */ - if (inet_ntop( - AF_INET6, &s->sin6_addr, socket->local_endpoint.address, sizeof(socket->local_endpoint.address))) { - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd %d: local endpoint %s:%d", - (void *)socket, - socket->io_handle.data.fd, - socket->local_endpoint.address, - port); - } else { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: determining local endpoint failed", - (void *)socket, - socket->io_handle.data.fd); - } - } - - socket->local_endpoint.port = port; - } else { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: getsockname() failed with error %d", - (void *)socket, - socket->io_handle.data.fd, - errno); - int aws_error = s_determine_socket_error(errno); - aws_raise_error(aws_error); - s_on_connection_error(socket, aws_error); - return AWS_OP_ERR; - } - - socket->state = CONNECTED_WRITE | CONNECTED_READ; - - if (aws_socket_assign_to_event_loop(socket, event_loop)) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: assignment to event loop %p failed with error %d", - (void *)socket, - socket->io_handle.data.fd, - (void *)event_loop, - aws_last_error()); - s_on_connection_error(socket, aws_last_error()); - return AWS_OP_ERR; - } - - socket->connection_result_fn(socket, AWS_ERROR_SUCCESS, socket->connect_accept_user_data); - - return AWS_OP_SUCCESS; -} - -static void s_on_connection_error(struct aws_socket *socket, int error) { - socket->state = ERROR; - AWS_LOGF_ERROR(AWS_LS_IO_SOCKET, "id=%p fd=%d: connection failure", (void *)socket, socket->io_handle.data.fd); - if (socket->connection_result_fn) { - socket->connection_result_fn(socket, error, socket->connect_accept_user_data); - } else if (socket->accept_result_fn) { - socket->accept_result_fn(socket, error, NULL, socket->connect_accept_user_data); - } -} - -/* the next two callbacks compete based on which one runs first. if s_socket_connect_event - * comes back first, then we set socket_args->socket = NULL and continue on with the connection. - * if s_handle_socket_timeout() runs first, is sees socket_args->socket is NULL and just cleans up its memory. - * s_handle_socket_timeout() will always run so the memory for socket_connect_args is always cleaned up there. */ -static void s_socket_connect_event( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - void *user_data) { - - (void)event_loop; - (void)handle; - - struct posix_socket_connect_args *socket_args = (struct posix_socket_connect_args *)user_data; - AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "fd=%d: connection activity handler triggered ", handle->data.fd); - - if (socket_args->socket) { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: has not timed out yet proceeding with connection.", - (void *)socket_args->socket, - handle->data.fd); - - struct posix_socket *socket_impl = socket_args->socket->impl; - if (!(events & AWS_IO_EVENT_TYPE_ERROR || events & AWS_IO_EVENT_TYPE_CLOSED) && - (events & AWS_IO_EVENT_TYPE_READABLE || events & AWS_IO_EVENT_TYPE_WRITABLE)) { - struct aws_socket *socket = socket_args->socket; - socket_args->socket = NULL; - socket_impl->connect_args = NULL; - s_on_connection_success(socket); - return; - } - - int aws_error = aws_socket_get_error(socket_args->socket); - /* we'll get another notification. */ - if (aws_error == AWS_IO_READ_WOULD_BLOCK) { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: spurious event, waiting for another notification.", - (void *)socket_args->socket, - handle->data.fd); - return; - } - - struct aws_socket *socket = socket_args->socket; - socket_args->socket = NULL; - socket_impl->connect_args = NULL; - aws_raise_error(aws_error); - s_on_connection_error(socket, aws_error); - } -} - -static void s_handle_socket_timeout(struct aws_task *task, void *args, aws_task_status status) { - (void)task; - (void)status; - - struct posix_socket_connect_args *socket_args = args; - - AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "task_id=%p: timeout task triggered, evaluating timeouts.", (void *)task); - /* successful connection will have nulled out connect_args->socket */ - if (socket_args->socket) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: timed out, shutting down.", - (void *)socket_args->socket, - socket_args->socket->io_handle.data.fd); - - socket_args->socket->state = TIMEDOUT; - int error_code = AWS_IO_SOCKET_TIMEOUT; - - if (status == AWS_TASK_STATUS_RUN_READY) { - aws_event_loop_unsubscribe_from_io_events(socket_args->socket->event_loop, &socket_args->socket->io_handle); - } else { - error_code = AWS_IO_EVENT_LOOP_SHUTDOWN; - aws_event_loop_free_io_event_resources(socket_args->socket->event_loop, &socket_args->socket->io_handle); - } - socket_args->socket->event_loop = NULL; - struct posix_socket *socket_impl = socket_args->socket->impl; - socket_impl->currently_subscribed = false; - aws_raise_error(error_code); - struct aws_socket *socket = socket_args->socket; - /*socket close sets socket_args->socket to NULL and - * socket_impl->connect_args to NULL. */ - aws_socket_close(socket); - s_on_connection_error(socket, error_code); - } - - aws_mem_release(socket_args->allocator, socket_args); -} - -/* this is used simply for moving a connect_success callback when the connect finished immediately - * (like for unix domain sockets) into the event loop's thread. Also note, in that case there was no - * timeout task scheduled, so in this case the socket_args are cleaned up. */ -static void s_run_connect_success(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)task; - struct posix_socket_connect_args *socket_args = arg; - - if (socket_args->socket) { - struct posix_socket *socket_impl = socket_args->socket->impl; - if (status == AWS_TASK_STATUS_RUN_READY) { - s_on_connection_success(socket_args->socket); - } else { - aws_raise_error(AWS_IO_SOCKET_CONNECT_ABORTED); - socket_args->socket->event_loop = NULL; - s_on_connection_error(socket_args->socket, AWS_IO_SOCKET_CONNECT_ABORTED); - } - socket_impl->connect_args = NULL; - } - - aws_mem_release(socket_args->allocator, socket_args); -} - -static inline int s_convert_pton_error(int pton_code) { - if (pton_code == 0) { - return AWS_IO_SOCKET_INVALID_ADDRESS; - } - - return s_determine_socket_error(errno); -} - -struct socket_address { - union sock_addr_types { - struct sockaddr_in addr_in; - struct sockaddr_in6 addr_in6; - struct sockaddr_un un_addr; -#ifdef USE_VSOCK - struct sockaddr_vm vm_addr; -#endif - } sock_addr_types; -}; - -#ifdef USE_VSOCK -/** Convert a string to a VSOCK CID. Respects the calling convetion of inet_pton: - * 0 on error, 1 on success. */ -static int parse_cid(const char *cid_str, unsigned int *value) { - if (cid_str == NULL || value == NULL) { - errno = EINVAL; - return 0; - } - /* strtoll returns 0 as both error and correct value */ - errno = 0; - /* unsigned long long to handle edge cases in convention explicitly */ - long long cid = strtoll(cid_str, NULL, 10); - if (errno != 0) { - return 0; - } - - /* -1U means any, so it's a valid value, but it needs to be converted to - * unsigned int. */ - if (cid == -1) { - *value = VMADDR_CID_ANY; - return 1; - } - - if (cid < 0 || cid > UINT_MAX) { - errno = ERANGE; - return 0; - } - - /* cast is safe here, edge cases already checked */ - *value = (unsigned int)cid; - return 1; -} -#endif - -int aws_socket_connect( - struct aws_socket *socket, - const struct aws_socket_endpoint *remote_endpoint, - struct aws_event_loop *event_loop, - aws_socket_on_connection_result_fn *on_connection_result, - void *user_data) { - AWS_ASSERT(event_loop); - AWS_ASSERT(!socket->event_loop); - - AWS_LOGF_DEBUG(AWS_LS_IO_SOCKET, "id=%p fd=%d: beginning connect.", (void *)socket, socket->io_handle.data.fd); - - if (socket->event_loop) { - return aws_raise_error(AWS_IO_EVENT_LOOP_ALREADY_ASSIGNED); - } - - if (socket->options.type != AWS_SOCKET_DGRAM) { - AWS_ASSERT(on_connection_result); - if (socket->state != INIT) { - return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); - } - } else { /* UDP socket */ - /* UDP sockets jump to CONNECT_READ if bind is called first */ - if (socket->state != CONNECTED_READ && socket->state != INIT) { - return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); - } - } - - size_t address_strlen; - if (aws_secure_strlen(remote_endpoint->address, AWS_ADDRESS_MAX_LEN, &address_strlen)) { - return AWS_OP_ERR; - } - - struct socket_address address; - AWS_ZERO_STRUCT(address); - socklen_t sock_size = 0; - int pton_err = 1; - if (socket->options.domain == AWS_SOCKET_IPV4) { - pton_err = inet_pton(AF_INET, remote_endpoint->address, &address.sock_addr_types.addr_in.sin_addr); - address.sock_addr_types.addr_in.sin_port = htons(remote_endpoint->port); - address.sock_addr_types.addr_in.sin_family = AF_INET; - sock_size = sizeof(address.sock_addr_types.addr_in); - } else if (socket->options.domain == AWS_SOCKET_IPV6) { - pton_err = inet_pton(AF_INET6, remote_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr); - address.sock_addr_types.addr_in6.sin6_port = htons(remote_endpoint->port); - address.sock_addr_types.addr_in6.sin6_family = AF_INET6; - sock_size = sizeof(address.sock_addr_types.addr_in6); - } else if (socket->options.domain == AWS_SOCKET_LOCAL) { - address.sock_addr_types.un_addr.sun_family = AF_UNIX; - strncpy(address.sock_addr_types.un_addr.sun_path, remote_endpoint->address, AWS_ADDRESS_MAX_LEN); - sock_size = sizeof(address.sock_addr_types.un_addr); -#ifdef USE_VSOCK - } else if (socket->options.domain == AWS_SOCKET_VSOCK) { - pton_err = parse_cid(remote_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid); - address.sock_addr_types.vm_addr.svm_family = AF_VSOCK; - address.sock_addr_types.vm_addr.svm_port = (unsigned int)remote_endpoint->port; - sock_size = sizeof(address.sock_addr_types.vm_addr); -#endif - } else { - AWS_ASSERT(0); - return aws_raise_error(AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY); - } - - if (pton_err != 1) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: failed to parse address %s:%d.", - (void *)socket, - socket->io_handle.data.fd, - remote_endpoint->address, - (int)remote_endpoint->port); - return aws_raise_error(s_convert_pton_error(pton_err)); - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: connecting to endpoint %s:%d.", - (void *)socket, - socket->io_handle.data.fd, - remote_endpoint->address, - (int)remote_endpoint->port); - - socket->state = CONNECTING; - socket->remote_endpoint = *remote_endpoint; - socket->connect_accept_user_data = user_data; - socket->connection_result_fn = on_connection_result; - - struct posix_socket *socket_impl = socket->impl; - - socket_impl->connect_args = aws_mem_calloc(socket->allocator, 1, sizeof(struct posix_socket_connect_args)); - if (!socket_impl->connect_args) { - return AWS_OP_ERR; - } - - socket_impl->connect_args->socket = socket; - socket_impl->connect_args->allocator = socket->allocator; - - socket_impl->connect_args->task.fn = s_handle_socket_timeout; - socket_impl->connect_args->task.arg = socket_impl->connect_args; - - int error_code = connect(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size); - socket->event_loop = event_loop; - - if (!error_code) { - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: connected immediately, not scheduling timeout.", - (void *)socket, - socket->io_handle.data.fd); - socket_impl->connect_args->task.fn = s_run_connect_success; - /* the subscription for IO will happen once we setup the connection in the task. Since we already - * know the connection succeeded, we don't need to register for events yet. */ - aws_event_loop_schedule_task_now(event_loop, &socket_impl->connect_args->task); - } - - if (error_code) { - error_code = errno; - if (error_code == EINPROGRESS || error_code == EALREADY) { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: connection pending waiting on event-loop notification or timeout.", - (void *)socket, - socket->io_handle.data.fd); - /* cache the timeout task; it is possible for the IO subscription to come back virtually immediately - * and null out the connect args */ - struct aws_task *timeout_task = &socket_impl->connect_args->task; - - socket_impl->currently_subscribed = true; - /* This event is for when the connection finishes. (the fd will flip writable). */ - if (aws_event_loop_subscribe_to_io_events( - event_loop, - &socket->io_handle, - AWS_IO_EVENT_TYPE_WRITABLE, - s_socket_connect_event, - socket_impl->connect_args)) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: failed to register with event-loop %p.", - (void *)socket, - socket->io_handle.data.fd, - (void *)event_loop); - socket_impl->currently_subscribed = false; - socket->event_loop = NULL; - goto err_clean_up; - } - - /* schedule a task to run at the connect timeout interval, if this task runs before the connect - * happens, we consider that a timeout. */ - uint64_t timeout = 0; - aws_event_loop_current_clock_time(event_loop, &timeout); - timeout += aws_timestamp_convert( - socket->options.connect_timeout_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL); - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: scheduling timeout task for %llu.", - (void *)socket, - socket->io_handle.data.fd, - (unsigned long long)timeout); - aws_event_loop_schedule_task_future(event_loop, timeout_task, timeout); - } else { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: connect failed with error code %d.", - (void *)socket, - socket->io_handle.data.fd, - error_code); - int aws_error = s_determine_socket_error(error_code); - aws_raise_error(aws_error); - socket->event_loop = NULL; - socket_impl->currently_subscribed = false; - goto err_clean_up; - } - } - return AWS_OP_SUCCESS; - -err_clean_up: - aws_mem_release(socket->allocator, socket_impl->connect_args); - socket_impl->connect_args = NULL; - return AWS_OP_ERR; -} - -int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint) { - if (socket->state != INIT) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: invalid state for bind operation.", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); - } - - size_t address_strlen; - if (aws_secure_strlen(local_endpoint->address, AWS_ADDRESS_MAX_LEN, &address_strlen)) { - return AWS_OP_ERR; - } - - int error_code = -1; - - socket->local_endpoint = *local_endpoint; - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: binding to %s:%d.", - (void *)socket, - socket->io_handle.data.fd, - local_endpoint->address, - (int)local_endpoint->port); - - struct socket_address address; - AWS_ZERO_STRUCT(address); - socklen_t sock_size = 0; - int pton_err = 1; - if (socket->options.domain == AWS_SOCKET_IPV4) { - pton_err = inet_pton(AF_INET, local_endpoint->address, &address.sock_addr_types.addr_in.sin_addr); - address.sock_addr_types.addr_in.sin_port = htons(local_endpoint->port); - address.sock_addr_types.addr_in.sin_family = AF_INET; - sock_size = sizeof(address.sock_addr_types.addr_in); - } else if (socket->options.domain == AWS_SOCKET_IPV6) { - pton_err = inet_pton(AF_INET6, local_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr); - address.sock_addr_types.addr_in6.sin6_port = htons(local_endpoint->port); - address.sock_addr_types.addr_in6.sin6_family = AF_INET6; - sock_size = sizeof(address.sock_addr_types.addr_in6); - } else if (socket->options.domain == AWS_SOCKET_LOCAL) { - address.sock_addr_types.un_addr.sun_family = AF_UNIX; - strncpy(address.sock_addr_types.un_addr.sun_path, local_endpoint->address, AWS_ADDRESS_MAX_LEN); - sock_size = sizeof(address.sock_addr_types.un_addr); -#ifdef USE_VSOCK - } else if (socket->options.domain == AWS_SOCKET_VSOCK) { - pton_err = parse_cid(local_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid); - address.sock_addr_types.vm_addr.svm_family = AF_VSOCK; - address.sock_addr_types.vm_addr.svm_port = (unsigned int)local_endpoint->port; - sock_size = sizeof(address.sock_addr_types.vm_addr); -#endif - } else { - AWS_ASSERT(0); - return aws_raise_error(AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY); - } - - if (pton_err != 1) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: failed to parse address %s:%d.", - (void *)socket, - socket->io_handle.data.fd, - local_endpoint->address, - (int)local_endpoint->port); - return aws_raise_error(s_convert_pton_error(pton_err)); - } - - error_code = bind(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size); - - if (!error_code) { - if (socket->options.type == AWS_SOCKET_STREAM) { - socket->state = BOUND; - } else { - /* e.g. UDP is now readable */ - socket->state = CONNECTED_READ; - } - AWS_LOGF_DEBUG(AWS_LS_IO_SOCKET, "id=%p fd=%d: successfully bound", (void *)socket, socket->io_handle.data.fd); - - return AWS_OP_SUCCESS; - } - - socket->state = ERROR; - error_code = errno; - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: bind failed with error code %d", - (void *)socket, - socket->io_handle.data.fd, - error_code); - - int aws_error = s_determine_socket_error(error_code); - return aws_raise_error(aws_error); -} - -int aws_socket_listen(struct aws_socket *socket, int backlog_size) { - if (socket->state != BOUND) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: invalid state for listen operation. You must call bind first.", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); - } - - int error_code = listen(socket->io_handle.data.fd, backlog_size); - - if (!error_code) { - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, "id=%p fd=%d: successfully listening", (void *)socket, socket->io_handle.data.fd); - socket->state = LISTENING; - return AWS_OP_SUCCESS; - } - - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: listen failed with error code %d", - (void *)socket, - socket->io_handle.data.fd, - error_code); - error_code = errno; - socket->state = ERROR; - - return aws_raise_error(s_determine_socket_error(error_code)); -} - -/* this is called by the event loop handler that was installed in start_accept(). It runs once the FD goes readable, - * accepts as many as it can and then returns control to the event loop. */ -static void s_socket_accept_event( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - void *user_data) { - - (void)event_loop; - - struct aws_socket *socket = user_data; - struct posix_socket *socket_impl = socket->impl; - - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, "id=%p fd=%d: listening event received", (void *)socket, socket->io_handle.data.fd); - - if (socket_impl->continue_accept && events & AWS_IO_EVENT_TYPE_READABLE) { - int in_fd = 0; - while (socket_impl->continue_accept && in_fd != -1) { - struct sockaddr_storage in_addr; - socklen_t in_len = sizeof(struct sockaddr_storage); - - in_fd = accept(handle->data.fd, (struct sockaddr *)&in_addr, &in_len); - if (in_fd == -1) { - int error = errno; - - if (error == EAGAIN || error == EWOULDBLOCK) { - break; - } - - int aws_error = aws_socket_get_error(socket); - aws_raise_error(aws_error); - s_on_connection_error(socket, aws_error); - break; - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, "id=%p fd=%d: incoming connection", (void *)socket, socket->io_handle.data.fd); - - struct aws_socket *new_sock = aws_mem_acquire(socket->allocator, sizeof(struct aws_socket)); - - if (!new_sock) { - close(in_fd); - s_on_connection_error(socket, aws_last_error()); - continue; - } - - if (s_socket_init(new_sock, socket->allocator, &socket->options, in_fd)) { - aws_mem_release(socket->allocator, new_sock); - s_on_connection_error(socket, aws_last_error()); - continue; - } - - new_sock->local_endpoint = socket->local_endpoint; - new_sock->state = CONNECTED_READ | CONNECTED_WRITE; - uint16_t port = 0; - - /* get the info on the incoming socket's address */ - if (in_addr.ss_family == AF_INET) { - struct sockaddr_in *s = (struct sockaddr_in *)&in_addr; - port = ntohs(s->sin_port); - /* this came from the kernel, a.) it won't fail. b.) even if it does - * its not fatal. come back and add logging later. */ - if (!inet_ntop( - AF_INET, - &s->sin_addr, - new_sock->remote_endpoint.address, - sizeof(new_sock->remote_endpoint.address))) { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d:. Failed to determine remote address.", - (void *)socket, - socket->io_handle.data.fd) - } - new_sock->options.domain = AWS_SOCKET_IPV4; - } else if (in_addr.ss_family == AF_INET6) { - /* this came from the kernel, a.) it won't fail. b.) even if it does - * its not fatal. come back and add logging later. */ - struct sockaddr_in6 *s = (struct sockaddr_in6 *)&in_addr; - port = ntohs(s->sin6_port); - if (!inet_ntop( - AF_INET6, - &s->sin6_addr, - new_sock->remote_endpoint.address, - sizeof(new_sock->remote_endpoint.address))) { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d:. Failed to determine remote address.", - (void *)socket, - socket->io_handle.data.fd) - } - new_sock->options.domain = AWS_SOCKET_IPV6; - } else if (in_addr.ss_family == AF_UNIX) { - new_sock->remote_endpoint = socket->local_endpoint; - new_sock->options.domain = AWS_SOCKET_LOCAL; - } - - new_sock->remote_endpoint.port = port; - - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: connected to %s:%d, incoming fd %d", - (void *)socket, - socket->io_handle.data.fd, - new_sock->remote_endpoint.address, - new_sock->remote_endpoint.port, - in_fd); - - int flags = fcntl(in_fd, F_GETFL, 0); - - flags |= O_NONBLOCK | O_CLOEXEC; - fcntl(in_fd, F_SETFL, flags); - - bool close_occurred = false; - socket_impl->close_happened = &close_occurred; - socket->accept_result_fn(socket, AWS_ERROR_SUCCESS, new_sock, socket->connect_accept_user_data); - - if (close_occurred) { - return; - } - - socket_impl->close_happened = NULL; - } - } - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: finished processing incoming connections, " - "waiting on event-loop notification", - (void *)socket, - socket->io_handle.data.fd); -} - -int aws_socket_start_accept( - struct aws_socket *socket, - struct aws_event_loop *accept_loop, - aws_socket_on_accept_result_fn *on_accept_result, - void *user_data) { - AWS_ASSERT(on_accept_result); - AWS_ASSERT(accept_loop); - - if (socket->event_loop) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: is already assigned to event-loop %p.", - (void *)socket, - socket->io_handle.data.fd, - (void *)socket->event_loop); - return aws_raise_error(AWS_IO_EVENT_LOOP_ALREADY_ASSIGNED); - } - - if (socket->state != LISTENING) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: invalid state for start_accept operation. You must call listen first.", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); - } - - socket->accept_result_fn = on_accept_result; - socket->connect_accept_user_data = user_data; - socket->event_loop = accept_loop; - struct posix_socket *socket_impl = socket->impl; - socket_impl->continue_accept = true; - socket_impl->currently_subscribed = true; - - if (aws_event_loop_subscribe_to_io_events( - socket->event_loop, &socket->io_handle, AWS_IO_EVENT_TYPE_READABLE, s_socket_accept_event, socket)) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: failed to subscribe to event-loop %p.", - (void *)socket, - socket->io_handle.data.fd, - (void *)socket->event_loop); - socket_impl->continue_accept = false; - socket_impl->currently_subscribed = false; - socket->event_loop = NULL; - - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -struct stop_accept_args { - struct aws_task task; - struct aws_mutex mutex; - struct aws_condition_variable condition_variable; - struct aws_socket *socket; - int ret_code; - bool invoked; -}; - -static bool s_stop_accept_pred(void *arg) { - struct stop_accept_args *stop_accept_args = arg; - return stop_accept_args->invoked; -} - -static void s_stop_accept_task(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)task; - (void)status; - - struct stop_accept_args *stop_accept_args = arg; - aws_mutex_lock(&stop_accept_args->mutex); - stop_accept_args->ret_code = AWS_OP_SUCCESS; - if (aws_socket_stop_accept(stop_accept_args->socket)) { - stop_accept_args->ret_code = aws_last_error(); - } - stop_accept_args->invoked = true; - aws_condition_variable_notify_one(&stop_accept_args->condition_variable); - aws_mutex_unlock(&stop_accept_args->mutex); -} - -int aws_socket_stop_accept(struct aws_socket *socket) { - if (socket->state != LISTENING) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: is not in a listening state, can't stop_accept.", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); - } - - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, "id=%p fd=%d: stopping accepting new connections", (void *)socket, socket->io_handle.data.fd); - - if (!aws_event_loop_thread_is_callers_thread(socket->event_loop)) { - struct stop_accept_args args = {.mutex = AWS_MUTEX_INIT, - .condition_variable = AWS_CONDITION_VARIABLE_INIT, - .invoked = false, - .socket = socket, - .ret_code = AWS_OP_SUCCESS, - .task = {.fn = s_stop_accept_task}}; - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: stopping accepting new connections from a different thread than " - "the socket is running from. Blocking until it shuts down.", - (void *)socket, - socket->io_handle.data.fd); - /* Look.... I know what I'm doing.... trust me, I'm an engineer. - * We wait on the completion before 'args' goes out of scope. - * NOLINTNEXTLINE */ - args.task.arg = &args; - aws_mutex_lock(&args.mutex); - aws_event_loop_schedule_task_now(socket->event_loop, &args.task); - aws_condition_variable_wait_pred(&args.condition_variable, &args.mutex, s_stop_accept_pred, &args); - aws_mutex_unlock(&args.mutex); - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: stop accept task finished running.", - (void *)socket, - socket->io_handle.data.fd); - - if (args.ret_code) { - return aws_raise_error(args.ret_code); - } - return AWS_OP_SUCCESS; - } - - int ret_val = AWS_OP_SUCCESS; - struct posix_socket *socket_impl = socket->impl; - if (socket_impl->currently_subscribed) { - ret_val = aws_event_loop_unsubscribe_from_io_events(socket->event_loop, &socket->io_handle); - socket_impl->currently_subscribed = false; - socket_impl->continue_accept = false; - socket->event_loop = NULL; - } - - return ret_val; -} - -int aws_socket_set_options(struct aws_socket *socket, const struct aws_socket_options *options) { - if (socket->options.domain != options->domain || socket->options.type != options->type) { - return aws_raise_error(AWS_IO_SOCKET_INVALID_OPTIONS); - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: setting socket options to: keep-alive %d, keep idle %d, keep-alive interval %d, keep-alive probe " - "count %d.", - (void *)socket, - socket->io_handle.data.fd, - (int)options->keepalive, - (int)options->keep_alive_timeout_sec, - (int)options->keep_alive_interval_sec, - (int)options->keep_alive_max_failed_probes); - - socket->options = *options; - - int option_value = 1; - if (AWS_UNLIKELY( - setsockopt(socket->io_handle.data.fd, SOL_SOCKET, NO_SIGNAL, &option_value, sizeof(option_value)))) { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: setsockopt() for NO_SIGNAL failed with errno %d. If you are having SIGPIPE signals thrown, " - "you may" - " want to install a signal trap in your application layer.", - (void *)socket, - socket->io_handle.data.fd, - errno); - } - - int reuse = 1; - if (AWS_UNLIKELY(setsockopt(socket->io_handle.data.fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(int)))) { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: setsockopt() for SO_REUSEADDR failed with errno %d.", - (void *)socket, - socket->io_handle.data.fd, - errno); - } - - if (options->type == AWS_SOCKET_STREAM && options->domain != AWS_SOCKET_LOCAL) { - if (socket->options.keepalive) { - int keep_alive = 1; - if (AWS_UNLIKELY( - setsockopt(socket->io_handle.data.fd, SOL_SOCKET, SO_KEEPALIVE, &keep_alive, sizeof(int)))) { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: setsockopt() for enabling SO_KEEPALIVE failed with errno %d.", - (void *)socket, - socket->io_handle.data.fd, - errno); - } - } - - if (socket->options.keep_alive_interval_sec && socket->options.keep_alive_timeout_sec) { - int ival_in_secs = socket->options.keep_alive_interval_sec; - if (AWS_UNLIKELY(setsockopt( - socket->io_handle.data.fd, IPPROTO_TCP, TCP_KEEPIDLE, &ival_in_secs, sizeof(ival_in_secs)))) { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: setsockopt() for enabling TCP_KEEPIDLE for TCP failed with errno %d.", - (void *)socket, - socket->io_handle.data.fd, - errno); - } - - ival_in_secs = socket->options.keep_alive_timeout_sec; - if (AWS_UNLIKELY(setsockopt( - socket->io_handle.data.fd, IPPROTO_TCP, TCP_KEEPINTVL, &ival_in_secs, sizeof(ival_in_secs)))) { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: setsockopt() for enabling TCP_KEEPINTVL for TCP failed with errno %d.", - (void *)socket, - socket->io_handle.data.fd, - errno); - } - } - - if (socket->options.keep_alive_max_failed_probes) { - int max_probes = socket->options.keep_alive_max_failed_probes; - if (AWS_UNLIKELY( - setsockopt(socket->io_handle.data.fd, IPPROTO_TCP, TCP_KEEPCNT, &max_probes, sizeof(max_probes)))) { - AWS_LOGF_WARN( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: setsockopt() for enabling TCP_KEEPCNT for TCP failed with errno %d.", - (void *)socket, - socket->io_handle.data.fd, - errno); - } - } - } - - return AWS_OP_SUCCESS; -} - -struct write_request { - struct aws_byte_cursor cursor_cpy; - aws_socket_on_write_completed_fn *written_fn; - void *write_user_data; - struct aws_linked_list_node node; - size_t original_buffer_len; -}; - -struct posix_socket_close_args { - struct aws_mutex mutex; - struct aws_condition_variable condition_variable; - struct aws_socket *socket; - bool invoked; - int ret_code; -}; - -static bool s_close_predicate(void *arg) { - struct posix_socket_close_args *close_args = arg; - return close_args->invoked; -} - -static void s_close_task(struct aws_task *task, void *arg, enum aws_task_status status) { - (void)task; - (void)status; - - struct posix_socket_close_args *close_args = arg; - aws_mutex_lock(&close_args->mutex); - close_args->ret_code = AWS_OP_SUCCESS; - - if (aws_socket_close(close_args->socket)) { - close_args->ret_code = aws_last_error(); - } - - close_args->invoked = true; - aws_condition_variable_notify_one(&close_args->condition_variable); - aws_mutex_unlock(&close_args->mutex); -} - -int aws_socket_close(struct aws_socket *socket) { - struct posix_socket *socket_impl = socket->impl; - AWS_LOGF_DEBUG(AWS_LS_IO_SOCKET, "id=%p fd=%d: closing", (void *)socket, socket->io_handle.data.fd); - if (socket->event_loop) { - /* don't freak out on me, this almost never happens, and never occurs inside a channel - * it only gets hit from a listening socket shutting down or from a unit test. */ - if (!aws_event_loop_thread_is_callers_thread(socket->event_loop)) { - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: closing from a different thread than " - "the socket is running from. Blocking until it closes down.", - (void *)socket, - socket->io_handle.data.fd); - /* the only time we allow this kind of thing is when you're a listener.*/ - if (socket->state != LISTENING) { - return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); - } - - struct posix_socket_close_args args = { - .mutex = AWS_MUTEX_INIT, - .condition_variable = AWS_CONDITION_VARIABLE_INIT, - .socket = socket, - .ret_code = AWS_OP_SUCCESS, - .invoked = false, - }; - - struct aws_task close_task = { - .fn = s_close_task, - .arg = &args, - }; - - aws_mutex_lock(&args.mutex); - aws_event_loop_schedule_task_now(socket->event_loop, &close_task); - aws_condition_variable_wait_pred(&args.condition_variable, &args.mutex, s_close_predicate, &args); - aws_mutex_unlock(&args.mutex); - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, "id=%p fd=%d: close task completed.", (void *)socket, socket->io_handle.data.fd); - if (args.ret_code) { - return aws_raise_error(args.ret_code); - } - - return AWS_OP_SUCCESS; - } - - if (socket_impl->currently_subscribed) { - if (socket->state & LISTENING) { - aws_socket_stop_accept(socket); - } else { - int err_code = aws_event_loop_unsubscribe_from_io_events(socket->event_loop, &socket->io_handle); - - if (err_code) { - return AWS_OP_ERR; - } - } - socket_impl->currently_subscribed = false; - socket->event_loop = NULL; - } - } - - if (socket_impl->close_happened) { - *socket_impl->close_happened = true; - } - - if (socket_impl->connect_args) { - socket_impl->connect_args->socket = NULL; - socket_impl->connect_args = NULL; - } - - if (aws_socket_is_open(socket)) { - close(socket->io_handle.data.fd); - socket->io_handle.data.fd = -1; - socket->state = CLOSED; - - /* after close, just go ahead and clear out the pending writes queue - * and tell the user they were cancelled. */ - while (!aws_linked_list_empty(&socket_impl->write_queue)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&socket_impl->write_queue); - struct write_request *write_request = AWS_CONTAINER_OF(node, struct write_request, node); - - write_request->written_fn( - socket, AWS_IO_SOCKET_CLOSED, write_request->original_buffer_len, write_request->write_user_data); - aws_mem_release(socket->allocator, write_request); - } - } - - return AWS_OP_SUCCESS; -} - -int aws_socket_shutdown_dir(struct aws_socket *socket, enum aws_channel_direction dir) { - int how = dir == AWS_CHANNEL_DIR_READ ? 0 : 1; - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, "id=%p fd=%d: shutting down in direction %d", (void *)socket, socket->io_handle.data.fd, dir); - if (shutdown(socket->io_handle.data.fd, how)) { - int aws_error = s_determine_socket_error(errno); - return aws_raise_error(aws_error); - } - - if (dir == AWS_CHANNEL_DIR_READ) { - socket->state &= ~CONNECTED_READ; - } else { - socket->state &= ~CONNECTED_WRITE; - } - - return AWS_OP_SUCCESS; -} - -/* this gets called in two scenarios. - * 1st scenario, someone called aws_socket_write() and we want to try writing now, so an error can be returned - * immediately if something bad has happened to the socket. In this case, `parent_request` is set. - * 2nd scenario, the event loop notified us that the socket went writable. In this case `parent_request` is NULL */ -static int s_process_write_requests(struct aws_socket *socket, struct write_request *parent_request) { - struct posix_socket *socket_impl = socket->impl; - struct aws_allocator *allocator = socket->allocator; - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, "id=%p fd=%d: processing write requests.", (void *)socket, socket->io_handle.data.fd); - - /* there's a potential deadlock where we notify the user that we wrote some data, the user - * says, "cool, now I can write more and then immediately calls aws_socket_write(). We need to make sure - * that we don't allow reentrancy in that case. */ - socket_impl->write_in_progress = true; - - if (parent_request) { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: processing write requests, called from aws_socket_write", - (void *)socket, - socket->io_handle.data.fd); - socket_impl->currently_in_event = true; - } else { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: processing write requests, invoked by the event-loop", - (void *)socket, - socket->io_handle.data.fd); - } - - bool purge = false; - int aws_error = AWS_OP_SUCCESS; - bool parent_request_failed = false; - - /* if a close call happens in the middle, this queue will have been cleaned out from under us. */ - while (!aws_linked_list_empty(&socket_impl->write_queue)) { - struct aws_linked_list_node *node = aws_linked_list_front(&socket_impl->write_queue); - struct write_request *write_request = AWS_CONTAINER_OF(node, struct write_request, node); - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: dequeued write request of size %llu, remaining to write %llu", - (void *)socket, - socket->io_handle.data.fd, - (unsigned long long)write_request->original_buffer_len, - (unsigned long long)write_request->cursor_cpy.len); - - ssize_t written = - send(socket->io_handle.data.fd, write_request->cursor_cpy.ptr, write_request->cursor_cpy.len, NO_SIGNAL); - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: send written size %d", - (void *)socket, - socket->io_handle.data.fd, - (int)written); - - if (written < 0) { - int error = errno; - if (error == EAGAIN) { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, "id=%p fd=%d: returned would block", (void *)socket, socket->io_handle.data.fd); - break; - } - - if (error == EPIPE) { - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: already closed before write", - (void *)socket, - socket->io_handle.data.fd); - aws_error = AWS_IO_SOCKET_CLOSED; - aws_raise_error(aws_error); - purge = true; - break; - } - - purge = true; - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: write error with error code %d", - (void *)socket, - socket->io_handle.data.fd, - error); - aws_error = s_determine_socket_error(error); - aws_raise_error(aws_error); - break; - } - - size_t remaining_to_write = write_request->cursor_cpy.len; - - aws_byte_cursor_advance(&write_request->cursor_cpy, (size_t)written); - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: remaining write request to write %llu", - (void *)socket, - socket->io_handle.data.fd, - (unsigned long long)write_request->cursor_cpy.len); - - if ((size_t)written == remaining_to_write) { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, "id=%p fd=%d: write request completed", (void *)socket, socket->io_handle.data.fd); - - aws_linked_list_remove(node); - write_request->written_fn( - socket, AWS_OP_SUCCESS, write_request->original_buffer_len, write_request->write_user_data); - aws_mem_release(allocator, write_request); - } - } - - if (purge) { - while (!aws_linked_list_empty(&socket_impl->write_queue)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&socket_impl->write_queue); - struct write_request *write_request = AWS_CONTAINER_OF(node, struct write_request, node); - - /* If this fn was invoked directly from aws_socket_write(), don't invoke the error callback - * as the user will be able to rely on the return value from aws_socket_write() */ - if (write_request == parent_request) { - parent_request_failed = true; - } else { - write_request->written_fn(socket, aws_error, 0, write_request->write_user_data); - } - - aws_mem_release(socket->allocator, write_request); - } - } - - socket_impl->write_in_progress = false; - - if (parent_request) { - socket_impl->currently_in_event = false; - } - - if (socket_impl->clean_yourself_up) { - aws_mem_release(allocator, socket_impl); - } - - /* Only report error if aws_socket_write() invoked this function and its write_request failed */ - if (!parent_request_failed) { - return AWS_OP_SUCCESS; - } - - aws_raise_error(aws_error); - return AWS_OP_ERR; -} - -static void s_on_socket_io_event( - struct aws_event_loop *event_loop, - struct aws_io_handle *handle, - int events, - void *user_data) { - (void)event_loop; - (void)handle; - /* this is to handle a race condition when an error kicks off a cleanup, or the user decides - * to close the socket based on something they read (SSL validation failed for example). - * if clean_up happens when currently_in_event is true, socket_impl is kept dangling but currently - * subscribed is set to false. */ - struct aws_socket *socket = user_data; - struct posix_socket *socket_impl = socket->impl; - struct aws_allocator *allocator = socket->allocator; - - socket_impl->currently_in_event = true; - - if (events & AWS_IO_EVENT_TYPE_REMOTE_HANG_UP || events & AWS_IO_EVENT_TYPE_CLOSED) { - aws_raise_error(AWS_IO_SOCKET_CLOSED); - AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "id=%p fd=%d: closed remotely", (void *)socket, socket->io_handle.data.fd); - if (socket->readable_fn) { - socket->readable_fn(socket, AWS_IO_SOCKET_CLOSED, socket->readable_user_data); - } - goto end_check; - } - - if (socket_impl->currently_subscribed && events & AWS_IO_EVENT_TYPE_ERROR) { - int aws_error = aws_socket_get_error(socket); - aws_raise_error(aws_error); - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, "id=%p fd=%d: error event occurred", (void *)socket, socket->io_handle.data.fd); - if (socket->readable_fn) { - socket->readable_fn(socket, aws_error, socket->readable_user_data); - } - goto end_check; - } - - if (socket_impl->currently_subscribed && events & AWS_IO_EVENT_TYPE_READABLE) { - AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "id=%p fd=%d: is readable", (void *)socket, socket->io_handle.data.fd); - if (socket->readable_fn) { - socket->readable_fn(socket, AWS_OP_SUCCESS, socket->readable_user_data); - } - } - /* if socket closed in between these branches, the currently_subscribed will be false and socket_impl will not - * have been cleaned up, so this next branch is safe. */ - if (socket_impl->currently_subscribed && events & AWS_IO_EVENT_TYPE_WRITABLE) { - AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "id=%p fd=%d: is writable", (void *)socket, socket->io_handle.data.fd); - s_process_write_requests(socket, NULL); - } - -end_check: - socket_impl->currently_in_event = false; - - if (socket_impl->clean_yourself_up) { - aws_mem_release(allocator, socket_impl); - } -} - -int aws_socket_assign_to_event_loop(struct aws_socket *socket, struct aws_event_loop *event_loop) { - if (!socket->event_loop) { - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: assigning to event loop %p", - (void *)socket, - socket->io_handle.data.fd, - (void *)event_loop); - socket->event_loop = event_loop; - struct posix_socket *socket_impl = socket->impl; - socket_impl->currently_subscribed = true; - if (aws_event_loop_subscribe_to_io_events( - event_loop, - &socket->io_handle, - AWS_IO_EVENT_TYPE_WRITABLE | AWS_IO_EVENT_TYPE_READABLE, - s_on_socket_io_event, - socket)) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: assigning to event loop %p failed with error %d", - (void *)socket, - socket->io_handle.data.fd, - (void *)event_loop, - aws_last_error()); - socket_impl->currently_subscribed = false; - socket->event_loop = NULL; - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; - } - - return aws_raise_error(AWS_IO_EVENT_LOOP_ALREADY_ASSIGNED); -} - -struct aws_event_loop *aws_socket_get_event_loop(struct aws_socket *socket) { - return socket->event_loop; -} - -int aws_socket_subscribe_to_readable_events( - struct aws_socket *socket, - aws_socket_on_readable_fn *on_readable, - void *user_data) { - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, " id=%p fd=%d: subscribing to readable events", (void *)socket, socket->io_handle.data.fd); - if (!(socket->state & CONNECTED_READ)) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: can't subscribe to readable events since the socket is not connected", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_NOT_CONNECTED); - } - - if (socket->readable_fn) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: can't subscribe to readable events since it is already subscribed", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_ERROR_IO_ALREADY_SUBSCRIBED); - } - - AWS_ASSERT(on_readable); - socket->readable_user_data = user_data; - socket->readable_fn = on_readable; - - return AWS_OP_SUCCESS; -} - -int aws_socket_read(struct aws_socket *socket, struct aws_byte_buf *buffer, size_t *amount_read) { - AWS_ASSERT(amount_read); - - if (!aws_event_loop_thread_is_callers_thread(socket->event_loop)) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: cannot read from a different thread than event loop %p", - (void *)socket, - socket->io_handle.data.fd, - (void *)socket->event_loop); - return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); - } - - if (!(socket->state & CONNECTED_READ)) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: cannot read because it is not connected", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_NOT_CONNECTED); - } - - ssize_t read_val = read(socket->io_handle.data.fd, buffer->buffer + buffer->len, buffer->capacity - buffer->len); - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET, "id=%p fd=%d: read of %d", (void *)socket, socket->io_handle.data.fd, (int)read_val); - - if (read_val > 0) { - *amount_read = (size_t)read_val; - buffer->len += *amount_read; - return AWS_OP_SUCCESS; - } - - /* read_val of 0 means EOF which we'll treat as AWS_IO_SOCKET_CLOSED */ - if (read_val == 0) { - AWS_LOGF_INFO( - AWS_LS_IO_SOCKET, "id=%p fd=%d: zero read, socket is closed", (void *)socket, socket->io_handle.data.fd); - *amount_read = 0; - - if (buffer->capacity - buffer->len > 0) { - return aws_raise_error(AWS_IO_SOCKET_CLOSED); - } - - return AWS_OP_SUCCESS; - } - - int error = errno; -#if defined(EWOULDBLOCK) - if (error == EAGAIN || error == EWOULDBLOCK) { -#else - if (error == EAGAIN) { -#endif - AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "id=%p fd=%d: read would block", (void *)socket, socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_READ_WOULD_BLOCK); - } - - if (error == EPIPE) { - AWS_LOGF_INFO(AWS_LS_IO_SOCKET, "id=%p fd=%d: socket is closed.", (void *)socket, socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_CLOSED); - } - - if (error == ETIMEDOUT) { - AWS_LOGF_ERROR(AWS_LS_IO_SOCKET, "id=%p fd=%d: socket timed out.", (void *)socket, socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_TIMEOUT); - } - - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: read failed with error: %s", - (void *)socket, - socket->io_handle.data.fd, - strerror(error)); - return aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); -} - -int aws_socket_write( - struct aws_socket *socket, - const struct aws_byte_cursor *cursor, - aws_socket_on_write_completed_fn *written_fn, - void *user_data) { - if (!aws_event_loop_thread_is_callers_thread(socket->event_loop)) { - return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); - } - - if (!(socket->state & CONNECTED_WRITE)) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: cannot write to because it is not connected", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_NOT_CONNECTED); - } - - AWS_ASSERT(written_fn); - struct posix_socket *socket_impl = socket->impl; - struct write_request *write_request = aws_mem_calloc(socket->allocator, 1, sizeof(struct write_request)); - - if (!write_request) { - return AWS_OP_ERR; - } - - write_request->original_buffer_len = cursor->len; - write_request->written_fn = written_fn; - write_request->write_user_data = user_data; - write_request->cursor_cpy = *cursor; - aws_linked_list_push_back(&socket_impl->write_queue, &write_request->node); - - /* avoid reentrancy when a user calls write after receiving their completion callback. */ - if (!socket_impl->write_in_progress) { - return s_process_write_requests(socket, write_request); - } - - return AWS_OP_SUCCESS; -} - -int aws_socket_get_error(struct aws_socket *socket) { - int connect_result; - socklen_t result_length = sizeof(connect_result); - - if (getsockopt(socket->io_handle.data.fd, SOL_SOCKET, SO_ERROR, &connect_result, &result_length) < 0) { - return AWS_OP_ERR; - } - - if (connect_result) { - return s_determine_socket_error(connect_result); - } - - return AWS_OP_SUCCESS; -} - -bool aws_socket_is_open(struct aws_socket *socket) { - return socket->io_handle.data.fd >= 0; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/socket.h> + +#include <aws/common/clock.h> +#include <aws/common/condition_variable.h> +#include <aws/common/mutex.h> +#include <aws/common/string.h> + +#include <aws/io/event_loop.h> +#include <aws/io/logging.h> + +#include <arpa/inet.h> +#include <aws/io/io.h> +#include <errno.h> +#include <fcntl.h> +#include <netinet/tcp.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#if defined(__MACH__) +# define NO_SIGNAL SO_NOSIGPIPE +# define TCP_KEEPIDLE TCP_KEEPALIVE +#else +# define NO_SIGNAL MSG_NOSIGNAL +#endif + +/* This isn't defined on ancient linux distros (breaking the builds). + * However, if this is a prebuild, we purposely build on an ancient system, but + * we want the kernel calls to still be the same as a modern build since that's likely the target of the application + * calling this code. Just define this if it isn't there already. GlibC and the kernel don't really care how the flag + * gets passed as long as it does. + */ +#ifndef O_CLOEXEC +# define O_CLOEXEC 02000000 +#endif + +#ifdef USE_VSOCK +# if defined(__linux__) && defined(AF_VSOCK) +# include <linux/vm_sockets.h> +# else +# error "USE_VSOCK not supported on current platform" +# endif +#endif + +/* other than CONNECTED_READ | CONNECTED_WRITE + * a socket is only in one of these states at a time. */ +enum socket_state { + INIT = 0x01, + CONNECTING = 0x02, + CONNECTED_READ = 0x04, + CONNECTED_WRITE = 0x08, + BOUND = 0x10, + LISTENING = 0x20, + TIMEDOUT = 0x40, + ERROR = 0x80, + CLOSED, +}; + +static int s_convert_domain(enum aws_socket_domain domain) { + switch (domain) { + case AWS_SOCKET_IPV4: + return AF_INET; + case AWS_SOCKET_IPV6: + return AF_INET6; + case AWS_SOCKET_LOCAL: + return AF_UNIX; +#ifdef USE_VSOCK + case AWS_SOCKET_VSOCK: + return AF_VSOCK; +#endif + default: + AWS_ASSERT(0); + return AF_INET; + } +} + +static int s_convert_type(enum aws_socket_type type) { + switch (type) { + case AWS_SOCKET_STREAM: + return SOCK_STREAM; + case AWS_SOCKET_DGRAM: + return SOCK_DGRAM; + default: + AWS_ASSERT(0); + return SOCK_STREAM; + } +} + +static int s_determine_socket_error(int error) { + switch (error) { + case ECONNREFUSED: + return AWS_IO_SOCKET_CONNECTION_REFUSED; + case ETIMEDOUT: + return AWS_IO_SOCKET_TIMEOUT; + case EHOSTUNREACH: + case ENETUNREACH: + return AWS_IO_SOCKET_NO_ROUTE_TO_HOST; + case EADDRNOTAVAIL: + return AWS_IO_SOCKET_INVALID_ADDRESS; + case ENETDOWN: + return AWS_IO_SOCKET_NETWORK_DOWN; + case ECONNABORTED: + return AWS_IO_SOCKET_CONNECT_ABORTED; + case EADDRINUSE: + return AWS_IO_SOCKET_ADDRESS_IN_USE; + case ENOBUFS: + case ENOMEM: + return AWS_ERROR_OOM; + case EAGAIN: + return AWS_IO_READ_WOULD_BLOCK; + case EMFILE: + case ENFILE: + return AWS_ERROR_MAX_FDS_EXCEEDED; + case ENOENT: + case EINVAL: + return AWS_ERROR_FILE_INVALID_PATH; + case EAFNOSUPPORT: + return AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY; + case EACCES: + return AWS_ERROR_NO_PERMISSION; + default: + return AWS_IO_SOCKET_NOT_CONNECTED; + } +} + +static int s_create_socket(struct aws_socket *sock, const struct aws_socket_options *options) { + + int fd = socket(s_convert_domain(options->domain), s_convert_type(options->type), 0); + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: initializing with domain %d and type %d", + (void *)sock, + fd, + options->domain, + options->type); + if (fd != -1) { + int flags = fcntl(fd, F_GETFL, 0); + flags |= O_NONBLOCK | O_CLOEXEC; + int success = fcntl(fd, F_SETFL, flags); + (void)success; + sock->io_handle.data.fd = fd; + sock->io_handle.additional_data = NULL; + return aws_socket_set_options(sock, options); + } + + int aws_error = s_determine_socket_error(errno); + return aws_raise_error(aws_error); +} + +struct posix_socket_connect_args { + struct aws_task task; + struct aws_allocator *allocator; + struct aws_socket *socket; +}; + +struct posix_socket { + struct aws_linked_list write_queue; + struct posix_socket_connect_args *connect_args; + bool write_in_progress; + bool currently_subscribed; + bool continue_accept; + bool currently_in_event; + bool clean_yourself_up; + bool *close_happened; +}; + +static int s_socket_init( + struct aws_socket *socket, + struct aws_allocator *alloc, + const struct aws_socket_options *options, + int existing_socket_fd) { + AWS_ASSERT(options); + AWS_ZERO_STRUCT(*socket); + + struct posix_socket *posix_socket = aws_mem_calloc(alloc, 1, sizeof(struct posix_socket)); + if (!posix_socket) { + socket->impl = NULL; + return AWS_OP_ERR; + } + + socket->allocator = alloc; + socket->io_handle.data.fd = -1; + socket->state = INIT; + socket->options = *options; + + if (existing_socket_fd < 0) { + int err = s_create_socket(socket, options); + if (err) { + aws_mem_release(alloc, posix_socket); + socket->impl = NULL; + return AWS_OP_ERR; + } + } else { + socket->io_handle = (struct aws_io_handle){ + .data = {.fd = existing_socket_fd}, + .additional_data = NULL, + }; + aws_socket_set_options(socket, options); + } + + aws_linked_list_init(&posix_socket->write_queue); + posix_socket->write_in_progress = false; + posix_socket->currently_subscribed = false; + posix_socket->continue_accept = false; + posix_socket->currently_in_event = false; + posix_socket->clean_yourself_up = false; + posix_socket->connect_args = NULL; + posix_socket->close_happened = NULL; + socket->impl = posix_socket; + return AWS_OP_SUCCESS; +} + +int aws_socket_init(struct aws_socket *socket, struct aws_allocator *alloc, const struct aws_socket_options *options) { + AWS_ASSERT(options); + return s_socket_init(socket, alloc, options, -1); +} + +void aws_socket_clean_up(struct aws_socket *socket) { + if (!socket->impl) { + /* protect from double clean */ + return; + } + if (aws_socket_is_open(socket)) { + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, "id=%p fd=%d: is still open, closing...", (void *)socket, socket->io_handle.data.fd); + aws_socket_close(socket); + } + struct posix_socket *socket_impl = socket->impl; + + if (!socket_impl->currently_in_event) { + aws_mem_release(socket->allocator, socket->impl); + } else { + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: is still pending io letting it dangle and cleaning up later.", + (void *)socket, + socket->io_handle.data.fd); + socket_impl->clean_yourself_up = true; + } + + AWS_ZERO_STRUCT(*socket); + socket->io_handle.data.fd = -1; +} + +static void s_on_connection_error(struct aws_socket *socket, int error); + +static int s_on_connection_success(struct aws_socket *socket) { + + struct aws_event_loop *event_loop = socket->event_loop; + struct posix_socket *socket_impl = socket->impl; + + if (socket_impl->currently_subscribed) { + aws_event_loop_unsubscribe_from_io_events(socket->event_loop, &socket->io_handle); + socket_impl->currently_subscribed = false; + } + + socket->event_loop = NULL; + + int connect_result; + socklen_t result_length = sizeof(connect_result); + + if (getsockopt(socket->io_handle.data.fd, SOL_SOCKET, SO_ERROR, &connect_result, &result_length) < 0) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: failed to determine connection error %d", + (void *)socket, + socket->io_handle.data.fd, + errno); + int aws_error = s_determine_socket_error(errno); + aws_raise_error(aws_error); + s_on_connection_error(socket, aws_error); + return AWS_OP_ERR; + } + + if (connect_result) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: connection error %d", + (void *)socket, + socket->io_handle.data.fd, + connect_result); + int aws_error = s_determine_socket_error(connect_result); + aws_raise_error(aws_error); + s_on_connection_error(socket, aws_error); + return AWS_OP_ERR; + } + + AWS_LOGF_INFO(AWS_LS_IO_SOCKET, "id=%p fd=%d: connection success", (void *)socket, socket->io_handle.data.fd); + + struct sockaddr_storage address; + AWS_ZERO_STRUCT(address); + socklen_t address_size = sizeof(address); + if (!getsockname(socket->io_handle.data.fd, (struct sockaddr *)&address, &address_size)) { + uint16_t port = 0; + + if (address.ss_family == AF_INET) { + struct sockaddr_in *s = (struct sockaddr_in *)&address; + port = ntohs(s->sin_port); + /* this comes straight from the kernal. a.) they won't fail. b.) even if they do, it's not fatal + * once we add logging, we can log this if it fails. */ + if (inet_ntop( + AF_INET, &s->sin_addr, socket->local_endpoint.address, sizeof(socket->local_endpoint.address))) { + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: local endpoint %s:%d", + (void *)socket, + socket->io_handle.data.fd, + socket->local_endpoint.address, + port); + } else { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: determining local endpoint failed", + (void *)socket, + socket->io_handle.data.fd); + } + } else if (address.ss_family == AF_INET6) { + struct sockaddr_in6 *s = (struct sockaddr_in6 *)&address; + port = ntohs(s->sin6_port); + /* this comes straight from the kernal. a.) they won't fail. b.) even if they do, it's not fatal + * once we add logging, we can log this if it fails. */ + if (inet_ntop( + AF_INET6, &s->sin6_addr, socket->local_endpoint.address, sizeof(socket->local_endpoint.address))) { + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd %d: local endpoint %s:%d", + (void *)socket, + socket->io_handle.data.fd, + socket->local_endpoint.address, + port); + } else { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: determining local endpoint failed", + (void *)socket, + socket->io_handle.data.fd); + } + } + + socket->local_endpoint.port = port; + } else { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: getsockname() failed with error %d", + (void *)socket, + socket->io_handle.data.fd, + errno); + int aws_error = s_determine_socket_error(errno); + aws_raise_error(aws_error); + s_on_connection_error(socket, aws_error); + return AWS_OP_ERR; + } + + socket->state = CONNECTED_WRITE | CONNECTED_READ; + + if (aws_socket_assign_to_event_loop(socket, event_loop)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: assignment to event loop %p failed with error %d", + (void *)socket, + socket->io_handle.data.fd, + (void *)event_loop, + aws_last_error()); + s_on_connection_error(socket, aws_last_error()); + return AWS_OP_ERR; + } + + socket->connection_result_fn(socket, AWS_ERROR_SUCCESS, socket->connect_accept_user_data); + + return AWS_OP_SUCCESS; +} + +static void s_on_connection_error(struct aws_socket *socket, int error) { + socket->state = ERROR; + AWS_LOGF_ERROR(AWS_LS_IO_SOCKET, "id=%p fd=%d: connection failure", (void *)socket, socket->io_handle.data.fd); + if (socket->connection_result_fn) { + socket->connection_result_fn(socket, error, socket->connect_accept_user_data); + } else if (socket->accept_result_fn) { + socket->accept_result_fn(socket, error, NULL, socket->connect_accept_user_data); + } +} + +/* the next two callbacks compete based on which one runs first. if s_socket_connect_event + * comes back first, then we set socket_args->socket = NULL and continue on with the connection. + * if s_handle_socket_timeout() runs first, is sees socket_args->socket is NULL and just cleans up its memory. + * s_handle_socket_timeout() will always run so the memory for socket_connect_args is always cleaned up there. */ +static void s_socket_connect_event( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + void *user_data) { + + (void)event_loop; + (void)handle; + + struct posix_socket_connect_args *socket_args = (struct posix_socket_connect_args *)user_data; + AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "fd=%d: connection activity handler triggered ", handle->data.fd); + + if (socket_args->socket) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: has not timed out yet proceeding with connection.", + (void *)socket_args->socket, + handle->data.fd); + + struct posix_socket *socket_impl = socket_args->socket->impl; + if (!(events & AWS_IO_EVENT_TYPE_ERROR || events & AWS_IO_EVENT_TYPE_CLOSED) && + (events & AWS_IO_EVENT_TYPE_READABLE || events & AWS_IO_EVENT_TYPE_WRITABLE)) { + struct aws_socket *socket = socket_args->socket; + socket_args->socket = NULL; + socket_impl->connect_args = NULL; + s_on_connection_success(socket); + return; + } + + int aws_error = aws_socket_get_error(socket_args->socket); + /* we'll get another notification. */ + if (aws_error == AWS_IO_READ_WOULD_BLOCK) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: spurious event, waiting for another notification.", + (void *)socket_args->socket, + handle->data.fd); + return; + } + + struct aws_socket *socket = socket_args->socket; + socket_args->socket = NULL; + socket_impl->connect_args = NULL; + aws_raise_error(aws_error); + s_on_connection_error(socket, aws_error); + } +} + +static void s_handle_socket_timeout(struct aws_task *task, void *args, aws_task_status status) { + (void)task; + (void)status; + + struct posix_socket_connect_args *socket_args = args; + + AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "task_id=%p: timeout task triggered, evaluating timeouts.", (void *)task); + /* successful connection will have nulled out connect_args->socket */ + if (socket_args->socket) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: timed out, shutting down.", + (void *)socket_args->socket, + socket_args->socket->io_handle.data.fd); + + socket_args->socket->state = TIMEDOUT; + int error_code = AWS_IO_SOCKET_TIMEOUT; + + if (status == AWS_TASK_STATUS_RUN_READY) { + aws_event_loop_unsubscribe_from_io_events(socket_args->socket->event_loop, &socket_args->socket->io_handle); + } else { + error_code = AWS_IO_EVENT_LOOP_SHUTDOWN; + aws_event_loop_free_io_event_resources(socket_args->socket->event_loop, &socket_args->socket->io_handle); + } + socket_args->socket->event_loop = NULL; + struct posix_socket *socket_impl = socket_args->socket->impl; + socket_impl->currently_subscribed = false; + aws_raise_error(error_code); + struct aws_socket *socket = socket_args->socket; + /*socket close sets socket_args->socket to NULL and + * socket_impl->connect_args to NULL. */ + aws_socket_close(socket); + s_on_connection_error(socket, error_code); + } + + aws_mem_release(socket_args->allocator, socket_args); +} + +/* this is used simply for moving a connect_success callback when the connect finished immediately + * (like for unix domain sockets) into the event loop's thread. Also note, in that case there was no + * timeout task scheduled, so in this case the socket_args are cleaned up. */ +static void s_run_connect_success(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + struct posix_socket_connect_args *socket_args = arg; + + if (socket_args->socket) { + struct posix_socket *socket_impl = socket_args->socket->impl; + if (status == AWS_TASK_STATUS_RUN_READY) { + s_on_connection_success(socket_args->socket); + } else { + aws_raise_error(AWS_IO_SOCKET_CONNECT_ABORTED); + socket_args->socket->event_loop = NULL; + s_on_connection_error(socket_args->socket, AWS_IO_SOCKET_CONNECT_ABORTED); + } + socket_impl->connect_args = NULL; + } + + aws_mem_release(socket_args->allocator, socket_args); +} + +static inline int s_convert_pton_error(int pton_code) { + if (pton_code == 0) { + return AWS_IO_SOCKET_INVALID_ADDRESS; + } + + return s_determine_socket_error(errno); +} + +struct socket_address { + union sock_addr_types { + struct sockaddr_in addr_in; + struct sockaddr_in6 addr_in6; + struct sockaddr_un un_addr; +#ifdef USE_VSOCK + struct sockaddr_vm vm_addr; +#endif + } sock_addr_types; +}; + +#ifdef USE_VSOCK +/** Convert a string to a VSOCK CID. Respects the calling convetion of inet_pton: + * 0 on error, 1 on success. */ +static int parse_cid(const char *cid_str, unsigned int *value) { + if (cid_str == NULL || value == NULL) { + errno = EINVAL; + return 0; + } + /* strtoll returns 0 as both error and correct value */ + errno = 0; + /* unsigned long long to handle edge cases in convention explicitly */ + long long cid = strtoll(cid_str, NULL, 10); + if (errno != 0) { + return 0; + } + + /* -1U means any, so it's a valid value, but it needs to be converted to + * unsigned int. */ + if (cid == -1) { + *value = VMADDR_CID_ANY; + return 1; + } + + if (cid < 0 || cid > UINT_MAX) { + errno = ERANGE; + return 0; + } + + /* cast is safe here, edge cases already checked */ + *value = (unsigned int)cid; + return 1; +} +#endif + +int aws_socket_connect( + struct aws_socket *socket, + const struct aws_socket_endpoint *remote_endpoint, + struct aws_event_loop *event_loop, + aws_socket_on_connection_result_fn *on_connection_result, + void *user_data) { + AWS_ASSERT(event_loop); + AWS_ASSERT(!socket->event_loop); + + AWS_LOGF_DEBUG(AWS_LS_IO_SOCKET, "id=%p fd=%d: beginning connect.", (void *)socket, socket->io_handle.data.fd); + + if (socket->event_loop) { + return aws_raise_error(AWS_IO_EVENT_LOOP_ALREADY_ASSIGNED); + } + + if (socket->options.type != AWS_SOCKET_DGRAM) { + AWS_ASSERT(on_connection_result); + if (socket->state != INIT) { + return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); + } + } else { /* UDP socket */ + /* UDP sockets jump to CONNECT_READ if bind is called first */ + if (socket->state != CONNECTED_READ && socket->state != INIT) { + return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); + } + } + + size_t address_strlen; + if (aws_secure_strlen(remote_endpoint->address, AWS_ADDRESS_MAX_LEN, &address_strlen)) { + return AWS_OP_ERR; + } + + struct socket_address address; + AWS_ZERO_STRUCT(address); + socklen_t sock_size = 0; + int pton_err = 1; + if (socket->options.domain == AWS_SOCKET_IPV4) { + pton_err = inet_pton(AF_INET, remote_endpoint->address, &address.sock_addr_types.addr_in.sin_addr); + address.sock_addr_types.addr_in.sin_port = htons(remote_endpoint->port); + address.sock_addr_types.addr_in.sin_family = AF_INET; + sock_size = sizeof(address.sock_addr_types.addr_in); + } else if (socket->options.domain == AWS_SOCKET_IPV6) { + pton_err = inet_pton(AF_INET6, remote_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr); + address.sock_addr_types.addr_in6.sin6_port = htons(remote_endpoint->port); + address.sock_addr_types.addr_in6.sin6_family = AF_INET6; + sock_size = sizeof(address.sock_addr_types.addr_in6); + } else if (socket->options.domain == AWS_SOCKET_LOCAL) { + address.sock_addr_types.un_addr.sun_family = AF_UNIX; + strncpy(address.sock_addr_types.un_addr.sun_path, remote_endpoint->address, AWS_ADDRESS_MAX_LEN); + sock_size = sizeof(address.sock_addr_types.un_addr); +#ifdef USE_VSOCK + } else if (socket->options.domain == AWS_SOCKET_VSOCK) { + pton_err = parse_cid(remote_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid); + address.sock_addr_types.vm_addr.svm_family = AF_VSOCK; + address.sock_addr_types.vm_addr.svm_port = (unsigned int)remote_endpoint->port; + sock_size = sizeof(address.sock_addr_types.vm_addr); +#endif + } else { + AWS_ASSERT(0); + return aws_raise_error(AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY); + } + + if (pton_err != 1) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: failed to parse address %s:%d.", + (void *)socket, + socket->io_handle.data.fd, + remote_endpoint->address, + (int)remote_endpoint->port); + return aws_raise_error(s_convert_pton_error(pton_err)); + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: connecting to endpoint %s:%d.", + (void *)socket, + socket->io_handle.data.fd, + remote_endpoint->address, + (int)remote_endpoint->port); + + socket->state = CONNECTING; + socket->remote_endpoint = *remote_endpoint; + socket->connect_accept_user_data = user_data; + socket->connection_result_fn = on_connection_result; + + struct posix_socket *socket_impl = socket->impl; + + socket_impl->connect_args = aws_mem_calloc(socket->allocator, 1, sizeof(struct posix_socket_connect_args)); + if (!socket_impl->connect_args) { + return AWS_OP_ERR; + } + + socket_impl->connect_args->socket = socket; + socket_impl->connect_args->allocator = socket->allocator; + + socket_impl->connect_args->task.fn = s_handle_socket_timeout; + socket_impl->connect_args->task.arg = socket_impl->connect_args; + + int error_code = connect(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size); + socket->event_loop = event_loop; + + if (!error_code) { + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: connected immediately, not scheduling timeout.", + (void *)socket, + socket->io_handle.data.fd); + socket_impl->connect_args->task.fn = s_run_connect_success; + /* the subscription for IO will happen once we setup the connection in the task. Since we already + * know the connection succeeded, we don't need to register for events yet. */ + aws_event_loop_schedule_task_now(event_loop, &socket_impl->connect_args->task); + } + + if (error_code) { + error_code = errno; + if (error_code == EINPROGRESS || error_code == EALREADY) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: connection pending waiting on event-loop notification or timeout.", + (void *)socket, + socket->io_handle.data.fd); + /* cache the timeout task; it is possible for the IO subscription to come back virtually immediately + * and null out the connect args */ + struct aws_task *timeout_task = &socket_impl->connect_args->task; + + socket_impl->currently_subscribed = true; + /* This event is for when the connection finishes. (the fd will flip writable). */ + if (aws_event_loop_subscribe_to_io_events( + event_loop, + &socket->io_handle, + AWS_IO_EVENT_TYPE_WRITABLE, + s_socket_connect_event, + socket_impl->connect_args)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: failed to register with event-loop %p.", + (void *)socket, + socket->io_handle.data.fd, + (void *)event_loop); + socket_impl->currently_subscribed = false; + socket->event_loop = NULL; + goto err_clean_up; + } + + /* schedule a task to run at the connect timeout interval, if this task runs before the connect + * happens, we consider that a timeout. */ + uint64_t timeout = 0; + aws_event_loop_current_clock_time(event_loop, &timeout); + timeout += aws_timestamp_convert( + socket->options.connect_timeout_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL); + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: scheduling timeout task for %llu.", + (void *)socket, + socket->io_handle.data.fd, + (unsigned long long)timeout); + aws_event_loop_schedule_task_future(event_loop, timeout_task, timeout); + } else { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: connect failed with error code %d.", + (void *)socket, + socket->io_handle.data.fd, + error_code); + int aws_error = s_determine_socket_error(error_code); + aws_raise_error(aws_error); + socket->event_loop = NULL; + socket_impl->currently_subscribed = false; + goto err_clean_up; + } + } + return AWS_OP_SUCCESS; + +err_clean_up: + aws_mem_release(socket->allocator, socket_impl->connect_args); + socket_impl->connect_args = NULL; + return AWS_OP_ERR; +} + +int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint) { + if (socket->state != INIT) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: invalid state for bind operation.", + (void *)socket, + socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); + } + + size_t address_strlen; + if (aws_secure_strlen(local_endpoint->address, AWS_ADDRESS_MAX_LEN, &address_strlen)) { + return AWS_OP_ERR; + } + + int error_code = -1; + + socket->local_endpoint = *local_endpoint; + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: binding to %s:%d.", + (void *)socket, + socket->io_handle.data.fd, + local_endpoint->address, + (int)local_endpoint->port); + + struct socket_address address; + AWS_ZERO_STRUCT(address); + socklen_t sock_size = 0; + int pton_err = 1; + if (socket->options.domain == AWS_SOCKET_IPV4) { + pton_err = inet_pton(AF_INET, local_endpoint->address, &address.sock_addr_types.addr_in.sin_addr); + address.sock_addr_types.addr_in.sin_port = htons(local_endpoint->port); + address.sock_addr_types.addr_in.sin_family = AF_INET; + sock_size = sizeof(address.sock_addr_types.addr_in); + } else if (socket->options.domain == AWS_SOCKET_IPV6) { + pton_err = inet_pton(AF_INET6, local_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr); + address.sock_addr_types.addr_in6.sin6_port = htons(local_endpoint->port); + address.sock_addr_types.addr_in6.sin6_family = AF_INET6; + sock_size = sizeof(address.sock_addr_types.addr_in6); + } else if (socket->options.domain == AWS_SOCKET_LOCAL) { + address.sock_addr_types.un_addr.sun_family = AF_UNIX; + strncpy(address.sock_addr_types.un_addr.sun_path, local_endpoint->address, AWS_ADDRESS_MAX_LEN); + sock_size = sizeof(address.sock_addr_types.un_addr); +#ifdef USE_VSOCK + } else if (socket->options.domain == AWS_SOCKET_VSOCK) { + pton_err = parse_cid(local_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid); + address.sock_addr_types.vm_addr.svm_family = AF_VSOCK; + address.sock_addr_types.vm_addr.svm_port = (unsigned int)local_endpoint->port; + sock_size = sizeof(address.sock_addr_types.vm_addr); +#endif + } else { + AWS_ASSERT(0); + return aws_raise_error(AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY); + } + + if (pton_err != 1) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: failed to parse address %s:%d.", + (void *)socket, + socket->io_handle.data.fd, + local_endpoint->address, + (int)local_endpoint->port); + return aws_raise_error(s_convert_pton_error(pton_err)); + } + + error_code = bind(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size); + + if (!error_code) { + if (socket->options.type == AWS_SOCKET_STREAM) { + socket->state = BOUND; + } else { + /* e.g. UDP is now readable */ + socket->state = CONNECTED_READ; + } + AWS_LOGF_DEBUG(AWS_LS_IO_SOCKET, "id=%p fd=%d: successfully bound", (void *)socket, socket->io_handle.data.fd); + + return AWS_OP_SUCCESS; + } + + socket->state = ERROR; + error_code = errno; + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: bind failed with error code %d", + (void *)socket, + socket->io_handle.data.fd, + error_code); + + int aws_error = s_determine_socket_error(error_code); + return aws_raise_error(aws_error); +} + +int aws_socket_listen(struct aws_socket *socket, int backlog_size) { + if (socket->state != BOUND) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: invalid state for listen operation. You must call bind first.", + (void *)socket, + socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); + } + + int error_code = listen(socket->io_handle.data.fd, backlog_size); + + if (!error_code) { + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, "id=%p fd=%d: successfully listening", (void *)socket, socket->io_handle.data.fd); + socket->state = LISTENING; + return AWS_OP_SUCCESS; + } + + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: listen failed with error code %d", + (void *)socket, + socket->io_handle.data.fd, + error_code); + error_code = errno; + socket->state = ERROR; + + return aws_raise_error(s_determine_socket_error(error_code)); +} + +/* this is called by the event loop handler that was installed in start_accept(). It runs once the FD goes readable, + * accepts as many as it can and then returns control to the event loop. */ +static void s_socket_accept_event( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + void *user_data) { + + (void)event_loop; + + struct aws_socket *socket = user_data; + struct posix_socket *socket_impl = socket->impl; + + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, "id=%p fd=%d: listening event received", (void *)socket, socket->io_handle.data.fd); + + if (socket_impl->continue_accept && events & AWS_IO_EVENT_TYPE_READABLE) { + int in_fd = 0; + while (socket_impl->continue_accept && in_fd != -1) { + struct sockaddr_storage in_addr; + socklen_t in_len = sizeof(struct sockaddr_storage); + + in_fd = accept(handle->data.fd, (struct sockaddr *)&in_addr, &in_len); + if (in_fd == -1) { + int error = errno; + + if (error == EAGAIN || error == EWOULDBLOCK) { + break; + } + + int aws_error = aws_socket_get_error(socket); + aws_raise_error(aws_error); + s_on_connection_error(socket, aws_error); + break; + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, "id=%p fd=%d: incoming connection", (void *)socket, socket->io_handle.data.fd); + + struct aws_socket *new_sock = aws_mem_acquire(socket->allocator, sizeof(struct aws_socket)); + + if (!new_sock) { + close(in_fd); + s_on_connection_error(socket, aws_last_error()); + continue; + } + + if (s_socket_init(new_sock, socket->allocator, &socket->options, in_fd)) { + aws_mem_release(socket->allocator, new_sock); + s_on_connection_error(socket, aws_last_error()); + continue; + } + + new_sock->local_endpoint = socket->local_endpoint; + new_sock->state = CONNECTED_READ | CONNECTED_WRITE; + uint16_t port = 0; + + /* get the info on the incoming socket's address */ + if (in_addr.ss_family == AF_INET) { + struct sockaddr_in *s = (struct sockaddr_in *)&in_addr; + port = ntohs(s->sin_port); + /* this came from the kernel, a.) it won't fail. b.) even if it does + * its not fatal. come back and add logging later. */ + if (!inet_ntop( + AF_INET, + &s->sin_addr, + new_sock->remote_endpoint.address, + sizeof(new_sock->remote_endpoint.address))) { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d:. Failed to determine remote address.", + (void *)socket, + socket->io_handle.data.fd) + } + new_sock->options.domain = AWS_SOCKET_IPV4; + } else if (in_addr.ss_family == AF_INET6) { + /* this came from the kernel, a.) it won't fail. b.) even if it does + * its not fatal. come back and add logging later. */ + struct sockaddr_in6 *s = (struct sockaddr_in6 *)&in_addr; + port = ntohs(s->sin6_port); + if (!inet_ntop( + AF_INET6, + &s->sin6_addr, + new_sock->remote_endpoint.address, + sizeof(new_sock->remote_endpoint.address))) { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d:. Failed to determine remote address.", + (void *)socket, + socket->io_handle.data.fd) + } + new_sock->options.domain = AWS_SOCKET_IPV6; + } else if (in_addr.ss_family == AF_UNIX) { + new_sock->remote_endpoint = socket->local_endpoint; + new_sock->options.domain = AWS_SOCKET_LOCAL; + } + + new_sock->remote_endpoint.port = port; + + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: connected to %s:%d, incoming fd %d", + (void *)socket, + socket->io_handle.data.fd, + new_sock->remote_endpoint.address, + new_sock->remote_endpoint.port, + in_fd); + + int flags = fcntl(in_fd, F_GETFL, 0); + + flags |= O_NONBLOCK | O_CLOEXEC; + fcntl(in_fd, F_SETFL, flags); + + bool close_occurred = false; + socket_impl->close_happened = &close_occurred; + socket->accept_result_fn(socket, AWS_ERROR_SUCCESS, new_sock, socket->connect_accept_user_data); + + if (close_occurred) { + return; + } + + socket_impl->close_happened = NULL; + } + } + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: finished processing incoming connections, " + "waiting on event-loop notification", + (void *)socket, + socket->io_handle.data.fd); +} + +int aws_socket_start_accept( + struct aws_socket *socket, + struct aws_event_loop *accept_loop, + aws_socket_on_accept_result_fn *on_accept_result, + void *user_data) { + AWS_ASSERT(on_accept_result); + AWS_ASSERT(accept_loop); + + if (socket->event_loop) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: is already assigned to event-loop %p.", + (void *)socket, + socket->io_handle.data.fd, + (void *)socket->event_loop); + return aws_raise_error(AWS_IO_EVENT_LOOP_ALREADY_ASSIGNED); + } + + if (socket->state != LISTENING) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: invalid state for start_accept operation. You must call listen first.", + (void *)socket, + socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); + } + + socket->accept_result_fn = on_accept_result; + socket->connect_accept_user_data = user_data; + socket->event_loop = accept_loop; + struct posix_socket *socket_impl = socket->impl; + socket_impl->continue_accept = true; + socket_impl->currently_subscribed = true; + + if (aws_event_loop_subscribe_to_io_events( + socket->event_loop, &socket->io_handle, AWS_IO_EVENT_TYPE_READABLE, s_socket_accept_event, socket)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: failed to subscribe to event-loop %p.", + (void *)socket, + socket->io_handle.data.fd, + (void *)socket->event_loop); + socket_impl->continue_accept = false; + socket_impl->currently_subscribed = false; + socket->event_loop = NULL; + + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +struct stop_accept_args { + struct aws_task task; + struct aws_mutex mutex; + struct aws_condition_variable condition_variable; + struct aws_socket *socket; + int ret_code; + bool invoked; +}; + +static bool s_stop_accept_pred(void *arg) { + struct stop_accept_args *stop_accept_args = arg; + return stop_accept_args->invoked; +} + +static void s_stop_accept_task(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + (void)status; + + struct stop_accept_args *stop_accept_args = arg; + aws_mutex_lock(&stop_accept_args->mutex); + stop_accept_args->ret_code = AWS_OP_SUCCESS; + if (aws_socket_stop_accept(stop_accept_args->socket)) { + stop_accept_args->ret_code = aws_last_error(); + } + stop_accept_args->invoked = true; + aws_condition_variable_notify_one(&stop_accept_args->condition_variable); + aws_mutex_unlock(&stop_accept_args->mutex); +} + +int aws_socket_stop_accept(struct aws_socket *socket) { + if (socket->state != LISTENING) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: is not in a listening state, can't stop_accept.", + (void *)socket, + socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); + } + + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, "id=%p fd=%d: stopping accepting new connections", (void *)socket, socket->io_handle.data.fd); + + if (!aws_event_loop_thread_is_callers_thread(socket->event_loop)) { + struct stop_accept_args args = {.mutex = AWS_MUTEX_INIT, + .condition_variable = AWS_CONDITION_VARIABLE_INIT, + .invoked = false, + .socket = socket, + .ret_code = AWS_OP_SUCCESS, + .task = {.fn = s_stop_accept_task}}; + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: stopping accepting new connections from a different thread than " + "the socket is running from. Blocking until it shuts down.", + (void *)socket, + socket->io_handle.data.fd); + /* Look.... I know what I'm doing.... trust me, I'm an engineer. + * We wait on the completion before 'args' goes out of scope. + * NOLINTNEXTLINE */ + args.task.arg = &args; + aws_mutex_lock(&args.mutex); + aws_event_loop_schedule_task_now(socket->event_loop, &args.task); + aws_condition_variable_wait_pred(&args.condition_variable, &args.mutex, s_stop_accept_pred, &args); + aws_mutex_unlock(&args.mutex); + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: stop accept task finished running.", + (void *)socket, + socket->io_handle.data.fd); + + if (args.ret_code) { + return aws_raise_error(args.ret_code); + } + return AWS_OP_SUCCESS; + } + + int ret_val = AWS_OP_SUCCESS; + struct posix_socket *socket_impl = socket->impl; + if (socket_impl->currently_subscribed) { + ret_val = aws_event_loop_unsubscribe_from_io_events(socket->event_loop, &socket->io_handle); + socket_impl->currently_subscribed = false; + socket_impl->continue_accept = false; + socket->event_loop = NULL; + } + + return ret_val; +} + +int aws_socket_set_options(struct aws_socket *socket, const struct aws_socket_options *options) { + if (socket->options.domain != options->domain || socket->options.type != options->type) { + return aws_raise_error(AWS_IO_SOCKET_INVALID_OPTIONS); + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: setting socket options to: keep-alive %d, keep idle %d, keep-alive interval %d, keep-alive probe " + "count %d.", + (void *)socket, + socket->io_handle.data.fd, + (int)options->keepalive, + (int)options->keep_alive_timeout_sec, + (int)options->keep_alive_interval_sec, + (int)options->keep_alive_max_failed_probes); + + socket->options = *options; + + int option_value = 1; + if (AWS_UNLIKELY( + setsockopt(socket->io_handle.data.fd, SOL_SOCKET, NO_SIGNAL, &option_value, sizeof(option_value)))) { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: setsockopt() for NO_SIGNAL failed with errno %d. If you are having SIGPIPE signals thrown, " + "you may" + " want to install a signal trap in your application layer.", + (void *)socket, + socket->io_handle.data.fd, + errno); + } + + int reuse = 1; + if (AWS_UNLIKELY(setsockopt(socket->io_handle.data.fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(int)))) { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: setsockopt() for SO_REUSEADDR failed with errno %d.", + (void *)socket, + socket->io_handle.data.fd, + errno); + } + + if (options->type == AWS_SOCKET_STREAM && options->domain != AWS_SOCKET_LOCAL) { + if (socket->options.keepalive) { + int keep_alive = 1; + if (AWS_UNLIKELY( + setsockopt(socket->io_handle.data.fd, SOL_SOCKET, SO_KEEPALIVE, &keep_alive, sizeof(int)))) { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: setsockopt() for enabling SO_KEEPALIVE failed with errno %d.", + (void *)socket, + socket->io_handle.data.fd, + errno); + } + } + + if (socket->options.keep_alive_interval_sec && socket->options.keep_alive_timeout_sec) { + int ival_in_secs = socket->options.keep_alive_interval_sec; + if (AWS_UNLIKELY(setsockopt( + socket->io_handle.data.fd, IPPROTO_TCP, TCP_KEEPIDLE, &ival_in_secs, sizeof(ival_in_secs)))) { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: setsockopt() for enabling TCP_KEEPIDLE for TCP failed with errno %d.", + (void *)socket, + socket->io_handle.data.fd, + errno); + } + + ival_in_secs = socket->options.keep_alive_timeout_sec; + if (AWS_UNLIKELY(setsockopt( + socket->io_handle.data.fd, IPPROTO_TCP, TCP_KEEPINTVL, &ival_in_secs, sizeof(ival_in_secs)))) { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: setsockopt() for enabling TCP_KEEPINTVL for TCP failed with errno %d.", + (void *)socket, + socket->io_handle.data.fd, + errno); + } + } + + if (socket->options.keep_alive_max_failed_probes) { + int max_probes = socket->options.keep_alive_max_failed_probes; + if (AWS_UNLIKELY( + setsockopt(socket->io_handle.data.fd, IPPROTO_TCP, TCP_KEEPCNT, &max_probes, sizeof(max_probes)))) { + AWS_LOGF_WARN( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: setsockopt() for enabling TCP_KEEPCNT for TCP failed with errno %d.", + (void *)socket, + socket->io_handle.data.fd, + errno); + } + } + } + + return AWS_OP_SUCCESS; +} + +struct write_request { + struct aws_byte_cursor cursor_cpy; + aws_socket_on_write_completed_fn *written_fn; + void *write_user_data; + struct aws_linked_list_node node; + size_t original_buffer_len; +}; + +struct posix_socket_close_args { + struct aws_mutex mutex; + struct aws_condition_variable condition_variable; + struct aws_socket *socket; + bool invoked; + int ret_code; +}; + +static bool s_close_predicate(void *arg) { + struct posix_socket_close_args *close_args = arg; + return close_args->invoked; +} + +static void s_close_task(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + (void)status; + + struct posix_socket_close_args *close_args = arg; + aws_mutex_lock(&close_args->mutex); + close_args->ret_code = AWS_OP_SUCCESS; + + if (aws_socket_close(close_args->socket)) { + close_args->ret_code = aws_last_error(); + } + + close_args->invoked = true; + aws_condition_variable_notify_one(&close_args->condition_variable); + aws_mutex_unlock(&close_args->mutex); +} + +int aws_socket_close(struct aws_socket *socket) { + struct posix_socket *socket_impl = socket->impl; + AWS_LOGF_DEBUG(AWS_LS_IO_SOCKET, "id=%p fd=%d: closing", (void *)socket, socket->io_handle.data.fd); + if (socket->event_loop) { + /* don't freak out on me, this almost never happens, and never occurs inside a channel + * it only gets hit from a listening socket shutting down or from a unit test. */ + if (!aws_event_loop_thread_is_callers_thread(socket->event_loop)) { + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: closing from a different thread than " + "the socket is running from. Blocking until it closes down.", + (void *)socket, + socket->io_handle.data.fd); + /* the only time we allow this kind of thing is when you're a listener.*/ + if (socket->state != LISTENING) { + return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); + } + + struct posix_socket_close_args args = { + .mutex = AWS_MUTEX_INIT, + .condition_variable = AWS_CONDITION_VARIABLE_INIT, + .socket = socket, + .ret_code = AWS_OP_SUCCESS, + .invoked = false, + }; + + struct aws_task close_task = { + .fn = s_close_task, + .arg = &args, + }; + + aws_mutex_lock(&args.mutex); + aws_event_loop_schedule_task_now(socket->event_loop, &close_task); + aws_condition_variable_wait_pred(&args.condition_variable, &args.mutex, s_close_predicate, &args); + aws_mutex_unlock(&args.mutex); + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, "id=%p fd=%d: close task completed.", (void *)socket, socket->io_handle.data.fd); + if (args.ret_code) { + return aws_raise_error(args.ret_code); + } + + return AWS_OP_SUCCESS; + } + + if (socket_impl->currently_subscribed) { + if (socket->state & LISTENING) { + aws_socket_stop_accept(socket); + } else { + int err_code = aws_event_loop_unsubscribe_from_io_events(socket->event_loop, &socket->io_handle); + + if (err_code) { + return AWS_OP_ERR; + } + } + socket_impl->currently_subscribed = false; + socket->event_loop = NULL; + } + } + + if (socket_impl->close_happened) { + *socket_impl->close_happened = true; + } + + if (socket_impl->connect_args) { + socket_impl->connect_args->socket = NULL; + socket_impl->connect_args = NULL; + } + + if (aws_socket_is_open(socket)) { + close(socket->io_handle.data.fd); + socket->io_handle.data.fd = -1; + socket->state = CLOSED; + + /* after close, just go ahead and clear out the pending writes queue + * and tell the user they were cancelled. */ + while (!aws_linked_list_empty(&socket_impl->write_queue)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&socket_impl->write_queue); + struct write_request *write_request = AWS_CONTAINER_OF(node, struct write_request, node); + + write_request->written_fn( + socket, AWS_IO_SOCKET_CLOSED, write_request->original_buffer_len, write_request->write_user_data); + aws_mem_release(socket->allocator, write_request); + } + } + + return AWS_OP_SUCCESS; +} + +int aws_socket_shutdown_dir(struct aws_socket *socket, enum aws_channel_direction dir) { + int how = dir == AWS_CHANNEL_DIR_READ ? 0 : 1; + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, "id=%p fd=%d: shutting down in direction %d", (void *)socket, socket->io_handle.data.fd, dir); + if (shutdown(socket->io_handle.data.fd, how)) { + int aws_error = s_determine_socket_error(errno); + return aws_raise_error(aws_error); + } + + if (dir == AWS_CHANNEL_DIR_READ) { + socket->state &= ~CONNECTED_READ; + } else { + socket->state &= ~CONNECTED_WRITE; + } + + return AWS_OP_SUCCESS; +} + +/* this gets called in two scenarios. + * 1st scenario, someone called aws_socket_write() and we want to try writing now, so an error can be returned + * immediately if something bad has happened to the socket. In this case, `parent_request` is set. + * 2nd scenario, the event loop notified us that the socket went writable. In this case `parent_request` is NULL */ +static int s_process_write_requests(struct aws_socket *socket, struct write_request *parent_request) { + struct posix_socket *socket_impl = socket->impl; + struct aws_allocator *allocator = socket->allocator; + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, "id=%p fd=%d: processing write requests.", (void *)socket, socket->io_handle.data.fd); + + /* there's a potential deadlock where we notify the user that we wrote some data, the user + * says, "cool, now I can write more and then immediately calls aws_socket_write(). We need to make sure + * that we don't allow reentrancy in that case. */ + socket_impl->write_in_progress = true; + + if (parent_request) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: processing write requests, called from aws_socket_write", + (void *)socket, + socket->io_handle.data.fd); + socket_impl->currently_in_event = true; + } else { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: processing write requests, invoked by the event-loop", + (void *)socket, + socket->io_handle.data.fd); + } + + bool purge = false; + int aws_error = AWS_OP_SUCCESS; + bool parent_request_failed = false; + + /* if a close call happens in the middle, this queue will have been cleaned out from under us. */ + while (!aws_linked_list_empty(&socket_impl->write_queue)) { + struct aws_linked_list_node *node = aws_linked_list_front(&socket_impl->write_queue); + struct write_request *write_request = AWS_CONTAINER_OF(node, struct write_request, node); + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: dequeued write request of size %llu, remaining to write %llu", + (void *)socket, + socket->io_handle.data.fd, + (unsigned long long)write_request->original_buffer_len, + (unsigned long long)write_request->cursor_cpy.len); + + ssize_t written = + send(socket->io_handle.data.fd, write_request->cursor_cpy.ptr, write_request->cursor_cpy.len, NO_SIGNAL); + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: send written size %d", + (void *)socket, + socket->io_handle.data.fd, + (int)written); + + if (written < 0) { + int error = errno; + if (error == EAGAIN) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, "id=%p fd=%d: returned would block", (void *)socket, socket->io_handle.data.fd); + break; + } + + if (error == EPIPE) { + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: already closed before write", + (void *)socket, + socket->io_handle.data.fd); + aws_error = AWS_IO_SOCKET_CLOSED; + aws_raise_error(aws_error); + purge = true; + break; + } + + purge = true; + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: write error with error code %d", + (void *)socket, + socket->io_handle.data.fd, + error); + aws_error = s_determine_socket_error(error); + aws_raise_error(aws_error); + break; + } + + size_t remaining_to_write = write_request->cursor_cpy.len; + + aws_byte_cursor_advance(&write_request->cursor_cpy, (size_t)written); + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: remaining write request to write %llu", + (void *)socket, + socket->io_handle.data.fd, + (unsigned long long)write_request->cursor_cpy.len); + + if ((size_t)written == remaining_to_write) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, "id=%p fd=%d: write request completed", (void *)socket, socket->io_handle.data.fd); + + aws_linked_list_remove(node); + write_request->written_fn( + socket, AWS_OP_SUCCESS, write_request->original_buffer_len, write_request->write_user_data); + aws_mem_release(allocator, write_request); + } + } + + if (purge) { + while (!aws_linked_list_empty(&socket_impl->write_queue)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&socket_impl->write_queue); + struct write_request *write_request = AWS_CONTAINER_OF(node, struct write_request, node); + + /* If this fn was invoked directly from aws_socket_write(), don't invoke the error callback + * as the user will be able to rely on the return value from aws_socket_write() */ + if (write_request == parent_request) { + parent_request_failed = true; + } else { + write_request->written_fn(socket, aws_error, 0, write_request->write_user_data); + } + + aws_mem_release(socket->allocator, write_request); + } + } + + socket_impl->write_in_progress = false; + + if (parent_request) { + socket_impl->currently_in_event = false; + } + + if (socket_impl->clean_yourself_up) { + aws_mem_release(allocator, socket_impl); + } + + /* Only report error if aws_socket_write() invoked this function and its write_request failed */ + if (!parent_request_failed) { + return AWS_OP_SUCCESS; + } + + aws_raise_error(aws_error); + return AWS_OP_ERR; +} + +static void s_on_socket_io_event( + struct aws_event_loop *event_loop, + struct aws_io_handle *handle, + int events, + void *user_data) { + (void)event_loop; + (void)handle; + /* this is to handle a race condition when an error kicks off a cleanup, or the user decides + * to close the socket based on something they read (SSL validation failed for example). + * if clean_up happens when currently_in_event is true, socket_impl is kept dangling but currently + * subscribed is set to false. */ + struct aws_socket *socket = user_data; + struct posix_socket *socket_impl = socket->impl; + struct aws_allocator *allocator = socket->allocator; + + socket_impl->currently_in_event = true; + + if (events & AWS_IO_EVENT_TYPE_REMOTE_HANG_UP || events & AWS_IO_EVENT_TYPE_CLOSED) { + aws_raise_error(AWS_IO_SOCKET_CLOSED); + AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "id=%p fd=%d: closed remotely", (void *)socket, socket->io_handle.data.fd); + if (socket->readable_fn) { + socket->readable_fn(socket, AWS_IO_SOCKET_CLOSED, socket->readable_user_data); + } + goto end_check; + } + + if (socket_impl->currently_subscribed && events & AWS_IO_EVENT_TYPE_ERROR) { + int aws_error = aws_socket_get_error(socket); + aws_raise_error(aws_error); + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, "id=%p fd=%d: error event occurred", (void *)socket, socket->io_handle.data.fd); + if (socket->readable_fn) { + socket->readable_fn(socket, aws_error, socket->readable_user_data); + } + goto end_check; + } + + if (socket_impl->currently_subscribed && events & AWS_IO_EVENT_TYPE_READABLE) { + AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "id=%p fd=%d: is readable", (void *)socket, socket->io_handle.data.fd); + if (socket->readable_fn) { + socket->readable_fn(socket, AWS_OP_SUCCESS, socket->readable_user_data); + } + } + /* if socket closed in between these branches, the currently_subscribed will be false and socket_impl will not + * have been cleaned up, so this next branch is safe. */ + if (socket_impl->currently_subscribed && events & AWS_IO_EVENT_TYPE_WRITABLE) { + AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "id=%p fd=%d: is writable", (void *)socket, socket->io_handle.data.fd); + s_process_write_requests(socket, NULL); + } + +end_check: + socket_impl->currently_in_event = false; + + if (socket_impl->clean_yourself_up) { + aws_mem_release(allocator, socket_impl); + } +} + +int aws_socket_assign_to_event_loop(struct aws_socket *socket, struct aws_event_loop *event_loop) { + if (!socket->event_loop) { + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: assigning to event loop %p", + (void *)socket, + socket->io_handle.data.fd, + (void *)event_loop); + socket->event_loop = event_loop; + struct posix_socket *socket_impl = socket->impl; + socket_impl->currently_subscribed = true; + if (aws_event_loop_subscribe_to_io_events( + event_loop, + &socket->io_handle, + AWS_IO_EVENT_TYPE_WRITABLE | AWS_IO_EVENT_TYPE_READABLE, + s_on_socket_io_event, + socket)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: assigning to event loop %p failed with error %d", + (void *)socket, + socket->io_handle.data.fd, + (void *)event_loop, + aws_last_error()); + socket_impl->currently_subscribed = false; + socket->event_loop = NULL; + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; + } + + return aws_raise_error(AWS_IO_EVENT_LOOP_ALREADY_ASSIGNED); +} + +struct aws_event_loop *aws_socket_get_event_loop(struct aws_socket *socket) { + return socket->event_loop; +} + +int aws_socket_subscribe_to_readable_events( + struct aws_socket *socket, + aws_socket_on_readable_fn *on_readable, + void *user_data) { + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, " id=%p fd=%d: subscribing to readable events", (void *)socket, socket->io_handle.data.fd); + if (!(socket->state & CONNECTED_READ)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: can't subscribe to readable events since the socket is not connected", + (void *)socket, + socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_NOT_CONNECTED); + } + + if (socket->readable_fn) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: can't subscribe to readable events since it is already subscribed", + (void *)socket, + socket->io_handle.data.fd); + return aws_raise_error(AWS_ERROR_IO_ALREADY_SUBSCRIBED); + } + + AWS_ASSERT(on_readable); + socket->readable_user_data = user_data; + socket->readable_fn = on_readable; + + return AWS_OP_SUCCESS; +} + +int aws_socket_read(struct aws_socket *socket, struct aws_byte_buf *buffer, size_t *amount_read) { + AWS_ASSERT(amount_read); + + if (!aws_event_loop_thread_is_callers_thread(socket->event_loop)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: cannot read from a different thread than event loop %p", + (void *)socket, + socket->io_handle.data.fd, + (void *)socket->event_loop); + return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); + } + + if (!(socket->state & CONNECTED_READ)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: cannot read because it is not connected", + (void *)socket, + socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_NOT_CONNECTED); + } + + ssize_t read_val = read(socket->io_handle.data.fd, buffer->buffer + buffer->len, buffer->capacity - buffer->len); + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, "id=%p fd=%d: read of %d", (void *)socket, socket->io_handle.data.fd, (int)read_val); + + if (read_val > 0) { + *amount_read = (size_t)read_val; + buffer->len += *amount_read; + return AWS_OP_SUCCESS; + } + + /* read_val of 0 means EOF which we'll treat as AWS_IO_SOCKET_CLOSED */ + if (read_val == 0) { + AWS_LOGF_INFO( + AWS_LS_IO_SOCKET, "id=%p fd=%d: zero read, socket is closed", (void *)socket, socket->io_handle.data.fd); + *amount_read = 0; + + if (buffer->capacity - buffer->len > 0) { + return aws_raise_error(AWS_IO_SOCKET_CLOSED); + } + + return AWS_OP_SUCCESS; + } + + int error = errno; +#if defined(EWOULDBLOCK) + if (error == EAGAIN || error == EWOULDBLOCK) { +#else + if (error == EAGAIN) { +#endif + AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "id=%p fd=%d: read would block", (void *)socket, socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_READ_WOULD_BLOCK); + } + + if (error == EPIPE) { + AWS_LOGF_INFO(AWS_LS_IO_SOCKET, "id=%p fd=%d: socket is closed.", (void *)socket, socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_CLOSED); + } + + if (error == ETIMEDOUT) { + AWS_LOGF_ERROR(AWS_LS_IO_SOCKET, "id=%p fd=%d: socket timed out.", (void *)socket, socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_TIMEOUT); + } + + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: read failed with error: %s", + (void *)socket, + socket->io_handle.data.fd, + strerror(error)); + return aws_raise_error(AWS_ERROR_SYS_CALL_FAILURE); +} + +int aws_socket_write( + struct aws_socket *socket, + const struct aws_byte_cursor *cursor, + aws_socket_on_write_completed_fn *written_fn, + void *user_data) { + if (!aws_event_loop_thread_is_callers_thread(socket->event_loop)) { + return aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); + } + + if (!(socket->state & CONNECTED_WRITE)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p fd=%d: cannot write to because it is not connected", + (void *)socket, + socket->io_handle.data.fd); + return aws_raise_error(AWS_IO_SOCKET_NOT_CONNECTED); + } + + AWS_ASSERT(written_fn); + struct posix_socket *socket_impl = socket->impl; + struct write_request *write_request = aws_mem_calloc(socket->allocator, 1, sizeof(struct write_request)); + + if (!write_request) { + return AWS_OP_ERR; + } + + write_request->original_buffer_len = cursor->len; + write_request->written_fn = written_fn; + write_request->write_user_data = user_data; + write_request->cursor_cpy = *cursor; + aws_linked_list_push_back(&socket_impl->write_queue, &write_request->node); + + /* avoid reentrancy when a user calls write after receiving their completion callback. */ + if (!socket_impl->write_in_progress) { + return s_process_write_requests(socket, write_request); + } + + return AWS_OP_SUCCESS; +} + +int aws_socket_get_error(struct aws_socket *socket) { + int connect_result; + socklen_t result_length = sizeof(connect_result); + + if (getsockopt(socket->io_handle.data.fd, SOL_SOCKET, SO_ERROR, &connect_result, &result_length) < 0) { + return AWS_OP_ERR; + } + + if (connect_result) { + return s_determine_socket_error(connect_result); + } + + return AWS_OP_SUCCESS; +} + +bool aws_socket_is_open(struct aws_socket *socket) { + return socket->io_handle.data.fd >= 0; +} diff --git a/contrib/restricted/aws/aws-c-io/source/retry_strategy.c b/contrib/restricted/aws/aws-c-io/source/retry_strategy.c index 69d444482b..0df0823fe4 100644 --- a/contrib/restricted/aws/aws-c-io/source/retry_strategy.c +++ b/contrib/restricted/aws/aws-c-io/source/retry_strategy.c @@ -1,57 +1,57 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/retry_strategy.h> - -void aws_retry_strategy_acquire(struct aws_retry_strategy *retry_strategy) { - aws_atomic_fetch_add_explicit(&retry_strategy->ref_count, 1, aws_memory_order_relaxed); -} - -void aws_retry_strategy_release(struct aws_retry_strategy *retry_strategy) { - size_t ref_count = aws_atomic_fetch_sub_explicit(&retry_strategy->ref_count, 1, aws_memory_order_seq_cst); - - if (ref_count == 1) { - retry_strategy->vtable->destroy(retry_strategy); - } -} - -int aws_retry_strategy_acquire_retry_token( - struct aws_retry_strategy *retry_strategy, - const struct aws_byte_cursor *partition_id, - aws_retry_strategy_on_retry_token_acquired_fn *on_acquired, - void *user_data, - uint64_t timeout_ms) { - AWS_PRECONDITION(retry_strategy); - AWS_PRECONDITION(retry_strategy->vtable->acquire_token); - return retry_strategy->vtable->acquire_token(retry_strategy, partition_id, on_acquired, user_data, timeout_ms); -} - -int aws_retry_strategy_schedule_retry( - struct aws_retry_token *token, - enum aws_retry_error_type error_type, - aws_retry_strategy_on_retry_ready_fn *retry_ready, - void *user_data) { - AWS_PRECONDITION(token); - AWS_PRECONDITION(token->retry_strategy); - AWS_PRECONDITION(token->retry_strategy->vtable->schedule_retry); - - return token->retry_strategy->vtable->schedule_retry(token, error_type, retry_ready, user_data); -} - -int aws_retry_strategy_token_record_success(struct aws_retry_token *token) { - AWS_PRECONDITION(token); - AWS_PRECONDITION(token->retry_strategy); - AWS_PRECONDITION(token->retry_strategy->vtable->record_success); - - return token->retry_strategy->vtable->record_success(token); -} - -void aws_retry_strategy_release_retry_token(struct aws_retry_token *token) { - if (token) { - AWS_PRECONDITION(token->retry_strategy); - AWS_PRECONDITION(token->retry_strategy->vtable->release_token); - - token->retry_strategy->vtable->release_token(token); - } -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/retry_strategy.h> + +void aws_retry_strategy_acquire(struct aws_retry_strategy *retry_strategy) { + aws_atomic_fetch_add_explicit(&retry_strategy->ref_count, 1, aws_memory_order_relaxed); +} + +void aws_retry_strategy_release(struct aws_retry_strategy *retry_strategy) { + size_t ref_count = aws_atomic_fetch_sub_explicit(&retry_strategy->ref_count, 1, aws_memory_order_seq_cst); + + if (ref_count == 1) { + retry_strategy->vtable->destroy(retry_strategy); + } +} + +int aws_retry_strategy_acquire_retry_token( + struct aws_retry_strategy *retry_strategy, + const struct aws_byte_cursor *partition_id, + aws_retry_strategy_on_retry_token_acquired_fn *on_acquired, + void *user_data, + uint64_t timeout_ms) { + AWS_PRECONDITION(retry_strategy); + AWS_PRECONDITION(retry_strategy->vtable->acquire_token); + return retry_strategy->vtable->acquire_token(retry_strategy, partition_id, on_acquired, user_data, timeout_ms); +} + +int aws_retry_strategy_schedule_retry( + struct aws_retry_token *token, + enum aws_retry_error_type error_type, + aws_retry_strategy_on_retry_ready_fn *retry_ready, + void *user_data) { + AWS_PRECONDITION(token); + AWS_PRECONDITION(token->retry_strategy); + AWS_PRECONDITION(token->retry_strategy->vtable->schedule_retry); + + return token->retry_strategy->vtable->schedule_retry(token, error_type, retry_ready, user_data); +} + +int aws_retry_strategy_token_record_success(struct aws_retry_token *token) { + AWS_PRECONDITION(token); + AWS_PRECONDITION(token->retry_strategy); + AWS_PRECONDITION(token->retry_strategy->vtable->record_success); + + return token->retry_strategy->vtable->record_success(token); +} + +void aws_retry_strategy_release_retry_token(struct aws_retry_token *token) { + if (token) { + AWS_PRECONDITION(token->retry_strategy); + AWS_PRECONDITION(token->retry_strategy->vtable->release_token); + + token->retry_strategy->vtable->release_token(token); + } +} diff --git a/contrib/restricted/aws/aws-c-io/source/s2n/s2n_tls_channel_handler.c b/contrib/restricted/aws/aws-c-io/source/s2n/s2n_tls_channel_handler.c index 9300125423..2aa1dbd82b 100644 --- a/contrib/restricted/aws/aws-c-io/source/s2n/s2n_tls_channel_handler.c +++ b/contrib/restricted/aws/aws-c-io/source/s2n/s2n_tls_channel_handler.c @@ -1,1093 +1,1093 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/tls_channel_handler.h> - -#include <aws/io/channel.h> -#include <aws/io/event_loop.h> -#include <aws/io/file_utils.h> -#include <aws/io/logging.h> -#include <aws/io/pki_utils.h> -#include <aws/io/private/tls_channel_handler_shared.h> -#include <aws/io/statistics.h> - -#include <aws/common/encoding.h> -#include <aws/common/string.h> -#include <aws/common/task_scheduler.h> -#include <aws/common/thread.h> - -#include <errno.h> -#include <inttypes.h> -#include <math.h> -#include <s2n.h> -#include <stdio.h> -#include <stdlib.h> - -#define EST_TLS_RECORD_OVERHEAD 53 /* 5 byte header + 32 + 16 bytes for padding */ -#define KB_1 1024 -#define MAX_RECORD_SIZE (KB_1 * 16) -#define EST_HANDSHAKE_SIZE (7 * KB_1) - -static const char *s_default_ca_dir = NULL; -static const char *s_default_ca_file = NULL; - -struct s2n_handler { - struct aws_channel_handler handler; - struct aws_tls_channel_handler_shared shared_state; - struct s2n_connection *connection; - struct aws_channel_slot *slot; - struct aws_linked_list input_queue; - struct aws_byte_buf protocol; - struct aws_byte_buf server_name; - aws_channel_on_message_write_completed_fn *latest_message_on_completion; - struct aws_channel_task sequential_tasks; - void *latest_message_completion_user_data; - aws_tls_on_negotiation_result_fn *on_negotiation_result; - aws_tls_on_data_read_fn *on_data_read; - aws_tls_on_error_fn *on_error; - void *user_data; - bool advertise_alpn_message; - bool negotiation_finished; -}; - -struct s2n_ctx { - struct aws_tls_ctx ctx; - struct s2n_config *s2n_config; -}; - -static const char *s_determine_default_pki_dir(void) { - /* debian variants */ - if (aws_path_exists("/etc/ssl/certs")) { - return "/etc/ssl/certs"; - } - - /* RHEL variants */ - if (aws_path_exists("/etc/pki/tls/certs")) { - return "/etc/pki/tls/certs"; - } - - /* android */ - if (aws_path_exists("/system/etc/security/cacerts")) { - return "/system/etc/security/cacerts"; - } - - /* Free BSD */ - if (aws_path_exists("/usr/local/share/certs")) { - return "/usr/local/share/certs"; - } - - /* Net BSD */ - if (aws_path_exists("/etc/openssl/certs")) { - return "/etc/openssl/certs"; - } - - return NULL; -} - -static const char *s_determine_default_pki_ca_file(void) { - /* debian variants */ - if (aws_path_exists("/etc/ssl/certs/ca-certificates.crt")) { - return "/etc/ssl/certs/ca-certificates.crt"; - } - - /* Old RHEL variants */ - if (aws_path_exists("/etc/pki/tls/certs/ca-bundle.crt")) { - return "/etc/pki/tls/certs/ca-bundle.crt"; - } - - /* Open SUSE */ - if (aws_path_exists("/etc/ssl/ca-bundle.pem")) { - return "/etc/ssl/ca-bundle.pem"; - } - - /* Open ELEC */ - if (aws_path_exists("/etc/pki/tls/cacert.pem")) { - return "/etc/pki/tls/cacert.pem"; - } - - /* Modern RHEL variants */ - if (aws_path_exists("/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem")) { - return "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem"; - } - - return NULL; -} - -void aws_tls_init_static_state(struct aws_allocator *alloc) { - (void)alloc; - AWS_LOGF_INFO(AWS_LS_IO_TLS, "static: Initializing TLS using s2n."); - - setenv("S2N_ENABLE_CLIENT_MODE", "1", 1); - setenv("S2N_DONT_MLOCK", "1", 1); - s2n_init(); - - s_default_ca_dir = s_determine_default_pki_dir(); - s_default_ca_file = s_determine_default_pki_ca_file(); - AWS_LOGF_DEBUG( - AWS_LS_IO_TLS, - "ctx: Based on OS, we detected the default PKI path as %s, and ca file as %s", - s_default_ca_dir, - s_default_ca_file); -} - -void aws_tls_clean_up_static_state(void) { - s2n_cleanup(); -} - -bool aws_tls_is_alpn_available(void) { - return true; -} - -bool aws_tls_is_cipher_pref_supported(enum aws_tls_cipher_pref cipher_pref) { - switch (cipher_pref) { - case AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT: - return true; - /* PQ Crypto no-ops on android for now */ -#ifndef ANDROID - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2019_06: - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_SIKE_TLSv1_0_2019_11: - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2020_02: - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_SIKE_TLSv1_0_2020_02: - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2020_07: - return true; -#endif - - default: - return false; - } -} - -static int s_generic_read(struct s2n_handler *handler, struct aws_byte_buf *buf) { - - size_t written = 0; - - while (!aws_linked_list_empty(&handler->input_queue) && written < buf->len) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&handler->input_queue); - struct aws_io_message *message = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); - - size_t remaining_message_len = message->message_data.len - message->copy_mark; - size_t remaining_buf_len = buf->len - written; - - size_t to_write = remaining_message_len < remaining_buf_len ? remaining_message_len : remaining_buf_len; - - struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data); - aws_byte_cursor_advance(&message_cursor, message->copy_mark); - aws_byte_cursor_read(&message_cursor, buf->buffer + written, to_write); - - written += to_write; - - message->copy_mark += to_write; - - if (message->copy_mark == message->message_data.len) { - aws_mem_release(message->allocator, message); - } else { - aws_linked_list_push_front(&handler->input_queue, &message->queueing_handle); - } - } - - if (written) { - return (int)written; - } - - errno = EAGAIN; - return -1; -} - -static int s_s2n_handler_recv(void *io_context, uint8_t *buf, uint32_t len) { - struct s2n_handler *handler = (struct s2n_handler *)io_context; - - struct aws_byte_buf read_buffer = aws_byte_buf_from_array(buf, len); - return s_generic_read(handler, &read_buffer); -} - -static int s_generic_send(struct s2n_handler *handler, struct aws_byte_buf *buf) { - - struct aws_byte_cursor buffer_cursor = aws_byte_cursor_from_buf(buf); - - size_t processed = 0; - while (processed < buf->len) { - struct aws_io_message *message = aws_channel_acquire_message_from_pool( - handler->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, buf->len - processed); - - if (!message) { - errno = ENOMEM; - return -1; - } - - const size_t overhead = aws_channel_slot_upstream_message_overhead(handler->slot); - const size_t available_msg_write_capacity = buffer_cursor.len - overhead; - - const size_t to_write = message->message_data.capacity > available_msg_write_capacity - ? available_msg_write_capacity - : message->message_data.capacity; - - struct aws_byte_cursor chunk = aws_byte_cursor_advance(&buffer_cursor, to_write); - if (aws_byte_buf_append(&message->message_data, &chunk)) { - aws_mem_release(message->allocator, message); - return -1; - } - processed += message->message_data.len; - - if (processed == buf->len) { - message->on_completion = handler->latest_message_on_completion; - message->user_data = handler->latest_message_completion_user_data; - handler->latest_message_on_completion = NULL; - handler->latest_message_completion_user_data = NULL; - } - - if (aws_channel_slot_send_message(handler->slot, message, AWS_CHANNEL_DIR_WRITE)) { - aws_mem_release(message->allocator, message); - errno = EPIPE; - return -1; - } - } - - if (processed) { - return (int)processed; - } - - errno = EAGAIN; - return -1; -} - -static int s_s2n_handler_send(void *io_context, const uint8_t *buf, uint32_t len) { - struct s2n_handler *handler = (struct s2n_handler *)io_context; - struct aws_byte_buf send_buf = aws_byte_buf_from_array(buf, len); - - return s_generic_send(handler, &send_buf); -} - -static void s_s2n_handler_destroy(struct aws_channel_handler *handler) { - if (handler) { - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - aws_tls_channel_handler_shared_clean_up(&s2n_handler->shared_state); - s2n_connection_free(s2n_handler->connection); - aws_mem_release(handler->alloc, (void *)s2n_handler); - } -} - -static void s_on_negotiation_result( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - int error_code, - void *user_data) { - - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - - aws_on_tls_negotiation_completed(&s2n_handler->shared_state, error_code); - - if (s2n_handler->on_negotiation_result) { - s2n_handler->on_negotiation_result(handler, slot, error_code, user_data); - } -} - -static int s_drive_negotiation(struct aws_channel_handler *handler) { - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - - aws_on_drive_tls_negotiation(&s2n_handler->shared_state); - - s2n_blocked_status blocked = S2N_NOT_BLOCKED; - do { - int negotiation_code = s2n_negotiate(s2n_handler->connection, &blocked); - - int s2n_error = s2n_errno; - if (negotiation_code == S2N_ERR_T_OK) { - s2n_handler->negotiation_finished = true; - - const char *protocol = s2n_get_application_protocol(s2n_handler->connection); - if (protocol) { - AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Alpn protocol negotiated as %s", (void *)handler, protocol); - s2n_handler->protocol = aws_byte_buf_from_c_str(protocol); - } - - const char *server_name = s2n_get_server_name(s2n_handler->connection); - - if (server_name) { - AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Remote server name is %s", (void *)handler, server_name); - s2n_handler->server_name = aws_byte_buf_from_c_str(server_name); - } - - if (s2n_handler->slot->adj_right && s2n_handler->advertise_alpn_message && protocol) { - struct aws_io_message *message = aws_channel_acquire_message_from_pool( - s2n_handler->slot->channel, - AWS_IO_MESSAGE_APPLICATION_DATA, - sizeof(struct aws_tls_negotiated_protocol_message)); - message->message_tag = AWS_TLS_NEGOTIATED_PROTOCOL_MESSAGE; - struct aws_tls_negotiated_protocol_message *protocol_message = - (struct aws_tls_negotiated_protocol_message *)message->message_data.buffer; - - protocol_message->protocol = s2n_handler->protocol; - message->message_data.len = sizeof(struct aws_tls_negotiated_protocol_message); - if (aws_channel_slot_send_message(s2n_handler->slot, message, AWS_CHANNEL_DIR_READ)) { - aws_mem_release(message->allocator, message); - aws_channel_shutdown(s2n_handler->slot->channel, aws_last_error()); - return AWS_OP_SUCCESS; - } - } - - s_on_negotiation_result(handler, s2n_handler->slot, AWS_OP_SUCCESS, s2n_handler->user_data); - - break; - } - if (s2n_error_get_type(s2n_error) != S2N_ERR_T_BLOCKED) { - AWS_LOGF_WARN( - AWS_LS_IO_TLS, - "id=%p: negotiation failed with error %s (%s)", - (void *)handler, - s2n_strerror(s2n_error, "EN"), - s2n_strerror_debug(s2n_error, "EN")); - - if (s2n_error_get_type(s2n_error) == S2N_ERR_T_ALERT) { - AWS_LOGF_DEBUG( - AWS_LS_IO_TLS, - "id=%p: Alert code %d", - (void *)handler, - s2n_connection_get_alert(s2n_handler->connection)); - } - - const char *err_str = s2n_strerror_debug(s2n_error, NULL); - (void)err_str; - s2n_handler->negotiation_finished = false; - - aws_raise_error(AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE); - - s_on_negotiation_result( - handler, s2n_handler->slot, AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE, s2n_handler->user_data); - - return AWS_OP_ERR; - } - } while (blocked == S2N_NOT_BLOCKED); - - return AWS_OP_SUCCESS; -} - -static void s_negotiation_task(struct aws_channel_task *task, void *arg, aws_task_status status) { - task->task_fn = NULL; - task->arg = NULL; - - if (status == AWS_TASK_STATUS_RUN_READY) { - struct aws_channel_handler *handler = arg; - s_drive_negotiation(handler); - } -} - -int aws_tls_client_handler_start_negotiation(struct aws_channel_handler *handler) { - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - - AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Kicking off TLS negotiation.", (void *)handler) - if (aws_channel_thread_is_callers_thread(s2n_handler->slot->channel)) { - return s_drive_negotiation(handler); - } - - aws_channel_task_init( - &s2n_handler->sequential_tasks, s_negotiation_task, handler, "s2n_channel_handler_negotiation"); - aws_channel_schedule_task_now(s2n_handler->slot->channel, &s2n_handler->sequential_tasks); - - return AWS_OP_SUCCESS; -} - -static int s_s2n_handler_process_read_message( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_io_message *message) { - - struct s2n_handler *s2n_handler = handler->impl; - - if (message) { - aws_linked_list_push_back(&s2n_handler->input_queue, &message->queueing_handle); - - if (!s2n_handler->negotiation_finished) { - size_t message_len = message->message_data.len; - if (!s_drive_negotiation(handler)) { - aws_channel_slot_increment_read_window(slot, message_len); - } else { - aws_channel_shutdown(s2n_handler->slot->channel, AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE); - } - return AWS_OP_SUCCESS; - } - } - - s2n_blocked_status blocked = S2N_NOT_BLOCKED; - size_t downstream_window = SIZE_MAX; - if (slot->adj_right) { - downstream_window = aws_channel_slot_downstream_read_window(slot); - } - - size_t processed = 0; - AWS_LOGF_TRACE( - AWS_LS_IO_TLS, "id=%p: Downstream window %llu", (void *)handler, (unsigned long long)downstream_window); - - while (processed < downstream_window && blocked == S2N_NOT_BLOCKED) { - - struct aws_io_message *outgoing_read_message = aws_channel_acquire_message_from_pool( - slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, downstream_window - processed); - if (!outgoing_read_message) { - return AWS_OP_ERR; - } - - ssize_t read = s2n_recv( - s2n_handler->connection, - outgoing_read_message->message_data.buffer, - outgoing_read_message->message_data.capacity, - &blocked); - - AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Bytes read %lld", (void *)handler, (long long)read); - - /* weird race where we received an alert from the peer, but s2n doesn't tell us about it..... - * if this happens, it's a graceful shutdown, so kick it off here. - * - * In other words, s2n, upon graceful shutdown, follows the unix EOF idiom. So just shutdown with - * SUCCESS. - */ - if (read == 0) { - AWS_LOGF_DEBUG( - AWS_LS_IO_TLS, - "id=%p: Alert code %d", - (void *)handler, - s2n_connection_get_alert(s2n_handler->connection)); - aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); - aws_channel_shutdown(slot->channel, AWS_OP_SUCCESS); - return AWS_OP_SUCCESS; - } - - if (read < 0) { - aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); - continue; - }; - - processed += read; - outgoing_read_message->message_data.len = (size_t)read; - - if (s2n_handler->on_data_read) { - s2n_handler->on_data_read(handler, slot, &outgoing_read_message->message_data, s2n_handler->user_data); - } - - if (slot->adj_right) { - aws_channel_slot_send_message(slot, outgoing_read_message, AWS_CHANNEL_DIR_READ); - } else { - aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); - } - } - - AWS_LOGF_TRACE( - AWS_LS_IO_TLS, - "id=%p: Remaining window for this event-loop tick: %llu", - (void *)handler, - (unsigned long long)downstream_window - processed); - - return AWS_OP_SUCCESS; -} - -static int s_s2n_handler_process_write_message( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_io_message *message) { - (void)slot; - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - - if (AWS_UNLIKELY(!s2n_handler->negotiation_finished)) { - return aws_raise_error(AWS_IO_TLS_ERROR_NOT_NEGOTIATED); - } - - s2n_handler->latest_message_on_completion = message->on_completion; - s2n_handler->latest_message_completion_user_data = message->user_data; - - s2n_blocked_status blocked; - ssize_t write_code = - s2n_send(s2n_handler->connection, message->message_data.buffer, (ssize_t)message->message_data.len, &blocked); - - AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Bytes written: %llu", (void *)handler, (unsigned long long)write_code); - - ssize_t message_len = (ssize_t)message->message_data.len; - - if (write_code < message_len) { - return aws_raise_error(AWS_IO_TLS_ERROR_WRITE_FAILURE); - } - - aws_mem_release(message->allocator, message); - - return AWS_OP_SUCCESS; -} - -static int s_s2n_handler_shutdown( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - enum aws_channel_direction dir, - int error_code, - bool abort_immediately) { - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - - if (dir == AWS_CHANNEL_DIR_WRITE) { - if (!abort_immediately && error_code != AWS_IO_SOCKET_CLOSED) { - AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Shutting down write direction", (void *)handler) - s2n_blocked_status blocked; - /* make a best effort, but the channel is going away after this run, so.... you only get one shot anyways */ - s2n_shutdown(s2n_handler->connection, &blocked); - } - } else { - AWS_LOGF_DEBUG( - AWS_LS_IO_TLS, "id=%p: Shutting down read direction with error code %d", (void *)handler, error_code); - - while (!aws_linked_list_empty(&s2n_handler->input_queue)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&s2n_handler->input_queue); - struct aws_io_message *message = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); - aws_mem_release(message->allocator, message); - } - } - - return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, abort_immediately); -} - -static void s_run_read(struct aws_channel_task *task, void *arg, aws_task_status status) { - task->task_fn = NULL; - task->arg = NULL; - - if (status == AWS_TASK_STATUS_RUN_READY) { - struct aws_channel_handler *handler = (struct aws_channel_handler *)arg; - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - s_s2n_handler_process_read_message(handler, s2n_handler->slot, NULL); - } -} - -static int s_s2n_handler_increment_read_window( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - size_t size) { - (void)size; - struct s2n_handler *s2n_handler = handler->impl; - - size_t downstream_size = aws_channel_slot_downstream_read_window(slot); - size_t current_window_size = slot->window_size; - - AWS_LOGF_TRACE( - AWS_LS_IO_TLS, "id=%p: Increment read window message received %llu", (void *)handler, (unsigned long long)size); - - size_t likely_records_count = (size_t)ceil((double)(downstream_size) / (double)(MAX_RECORD_SIZE)); - size_t offset_size = aws_mul_size_saturating(likely_records_count, EST_TLS_RECORD_OVERHEAD); - size_t total_desired_size = aws_add_size_saturating(offset_size, downstream_size); - - if (total_desired_size > current_window_size) { - size_t window_update_size = total_desired_size - current_window_size; - AWS_LOGF_TRACE( - AWS_LS_IO_TLS, - "id=%p: Propagating read window increment of size %llu", - (void *)handler, - (unsigned long long)window_update_size); - aws_channel_slot_increment_read_window(slot, window_update_size); - } - - if (s2n_handler->negotiation_finished && !s2n_handler->sequential_tasks.node.next) { - /* TLS requires full records before it can decrypt anything. As a result we need to check everything we've - * buffered instead of just waiting on a read from the socket, or we'll hit a deadlock. - * - * We have messages in a queue and they need to be run after the socket has popped (even if it didn't have data - * to read). Alternatively, s2n reads entire records at a time, so we'll need to grab whatever we can and we - * have no idea what's going on inside there. So we need to attempt another read.*/ - aws_channel_task_init( - &s2n_handler->sequential_tasks, s_run_read, handler, "s2n_channel_handler_read_on_window_increment"); - aws_channel_schedule_task_now(slot->channel, &s2n_handler->sequential_tasks); - } - - return AWS_OP_SUCCESS; -} - -static size_t s_s2n_handler_message_overhead(struct aws_channel_handler *handler) { - (void)handler; - return EST_TLS_RECORD_OVERHEAD; -} - -static size_t s_s2n_handler_initial_window_size(struct aws_channel_handler *handler) { - (void)handler; - - return EST_HANDSHAKE_SIZE; -} - -static void s_s2n_handler_reset_statistics(struct aws_channel_handler *handler) { - struct s2n_handler *s2n_handler = handler->impl; - - aws_crt_statistics_tls_reset(&s2n_handler->shared_state.stats); -} - -static void s_s2n_handler_gather_statistics(struct aws_channel_handler *handler, struct aws_array_list *stats) { - struct s2n_handler *s2n_handler = handler->impl; - - void *stats_base = &s2n_handler->shared_state.stats; - aws_array_list_push_back(stats, &stats_base); -} - -struct aws_byte_buf aws_tls_handler_protocol(struct aws_channel_handler *handler) { - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - return s2n_handler->protocol; -} - -struct aws_byte_buf aws_tls_handler_server_name(struct aws_channel_handler *handler) { - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - return s2n_handler->server_name; -} - -static struct aws_channel_handler_vtable s_handler_vtable = { - .destroy = s_s2n_handler_destroy, - .process_read_message = s_s2n_handler_process_read_message, - .process_write_message = s_s2n_handler_process_write_message, - .shutdown = s_s2n_handler_shutdown, - .increment_read_window = s_s2n_handler_increment_read_window, - .initial_window_size = s_s2n_handler_initial_window_size, - .message_overhead = s_s2n_handler_message_overhead, - .reset_statistics = s_s2n_handler_reset_statistics, - .gather_statistics = s_s2n_handler_gather_statistics, -}; - -static int s_parse_protocol_preferences( - struct aws_string *alpn_list_str, - const char protocol_output[4][128], - size_t *protocol_count) { - size_t max_count = *protocol_count; - *protocol_count = 0; - - struct aws_byte_cursor alpn_list_buffer[4]; - AWS_ZERO_ARRAY(alpn_list_buffer); - struct aws_array_list alpn_list; - struct aws_byte_cursor user_alpn_str = aws_byte_cursor_from_string(alpn_list_str); - - aws_array_list_init_static(&alpn_list, alpn_list_buffer, 4, sizeof(struct aws_byte_cursor)); - - if (aws_byte_cursor_split_on_char(&user_alpn_str, ';', &alpn_list)) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - return AWS_OP_ERR; - } - - size_t protocols_list_len = aws_array_list_length(&alpn_list); - if (protocols_list_len < 1) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - return AWS_OP_ERR; - } - - for (size_t i = 0; i < protocols_list_len && i < max_count; ++i) { - struct aws_byte_cursor cursor; - AWS_ZERO_STRUCT(cursor); - if (aws_array_list_get_at(&alpn_list, (void *)&cursor, (size_t)i)) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - return AWS_OP_ERR; - } - AWS_FATAL_ASSERT(cursor.ptr && cursor.len > 0); - memcpy((void *)protocol_output[i], cursor.ptr, cursor.len); - *protocol_count += 1; - } - - return AWS_OP_SUCCESS; -} - -static size_t s_tl_cleanup_key = 0; /* Address of variable serves as key in hash table */ - -/* - * This local object is added to the table of every event loop that has a (s2n) tls connection - * added to it at some point in time - */ -static struct aws_event_loop_local_object s_tl_cleanup_object = {.key = &s_tl_cleanup_key, - .object = NULL, - .on_object_removed = NULL}; - -static void s_aws_cleanup_s2n_thread_local_state(void *user_data) { - (void)user_data; - - s2n_cleanup(); -} - -/* s2n allocates thread-local data structures. We need to clean these up when the event loop's thread exits. */ -static int s_s2n_tls_channel_handler_schedule_thread_local_cleanup(struct aws_channel_slot *slot) { - struct aws_channel *channel = slot->channel; - - struct aws_event_loop_local_object existing_marker; - AWS_ZERO_STRUCT(existing_marker); - - /* - * Check whether another s2n_tls_channel_handler has already scheduled the cleanup task. - */ - if (aws_channel_fetch_local_object(channel, &s_tl_cleanup_key, &existing_marker)) { - /* Doesn't exist in event loop table: add it and add the at-exit cleanup callback */ - if (aws_channel_put_local_object(channel, &s_tl_cleanup_key, &s_tl_cleanup_object)) { - return AWS_OP_ERR; - } - - aws_thread_current_at_exit(s_aws_cleanup_s2n_thread_local_state, NULL); - } - - return AWS_OP_SUCCESS; -} - -static struct aws_channel_handler *s_new_tls_handler( - struct aws_allocator *allocator, - struct aws_tls_connection_options *options, - struct aws_channel_slot *slot, - s2n_mode mode) { - - AWS_ASSERT(options->ctx); - struct s2n_handler *s2n_handler = aws_mem_calloc(allocator, 1, sizeof(struct s2n_handler)); - if (!s2n_handler) { - return NULL; - } - - struct s2n_ctx *s2n_ctx = (struct s2n_ctx *)options->ctx->impl; - s2n_handler->connection = s2n_connection_new(mode); - - if (!s2n_handler->connection) { - goto cleanup_s2n_handler; - } - - aws_tls_channel_handler_shared_init(&s2n_handler->shared_state, &s2n_handler->handler, options); - - s2n_handler->handler.impl = s2n_handler; - s2n_handler->handler.alloc = allocator; - s2n_handler->handler.vtable = &s_handler_vtable; - s2n_handler->handler.slot = slot; - s2n_handler->user_data = options->user_data; - s2n_handler->on_data_read = options->on_data_read; - s2n_handler->on_error = options->on_error; - s2n_handler->on_negotiation_result = options->on_negotiation_result; - s2n_handler->advertise_alpn_message = options->advertise_alpn_message; - - s2n_handler->latest_message_completion_user_data = NULL; - s2n_handler->latest_message_on_completion = NULL; - s2n_handler->slot = slot; - aws_linked_list_init(&s2n_handler->input_queue); - - s2n_handler->protocol = aws_byte_buf_from_array(NULL, 0); - - if (options->server_name) { - - if (s2n_set_server_name(s2n_handler->connection, aws_string_c_str(options->server_name))) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_conn; - } - } - - s2n_handler->negotiation_finished = false; - - s2n_connection_set_recv_cb(s2n_handler->connection, s_s2n_handler_recv); - s2n_connection_set_recv_ctx(s2n_handler->connection, s2n_handler); - s2n_connection_set_send_cb(s2n_handler->connection, s_s2n_handler_send); - s2n_connection_set_send_ctx(s2n_handler->connection, s2n_handler); - s2n_connection_set_blinding(s2n_handler->connection, S2N_SELF_SERVICE_BLINDING); - - if (options->alpn_list) { - AWS_LOGF_DEBUG( - AWS_LS_IO_TLS, - "id=%p: Setting ALPN list %s", - (void *)&s2n_handler->handler, - aws_string_c_str(options->alpn_list)); - - const char protocols_cpy[4][128]; - AWS_ZERO_ARRAY(protocols_cpy); - size_t protocols_size = 4; - if (s_parse_protocol_preferences(options->alpn_list, protocols_cpy, &protocols_size)) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_conn; - } - - const char *protocols[4]; - AWS_ZERO_ARRAY(protocols); - for (size_t i = 0; i < protocols_size; ++i) { - protocols[i] = protocols_cpy[i]; - } - - if (s2n_connection_set_protocol_preferences( - s2n_handler->connection, (const char *const *)protocols, (int)protocols_size)) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_conn; - } - } - - if (s2n_connection_set_config(s2n_handler->connection, s2n_ctx->s2n_config)) { - AWS_LOGF_WARN( - AWS_LS_IO_TLS, - "id=%p: configuration error %s (%s)", - (void *)&s2n_handler->handler, - s2n_strerror(s2n_errno, "EN"), - s2n_strerror_debug(s2n_errno, "EN")); - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_conn; - } - - if (s_s2n_tls_channel_handler_schedule_thread_local_cleanup(slot)) { - goto cleanup_conn; - } - - return &s2n_handler->handler; - -cleanup_conn: - s2n_connection_free(s2n_handler->connection); - -cleanup_s2n_handler: - aws_mem_release(allocator, s2n_handler); - - return NULL; -} - -struct aws_channel_handler *aws_tls_client_handler_new( - struct aws_allocator *allocator, - struct aws_tls_connection_options *options, - struct aws_channel_slot *slot) { - - return s_new_tls_handler(allocator, options, slot, S2N_CLIENT); -} - -struct aws_channel_handler *aws_tls_server_handler_new( - struct aws_allocator *allocator, - struct aws_tls_connection_options *options, - struct aws_channel_slot *slot) { - - return s_new_tls_handler(allocator, options, slot, S2N_SERVER); -} - -static void s_s2n_ctx_destroy(struct s2n_ctx *s2n_ctx) { - if (s2n_ctx != NULL) { - s2n_config_free(s2n_ctx->s2n_config); - aws_mem_release(s2n_ctx->ctx.alloc, s2n_ctx); - } -} - -static struct aws_tls_ctx *s_tls_ctx_new( - struct aws_allocator *alloc, - const struct aws_tls_ctx_options *options, - s2n_mode mode) { - struct s2n_ctx *s2n_ctx = aws_mem_calloc(alloc, 1, sizeof(struct s2n_ctx)); - - if (!s2n_ctx) { - return NULL; - } - - if (!aws_tls_is_cipher_pref_supported(options->cipher_pref)) { - aws_raise_error(AWS_IO_TLS_CIPHER_PREF_UNSUPPORTED); - AWS_LOGF_ERROR(AWS_LS_IO_TLS, "static: TLS Cipher Preference is not supported: %d.", options->cipher_pref); - return NULL; - } - - s2n_ctx->ctx.alloc = alloc; - s2n_ctx->ctx.impl = s2n_ctx; - aws_ref_count_init(&s2n_ctx->ctx.ref_count, s2n_ctx, (aws_simple_completion_callback *)s_s2n_ctx_destroy); - s2n_ctx->s2n_config = s2n_config_new(); - - if (!s2n_ctx->s2n_config) { - goto cleanup_s2n_ctx; - } - - switch (options->minimum_tls_version) { - case AWS_IO_SSLv3: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "CloudFront-SSL-v-3"); - break; - case AWS_IO_TLSv1: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "CloudFront-TLS-1-0-2014"); - break; - case AWS_IO_TLSv1_1: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "ELBSecurityPolicy-TLS-1-1-2017-01"); - break; - case AWS_IO_TLSv1_2: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "ELBSecurityPolicy-TLS-1-2-Ext-2018-06"); - break; - case AWS_IO_TLSv1_3: - AWS_LOGF_ERROR(AWS_LS_IO_TLS, "TLS 1.3 is not supported yet."); - /* sorry guys, we'll add this as soon as s2n does. */ - aws_raise_error(AWS_IO_TLS_VERSION_UNSUPPORTED); - goto cleanup_s2n_ctx; - case AWS_IO_TLS_VER_SYS_DEFAULTS: - default: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "ELBSecurityPolicy-TLS-1-1-2017-01"); - } - - switch (options->cipher_pref) { - case AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT: - /* No-Op, if the user configured a minimum_tls_version then a version-specific Cipher Preference was set */ - break; - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2019_06: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "KMS-PQ-TLS-1-0-2019-06"); - break; - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_SIKE_TLSv1_0_2019_11: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "PQ-SIKE-TEST-TLS-1-0-2019-11"); - break; - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2020_02: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "KMS-PQ-TLS-1-0-2020-02"); - break; - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_SIKE_TLSv1_0_2020_02: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "PQ-SIKE-TEST-TLS-1-0-2020-02"); - break; - case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2020_07: - s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "KMS-PQ-TLS-1-0-2020-07"); - break; - default: - AWS_LOGF_ERROR(AWS_LS_IO_TLS, "Unrecognized TLS Cipher Preference: %d", options->cipher_pref); - aws_raise_error(AWS_IO_TLS_CIPHER_PREF_UNSUPPORTED); - goto cleanup_s2n_ctx; - } - - if (options->certificate.len && options->private_key.len) { - AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "ctx: Certificate and key have been set, setting them up now."); - - if (!aws_text_is_utf8(options->certificate.buffer, options->certificate.len)) { - AWS_LOGF_ERROR(AWS_LS_IO_TLS, "static: failed to import certificate, must be ASCII/UTF-8 encoded"); - aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); - goto cleanup_s2n_ctx; - } - - if (!aws_text_is_utf8(options->private_key.buffer, options->private_key.len)) { - AWS_LOGF_ERROR(AWS_LS_IO_TLS, "static: failed to import private key, must be ASCII/UTF-8 encoded"); - aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); - goto cleanup_s2n_ctx; - } - - int err_code = s2n_config_add_cert_chain_and_key( - s2n_ctx->s2n_config, (const char *)options->certificate.buffer, (const char *)options->private_key.buffer); - - if (mode == S2N_CLIENT) { - s2n_config_set_client_auth_type(s2n_ctx->s2n_config, S2N_CERT_AUTH_REQUIRED); - } - - if (err_code != S2N_ERR_T_OK) { - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, - "ctx: configuration error %s (%s)", - s2n_strerror(s2n_errno, "EN"), - s2n_strerror_debug(s2n_errno, "EN")); - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } - - if (options->verify_peer) { - if (s2n_config_set_check_stapled_ocsp_response(s2n_ctx->s2n_config, 1) == S2N_SUCCESS) { - if (s2n_config_set_status_request_type(s2n_ctx->s2n_config, S2N_STATUS_REQUEST_OCSP) != S2N_SUCCESS) { - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, - "ctx: ocsp status request cannot be set: %s (%s)", - s2n_strerror(s2n_errno, "EN"), - s2n_strerror_debug(s2n_errno, "EN")); - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } else { - if (s2n_error_get_type(s2n_errno) == S2N_ERR_T_USAGE) { - AWS_LOGF_INFO(AWS_LS_IO_TLS, "ctx: cannot enable ocsp stapling: %s", s2n_strerror(s2n_errno, "EN")); - } else { - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, - "ctx: cannot enable ocsp stapling: %s (%s)", - s2n_strerror(s2n_errno, "EN"), - s2n_strerror_debug(s2n_errno, "EN")); - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } - - if (options->ca_path) { - if (s2n_config_set_verification_ca_location( - s2n_ctx->s2n_config, NULL, aws_string_c_str(options->ca_path))) { - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, - "ctx: configuration error %s (%s)", - s2n_strerror(s2n_errno, "EN"), - s2n_strerror_debug(s2n_errno, "EN")); - AWS_LOGF_ERROR(AWS_LS_IO_TLS, "Failed to set ca_path %s\n", aws_string_c_str(options->ca_path)); - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } - - if (options->ca_file.len) { - if (s2n_config_add_pem_to_trust_store(s2n_ctx->s2n_config, (const char *)options->ca_file.buffer)) { - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, - "ctx: configuration error %s (%s)", - s2n_strerror(s2n_errno, "EN"), - s2n_strerror_debug(s2n_errno, "EN")); - AWS_LOGF_ERROR(AWS_LS_IO_TLS, "Failed to set ca_file %s\n", (const char *)options->ca_file.buffer); - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } - - if (!options->ca_path && !options->ca_file.len) { - if (s2n_config_set_verification_ca_location(s2n_ctx->s2n_config, s_default_ca_file, s_default_ca_dir)) { - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, - "ctx: configuration error %s (%s)", - s2n_strerror(s2n_errno, "EN"), - s2n_strerror_debug(s2n_errno, "EN")); - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, "Failed to set ca_path: %s and ca_file %s\n", s_default_ca_dir, s_default_ca_file); - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } - - if (mode == S2N_SERVER && s2n_config_set_client_auth_type(s2n_ctx->s2n_config, S2N_CERT_AUTH_REQUIRED)) { - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, - "ctx: configuration error %s (%s)", - s2n_strerror(s2n_errno, "EN"), - s2n_strerror_debug(s2n_errno, "EN")); - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } else if (mode != S2N_SERVER) { - AWS_LOGF_WARN( - AWS_LS_IO_TLS, - "ctx: X.509 validation has been disabled. " - "If this is not running in a test environment, this is likely a security vulnerability."); - if (s2n_config_disable_x509_verification(s2n_ctx->s2n_config)) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } - - if (options->alpn_list) { - AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "ctx: Setting ALPN list %s", aws_string_c_str(options->alpn_list)); - const char protocols_cpy[4][128]; - AWS_ZERO_ARRAY(protocols_cpy); - size_t protocols_size = 4; - if (s_parse_protocol_preferences(options->alpn_list, protocols_cpy, &protocols_size)) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - - const char *protocols[4]; - AWS_ZERO_ARRAY(protocols); - for (size_t i = 0; i < protocols_size; ++i) { - protocols[i] = protocols_cpy[i]; - } - - if (s2n_config_set_protocol_preferences(s2n_ctx->s2n_config, protocols, (int)protocols_size)) { - aws_raise_error(AWS_IO_TLS_CTX_ERROR); - goto cleanup_s2n_config; - } - } - - if (options->max_fragment_size == 512) { - s2n_config_send_max_fragment_length(s2n_ctx->s2n_config, S2N_TLS_MAX_FRAG_LEN_512); - } else if (options->max_fragment_size == 1024) { - s2n_config_send_max_fragment_length(s2n_ctx->s2n_config, S2N_TLS_MAX_FRAG_LEN_1024); - } else if (options->max_fragment_size == 2048) { - s2n_config_send_max_fragment_length(s2n_ctx->s2n_config, S2N_TLS_MAX_FRAG_LEN_2048); - } else if (options->max_fragment_size == 4096) { - s2n_config_send_max_fragment_length(s2n_ctx->s2n_config, S2N_TLS_MAX_FRAG_LEN_4096); - } - - return &s2n_ctx->ctx; - -cleanup_s2n_config: - s2n_config_free(s2n_ctx->s2n_config); - -cleanup_s2n_ctx: - aws_mem_release(alloc, s2n_ctx); - - return NULL; -} - -struct aws_tls_ctx *aws_tls_server_ctx_new(struct aws_allocator *alloc, const struct aws_tls_ctx_options *options) { - aws_io_fatal_assert_library_initialized(); - return s_tls_ctx_new(alloc, options, S2N_SERVER); -} - -struct aws_tls_ctx *aws_tls_client_ctx_new(struct aws_allocator *alloc, const struct aws_tls_ctx_options *options) { - aws_io_fatal_assert_library_initialized(); - return s_tls_ctx_new(alloc, options, S2N_CLIENT); -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/tls_channel_handler.h> + +#include <aws/io/channel.h> +#include <aws/io/event_loop.h> +#include <aws/io/file_utils.h> +#include <aws/io/logging.h> +#include <aws/io/pki_utils.h> +#include <aws/io/private/tls_channel_handler_shared.h> +#include <aws/io/statistics.h> + +#include <aws/common/encoding.h> +#include <aws/common/string.h> +#include <aws/common/task_scheduler.h> +#include <aws/common/thread.h> + +#include <errno.h> +#include <inttypes.h> +#include <math.h> +#include <s2n.h> +#include <stdio.h> +#include <stdlib.h> + +#define EST_TLS_RECORD_OVERHEAD 53 /* 5 byte header + 32 + 16 bytes for padding */ +#define KB_1 1024 +#define MAX_RECORD_SIZE (KB_1 * 16) +#define EST_HANDSHAKE_SIZE (7 * KB_1) + +static const char *s_default_ca_dir = NULL; +static const char *s_default_ca_file = NULL; + +struct s2n_handler { + struct aws_channel_handler handler; + struct aws_tls_channel_handler_shared shared_state; + struct s2n_connection *connection; + struct aws_channel_slot *slot; + struct aws_linked_list input_queue; + struct aws_byte_buf protocol; + struct aws_byte_buf server_name; + aws_channel_on_message_write_completed_fn *latest_message_on_completion; + struct aws_channel_task sequential_tasks; + void *latest_message_completion_user_data; + aws_tls_on_negotiation_result_fn *on_negotiation_result; + aws_tls_on_data_read_fn *on_data_read; + aws_tls_on_error_fn *on_error; + void *user_data; + bool advertise_alpn_message; + bool negotiation_finished; +}; + +struct s2n_ctx { + struct aws_tls_ctx ctx; + struct s2n_config *s2n_config; +}; + +static const char *s_determine_default_pki_dir(void) { + /* debian variants */ + if (aws_path_exists("/etc/ssl/certs")) { + return "/etc/ssl/certs"; + } + + /* RHEL variants */ + if (aws_path_exists("/etc/pki/tls/certs")) { + return "/etc/pki/tls/certs"; + } + + /* android */ + if (aws_path_exists("/system/etc/security/cacerts")) { + return "/system/etc/security/cacerts"; + } + + /* Free BSD */ + if (aws_path_exists("/usr/local/share/certs")) { + return "/usr/local/share/certs"; + } + + /* Net BSD */ + if (aws_path_exists("/etc/openssl/certs")) { + return "/etc/openssl/certs"; + } + + return NULL; +} + +static const char *s_determine_default_pki_ca_file(void) { + /* debian variants */ + if (aws_path_exists("/etc/ssl/certs/ca-certificates.crt")) { + return "/etc/ssl/certs/ca-certificates.crt"; + } + + /* Old RHEL variants */ + if (aws_path_exists("/etc/pki/tls/certs/ca-bundle.crt")) { + return "/etc/pki/tls/certs/ca-bundle.crt"; + } + + /* Open SUSE */ + if (aws_path_exists("/etc/ssl/ca-bundle.pem")) { + return "/etc/ssl/ca-bundle.pem"; + } + + /* Open ELEC */ + if (aws_path_exists("/etc/pki/tls/cacert.pem")) { + return "/etc/pki/tls/cacert.pem"; + } + + /* Modern RHEL variants */ + if (aws_path_exists("/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem")) { + return "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem"; + } + + return NULL; +} + +void aws_tls_init_static_state(struct aws_allocator *alloc) { + (void)alloc; + AWS_LOGF_INFO(AWS_LS_IO_TLS, "static: Initializing TLS using s2n."); + + setenv("S2N_ENABLE_CLIENT_MODE", "1", 1); + setenv("S2N_DONT_MLOCK", "1", 1); + s2n_init(); + + s_default_ca_dir = s_determine_default_pki_dir(); + s_default_ca_file = s_determine_default_pki_ca_file(); + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, + "ctx: Based on OS, we detected the default PKI path as %s, and ca file as %s", + s_default_ca_dir, + s_default_ca_file); +} + +void aws_tls_clean_up_static_state(void) { + s2n_cleanup(); +} + +bool aws_tls_is_alpn_available(void) { + return true; +} + +bool aws_tls_is_cipher_pref_supported(enum aws_tls_cipher_pref cipher_pref) { + switch (cipher_pref) { + case AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT: + return true; + /* PQ Crypto no-ops on android for now */ +#ifndef ANDROID + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2019_06: + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_SIKE_TLSv1_0_2019_11: + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2020_02: + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_SIKE_TLSv1_0_2020_02: + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2020_07: + return true; +#endif + + default: + return false; + } +} + +static int s_generic_read(struct s2n_handler *handler, struct aws_byte_buf *buf) { + + size_t written = 0; + + while (!aws_linked_list_empty(&handler->input_queue) && written < buf->len) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&handler->input_queue); + struct aws_io_message *message = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); + + size_t remaining_message_len = message->message_data.len - message->copy_mark; + size_t remaining_buf_len = buf->len - written; + + size_t to_write = remaining_message_len < remaining_buf_len ? remaining_message_len : remaining_buf_len; + + struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data); + aws_byte_cursor_advance(&message_cursor, message->copy_mark); + aws_byte_cursor_read(&message_cursor, buf->buffer + written, to_write); + + written += to_write; + + message->copy_mark += to_write; + + if (message->copy_mark == message->message_data.len) { + aws_mem_release(message->allocator, message); + } else { + aws_linked_list_push_front(&handler->input_queue, &message->queueing_handle); + } + } + + if (written) { + return (int)written; + } + + errno = EAGAIN; + return -1; +} + +static int s_s2n_handler_recv(void *io_context, uint8_t *buf, uint32_t len) { + struct s2n_handler *handler = (struct s2n_handler *)io_context; + + struct aws_byte_buf read_buffer = aws_byte_buf_from_array(buf, len); + return s_generic_read(handler, &read_buffer); +} + +static int s_generic_send(struct s2n_handler *handler, struct aws_byte_buf *buf) { + + struct aws_byte_cursor buffer_cursor = aws_byte_cursor_from_buf(buf); + + size_t processed = 0; + while (processed < buf->len) { + struct aws_io_message *message = aws_channel_acquire_message_from_pool( + handler->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, buf->len - processed); + + if (!message) { + errno = ENOMEM; + return -1; + } + + const size_t overhead = aws_channel_slot_upstream_message_overhead(handler->slot); + const size_t available_msg_write_capacity = buffer_cursor.len - overhead; + + const size_t to_write = message->message_data.capacity > available_msg_write_capacity + ? available_msg_write_capacity + : message->message_data.capacity; + + struct aws_byte_cursor chunk = aws_byte_cursor_advance(&buffer_cursor, to_write); + if (aws_byte_buf_append(&message->message_data, &chunk)) { + aws_mem_release(message->allocator, message); + return -1; + } + processed += message->message_data.len; + + if (processed == buf->len) { + message->on_completion = handler->latest_message_on_completion; + message->user_data = handler->latest_message_completion_user_data; + handler->latest_message_on_completion = NULL; + handler->latest_message_completion_user_data = NULL; + } + + if (aws_channel_slot_send_message(handler->slot, message, AWS_CHANNEL_DIR_WRITE)) { + aws_mem_release(message->allocator, message); + errno = EPIPE; + return -1; + } + } + + if (processed) { + return (int)processed; + } + + errno = EAGAIN; + return -1; +} + +static int s_s2n_handler_send(void *io_context, const uint8_t *buf, uint32_t len) { + struct s2n_handler *handler = (struct s2n_handler *)io_context; + struct aws_byte_buf send_buf = aws_byte_buf_from_array(buf, len); + + return s_generic_send(handler, &send_buf); +} + +static void s_s2n_handler_destroy(struct aws_channel_handler *handler) { + if (handler) { + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + aws_tls_channel_handler_shared_clean_up(&s2n_handler->shared_state); + s2n_connection_free(s2n_handler->connection); + aws_mem_release(handler->alloc, (void *)s2n_handler); + } +} + +static void s_on_negotiation_result( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + int error_code, + void *user_data) { + + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + + aws_on_tls_negotiation_completed(&s2n_handler->shared_state, error_code); + + if (s2n_handler->on_negotiation_result) { + s2n_handler->on_negotiation_result(handler, slot, error_code, user_data); + } +} + +static int s_drive_negotiation(struct aws_channel_handler *handler) { + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + + aws_on_drive_tls_negotiation(&s2n_handler->shared_state); + + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + do { + int negotiation_code = s2n_negotiate(s2n_handler->connection, &blocked); + + int s2n_error = s2n_errno; + if (negotiation_code == S2N_ERR_T_OK) { + s2n_handler->negotiation_finished = true; + + const char *protocol = s2n_get_application_protocol(s2n_handler->connection); + if (protocol) { + AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Alpn protocol negotiated as %s", (void *)handler, protocol); + s2n_handler->protocol = aws_byte_buf_from_c_str(protocol); + } + + const char *server_name = s2n_get_server_name(s2n_handler->connection); + + if (server_name) { + AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Remote server name is %s", (void *)handler, server_name); + s2n_handler->server_name = aws_byte_buf_from_c_str(server_name); + } + + if (s2n_handler->slot->adj_right && s2n_handler->advertise_alpn_message && protocol) { + struct aws_io_message *message = aws_channel_acquire_message_from_pool( + s2n_handler->slot->channel, + AWS_IO_MESSAGE_APPLICATION_DATA, + sizeof(struct aws_tls_negotiated_protocol_message)); + message->message_tag = AWS_TLS_NEGOTIATED_PROTOCOL_MESSAGE; + struct aws_tls_negotiated_protocol_message *protocol_message = + (struct aws_tls_negotiated_protocol_message *)message->message_data.buffer; + + protocol_message->protocol = s2n_handler->protocol; + message->message_data.len = sizeof(struct aws_tls_negotiated_protocol_message); + if (aws_channel_slot_send_message(s2n_handler->slot, message, AWS_CHANNEL_DIR_READ)) { + aws_mem_release(message->allocator, message); + aws_channel_shutdown(s2n_handler->slot->channel, aws_last_error()); + return AWS_OP_SUCCESS; + } + } + + s_on_negotiation_result(handler, s2n_handler->slot, AWS_OP_SUCCESS, s2n_handler->user_data); + + break; + } + if (s2n_error_get_type(s2n_error) != S2N_ERR_T_BLOCKED) { + AWS_LOGF_WARN( + AWS_LS_IO_TLS, + "id=%p: negotiation failed with error %s (%s)", + (void *)handler, + s2n_strerror(s2n_error, "EN"), + s2n_strerror_debug(s2n_error, "EN")); + + if (s2n_error_get_type(s2n_error) == S2N_ERR_T_ALERT) { + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, + "id=%p: Alert code %d", + (void *)handler, + s2n_connection_get_alert(s2n_handler->connection)); + } + + const char *err_str = s2n_strerror_debug(s2n_error, NULL); + (void)err_str; + s2n_handler->negotiation_finished = false; + + aws_raise_error(AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE); + + s_on_negotiation_result( + handler, s2n_handler->slot, AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE, s2n_handler->user_data); + + return AWS_OP_ERR; + } + } while (blocked == S2N_NOT_BLOCKED); + + return AWS_OP_SUCCESS; +} + +static void s_negotiation_task(struct aws_channel_task *task, void *arg, aws_task_status status) { + task->task_fn = NULL; + task->arg = NULL; + + if (status == AWS_TASK_STATUS_RUN_READY) { + struct aws_channel_handler *handler = arg; + s_drive_negotiation(handler); + } +} + +int aws_tls_client_handler_start_negotiation(struct aws_channel_handler *handler) { + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + + AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Kicking off TLS negotiation.", (void *)handler) + if (aws_channel_thread_is_callers_thread(s2n_handler->slot->channel)) { + return s_drive_negotiation(handler); + } + + aws_channel_task_init( + &s2n_handler->sequential_tasks, s_negotiation_task, handler, "s2n_channel_handler_negotiation"); + aws_channel_schedule_task_now(s2n_handler->slot->channel, &s2n_handler->sequential_tasks); + + return AWS_OP_SUCCESS; +} + +static int s_s2n_handler_process_read_message( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_io_message *message) { + + struct s2n_handler *s2n_handler = handler->impl; + + if (message) { + aws_linked_list_push_back(&s2n_handler->input_queue, &message->queueing_handle); + + if (!s2n_handler->negotiation_finished) { + size_t message_len = message->message_data.len; + if (!s_drive_negotiation(handler)) { + aws_channel_slot_increment_read_window(slot, message_len); + } else { + aws_channel_shutdown(s2n_handler->slot->channel, AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE); + } + return AWS_OP_SUCCESS; + } + } + + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + size_t downstream_window = SIZE_MAX; + if (slot->adj_right) { + downstream_window = aws_channel_slot_downstream_read_window(slot); + } + + size_t processed = 0; + AWS_LOGF_TRACE( + AWS_LS_IO_TLS, "id=%p: Downstream window %llu", (void *)handler, (unsigned long long)downstream_window); + + while (processed < downstream_window && blocked == S2N_NOT_BLOCKED) { + + struct aws_io_message *outgoing_read_message = aws_channel_acquire_message_from_pool( + slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, downstream_window - processed); + if (!outgoing_read_message) { + return AWS_OP_ERR; + } + + ssize_t read = s2n_recv( + s2n_handler->connection, + outgoing_read_message->message_data.buffer, + outgoing_read_message->message_data.capacity, + &blocked); + + AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Bytes read %lld", (void *)handler, (long long)read); + + /* weird race where we received an alert from the peer, but s2n doesn't tell us about it..... + * if this happens, it's a graceful shutdown, so kick it off here. + * + * In other words, s2n, upon graceful shutdown, follows the unix EOF idiom. So just shutdown with + * SUCCESS. + */ + if (read == 0) { + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, + "id=%p: Alert code %d", + (void *)handler, + s2n_connection_get_alert(s2n_handler->connection)); + aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); + aws_channel_shutdown(slot->channel, AWS_OP_SUCCESS); + return AWS_OP_SUCCESS; + } + + if (read < 0) { + aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); + continue; + }; + + processed += read; + outgoing_read_message->message_data.len = (size_t)read; + + if (s2n_handler->on_data_read) { + s2n_handler->on_data_read(handler, slot, &outgoing_read_message->message_data, s2n_handler->user_data); + } + + if (slot->adj_right) { + aws_channel_slot_send_message(slot, outgoing_read_message, AWS_CHANNEL_DIR_READ); + } else { + aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); + } + } + + AWS_LOGF_TRACE( + AWS_LS_IO_TLS, + "id=%p: Remaining window for this event-loop tick: %llu", + (void *)handler, + (unsigned long long)downstream_window - processed); + + return AWS_OP_SUCCESS; +} + +static int s_s2n_handler_process_write_message( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_io_message *message) { + (void)slot; + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + + if (AWS_UNLIKELY(!s2n_handler->negotiation_finished)) { + return aws_raise_error(AWS_IO_TLS_ERROR_NOT_NEGOTIATED); + } + + s2n_handler->latest_message_on_completion = message->on_completion; + s2n_handler->latest_message_completion_user_data = message->user_data; + + s2n_blocked_status blocked; + ssize_t write_code = + s2n_send(s2n_handler->connection, message->message_data.buffer, (ssize_t)message->message_data.len, &blocked); + + AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Bytes written: %llu", (void *)handler, (unsigned long long)write_code); + + ssize_t message_len = (ssize_t)message->message_data.len; + + if (write_code < message_len) { + return aws_raise_error(AWS_IO_TLS_ERROR_WRITE_FAILURE); + } + + aws_mem_release(message->allocator, message); + + return AWS_OP_SUCCESS; +} + +static int s_s2n_handler_shutdown( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + enum aws_channel_direction dir, + int error_code, + bool abort_immediately) { + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + + if (dir == AWS_CHANNEL_DIR_WRITE) { + if (!abort_immediately && error_code != AWS_IO_SOCKET_CLOSED) { + AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Shutting down write direction", (void *)handler) + s2n_blocked_status blocked; + /* make a best effort, but the channel is going away after this run, so.... you only get one shot anyways */ + s2n_shutdown(s2n_handler->connection, &blocked); + } + } else { + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, "id=%p: Shutting down read direction with error code %d", (void *)handler, error_code); + + while (!aws_linked_list_empty(&s2n_handler->input_queue)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&s2n_handler->input_queue); + struct aws_io_message *message = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); + aws_mem_release(message->allocator, message); + } + } + + return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, abort_immediately); +} + +static void s_run_read(struct aws_channel_task *task, void *arg, aws_task_status status) { + task->task_fn = NULL; + task->arg = NULL; + + if (status == AWS_TASK_STATUS_RUN_READY) { + struct aws_channel_handler *handler = (struct aws_channel_handler *)arg; + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + s_s2n_handler_process_read_message(handler, s2n_handler->slot, NULL); + } +} + +static int s_s2n_handler_increment_read_window( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + size_t size) { + (void)size; + struct s2n_handler *s2n_handler = handler->impl; + + size_t downstream_size = aws_channel_slot_downstream_read_window(slot); + size_t current_window_size = slot->window_size; + + AWS_LOGF_TRACE( + AWS_LS_IO_TLS, "id=%p: Increment read window message received %llu", (void *)handler, (unsigned long long)size); + + size_t likely_records_count = (size_t)ceil((double)(downstream_size) / (double)(MAX_RECORD_SIZE)); + size_t offset_size = aws_mul_size_saturating(likely_records_count, EST_TLS_RECORD_OVERHEAD); + size_t total_desired_size = aws_add_size_saturating(offset_size, downstream_size); + + if (total_desired_size > current_window_size) { + size_t window_update_size = total_desired_size - current_window_size; + AWS_LOGF_TRACE( + AWS_LS_IO_TLS, + "id=%p: Propagating read window increment of size %llu", + (void *)handler, + (unsigned long long)window_update_size); + aws_channel_slot_increment_read_window(slot, window_update_size); + } + + if (s2n_handler->negotiation_finished && !s2n_handler->sequential_tasks.node.next) { + /* TLS requires full records before it can decrypt anything. As a result we need to check everything we've + * buffered instead of just waiting on a read from the socket, or we'll hit a deadlock. + * + * We have messages in a queue and they need to be run after the socket has popped (even if it didn't have data + * to read). Alternatively, s2n reads entire records at a time, so we'll need to grab whatever we can and we + * have no idea what's going on inside there. So we need to attempt another read.*/ + aws_channel_task_init( + &s2n_handler->sequential_tasks, s_run_read, handler, "s2n_channel_handler_read_on_window_increment"); + aws_channel_schedule_task_now(slot->channel, &s2n_handler->sequential_tasks); + } + + return AWS_OP_SUCCESS; +} + +static size_t s_s2n_handler_message_overhead(struct aws_channel_handler *handler) { + (void)handler; + return EST_TLS_RECORD_OVERHEAD; +} + +static size_t s_s2n_handler_initial_window_size(struct aws_channel_handler *handler) { + (void)handler; + + return EST_HANDSHAKE_SIZE; +} + +static void s_s2n_handler_reset_statistics(struct aws_channel_handler *handler) { + struct s2n_handler *s2n_handler = handler->impl; + + aws_crt_statistics_tls_reset(&s2n_handler->shared_state.stats); +} + +static void s_s2n_handler_gather_statistics(struct aws_channel_handler *handler, struct aws_array_list *stats) { + struct s2n_handler *s2n_handler = handler->impl; + + void *stats_base = &s2n_handler->shared_state.stats; + aws_array_list_push_back(stats, &stats_base); +} + +struct aws_byte_buf aws_tls_handler_protocol(struct aws_channel_handler *handler) { + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + return s2n_handler->protocol; +} + +struct aws_byte_buf aws_tls_handler_server_name(struct aws_channel_handler *handler) { + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + return s2n_handler->server_name; +} + +static struct aws_channel_handler_vtable s_handler_vtable = { + .destroy = s_s2n_handler_destroy, + .process_read_message = s_s2n_handler_process_read_message, + .process_write_message = s_s2n_handler_process_write_message, + .shutdown = s_s2n_handler_shutdown, + .increment_read_window = s_s2n_handler_increment_read_window, + .initial_window_size = s_s2n_handler_initial_window_size, + .message_overhead = s_s2n_handler_message_overhead, + .reset_statistics = s_s2n_handler_reset_statistics, + .gather_statistics = s_s2n_handler_gather_statistics, +}; + +static int s_parse_protocol_preferences( + struct aws_string *alpn_list_str, + const char protocol_output[4][128], + size_t *protocol_count) { + size_t max_count = *protocol_count; + *protocol_count = 0; + + struct aws_byte_cursor alpn_list_buffer[4]; + AWS_ZERO_ARRAY(alpn_list_buffer); + struct aws_array_list alpn_list; + struct aws_byte_cursor user_alpn_str = aws_byte_cursor_from_string(alpn_list_str); + + aws_array_list_init_static(&alpn_list, alpn_list_buffer, 4, sizeof(struct aws_byte_cursor)); + + if (aws_byte_cursor_split_on_char(&user_alpn_str, ';', &alpn_list)) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + return AWS_OP_ERR; + } + + size_t protocols_list_len = aws_array_list_length(&alpn_list); + if (protocols_list_len < 1) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + return AWS_OP_ERR; + } + + for (size_t i = 0; i < protocols_list_len && i < max_count; ++i) { + struct aws_byte_cursor cursor; + AWS_ZERO_STRUCT(cursor); + if (aws_array_list_get_at(&alpn_list, (void *)&cursor, (size_t)i)) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + return AWS_OP_ERR; + } + AWS_FATAL_ASSERT(cursor.ptr && cursor.len > 0); + memcpy((void *)protocol_output[i], cursor.ptr, cursor.len); + *protocol_count += 1; + } + + return AWS_OP_SUCCESS; +} + +static size_t s_tl_cleanup_key = 0; /* Address of variable serves as key in hash table */ + +/* + * This local object is added to the table of every event loop that has a (s2n) tls connection + * added to it at some point in time + */ +static struct aws_event_loop_local_object s_tl_cleanup_object = {.key = &s_tl_cleanup_key, + .object = NULL, + .on_object_removed = NULL}; + +static void s_aws_cleanup_s2n_thread_local_state(void *user_data) { + (void)user_data; + + s2n_cleanup(); +} + +/* s2n allocates thread-local data structures. We need to clean these up when the event loop's thread exits. */ +static int s_s2n_tls_channel_handler_schedule_thread_local_cleanup(struct aws_channel_slot *slot) { + struct aws_channel *channel = slot->channel; + + struct aws_event_loop_local_object existing_marker; + AWS_ZERO_STRUCT(existing_marker); + + /* + * Check whether another s2n_tls_channel_handler has already scheduled the cleanup task. + */ + if (aws_channel_fetch_local_object(channel, &s_tl_cleanup_key, &existing_marker)) { + /* Doesn't exist in event loop table: add it and add the at-exit cleanup callback */ + if (aws_channel_put_local_object(channel, &s_tl_cleanup_key, &s_tl_cleanup_object)) { + return AWS_OP_ERR; + } + + aws_thread_current_at_exit(s_aws_cleanup_s2n_thread_local_state, NULL); + } + + return AWS_OP_SUCCESS; +} + +static struct aws_channel_handler *s_new_tls_handler( + struct aws_allocator *allocator, + struct aws_tls_connection_options *options, + struct aws_channel_slot *slot, + s2n_mode mode) { + + AWS_ASSERT(options->ctx); + struct s2n_handler *s2n_handler = aws_mem_calloc(allocator, 1, sizeof(struct s2n_handler)); + if (!s2n_handler) { + return NULL; + } + + struct s2n_ctx *s2n_ctx = (struct s2n_ctx *)options->ctx->impl; + s2n_handler->connection = s2n_connection_new(mode); + + if (!s2n_handler->connection) { + goto cleanup_s2n_handler; + } + + aws_tls_channel_handler_shared_init(&s2n_handler->shared_state, &s2n_handler->handler, options); + + s2n_handler->handler.impl = s2n_handler; + s2n_handler->handler.alloc = allocator; + s2n_handler->handler.vtable = &s_handler_vtable; + s2n_handler->handler.slot = slot; + s2n_handler->user_data = options->user_data; + s2n_handler->on_data_read = options->on_data_read; + s2n_handler->on_error = options->on_error; + s2n_handler->on_negotiation_result = options->on_negotiation_result; + s2n_handler->advertise_alpn_message = options->advertise_alpn_message; + + s2n_handler->latest_message_completion_user_data = NULL; + s2n_handler->latest_message_on_completion = NULL; + s2n_handler->slot = slot; + aws_linked_list_init(&s2n_handler->input_queue); + + s2n_handler->protocol = aws_byte_buf_from_array(NULL, 0); + + if (options->server_name) { + + if (s2n_set_server_name(s2n_handler->connection, aws_string_c_str(options->server_name))) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_conn; + } + } + + s2n_handler->negotiation_finished = false; + + s2n_connection_set_recv_cb(s2n_handler->connection, s_s2n_handler_recv); + s2n_connection_set_recv_ctx(s2n_handler->connection, s2n_handler); + s2n_connection_set_send_cb(s2n_handler->connection, s_s2n_handler_send); + s2n_connection_set_send_ctx(s2n_handler->connection, s2n_handler); + s2n_connection_set_blinding(s2n_handler->connection, S2N_SELF_SERVICE_BLINDING); + + if (options->alpn_list) { + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, + "id=%p: Setting ALPN list %s", + (void *)&s2n_handler->handler, + aws_string_c_str(options->alpn_list)); + + const char protocols_cpy[4][128]; + AWS_ZERO_ARRAY(protocols_cpy); + size_t protocols_size = 4; + if (s_parse_protocol_preferences(options->alpn_list, protocols_cpy, &protocols_size)) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_conn; + } + + const char *protocols[4]; + AWS_ZERO_ARRAY(protocols); + for (size_t i = 0; i < protocols_size; ++i) { + protocols[i] = protocols_cpy[i]; + } + + if (s2n_connection_set_protocol_preferences( + s2n_handler->connection, (const char *const *)protocols, (int)protocols_size)) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_conn; + } + } + + if (s2n_connection_set_config(s2n_handler->connection, s2n_ctx->s2n_config)) { + AWS_LOGF_WARN( + AWS_LS_IO_TLS, + "id=%p: configuration error %s (%s)", + (void *)&s2n_handler->handler, + s2n_strerror(s2n_errno, "EN"), + s2n_strerror_debug(s2n_errno, "EN")); + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_conn; + } + + if (s_s2n_tls_channel_handler_schedule_thread_local_cleanup(slot)) { + goto cleanup_conn; + } + + return &s2n_handler->handler; + +cleanup_conn: + s2n_connection_free(s2n_handler->connection); + +cleanup_s2n_handler: + aws_mem_release(allocator, s2n_handler); + + return NULL; +} + +struct aws_channel_handler *aws_tls_client_handler_new( + struct aws_allocator *allocator, + struct aws_tls_connection_options *options, + struct aws_channel_slot *slot) { + + return s_new_tls_handler(allocator, options, slot, S2N_CLIENT); +} + +struct aws_channel_handler *aws_tls_server_handler_new( + struct aws_allocator *allocator, + struct aws_tls_connection_options *options, + struct aws_channel_slot *slot) { + + return s_new_tls_handler(allocator, options, slot, S2N_SERVER); +} + +static void s_s2n_ctx_destroy(struct s2n_ctx *s2n_ctx) { + if (s2n_ctx != NULL) { + s2n_config_free(s2n_ctx->s2n_config); + aws_mem_release(s2n_ctx->ctx.alloc, s2n_ctx); + } +} + +static struct aws_tls_ctx *s_tls_ctx_new( + struct aws_allocator *alloc, + const struct aws_tls_ctx_options *options, + s2n_mode mode) { + struct s2n_ctx *s2n_ctx = aws_mem_calloc(alloc, 1, sizeof(struct s2n_ctx)); + + if (!s2n_ctx) { + return NULL; + } + + if (!aws_tls_is_cipher_pref_supported(options->cipher_pref)) { + aws_raise_error(AWS_IO_TLS_CIPHER_PREF_UNSUPPORTED); + AWS_LOGF_ERROR(AWS_LS_IO_TLS, "static: TLS Cipher Preference is not supported: %d.", options->cipher_pref); + return NULL; + } + + s2n_ctx->ctx.alloc = alloc; + s2n_ctx->ctx.impl = s2n_ctx; + aws_ref_count_init(&s2n_ctx->ctx.ref_count, s2n_ctx, (aws_simple_completion_callback *)s_s2n_ctx_destroy); + s2n_ctx->s2n_config = s2n_config_new(); + + if (!s2n_ctx->s2n_config) { + goto cleanup_s2n_ctx; + } + + switch (options->minimum_tls_version) { + case AWS_IO_SSLv3: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "CloudFront-SSL-v-3"); + break; + case AWS_IO_TLSv1: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "CloudFront-TLS-1-0-2014"); + break; + case AWS_IO_TLSv1_1: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "ELBSecurityPolicy-TLS-1-1-2017-01"); + break; + case AWS_IO_TLSv1_2: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "ELBSecurityPolicy-TLS-1-2-Ext-2018-06"); + break; + case AWS_IO_TLSv1_3: + AWS_LOGF_ERROR(AWS_LS_IO_TLS, "TLS 1.3 is not supported yet."); + /* sorry guys, we'll add this as soon as s2n does. */ + aws_raise_error(AWS_IO_TLS_VERSION_UNSUPPORTED); + goto cleanup_s2n_ctx; + case AWS_IO_TLS_VER_SYS_DEFAULTS: + default: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "ELBSecurityPolicy-TLS-1-1-2017-01"); + } + + switch (options->cipher_pref) { + case AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT: + /* No-Op, if the user configured a minimum_tls_version then a version-specific Cipher Preference was set */ + break; + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2019_06: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "KMS-PQ-TLS-1-0-2019-06"); + break; + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_SIKE_TLSv1_0_2019_11: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "PQ-SIKE-TEST-TLS-1-0-2019-11"); + break; + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2020_02: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "KMS-PQ-TLS-1-0-2020-02"); + break; + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_SIKE_TLSv1_0_2020_02: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "PQ-SIKE-TEST-TLS-1-0-2020-02"); + break; + case AWS_IO_TLS_CIPHER_PREF_KMS_PQ_TLSv1_0_2020_07: + s2n_config_set_cipher_preferences(s2n_ctx->s2n_config, "KMS-PQ-TLS-1-0-2020-07"); + break; + default: + AWS_LOGF_ERROR(AWS_LS_IO_TLS, "Unrecognized TLS Cipher Preference: %d", options->cipher_pref); + aws_raise_error(AWS_IO_TLS_CIPHER_PREF_UNSUPPORTED); + goto cleanup_s2n_ctx; + } + + if (options->certificate.len && options->private_key.len) { + AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "ctx: Certificate and key have been set, setting them up now."); + + if (!aws_text_is_utf8(options->certificate.buffer, options->certificate.len)) { + AWS_LOGF_ERROR(AWS_LS_IO_TLS, "static: failed to import certificate, must be ASCII/UTF-8 encoded"); + aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); + goto cleanup_s2n_ctx; + } + + if (!aws_text_is_utf8(options->private_key.buffer, options->private_key.len)) { + AWS_LOGF_ERROR(AWS_LS_IO_TLS, "static: failed to import private key, must be ASCII/UTF-8 encoded"); + aws_raise_error(AWS_IO_FILE_VALIDATION_FAILURE); + goto cleanup_s2n_ctx; + } + + int err_code = s2n_config_add_cert_chain_and_key( + s2n_ctx->s2n_config, (const char *)options->certificate.buffer, (const char *)options->private_key.buffer); + + if (mode == S2N_CLIENT) { + s2n_config_set_client_auth_type(s2n_ctx->s2n_config, S2N_CERT_AUTH_REQUIRED); + } + + if (err_code != S2N_ERR_T_OK) { + AWS_LOGF_ERROR( + AWS_LS_IO_TLS, + "ctx: configuration error %s (%s)", + s2n_strerror(s2n_errno, "EN"), + s2n_strerror_debug(s2n_errno, "EN")); + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } + + if (options->verify_peer) { + if (s2n_config_set_check_stapled_ocsp_response(s2n_ctx->s2n_config, 1) == S2N_SUCCESS) { + if (s2n_config_set_status_request_type(s2n_ctx->s2n_config, S2N_STATUS_REQUEST_OCSP) != S2N_SUCCESS) { + AWS_LOGF_ERROR( + AWS_LS_IO_TLS, + "ctx: ocsp status request cannot be set: %s (%s)", + s2n_strerror(s2n_errno, "EN"), + s2n_strerror_debug(s2n_errno, "EN")); + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } else { + if (s2n_error_get_type(s2n_errno) == S2N_ERR_T_USAGE) { + AWS_LOGF_INFO(AWS_LS_IO_TLS, "ctx: cannot enable ocsp stapling: %s", s2n_strerror(s2n_errno, "EN")); + } else { + AWS_LOGF_ERROR( + AWS_LS_IO_TLS, + "ctx: cannot enable ocsp stapling: %s (%s)", + s2n_strerror(s2n_errno, "EN"), + s2n_strerror_debug(s2n_errno, "EN")); + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } + + if (options->ca_path) { + if (s2n_config_set_verification_ca_location( + s2n_ctx->s2n_config, NULL, aws_string_c_str(options->ca_path))) { + AWS_LOGF_ERROR( + AWS_LS_IO_TLS, + "ctx: configuration error %s (%s)", + s2n_strerror(s2n_errno, "EN"), + s2n_strerror_debug(s2n_errno, "EN")); + AWS_LOGF_ERROR(AWS_LS_IO_TLS, "Failed to set ca_path %s\n", aws_string_c_str(options->ca_path)); + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } + + if (options->ca_file.len) { + if (s2n_config_add_pem_to_trust_store(s2n_ctx->s2n_config, (const char *)options->ca_file.buffer)) { + AWS_LOGF_ERROR( + AWS_LS_IO_TLS, + "ctx: configuration error %s (%s)", + s2n_strerror(s2n_errno, "EN"), + s2n_strerror_debug(s2n_errno, "EN")); + AWS_LOGF_ERROR(AWS_LS_IO_TLS, "Failed to set ca_file %s\n", (const char *)options->ca_file.buffer); + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } + + if (!options->ca_path && !options->ca_file.len) { + if (s2n_config_set_verification_ca_location(s2n_ctx->s2n_config, s_default_ca_file, s_default_ca_dir)) { + AWS_LOGF_ERROR( + AWS_LS_IO_TLS, + "ctx: configuration error %s (%s)", + s2n_strerror(s2n_errno, "EN"), + s2n_strerror_debug(s2n_errno, "EN")); + AWS_LOGF_ERROR( + AWS_LS_IO_TLS, "Failed to set ca_path: %s and ca_file %s\n", s_default_ca_dir, s_default_ca_file); + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } + + if (mode == S2N_SERVER && s2n_config_set_client_auth_type(s2n_ctx->s2n_config, S2N_CERT_AUTH_REQUIRED)) { + AWS_LOGF_ERROR( + AWS_LS_IO_TLS, + "ctx: configuration error %s (%s)", + s2n_strerror(s2n_errno, "EN"), + s2n_strerror_debug(s2n_errno, "EN")); + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } else if (mode != S2N_SERVER) { + AWS_LOGF_WARN( + AWS_LS_IO_TLS, + "ctx: X.509 validation has been disabled. " + "If this is not running in a test environment, this is likely a security vulnerability."); + if (s2n_config_disable_x509_verification(s2n_ctx->s2n_config)) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } + + if (options->alpn_list) { + AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "ctx: Setting ALPN list %s", aws_string_c_str(options->alpn_list)); + const char protocols_cpy[4][128]; + AWS_ZERO_ARRAY(protocols_cpy); + size_t protocols_size = 4; + if (s_parse_protocol_preferences(options->alpn_list, protocols_cpy, &protocols_size)) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + + const char *protocols[4]; + AWS_ZERO_ARRAY(protocols); + for (size_t i = 0; i < protocols_size; ++i) { + protocols[i] = protocols_cpy[i]; + } + + if (s2n_config_set_protocol_preferences(s2n_ctx->s2n_config, protocols, (int)protocols_size)) { + aws_raise_error(AWS_IO_TLS_CTX_ERROR); + goto cleanup_s2n_config; + } + } + + if (options->max_fragment_size == 512) { + s2n_config_send_max_fragment_length(s2n_ctx->s2n_config, S2N_TLS_MAX_FRAG_LEN_512); + } else if (options->max_fragment_size == 1024) { + s2n_config_send_max_fragment_length(s2n_ctx->s2n_config, S2N_TLS_MAX_FRAG_LEN_1024); + } else if (options->max_fragment_size == 2048) { + s2n_config_send_max_fragment_length(s2n_ctx->s2n_config, S2N_TLS_MAX_FRAG_LEN_2048); + } else if (options->max_fragment_size == 4096) { + s2n_config_send_max_fragment_length(s2n_ctx->s2n_config, S2N_TLS_MAX_FRAG_LEN_4096); + } + + return &s2n_ctx->ctx; + +cleanup_s2n_config: + s2n_config_free(s2n_ctx->s2n_config); + +cleanup_s2n_ctx: + aws_mem_release(alloc, s2n_ctx); + + return NULL; +} + +struct aws_tls_ctx *aws_tls_server_ctx_new(struct aws_allocator *alloc, const struct aws_tls_ctx_options *options) { + aws_io_fatal_assert_library_initialized(); + return s_tls_ctx_new(alloc, options, S2N_SERVER); +} + +struct aws_tls_ctx *aws_tls_client_ctx_new(struct aws_allocator *alloc, const struct aws_tls_ctx_options *options) { + aws_io_fatal_assert_library_initialized(); + return s_tls_ctx_new(alloc, options, S2N_CLIENT); +} diff --git a/contrib/restricted/aws/aws-c-io/source/socket_channel_handler.c b/contrib/restricted/aws/aws-c-io/source/socket_channel_handler.c index 40a178123b..86aeced988 100644 --- a/contrib/restricted/aws/aws-c-io/source/socket_channel_handler.c +++ b/contrib/restricted/aws/aws-c-io/source/socket_channel_handler.c @@ -1,421 +1,421 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/socket_channel_handler.h> - -#include <aws/common/error.h> -#include <aws/common/task_scheduler.h> - -#include <aws/io/event_loop.h> -#include <aws/io/logging.h> -#include <aws/io/socket.h> -#include <aws/io/statistics.h> - -#if _MSC_VER -# pragma warning(disable : 4204) /* non-constant aggregate initializer */ -#endif - -struct socket_handler { - struct aws_socket *socket; - struct aws_channel_slot *slot; - size_t max_rw_size; - struct aws_channel_task read_task_storage; - struct aws_channel_task shutdown_task_storage; - struct aws_crt_statistics_socket stats; - int shutdown_err_code; - bool shutdown_in_progress; -}; - -static int s_socket_process_read_message( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_io_message *message) { - (void)handler; - (void)slot; - (void)message; - - AWS_LOGF_FATAL( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: process_read_message called on " - "socket handler. This should never happen", - (void *)handler); - - /*since a socket handler will ALWAYS be the first handler in a channel, - * this should NEVER happen, if it does it's a programmer error.*/ - AWS_ASSERT(0); - return aws_raise_error(AWS_IO_CHANNEL_ERROR_ERROR_CANT_ACCEPT_INPUT); -} - -/* invoked by the socket when a write has completed or failed. */ -static void s_on_socket_write_complete( - struct aws_socket *socket, - int error_code, - size_t amount_written, - void *user_data) { - - if (user_data) { - struct aws_io_message *message = user_data; - struct aws_channel *channel = message->owning_channel; - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "static: write of size %llu, completed on channel %p", - (unsigned long long)amount_written, - (void *)channel); - - if (message->on_completion) { - message->on_completion(channel, message, error_code, message->user_data); - } - - if (socket && socket->handler) { - struct socket_handler *socket_handler = socket->handler->impl; - socket_handler->stats.bytes_written += amount_written; - } - - aws_mem_release(message->allocator, message); - - if (error_code) { - aws_channel_shutdown(channel, error_code); - } - } -} - -static int s_socket_process_write_message( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - struct aws_io_message *message) { - (void)slot; - struct socket_handler *socket_handler = handler->impl; - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: writing message of size %llu", - (void *)handler, - (unsigned long long)message->message_data.len); - - if (!aws_socket_is_open(socket_handler->socket)) { - return aws_raise_error(AWS_IO_SOCKET_CLOSED); - } - - struct aws_byte_cursor cursor = aws_byte_cursor_from_buf(&message->message_data); - if (aws_socket_write(socket_handler->socket, &cursor, s_on_socket_write_complete, message)) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -static void s_read_task(struct aws_channel_task *task, void *arg, aws_task_status status); - -static void s_on_readable_notification(struct aws_socket *socket, int error_code, void *user_data); - -/* Ok this next function is VERY important for how back pressure works. Here's what it's supposed to be doing: - * - * See how much data downstream is willing to accept. - * See how much we're actually willing to read per event loop tick (usually 16 kb). - * Take the minimum of those two. - * Try and read as much as possible up to the calculated max read. - * If we didn't read up to the max_read, we go back to waiting on the event loop to tell us we can read more. - * If we did read up to the max_read, we stop reading immediately and wait for either for a window update, - * or schedule a task to enforce fairness for other sockets in the event loop if we read up to the max - * read per event loop tick. - */ -static void s_do_read(struct socket_handler *socket_handler) { - - size_t downstream_window = aws_channel_slot_downstream_read_window(socket_handler->slot); - size_t max_to_read = - downstream_window > socket_handler->max_rw_size ? socket_handler->max_rw_size : downstream_window; - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: invoking read. Downstream window %llu, max_to_read %llu", - (void *)socket_handler->slot->handler, - (unsigned long long)downstream_window, - (unsigned long long)max_to_read); - - if (max_to_read == 0) { - return; - } - - size_t total_read = 0; - size_t read = 0; - while (total_read < max_to_read && !socket_handler->shutdown_in_progress) { - size_t iter_max_read = max_to_read - total_read; - - struct aws_io_message *message = aws_channel_acquire_message_from_pool( - socket_handler->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, iter_max_read); - - if (!message) { - break; - } - - if (aws_socket_read(socket_handler->socket, &message->message_data, &read)) { - aws_mem_release(message->allocator, message); - break; - } - - total_read += read; - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: read %llu from socket", - (void *)socket_handler->slot->handler, - (unsigned long long)read); - - if (aws_channel_slot_send_message(socket_handler->slot, message, AWS_CHANNEL_DIR_READ)) { - aws_mem_release(message->allocator, message); - break; - } - } - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: total read on this tick %llu", - (void *)&socket_handler->slot->handler, - (unsigned long long)total_read); - - socket_handler->stats.bytes_read += total_read; - - /* resubscribe as long as there's no error, just return if we're in a would block scenario. */ - if (total_read < max_to_read) { - int last_error = aws_last_error(); - - if (last_error != AWS_IO_READ_WOULD_BLOCK && !socket_handler->shutdown_in_progress) { - aws_channel_shutdown(socket_handler->slot->channel, last_error); - } - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: out of data to read on socket. " - "Waiting on event-loop notification.", - (void *)socket_handler->slot->handler); - return; - } - /* in this case, everything was fine, but there's still pending reads. We need to schedule a task to do the read - * again. */ - if (!socket_handler->shutdown_in_progress && total_read == socket_handler->max_rw_size && - !socket_handler->read_task_storage.task_fn) { - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: more data is pending read, but we've exceeded " - "the max read on this tick. Scheduling a task to read on next tick.", - (void *)socket_handler->slot->handler); - aws_channel_task_init( - &socket_handler->read_task_storage, s_read_task, socket_handler, "socket_handler_re_read"); - aws_channel_schedule_task_now(socket_handler->slot->channel, &socket_handler->read_task_storage); - } -} - -/* the socket is either readable or errored out. If it's readable, kick off s_do_read() to do its thing. - * If an error, start the channel shutdown process. */ -static void s_on_readable_notification(struct aws_socket *socket, int error_code, void *user_data) { - (void)socket; - - struct socket_handler *socket_handler = user_data; - AWS_LOGF_TRACE(AWS_LS_IO_SOCKET_HANDLER, "id=%p: socket is now readable", (void *)socket_handler->slot->handler); - - /* read regardless so we can pick up data that was sent prior to the close. For example, peer sends a TLS ALERT - * then immediately closes the socket. On some platforms, we'll never see the readable flag. So we want to make - * sure we read the ALERT, otherwise, we'll end up telling the user that the channel shutdown because of a socket - * closure, when in reality it was a TLS error */ - s_do_read(socket_handler); - - if (error_code && !socket_handler->shutdown_in_progress) { - aws_channel_shutdown(socket_handler->slot->channel, error_code); - } -} - -/* Either the result of a context switch (for fairness in the event loop), or a window update. */ -static void s_read_task(struct aws_channel_task *task, void *arg, aws_task_status status) { - task->task_fn = NULL; - task->arg = NULL; - - if (status == AWS_TASK_STATUS_RUN_READY) { - struct socket_handler *socket_handler = arg; - s_do_read(socket_handler); - } -} - -static int s_socket_increment_read_window( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - size_t size) { - (void)size; - - struct socket_handler *socket_handler = handler->impl; - - if (!socket_handler->shutdown_in_progress && !socket_handler->read_task_storage.task_fn) { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: increment read window message received, scheduling" - " task for another read operation.", - (void *)handler); - - aws_channel_task_init( - &socket_handler->read_task_storage, s_read_task, socket_handler, "socket_handler_read_on_window_increment"); - aws_channel_schedule_task_now(slot->channel, &socket_handler->read_task_storage); - } - - return AWS_OP_SUCCESS; -} - -static void s_close_task(struct aws_channel_task *task, void *arg, aws_task_status status) { - (void)task; - (void)status; - - struct aws_channel_handler *handler = arg; - struct socket_handler *socket_handler = handler->impl; - - /* - * Run this unconditionally regardless of status, otherwise channel will not - * finish shutting down properly - */ - - /* this only happens in write direction. */ - /* we also don't care about the free_scarce_resource_immediately - * code since we're always the last one in the shutdown sequence. */ - aws_channel_slot_on_handler_shutdown_complete( - socket_handler->slot, AWS_CHANNEL_DIR_WRITE, socket_handler->shutdown_err_code, false); -} - -static int s_socket_shutdown( - struct aws_channel_handler *handler, - struct aws_channel_slot *slot, - enum aws_channel_direction dir, - int error_code, - bool free_scarce_resource_immediately) { - struct socket_handler *socket_handler = (struct socket_handler *)handler->impl; - - socket_handler->shutdown_in_progress = true; - if (dir == AWS_CHANNEL_DIR_READ) { - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: shutting down read direction with error_code %d", - (void *)handler, - error_code); - if (free_scarce_resource_immediately && aws_socket_is_open(socket_handler->socket)) { - if (aws_socket_close(socket_handler->socket)) { - return AWS_OP_ERR; - } - } - - return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, free_scarce_resource_immediately); - } - - AWS_LOGF_TRACE( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: shutting down write direction with error_code %d", - (void *)handler, - error_code); - if (aws_socket_is_open(socket_handler->socket)) { - aws_socket_close(socket_handler->socket); - } - - /* Schedule a task to complete the shutdown, in case a do_read task is currently pending. - * It's OK to delay the shutdown, even when free_scarce_resources_immediately is true, - * because the socket has been closed: mitigating the risk that the socket is still being abused by - * a hostile peer. */ - aws_channel_task_init(&socket_handler->shutdown_task_storage, s_close_task, handler, "socket_handler_close"); - socket_handler->shutdown_err_code = error_code; - aws_channel_schedule_task_now(slot->channel, &socket_handler->shutdown_task_storage); - return AWS_OP_SUCCESS; -} - -static size_t s_message_overhead(struct aws_channel_handler *handler) { - (void)handler; - return 0; -} - -static size_t s_socket_initial_window_size(struct aws_channel_handler *handler) { - (void)handler; - return SIZE_MAX; -} - -static void s_socket_destroy(struct aws_channel_handler *handler) { - if (handler != NULL) { - struct socket_handler *socket_handler = (struct socket_handler *)handler->impl; - if (socket_handler != NULL) { - aws_crt_statistics_socket_cleanup(&socket_handler->stats); - } - - aws_mem_release(handler->alloc, handler); - } -} - -static void s_reset_statistics(struct aws_channel_handler *handler) { - struct socket_handler *socket_handler = (struct socket_handler *)handler->impl; - - aws_crt_statistics_socket_reset(&socket_handler->stats); -} - -void s_gather_statistics(struct aws_channel_handler *handler, struct aws_array_list *stats_list) { - struct socket_handler *socket_handler = (struct socket_handler *)handler->impl; - - void *stats_base = &socket_handler->stats; - aws_array_list_push_back(stats_list, &stats_base); -} - -static struct aws_channel_handler_vtable s_vtable = { - .process_read_message = s_socket_process_read_message, - .destroy = s_socket_destroy, - .process_write_message = s_socket_process_write_message, - .initial_window_size = s_socket_initial_window_size, - .increment_read_window = s_socket_increment_read_window, - .shutdown = s_socket_shutdown, - .message_overhead = s_message_overhead, - .reset_statistics = s_reset_statistics, - .gather_statistics = s_gather_statistics, -}; - -struct aws_channel_handler *aws_socket_handler_new( - struct aws_allocator *allocator, - struct aws_socket *socket, - struct aws_channel_slot *slot, - size_t max_read_size) { - - /* make sure something has assigned this socket to an event loop, in client mode this will already have occurred. - In server mode, someone should have assigned it before calling us.*/ - AWS_ASSERT(aws_socket_get_event_loop(socket)); - - struct aws_channel_handler *handler = NULL; - - struct socket_handler *impl = NULL; - - if (!aws_mem_acquire_many( - allocator, 2, &handler, sizeof(struct aws_channel_handler), &impl, sizeof(struct socket_handler))) { - return NULL; - } - - impl->socket = socket; - impl->slot = slot; - impl->max_rw_size = max_read_size; - AWS_ZERO_STRUCT(impl->read_task_storage); - AWS_ZERO_STRUCT(impl->shutdown_task_storage); - impl->shutdown_in_progress = false; - if (aws_crt_statistics_socket_init(&impl->stats)) { - goto cleanup_handler; - } - - AWS_LOGF_DEBUG( - AWS_LS_IO_SOCKET_HANDLER, - "id=%p: Socket handler created with max_read_size of %llu", - (void *)handler, - (unsigned long long)max_read_size); - - handler->alloc = allocator; - handler->impl = impl; - handler->vtable = &s_vtable; - handler->slot = slot; - if (aws_socket_subscribe_to_readable_events(socket, s_on_readable_notification, impl)) { - goto cleanup_handler; - } - - socket->handler = handler; - - return handler; - -cleanup_handler: - aws_mem_release(allocator, handler); - - return NULL; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/socket_channel_handler.h> + +#include <aws/common/error.h> +#include <aws/common/task_scheduler.h> + +#include <aws/io/event_loop.h> +#include <aws/io/logging.h> +#include <aws/io/socket.h> +#include <aws/io/statistics.h> + +#if _MSC_VER +# pragma warning(disable : 4204) /* non-constant aggregate initializer */ +#endif + +struct socket_handler { + struct aws_socket *socket; + struct aws_channel_slot *slot; + size_t max_rw_size; + struct aws_channel_task read_task_storage; + struct aws_channel_task shutdown_task_storage; + struct aws_crt_statistics_socket stats; + int shutdown_err_code; + bool shutdown_in_progress; +}; + +static int s_socket_process_read_message( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_io_message *message) { + (void)handler; + (void)slot; + (void)message; + + AWS_LOGF_FATAL( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: process_read_message called on " + "socket handler. This should never happen", + (void *)handler); + + /*since a socket handler will ALWAYS be the first handler in a channel, + * this should NEVER happen, if it does it's a programmer error.*/ + AWS_ASSERT(0); + return aws_raise_error(AWS_IO_CHANNEL_ERROR_ERROR_CANT_ACCEPT_INPUT); +} + +/* invoked by the socket when a write has completed or failed. */ +static void s_on_socket_write_complete( + struct aws_socket *socket, + int error_code, + size_t amount_written, + void *user_data) { + + if (user_data) { + struct aws_io_message *message = user_data; + struct aws_channel *channel = message->owning_channel; + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "static: write of size %llu, completed on channel %p", + (unsigned long long)amount_written, + (void *)channel); + + if (message->on_completion) { + message->on_completion(channel, message, error_code, message->user_data); + } + + if (socket && socket->handler) { + struct socket_handler *socket_handler = socket->handler->impl; + socket_handler->stats.bytes_written += amount_written; + } + + aws_mem_release(message->allocator, message); + + if (error_code) { + aws_channel_shutdown(channel, error_code); + } + } +} + +static int s_socket_process_write_message( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_io_message *message) { + (void)slot; + struct socket_handler *socket_handler = handler->impl; + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: writing message of size %llu", + (void *)handler, + (unsigned long long)message->message_data.len); + + if (!aws_socket_is_open(socket_handler->socket)) { + return aws_raise_error(AWS_IO_SOCKET_CLOSED); + } + + struct aws_byte_cursor cursor = aws_byte_cursor_from_buf(&message->message_data); + if (aws_socket_write(socket_handler->socket, &cursor, s_on_socket_write_complete, message)) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +static void s_read_task(struct aws_channel_task *task, void *arg, aws_task_status status); + +static void s_on_readable_notification(struct aws_socket *socket, int error_code, void *user_data); + +/* Ok this next function is VERY important for how back pressure works. Here's what it's supposed to be doing: + * + * See how much data downstream is willing to accept. + * See how much we're actually willing to read per event loop tick (usually 16 kb). + * Take the minimum of those two. + * Try and read as much as possible up to the calculated max read. + * If we didn't read up to the max_read, we go back to waiting on the event loop to tell us we can read more. + * If we did read up to the max_read, we stop reading immediately and wait for either for a window update, + * or schedule a task to enforce fairness for other sockets in the event loop if we read up to the max + * read per event loop tick. + */ +static void s_do_read(struct socket_handler *socket_handler) { + + size_t downstream_window = aws_channel_slot_downstream_read_window(socket_handler->slot); + size_t max_to_read = + downstream_window > socket_handler->max_rw_size ? socket_handler->max_rw_size : downstream_window; + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: invoking read. Downstream window %llu, max_to_read %llu", + (void *)socket_handler->slot->handler, + (unsigned long long)downstream_window, + (unsigned long long)max_to_read); + + if (max_to_read == 0) { + return; + } + + size_t total_read = 0; + size_t read = 0; + while (total_read < max_to_read && !socket_handler->shutdown_in_progress) { + size_t iter_max_read = max_to_read - total_read; + + struct aws_io_message *message = aws_channel_acquire_message_from_pool( + socket_handler->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, iter_max_read); + + if (!message) { + break; + } + + if (aws_socket_read(socket_handler->socket, &message->message_data, &read)) { + aws_mem_release(message->allocator, message); + break; + } + + total_read += read; + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: read %llu from socket", + (void *)socket_handler->slot->handler, + (unsigned long long)read); + + if (aws_channel_slot_send_message(socket_handler->slot, message, AWS_CHANNEL_DIR_READ)) { + aws_mem_release(message->allocator, message); + break; + } + } + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: total read on this tick %llu", + (void *)&socket_handler->slot->handler, + (unsigned long long)total_read); + + socket_handler->stats.bytes_read += total_read; + + /* resubscribe as long as there's no error, just return if we're in a would block scenario. */ + if (total_read < max_to_read) { + int last_error = aws_last_error(); + + if (last_error != AWS_IO_READ_WOULD_BLOCK && !socket_handler->shutdown_in_progress) { + aws_channel_shutdown(socket_handler->slot->channel, last_error); + } + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: out of data to read on socket. " + "Waiting on event-loop notification.", + (void *)socket_handler->slot->handler); + return; + } + /* in this case, everything was fine, but there's still pending reads. We need to schedule a task to do the read + * again. */ + if (!socket_handler->shutdown_in_progress && total_read == socket_handler->max_rw_size && + !socket_handler->read_task_storage.task_fn) { + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: more data is pending read, but we've exceeded " + "the max read on this tick. Scheduling a task to read on next tick.", + (void *)socket_handler->slot->handler); + aws_channel_task_init( + &socket_handler->read_task_storage, s_read_task, socket_handler, "socket_handler_re_read"); + aws_channel_schedule_task_now(socket_handler->slot->channel, &socket_handler->read_task_storage); + } +} + +/* the socket is either readable or errored out. If it's readable, kick off s_do_read() to do its thing. + * If an error, start the channel shutdown process. */ +static void s_on_readable_notification(struct aws_socket *socket, int error_code, void *user_data) { + (void)socket; + + struct socket_handler *socket_handler = user_data; + AWS_LOGF_TRACE(AWS_LS_IO_SOCKET_HANDLER, "id=%p: socket is now readable", (void *)socket_handler->slot->handler); + + /* read regardless so we can pick up data that was sent prior to the close. For example, peer sends a TLS ALERT + * then immediately closes the socket. On some platforms, we'll never see the readable flag. So we want to make + * sure we read the ALERT, otherwise, we'll end up telling the user that the channel shutdown because of a socket + * closure, when in reality it was a TLS error */ + s_do_read(socket_handler); + + if (error_code && !socket_handler->shutdown_in_progress) { + aws_channel_shutdown(socket_handler->slot->channel, error_code); + } +} + +/* Either the result of a context switch (for fairness in the event loop), or a window update. */ +static void s_read_task(struct aws_channel_task *task, void *arg, aws_task_status status) { + task->task_fn = NULL; + task->arg = NULL; + + if (status == AWS_TASK_STATUS_RUN_READY) { + struct socket_handler *socket_handler = arg; + s_do_read(socket_handler); + } +} + +static int s_socket_increment_read_window( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + size_t size) { + (void)size; + + struct socket_handler *socket_handler = handler->impl; + + if (!socket_handler->shutdown_in_progress && !socket_handler->read_task_storage.task_fn) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: increment read window message received, scheduling" + " task for another read operation.", + (void *)handler); + + aws_channel_task_init( + &socket_handler->read_task_storage, s_read_task, socket_handler, "socket_handler_read_on_window_increment"); + aws_channel_schedule_task_now(slot->channel, &socket_handler->read_task_storage); + } + + return AWS_OP_SUCCESS; +} + +static void s_close_task(struct aws_channel_task *task, void *arg, aws_task_status status) { + (void)task; + (void)status; + + struct aws_channel_handler *handler = arg; + struct socket_handler *socket_handler = handler->impl; + + /* + * Run this unconditionally regardless of status, otherwise channel will not + * finish shutting down properly + */ + + /* this only happens in write direction. */ + /* we also don't care about the free_scarce_resource_immediately + * code since we're always the last one in the shutdown sequence. */ + aws_channel_slot_on_handler_shutdown_complete( + socket_handler->slot, AWS_CHANNEL_DIR_WRITE, socket_handler->shutdown_err_code, false); +} + +static int s_socket_shutdown( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + enum aws_channel_direction dir, + int error_code, + bool free_scarce_resource_immediately) { + struct socket_handler *socket_handler = (struct socket_handler *)handler->impl; + + socket_handler->shutdown_in_progress = true; + if (dir == AWS_CHANNEL_DIR_READ) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: shutting down read direction with error_code %d", + (void *)handler, + error_code); + if (free_scarce_resource_immediately && aws_socket_is_open(socket_handler->socket)) { + if (aws_socket_close(socket_handler->socket)) { + return AWS_OP_ERR; + } + } + + return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, free_scarce_resource_immediately); + } + + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: shutting down write direction with error_code %d", + (void *)handler, + error_code); + if (aws_socket_is_open(socket_handler->socket)) { + aws_socket_close(socket_handler->socket); + } + + /* Schedule a task to complete the shutdown, in case a do_read task is currently pending. + * It's OK to delay the shutdown, even when free_scarce_resources_immediately is true, + * because the socket has been closed: mitigating the risk that the socket is still being abused by + * a hostile peer. */ + aws_channel_task_init(&socket_handler->shutdown_task_storage, s_close_task, handler, "socket_handler_close"); + socket_handler->shutdown_err_code = error_code; + aws_channel_schedule_task_now(slot->channel, &socket_handler->shutdown_task_storage); + return AWS_OP_SUCCESS; +} + +static size_t s_message_overhead(struct aws_channel_handler *handler) { + (void)handler; + return 0; +} + +static size_t s_socket_initial_window_size(struct aws_channel_handler *handler) { + (void)handler; + return SIZE_MAX; +} + +static void s_socket_destroy(struct aws_channel_handler *handler) { + if (handler != NULL) { + struct socket_handler *socket_handler = (struct socket_handler *)handler->impl; + if (socket_handler != NULL) { + aws_crt_statistics_socket_cleanup(&socket_handler->stats); + } + + aws_mem_release(handler->alloc, handler); + } +} + +static void s_reset_statistics(struct aws_channel_handler *handler) { + struct socket_handler *socket_handler = (struct socket_handler *)handler->impl; + + aws_crt_statistics_socket_reset(&socket_handler->stats); +} + +void s_gather_statistics(struct aws_channel_handler *handler, struct aws_array_list *stats_list) { + struct socket_handler *socket_handler = (struct socket_handler *)handler->impl; + + void *stats_base = &socket_handler->stats; + aws_array_list_push_back(stats_list, &stats_base); +} + +static struct aws_channel_handler_vtable s_vtable = { + .process_read_message = s_socket_process_read_message, + .destroy = s_socket_destroy, + .process_write_message = s_socket_process_write_message, + .initial_window_size = s_socket_initial_window_size, + .increment_read_window = s_socket_increment_read_window, + .shutdown = s_socket_shutdown, + .message_overhead = s_message_overhead, + .reset_statistics = s_reset_statistics, + .gather_statistics = s_gather_statistics, +}; + +struct aws_channel_handler *aws_socket_handler_new( + struct aws_allocator *allocator, + struct aws_socket *socket, + struct aws_channel_slot *slot, + size_t max_read_size) { + + /* make sure something has assigned this socket to an event loop, in client mode this will already have occurred. + In server mode, someone should have assigned it before calling us.*/ + AWS_ASSERT(aws_socket_get_event_loop(socket)); + + struct aws_channel_handler *handler = NULL; + + struct socket_handler *impl = NULL; + + if (!aws_mem_acquire_many( + allocator, 2, &handler, sizeof(struct aws_channel_handler), &impl, sizeof(struct socket_handler))) { + return NULL; + } + + impl->socket = socket; + impl->slot = slot; + impl->max_rw_size = max_read_size; + AWS_ZERO_STRUCT(impl->read_task_storage); + AWS_ZERO_STRUCT(impl->shutdown_task_storage); + impl->shutdown_in_progress = false; + if (aws_crt_statistics_socket_init(&impl->stats)) { + goto cleanup_handler; + } + + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET_HANDLER, + "id=%p: Socket handler created with max_read_size of %llu", + (void *)handler, + (unsigned long long)max_read_size); + + handler->alloc = allocator; + handler->impl = impl; + handler->vtable = &s_vtable; + handler->slot = slot; + if (aws_socket_subscribe_to_readable_events(socket, s_on_readable_notification, impl)) { + goto cleanup_handler; + } + + socket->handler = handler; + + return handler; + +cleanup_handler: + aws_mem_release(allocator, handler); + + return NULL; +} diff --git a/contrib/restricted/aws/aws-c-io/source/statistics.c b/contrib/restricted/aws/aws-c-io/source/statistics.c index 52290a7558..ba10eaa901 100644 --- a/contrib/restricted/aws/aws-c-io/source/statistics.c +++ b/contrib/restricted/aws/aws-c-io/source/statistics.c @@ -1,44 +1,44 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/statistics.h> - -#include <aws/io/channel.h> -#include <aws/io/logging.h> - -int aws_crt_statistics_socket_init(struct aws_crt_statistics_socket *stats) { - AWS_ZERO_STRUCT(*stats); - stats->category = AWSCRT_STAT_CAT_SOCKET; - - return AWS_OP_SUCCESS; -} - -void aws_crt_statistics_socket_cleanup(struct aws_crt_statistics_socket *stats) { - (void)stats; -} - -void aws_crt_statistics_socket_reset(struct aws_crt_statistics_socket *stats) { - stats->bytes_read = 0; - stats->bytes_written = 0; -} - -int aws_crt_statistics_tls_init(struct aws_crt_statistics_tls *stats) { - AWS_ZERO_STRUCT(*stats); - stats->category = AWSCRT_STAT_CAT_TLS; - stats->handshake_status = AWS_TLS_NEGOTIATION_STATUS_NONE; - - return AWS_OP_SUCCESS; -} - -void aws_crt_statistics_tls_cleanup(struct aws_crt_statistics_tls *stats) { - (void)stats; -} - -void aws_crt_statistics_tls_reset(struct aws_crt_statistics_tls *stats) { - /* - * We currently don't have any resettable tls statistics yet, but they may be added in the future. - */ - (void)stats; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/statistics.h> + +#include <aws/io/channel.h> +#include <aws/io/logging.h> + +int aws_crt_statistics_socket_init(struct aws_crt_statistics_socket *stats) { + AWS_ZERO_STRUCT(*stats); + stats->category = AWSCRT_STAT_CAT_SOCKET; + + return AWS_OP_SUCCESS; +} + +void aws_crt_statistics_socket_cleanup(struct aws_crt_statistics_socket *stats) { + (void)stats; +} + +void aws_crt_statistics_socket_reset(struct aws_crt_statistics_socket *stats) { + stats->bytes_read = 0; + stats->bytes_written = 0; +} + +int aws_crt_statistics_tls_init(struct aws_crt_statistics_tls *stats) { + AWS_ZERO_STRUCT(*stats); + stats->category = AWSCRT_STAT_CAT_TLS; + stats->handshake_status = AWS_TLS_NEGOTIATION_STATUS_NONE; + + return AWS_OP_SUCCESS; +} + +void aws_crt_statistics_tls_cleanup(struct aws_crt_statistics_tls *stats) { + (void)stats; +} + +void aws_crt_statistics_tls_reset(struct aws_crt_statistics_tls *stats) { + /* + * We currently don't have any resettable tls statistics yet, but they may be added in the future. + */ + (void)stats; +} diff --git a/contrib/restricted/aws/aws-c-io/source/stream.c b/contrib/restricted/aws/aws-c-io/source/stream.c index 69c73ab243..d6b565b0b6 100644 --- a/contrib/restricted/aws/aws-c-io/source/stream.c +++ b/contrib/restricted/aws/aws-c-io/source/stream.c @@ -1,368 +1,368 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/stream.h> - -#include <aws/io/file_utils.h> - -#include <errno.h> - -#if _MSC_VER -# pragma warning(disable : 4996) /* fopen */ -#endif - -int aws_input_stream_seek(struct aws_input_stream *stream, aws_off_t offset, enum aws_stream_seek_basis basis) { - AWS_ASSERT(stream && stream->vtable && stream->vtable->seek); - - return stream->vtable->seek(stream, offset, basis); -} - -int aws_input_stream_read(struct aws_input_stream *stream, struct aws_byte_buf *dest) { - AWS_ASSERT(stream && stream->vtable && stream->vtable->read); - AWS_ASSERT(dest); - AWS_ASSERT(dest->len <= dest->capacity); - - /* Deal with this edge case here, instead of relying on every implementation to do it right. */ - if (dest->capacity == dest->len) { - return AWS_OP_SUCCESS; - } - - /* Prevent implementations from accidentally overwriting existing data in the buffer. - * Hand them a "safe" buffer that starts where the existing data ends. */ - const void *safe_buf_start = dest->buffer + dest->len; - const size_t safe_buf_capacity = dest->capacity - dest->len; - struct aws_byte_buf safe_buf = aws_byte_buf_from_empty_array(safe_buf_start, safe_buf_capacity); - - int read_result = stream->vtable->read(stream, &safe_buf); - - /* Ensure the implementation did not commit forbidden acts upon the buffer */ - AWS_FATAL_ASSERT( - (safe_buf.buffer == safe_buf_start) && (safe_buf.capacity == safe_buf_capacity) && - (safe_buf.len <= safe_buf_capacity)); - - if (read_result == AWS_OP_SUCCESS) { - /* Update the actual buffer */ - dest->len += safe_buf.len; - } - - return read_result; -} - -int aws_input_stream_get_status(struct aws_input_stream *stream, struct aws_stream_status *status) { - AWS_ASSERT(stream && stream->vtable && stream->vtable->get_status); - - return stream->vtable->get_status(stream, status); -} - -int aws_input_stream_get_length(struct aws_input_stream *stream, int64_t *out_length) { - AWS_ASSERT(stream && stream->vtable && stream->vtable->get_length); - - return stream->vtable->get_length(stream, out_length); -} - -void aws_input_stream_destroy(struct aws_input_stream *stream) { - if (stream != NULL) { - AWS_ASSERT(stream->vtable && stream->vtable->destroy); - - stream->vtable->destroy(stream); - } -} - -/* - * cursor stream implementation - */ - -struct aws_input_stream_byte_cursor_impl { - struct aws_byte_cursor original_cursor; - struct aws_byte_cursor current_cursor; -}; - -/* - * This is an ugly function that, in the absence of better guidance, is designed to handle all possible combinations of - * aws_off_t (int32_t, int64_t) x all possible combinations of size_t (uint32_t, uint64_t). Whether the anomalous - * combination of int64_t vs. uint32_t is even possible on any real platform is unknown. If size_t ever exceeds 64 bits - * this function will fail badly. - * - * Safety and invariant assumptions are sprinkled via comments. The overall strategy is to cast up to 64 bits and - * perform all arithmetic there, being careful with signed vs. unsigned to prevent bad operations. - * - * Assumption #1: aws_off_t resolves to a signed integer 64 bits or smaller - * Assumption #2: size_t resolves to an unsigned integer 64 bits or smaller - */ - -AWS_STATIC_ASSERT(sizeof(aws_off_t) <= 8); -AWS_STATIC_ASSERT(sizeof(size_t) <= 8); - -static int s_aws_input_stream_byte_cursor_seek( - struct aws_input_stream *stream, - aws_off_t offset, - enum aws_stream_seek_basis basis) { - struct aws_input_stream_byte_cursor_impl *impl = stream->impl; - - uint64_t final_offset = 0; - int64_t checked_offset = offset; /* safe by assumption 1 */ - - switch (basis) { - case AWS_SSB_BEGIN: - /* - * (uint64_t)checked_offset -- safe by virtue of the earlier is-negative check + Assumption 1 - * (uint64_t)impl->original_cursor.len -- safe via assumption 2 - */ - if (checked_offset < 0 || (uint64_t)checked_offset > (uint64_t)impl->original_cursor.len) { - return aws_raise_error(AWS_IO_STREAM_INVALID_SEEK_POSITION); - } - - /* safe because negative offsets were turned into an error */ - final_offset = (uint64_t)checked_offset; - break; - - case AWS_SSB_END: - /* - * -checked_offset -- safe as long checked_offset is not INT64_MIN which was previously checked - * (uint64_t)(-checked_offset) -- safe because (-checked_offset) is positive (and < INT64_MAX < UINT64_MAX) - */ - if (checked_offset > 0 || checked_offset == INT64_MIN || - (uint64_t)(-checked_offset) > (uint64_t)impl->original_cursor.len) { - return aws_raise_error(AWS_IO_STREAM_INVALID_SEEK_POSITION); - } - - /* cases that would make this unsafe became errors with previous conditional */ - final_offset = (uint64_t)impl->original_cursor.len - (uint64_t)(-checked_offset); - break; - } - - /* true because we already validated against (impl->original_cursor.len) which is <= SIZE_MAX */ - AWS_ASSERT(final_offset <= SIZE_MAX); - - /* safe via previous assert */ - size_t final_offset_sz = (size_t)final_offset; - - /* sanity */ - AWS_ASSERT(final_offset_sz <= impl->current_cursor.len); - - impl->current_cursor = impl->original_cursor; - - /* let's skip advance */ - impl->current_cursor.ptr += final_offset_sz; - impl->current_cursor.len -= final_offset_sz; - - return AWS_OP_SUCCESS; -} - -static int s_aws_input_stream_byte_cursor_read(struct aws_input_stream *stream, struct aws_byte_buf *dest) { - struct aws_input_stream_byte_cursor_impl *impl = stream->impl; - - size_t actually_read = dest->capacity - dest->len; - if (actually_read > impl->current_cursor.len) { - actually_read = impl->current_cursor.len; - } - - if (!aws_byte_buf_write(dest, impl->current_cursor.ptr, actually_read)) { - return aws_raise_error(AWS_IO_STREAM_READ_FAILED); - } - - aws_byte_cursor_advance(&impl->current_cursor, actually_read); - - return AWS_OP_SUCCESS; -} - -static int s_aws_input_stream_byte_cursor_get_status( - struct aws_input_stream *stream, - struct aws_stream_status *status) { - struct aws_input_stream_byte_cursor_impl *impl = stream->impl; - - status->is_end_of_stream = impl->current_cursor.len == 0; - status->is_valid = true; - - return AWS_OP_SUCCESS; -} - -static int s_aws_input_stream_byte_cursor_get_length(struct aws_input_stream *stream, int64_t *out_length) { - struct aws_input_stream_byte_cursor_impl *impl = stream->impl; - -#if SIZE_MAX > INT64_MAX - size_t length = impl->original_cursor.len; - if (length > INT64_MAX) { - return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED); - } -#endif - - *out_length = (int64_t)impl->original_cursor.len; - - return AWS_OP_SUCCESS; -} - -static void s_aws_input_stream_byte_cursor_destroy(struct aws_input_stream *stream) { - aws_mem_release(stream->allocator, stream); -} - -static struct aws_input_stream_vtable s_aws_input_stream_byte_cursor_vtable = { - .seek = s_aws_input_stream_byte_cursor_seek, - .read = s_aws_input_stream_byte_cursor_read, - .get_status = s_aws_input_stream_byte_cursor_get_status, - .get_length = s_aws_input_stream_byte_cursor_get_length, - .destroy = s_aws_input_stream_byte_cursor_destroy}; - -struct aws_input_stream *aws_input_stream_new_from_cursor( - struct aws_allocator *allocator, - const struct aws_byte_cursor *cursor) { - - struct aws_input_stream *input_stream = NULL; - struct aws_input_stream_byte_cursor_impl *impl = NULL; - - aws_mem_acquire_many( - allocator, - 2, - &input_stream, - sizeof(struct aws_input_stream), - &impl, - sizeof(struct aws_input_stream_byte_cursor_impl)); - - if (!input_stream) { - return NULL; - } - - AWS_ZERO_STRUCT(*input_stream); - AWS_ZERO_STRUCT(*impl); - - input_stream->allocator = allocator; - input_stream->vtable = &s_aws_input_stream_byte_cursor_vtable; - input_stream->impl = impl; - - impl->original_cursor = *cursor; - impl->current_cursor = *cursor; - - return input_stream; -} - -/* - * file-based input stream - */ -struct aws_input_stream_file_impl { - FILE *file; - bool close_on_clean_up; -}; - -static int s_aws_input_stream_file_seek( - struct aws_input_stream *stream, - aws_off_t offset, - enum aws_stream_seek_basis basis) { - struct aws_input_stream_file_impl *impl = stream->impl; - - int whence = (basis == AWS_SSB_BEGIN) ? SEEK_SET : SEEK_END; - if (aws_fseek(impl->file, offset, whence)) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -static int s_aws_input_stream_file_read(struct aws_input_stream *stream, struct aws_byte_buf *dest) { - struct aws_input_stream_file_impl *impl = stream->impl; - - size_t max_read = dest->capacity - dest->len; - size_t actually_read = fread(dest->buffer + dest->len, 1, max_read, impl->file); - if (actually_read == 0) { - if (ferror(impl->file)) { - return aws_raise_error(AWS_IO_STREAM_READ_FAILED); - } - } - - dest->len += actually_read; - - return AWS_OP_SUCCESS; -} - -static int s_aws_input_stream_file_get_status(struct aws_input_stream *stream, struct aws_stream_status *status) { - struct aws_input_stream_file_impl *impl = stream->impl; - - status->is_end_of_stream = feof(impl->file) != 0; - status->is_valid = ferror(impl->file) == 0; - - return AWS_OP_SUCCESS; -} - -static int s_aws_input_stream_file_get_length(struct aws_input_stream *stream, int64_t *length) { - struct aws_input_stream_file_impl *impl = stream->impl; - - return aws_file_get_length(impl->file, length); -} - -static void s_aws_input_stream_file_destroy(struct aws_input_stream *stream) { - struct aws_input_stream_file_impl *impl = stream->impl; - - if (impl->close_on_clean_up && impl->file) { - fclose(impl->file); - } - - aws_mem_release(stream->allocator, stream); -} - -static struct aws_input_stream_vtable s_aws_input_stream_file_vtable = { - .seek = s_aws_input_stream_file_seek, - .read = s_aws_input_stream_file_read, - .get_status = s_aws_input_stream_file_get_status, - .get_length = s_aws_input_stream_file_get_length, - .destroy = s_aws_input_stream_file_destroy}; - -struct aws_input_stream *aws_input_stream_new_from_file(struct aws_allocator *allocator, const char *file_name) { - - struct aws_input_stream *input_stream = NULL; - struct aws_input_stream_file_impl *impl = NULL; - - aws_mem_acquire_many( - allocator, 2, &input_stream, sizeof(struct aws_input_stream), &impl, sizeof(struct aws_input_stream_file_impl)); - - if (!input_stream) { - return NULL; - } - - AWS_ZERO_STRUCT(*input_stream); - AWS_ZERO_STRUCT(*impl); - - input_stream->allocator = allocator; - input_stream->vtable = &s_aws_input_stream_file_vtable; - input_stream->impl = impl; - - impl->file = fopen(file_name, "r"); - if (impl->file == NULL) { - aws_translate_and_raise_io_error(errno); - goto on_error; - } - - impl->close_on_clean_up = true; - - return input_stream; - -on_error: - - aws_input_stream_destroy(input_stream); - - return NULL; -} - -struct aws_input_stream *aws_input_stream_new_from_open_file(struct aws_allocator *allocator, FILE *file) { - struct aws_input_stream *input_stream = NULL; - struct aws_input_stream_file_impl *impl = NULL; - - aws_mem_acquire_many( - allocator, 2, &input_stream, sizeof(struct aws_input_stream), &impl, sizeof(struct aws_input_stream_file_impl)); - - if (!input_stream) { - return NULL; - } - - AWS_ZERO_STRUCT(*input_stream); - AWS_ZERO_STRUCT(*impl); - - input_stream->allocator = allocator; - input_stream->vtable = &s_aws_input_stream_file_vtable; - input_stream->impl = impl; - - impl->file = file; - impl->close_on_clean_up = false; - - return input_stream; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/stream.h> + +#include <aws/io/file_utils.h> + +#include <errno.h> + +#if _MSC_VER +# pragma warning(disable : 4996) /* fopen */ +#endif + +int aws_input_stream_seek(struct aws_input_stream *stream, aws_off_t offset, enum aws_stream_seek_basis basis) { + AWS_ASSERT(stream && stream->vtable && stream->vtable->seek); + + return stream->vtable->seek(stream, offset, basis); +} + +int aws_input_stream_read(struct aws_input_stream *stream, struct aws_byte_buf *dest) { + AWS_ASSERT(stream && stream->vtable && stream->vtable->read); + AWS_ASSERT(dest); + AWS_ASSERT(dest->len <= dest->capacity); + + /* Deal with this edge case here, instead of relying on every implementation to do it right. */ + if (dest->capacity == dest->len) { + return AWS_OP_SUCCESS; + } + + /* Prevent implementations from accidentally overwriting existing data in the buffer. + * Hand them a "safe" buffer that starts where the existing data ends. */ + const void *safe_buf_start = dest->buffer + dest->len; + const size_t safe_buf_capacity = dest->capacity - dest->len; + struct aws_byte_buf safe_buf = aws_byte_buf_from_empty_array(safe_buf_start, safe_buf_capacity); + + int read_result = stream->vtable->read(stream, &safe_buf); + + /* Ensure the implementation did not commit forbidden acts upon the buffer */ + AWS_FATAL_ASSERT( + (safe_buf.buffer == safe_buf_start) && (safe_buf.capacity == safe_buf_capacity) && + (safe_buf.len <= safe_buf_capacity)); + + if (read_result == AWS_OP_SUCCESS) { + /* Update the actual buffer */ + dest->len += safe_buf.len; + } + + return read_result; +} + +int aws_input_stream_get_status(struct aws_input_stream *stream, struct aws_stream_status *status) { + AWS_ASSERT(stream && stream->vtable && stream->vtable->get_status); + + return stream->vtable->get_status(stream, status); +} + +int aws_input_stream_get_length(struct aws_input_stream *stream, int64_t *out_length) { + AWS_ASSERT(stream && stream->vtable && stream->vtable->get_length); + + return stream->vtable->get_length(stream, out_length); +} + +void aws_input_stream_destroy(struct aws_input_stream *stream) { + if (stream != NULL) { + AWS_ASSERT(stream->vtable && stream->vtable->destroy); + + stream->vtable->destroy(stream); + } +} + +/* + * cursor stream implementation + */ + +struct aws_input_stream_byte_cursor_impl { + struct aws_byte_cursor original_cursor; + struct aws_byte_cursor current_cursor; +}; + +/* + * This is an ugly function that, in the absence of better guidance, is designed to handle all possible combinations of + * aws_off_t (int32_t, int64_t) x all possible combinations of size_t (uint32_t, uint64_t). Whether the anomalous + * combination of int64_t vs. uint32_t is even possible on any real platform is unknown. If size_t ever exceeds 64 bits + * this function will fail badly. + * + * Safety and invariant assumptions are sprinkled via comments. The overall strategy is to cast up to 64 bits and + * perform all arithmetic there, being careful with signed vs. unsigned to prevent bad operations. + * + * Assumption #1: aws_off_t resolves to a signed integer 64 bits or smaller + * Assumption #2: size_t resolves to an unsigned integer 64 bits or smaller + */ + +AWS_STATIC_ASSERT(sizeof(aws_off_t) <= 8); +AWS_STATIC_ASSERT(sizeof(size_t) <= 8); + +static int s_aws_input_stream_byte_cursor_seek( + struct aws_input_stream *stream, + aws_off_t offset, + enum aws_stream_seek_basis basis) { + struct aws_input_stream_byte_cursor_impl *impl = stream->impl; + + uint64_t final_offset = 0; + int64_t checked_offset = offset; /* safe by assumption 1 */ + + switch (basis) { + case AWS_SSB_BEGIN: + /* + * (uint64_t)checked_offset -- safe by virtue of the earlier is-negative check + Assumption 1 + * (uint64_t)impl->original_cursor.len -- safe via assumption 2 + */ + if (checked_offset < 0 || (uint64_t)checked_offset > (uint64_t)impl->original_cursor.len) { + return aws_raise_error(AWS_IO_STREAM_INVALID_SEEK_POSITION); + } + + /* safe because negative offsets were turned into an error */ + final_offset = (uint64_t)checked_offset; + break; + + case AWS_SSB_END: + /* + * -checked_offset -- safe as long checked_offset is not INT64_MIN which was previously checked + * (uint64_t)(-checked_offset) -- safe because (-checked_offset) is positive (and < INT64_MAX < UINT64_MAX) + */ + if (checked_offset > 0 || checked_offset == INT64_MIN || + (uint64_t)(-checked_offset) > (uint64_t)impl->original_cursor.len) { + return aws_raise_error(AWS_IO_STREAM_INVALID_SEEK_POSITION); + } + + /* cases that would make this unsafe became errors with previous conditional */ + final_offset = (uint64_t)impl->original_cursor.len - (uint64_t)(-checked_offset); + break; + } + + /* true because we already validated against (impl->original_cursor.len) which is <= SIZE_MAX */ + AWS_ASSERT(final_offset <= SIZE_MAX); + + /* safe via previous assert */ + size_t final_offset_sz = (size_t)final_offset; + + /* sanity */ + AWS_ASSERT(final_offset_sz <= impl->current_cursor.len); + + impl->current_cursor = impl->original_cursor; + + /* let's skip advance */ + impl->current_cursor.ptr += final_offset_sz; + impl->current_cursor.len -= final_offset_sz; + + return AWS_OP_SUCCESS; +} + +static int s_aws_input_stream_byte_cursor_read(struct aws_input_stream *stream, struct aws_byte_buf *dest) { + struct aws_input_stream_byte_cursor_impl *impl = stream->impl; + + size_t actually_read = dest->capacity - dest->len; + if (actually_read > impl->current_cursor.len) { + actually_read = impl->current_cursor.len; + } + + if (!aws_byte_buf_write(dest, impl->current_cursor.ptr, actually_read)) { + return aws_raise_error(AWS_IO_STREAM_READ_FAILED); + } + + aws_byte_cursor_advance(&impl->current_cursor, actually_read); + + return AWS_OP_SUCCESS; +} + +static int s_aws_input_stream_byte_cursor_get_status( + struct aws_input_stream *stream, + struct aws_stream_status *status) { + struct aws_input_stream_byte_cursor_impl *impl = stream->impl; + + status->is_end_of_stream = impl->current_cursor.len == 0; + status->is_valid = true; + + return AWS_OP_SUCCESS; +} + +static int s_aws_input_stream_byte_cursor_get_length(struct aws_input_stream *stream, int64_t *out_length) { + struct aws_input_stream_byte_cursor_impl *impl = stream->impl; + +#if SIZE_MAX > INT64_MAX + size_t length = impl->original_cursor.len; + if (length > INT64_MAX) { + return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED); + } +#endif + + *out_length = (int64_t)impl->original_cursor.len; + + return AWS_OP_SUCCESS; +} + +static void s_aws_input_stream_byte_cursor_destroy(struct aws_input_stream *stream) { + aws_mem_release(stream->allocator, stream); +} + +static struct aws_input_stream_vtable s_aws_input_stream_byte_cursor_vtable = { + .seek = s_aws_input_stream_byte_cursor_seek, + .read = s_aws_input_stream_byte_cursor_read, + .get_status = s_aws_input_stream_byte_cursor_get_status, + .get_length = s_aws_input_stream_byte_cursor_get_length, + .destroy = s_aws_input_stream_byte_cursor_destroy}; + +struct aws_input_stream *aws_input_stream_new_from_cursor( + struct aws_allocator *allocator, + const struct aws_byte_cursor *cursor) { + + struct aws_input_stream *input_stream = NULL; + struct aws_input_stream_byte_cursor_impl *impl = NULL; + + aws_mem_acquire_many( + allocator, + 2, + &input_stream, + sizeof(struct aws_input_stream), + &impl, + sizeof(struct aws_input_stream_byte_cursor_impl)); + + if (!input_stream) { + return NULL; + } + + AWS_ZERO_STRUCT(*input_stream); + AWS_ZERO_STRUCT(*impl); + + input_stream->allocator = allocator; + input_stream->vtable = &s_aws_input_stream_byte_cursor_vtable; + input_stream->impl = impl; + + impl->original_cursor = *cursor; + impl->current_cursor = *cursor; + + return input_stream; +} + +/* + * file-based input stream + */ +struct aws_input_stream_file_impl { + FILE *file; + bool close_on_clean_up; +}; + +static int s_aws_input_stream_file_seek( + struct aws_input_stream *stream, + aws_off_t offset, + enum aws_stream_seek_basis basis) { + struct aws_input_stream_file_impl *impl = stream->impl; + + int whence = (basis == AWS_SSB_BEGIN) ? SEEK_SET : SEEK_END; + if (aws_fseek(impl->file, offset, whence)) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +static int s_aws_input_stream_file_read(struct aws_input_stream *stream, struct aws_byte_buf *dest) { + struct aws_input_stream_file_impl *impl = stream->impl; + + size_t max_read = dest->capacity - dest->len; + size_t actually_read = fread(dest->buffer + dest->len, 1, max_read, impl->file); + if (actually_read == 0) { + if (ferror(impl->file)) { + return aws_raise_error(AWS_IO_STREAM_READ_FAILED); + } + } + + dest->len += actually_read; + + return AWS_OP_SUCCESS; +} + +static int s_aws_input_stream_file_get_status(struct aws_input_stream *stream, struct aws_stream_status *status) { + struct aws_input_stream_file_impl *impl = stream->impl; + + status->is_end_of_stream = feof(impl->file) != 0; + status->is_valid = ferror(impl->file) == 0; + + return AWS_OP_SUCCESS; +} + +static int s_aws_input_stream_file_get_length(struct aws_input_stream *stream, int64_t *length) { + struct aws_input_stream_file_impl *impl = stream->impl; + + return aws_file_get_length(impl->file, length); +} + +static void s_aws_input_stream_file_destroy(struct aws_input_stream *stream) { + struct aws_input_stream_file_impl *impl = stream->impl; + + if (impl->close_on_clean_up && impl->file) { + fclose(impl->file); + } + + aws_mem_release(stream->allocator, stream); +} + +static struct aws_input_stream_vtable s_aws_input_stream_file_vtable = { + .seek = s_aws_input_stream_file_seek, + .read = s_aws_input_stream_file_read, + .get_status = s_aws_input_stream_file_get_status, + .get_length = s_aws_input_stream_file_get_length, + .destroy = s_aws_input_stream_file_destroy}; + +struct aws_input_stream *aws_input_stream_new_from_file(struct aws_allocator *allocator, const char *file_name) { + + struct aws_input_stream *input_stream = NULL; + struct aws_input_stream_file_impl *impl = NULL; + + aws_mem_acquire_many( + allocator, 2, &input_stream, sizeof(struct aws_input_stream), &impl, sizeof(struct aws_input_stream_file_impl)); + + if (!input_stream) { + return NULL; + } + + AWS_ZERO_STRUCT(*input_stream); + AWS_ZERO_STRUCT(*impl); + + input_stream->allocator = allocator; + input_stream->vtable = &s_aws_input_stream_file_vtable; + input_stream->impl = impl; + + impl->file = fopen(file_name, "r"); + if (impl->file == NULL) { + aws_translate_and_raise_io_error(errno); + goto on_error; + } + + impl->close_on_clean_up = true; + + return input_stream; + +on_error: + + aws_input_stream_destroy(input_stream); + + return NULL; +} + +struct aws_input_stream *aws_input_stream_new_from_open_file(struct aws_allocator *allocator, FILE *file) { + struct aws_input_stream *input_stream = NULL; + struct aws_input_stream_file_impl *impl = NULL; + + aws_mem_acquire_many( + allocator, 2, &input_stream, sizeof(struct aws_input_stream), &impl, sizeof(struct aws_input_stream_file_impl)); + + if (!input_stream) { + return NULL; + } + + AWS_ZERO_STRUCT(*input_stream); + AWS_ZERO_STRUCT(*impl); + + input_stream->allocator = allocator; + input_stream->vtable = &s_aws_input_stream_file_vtable; + input_stream->impl = impl; + + impl->file = file; + impl->close_on_clean_up = false; + + return input_stream; +} diff --git a/contrib/restricted/aws/aws-c-io/source/tls_channel_handler.c b/contrib/restricted/aws/aws-c-io/source/tls_channel_handler.c index bf0f3be9f2..4892f1a46d 100644 --- a/contrib/restricted/aws/aws-c-io/source/tls_channel_handler.c +++ b/contrib/restricted/aws/aws-c-io/source/tls_channel_handler.c @@ -1,463 +1,463 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/channel.h> -#include <aws/io/file_utils.h> -#include <aws/io/logging.h> -#include <aws/io/tls_channel_handler.h> - -#define AWS_DEFAULT_TLS_TIMEOUT_MS 10000 - -#include <aws/common/string.h> - -void aws_tls_ctx_options_init_default_client(struct aws_tls_ctx_options *options, struct aws_allocator *allocator) { - AWS_ZERO_STRUCT(*options); - options->allocator = allocator; - options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; - options->cipher_pref = AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT; - options->verify_peer = true; - options->max_fragment_size = g_aws_channel_max_fragment_size; -} - -void aws_tls_ctx_options_clean_up(struct aws_tls_ctx_options *options) { - if (options->ca_file.len) { - aws_byte_buf_clean_up(&options->ca_file); - } - - if (options->ca_path) { - aws_string_destroy(options->ca_path); - } - - if (options->certificate.len) { - aws_byte_buf_clean_up(&options->certificate); - } - - if (options->private_key.len) { - aws_byte_buf_clean_up_secure(&options->private_key); - } - -#ifdef __APPLE__ - if (options->pkcs12.len) { - aws_byte_buf_clean_up_secure(&options->pkcs12); - } - - if (options->pkcs12_password.len) { - aws_byte_buf_clean_up_secure(&options->pkcs12_password); - } -#endif - - if (options->alpn_list) { - aws_string_destroy(options->alpn_list); - } - - AWS_ZERO_STRUCT(*options); -} - -static int s_load_null_terminated_buffer_from_cursor( - struct aws_byte_buf *load_into, - struct aws_allocator *allocator, - const struct aws_byte_cursor *from) { - if (from->ptr[from->len - 1] == 0) { - if (aws_byte_buf_init_copy_from_cursor(load_into, allocator, *from)) { - return AWS_OP_ERR; - } - - load_into->len -= 1; - } else { - if (aws_byte_buf_init(load_into, allocator, from->len + 1)) { - return AWS_OP_ERR; - } - - memcpy(load_into->buffer, from->ptr, from->len); - load_into->buffer[from->len] = 0; - load_into->len = from->len; - } - - return AWS_OP_SUCCESS; -} - -#if !defined(AWS_OS_IOS) - -int aws_tls_ctx_options_init_client_mtls( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - const struct aws_byte_cursor *cert, - const struct aws_byte_cursor *pkey) { - AWS_ZERO_STRUCT(*options); - options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; - options->cipher_pref = AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT; - options->verify_peer = true; - options->allocator = allocator; - options->max_fragment_size = g_aws_channel_max_fragment_size; - - /* s2n relies on null terminated c_strings, so we need to make sure we're properly - * terminated, but we don't want length to reflect the terminator because - * Apple and Windows will fail hard if you use a null terminator. */ - if (s_load_null_terminated_buffer_from_cursor(&options->certificate, allocator, cert)) { - return AWS_OP_ERR; - } - - if (s_load_null_terminated_buffer_from_cursor(&options->private_key, allocator, pkey)) { - aws_byte_buf_clean_up(&options->certificate); - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -int aws_tls_ctx_options_init_client_mtls_from_path( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - const char *cert_path, - const char *pkey_path) { - AWS_ZERO_STRUCT(*options); - options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; - options->cipher_pref = AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT; - options->verify_peer = true; - options->allocator = allocator; - options->max_fragment_size = g_aws_channel_max_fragment_size; - - if (aws_byte_buf_init_from_file(&options->certificate, allocator, cert_path)) { - return AWS_OP_ERR; - } - - if (aws_byte_buf_init_from_file(&options->private_key, allocator, pkey_path)) { - aws_byte_buf_clean_up(&options->certificate); - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -#endif /* AWS_OS_IOS */ - -#ifdef _WIN32 -void aws_tls_ctx_options_init_client_mtls_from_system_path( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - const char *cert_reg_path) { - AWS_ZERO_STRUCT(*options); - options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; - options->verify_peer = true; - options->allocator = allocator; - options->max_fragment_size = g_aws_channel_max_fragment_size; - options->system_certificate_path = cert_reg_path; -} - -void aws_tls_ctx_options_init_default_server_from_system_path( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - const char *cert_reg_path) { - aws_tls_ctx_options_init_client_mtls_from_system_path(options, allocator, cert_reg_path); - options->verify_peer = false; -} -#endif /* _WIN32 */ - -#ifdef __APPLE__ -int aws_tls_ctx_options_init_client_mtls_pkcs12_from_path( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - const char *pkcs12_path, - struct aws_byte_cursor *pkcs_pwd) { - AWS_ZERO_STRUCT(*options); - options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; - options->verify_peer = true; - options->allocator = allocator; - options->max_fragment_size = g_aws_channel_max_fragment_size; - - if (aws_byte_buf_init_from_file(&options->pkcs12, allocator, pkcs12_path)) { - return AWS_OP_ERR; - } - - if (aws_byte_buf_init_copy_from_cursor(&options->pkcs12_password, allocator, *pkcs_pwd)) { - aws_byte_buf_clean_up_secure(&options->pkcs12); - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -int aws_tls_ctx_options_init_client_mtls_pkcs12( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - struct aws_byte_cursor *pkcs12, - struct aws_byte_cursor *pkcs_pwd) { - AWS_ZERO_STRUCT(*options); - options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; - options->verify_peer = true; - options->allocator = allocator; - options->max_fragment_size = g_aws_channel_max_fragment_size; - - if (s_load_null_terminated_buffer_from_cursor(&options->pkcs12, allocator, pkcs12)) { - return AWS_OP_ERR; - } - - if (s_load_null_terminated_buffer_from_cursor(&options->pkcs12_password, allocator, pkcs_pwd)) { - aws_byte_buf_clean_up_secure(&options->pkcs12); - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -int aws_tls_ctx_options_init_server_pkcs12_from_path( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - const char *pkcs12_path, - struct aws_byte_cursor *pkcs_password) { - if (aws_tls_ctx_options_init_client_mtls_pkcs12_from_path(options, allocator, pkcs12_path, pkcs_password)) { - return AWS_OP_ERR; - } - - options->verify_peer = false; - return AWS_OP_SUCCESS; -} - -int aws_tls_ctx_options_init_server_pkcs12( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - struct aws_byte_cursor *pkcs12, - struct aws_byte_cursor *pkcs_password) { - if (aws_tls_ctx_options_init_client_mtls_pkcs12(options, allocator, pkcs12, pkcs_password)) { - return AWS_OP_ERR; - } - - options->verify_peer = false; - return AWS_OP_SUCCESS; -} - -#endif /* __APPLE__ */ - -#if !defined(AWS_OS_IOS) - -int aws_tls_ctx_options_init_default_server_from_path( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - const char *cert_path, - const char *pkey_path) { - if (aws_tls_ctx_options_init_client_mtls_from_path(options, allocator, cert_path, pkey_path)) { - return AWS_OP_ERR; - } - - options->verify_peer = false; - return AWS_OP_SUCCESS; -} - -int aws_tls_ctx_options_init_default_server( - struct aws_tls_ctx_options *options, - struct aws_allocator *allocator, - struct aws_byte_cursor *cert, - struct aws_byte_cursor *pkey) { - if (aws_tls_ctx_options_init_client_mtls(options, allocator, cert, pkey)) { - return AWS_OP_ERR; - } - - options->verify_peer = false; - return AWS_OP_SUCCESS; -} - -#endif /* AWS_OS_IOS */ - -int aws_tls_ctx_options_set_alpn_list(struct aws_tls_ctx_options *options, const char *alpn_list) { - options->alpn_list = aws_string_new_from_c_str(options->allocator, alpn_list); - if (!options->alpn_list) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -void aws_tls_ctx_options_set_verify_peer(struct aws_tls_ctx_options *options, bool verify_peer) { - options->verify_peer = verify_peer; -} - -void aws_tls_ctx_options_set_minimum_tls_version( - struct aws_tls_ctx_options *options, - enum aws_tls_versions minimum_tls_version) { - options->minimum_tls_version = minimum_tls_version; -} - -int aws_tls_ctx_options_override_default_trust_store_from_path( - struct aws_tls_ctx_options *options, - const char *ca_path, - const char *ca_file) { - - if (ca_path) { - options->ca_path = aws_string_new_from_c_str(options->allocator, ca_path); - if (!options->ca_path) { - return AWS_OP_ERR; - } - } - - if (ca_file) { - if (aws_byte_buf_init_from_file(&options->ca_file, options->allocator, ca_file)) { - return AWS_OP_ERR; - } - } - - return AWS_OP_SUCCESS; -} - -int aws_tls_ctx_options_override_default_trust_store( - struct aws_tls_ctx_options *options, - const struct aws_byte_cursor *ca_file) { - - /* s2n relies on null terminated c_strings, so we need to make sure we're properly - * terminated, but we don't want length to reflect the terminator because - * Apple and Windows will fail hard if you use a null terminator. */ - if (s_load_null_terminated_buffer_from_cursor(&options->ca_file, options->allocator, ca_file)) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -void aws_tls_connection_options_init_from_ctx( - struct aws_tls_connection_options *conn_options, - struct aws_tls_ctx *ctx) { - AWS_ZERO_STRUCT(*conn_options); - /* the assumption here, is that if it was set in the context, we WANT it to be NULL here unless it's different. - * so only set verify peer at this point. */ - conn_options->ctx = aws_tls_ctx_acquire(ctx); - - conn_options->timeout_ms = AWS_DEFAULT_TLS_TIMEOUT_MS; -} - -int aws_tls_connection_options_copy( - struct aws_tls_connection_options *to, - const struct aws_tls_connection_options *from) { - /* copy everything copyable over, then override the rest with deep copies. */ - *to = *from; - - to->ctx = aws_tls_ctx_acquire(from->ctx); - - if (from->alpn_list) { - to->alpn_list = aws_string_new_from_string(from->alpn_list->allocator, from->alpn_list); - - if (!to->alpn_list) { - return AWS_OP_ERR; - } - } - - if (from->server_name) { - to->server_name = aws_string_new_from_string(from->server_name->allocator, from->server_name); - - if (!to->server_name) { - aws_string_destroy(to->server_name); - return AWS_OP_ERR; - } - } - - return AWS_OP_SUCCESS; -} - -void aws_tls_connection_options_clean_up(struct aws_tls_connection_options *connection_options) { - aws_tls_ctx_release(connection_options->ctx); - - if (connection_options->alpn_list) { - aws_string_destroy(connection_options->alpn_list); - } - - if (connection_options->server_name) { - aws_string_destroy(connection_options->server_name); - } - - AWS_ZERO_STRUCT(*connection_options); -} - -void aws_tls_connection_options_set_callbacks( - struct aws_tls_connection_options *conn_options, - aws_tls_on_negotiation_result_fn *on_negotiation_result, - aws_tls_on_data_read_fn *on_data_read, - aws_tls_on_error_fn *on_error, - void *user_data) { - conn_options->on_negotiation_result = on_negotiation_result; - conn_options->on_data_read = on_data_read; - conn_options->on_error = on_error; - conn_options->user_data = user_data; -} - -int aws_tls_connection_options_set_server_name( - struct aws_tls_connection_options *conn_options, - struct aws_allocator *allocator, - struct aws_byte_cursor *server_name) { - conn_options->server_name = aws_string_new_from_cursor(allocator, server_name); - if (!conn_options->server_name) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -int aws_tls_connection_options_set_alpn_list( - struct aws_tls_connection_options *conn_options, - struct aws_allocator *allocator, - const char *alpn_list) { - - conn_options->alpn_list = aws_string_new_from_c_str(allocator, alpn_list); - if (!conn_options->alpn_list) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -int aws_channel_setup_client_tls( - struct aws_channel_slot *right_of_slot, - struct aws_tls_connection_options *tls_options) { - - AWS_FATAL_ASSERT(right_of_slot != NULL); - struct aws_channel *channel = right_of_slot->channel; - struct aws_allocator *allocator = right_of_slot->alloc; - - struct aws_channel_slot *tls_slot = aws_channel_slot_new(channel); - - /* as far as cleanup goes, since this stuff is being added to a channel, the caller will free this memory - when they clean up the channel. */ - if (!tls_slot) { - return AWS_OP_ERR; - } - - struct aws_channel_handler *tls_handler = aws_tls_client_handler_new(allocator, tls_options, tls_slot); - if (!tls_handler) { - aws_mem_release(allocator, tls_slot); - return AWS_OP_ERR; - } - - /* - * From here on out, channel shutdown will handle slot/handler cleanup - */ - aws_channel_slot_insert_right(right_of_slot, tls_slot); - AWS_LOGF_TRACE( - AWS_LS_IO_CHANNEL, - "id=%p: Setting up client TLS with handler %p on slot %p", - (void *)channel, - (void *)tls_handler, - (void *)tls_slot); - - if (aws_channel_slot_set_handler(tls_slot, tls_handler) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - - if (aws_tls_client_handler_start_negotiation(tls_handler) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - - return AWS_OP_SUCCESS; -} - -struct aws_tls_ctx *aws_tls_ctx_acquire(struct aws_tls_ctx *ctx) { - if (ctx != NULL) { - aws_ref_count_acquire(&ctx->ref_count); - } - - return ctx; -} - -void aws_tls_ctx_release(struct aws_tls_ctx *ctx) { - if (ctx != NULL) { - aws_ref_count_release(&ctx->ref_count); - } -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/channel.h> +#include <aws/io/file_utils.h> +#include <aws/io/logging.h> +#include <aws/io/tls_channel_handler.h> + +#define AWS_DEFAULT_TLS_TIMEOUT_MS 10000 + +#include <aws/common/string.h> + +void aws_tls_ctx_options_init_default_client(struct aws_tls_ctx_options *options, struct aws_allocator *allocator) { + AWS_ZERO_STRUCT(*options); + options->allocator = allocator; + options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; + options->cipher_pref = AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT; + options->verify_peer = true; + options->max_fragment_size = g_aws_channel_max_fragment_size; +} + +void aws_tls_ctx_options_clean_up(struct aws_tls_ctx_options *options) { + if (options->ca_file.len) { + aws_byte_buf_clean_up(&options->ca_file); + } + + if (options->ca_path) { + aws_string_destroy(options->ca_path); + } + + if (options->certificate.len) { + aws_byte_buf_clean_up(&options->certificate); + } + + if (options->private_key.len) { + aws_byte_buf_clean_up_secure(&options->private_key); + } + +#ifdef __APPLE__ + if (options->pkcs12.len) { + aws_byte_buf_clean_up_secure(&options->pkcs12); + } + + if (options->pkcs12_password.len) { + aws_byte_buf_clean_up_secure(&options->pkcs12_password); + } +#endif + + if (options->alpn_list) { + aws_string_destroy(options->alpn_list); + } + + AWS_ZERO_STRUCT(*options); +} + +static int s_load_null_terminated_buffer_from_cursor( + struct aws_byte_buf *load_into, + struct aws_allocator *allocator, + const struct aws_byte_cursor *from) { + if (from->ptr[from->len - 1] == 0) { + if (aws_byte_buf_init_copy_from_cursor(load_into, allocator, *from)) { + return AWS_OP_ERR; + } + + load_into->len -= 1; + } else { + if (aws_byte_buf_init(load_into, allocator, from->len + 1)) { + return AWS_OP_ERR; + } + + memcpy(load_into->buffer, from->ptr, from->len); + load_into->buffer[from->len] = 0; + load_into->len = from->len; + } + + return AWS_OP_SUCCESS; +} + +#if !defined(AWS_OS_IOS) + +int aws_tls_ctx_options_init_client_mtls( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + const struct aws_byte_cursor *cert, + const struct aws_byte_cursor *pkey) { + AWS_ZERO_STRUCT(*options); + options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; + options->cipher_pref = AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT; + options->verify_peer = true; + options->allocator = allocator; + options->max_fragment_size = g_aws_channel_max_fragment_size; + + /* s2n relies on null terminated c_strings, so we need to make sure we're properly + * terminated, but we don't want length to reflect the terminator because + * Apple and Windows will fail hard if you use a null terminator. */ + if (s_load_null_terminated_buffer_from_cursor(&options->certificate, allocator, cert)) { + return AWS_OP_ERR; + } + + if (s_load_null_terminated_buffer_from_cursor(&options->private_key, allocator, pkey)) { + aws_byte_buf_clean_up(&options->certificate); + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +int aws_tls_ctx_options_init_client_mtls_from_path( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + const char *cert_path, + const char *pkey_path) { + AWS_ZERO_STRUCT(*options); + options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; + options->cipher_pref = AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT; + options->verify_peer = true; + options->allocator = allocator; + options->max_fragment_size = g_aws_channel_max_fragment_size; + + if (aws_byte_buf_init_from_file(&options->certificate, allocator, cert_path)) { + return AWS_OP_ERR; + } + + if (aws_byte_buf_init_from_file(&options->private_key, allocator, pkey_path)) { + aws_byte_buf_clean_up(&options->certificate); + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +#endif /* AWS_OS_IOS */ + +#ifdef _WIN32 +void aws_tls_ctx_options_init_client_mtls_from_system_path( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + const char *cert_reg_path) { + AWS_ZERO_STRUCT(*options); + options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; + options->verify_peer = true; + options->allocator = allocator; + options->max_fragment_size = g_aws_channel_max_fragment_size; + options->system_certificate_path = cert_reg_path; +} + +void aws_tls_ctx_options_init_default_server_from_system_path( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + const char *cert_reg_path) { + aws_tls_ctx_options_init_client_mtls_from_system_path(options, allocator, cert_reg_path); + options->verify_peer = false; +} +#endif /* _WIN32 */ + +#ifdef __APPLE__ +int aws_tls_ctx_options_init_client_mtls_pkcs12_from_path( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + const char *pkcs12_path, + struct aws_byte_cursor *pkcs_pwd) { + AWS_ZERO_STRUCT(*options); + options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; + options->verify_peer = true; + options->allocator = allocator; + options->max_fragment_size = g_aws_channel_max_fragment_size; + + if (aws_byte_buf_init_from_file(&options->pkcs12, allocator, pkcs12_path)) { + return AWS_OP_ERR; + } + + if (aws_byte_buf_init_copy_from_cursor(&options->pkcs12_password, allocator, *pkcs_pwd)) { + aws_byte_buf_clean_up_secure(&options->pkcs12); + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +int aws_tls_ctx_options_init_client_mtls_pkcs12( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + struct aws_byte_cursor *pkcs12, + struct aws_byte_cursor *pkcs_pwd) { + AWS_ZERO_STRUCT(*options); + options->minimum_tls_version = AWS_IO_TLS_VER_SYS_DEFAULTS; + options->verify_peer = true; + options->allocator = allocator; + options->max_fragment_size = g_aws_channel_max_fragment_size; + + if (s_load_null_terminated_buffer_from_cursor(&options->pkcs12, allocator, pkcs12)) { + return AWS_OP_ERR; + } + + if (s_load_null_terminated_buffer_from_cursor(&options->pkcs12_password, allocator, pkcs_pwd)) { + aws_byte_buf_clean_up_secure(&options->pkcs12); + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +int aws_tls_ctx_options_init_server_pkcs12_from_path( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + const char *pkcs12_path, + struct aws_byte_cursor *pkcs_password) { + if (aws_tls_ctx_options_init_client_mtls_pkcs12_from_path(options, allocator, pkcs12_path, pkcs_password)) { + return AWS_OP_ERR; + } + + options->verify_peer = false; + return AWS_OP_SUCCESS; +} + +int aws_tls_ctx_options_init_server_pkcs12( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + struct aws_byte_cursor *pkcs12, + struct aws_byte_cursor *pkcs_password) { + if (aws_tls_ctx_options_init_client_mtls_pkcs12(options, allocator, pkcs12, pkcs_password)) { + return AWS_OP_ERR; + } + + options->verify_peer = false; + return AWS_OP_SUCCESS; +} + +#endif /* __APPLE__ */ + +#if !defined(AWS_OS_IOS) + +int aws_tls_ctx_options_init_default_server_from_path( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + const char *cert_path, + const char *pkey_path) { + if (aws_tls_ctx_options_init_client_mtls_from_path(options, allocator, cert_path, pkey_path)) { + return AWS_OP_ERR; + } + + options->verify_peer = false; + return AWS_OP_SUCCESS; +} + +int aws_tls_ctx_options_init_default_server( + struct aws_tls_ctx_options *options, + struct aws_allocator *allocator, + struct aws_byte_cursor *cert, + struct aws_byte_cursor *pkey) { + if (aws_tls_ctx_options_init_client_mtls(options, allocator, cert, pkey)) { + return AWS_OP_ERR; + } + + options->verify_peer = false; + return AWS_OP_SUCCESS; +} + +#endif /* AWS_OS_IOS */ + +int aws_tls_ctx_options_set_alpn_list(struct aws_tls_ctx_options *options, const char *alpn_list) { + options->alpn_list = aws_string_new_from_c_str(options->allocator, alpn_list); + if (!options->alpn_list) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +void aws_tls_ctx_options_set_verify_peer(struct aws_tls_ctx_options *options, bool verify_peer) { + options->verify_peer = verify_peer; +} + +void aws_tls_ctx_options_set_minimum_tls_version( + struct aws_tls_ctx_options *options, + enum aws_tls_versions minimum_tls_version) { + options->minimum_tls_version = minimum_tls_version; +} + +int aws_tls_ctx_options_override_default_trust_store_from_path( + struct aws_tls_ctx_options *options, + const char *ca_path, + const char *ca_file) { + + if (ca_path) { + options->ca_path = aws_string_new_from_c_str(options->allocator, ca_path); + if (!options->ca_path) { + return AWS_OP_ERR; + } + } + + if (ca_file) { + if (aws_byte_buf_init_from_file(&options->ca_file, options->allocator, ca_file)) { + return AWS_OP_ERR; + } + } + + return AWS_OP_SUCCESS; +} + +int aws_tls_ctx_options_override_default_trust_store( + struct aws_tls_ctx_options *options, + const struct aws_byte_cursor *ca_file) { + + /* s2n relies on null terminated c_strings, so we need to make sure we're properly + * terminated, but we don't want length to reflect the terminator because + * Apple and Windows will fail hard if you use a null terminator. */ + if (s_load_null_terminated_buffer_from_cursor(&options->ca_file, options->allocator, ca_file)) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +void aws_tls_connection_options_init_from_ctx( + struct aws_tls_connection_options *conn_options, + struct aws_tls_ctx *ctx) { + AWS_ZERO_STRUCT(*conn_options); + /* the assumption here, is that if it was set in the context, we WANT it to be NULL here unless it's different. + * so only set verify peer at this point. */ + conn_options->ctx = aws_tls_ctx_acquire(ctx); + + conn_options->timeout_ms = AWS_DEFAULT_TLS_TIMEOUT_MS; +} + +int aws_tls_connection_options_copy( + struct aws_tls_connection_options *to, + const struct aws_tls_connection_options *from) { + /* copy everything copyable over, then override the rest with deep copies. */ + *to = *from; + + to->ctx = aws_tls_ctx_acquire(from->ctx); + + if (from->alpn_list) { + to->alpn_list = aws_string_new_from_string(from->alpn_list->allocator, from->alpn_list); + + if (!to->alpn_list) { + return AWS_OP_ERR; + } + } + + if (from->server_name) { + to->server_name = aws_string_new_from_string(from->server_name->allocator, from->server_name); + + if (!to->server_name) { + aws_string_destroy(to->server_name); + return AWS_OP_ERR; + } + } + + return AWS_OP_SUCCESS; +} + +void aws_tls_connection_options_clean_up(struct aws_tls_connection_options *connection_options) { + aws_tls_ctx_release(connection_options->ctx); + + if (connection_options->alpn_list) { + aws_string_destroy(connection_options->alpn_list); + } + + if (connection_options->server_name) { + aws_string_destroy(connection_options->server_name); + } + + AWS_ZERO_STRUCT(*connection_options); +} + +void aws_tls_connection_options_set_callbacks( + struct aws_tls_connection_options *conn_options, + aws_tls_on_negotiation_result_fn *on_negotiation_result, + aws_tls_on_data_read_fn *on_data_read, + aws_tls_on_error_fn *on_error, + void *user_data) { + conn_options->on_negotiation_result = on_negotiation_result; + conn_options->on_data_read = on_data_read; + conn_options->on_error = on_error; + conn_options->user_data = user_data; +} + +int aws_tls_connection_options_set_server_name( + struct aws_tls_connection_options *conn_options, + struct aws_allocator *allocator, + struct aws_byte_cursor *server_name) { + conn_options->server_name = aws_string_new_from_cursor(allocator, server_name); + if (!conn_options->server_name) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +int aws_tls_connection_options_set_alpn_list( + struct aws_tls_connection_options *conn_options, + struct aws_allocator *allocator, + const char *alpn_list) { + + conn_options->alpn_list = aws_string_new_from_c_str(allocator, alpn_list); + if (!conn_options->alpn_list) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +int aws_channel_setup_client_tls( + struct aws_channel_slot *right_of_slot, + struct aws_tls_connection_options *tls_options) { + + AWS_FATAL_ASSERT(right_of_slot != NULL); + struct aws_channel *channel = right_of_slot->channel; + struct aws_allocator *allocator = right_of_slot->alloc; + + struct aws_channel_slot *tls_slot = aws_channel_slot_new(channel); + + /* as far as cleanup goes, since this stuff is being added to a channel, the caller will free this memory + when they clean up the channel. */ + if (!tls_slot) { + return AWS_OP_ERR; + } + + struct aws_channel_handler *tls_handler = aws_tls_client_handler_new(allocator, tls_options, tls_slot); + if (!tls_handler) { + aws_mem_release(allocator, tls_slot); + return AWS_OP_ERR; + } + + /* + * From here on out, channel shutdown will handle slot/handler cleanup + */ + aws_channel_slot_insert_right(right_of_slot, tls_slot); + AWS_LOGF_TRACE( + AWS_LS_IO_CHANNEL, + "id=%p: Setting up client TLS with handler %p on slot %p", + (void *)channel, + (void *)tls_handler, + (void *)tls_slot); + + if (aws_channel_slot_set_handler(tls_slot, tls_handler) != AWS_OP_SUCCESS) { + return AWS_OP_ERR; + } + + if (aws_tls_client_handler_start_negotiation(tls_handler) != AWS_OP_SUCCESS) { + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +struct aws_tls_ctx *aws_tls_ctx_acquire(struct aws_tls_ctx *ctx) { + if (ctx != NULL) { + aws_ref_count_acquire(&ctx->ref_count); + } + + return ctx; +} + +void aws_tls_ctx_release(struct aws_tls_ctx *ctx) { + if (ctx != NULL) { + aws_ref_count_release(&ctx->ref_count); + } +} diff --git a/contrib/restricted/aws/aws-c-io/source/tls_channel_handler_shared.c b/contrib/restricted/aws/aws-c-io/source/tls_channel_handler_shared.c index 0a35e78b67..0cdbfd8e29 100644 --- a/contrib/restricted/aws/aws-c-io/source/tls_channel_handler_shared.c +++ b/contrib/restricted/aws/aws-c-io/source/tls_channel_handler_shared.c @@ -1,65 +1,65 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include <aws/io/private/tls_channel_handler_shared.h> - -#include <aws/common/clock.h> -#include <aws/io/tls_channel_handler.h> - -static void s_tls_timeout_task_fn(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) { - (void)channel_task; - - if (status != AWS_TASK_STATUS_RUN_READY) { - return; - } - - struct aws_tls_channel_handler_shared *tls_handler_shared = arg; - if (tls_handler_shared->stats.handshake_status != AWS_TLS_NEGOTIATION_STATUS_ONGOING) { - return; - } - - struct aws_channel *channel = tls_handler_shared->handler->slot->channel; - aws_channel_shutdown(channel, AWS_IO_TLS_NEGOTIATION_TIMEOUT); -} - -void aws_tls_channel_handler_shared_init( - struct aws_tls_channel_handler_shared *tls_handler_shared, - struct aws_channel_handler *handler, - struct aws_tls_connection_options *options) { - tls_handler_shared->handler = handler; - tls_handler_shared->tls_timeout_ms = options->timeout_ms; - aws_crt_statistics_tls_init(&tls_handler_shared->stats); - aws_channel_task_init(&tls_handler_shared->timeout_task, s_tls_timeout_task_fn, tls_handler_shared, "tls_timeout"); -} - -void aws_tls_channel_handler_shared_clean_up(struct aws_tls_channel_handler_shared *tls_handler_shared) { - (void)tls_handler_shared; -} - -void aws_on_drive_tls_negotiation(struct aws_tls_channel_handler_shared *tls_handler_shared) { - if (tls_handler_shared->stats.handshake_status == AWS_TLS_NEGOTIATION_STATUS_NONE) { - tls_handler_shared->stats.handshake_status = AWS_TLS_NEGOTIATION_STATUS_ONGOING; - - uint64_t now = 0; - aws_channel_current_clock_time(tls_handler_shared->handler->slot->channel, &now); - tls_handler_shared->stats.handshake_start_ns = now; - - if (tls_handler_shared->tls_timeout_ms > 0) { - uint64_t timeout_ns = - now + aws_timestamp_convert( - tls_handler_shared->tls_timeout_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL); - - aws_channel_schedule_task_future( - tls_handler_shared->handler->slot->channel, &tls_handler_shared->timeout_task, timeout_ns); - } - } -} - -void aws_on_tls_negotiation_completed(struct aws_tls_channel_handler_shared *tls_handler_shared, int error_code) { - tls_handler_shared->stats.handshake_status = - (error_code == AWS_ERROR_SUCCESS) ? AWS_TLS_NEGOTIATION_STATUS_SUCCESS : AWS_TLS_NEGOTIATION_STATUS_FAILURE; - aws_channel_current_clock_time( - tls_handler_shared->handler->slot->channel, &tls_handler_shared->stats.handshake_end_ns); -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/io/private/tls_channel_handler_shared.h> + +#include <aws/common/clock.h> +#include <aws/io/tls_channel_handler.h> + +static void s_tls_timeout_task_fn(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) { + (void)channel_task; + + if (status != AWS_TASK_STATUS_RUN_READY) { + return; + } + + struct aws_tls_channel_handler_shared *tls_handler_shared = arg; + if (tls_handler_shared->stats.handshake_status != AWS_TLS_NEGOTIATION_STATUS_ONGOING) { + return; + } + + struct aws_channel *channel = tls_handler_shared->handler->slot->channel; + aws_channel_shutdown(channel, AWS_IO_TLS_NEGOTIATION_TIMEOUT); +} + +void aws_tls_channel_handler_shared_init( + struct aws_tls_channel_handler_shared *tls_handler_shared, + struct aws_channel_handler *handler, + struct aws_tls_connection_options *options) { + tls_handler_shared->handler = handler; + tls_handler_shared->tls_timeout_ms = options->timeout_ms; + aws_crt_statistics_tls_init(&tls_handler_shared->stats); + aws_channel_task_init(&tls_handler_shared->timeout_task, s_tls_timeout_task_fn, tls_handler_shared, "tls_timeout"); +} + +void aws_tls_channel_handler_shared_clean_up(struct aws_tls_channel_handler_shared *tls_handler_shared) { + (void)tls_handler_shared; +} + +void aws_on_drive_tls_negotiation(struct aws_tls_channel_handler_shared *tls_handler_shared) { + if (tls_handler_shared->stats.handshake_status == AWS_TLS_NEGOTIATION_STATUS_NONE) { + tls_handler_shared->stats.handshake_status = AWS_TLS_NEGOTIATION_STATUS_ONGOING; + + uint64_t now = 0; + aws_channel_current_clock_time(tls_handler_shared->handler->slot->channel, &now); + tls_handler_shared->stats.handshake_start_ns = now; + + if (tls_handler_shared->tls_timeout_ms > 0) { + uint64_t timeout_ns = + now + aws_timestamp_convert( + tls_handler_shared->tls_timeout_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL); + + aws_channel_schedule_task_future( + tls_handler_shared->handler->slot->channel, &tls_handler_shared->timeout_task, timeout_ns); + } + } +} + +void aws_on_tls_negotiation_completed(struct aws_tls_channel_handler_shared *tls_handler_shared, int error_code) { + tls_handler_shared->stats.handshake_status = + (error_code == AWS_ERROR_SUCCESS) ? AWS_TLS_NEGOTIATION_STATUS_SUCCESS : AWS_TLS_NEGOTIATION_STATUS_FAILURE; + aws_channel_current_clock_time( + tls_handler_shared->handler->slot->channel, &tls_handler_shared->stats.handshake_end_ns); +} diff --git a/contrib/restricted/aws/aws-c-io/source/uri.c b/contrib/restricted/aws/aws-c-io/source/uri.c index bb0cf01ae4..313acb79f6 100644 --- a/contrib/restricted/aws/aws-c-io/source/uri.c +++ b/contrib/restricted/aws/aws-c-io/source/uri.c @@ -1,562 +1,562 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#include <aws/io/uri.h> - -#include <aws/common/common.h> - -#include <ctype.h> -#include <inttypes.h> -#include <stdio.h> -#include <string.h> - -#if _MSC_VER -# pragma warning(disable : 4221) /* aggregate initializer using local variable addresses */ -# pragma warning(disable : 4204) /* non-constant aggregate initializer */ -# pragma warning(disable : 4996) /* sprintf */ -#endif - -enum parser_state { - ON_SCHEME, - ON_AUTHORITY, - ON_PATH, - ON_QUERY_STRING, - FINISHED, - ERROR, -}; - -struct uri_parser { - struct aws_uri *uri; - enum parser_state state; -}; - -typedef void(parse_fn)(struct uri_parser *parser, struct aws_byte_cursor *str); - -static void s_parse_scheme(struct uri_parser *parser, struct aws_byte_cursor *str); -static void s_parse_authority(struct uri_parser *parser, struct aws_byte_cursor *str); -static void s_parse_path(struct uri_parser *parser, struct aws_byte_cursor *str); -static void s_parse_query_string(struct uri_parser *parser, struct aws_byte_cursor *str); - -static parse_fn *s_states[] = { - [ON_SCHEME] = s_parse_scheme, - [ON_AUTHORITY] = s_parse_authority, - [ON_PATH] = s_parse_path, - [ON_QUERY_STRING] = s_parse_query_string, -}; - -static int s_init_from_uri_str(struct aws_uri *uri) { - struct uri_parser parser = { - .state = ON_SCHEME, - .uri = uri, - }; - - struct aws_byte_cursor uri_cur = aws_byte_cursor_from_buf(&uri->uri_str); - - while (parser.state < FINISHED) { - s_states[parser.state](&parser, &uri_cur); - } - - /* Each state function sets the next state, if something goes wrong it sets it to ERROR which is > FINISHED */ - if (parser.state == FINISHED) { - return AWS_OP_SUCCESS; - } - - aws_byte_buf_clean_up(&uri->uri_str); - AWS_ZERO_STRUCT(*uri); - return AWS_OP_ERR; -} - -int aws_uri_init_parse(struct aws_uri *uri, struct aws_allocator *allocator, const struct aws_byte_cursor *uri_str) { - AWS_ZERO_STRUCT(*uri); - uri->self_size = sizeof(struct aws_uri); - uri->allocator = allocator; - - if (aws_byte_buf_init_copy_from_cursor(&uri->uri_str, allocator, *uri_str)) { - return AWS_OP_ERR; - } - - return s_init_from_uri_str(uri); -} - -int aws_uri_init_from_builder_options( - struct aws_uri *uri, - struct aws_allocator *allocator, - struct aws_uri_builder_options *options) { - - AWS_ZERO_STRUCT(*uri); - - if (options->query_string.len && options->query_params) { - return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - } - - uri->self_size = sizeof(struct aws_uri); - uri->allocator = allocator; - - size_t buffer_size = 0; - if (options->scheme.len) { - /* 3 for :// */ - buffer_size += options->scheme.len + 3; - } - - buffer_size += options->host_name.len; - - if (options->port) { - /* max strlen of a 16 bit integer is 5 */ - buffer_size += 6; - } - - buffer_size += options->path.len; - - if (options->query_params) { - size_t query_len = aws_array_list_length(options->query_params); - if (query_len) { - /* for the '?' */ - buffer_size += 1; - for (size_t i = 0; i < query_len; ++i) { - struct aws_uri_param *uri_param_ptr = NULL; - aws_array_list_get_at_ptr(options->query_params, (void **)&uri_param_ptr, i); - /* 2 == 1 for '&' and 1 for '='. who cares if we over-allocate a little? */ - buffer_size += uri_param_ptr->key.len + uri_param_ptr->value.len + 2; - } - } - } else if (options->query_string.len) { - /* for the '?' */ - buffer_size += 1; - buffer_size += options->query_string.len; - } - - if (aws_byte_buf_init(&uri->uri_str, allocator, buffer_size)) { - return AWS_OP_ERR; - } - - uri->uri_str.len = 0; - if (options->scheme.len) { - aws_byte_buf_append(&uri->uri_str, &options->scheme); - struct aws_byte_cursor scheme_app = aws_byte_cursor_from_c_str("://"); - aws_byte_buf_append(&uri->uri_str, &scheme_app); - } - - aws_byte_buf_append(&uri->uri_str, &options->host_name); - - struct aws_byte_cursor port_app = aws_byte_cursor_from_c_str(":"); - if (options->port) { - aws_byte_buf_append(&uri->uri_str, &port_app); - char port_arr[6] = {0}; - sprintf(port_arr, "%" PRIu16, options->port); - struct aws_byte_cursor port_csr = aws_byte_cursor_from_c_str(port_arr); - aws_byte_buf_append(&uri->uri_str, &port_csr); - } - - aws_byte_buf_append(&uri->uri_str, &options->path); - - struct aws_byte_cursor query_app = aws_byte_cursor_from_c_str("?"); - - if (options->query_params) { - struct aws_byte_cursor query_param_app = aws_byte_cursor_from_c_str("&"); - struct aws_byte_cursor key_value_delim = aws_byte_cursor_from_c_str("="); - - aws_byte_buf_append(&uri->uri_str, &query_app); - size_t query_len = aws_array_list_length(options->query_params); - for (size_t i = 0; i < query_len; ++i) { - struct aws_uri_param *uri_param_ptr = NULL; - aws_array_list_get_at_ptr(options->query_params, (void **)&uri_param_ptr, i); - aws_byte_buf_append(&uri->uri_str, &uri_param_ptr->key); - aws_byte_buf_append(&uri->uri_str, &key_value_delim); - aws_byte_buf_append(&uri->uri_str, &uri_param_ptr->value); - - if (i < query_len - 1) { - aws_byte_buf_append(&uri->uri_str, &query_param_app); - } - } - } else if (options->query_string.len) { - aws_byte_buf_append(&uri->uri_str, &query_app); - aws_byte_buf_append(&uri->uri_str, &options->query_string); - } - - return s_init_from_uri_str(uri); -} - -void aws_uri_clean_up(struct aws_uri *uri) { - if (uri->uri_str.allocator) { - aws_byte_buf_clean_up(&uri->uri_str); - } - AWS_ZERO_STRUCT(*uri); -} - -const struct aws_byte_cursor *aws_uri_scheme(const struct aws_uri *uri) { - return &uri->scheme; -} - -const struct aws_byte_cursor *aws_uri_authority(const struct aws_uri *uri) { - return &uri->authority; -} - -const struct aws_byte_cursor *aws_uri_path(const struct aws_uri *uri) { - return &uri->path; -} - -const struct aws_byte_cursor *aws_uri_query_string(const struct aws_uri *uri) { - return &uri->query_string; -} - -const struct aws_byte_cursor *aws_uri_path_and_query(const struct aws_uri *uri) { - return &uri->path_and_query; -} - -const struct aws_byte_cursor *aws_uri_host_name(const struct aws_uri *uri) { - return &uri->host_name; -} - -uint16_t aws_uri_port(const struct aws_uri *uri) { - return uri->port; -} - -bool aws_uri_query_string_next_param(const struct aws_uri *uri, struct aws_uri_param *param) { - /* If param is zeroed, then this is the first run. */ - bool first_run = param->value.ptr == NULL; - - /* aws_byte_cursor_next_split() is used to iterate over params in the query string. - * It takes an in/out substring arg similar to how this function works */ - struct aws_byte_cursor substr; - if (first_run) { - /* substring must be zeroed to start */ - AWS_ZERO_STRUCT(substr); - } else { - /* re-assemble substring which contained key and value */ - substr.ptr = param->key.ptr; - substr.len = (param->value.ptr - param->key.ptr) + param->value.len; - } - - /* The do-while is to skip over any empty substrings */ - do { - if (!aws_byte_cursor_next_split(&uri->query_string, '&', &substr)) { - /* no more splits, done iterating */ - return false; - } - } while (substr.len == 0); - - uint8_t *delim = memchr(substr.ptr, '=', substr.len); - if (delim) { - param->key.ptr = substr.ptr; - param->key.len = delim - substr.ptr; - param->value.ptr = delim + 1; - param->value.len = substr.len - param->key.len - 1; - } else { - /* no '=', key gets substring, value is blank */ - param->key = substr; - param->value.ptr = substr.ptr + substr.len; - param->value.len = 0; - } - - return true; -} - -int aws_uri_query_string_params(const struct aws_uri *uri, struct aws_array_list *out_params) { - struct aws_uri_param param; - AWS_ZERO_STRUCT(param); - while (aws_uri_query_string_next_param(uri, ¶m)) { - if (aws_array_list_push_back(out_params, ¶m)) { - return AWS_OP_ERR; - } - } - - return AWS_OP_SUCCESS; -} - -static void s_parse_scheme(struct uri_parser *parser, struct aws_byte_cursor *str) { - uint8_t *location_of_colon = memchr(str->ptr, ':', str->len); - - if (!location_of_colon) { - parser->state = ON_AUTHORITY; - return; - } - - /* make sure we didn't just pick up the port by mistake */ - if ((size_t)(location_of_colon - str->ptr) < str->len && *(location_of_colon + 1) != '/') { - parser->state = ON_AUTHORITY; - return; - } - - const size_t scheme_len = location_of_colon - str->ptr; - parser->uri->scheme = aws_byte_cursor_advance(str, scheme_len); - - if (str->len < 3 || str->ptr[0] != ':' || str->ptr[1] != '/' || str->ptr[2] != '/') { - aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); - parser->state = ERROR; - return; - } - - /* advance past the "://" */ - aws_byte_cursor_advance(str, 3); - parser->state = ON_AUTHORITY; -} - -static const char *s_default_path = "/"; - -static void s_parse_authority(struct uri_parser *parser, struct aws_byte_cursor *str) { - uint8_t *location_of_slash = memchr(str->ptr, '/', str->len); - uint8_t *location_of_qmark = memchr(str->ptr, '?', str->len); - - if (!location_of_slash && !location_of_qmark && str->len) { - parser->uri->authority.ptr = str->ptr; - parser->uri->authority.len = str->len; - - parser->uri->path.ptr = (uint8_t *)s_default_path; - parser->uri->path.len = 1; - parser->uri->path_and_query = parser->uri->path; - parser->state = FINISHED; - aws_byte_cursor_advance(str, parser->uri->authority.len); - } else if (!str->len) { - parser->state = ERROR; - aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); - return; - } else { - uint8_t *end = str->ptr + str->len; - if (location_of_slash) { - parser->state = ON_PATH; - end = location_of_slash; - } else if (location_of_qmark) { - parser->state = ON_QUERY_STRING; - end = location_of_qmark; - } - - parser->uri->authority = aws_byte_cursor_advance(str, end - str->ptr); - } - - struct aws_byte_cursor authority_parse_csr = parser->uri->authority; - - if (authority_parse_csr.len) { - uint8_t *port_delim = memchr(authority_parse_csr.ptr, ':', authority_parse_csr.len); - - if (!port_delim) { - parser->uri->port = 0; - parser->uri->host_name = parser->uri->authority; - return; - } - - parser->uri->host_name.ptr = authority_parse_csr.ptr; - parser->uri->host_name.len = port_delim - authority_parse_csr.ptr; - - size_t port_len = parser->uri->authority.len - parser->uri->host_name.len - 1; - port_delim += 1; - for (size_t i = 0; i < port_len; ++i) { - if (!aws_isdigit(port_delim[i])) { - parser->state = ERROR; - aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); - return; - } - } - - if (port_len > 5) { - parser->state = ERROR; - aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); - return; - } - - /* why 6? because the port is a 16-bit unsigned integer*/ - char atoi_buf[6] = {0}; - memcpy(atoi_buf, port_delim, port_len); - int port_int = atoi(atoi_buf); - if (port_int > UINT16_MAX) { - parser->state = ERROR; - aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); - return; - } - - parser->uri->port = (uint16_t)port_int; - } -} - -static void s_parse_path(struct uri_parser *parser, struct aws_byte_cursor *str) { - parser->uri->path_and_query = *str; - - uint8_t *location_of_q_mark = memchr(str->ptr, '?', str->len); - - if (!location_of_q_mark) { - parser->uri->path.ptr = str->ptr; - parser->uri->path.len = str->len; - parser->state = FINISHED; - aws_byte_cursor_advance(str, parser->uri->path.len); - return; - } - - if (!str->len) { - parser->state = ERROR; - aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); - return; - } - - parser->uri->path.ptr = str->ptr; - parser->uri->path.len = location_of_q_mark - str->ptr; - aws_byte_cursor_advance(str, parser->uri->path.len); - parser->state = ON_QUERY_STRING; -} - -static void s_parse_query_string(struct uri_parser *parser, struct aws_byte_cursor *str) { - if (!parser->uri->path_and_query.ptr) { - parser->uri->path_and_query = *str; - } - /* we don't want the '?' character. */ - if (str->len) { - parser->uri->query_string.ptr = str->ptr + 1; - parser->uri->query_string.len = str->len - 1; - } - - aws_byte_cursor_advance(str, parser->uri->query_string.len + 1); - parser->state = FINISHED; -} - -static uint8_t s_to_uppercase_hex(uint8_t value) { - AWS_ASSERT(value < 16); - - if (value < 10) { - return (uint8_t)('0' + value); - } - - return (uint8_t)('A' + value - 10); -} - -typedef void(unchecked_append_canonicalized_character_fn)(struct aws_byte_buf *buffer, uint8_t value); - -/* - * Appends a character or its hex encoding to the buffer. We reserve enough space up front so that - * we can do this with raw pointers rather than multiple function calls/cursors/etc... - * - * This function is for the uri path - */ -static void s_unchecked_append_canonicalized_path_character(struct aws_byte_buf *buffer, uint8_t value) { - AWS_ASSERT(buffer->len + 3 <= buffer->capacity); - - uint8_t *dest_ptr = buffer->buffer + buffer->len; - - if (aws_isalnum(value)) { - ++buffer->len; - *dest_ptr = value; - return; - } - - switch (value) { - case '-': - case '_': - case '.': - case '~': - case '$': - case '&': - case ',': - case '/': - case ':': - case ';': - case '=': - case '@': { - ++buffer->len; - *dest_ptr = value; - return; - } - - default: - buffer->len += 3; - *dest_ptr++ = '%'; - *dest_ptr++ = s_to_uppercase_hex(value >> 4); - *dest_ptr = s_to_uppercase_hex(value & 0x0F); - return; - } -} - -/* - * Appends a character or its hex encoding to the buffer. We reserve enough space up front so that - * we can do this with raw pointers rather than multiple function calls/cursors/etc... - * - * This function is for query params - */ -static void s_raw_append_canonicalized_param_character(struct aws_byte_buf *buffer, uint8_t value) { - AWS_ASSERT(buffer->len + 3 <= buffer->capacity); - - uint8_t *dest_ptr = buffer->buffer + buffer->len; - - if (aws_isalnum(value)) { - ++buffer->len; - *dest_ptr = value; - return; - } - - switch (value) { - case '-': - case '_': - case '.': - case '~': { - ++buffer->len; - *dest_ptr = value; - return; - } - - default: - buffer->len += 3; - *dest_ptr++ = '%'; - *dest_ptr++ = s_to_uppercase_hex(value >> 4); - *dest_ptr = s_to_uppercase_hex(value & 0x0F); - return; - } -} - -/* - * Writes a cursor to a buffer using the supplied encoding function. - */ -static int s_encode_cursor_to_buffer( - struct aws_byte_buf *buffer, - const struct aws_byte_cursor *cursor, - unchecked_append_canonicalized_character_fn *append_canonicalized_character) { - uint8_t *current_ptr = cursor->ptr; - uint8_t *end_ptr = cursor->ptr + cursor->len; - - /* - * reserve room up front for the worst possible case: everything gets % encoded - */ - size_t capacity_needed = 0; - if (AWS_UNLIKELY(aws_mul_size_checked(3, cursor->len, &capacity_needed))) { - return AWS_OP_ERR; - } - - if (aws_byte_buf_reserve_relative(buffer, capacity_needed)) { - return AWS_OP_ERR; - } - - while (current_ptr < end_ptr) { - append_canonicalized_character(buffer, *current_ptr); - ++current_ptr; - } - - return AWS_OP_SUCCESS; -} - -int aws_byte_buf_append_encoding_uri_path(struct aws_byte_buf *buffer, const struct aws_byte_cursor *cursor) { - return s_encode_cursor_to_buffer(buffer, cursor, s_unchecked_append_canonicalized_path_character); -} - -int aws_byte_buf_append_encoding_uri_param(struct aws_byte_buf *buffer, const struct aws_byte_cursor *cursor) { - return s_encode_cursor_to_buffer(buffer, cursor, s_raw_append_canonicalized_param_character); -} - -int aws_byte_buf_append_decoding_uri(struct aws_byte_buf *buffer, const struct aws_byte_cursor *cursor) { - /* reserve room up front for worst possible case: no % and everything copies over 1:1 */ - if (aws_byte_buf_reserve_relative(buffer, cursor->len)) { - return AWS_OP_ERR; - } - - /* advance over cursor */ - struct aws_byte_cursor advancing = *cursor; - uint8_t c; - while (aws_byte_cursor_read_u8(&advancing, &c)) { - - if (c == '%') { - /* two hex characters following '%' are the byte's value */ - if (AWS_UNLIKELY(aws_byte_cursor_read_hex_u8(&advancing, &c) == false)) { - return aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); - } - } - - buffer->buffer[buffer->len++] = c; - } - - return AWS_OP_SUCCESS; -} +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/io/uri.h> + +#include <aws/common/common.h> + +#include <ctype.h> +#include <inttypes.h> +#include <stdio.h> +#include <string.h> + +#if _MSC_VER +# pragma warning(disable : 4221) /* aggregate initializer using local variable addresses */ +# pragma warning(disable : 4204) /* non-constant aggregate initializer */ +# pragma warning(disable : 4996) /* sprintf */ +#endif + +enum parser_state { + ON_SCHEME, + ON_AUTHORITY, + ON_PATH, + ON_QUERY_STRING, + FINISHED, + ERROR, +}; + +struct uri_parser { + struct aws_uri *uri; + enum parser_state state; +}; + +typedef void(parse_fn)(struct uri_parser *parser, struct aws_byte_cursor *str); + +static void s_parse_scheme(struct uri_parser *parser, struct aws_byte_cursor *str); +static void s_parse_authority(struct uri_parser *parser, struct aws_byte_cursor *str); +static void s_parse_path(struct uri_parser *parser, struct aws_byte_cursor *str); +static void s_parse_query_string(struct uri_parser *parser, struct aws_byte_cursor *str); + +static parse_fn *s_states[] = { + [ON_SCHEME] = s_parse_scheme, + [ON_AUTHORITY] = s_parse_authority, + [ON_PATH] = s_parse_path, + [ON_QUERY_STRING] = s_parse_query_string, +}; + +static int s_init_from_uri_str(struct aws_uri *uri) { + struct uri_parser parser = { + .state = ON_SCHEME, + .uri = uri, + }; + + struct aws_byte_cursor uri_cur = aws_byte_cursor_from_buf(&uri->uri_str); + + while (parser.state < FINISHED) { + s_states[parser.state](&parser, &uri_cur); + } + + /* Each state function sets the next state, if something goes wrong it sets it to ERROR which is > FINISHED */ + if (parser.state == FINISHED) { + return AWS_OP_SUCCESS; + } + + aws_byte_buf_clean_up(&uri->uri_str); + AWS_ZERO_STRUCT(*uri); + return AWS_OP_ERR; +} + +int aws_uri_init_parse(struct aws_uri *uri, struct aws_allocator *allocator, const struct aws_byte_cursor *uri_str) { + AWS_ZERO_STRUCT(*uri); + uri->self_size = sizeof(struct aws_uri); + uri->allocator = allocator; + + if (aws_byte_buf_init_copy_from_cursor(&uri->uri_str, allocator, *uri_str)) { + return AWS_OP_ERR; + } + + return s_init_from_uri_str(uri); +} + +int aws_uri_init_from_builder_options( + struct aws_uri *uri, + struct aws_allocator *allocator, + struct aws_uri_builder_options *options) { + + AWS_ZERO_STRUCT(*uri); + + if (options->query_string.len && options->query_params) { + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + uri->self_size = sizeof(struct aws_uri); + uri->allocator = allocator; + + size_t buffer_size = 0; + if (options->scheme.len) { + /* 3 for :// */ + buffer_size += options->scheme.len + 3; + } + + buffer_size += options->host_name.len; + + if (options->port) { + /* max strlen of a 16 bit integer is 5 */ + buffer_size += 6; + } + + buffer_size += options->path.len; + + if (options->query_params) { + size_t query_len = aws_array_list_length(options->query_params); + if (query_len) { + /* for the '?' */ + buffer_size += 1; + for (size_t i = 0; i < query_len; ++i) { + struct aws_uri_param *uri_param_ptr = NULL; + aws_array_list_get_at_ptr(options->query_params, (void **)&uri_param_ptr, i); + /* 2 == 1 for '&' and 1 for '='. who cares if we over-allocate a little? */ + buffer_size += uri_param_ptr->key.len + uri_param_ptr->value.len + 2; + } + } + } else if (options->query_string.len) { + /* for the '?' */ + buffer_size += 1; + buffer_size += options->query_string.len; + } + + if (aws_byte_buf_init(&uri->uri_str, allocator, buffer_size)) { + return AWS_OP_ERR; + } + + uri->uri_str.len = 0; + if (options->scheme.len) { + aws_byte_buf_append(&uri->uri_str, &options->scheme); + struct aws_byte_cursor scheme_app = aws_byte_cursor_from_c_str("://"); + aws_byte_buf_append(&uri->uri_str, &scheme_app); + } + + aws_byte_buf_append(&uri->uri_str, &options->host_name); + + struct aws_byte_cursor port_app = aws_byte_cursor_from_c_str(":"); + if (options->port) { + aws_byte_buf_append(&uri->uri_str, &port_app); + char port_arr[6] = {0}; + sprintf(port_arr, "%" PRIu16, options->port); + struct aws_byte_cursor port_csr = aws_byte_cursor_from_c_str(port_arr); + aws_byte_buf_append(&uri->uri_str, &port_csr); + } + + aws_byte_buf_append(&uri->uri_str, &options->path); + + struct aws_byte_cursor query_app = aws_byte_cursor_from_c_str("?"); + + if (options->query_params) { + struct aws_byte_cursor query_param_app = aws_byte_cursor_from_c_str("&"); + struct aws_byte_cursor key_value_delim = aws_byte_cursor_from_c_str("="); + + aws_byte_buf_append(&uri->uri_str, &query_app); + size_t query_len = aws_array_list_length(options->query_params); + for (size_t i = 0; i < query_len; ++i) { + struct aws_uri_param *uri_param_ptr = NULL; + aws_array_list_get_at_ptr(options->query_params, (void **)&uri_param_ptr, i); + aws_byte_buf_append(&uri->uri_str, &uri_param_ptr->key); + aws_byte_buf_append(&uri->uri_str, &key_value_delim); + aws_byte_buf_append(&uri->uri_str, &uri_param_ptr->value); + + if (i < query_len - 1) { + aws_byte_buf_append(&uri->uri_str, &query_param_app); + } + } + } else if (options->query_string.len) { + aws_byte_buf_append(&uri->uri_str, &query_app); + aws_byte_buf_append(&uri->uri_str, &options->query_string); + } + + return s_init_from_uri_str(uri); +} + +void aws_uri_clean_up(struct aws_uri *uri) { + if (uri->uri_str.allocator) { + aws_byte_buf_clean_up(&uri->uri_str); + } + AWS_ZERO_STRUCT(*uri); +} + +const struct aws_byte_cursor *aws_uri_scheme(const struct aws_uri *uri) { + return &uri->scheme; +} + +const struct aws_byte_cursor *aws_uri_authority(const struct aws_uri *uri) { + return &uri->authority; +} + +const struct aws_byte_cursor *aws_uri_path(const struct aws_uri *uri) { + return &uri->path; +} + +const struct aws_byte_cursor *aws_uri_query_string(const struct aws_uri *uri) { + return &uri->query_string; +} + +const struct aws_byte_cursor *aws_uri_path_and_query(const struct aws_uri *uri) { + return &uri->path_and_query; +} + +const struct aws_byte_cursor *aws_uri_host_name(const struct aws_uri *uri) { + return &uri->host_name; +} + +uint16_t aws_uri_port(const struct aws_uri *uri) { + return uri->port; +} + +bool aws_uri_query_string_next_param(const struct aws_uri *uri, struct aws_uri_param *param) { + /* If param is zeroed, then this is the first run. */ + bool first_run = param->value.ptr == NULL; + + /* aws_byte_cursor_next_split() is used to iterate over params in the query string. + * It takes an in/out substring arg similar to how this function works */ + struct aws_byte_cursor substr; + if (first_run) { + /* substring must be zeroed to start */ + AWS_ZERO_STRUCT(substr); + } else { + /* re-assemble substring which contained key and value */ + substr.ptr = param->key.ptr; + substr.len = (param->value.ptr - param->key.ptr) + param->value.len; + } + + /* The do-while is to skip over any empty substrings */ + do { + if (!aws_byte_cursor_next_split(&uri->query_string, '&', &substr)) { + /* no more splits, done iterating */ + return false; + } + } while (substr.len == 0); + + uint8_t *delim = memchr(substr.ptr, '=', substr.len); + if (delim) { + param->key.ptr = substr.ptr; + param->key.len = delim - substr.ptr; + param->value.ptr = delim + 1; + param->value.len = substr.len - param->key.len - 1; + } else { + /* no '=', key gets substring, value is blank */ + param->key = substr; + param->value.ptr = substr.ptr + substr.len; + param->value.len = 0; + } + + return true; +} + +int aws_uri_query_string_params(const struct aws_uri *uri, struct aws_array_list *out_params) { + struct aws_uri_param param; + AWS_ZERO_STRUCT(param); + while (aws_uri_query_string_next_param(uri, ¶m)) { + if (aws_array_list_push_back(out_params, ¶m)) { + return AWS_OP_ERR; + } + } + + return AWS_OP_SUCCESS; +} + +static void s_parse_scheme(struct uri_parser *parser, struct aws_byte_cursor *str) { + uint8_t *location_of_colon = memchr(str->ptr, ':', str->len); + + if (!location_of_colon) { + parser->state = ON_AUTHORITY; + return; + } + + /* make sure we didn't just pick up the port by mistake */ + if ((size_t)(location_of_colon - str->ptr) < str->len && *(location_of_colon + 1) != '/') { + parser->state = ON_AUTHORITY; + return; + } + + const size_t scheme_len = location_of_colon - str->ptr; + parser->uri->scheme = aws_byte_cursor_advance(str, scheme_len); + + if (str->len < 3 || str->ptr[0] != ':' || str->ptr[1] != '/' || str->ptr[2] != '/') { + aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); + parser->state = ERROR; + return; + } + + /* advance past the "://" */ + aws_byte_cursor_advance(str, 3); + parser->state = ON_AUTHORITY; +} + +static const char *s_default_path = "/"; + +static void s_parse_authority(struct uri_parser *parser, struct aws_byte_cursor *str) { + uint8_t *location_of_slash = memchr(str->ptr, '/', str->len); + uint8_t *location_of_qmark = memchr(str->ptr, '?', str->len); + + if (!location_of_slash && !location_of_qmark && str->len) { + parser->uri->authority.ptr = str->ptr; + parser->uri->authority.len = str->len; + + parser->uri->path.ptr = (uint8_t *)s_default_path; + parser->uri->path.len = 1; + parser->uri->path_and_query = parser->uri->path; + parser->state = FINISHED; + aws_byte_cursor_advance(str, parser->uri->authority.len); + } else if (!str->len) { + parser->state = ERROR; + aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); + return; + } else { + uint8_t *end = str->ptr + str->len; + if (location_of_slash) { + parser->state = ON_PATH; + end = location_of_slash; + } else if (location_of_qmark) { + parser->state = ON_QUERY_STRING; + end = location_of_qmark; + } + + parser->uri->authority = aws_byte_cursor_advance(str, end - str->ptr); + } + + struct aws_byte_cursor authority_parse_csr = parser->uri->authority; + + if (authority_parse_csr.len) { + uint8_t *port_delim = memchr(authority_parse_csr.ptr, ':', authority_parse_csr.len); + + if (!port_delim) { + parser->uri->port = 0; + parser->uri->host_name = parser->uri->authority; + return; + } + + parser->uri->host_name.ptr = authority_parse_csr.ptr; + parser->uri->host_name.len = port_delim - authority_parse_csr.ptr; + + size_t port_len = parser->uri->authority.len - parser->uri->host_name.len - 1; + port_delim += 1; + for (size_t i = 0; i < port_len; ++i) { + if (!aws_isdigit(port_delim[i])) { + parser->state = ERROR; + aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); + return; + } + } + + if (port_len > 5) { + parser->state = ERROR; + aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); + return; + } + + /* why 6? because the port is a 16-bit unsigned integer*/ + char atoi_buf[6] = {0}; + memcpy(atoi_buf, port_delim, port_len); + int port_int = atoi(atoi_buf); + if (port_int > UINT16_MAX) { + parser->state = ERROR; + aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); + return; + } + + parser->uri->port = (uint16_t)port_int; + } +} + +static void s_parse_path(struct uri_parser *parser, struct aws_byte_cursor *str) { + parser->uri->path_and_query = *str; + + uint8_t *location_of_q_mark = memchr(str->ptr, '?', str->len); + + if (!location_of_q_mark) { + parser->uri->path.ptr = str->ptr; + parser->uri->path.len = str->len; + parser->state = FINISHED; + aws_byte_cursor_advance(str, parser->uri->path.len); + return; + } + + if (!str->len) { + parser->state = ERROR; + aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); + return; + } + + parser->uri->path.ptr = str->ptr; + parser->uri->path.len = location_of_q_mark - str->ptr; + aws_byte_cursor_advance(str, parser->uri->path.len); + parser->state = ON_QUERY_STRING; +} + +static void s_parse_query_string(struct uri_parser *parser, struct aws_byte_cursor *str) { + if (!parser->uri->path_and_query.ptr) { + parser->uri->path_and_query = *str; + } + /* we don't want the '?' character. */ + if (str->len) { + parser->uri->query_string.ptr = str->ptr + 1; + parser->uri->query_string.len = str->len - 1; + } + + aws_byte_cursor_advance(str, parser->uri->query_string.len + 1); + parser->state = FINISHED; +} + +static uint8_t s_to_uppercase_hex(uint8_t value) { + AWS_ASSERT(value < 16); + + if (value < 10) { + return (uint8_t)('0' + value); + } + + return (uint8_t)('A' + value - 10); +} + +typedef void(unchecked_append_canonicalized_character_fn)(struct aws_byte_buf *buffer, uint8_t value); + +/* + * Appends a character or its hex encoding to the buffer. We reserve enough space up front so that + * we can do this with raw pointers rather than multiple function calls/cursors/etc... + * + * This function is for the uri path + */ +static void s_unchecked_append_canonicalized_path_character(struct aws_byte_buf *buffer, uint8_t value) { + AWS_ASSERT(buffer->len + 3 <= buffer->capacity); + + uint8_t *dest_ptr = buffer->buffer + buffer->len; + + if (aws_isalnum(value)) { + ++buffer->len; + *dest_ptr = value; + return; + } + + switch (value) { + case '-': + case '_': + case '.': + case '~': + case '$': + case '&': + case ',': + case '/': + case ':': + case ';': + case '=': + case '@': { + ++buffer->len; + *dest_ptr = value; + return; + } + + default: + buffer->len += 3; + *dest_ptr++ = '%'; + *dest_ptr++ = s_to_uppercase_hex(value >> 4); + *dest_ptr = s_to_uppercase_hex(value & 0x0F); + return; + } +} + +/* + * Appends a character or its hex encoding to the buffer. We reserve enough space up front so that + * we can do this with raw pointers rather than multiple function calls/cursors/etc... + * + * This function is for query params + */ +static void s_raw_append_canonicalized_param_character(struct aws_byte_buf *buffer, uint8_t value) { + AWS_ASSERT(buffer->len + 3 <= buffer->capacity); + + uint8_t *dest_ptr = buffer->buffer + buffer->len; + + if (aws_isalnum(value)) { + ++buffer->len; + *dest_ptr = value; + return; + } + + switch (value) { + case '-': + case '_': + case '.': + case '~': { + ++buffer->len; + *dest_ptr = value; + return; + } + + default: + buffer->len += 3; + *dest_ptr++ = '%'; + *dest_ptr++ = s_to_uppercase_hex(value >> 4); + *dest_ptr = s_to_uppercase_hex(value & 0x0F); + return; + } +} + +/* + * Writes a cursor to a buffer using the supplied encoding function. + */ +static int s_encode_cursor_to_buffer( + struct aws_byte_buf *buffer, + const struct aws_byte_cursor *cursor, + unchecked_append_canonicalized_character_fn *append_canonicalized_character) { + uint8_t *current_ptr = cursor->ptr; + uint8_t *end_ptr = cursor->ptr + cursor->len; + + /* + * reserve room up front for the worst possible case: everything gets % encoded + */ + size_t capacity_needed = 0; + if (AWS_UNLIKELY(aws_mul_size_checked(3, cursor->len, &capacity_needed))) { + return AWS_OP_ERR; + } + + if (aws_byte_buf_reserve_relative(buffer, capacity_needed)) { + return AWS_OP_ERR; + } + + while (current_ptr < end_ptr) { + append_canonicalized_character(buffer, *current_ptr); + ++current_ptr; + } + + return AWS_OP_SUCCESS; +} + +int aws_byte_buf_append_encoding_uri_path(struct aws_byte_buf *buffer, const struct aws_byte_cursor *cursor) { + return s_encode_cursor_to_buffer(buffer, cursor, s_unchecked_append_canonicalized_path_character); +} + +int aws_byte_buf_append_encoding_uri_param(struct aws_byte_buf *buffer, const struct aws_byte_cursor *cursor) { + return s_encode_cursor_to_buffer(buffer, cursor, s_raw_append_canonicalized_param_character); +} + +int aws_byte_buf_append_decoding_uri(struct aws_byte_buf *buffer, const struct aws_byte_cursor *cursor) { + /* reserve room up front for worst possible case: no % and everything copies over 1:1 */ + if (aws_byte_buf_reserve_relative(buffer, cursor->len)) { + return AWS_OP_ERR; + } + + /* advance over cursor */ + struct aws_byte_cursor advancing = *cursor; + uint8_t c; + while (aws_byte_cursor_read_u8(&advancing, &c)) { + + if (c == '%') { + /* two hex characters following '%' are the byte's value */ + if (AWS_UNLIKELY(aws_byte_cursor_read_hex_u8(&advancing, &c) == false)) { + return aws_raise_error(AWS_ERROR_MALFORMED_INPUT_STRING); + } + } + + buffer->buffer[buffer->len++] = c; + } + + return AWS_OP_SUCCESS; +} |