diff options
author | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:24:06 +0300 |
---|---|---|
committer | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:41:34 +0300 |
commit | e0e3e1717e3d33762ce61950504f9637a6e669ed (patch) | |
tree | bca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/grpcio/py3/grpc/_interceptor.py | |
parent | 38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff) | |
download | ydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz |
add ydb deps
Diffstat (limited to 'contrib/python/grpcio/py3/grpc/_interceptor.py')
-rw-r--r-- | contrib/python/grpcio/py3/grpc/_interceptor.py | 638 |
1 files changed, 638 insertions, 0 deletions
diff --git a/contrib/python/grpcio/py3/grpc/_interceptor.py b/contrib/python/grpcio/py3/grpc/_interceptor.py new file mode 100644 index 0000000000..865ff17d35 --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/_interceptor.py @@ -0,0 +1,638 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of gRPC Python interceptors.""" + +import collections +import sys +import types +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +import grpc + +from ._typing import DeserializingFunction +from ._typing import DoneCallbackType +from ._typing import MetadataType +from ._typing import RequestIterableType +from ._typing import SerializingFunction + + +class _ServicePipeline(object): + interceptors: Tuple[grpc.ServerInterceptor] + + def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]): + self.interceptors = tuple(interceptors) + + def _continuation(self, thunk: Callable, index: int) -> Callable: + return lambda context: self._intercept_at(thunk, index, context) + + def _intercept_at( + self, thunk: Callable, index: int, + context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler: + if index < len(self.interceptors): + interceptor = self.interceptors[index] + thunk = self._continuation(thunk, index + 1) + return interceptor.intercept_service(thunk, context) + else: + return thunk(context) + + def execute(self, thunk: Callable, + context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler: + return self._intercept_at(thunk, 0, context) + + +def service_pipeline( + interceptors: Optional[Sequence[grpc.ServerInterceptor]] +) -> Optional[_ServicePipeline]: + return _ServicePipeline(interceptors) if interceptors else None + + +class _ClientCallDetails( + collections.namedtuple('_ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials', + 'wait_for_ready', 'compression')), + grpc.ClientCallDetails): + pass + + +def _unwrap_client_call_details( + call_details: grpc.ClientCallDetails, + default_details: grpc.ClientCallDetails +) -> Tuple[str, float, MetadataType, grpc.CallCredentials, bool, + grpc.Compression]: + try: + method = call_details.method # pytype: disable=attribute-error + except AttributeError: + method = default_details.method # pytype: disable=attribute-error + + try: + timeout = call_details.timeout # pytype: disable=attribute-error + except AttributeError: + timeout = default_details.timeout # pytype: disable=attribute-error + + try: + metadata = call_details.metadata # pytype: disable=attribute-error + except AttributeError: + metadata = default_details.metadata # pytype: disable=attribute-error + + try: + credentials = call_details.credentials # pytype: disable=attribute-error + except AttributeError: + credentials = default_details.credentials # pytype: disable=attribute-error + + try: + wait_for_ready = call_details.wait_for_ready # pytype: disable=attribute-error + except AttributeError: + wait_for_ready = default_details.wait_for_ready # pytype: disable=attribute-error + + try: + compression = call_details.compression # pytype: disable=attribute-error + except AttributeError: + compression = default_details.compression # pytype: disable=attribute-error + + return method, timeout, metadata, credentials, wait_for_ready, compression + + +class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors + _exception: Exception + _traceback: types.TracebackType + + def __init__(self, exception: Exception, traceback: types.TracebackType): + super(_FailureOutcome, self).__init__() + self._exception = exception + self._traceback = traceback + + def initial_metadata(self) -> Optional[MetadataType]: + return None + + def trailing_metadata(self) -> Optional[MetadataType]: + return None + + def code(self) -> Optional[grpc.StatusCode]: + return grpc.StatusCode.INTERNAL + + def details(self) -> Optional[str]: + return 'Exception raised while intercepting the RPC' + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return False + + def is_active(self) -> bool: + return False + + def time_remaining(self) -> Optional[float]: + return None + + def running(self) -> bool: + return False + + def done(self) -> bool: + return True + + def result(self, ignored_timeout: Optional[float] = None): + raise self._exception + + def exception( + self, + ignored_timeout: Optional[float] = None) -> Optional[Exception]: + return self._exception + + def traceback( + self, + ignored_timeout: Optional[float] = None + ) -> Optional[types.TracebackType]: + return self._traceback + + def add_callback(self, unused_callback) -> bool: + return False + + def add_done_callback(self, fn: DoneCallbackType) -> None: + fn(self) + + def __iter__(self): + return self + + def __next__(self): + raise self._exception + + def next(self): + return self.__next__() + + +class _UnaryOutcome(grpc.Call, grpc.Future): + _response: Any + _call: grpc.Call + + def __init__(self, response: Any, call: grpc.Call): + self._response = response + self._call = call + + def initial_metadata(self) -> Optional[MetadataType]: + return self._call.initial_metadata() + + def trailing_metadata(self) -> Optional[MetadataType]: + return self._call.trailing_metadata() + + def code(self) -> Optional[grpc.StatusCode]: + return self._call.code() + + def details(self) -> Optional[str]: + return self._call.details() + + def is_active(self) -> bool: + return self._call.is_active() + + def time_remaining(self) -> Optional[float]: + return self._call.time_remaining() + + def cancel(self) -> bool: + return self._call.cancel() + + def add_callback(self, callback) -> bool: + return self._call.add_callback(callback) + + def cancelled(self) -> bool: + return False + + def running(self) -> bool: + return False + + def done(self) -> bool: + return True + + def result(self, ignored_timeout: Optional[float] = None): + return self._response + + def exception(self, ignored_timeout: Optional[float] = None): + return None + + def traceback(self, ignored_timeout: Optional[float] = None): + return None + + def add_done_callback(self, fn: DoneCallbackType) -> None: + fn(self) + + +class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): + _thunk: Callable + _method: str + _interceptor: grpc.UnaryUnaryClientInterceptor + + def __init__(self, thunk: Callable, method: str, + interceptor: grpc.UnaryUnaryClientInterceptor): + self._thunk = thunk + self._method = method + self._interceptor = interceptor + + def __call__(self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: + response, ignored_call = self._with_call(request, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression) + return response + + def _with_call( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) + + def continuation(new_details, request): + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) + try: + response, call = self._thunk(new_method).with_call( + request, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials, + wait_for_ready=new_wait_for_ready, + compression=new_compression) + return _UnaryOutcome(response, call) + except grpc.RpcError as rpc_error: + return rpc_error + except Exception as exception: # pylint:disable=broad-except + return _FailureOutcome(exception, sys.exc_info()[2]) + + call = self._interceptor.intercept_unary_unary(continuation, + client_call_details, + request) + return call.result(), call + + def with_call( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: + return self._with_call(request, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression) + + def future(self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) + + def continuation(new_details, request): + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) + return self._thunk(new_method).future( + request, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials, + wait_for_ready=new_wait_for_ready, + compression=new_compression) + + try: + return self._interceptor.intercept_unary_unary( + continuation, client_call_details, request) + except Exception as exception: # pylint:disable=broad-except + return _FailureOutcome(exception, sys.exc_info()[2]) + + +class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + _thunk: Callable + _method: str + _interceptor: grpc.UnaryStreamClientInterceptor + + def __init__(self, thunk: Callable, method: str, + interceptor: grpc.UnaryStreamClientInterceptor): + self._thunk = thunk + self._method = method + self._interceptor = interceptor + + def __call__(self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) + + def continuation(new_details, request): + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) + return self._thunk(new_method)(request, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials, + wait_for_ready=new_wait_for_ready, + compression=new_compression) + + try: + return self._interceptor.intercept_unary_stream( + continuation, client_call_details, request) + except Exception as exception: # pylint:disable=broad-except + return _FailureOutcome(exception, sys.exc_info()[2]) + + +class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): + _thunk: Callable + _method: str + _interceptor: grpc.StreamUnaryClientInterceptor + + def __init__(self, thunk: Callable, method: str, + interceptor: grpc.StreamUnaryClientInterceptor): + self._thunk = thunk + self._method = method + self._interceptor = interceptor + + def __call__(self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: + response, ignored_call = self._with_call(request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression) + return response + + def _with_call( + self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) + + def continuation(new_details, request_iterator): + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) + try: + response, call = self._thunk(new_method).with_call( + request_iterator, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials, + wait_for_ready=new_wait_for_ready, + compression=new_compression) + return _UnaryOutcome(response, call) + except grpc.RpcError as rpc_error: + return rpc_error + except Exception as exception: # pylint:disable=broad-except + return _FailureOutcome(exception, sys.exc_info()[2]) + + call = self._interceptor.intercept_stream_unary(continuation, + client_call_details, + request_iterator) + return call.result(), call + + def with_call( + self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: + return self._with_call(request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression) + + def future(self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) + + def continuation(new_details, request_iterator): + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) + return self._thunk(new_method).future( + request_iterator, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials, + wait_for_ready=new_wait_for_ready, + compression=new_compression) + + try: + return self._interceptor.intercept_stream_unary( + continuation, client_call_details, request_iterator) + except Exception as exception: # pylint:disable=broad-except + return _FailureOutcome(exception, sys.exc_info()[2]) + + +class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): + _thunk: Callable + _method: str + _interceptor: grpc.StreamStreamClientInterceptor + + def __init__(self, thunk: Callable, method: str, + interceptor: grpc.StreamStreamClientInterceptor): + self._thunk = thunk + self._method = method + self._interceptor = interceptor + + def __call__(self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) + + def continuation(new_details, request_iterator): + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) + return self._thunk(new_method)(request_iterator, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials, + wait_for_ready=new_wait_for_ready, + compression=new_compression) + + try: + return self._interceptor.intercept_stream_stream( + continuation, client_call_details, request_iterator) + except Exception as exception: # pylint:disable=broad-except + return _FailureOutcome(exception, sys.exc_info()[2]) + + +class _Channel(grpc.Channel): + _channel: grpc.Channel + _interceptor: Union[grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor] + + def __init__(self, channel: grpc.Channel, + interceptor: Union[grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor]): + self._channel = channel + self._interceptor = interceptor + + def subscribe(self, + callback: Callable, + try_to_connect: Optional[bool] = False): + self._channel.subscribe(callback, try_to_connect=try_to_connect) + + def unsubscribe(self, callback: Callable): + self._channel.unsubscribe(callback) + + def unary_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.UnaryUnaryMultiCallable: + thunk = lambda m: self._channel.unary_unary(m, request_serializer, + response_deserializer) + if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): + return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) + else: + return thunk(method) + + def unary_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.UnaryStreamMultiCallable: + thunk = lambda m: self._channel.unary_stream(m, request_serializer, + response_deserializer) + if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): + return _UnaryStreamMultiCallable(thunk, method, self._interceptor) + else: + return thunk(method) + + def stream_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.StreamUnaryMultiCallable: + thunk = lambda m: self._channel.stream_unary(m, request_serializer, + response_deserializer) + if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): + return _StreamUnaryMultiCallable(thunk, method, self._interceptor) + else: + return thunk(method) + + def stream_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.StreamStreamMultiCallable: + thunk = lambda m: self._channel.stream_stream(m, request_serializer, + response_deserializer) + if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): + return _StreamStreamMultiCallable(thunk, method, self._interceptor) + else: + return thunk(method) + + def _close(self): + self._channel.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._close() + return False + + def close(self): + self._channel.close() + + +def intercept_channel( + channel: grpc.Channel, + *interceptors: Optional[Sequence[Union[grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor]]] +) -> grpc.Channel: + for interceptor in reversed(list(interceptors)): + if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \ + not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \ + not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \ + not isinstance(interceptor, grpc.StreamStreamClientInterceptor): + raise TypeError('interceptor must be ' + 'grpc.UnaryUnaryClientInterceptor or ' + 'grpc.UnaryStreamClientInterceptor or ' + 'grpc.StreamUnaryClientInterceptor or ' + 'grpc.StreamStreamClientInterceptor or ') + channel = _Channel(channel, interceptor) + return channel |