diff options
author | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:24:06 +0300 |
---|---|---|
committer | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:41:34 +0300 |
commit | e0e3e1717e3d33762ce61950504f9637a6e669ed (patch) | |
tree | bca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/grpcio/py2/grpc/_server.py | |
parent | 38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff) | |
download | ydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz |
add ydb deps
Diffstat (limited to 'contrib/python/grpcio/py2/grpc/_server.py')
-rw-r--r-- | contrib/python/grpcio/py2/grpc/_server.py | 1003 |
1 files changed, 1003 insertions, 0 deletions
diff --git a/contrib/python/grpcio/py2/grpc/_server.py b/contrib/python/grpcio/py2/grpc/_server.py new file mode 100644 index 0000000000..58ab14feb3 --- /dev/null +++ b/contrib/python/grpcio/py2/grpc/_server.py @@ -0,0 +1,1003 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Service-side implementation of gRPC Python.""" + +import collections +from concurrent import futures +import enum +import logging +import threading +import time + +import grpc +from grpc import _common +from grpc import _compression +from grpc import _interceptor +from grpc._cython import cygrpc +import six + +_LOGGER = logging.getLogger(__name__) + +_SHUTDOWN_TAG = 'shutdown' +_REQUEST_CALL_TAG = 'request_call' + +_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server' +_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata' +_RECEIVE_MESSAGE_TOKEN = 'receive_message' +_SEND_MESSAGE_TOKEN = 'send_message' +_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( + 'send_initial_metadata * send_message') +_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server' +_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( + 'send_initial_metadata * send_status_from_server') + +_OPEN = 'open' +_CLOSED = 'closed' +_CANCELLED = 'cancelled' + +_EMPTY_FLAGS = 0 + +_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 +_INF_TIMEOUT = 1e9 + + +def _serialized_request(request_event): + return request_event.batch_operations[0].message() + + +def _application_code(code): + cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) + return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code + + +def _completion_code(state): + if state.code is None: + return cygrpc.StatusCode.ok + else: + return _application_code(state.code) + + +def _abortion_code(state, code): + if state.code is None: + return code + else: + return _application_code(state.code) + + +def _details(state): + return b'' if state.details is None else state.details + + +class _HandlerCallDetails( + collections.namedtuple('_HandlerCallDetails', ( + 'method', + 'invocation_metadata', + )), grpc.HandlerCallDetails): + pass + + +class _RPCState(object): + + def __init__(self): + self.condition = threading.Condition() + self.due = set() + self.request = None + self.client = _OPEN + self.initial_metadata_allowed = True + self.compression_algorithm = None + self.disable_next_compression = False + self.trailing_metadata = None + self.code = None + self.details = None + self.statused = False + self.rpc_errors = [] + self.callbacks = [] + self.aborted = False + + +def _raise_rpc_error(state): + rpc_error = grpc.RpcError() + state.rpc_errors.append(rpc_error) + raise rpc_error + + +def _possibly_finish_call(state, token): + state.due.remove(token) + if not _is_rpc_state_active(state) and not state.due: + callbacks = state.callbacks + state.callbacks = None + return state, callbacks + else: + return None, () + + +def _send_status_from_server(state, token): + + def send_status_from_server(unused_send_status_from_server_event): + with state.condition: + return _possibly_finish_call(state, token) + + return send_status_from_server + + +def _get_initial_metadata(state, metadata): + with state.condition: + if state.compression_algorithm: + compression_metadata = ( + _compression.compression_algorithm_to_metadata( + state.compression_algorithm),) + if metadata is None: + return compression_metadata + else: + return compression_metadata + tuple(metadata) + else: + return metadata + + +def _get_initial_metadata_operation(state, metadata): + operation = cygrpc.SendInitialMetadataOperation( + _get_initial_metadata(state, metadata), _EMPTY_FLAGS) + return operation + + +def _abort(state, call, code, details): + if state.client is not _CANCELLED: + effective_code = _abortion_code(state, code) + effective_details = details if state.details is None else state.details + if state.initial_metadata_allowed: + operations = ( + _get_initial_metadata_operation(state, None), + cygrpc.SendStatusFromServerOperation(state.trailing_metadata, + effective_code, + effective_details, + _EMPTY_FLAGS), + ) + token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN + else: + operations = (cygrpc.SendStatusFromServerOperation( + state.trailing_metadata, effective_code, effective_details, + _EMPTY_FLAGS),) + token = _SEND_STATUS_FROM_SERVER_TOKEN + call.start_server_batch(operations, + _send_status_from_server(state, token)) + state.statused = True + state.due.add(token) + + +def _receive_close_on_server(state): + + def receive_close_on_server(receive_close_on_server_event): + with state.condition: + if receive_close_on_server_event.batch_operations[0].cancelled(): + state.client = _CANCELLED + elif state.client is _OPEN: + state.client = _CLOSED + state.condition.notify_all() + return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) + + return receive_close_on_server + + +def _receive_message(state, call, request_deserializer): + + def receive_message(receive_message_event): + serialized_request = _serialized_request(receive_message_event) + if serialized_request is None: + with state.condition: + if state.client is _OPEN: + state.client = _CLOSED + state.condition.notify_all() + return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) + else: + request = _common.deserialize(serialized_request, + request_deserializer) + with state.condition: + if request is None: + _abort(state, call, cygrpc.StatusCode.internal, + b'Exception deserializing request!') + else: + state.request = request + state.condition.notify_all() + return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) + + return receive_message + + +def _send_initial_metadata(state): + + def send_initial_metadata(unused_send_initial_metadata_event): + with state.condition: + return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) + + return send_initial_metadata + + +def _send_message(state, token): + + def send_message(unused_send_message_event): + with state.condition: + state.condition.notify_all() + return _possibly_finish_call(state, token) + + return send_message + + +class _Context(grpc.ServicerContext): + + def __init__(self, rpc_event, state, request_deserializer): + self._rpc_event = rpc_event + self._state = state + self._request_deserializer = request_deserializer + + def is_active(self): + with self._state.condition: + return _is_rpc_state_active(self._state) + + def time_remaining(self): + return max(self._rpc_event.call_details.deadline - time.time(), 0) + + def cancel(self): + self._rpc_event.call.cancel() + + def add_callback(self, callback): + with self._state.condition: + if self._state.callbacks is None: + return False + else: + self._state.callbacks.append(callback) + return True + + def disable_next_message_compression(self): + with self._state.condition: + self._state.disable_next_compression = True + + def invocation_metadata(self): + return self._rpc_event.invocation_metadata + + def peer(self): + return _common.decode(self._rpc_event.call.peer()) + + def peer_identities(self): + return cygrpc.peer_identities(self._rpc_event.call) + + def peer_identity_key(self): + id_key = cygrpc.peer_identity_key(self._rpc_event.call) + return id_key if id_key is None else _common.decode(id_key) + + def auth_context(self): + return { + _common.decode(key): value for key, value in six.iteritems( + cygrpc.auth_context(self._rpc_event.call)) + } + + def set_compression(self, compression): + with self._state.condition: + self._state.compression_algorithm = compression + + def send_initial_metadata(self, initial_metadata): + with self._state.condition: + if self._state.client is _CANCELLED: + _raise_rpc_error(self._state) + else: + if self._state.initial_metadata_allowed: + operation = _get_initial_metadata_operation( + self._state, initial_metadata) + self._rpc_event.call.start_server_batch( + (operation,), _send_initial_metadata(self._state)) + self._state.initial_metadata_allowed = False + self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) + else: + raise ValueError('Initial metadata no longer allowed!') + + def set_trailing_metadata(self, trailing_metadata): + with self._state.condition: + self._state.trailing_metadata = trailing_metadata + + def trailing_metadata(self): + return self._state.trailing_metadata + + def abort(self, code, details): + # treat OK like other invalid arguments: fail the RPC + if code == grpc.StatusCode.OK: + _LOGGER.error( + 'abort() called with StatusCode.OK; returning UNKNOWN') + code = grpc.StatusCode.UNKNOWN + details = '' + with self._state.condition: + self._state.code = code + self._state.details = _common.encode(details) + self._state.aborted = True + raise Exception() + + def abort_with_status(self, status): + self._state.trailing_metadata = status.trailing_metadata + self.abort(status.code, status.details) + + def set_code(self, code): + with self._state.condition: + self._state.code = code + + def code(self): + return self._state.code + + def set_details(self, details): + with self._state.condition: + self._state.details = _common.encode(details) + + def details(self): + return self._state.details + + def _finalize_state(self): + pass + + +class _RequestIterator(object): + + def __init__(self, state, call, request_deserializer): + self._state = state + self._call = call + self._request_deserializer = request_deserializer + + def _raise_or_start_receive_message(self): + if self._state.client is _CANCELLED: + _raise_rpc_error(self._state) + elif not _is_rpc_state_active(self._state): + raise StopIteration() + else: + self._call.start_server_batch( + (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), + _receive_message(self._state, self._call, + self._request_deserializer)) + self._state.due.add(_RECEIVE_MESSAGE_TOKEN) + + def _look_for_request(self): + if self._state.client is _CANCELLED: + _raise_rpc_error(self._state) + elif (self._state.request is None and + _RECEIVE_MESSAGE_TOKEN not in self._state.due): + raise StopIteration() + else: + request = self._state.request + self._state.request = None + return request + + raise AssertionError() # should never run + + def _next(self): + with self._state.condition: + self._raise_or_start_receive_message() + while True: + self._state.condition.wait() + request = self._look_for_request() + if request is not None: + return request + + def __iter__(self): + return self + + def __next__(self): + return self._next() + + def next(self): + return self._next() + + +def _unary_request(rpc_event, state, request_deserializer): + + def unary_request(): + with state.condition: + if not _is_rpc_state_active(state): + return None + else: + rpc_event.call.start_server_batch( + (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), + _receive_message(state, rpc_event.call, + request_deserializer)) + state.due.add(_RECEIVE_MESSAGE_TOKEN) + while True: + state.condition.wait() + if state.request is None: + if state.client is _CLOSED: + details = '"{}" requires exactly one request message.'.format( + rpc_event.call_details.method) + _abort(state, rpc_event.call, + cygrpc.StatusCode.unimplemented, + _common.encode(details)) + return None + elif state.client is _CANCELLED: + return None + else: + request = state.request + state.request = None + return request + + return unary_request + + +def _call_behavior(rpc_event, + state, + behavior, + argument, + request_deserializer, + send_response_callback=None): + from grpc import _create_servicer_context + with _create_servicer_context(rpc_event, state, + request_deserializer) as context: + try: + response_or_iterator = None + if send_response_callback is not None: + response_or_iterator = behavior(argument, context, + send_response_callback) + else: + response_or_iterator = behavior(argument, context) + return response_or_iterator, True + except Exception as exception: # pylint: disable=broad-except + with state.condition: + if state.aborted: + _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, + b'RPC Aborted') + elif exception not in state.rpc_errors: + details = 'Exception calling application: {}'.format( + exception) + _LOGGER.exception(details) + _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, + _common.encode(details)) + return None, False + + +def _take_response_from_response_iterator(rpc_event, state, response_iterator): + try: + return next(response_iterator), True + except StopIteration: + return None, True + except Exception as exception: # pylint: disable=broad-except + with state.condition: + if state.aborted: + _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, + b'RPC Aborted') + elif exception not in state.rpc_errors: + details = 'Exception iterating responses: {}'.format(exception) + _LOGGER.exception(details) + _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, + _common.encode(details)) + return None, False + + +def _serialize_response(rpc_event, state, response, response_serializer): + serialized_response = _common.serialize(response, response_serializer) + if serialized_response is None: + with state.condition: + _abort(state, rpc_event.call, cygrpc.StatusCode.internal, + b'Failed to serialize response!') + return None + else: + return serialized_response + + +def _get_send_message_op_flags_from_state(state): + if state.disable_next_compression: + return cygrpc.WriteFlag.no_compress + else: + return _EMPTY_FLAGS + + +def _reset_per_message_state(state): + with state.condition: + state.disable_next_compression = False + + +def _send_response(rpc_event, state, serialized_response): + with state.condition: + if not _is_rpc_state_active(state): + return False + else: + if state.initial_metadata_allowed: + operations = ( + _get_initial_metadata_operation(state, None), + cygrpc.SendMessageOperation( + serialized_response, + _get_send_message_op_flags_from_state(state)), + ) + state.initial_metadata_allowed = False + token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN + else: + operations = (cygrpc.SendMessageOperation( + serialized_response, + _get_send_message_op_flags_from_state(state)),) + token = _SEND_MESSAGE_TOKEN + rpc_event.call.start_server_batch(operations, + _send_message(state, token)) + state.due.add(token) + _reset_per_message_state(state) + while True: + state.condition.wait() + if token not in state.due: + return _is_rpc_state_active(state) + + +def _status(rpc_event, state, serialized_response): + with state.condition: + if state.client is not _CANCELLED: + code = _completion_code(state) + details = _details(state) + operations = [ + cygrpc.SendStatusFromServerOperation(state.trailing_metadata, + code, details, + _EMPTY_FLAGS), + ] + if state.initial_metadata_allowed: + operations.append(_get_initial_metadata_operation(state, None)) + if serialized_response is not None: + operations.append( + cygrpc.SendMessageOperation( + serialized_response, + _get_send_message_op_flags_from_state(state))) + rpc_event.call.start_server_batch( + operations, + _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) + state.statused = True + _reset_per_message_state(state) + state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) + + +def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk, + request_deserializer, response_serializer): + cygrpc.install_context_from_request_call_event(rpc_event) + try: + argument = argument_thunk() + if argument is not None: + response, proceed = _call_behavior(rpc_event, state, behavior, + argument, request_deserializer) + if proceed: + serialized_response = _serialize_response( + rpc_event, state, response, response_serializer) + if serialized_response is not None: + _status(rpc_event, state, serialized_response) + finally: + cygrpc.uninstall_context() + + +def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk, + request_deserializer, response_serializer): + cygrpc.install_context_from_request_call_event(rpc_event) + + def send_response(response): + if response is None: + _status(rpc_event, state, None) + else: + serialized_response = _serialize_response(rpc_event, state, + response, + response_serializer) + if serialized_response is not None: + _send_response(rpc_event, state, serialized_response) + + try: + argument = argument_thunk() + if argument is not None: + if hasattr(behavior, 'experimental_non_blocking' + ) and behavior.experimental_non_blocking: + _call_behavior(rpc_event, + state, + behavior, + argument, + request_deserializer, + send_response_callback=send_response) + else: + response_iterator, proceed = _call_behavior( + rpc_event, state, behavior, argument, request_deserializer) + if proceed: + _send_message_callback_to_blocking_iterator_adapter( + rpc_event, state, send_response, response_iterator) + finally: + cygrpc.uninstall_context() + + +def _is_rpc_state_active(state): + return state.client is not _CANCELLED and not state.statused + + +def _send_message_callback_to_blocking_iterator_adapter(rpc_event, state, + send_response_callback, + response_iterator): + while True: + response, proceed = _take_response_from_response_iterator( + rpc_event, state, response_iterator) + if proceed: + send_response_callback(response) + if not _is_rpc_state_active(state): + break + else: + break + + +def _select_thread_pool_for_behavior(behavior, default_thread_pool): + if hasattr(behavior, 'experimental_thread_pool') and isinstance( + behavior.experimental_thread_pool, futures.ThreadPoolExecutor): + return behavior.experimental_thread_pool + else: + return default_thread_pool + + +def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool): + unary_request = _unary_request(rpc_event, state, + method_handler.request_deserializer) + thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary, + default_thread_pool) + return thread_pool.submit(_unary_response_in_pool, rpc_event, state, + method_handler.unary_unary, unary_request, + method_handler.request_deserializer, + method_handler.response_serializer) + + +def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool): + unary_request = _unary_request(rpc_event, state, + method_handler.request_deserializer) + thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream, + default_thread_pool) + return thread_pool.submit(_stream_response_in_pool, rpc_event, state, + method_handler.unary_stream, unary_request, + method_handler.request_deserializer, + method_handler.response_serializer) + + +def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool): + request_iterator = _RequestIterator(state, rpc_event.call, + method_handler.request_deserializer) + thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary, + default_thread_pool) + return thread_pool.submit(_unary_response_in_pool, rpc_event, state, + method_handler.stream_unary, + lambda: request_iterator, + method_handler.request_deserializer, + method_handler.response_serializer) + + +def _handle_stream_stream(rpc_event, state, method_handler, + default_thread_pool): + request_iterator = _RequestIterator(state, rpc_event.call, + method_handler.request_deserializer) + thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream, + default_thread_pool) + return thread_pool.submit(_stream_response_in_pool, rpc_event, state, + method_handler.stream_stream, + lambda: request_iterator, + method_handler.request_deserializer, + method_handler.response_serializer) + + +def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline): + + def query_handlers(handler_call_details): + for generic_handler in generic_handlers: + method_handler = generic_handler.service(handler_call_details) + if method_handler is not None: + return method_handler + return None + + handler_call_details = _HandlerCallDetails( + _common.decode(rpc_event.call_details.method), + rpc_event.invocation_metadata) + + if interceptor_pipeline is not None: + return interceptor_pipeline.execute(query_handlers, + handler_call_details) + else: + return query_handlers(handler_call_details) + + +def _reject_rpc(rpc_event, status, details): + rpc_state = _RPCState() + operations = ( + _get_initial_metadata_operation(rpc_state, None), + cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation(None, status, details, + _EMPTY_FLAGS), + ) + rpc_event.call.start_server_batch(operations, lambda ignored_event: ( + rpc_state, + (), + )) + return rpc_state + + +def _handle_with_method_handler(rpc_event, method_handler, thread_pool): + state = _RPCState() + with state.condition: + rpc_event.call.start_server_batch( + (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), + _receive_close_on_server(state)) + state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) + if method_handler.request_streaming: + if method_handler.response_streaming: + return state, _handle_stream_stream(rpc_event, state, + method_handler, thread_pool) + else: + return state, _handle_stream_unary(rpc_event, state, + method_handler, thread_pool) + else: + if method_handler.response_streaming: + return state, _handle_unary_stream(rpc_event, state, + method_handler, thread_pool) + else: + return state, _handle_unary_unary(rpc_event, state, + method_handler, thread_pool) + + +def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool, + concurrency_exceeded): + if not rpc_event.success: + return None, None + if rpc_event.call_details.method is not None: + try: + method_handler = _find_method_handler(rpc_event, generic_handlers, + interceptor_pipeline) + except Exception as exception: # pylint: disable=broad-except + details = 'Exception servicing handler: {}'.format(exception) + _LOGGER.exception(details) + return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown, + b'Error in service handler!'), None + if method_handler is None: + return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented, + b'Method not found!'), None + elif concurrency_exceeded: + return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted, + b'Concurrent RPC limit exceeded!'), None + else: + return _handle_with_method_handler(rpc_event, method_handler, + thread_pool) + else: + return None, None + + +@enum.unique +class _ServerStage(enum.Enum): + STOPPED = 'stopped' + STARTED = 'started' + GRACE = 'grace' + + +class _ServerState(object): + + # pylint: disable=too-many-arguments + def __init__(self, completion_queue, server, generic_handlers, + interceptor_pipeline, thread_pool, maximum_concurrent_rpcs): + self.lock = threading.RLock() + self.completion_queue = completion_queue + self.server = server + self.generic_handlers = list(generic_handlers) + self.interceptor_pipeline = interceptor_pipeline + self.thread_pool = thread_pool + self.stage = _ServerStage.STOPPED + self.termination_event = threading.Event() + self.shutdown_events = [self.termination_event] + self.maximum_concurrent_rpcs = maximum_concurrent_rpcs + self.active_rpc_count = 0 + + # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. + self.rpc_states = set() + self.due = set() + + # A "volatile" flag to interrupt the daemon serving thread + self.server_deallocated = False + + +def _add_generic_handlers(state, generic_handlers): + with state.lock: + state.generic_handlers.extend(generic_handlers) + + +def _add_insecure_port(state, address): + with state.lock: + return state.server.add_http2_port(address) + + +def _add_secure_port(state, address, server_credentials): + with state.lock: + return state.server.add_http2_port(address, + server_credentials._credentials) + + +def _request_call(state): + state.server.request_call(state.completion_queue, state.completion_queue, + _REQUEST_CALL_TAG) + state.due.add(_REQUEST_CALL_TAG) + + +# TODO(https://github.com/grpc/grpc/issues/6597): delete this function. +def _stop_serving(state): + if not state.rpc_states and not state.due: + state.server.destroy() + for shutdown_event in state.shutdown_events: + shutdown_event.set() + state.stage = _ServerStage.STOPPED + return True + else: + return False + + +def _on_call_completed(state): + with state.lock: + state.active_rpc_count -= 1 + + +def _process_event_and_continue(state, event): + should_continue = True + if event.tag is _SHUTDOWN_TAG: + with state.lock: + state.due.remove(_SHUTDOWN_TAG) + if _stop_serving(state): + should_continue = False + elif event.tag is _REQUEST_CALL_TAG: + with state.lock: + state.due.remove(_REQUEST_CALL_TAG) + concurrency_exceeded = ( + state.maximum_concurrent_rpcs is not None and + state.active_rpc_count >= state.maximum_concurrent_rpcs) + rpc_state, rpc_future = _handle_call(event, state.generic_handlers, + state.interceptor_pipeline, + state.thread_pool, + concurrency_exceeded) + if rpc_state is not None: + state.rpc_states.add(rpc_state) + if rpc_future is not None: + state.active_rpc_count += 1 + rpc_future.add_done_callback( + lambda unused_future: _on_call_completed(state)) + if state.stage is _ServerStage.STARTED: + _request_call(state) + elif _stop_serving(state): + should_continue = False + else: + rpc_state, callbacks = event.tag(event) + for callback in callbacks: + try: + callback() + except Exception: # pylint: disable=broad-except + _LOGGER.exception('Exception calling callback!') + if rpc_state is not None: + with state.lock: + state.rpc_states.remove(rpc_state) + if _stop_serving(state): + should_continue = False + return should_continue + + +def _serve(state): + while True: + timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S + event = state.completion_queue.poll(timeout) + if state.server_deallocated: + _begin_shutdown_once(state) + if event.completion_type != cygrpc.CompletionType.queue_timeout: + if not _process_event_and_continue(state, event): + return + # We want to force the deletion of the previous event + # ~before~ we poll again; if the event has a reference + # to a shutdown Call object, this can induce spinlock. + event = None + + +def _begin_shutdown_once(state): + with state.lock: + if state.stage is _ServerStage.STARTED: + state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) + state.stage = _ServerStage.GRACE + state.due.add(_SHUTDOWN_TAG) + + +def _stop(state, grace): + with state.lock: + if state.stage is _ServerStage.STOPPED: + shutdown_event = threading.Event() + shutdown_event.set() + return shutdown_event + else: + _begin_shutdown_once(state) + shutdown_event = threading.Event() + state.shutdown_events.append(shutdown_event) + if grace is None: + state.server.cancel_all_calls() + else: + + def cancel_all_calls_after_grace(): + shutdown_event.wait(timeout=grace) + with state.lock: + state.server.cancel_all_calls() + + thread = threading.Thread(target=cancel_all_calls_after_grace) + thread.start() + return shutdown_event + shutdown_event.wait() + return shutdown_event + + +def _start(state): + with state.lock: + if state.stage is not _ServerStage.STOPPED: + raise ValueError('Cannot start already-started server!') + state.server.start() + state.stage = _ServerStage.STARTED + _request_call(state) + + thread = threading.Thread(target=_serve, args=(state,)) + thread.daemon = True + thread.start() + + +def _validate_generic_rpc_handlers(generic_rpc_handlers): + for generic_rpc_handler in generic_rpc_handlers: + service_attribute = getattr(generic_rpc_handler, 'service', None) + if service_attribute is None: + raise AttributeError( + '"{}" must conform to grpc.GenericRpcHandler type but does ' + 'not have "service" method!'.format(generic_rpc_handler)) + + +def _augment_options(base_options, compression): + compression_option = _compression.create_channel_option(compression) + return tuple(base_options) + compression_option + + +class _Server(grpc.Server): + + # pylint: disable=too-many-arguments + def __init__(self, thread_pool, generic_handlers, interceptors, options, + maximum_concurrent_rpcs, compression, xds): + completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server(_augment_options(options, compression), xds) + server.register_completion_queue(completion_queue) + self._state = _ServerState(completion_queue, server, generic_handlers, + _interceptor.service_pipeline(interceptors), + thread_pool, maximum_concurrent_rpcs) + + def add_generic_rpc_handlers(self, generic_rpc_handlers): + _validate_generic_rpc_handlers(generic_rpc_handlers) + _add_generic_handlers(self._state, generic_rpc_handlers) + + def add_insecure_port(self, address): + return _common.validate_port_binding_result( + address, _add_insecure_port(self._state, _common.encode(address))) + + def add_secure_port(self, address, server_credentials): + return _common.validate_port_binding_result( + address, + _add_secure_port(self._state, _common.encode(address), + server_credentials)) + + def start(self): + _start(self._state) + + def wait_for_termination(self, timeout=None): + # NOTE(https://bugs.python.org/issue35935) + # Remove this workaround once threading.Event.wait() is working with + # CTRL+C across platforms. + return _common.wait(self._state.termination_event.wait, + self._state.termination_event.is_set, + timeout=timeout) + + def stop(self, grace): + return _stop(self._state, grace) + + def __del__(self): + if hasattr(self, '_state'): + # We can not grab a lock in __del__(), so set a flag to signal the + # serving daemon thread (if it exists) to initiate shutdown. + self._state.server_deallocated = True + + +def create_server(thread_pool, generic_rpc_handlers, interceptors, options, + maximum_concurrent_rpcs, compression, xds): + _validate_generic_rpc_handlers(generic_rpc_handlers) + return _Server(thread_pool, generic_rpc_handlers, interceptors, options, + maximum_concurrent_rpcs, compression, xds) |