diff options
author | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:24:06 +0300 |
---|---|---|
committer | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:41:34 +0300 |
commit | e0e3e1717e3d33762ce61950504f9637a6e669ed (patch) | |
tree | bca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/grpcio/py3/grpc/_channel.py | |
parent | 38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff) | |
download | ydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz |
add ydb deps
Diffstat (limited to 'contrib/python/grpcio/py3/grpc/_channel.py')
-rw-r--r-- | contrib/python/grpcio/py3/grpc/_channel.py | 1767 |
1 files changed, 1767 insertions, 0 deletions
diff --git a/contrib/python/grpcio/py3/grpc/_channel.py b/contrib/python/grpcio/py3/grpc/_channel.py new file mode 100644 index 0000000000..d31344fd0e --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/_channel.py @@ -0,0 +1,1767 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Invocation-side implementation of gRPC Python.""" + +import copy +import functools +import logging +import os +import sys +import threading +import time +import types +from typing import (Any, Callable, Iterator, List, Optional, Sequence, Set, + Tuple, Union) + +import grpc # pytype: disable=pyi-error +from grpc import _common # pytype: disable=pyi-error +from grpc import _compression # pytype: disable=pyi-error +from grpc import _grpcio_metadata # pytype: disable=pyi-error +from grpc._cython import cygrpc +from grpc._typing import ChannelArgumentType +from grpc._typing import DeserializingFunction +from grpc._typing import IntegratedCallFactory +from grpc._typing import MetadataType +from grpc._typing import NullaryCallbackType +from grpc._typing import ResponseType +from grpc._typing import SerializingFunction +from grpc._typing import UserTag +import grpc.experimental # pytype: disable=pyi-error + +_LOGGER = logging.getLogger(__name__) + +_USER_AGENT = 'grpc-python/{}'.format(_grpcio_metadata.__version__) + +_EMPTY_FLAGS = 0 + +# NOTE(rbellevi): No guarantees are given about the maintenance of this +# environment variable. +_DEFAULT_SINGLE_THREADED_UNARY_STREAM = os.getenv( + "GRPC_SINGLE_THREADED_UNARY_STREAM") is not None + +_UNARY_UNARY_INITIAL_DUE = ( + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.send_message, + cygrpc.OperationType.send_close_from_client, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.receive_status_on_client, +) +_UNARY_STREAM_INITIAL_DUE = ( + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.send_message, + cygrpc.OperationType.send_close_from_client, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_status_on_client, +) +_STREAM_UNARY_INITIAL_DUE = ( + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.receive_status_on_client, +) +_STREAM_STREAM_INITIAL_DUE = ( + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_status_on_client, +) + +_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = ( + 'Exception calling channel subscription callback!') + +_OK_RENDEZVOUS_REPR_FORMAT = ('<{} of RPC that terminated with:\n' + '\tstatus = {}\n' + '\tdetails = "{}"\n' + '>') + +_NON_OK_RENDEZVOUS_REPR_FORMAT = ('<{} of RPC that terminated with:\n' + '\tstatus = {}\n' + '\tdetails = "{}"\n' + '\tdebug_error_string = "{}"\n' + '>') + + +def _deadline(timeout: Optional[float]) -> Optional[float]: + return None if timeout is None else time.time() + timeout + + +def _unknown_code_details(unknown_cygrpc_code: Optional[grpc.StatusCode], + details: Optional[str]) -> str: + return 'Server sent unknown code {} and details "{}"'.format( + unknown_cygrpc_code, details) + + +class _RPCState(object): + condition: threading.Condition + due: Set[cygrpc.OperationType] + initial_metadata: Optional[MetadataType] + response: Any + trailing_metadata: Optional[MetadataType] + code: Optional[grpc.StatusCode] + details: Optional[str] + debug_error_string: Optional[str] + cancelled: bool + callbacks: List[NullaryCallbackType] + fork_epoch: Optional[int] + + def __init__(self, due: Sequence[cygrpc.OperationType], + initial_metadata: Optional[MetadataType], + trailing_metadata: Optional[MetadataType], + code: Optional[grpc.StatusCode], details: Optional[str]): + # `condition` guards all members of _RPCState. `notify_all` is called on + # `condition` when the state of the RPC has changed. + self.condition = threading.Condition() + + # The cygrpc.OperationType objects representing events due from the RPC's + # completion queue. If an operation is in `due`, it is guaranteed that + # `operate()` has been called on a corresponding operation. But the + # converse is not true. That is, in the case of failed `operate()` + # calls, there may briefly be events in `due` that do not correspond to + # operations submitted to Core. + self.due = set(due) + self.initial_metadata = initial_metadata + self.response = None + self.trailing_metadata = trailing_metadata + self.code = code + self.details = details + self.debug_error_string = None + + # The semantics of grpc.Future.cancel and grpc.Future.cancelled are + # slightly wonky, so they have to be tracked separately from the rest of the + # result of the RPC. This field tracks whether cancellation was requested + # prior to termination of the RPC. + self.cancelled = False + self.callbacks = [] + self.fork_epoch = cygrpc.get_fork_epoch() + + def reset_postfork_child(self): + self.condition = threading.Condition() + + +def _abort(state: _RPCState, code: grpc.StatusCode, details: str) -> None: + if state.code is None: + state.code = code + state.details = details + if state.initial_metadata is None: + state.initial_metadata = () + state.trailing_metadata = () + + +def _handle_event( + event: cygrpc.BaseEvent, state: _RPCState, + response_deserializer: Optional[DeserializingFunction] +) -> List[NullaryCallbackType]: + callbacks = [] + for batch_operation in event.batch_operations: + operation_type = batch_operation.type() + state.due.remove(operation_type) + if operation_type == cygrpc.OperationType.receive_initial_metadata: + state.initial_metadata = batch_operation.initial_metadata() + elif operation_type == cygrpc.OperationType.receive_message: + serialized_response = batch_operation.message() + if serialized_response is not None: + response = _common.deserialize(serialized_response, + response_deserializer) + if response is None: + details = 'Exception deserializing response!' + _abort(state, grpc.StatusCode.INTERNAL, details) + else: + state.response = response + elif operation_type == cygrpc.OperationType.receive_status_on_client: + state.trailing_metadata = batch_operation.trailing_metadata() + if state.code is None: + code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE.get( + batch_operation.code()) + if code is None: + state.code = grpc.StatusCode.UNKNOWN + state.details = _unknown_code_details( + code, batch_operation.details()) + else: + state.code = code + state.details = batch_operation.details() + state.debug_error_string = batch_operation.error_string() + callbacks.extend(state.callbacks) + state.callbacks = None + return callbacks + + +def _event_handler( + state: _RPCState, + response_deserializer: Optional[DeserializingFunction]) -> UserTag: + + def handle_event(event): + with state.condition: + callbacks = _handle_event(event, state, response_deserializer) + state.condition.notify_all() + done = not state.due + for callback in callbacks: + try: + callback() + except Exception as e: # pylint: disable=broad-except + # NOTE(rbellevi): We suppress but log errors here so as not to + # kill the channel spin thread. + logging.error('Exception in callback %s: %s', + repr(callback.func), repr(e)) + return done and state.fork_epoch >= cygrpc.get_fork_epoch() + + return handle_event + + +# TODO(xuanwn): Create a base class for IntegratedCall and SegregatedCall. +#pylint: disable=too-many-statements +def _consume_request_iterator(request_iterator: Iterator, state: _RPCState, + call: Union[cygrpc.IntegratedCall, + cygrpc.SegregatedCall], + request_serializer: SerializingFunction, + event_handler: Optional[UserTag]) -> None: + """Consume a request supplied by the user.""" + + def consume_request_iterator(): # pylint: disable=too-many-branches + # Iterate over the request iterator until it is exhausted or an error + # condition is encountered. + while True: + return_from_user_request_generator_invoked = False + try: + # The thread may die in user-code. Do not block fork for this. + cygrpc.enter_user_request_generator() + request = next(request_iterator) + except StopIteration: + break + except Exception: # pylint: disable=broad-except + cygrpc.return_from_user_request_generator() + return_from_user_request_generator_invoked = True + code = grpc.StatusCode.UNKNOWN + details = 'Exception iterating requests!' + _LOGGER.exception(details) + call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details) + _abort(state, code, details) + return + finally: + if not return_from_user_request_generator_invoked: + cygrpc.return_from_user_request_generator() + serialized_request = _common.serialize(request, request_serializer) + with state.condition: + if state.code is None and not state.cancelled: + if serialized_request is None: + code = grpc.StatusCode.INTERNAL + details = 'Exception serializing request!' + call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details) + _abort(state, code, details) + return + else: + state.due.add(cygrpc.OperationType.send_message) + operations = (cygrpc.SendMessageOperation( + serialized_request, _EMPTY_FLAGS),) + operating = call.operate(operations, event_handler) + if not operating: + state.due.remove(cygrpc.OperationType.send_message) + return + + def _done(): + return (state.code is not None or + cygrpc.OperationType.send_message + not in state.due) + + _common.wait(state.condition.wait, + _done, + spin_cb=functools.partial( + cygrpc.block_if_fork_in_progress, + state)) + if state.code is not None: + return + else: + return + with state.condition: + if state.code is None: + state.due.add(cygrpc.OperationType.send_close_from_client) + operations = ( + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),) + operating = call.operate(operations, event_handler) + if not operating: + state.due.remove( + cygrpc.OperationType.send_close_from_client) + + consumption_thread = cygrpc.ForkManagedThread( + target=consume_request_iterator) + consumption_thread.setDaemon(True) + consumption_thread.start() + + +def _rpc_state_string(class_name: str, rpc_state: _RPCState) -> str: + """Calculates error string for RPC.""" + with rpc_state.condition: + if rpc_state.code is None: + return '<{} object>'.format(class_name) + elif rpc_state.code is grpc.StatusCode.OK: + return _OK_RENDEZVOUS_REPR_FORMAT.format(class_name, rpc_state.code, + rpc_state.details) + else: + return _NON_OK_RENDEZVOUS_REPR_FORMAT.format( + class_name, rpc_state.code, rpc_state.details, + rpc_state.debug_error_string) + + +class _InactiveRpcError(grpc.RpcError, grpc.Call, grpc.Future): + """An RPC error not tied to the execution of a particular RPC. + + The RPC represented by the state object must not be in-progress or + cancelled. + + Attributes: + _state: An instance of _RPCState. + """ + _state: _RPCState + + def __init__(self, state: _RPCState): + with state.condition: + self._state = _RPCState((), copy.deepcopy(state.initial_metadata), + copy.deepcopy(state.trailing_metadata), + state.code, copy.deepcopy(state.details)) + self._state.response = copy.copy(state.response) + self._state.debug_error_string = copy.copy(state.debug_error_string) + + def initial_metadata(self) -> Optional[MetadataType]: + return self._state.initial_metadata + + def trailing_metadata(self) -> Optional[MetadataType]: + return self._state.trailing_metadata + + def code(self) -> Optional[grpc.StatusCode]: + return self._state.code + + def details(self) -> Optional[str]: + return _common.decode(self._state.details) + + def debug_error_string(self) -> Optional[str]: + return _common.decode(self._state.debug_error_string) + + def _repr(self) -> str: + return _rpc_state_string(self.__class__.__name__, self._state) + + def __repr__(self) -> str: + return self._repr() + + def __str__(self) -> str: + return self._repr() + + def cancel(self) -> bool: + """See grpc.Future.cancel.""" + return False + + def cancelled(self) -> bool: + """See grpc.Future.cancelled.""" + return False + + def running(self) -> bool: + """See grpc.Future.running.""" + return False + + def done(self) -> bool: + """See grpc.Future.done.""" + return True + + def result(self, timeout: Optional[float] = None) -> Any: # pylint: disable=unused-argument + """See grpc.Future.result.""" + raise self + + def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: # pylint: disable=unused-argument + """See grpc.Future.exception.""" + return self + + def traceback( + self, + timeout: Optional[float] = None # pylint: disable=unused-argument + ) -> Optional[types.TracebackType]: + """See grpc.Future.traceback.""" + try: + raise self + except grpc.RpcError: + return sys.exc_info()[2] + + def add_done_callback( + self, + fn: Callable[[grpc.Future], None], + timeout: Optional[float] = None) -> None: # pylint: disable=unused-argument + """See grpc.Future.add_done_callback.""" + fn(self) + + +class _Rendezvous(grpc.RpcError, grpc.RpcContext): + """An RPC iterator. + + Attributes: + _state: An instance of _RPCState. + _call: An instance of SegregatedCall or IntegratedCall. + In either case, the _call object is expected to have operate, cancel, + and next_event methods. + _response_deserializer: A callable taking bytes and return a Python + object. + _deadline: A float representing the deadline of the RPC in seconds. Or + possibly None, to represent an RPC with no deadline at all. + """ + _state: _RPCState + _call: Union[cygrpc.SegregatedCall, cygrpc.IntegratedCall] + _response_deserializer: Optional[DeserializingFunction] + _deadline: Optional[float] + + def __init__(self, state: _RPCState, call: Union[cygrpc.SegregatedCall, + cygrpc.IntegratedCall], + response_deserializer: Optional[DeserializingFunction], + deadline: Optional[float]): + super(_Rendezvous, self).__init__() + self._state = state + self._call = call + self._response_deserializer = response_deserializer + self._deadline = deadline + + def is_active(self) -> bool: + """See grpc.RpcContext.is_active""" + with self._state.condition: + return self._state.code is None + + def time_remaining(self) -> Optional[float]: + """See grpc.RpcContext.time_remaining""" + with self._state.condition: + if self._deadline is None: + return None + else: + return max(self._deadline - time.time(), 0) + + def cancel(self) -> bool: + """See grpc.RpcContext.cancel""" + with self._state.condition: + if self._state.code is None: + code = grpc.StatusCode.CANCELLED + details = 'Locally cancelled by application!' + self._call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details) + self._state.cancelled = True + _abort(self._state, code, details) + self._state.condition.notify_all() + return True + else: + return False + + def add_callback(self, callback: NullaryCallbackType) -> bool: + """See grpc.RpcContext.add_callback""" + with self._state.condition: + if self._state.callbacks is None: + return False + else: + self._state.callbacks.append(callback) + return True + + def __iter__(self): + return self + + def next(self): + return self._next() + + def __next__(self): + return self._next() + + def _next(self): + raise NotImplementedError() + + def debug_error_string(self) -> Optional[str]: + raise NotImplementedError() + + def _repr(self) -> str: + return _rpc_state_string(self.__class__.__name__, self._state) + + def __repr__(self) -> str: + return self._repr() + + def __str__(self) -> str: + return self._repr() + + def __del__(self) -> None: + with self._state.condition: + if self._state.code is None: + self._state.code = grpc.StatusCode.CANCELLED + self._state.details = 'Cancelled upon garbage collection!' + self._state.cancelled = True + self._call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code], + self._state.details) + self._state.condition.notify_all() + + +class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: disable=too-many-ancestors + """An RPC iterator operating entirely on a single thread. + + The __next__ method of _SingleThreadedRendezvous does not depend on the + existence of any other thread, including the "channel spin thread". + However, this means that its interface is entirely synchronous. So this + class cannot completely fulfill the grpc.Future interface. The result, + exception, and traceback methods will never block and will instead raise + an exception if calling the method would result in blocking. + + This means that these methods are safe to call from add_done_callback + handlers. + """ + _state: _RPCState + + def _is_complete(self) -> bool: + return self._state.code is not None + + def cancelled(self) -> bool: + with self._state.condition: + return self._state.cancelled + + def running(self) -> bool: + with self._state.condition: + return self._state.code is None + + def done(self) -> bool: + with self._state.condition: + return self._state.code is not None + + def result(self, timeout: Optional[float] = None) -> Any: + """Returns the result of the computation or raises its exception. + + This method will never block. Instead, it will raise an exception + if calling this method would otherwise result in blocking. + + Since this method will never block, any `timeout` argument passed will + be ignored. + """ + del timeout + with self._state.condition: + if not self._is_complete(): + raise grpc.experimental.UsageError( + "_SingleThreadedRendezvous only supports result() when the RPC is complete." + ) + if self._state.code is grpc.StatusCode.OK: + return self._state.response + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + raise self + + def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: + """Return the exception raised by the computation. + + This method will never block. Instead, it will raise an exception + if calling this method would otherwise result in blocking. + + Since this method will never block, any `timeout` argument passed will + be ignored. + """ + del timeout + with self._state.condition: + if not self._is_complete(): + raise grpc.experimental.UsageError( + "_SingleThreadedRendezvous only supports exception() when the RPC is complete." + ) + if self._state.code is grpc.StatusCode.OK: + return None + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + return self + + def traceback( + self, + timeout: Optional[float] = None) -> Optional[types.TracebackType]: + """Access the traceback of the exception raised by the computation. + + This method will never block. Instead, it will raise an exception + if calling this method would otherwise result in blocking. + + Since this method will never block, any `timeout` argument passed will + be ignored. + """ + del timeout + with self._state.condition: + if not self._is_complete(): + raise grpc.experimental.UsageError( + "_SingleThreadedRendezvous only supports traceback() when the RPC is complete." + ) + if self._state.code is grpc.StatusCode.OK: + return None + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + try: + raise self + except grpc.RpcError: + return sys.exc_info()[2] + + def add_done_callback(self, fn: Callable[[grpc.Future], None]) -> None: + with self._state.condition: + if self._state.code is None: + self._state.callbacks.append(functools.partial(fn, self)) + return + + fn(self) + + def initial_metadata(self) -> Optional[MetadataType]: + """See grpc.Call.initial_metadata""" + with self._state.condition: + # NOTE(gnossen): Based on our initial call batch, we are guaranteed + # to receive initial metadata before any messages. + while self._state.initial_metadata is None: + self._consume_next_event() + return self._state.initial_metadata + + def trailing_metadata(self) -> Optional[MetadataType]: + """See grpc.Call.trailing_metadata""" + with self._state.condition: + if self._state.trailing_metadata is None: + raise grpc.experimental.UsageError( + "Cannot get trailing metadata until RPC is completed.") + return self._state.trailing_metadata + + def code(self) -> Optional[grpc.StatusCode]: + """See grpc.Call.code""" + with self._state.condition: + if self._state.code is None: + raise grpc.experimental.UsageError( + "Cannot get code until RPC is completed.") + return self._state.code + + def details(self) -> Optional[str]: + """See grpc.Call.details""" + with self._state.condition: + if self._state.details is None: + raise grpc.experimental.UsageError( + "Cannot get details until RPC is completed.") + return _common.decode(self._state.details) + + def _consume_next_event(self) -> Optional[cygrpc.BaseEvent]: + event = self._call.next_event() + with self._state.condition: + callbacks = _handle_event(event, self._state, + self._response_deserializer) + for callback in callbacks: + # NOTE(gnossen): We intentionally allow exceptions to bubble up + # to the user when running on a single thread. + callback() + return event + + def _next_response(self) -> Any: + while True: + self._consume_next_event() + with self._state.condition: + if self._state.response is not None: + response = self._state.response + self._state.response = None + return response + elif cygrpc.OperationType.receive_message not in self._state.due: + if self._state.code is grpc.StatusCode.OK: + raise StopIteration() + elif self._state.code is not None: + raise self + + def _next(self) -> Any: + with self._state.condition: + if self._state.code is None: + # We tentatively add the operation as expected and remove + # it if the enqueue operation fails. This allows us to guarantee that + # if an event has been submitted to the core completion queue, + # it is in `due`. If we waited until after a successful + # enqueue operation then a signal could interrupt this + # thread between the enqueue operation and the addition of the + # operation to `due`. This would cause an exception on the + # channel spin thread when the operation completes and no + # corresponding operation would be present in state.due. + # Note that, since `condition` is held through this block, there is + # no data race on `due`. + self._state.due.add(cygrpc.OperationType.receive_message) + operating = self._call.operate( + (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), None) + if not operating: + self._state.due.remove(cygrpc.OperationType.receive_message) + elif self._state.code is grpc.StatusCode.OK: + raise StopIteration() + else: + raise self + return self._next_response() + + def debug_error_string(self) -> Optional[str]: + with self._state.condition: + if self._state.debug_error_string is None: + raise grpc.experimental.UsageError( + "Cannot get debug error string until RPC is completed.") + return _common.decode(self._state.debug_error_string) + + +class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: disable=too-many-ancestors + """An RPC iterator that depends on a channel spin thread. + + This iterator relies upon a per-channel thread running in the background, + dequeueing events from the completion queue, and notifying threads waiting + on the threading.Condition object in the _RPCState object. + + This extra thread allows _MultiThreadedRendezvous to fulfill the grpc.Future interface + and to mediate a bidirection streaming RPC. + """ + _state: _RPCState + + def initial_metadata(self) -> Optional[MetadataType]: + """See grpc.Call.initial_metadata""" + with self._state.condition: + + def _done(): + return self._state.initial_metadata is not None + + _common.wait(self._state.condition.wait, _done) + return self._state.initial_metadata + + def trailing_metadata(self) -> Optional[MetadataType]: + """See grpc.Call.trailing_metadata""" + with self._state.condition: + + def _done(): + return self._state.trailing_metadata is not None + + _common.wait(self._state.condition.wait, _done) + return self._state.trailing_metadata + + def code(self) -> Optional[grpc.StatusCode]: + """See grpc.Call.code""" + with self._state.condition: + + def _done(): + return self._state.code is not None + + _common.wait(self._state.condition.wait, _done) + return self._state.code + + def details(self) -> Optional[str]: + """See grpc.Call.details""" + with self._state.condition: + + def _done(): + return self._state.details is not None + + _common.wait(self._state.condition.wait, _done) + return _common.decode(self._state.details) + + def debug_error_string(self) -> Optional[str]: + with self._state.condition: + + def _done(): + return self._state.debug_error_string is not None + + _common.wait(self._state.condition.wait, _done) + return _common.decode(self._state.debug_error_string) + + def cancelled(self) -> bool: + with self._state.condition: + return self._state.cancelled + + def running(self) -> bool: + with self._state.condition: + return self._state.code is None + + def done(self) -> bool: + with self._state.condition: + return self._state.code is not None + + def _is_complete(self) -> bool: + return self._state.code is not None + + def result(self, timeout: Optional[float] = None) -> Any: + """Returns the result of the computation or raises its exception. + + See grpc.Future.result for the full API contract. + """ + with self._state.condition: + timed_out = _common.wait(self._state.condition.wait, + self._is_complete, + timeout=timeout) + if timed_out: + raise grpc.FutureTimeoutError() + else: + if self._state.code is grpc.StatusCode.OK: + return self._state.response + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + raise self + + def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: + """Return the exception raised by the computation. + + See grpc.Future.exception for the full API contract. + """ + with self._state.condition: + timed_out = _common.wait(self._state.condition.wait, + self._is_complete, + timeout=timeout) + if timed_out: + raise grpc.FutureTimeoutError() + else: + if self._state.code is grpc.StatusCode.OK: + return None + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + return self + + def traceback( + self, + timeout: Optional[float] = None) -> Optional[types.TracebackType]: + """Access the traceback of the exception raised by the computation. + + See grpc.future.traceback for the full API contract. + """ + with self._state.condition: + timed_out = _common.wait(self._state.condition.wait, + self._is_complete, + timeout=timeout) + if timed_out: + raise grpc.FutureTimeoutError() + else: + if self._state.code is grpc.StatusCode.OK: + return None + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + try: + raise self + except grpc.RpcError: + return sys.exc_info()[2] + + def add_done_callback(self, fn: Callable[[grpc.Future], None]) -> None: + with self._state.condition: + if self._state.code is None: + self._state.callbacks.append(functools.partial(fn, self)) + return + + fn(self) + + def _next(self) -> Any: + with self._state.condition: + if self._state.code is None: + event_handler = _event_handler(self._state, + self._response_deserializer) + self._state.due.add(cygrpc.OperationType.receive_message) + operating = self._call.operate( + (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), + event_handler) + if not operating: + self._state.due.remove(cygrpc.OperationType.receive_message) + elif self._state.code is grpc.StatusCode.OK: + raise StopIteration() + else: + raise self + + def _response_ready(): + return (self._state.response is not None or + (cygrpc.OperationType.receive_message + not in self._state.due and + self._state.code is not None)) + + _common.wait(self._state.condition.wait, _response_ready) + if self._state.response is not None: + response = self._state.response + self._state.response = None + return response + elif cygrpc.OperationType.receive_message not in self._state.due: + if self._state.code is grpc.StatusCode.OK: + raise StopIteration() + elif self._state.code is not None: + raise self + + +def _start_unary_request( + request: Any, timeout: Optional[float], + request_serializer: SerializingFunction +) -> Tuple[Optional[float], Optional[bytes], Optional[grpc.RpcError]]: + deadline = _deadline(timeout) + serialized_request = _common.serialize(request, request_serializer) + if serialized_request is None: + state = _RPCState((), (), (), grpc.StatusCode.INTERNAL, + 'Exception serializing request!') + error = _InactiveRpcError(state) + return deadline, None, error + else: + return deadline, serialized_request, None + + +def _end_unary_response_blocking( + state: _RPCState, call: cygrpc.SegregatedCall, with_call: bool, + deadline: Optional[float] +) -> Union[ResponseType, Tuple[ResponseType, grpc.Call]]: + if state.code is grpc.StatusCode.OK: + if with_call: + rendezvous = _MultiThreadedRendezvous(state, call, None, deadline) + return state.response, rendezvous + else: + return state.response + else: + raise _InactiveRpcError(state) # pytype: disable=not-instantiable + + +def _stream_unary_invocation_operations( + metadata: Optional[MetadataType], + initial_metadata_flags: int) -> Sequence[Sequence[cygrpc.Operation]]: + return ( + ( + cygrpc.SendInitialMetadataOperation(metadata, + initial_metadata_flags), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + + +def _stream_unary_invocation_operations_and_tags( + metadata: Optional[MetadataType], initial_metadata_flags: int +) -> Sequence[Tuple[Sequence[cygrpc.Operation], Optional[UserTag]]]: + return tuple(( + operations, + None, + ) for operations in _stream_unary_invocation_operations( + metadata, initial_metadata_flags)) + + +def _determine_deadline(user_deadline: Optional[float]) -> Optional[float]: + parent_deadline = cygrpc.get_deadline_from_context() + if parent_deadline is None and user_deadline is None: + return None + elif parent_deadline is not None and user_deadline is None: + return parent_deadline + elif user_deadline is not None and parent_deadline is None: + return user_deadline + else: + return min(parent_deadline, user_deadline) + + +class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): + _channel: cygrpc.Channel + _managed_call: IntegratedCallFactory + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any + + # pylint: disable=too-many-arguments + def __init__(self, channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, method: bytes, + request_serializer: Optional[SerializingFunction], + response_deserializer: Optional[DeserializingFunction]): + self._channel = channel + self._managed_call = managed_call + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._context = cygrpc.build_census_context() + + def _prepare( + self, request: Any, timeout: Optional[float], + metadata: Optional[MetadataType], wait_for_ready: Optional[bool], + compression: Optional[grpc.Compression] + ) -> Tuple[Optional[_RPCState], Optional[Sequence[cygrpc.Operation]], + Optional[float], Optional[grpc.RpcError]]: + deadline, serialized_request, rendezvous = _start_unary_request( + request, timeout, self._request_serializer) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) + augmented_metadata = _compression.augment_metadata( + metadata, compression) + if serialized_request is None: + return None, None, None, rendezvous + else: + state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) + operations = ( + cygrpc.SendInitialMetadataOperation(augmented_metadata, + initial_metadata_flags), + cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ) + return state, operations, deadline, None + + def _blocking( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[_RPCState, cygrpc.SegregatedCall]: + state, operations, deadline, rendezvous = self._prepare( + request, timeout, metadata, wait_for_ready, compression) + if state is None: + raise rendezvous # pylint: disable-msg=raising-bad-type + else: + call = self._channel.segregated_call( + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, _determine_deadline(deadline), metadata, + None if credentials is None else credentials._credentials, (( + operations, + None, + ),), self._context) + event = call.next_event() + _handle_event(event, state, self._response_deserializer) + return state, call + + def __call__(self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: + state, call, = self._blocking(request, timeout, metadata, credentials, + wait_for_ready, compression) + return _end_unary_response_blocking(state, call, False, None) + + def with_call( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: + state, call, = self._blocking(request, timeout, metadata, credentials, + wait_for_ready, compression) + return _end_unary_response_blocking(state, call, True, None) + + def future( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _MultiThreadedRendezvous: + state, operations, deadline, rendezvous = self._prepare( + request, timeout, metadata, wait_for_ready, compression) + if state is None: + raise rendezvous # pylint: disable-msg=raising-bad-type + else: + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, deadline, metadata, + None if credentials is None else credentials._credentials, + (operations,), event_handler, self._context) + return _MultiThreadedRendezvous(state, call, + self._response_deserializer, + deadline) + + +class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + _channel: cygrpc.Channel + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any + + # pylint: disable=too-many-arguments + def __init__(self, channel: cygrpc.Channel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction): + self._channel = channel + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._context = cygrpc.build_census_context() + + def __call__( # pylint: disable=too-many-locals + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _SingleThreadedRendezvous: + deadline = _deadline(timeout) + serialized_request = _common.serialize(request, + self._request_serializer) + if serialized_request is None: + state = _RPCState((), (), (), grpc.StatusCode.INTERNAL, + 'Exception serializing request!') + raise _InactiveRpcError(state) + + state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) + call_credentials = None if credentials is None else credentials._credentials + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) + augmented_metadata = _compression.augment_metadata( + metadata, compression) + operations = ( + (cygrpc.SendInitialMetadataOperation(augmented_metadata, + initial_metadata_flags), + cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS)), + (cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + operations_and_tags = tuple((ops, None) for ops in operations) + call = self._channel.segregated_call( + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, _determine_deadline(deadline), metadata, call_credentials, + operations_and_tags, self._context) + return _SingleThreadedRendezvous(state, call, + self._response_deserializer, deadline) + + +class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + _channel: cygrpc.Channel + _managed_call: IntegratedCallFactory + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any + + # pylint: disable=too-many-arguments + def __init__(self, channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction): + self._channel = channel + self._managed_call = managed_call + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._context = cygrpc.build_census_context() + + def __call__( # pylint: disable=too-many-locals + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[ + grpc.Compression] = None) -> _MultiThreadedRendezvous: + deadline, serialized_request, rendezvous = _start_unary_request( + request, timeout, self._request_serializer) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) + if serialized_request is None: + raise rendezvous # pylint: disable-msg=raising-bad-type + else: + augmented_metadata = _compression.augment_metadata( + metadata, compression) + state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) + operations = ( + ( + cygrpc.SendInitialMetadataOperation(augmented_metadata, + initial_metadata_flags), + cygrpc.SendMessageOperation(serialized_request, + _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + call = self._managed_call( + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, _determine_deadline(deadline), metadata, + None if credentials is None else credentials._credentials, + operations, _event_handler(state, self._response_deserializer), + self._context) + return _MultiThreadedRendezvous(state, call, + self._response_deserializer, + deadline) + + +class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): + _channel: cygrpc.Channel + _managed_call: IntegratedCallFactory + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any + + # pylint: disable=too-many-arguments + def __init__(self, channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, method: bytes, + request_serializer: Optional[SerializingFunction], + response_deserializer: Optional[DeserializingFunction]): + self._channel = channel + self._managed_call = managed_call + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._context = cygrpc.build_census_context() + + def _blocking( + self, request_iterator: Iterator, timeout: Optional[float], + metadata: Optional[MetadataType], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], compression: Optional[grpc.Compression] + ) -> Tuple[_RPCState, cygrpc.SegregatedCall]: + deadline = _deadline(timeout) + state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) + augmented_metadata = _compression.augment_metadata( + metadata, compression) + call = self._channel.segregated_call( + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, _determine_deadline(deadline), augmented_metadata, + None if credentials is None else credentials._credentials, + _stream_unary_invocation_operations_and_tags( + augmented_metadata, initial_metadata_flags), self._context) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, None) + while True: + event = call.next_event() + with state.condition: + _handle_event(event, state, self._response_deserializer) + state.condition.notify_all() + if not state.due: + break + return state, call + + def __call__(self, + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: + state, call, = self._blocking(request_iterator, timeout, metadata, + credentials, wait_for_ready, compression) + return _end_unary_response_blocking(state, call, False, None) + + def with_call( + self, + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: + state, call, = self._blocking(request_iterator, timeout, metadata, + credentials, wait_for_ready, compression) + return _end_unary_response_blocking(state, call, True, None) + + def future( + self, + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _MultiThreadedRendezvous: + deadline = _deadline(timeout) + state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) + event_handler = _event_handler(state, self._response_deserializer) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) + augmented_metadata = _compression.augment_metadata( + metadata, compression) + call = self._managed_call( + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, deadline, augmented_metadata, + None if credentials is None else credentials._credentials, + _stream_unary_invocation_operations(metadata, + initial_metadata_flags), + event_handler, self._context) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, event_handler) + return _MultiThreadedRendezvous(state, call, + self._response_deserializer, deadline) + + +class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): + _channel: cygrpc.Channel + _managed_call: IntegratedCallFactory + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any + + # pylint: disable=too-many-arguments + def __init__(self, + channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, + method: bytes, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None): + self._channel = channel + self._managed_call = managed_call + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._context = cygrpc.build_census_context() + + def __call__( + self, + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _MultiThreadedRendezvous: + deadline = _deadline(timeout) + state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) + augmented_metadata = _compression.augment_metadata( + metadata, compression) + operations = ( + ( + cygrpc.SendInitialMetadataOperation(augmented_metadata, + initial_metadata_flags), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), + ) + event_handler = _event_handler(state, self._response_deserializer) + call = self._managed_call( + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, _determine_deadline(deadline), augmented_metadata, + None if credentials is None else credentials._credentials, + operations, event_handler, self._context) + _consume_request_iterator(request_iterator, state, call, + self._request_serializer, event_handler) + return _MultiThreadedRendezvous(state, call, + self._response_deserializer, deadline) + + +class _InitialMetadataFlags(int): + """Stores immutable initial metadata flags""" + + def __new__(cls, value: int = _EMPTY_FLAGS): + value &= cygrpc.InitialMetadataFlags.used_mask + return super(_InitialMetadataFlags, cls).__new__(cls, value) + + def with_wait_for_ready(self, wait_for_ready: Optional[bool]) -> int: + if wait_for_ready is not None: + if wait_for_ready: + return self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \ + cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) + elif not wait_for_ready: + return self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \ + cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) + return self + + +class _ChannelCallState(object): + channel: cygrpc.Channel + managed_calls: int + threading: bool + + def __init__(self, channel: cygrpc.Channel): + self.lock = threading.Lock() + self.channel = channel + self.managed_calls = 0 + self.threading = False + + def reset_postfork_child(self) -> None: + self.managed_calls = 0 + + def __del__(self): + try: + self.channel.close(cygrpc.StatusCode.cancelled, + 'Channel deallocated!') + except (TypeError, AttributeError): + pass + + +def _run_channel_spin_thread(state: _ChannelCallState) -> None: + + def channel_spin(): + while True: + cygrpc.block_if_fork_in_progress(state) + event = state.channel.next_call_event() + if event.completion_type == cygrpc.CompletionType.queue_timeout: + continue + call_completed = event.tag(event) + if call_completed: + with state.lock: + state.managed_calls -= 1 + if state.managed_calls == 0: + return + + channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin) + channel_spin_thread.setDaemon(True) + channel_spin_thread.start() + + +def _channel_managed_call_management(state: _ChannelCallState): + + # pylint: disable=too-many-arguments + def create(flags: int, method: bytes, host: Optional[str], + deadline: Optional[float], metadata: Optional[MetadataType], + credentials: Optional[cygrpc.CallCredentials], + operations: Sequence[Sequence[cygrpc.Operation]], + event_handler: UserTag, context) -> cygrpc.IntegratedCall: + """Creates a cygrpc.IntegratedCall. + + Args: + flags: An integer bitfield of call flags. + method: The RPC method. + host: A host string for the created call. + deadline: A float to be the deadline of the created call or None if + the call is to have an infinite deadline. + metadata: The metadata for the call or None. + credentials: A cygrpc.CallCredentials or None. + operations: A sequence of sequences of cygrpc.Operations to be + started on the call. + event_handler: A behavior to call to handle the events resultant from + the operations on the call. + context: Context object for distributed tracing. + Returns: + A cygrpc.IntegratedCall with which to conduct an RPC. + """ + operations_and_tags = tuple(( + operation, + event_handler, + ) for operation in operations) + with state.lock: + call = state.channel.integrated_call(flags, method, host, deadline, + metadata, credentials, + operations_and_tags, context) + if state.managed_calls == 0: + state.managed_calls = 1 + _run_channel_spin_thread(state) + else: + state.managed_calls += 1 + return call + + return create + + +class _ChannelConnectivityState(object): + lock: threading.RLock + channel: grpc.Channel + polling: bool + connectivity: grpc.ChannelConnectivity + try_to_connect: bool + # TODO(xuanwn): Refactor this: https://github.com/grpc/grpc/issues/31704 + callbacks_and_connectivities: List[Sequence[Union[Callable[ + [grpc.ChannelConnectivity], None], Optional[grpc.ChannelConnectivity]]]] + delivering: bool + + def __init__(self, channel: grpc.Channel): + self.lock = threading.RLock() + self.channel = channel + self.polling = False + self.connectivity = None + self.try_to_connect = False + self.callbacks_and_connectivities = [] + self.delivering = False + + def reset_postfork_child(self) -> None: + self.polling = False + self.connectivity = None + self.try_to_connect = False + self.callbacks_and_connectivities = [] + self.delivering = False + + +def _deliveries( + state: _ChannelConnectivityState +) -> List[Callable[[grpc.ChannelConnectivity], None]]: + callbacks_needing_update = [] + for callback_and_connectivity in state.callbacks_and_connectivities: + callback, callback_connectivity, = callback_and_connectivity + if callback_connectivity is not state.connectivity: + callbacks_needing_update.append(callback) + callback_and_connectivity[1] = state.connectivity + return callbacks_needing_update + + +def _deliver( + state: _ChannelConnectivityState, + initial_connectivity: grpc.ChannelConnectivity, + initial_callbacks: Sequence[Callable[[grpc.ChannelConnectivity], None]] +) -> None: + connectivity = initial_connectivity + callbacks = initial_callbacks + while True: + for callback in callbacks: + cygrpc.block_if_fork_in_progress(state) + try: + callback(connectivity) + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE) + with state.lock: + callbacks = _deliveries(state) + if callbacks: + connectivity = state.connectivity + else: + state.delivering = False + return + + +def _spawn_delivery( + state: _ChannelConnectivityState, + callbacks: Sequence[Callable[[grpc.ChannelConnectivity], + None]]) -> None: + delivering_thread = cygrpc.ForkManagedThread(target=_deliver, + args=( + state, + state.connectivity, + callbacks, + )) + delivering_thread.setDaemon(True) + delivering_thread.start() + state.delivering = True + + +# NOTE(https://github.com/grpc/grpc/issues/3064): We'd rather not poll. +def _poll_connectivity(state: _ChannelConnectivityState, channel: grpc.Channel, + initial_try_to_connect: bool) -> None: + try_to_connect = initial_try_to_connect + connectivity = channel.check_connectivity_state(try_to_connect) + with state.lock: + state.connectivity = ( + _common. + CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[connectivity]) + callbacks = tuple( + callback for callback, unused_but_known_to_be_none_connectivity in + state.callbacks_and_connectivities) + for callback_and_connectivity in state.callbacks_and_connectivities: + callback_and_connectivity[1] = state.connectivity + if callbacks: + _spawn_delivery(state, callbacks) + while True: + event = channel.watch_connectivity_state(connectivity, + time.time() + 0.2) + cygrpc.block_if_fork_in_progress(state) + with state.lock: + if not state.callbacks_and_connectivities and not state.try_to_connect: + state.polling = False + state.connectivity = None + break + try_to_connect = state.try_to_connect + state.try_to_connect = False + if event.success or try_to_connect: + connectivity = channel.check_connectivity_state(try_to_connect) + with state.lock: + state.connectivity = ( + _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[ + connectivity]) + if not state.delivering: + callbacks = _deliveries(state) + if callbacks: + _spawn_delivery(state, callbacks) + + +def _subscribe(state: _ChannelConnectivityState, + callback: Callable[[grpc.ChannelConnectivity], + None], try_to_connect: bool) -> None: + with state.lock: + if not state.callbacks_and_connectivities and not state.polling: + polling_thread = cygrpc.ForkManagedThread( + target=_poll_connectivity, + args=(state, state.channel, bool(try_to_connect))) + polling_thread.setDaemon(True) + polling_thread.start() + state.polling = True + state.callbacks_and_connectivities.append([callback, None]) + elif not state.delivering and state.connectivity is not None: + _spawn_delivery(state, (callback,)) + state.try_to_connect |= bool(try_to_connect) + state.callbacks_and_connectivities.append( + [callback, state.connectivity]) + else: + state.try_to_connect |= bool(try_to_connect) + state.callbacks_and_connectivities.append([callback, None]) + + +def _unsubscribe(state: _ChannelConnectivityState, + callback: Callable[[grpc.ChannelConnectivity], None]) -> None: + with state.lock: + for index, (subscribed_callback, unused_connectivity) in enumerate( + state.callbacks_and_connectivities): + if callback == subscribed_callback: + state.callbacks_and_connectivities.pop(index) + break + + +def _augment_options( + base_options: Sequence[ChannelArgumentType], + compression: Optional[grpc.Compression] +) -> Sequence[ChannelArgumentType]: + compression_option = _compression.create_channel_option(compression) + return tuple(base_options) + compression_option + (( + cygrpc.ChannelArgKey.primary_user_agent_string, + _USER_AGENT, + ),) + + +def _separate_channel_options( + options: Sequence[ChannelArgumentType] +) -> Tuple[Sequence[ChannelArgumentType], Sequence[ChannelArgumentType]]: + """Separates core channel options from Python channel options.""" + core_options = [] + python_options = [] + for pair in options: + if pair[0] == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream: + python_options.append(pair) + else: + core_options.append(pair) + return python_options, core_options + + +class Channel(grpc.Channel): + """A cygrpc.Channel-backed implementation of grpc.Channel.""" + _single_threaded_unary_stream: bool + _channel: cygrpc.Channel + _call_state: _ChannelCallState + _connectivity_state: _ChannelConnectivityState + + def __init__(self, target: str, options: Sequence[ChannelArgumentType], + credentials: Optional[grpc.ChannelCredentials], + compression: Optional[grpc.Compression]): + """Constructor. + + Args: + target: The target to which to connect. + options: Configuration options for the channel. + credentials: A cygrpc.ChannelCredentials or None. + compression: An optional value indicating the compression method to be + used over the lifetime of the channel. + """ + python_options, core_options = _separate_channel_options(options) + self._single_threaded_unary_stream = _DEFAULT_SINGLE_THREADED_UNARY_STREAM + self._process_python_options(python_options) + self._channel = cygrpc.Channel( + _common.encode(target), _augment_options(core_options, compression), + credentials) + self._call_state = _ChannelCallState(self._channel) + self._connectivity_state = _ChannelConnectivityState(self._channel) + cygrpc.fork_register_channel(self) + if cygrpc.g_gevent_activated: + cygrpc.gevent_increment_channel_count() + + def _process_python_options( + self, python_options: Sequence[ChannelArgumentType]) -> None: + """Sets channel attributes according to python-only channel options.""" + for pair in python_options: + if pair[0] == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream: + self._single_threaded_unary_stream = True + + def subscribe(self, + callback: Callable[[grpc.ChannelConnectivity], None], + try_to_connect: Optional[bool] = None) -> None: + _subscribe(self._connectivity_state, callback, try_to_connect) + + def unsubscribe( + self, callback: Callable[[grpc.ChannelConnectivity], None]) -> None: + _unsubscribe(self._connectivity_state, callback) + + def unary_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.UnaryUnaryMultiCallable: + return _UnaryUnaryMultiCallable( + self._channel, _channel_managed_call_management(self._call_state), + _common.encode(method), request_serializer, response_deserializer) + + def unary_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.UnaryStreamMultiCallable: + # NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC + # on a single Python thread results in an appreciable speed-up. However, + # due to slight differences in capability, the multi-threaded variant + # remains the default. + if self._single_threaded_unary_stream: + return _SingleThreadedUnaryStreamMultiCallable( + self._channel, _common.encode(method), request_serializer, + response_deserializer) + else: + return _UnaryStreamMultiCallable( + self._channel, + _channel_managed_call_management(self._call_state), + _common.encode(method), request_serializer, + response_deserializer) + + def stream_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.StreamUnaryMultiCallable: + return _StreamUnaryMultiCallable( + self._channel, _channel_managed_call_management(self._call_state), + _common.encode(method), request_serializer, response_deserializer) + + def stream_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.StreamStreamMultiCallable: + return _StreamStreamMultiCallable( + self._channel, _channel_managed_call_management(self._call_state), + _common.encode(method), request_serializer, response_deserializer) + + def _unsubscribe_all(self) -> None: + state = self._connectivity_state + if state: + with state.lock: + del state.callbacks_and_connectivities[:] + + def _close(self) -> None: + self._unsubscribe_all() + self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!') + cygrpc.fork_unregister_channel(self) + if cygrpc.g_gevent_activated: + cygrpc.gevent_decrement_channel_count() + + def _close_on_fork(self) -> None: + self._unsubscribe_all() + self._channel.close_on_fork(cygrpc.StatusCode.cancelled, + 'Channel closed due to fork') + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._close() + return False + + def close(self) -> None: + self._close() + + def __del__(self): + # TODO(https://github.com/grpc/grpc/issues/12531): Several releases + # after 1.12 (1.16 or thereabouts?) add a "self._channel.close" call + # here (or more likely, call self._close() here). We don't do this today + # because many valid use cases today allow the channel to be deleted + # immediately after stubs are created. After a sufficient period of time + # has passed for all users to be trusted to freeze out to their channels + # for as long as they are in use and to close them after using them, + # then deletion of this grpc._channel.Channel instance can be made to + # effect closure of the underlying cygrpc.Channel instance. + try: + self._unsubscribe_all() + except: # pylint: disable=bare-except + # Exceptions in __del__ are ignored by Python anyway, but they can + # keep spamming logs. Just silence them. + pass |