diff options
| author | nkozlovskiy <[email protected]> | 2023-09-29 12:24:06 +0300 |
|---|---|---|
| committer | nkozlovskiy <[email protected]> | 2023-09-29 12:41:34 +0300 |
| commit | e0e3e1717e3d33762ce61950504f9637a6e669ed (patch) | |
| tree | bca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/grpcio/py3/grpc/aio | |
| parent | 38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff) | |
add ydb deps
Diffstat (limited to 'contrib/python/grpcio/py3/grpc/aio')
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/__init__.py | 95 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_base_call.py | 248 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_base_channel.py | 348 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_base_server.py | 369 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_call.py | 649 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_channel.py | 492 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_interceptor.py | 1001 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_metadata.py | 120 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_server.py | 209 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_typing.py | 35 | ||||
| -rw-r--r-- | contrib/python/grpcio/py3/grpc/aio/_utils.py | 22 |
11 files changed, 3588 insertions, 0 deletions
diff --git a/contrib/python/grpcio/py3/grpc/aio/__init__.py b/contrib/python/grpcio/py3/grpc/aio/__init__.py new file mode 100644 index 00000000000..3436d2ef98c --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/__init__.py @@ -0,0 +1,95 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""gRPC's Asynchronous Python API. + +gRPC Async API objects may only be used on the thread on which they were +created. AsyncIO doesn't provide thread safety for most of its APIs. +""" + +from typing import Any, Optional, Sequence, Tuple + +import grpc +from grpc._cython.cygrpc import AbortError +from grpc._cython.cygrpc import BaseError +from grpc._cython.cygrpc import EOF +from grpc._cython.cygrpc import InternalError +from grpc._cython.cygrpc import UsageError +from grpc._cython.cygrpc import init_grpc_aio +from grpc._cython.cygrpc import shutdown_grpc_aio + +from ._base_call import Call +from ._base_call import RpcContext +from ._base_call import StreamStreamCall +from ._base_call import StreamUnaryCall +from ._base_call import UnaryStreamCall +from ._base_call import UnaryUnaryCall +from ._base_channel import Channel +from ._base_channel import StreamStreamMultiCallable +from ._base_channel import StreamUnaryMultiCallable +from ._base_channel import UnaryStreamMultiCallable +from ._base_channel import UnaryUnaryMultiCallable +from ._base_server import Server +from ._base_server import ServicerContext +from ._call import AioRpcError +from ._channel import insecure_channel +from ._channel import secure_channel +from ._interceptor import ClientCallDetails +from ._interceptor import ClientInterceptor +from ._interceptor import InterceptedUnaryUnaryCall +from ._interceptor import ServerInterceptor +from ._interceptor import StreamStreamClientInterceptor +from ._interceptor import StreamUnaryClientInterceptor +from ._interceptor import UnaryStreamClientInterceptor +from ._interceptor import UnaryUnaryClientInterceptor +from ._metadata import Metadata +from ._server import server +from ._typing import ChannelArgumentType + +################################### __all__ ################################# + +__all__ = ( + 'init_grpc_aio', + 'shutdown_grpc_aio', + 'AioRpcError', + 'RpcContext', + 'Call', + 'UnaryUnaryCall', + 'UnaryStreamCall', + 'StreamUnaryCall', + 'StreamStreamCall', + 'Channel', + 'UnaryUnaryMultiCallable', + 'UnaryStreamMultiCallable', + 'StreamUnaryMultiCallable', + 'StreamStreamMultiCallable', + 'ClientCallDetails', + 'ClientInterceptor', + 'UnaryStreamClientInterceptor', + 'UnaryUnaryClientInterceptor', + 'StreamUnaryClientInterceptor', + 'StreamStreamClientInterceptor', + 'InterceptedUnaryUnaryCall', + 'ServerInterceptor', + 'insecure_channel', + 'server', + 'Server', + 'ServicerContext', + 'EOF', + 'secure_channel', + 'AbortError', + 'BaseError', + 'UsageError', + 'InternalError', + 'Metadata', +) diff --git a/contrib/python/grpcio/py3/grpc/aio/_base_call.py b/contrib/python/grpcio/py3/grpc/aio/_base_call.py new file mode 100644 index 00000000000..029584e94a5 --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_base_call.py @@ -0,0 +1,248 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract base classes for client-side Call objects. + +Call objects represents the RPC itself, and offer methods to access / modify +its information. They also offer methods to manipulate the life-cycle of the +RPC, e.g. cancellation. +""" + +from abc import ABCMeta +from abc import abstractmethod +from typing import AsyncIterable, Awaitable, Generic, Optional, Union + +import grpc + +from ._metadata import Metadata +from ._typing import DoneCallbackType +from ._typing import EOFType +from ._typing import RequestType +from ._typing import ResponseType + +__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' + + +class RpcContext(metaclass=ABCMeta): + """Provides RPC-related information and action.""" + + @abstractmethod + def cancelled(self) -> bool: + """Return True if the RPC is cancelled. + + The RPC is cancelled when the cancellation was requested with cancel(). + + Returns: + A bool indicates whether the RPC is cancelled or not. + """ + + @abstractmethod + def done(self) -> bool: + """Return True if the RPC is done. + + An RPC is done if the RPC is completed, cancelled or aborted. + + Returns: + A bool indicates if the RPC is done. + """ + + @abstractmethod + def time_remaining(self) -> Optional[float]: + """Describes the length of allowed time remaining for the RPC. + + Returns: + A nonnegative float indicating the length of allowed time in seconds + remaining for the RPC to complete before it is considered to have + timed out, or None if no deadline was specified for the RPC. + """ + + @abstractmethod + def cancel(self) -> bool: + """Cancels the RPC. + + Idempotent and has no effect if the RPC has already terminated. + + Returns: + A bool indicates if the cancellation is performed or not. + """ + + @abstractmethod + def add_done_callback(self, callback: DoneCallbackType) -> None: + """Registers a callback to be called on RPC termination. + + Args: + callback: A callable object will be called with the call object as + its only argument. + """ + + +class Call(RpcContext, metaclass=ABCMeta): + """The abstract base class of an RPC on the client-side.""" + + @abstractmethod + async def initial_metadata(self) -> Metadata: + """Accesses the initial metadata sent by the server. + + Returns: + The initial :term:`metadata`. + """ + + @abstractmethod + async def trailing_metadata(self) -> Metadata: + """Accesses the trailing metadata sent by the server. + + Returns: + The trailing :term:`metadata`. + """ + + @abstractmethod + async def code(self) -> grpc.StatusCode: + """Accesses the status code sent by the server. + + Returns: + The StatusCode value for the RPC. + """ + + @abstractmethod + async def details(self) -> str: + """Accesses the details sent by the server. + + Returns: + The details string of the RPC. + """ + + @abstractmethod + async def wait_for_connection(self) -> None: + """Waits until connected to peer and raises aio.AioRpcError if failed. + + This is an EXPERIMENTAL method. + + This method ensures the RPC has been successfully connected. Otherwise, + an AioRpcError will be raised to explain the reason of the connection + failure. + + This method is recommended for building retry mechanisms. + """ + + +class UnaryUnaryCall(Generic[RequestType, ResponseType], + Call, + metaclass=ABCMeta): + """The abstract base class of an unary-unary RPC on the client-side.""" + + @abstractmethod + def __await__(self) -> Awaitable[ResponseType]: + """Await the response message to be ready. + + Returns: + The response message of the RPC. + """ + + +class UnaryStreamCall(Generic[RequestType, ResponseType], + Call, + metaclass=ABCMeta): + + @abstractmethod + def __aiter__(self) -> AsyncIterable[ResponseType]: + """Returns the async iterable representation that yields messages. + + Under the hood, it is calling the "read" method. + + Returns: + An async iterable object that yields messages. + """ + + @abstractmethod + async def read(self) -> Union[EOFType, ResponseType]: + """Reads one message from the stream. + + Read operations must be serialized when called from multiple + coroutines. + + Returns: + A response message, or an `grpc.aio.EOF` to indicate the end of the + stream. + """ + + +class StreamUnaryCall(Generic[RequestType, ResponseType], + Call, + metaclass=ABCMeta): + + @abstractmethod + async def write(self, request: RequestType) -> None: + """Writes one message to the stream. + + Raises: + An RpcError exception if the write failed. + """ + + @abstractmethod + async def done_writing(self) -> None: + """Notifies server that the client is done sending messages. + + After done_writing is called, any additional invocation to the write + function will fail. This function is idempotent. + """ + + @abstractmethod + def __await__(self) -> Awaitable[ResponseType]: + """Await the response message to be ready. + + Returns: + The response message of the stream. + """ + + +class StreamStreamCall(Generic[RequestType, ResponseType], + Call, + metaclass=ABCMeta): + + @abstractmethod + def __aiter__(self) -> AsyncIterable[ResponseType]: + """Returns the async iterable representation that yields messages. + + Under the hood, it is calling the "read" method. + + Returns: + An async iterable object that yields messages. + """ + + @abstractmethod + async def read(self) -> Union[EOFType, ResponseType]: + """Reads one message from the stream. + + Read operations must be serialized when called from multiple + coroutines. + + Returns: + A response message, or an `grpc.aio.EOF` to indicate the end of the + stream. + """ + + @abstractmethod + async def write(self, request: RequestType) -> None: + """Writes one message to the stream. + + Raises: + An RpcError exception if the write failed. + """ + + @abstractmethod + async def done_writing(self) -> None: + """Notifies server that the client is done sending messages. + + After done_writing is called, any additional invocation to the write + function will fail. This function is idempotent. + """ diff --git a/contrib/python/grpcio/py3/grpc/aio/_base_channel.py b/contrib/python/grpcio/py3/grpc/aio/_base_channel.py new file mode 100644 index 00000000000..4135e4796c7 --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_base_channel.py @@ -0,0 +1,348 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract base classes for Channel objects and Multicallable objects.""" + +import abc +from typing import Any, Optional + +import grpc + +from . import _base_call +from ._typing import DeserializingFunction +from ._typing import MetadataType +from ._typing import RequestIterableType +from ._typing import SerializingFunction + + +class UnaryUnaryMultiCallable(abc.ABC): + """Enables asynchronous invocation of a unary-call RPC.""" + + @abc.abstractmethod + def __call__( + self, + request: Any, + *, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.UnaryUnaryCall: + """Asynchronously invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow + for the RPC. + metadata: Optional :term:`metadata` to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Only valid for + secure Channel. + wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. + + Returns: + A UnaryUnaryCall object. + + Raises: + RpcError: Indicates that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + +class UnaryStreamMultiCallable(abc.ABC): + """Enables asynchronous invocation of a server-streaming RPC.""" + + @abc.abstractmethod + def __call__( + self, + request: Any, + *, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.UnaryStreamCall: + """Asynchronously invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow + for the RPC. + metadata: Optional :term:`metadata` to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Only valid for + secure Channel. + wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. + + Returns: + A UnaryStreamCall object. + + Raises: + RpcError: Indicates that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + +class StreamUnaryMultiCallable(abc.ABC): + """Enables asynchronous invocation of a client-streaming RPC.""" + + @abc.abstractmethod + def __call__( + self, + request_iterator: Optional[RequestIterableType] = None, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.StreamUnaryCall: + """Asynchronously invokes the underlying RPC. + + Args: + request_iterator: An optional async iterable or iterable of request + messages for the RPC. + timeout: An optional duration of time in seconds to allow + for the RPC. + metadata: Optional :term:`metadata` to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Only valid for + secure Channel. + wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. + + Returns: + A StreamUnaryCall object. + + Raises: + RpcError: Indicates that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + +class StreamStreamMultiCallable(abc.ABC): + """Enables asynchronous invocation of a bidirectional-streaming RPC.""" + + @abc.abstractmethod + def __call__( + self, + request_iterator: Optional[RequestIterableType] = None, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.StreamStreamCall: + """Asynchronously invokes the underlying RPC. + + Args: + request_iterator: An optional async iterable or iterable of request + messages for the RPC. + timeout: An optional duration of time in seconds to allow + for the RPC. + metadata: Optional :term:`metadata` to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Only valid for + secure Channel. + wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. + + Returns: + A StreamStreamCall object. + + Raises: + RpcError: Indicates that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + +class Channel(abc.ABC): + """Enables asynchronous RPC invocation as a client. + + Channel objects implement the Asynchronous Context Manager (aka. async + with) type, although they are not supportted to be entered and exited + multiple times. + """ + + @abc.abstractmethod + async def __aenter__(self): + """Starts an asynchronous context manager. + + Returns: + Channel the channel that was instantiated. + """ + + @abc.abstractmethod + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Finishes the asynchronous context manager by closing the channel. + + Still active RPCs will be cancelled. + """ + + @abc.abstractmethod + async def close(self, grace: Optional[float] = None): + """Closes this Channel and releases all resources held by it. + + This method immediately stops the channel from executing new RPCs in + all cases. + + If a grace period is specified, this method wait until all active + RPCs are finshed, once the grace period is reached the ones that haven't + been terminated are cancelled. If a grace period is not specified + (by passing None for grace), all existing RPCs are cancelled immediately. + + This method is idempotent. + """ + + @abc.abstractmethod + def get_state(self, + try_to_connect: bool = False) -> grpc.ChannelConnectivity: + """Checks the connectivity state of a channel. + + This is an EXPERIMENTAL API. + + If the channel reaches a stable connectivity state, it is guaranteed + that the return value of this function will eventually converge to that + state. + + Args: + try_to_connect: a bool indicate whether the Channel should try to + connect to peer or not. + + Returns: A ChannelConnectivity object. + """ + + @abc.abstractmethod + async def wait_for_state_change( + self, + last_observed_state: grpc.ChannelConnectivity, + ) -> None: + """Waits for a change in connectivity state. + + This is an EXPERIMENTAL API. + + The function blocks until there is a change in the channel connectivity + state from the "last_observed_state". If the state is already + different, this function will return immediately. + + There is an inherent race between the invocation of + "Channel.wait_for_state_change" and "Channel.get_state". The state can + change arbitrary many times during the race, so there is no way to + observe every state transition. + + If there is a need to put a timeout for this function, please refer to + "asyncio.wait_for". + + Args: + last_observed_state: A grpc.ChannelConnectivity object representing + the last known state. + """ + + @abc.abstractmethod + async def channel_ready(self) -> None: + """Creates a coroutine that blocks until the Channel is READY.""" + + @abc.abstractmethod + def unary_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> UnaryUnaryMultiCallable: + """Creates a UnaryUnaryMultiCallable for a unary-unary method. + + Args: + method: The name of the RPC method. + request_serializer: Optional :term:`serializer` for serializing the request + message. Request goes unserialized in case None is passed. + response_deserializer: Optional :term:`deserializer` for deserializing the + response message. Response goes undeserialized in case None + is passed. + + Returns: + A UnaryUnaryMultiCallable value for the named unary-unary method. + """ + + @abc.abstractmethod + def unary_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> UnaryStreamMultiCallable: + """Creates a UnaryStreamMultiCallable for a unary-stream method. + + Args: + method: The name of the RPC method. + request_serializer: Optional :term:`serializer` for serializing the request + message. Request goes unserialized in case None is passed. + response_deserializer: Optional :term:`deserializer` for deserializing the + response message. Response goes undeserialized in case None + is passed. + + Returns: + A UnarySteramMultiCallable value for the named unary-stream method. + """ + + @abc.abstractmethod + def stream_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> StreamUnaryMultiCallable: + """Creates a StreamUnaryMultiCallable for a stream-unary method. + + Args: + method: The name of the RPC method. + request_serializer: Optional :term:`serializer` for serializing the request + message. Request goes unserialized in case None is passed. + response_deserializer: Optional :term:`deserializer` for deserializing the + response message. Response goes undeserialized in case None + is passed. + + Returns: + A StreamUnaryMultiCallable value for the named stream-unary method. + """ + + @abc.abstractmethod + def stream_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> StreamStreamMultiCallable: + """Creates a StreamStreamMultiCallable for a stream-stream method. + + Args: + method: The name of the RPC method. + request_serializer: Optional :term:`serializer` for serializing the request + message. Request goes unserialized in case None is passed. + response_deserializer: Optional :term:`deserializer` for deserializing the + response message. Response goes undeserialized in case None + is passed. + + Returns: + A StreamStreamMultiCallable value for the named stream-stream method. + """ diff --git a/contrib/python/grpcio/py3/grpc/aio/_base_server.py b/contrib/python/grpcio/py3/grpc/aio/_base_server.py new file mode 100644 index 00000000000..a86bbbad09f --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_base_server.py @@ -0,0 +1,369 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract base classes for server-side classes.""" + +import abc +from typing import Generic, Iterable, Mapping, NoReturn, Optional, Sequence + +import grpc + +from ._metadata import Metadata +from ._typing import DoneCallbackType +from ._typing import MetadataType +from ._typing import RequestType +from ._typing import ResponseType + + +class Server(abc.ABC): + """Serves RPCs.""" + + @abc.abstractmethod + def add_generic_rpc_handlers( + self, + generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None: + """Registers GenericRpcHandlers with this Server. + + This method is only safe to call before the server is started. + + Args: + generic_rpc_handlers: A sequence of GenericRpcHandlers that will be + used to service RPCs. + """ + + @abc.abstractmethod + def add_insecure_port(self, address: str) -> int: + """Opens an insecure port for accepting RPCs. + + A port is a communication endpoint that used by networking protocols, + like TCP and UDP. To date, we only support TCP. + + This method may only be called before starting the server. + + Args: + address: The address for which to open a port. If the port is 0, + or not specified in the address, then the gRPC runtime will choose a port. + + Returns: + An integer port on which the server will accept RPC requests. + """ + + @abc.abstractmethod + def add_secure_port(self, address: str, + server_credentials: grpc.ServerCredentials) -> int: + """Opens a secure port for accepting RPCs. + + A port is a communication endpoint that used by networking protocols, + like TCP and UDP. To date, we only support TCP. + + This method may only be called before starting the server. + + Args: + address: The address for which to open a port. + if the port is 0, or not specified in the address, then the gRPC + runtime will choose a port. + server_credentials: A ServerCredentials object. + + Returns: + An integer port on which the server will accept RPC requests. + """ + + @abc.abstractmethod + async def start(self) -> None: + """Starts this Server. + + This method may only be called once. (i.e. it is not idempotent). + """ + + @abc.abstractmethod + async def stop(self, grace: Optional[float]) -> None: + """Stops this Server. + + This method immediately stops the server from servicing new RPCs in + all cases. + + If a grace period is specified, this method returns immediately and all + RPCs active at the end of the grace period are aborted. If a grace + period is not specified (by passing None for grace), all existing RPCs + are aborted immediately and this method blocks until the last RPC + handler terminates. + + This method is idempotent and may be called at any time. Passing a + smaller grace value in a subsequent call will have the effect of + stopping the Server sooner (passing None will have the effect of + stopping the server immediately). Passing a larger grace value in a + subsequent call will not have the effect of stopping the server later + (i.e. the most restrictive grace value is used). + + Args: + grace: A duration of time in seconds or None. + """ + + @abc.abstractmethod + async def wait_for_termination(self, + timeout: Optional[float] = None) -> bool: + """Continues current coroutine once the server stops. + + This is an EXPERIMENTAL API. + + The wait will not consume computational resources during blocking, and + it will block until one of the two following conditions are met: + + 1) The server is stopped or terminated; + 2) A timeout occurs if timeout is not `None`. + + The timeout argument works in the same way as `threading.Event.wait()`. + https://docs.python.org/3/library/threading.html#threading.Event.wait + + Args: + timeout: A floating point number specifying a timeout for the + operation in seconds. + + Returns: + A bool indicates if the operation times out. + """ + + +# pylint: disable=too-many-public-methods +class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): + """A context object passed to method implementations.""" + + @abc.abstractmethod + async def read(self) -> RequestType: + """Reads one message from the RPC. + + Only one read operation is allowed simultaneously. + + Returns: + A response message of the RPC. + + Raises: + An RpcError exception if the read failed. + """ + + @abc.abstractmethod + async def write(self, message: ResponseType) -> None: + """Writes one message to the RPC. + + Only one write operation is allowed simultaneously. + + Raises: + An RpcError exception if the write failed. + """ + + @abc.abstractmethod + async def send_initial_metadata(self, + initial_metadata: MetadataType) -> None: + """Sends the initial metadata value to the client. + + This method need not be called by implementations if they have no + metadata to add to what the gRPC runtime will transmit. + + Args: + initial_metadata: The initial :term:`metadata`. + """ + + @abc.abstractmethod + async def abort( + self, + code: grpc.StatusCode, + details: str = '', + trailing_metadata: MetadataType = tuple()) -> NoReturn: + """Raises an exception to terminate the RPC with a non-OK status. + + The code and details passed as arguments will supercede any existing + ones. + + Args: + code: A StatusCode object to be sent to the client. + It must not be StatusCode.OK. + details: A UTF-8-encodable string to be sent to the client upon + termination of the RPC. + trailing_metadata: A sequence of tuple represents the trailing + :term:`metadata`. + + Raises: + Exception: An exception is always raised to signal the abortion the + RPC to the gRPC runtime. + """ + + @abc.abstractmethod + def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: + """Sends the trailing metadata for the RPC. + + This method need not be called by implementations if they have no + metadata to add to what the gRPC runtime will transmit. + + Args: + trailing_metadata: The trailing :term:`metadata`. + """ + + @abc.abstractmethod + def invocation_metadata(self) -> Optional[Metadata]: + """Accesses the metadata sent by the client. + + Returns: + The invocation :term:`metadata`. + """ + + @abc.abstractmethod + def set_code(self, code: grpc.StatusCode) -> None: + """Sets the value to be used as status code upon RPC completion. + + This method need not be called by method implementations if they wish + the gRPC runtime to determine the status code of the RPC. + + Args: + code: A StatusCode object to be sent to the client. + """ + + @abc.abstractmethod + def set_details(self, details: str) -> None: + """Sets the value to be used the as detail string upon RPC completion. + + This method need not be called by method implementations if they have + no details to transmit. + + Args: + details: A UTF-8-encodable string to be sent to the client upon + termination of the RPC. + """ + + @abc.abstractmethod + def set_compression(self, compression: grpc.Compression) -> None: + """Set the compression algorithm to be used for the entire call. + + Args: + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. + """ + + @abc.abstractmethod + def disable_next_message_compression(self) -> None: + """Disables compression for the next response message. + + This method will override any compression configuration set during + server creation or set on the call. + """ + + @abc.abstractmethod + def peer(self) -> str: + """Identifies the peer that invoked the RPC being serviced. + + Returns: + A string identifying the peer that invoked the RPC being serviced. + The string format is determined by gRPC runtime. + """ + + @abc.abstractmethod + def peer_identities(self) -> Optional[Iterable[bytes]]: + """Gets one or more peer identity(s). + + Equivalent to + servicer_context.auth_context().get(servicer_context.peer_identity_key()) + + Returns: + An iterable of the identities, or None if the call is not + authenticated. Each identity is returned as a raw bytes type. + """ + + @abc.abstractmethod + def peer_identity_key(self) -> Optional[str]: + """The auth property used to identify the peer. + + For example, "x509_common_name" or "x509_subject_alternative_name" are + used to identify an SSL peer. + + Returns: + The auth property (string) that indicates the + peer identity, or None if the call is not authenticated. + """ + + @abc.abstractmethod + def auth_context(self) -> Mapping[str, Iterable[bytes]]: + """Gets the auth context for the call. + + Returns: + A map of strings to an iterable of bytes for each auth property. + """ + + def time_remaining(self) -> float: + """Describes the length of allowed time remaining for the RPC. + + Returns: + A nonnegative float indicating the length of allowed time in seconds + remaining for the RPC to complete before it is considered to have + timed out, or None if no deadline was specified for the RPC. + """ + + def trailing_metadata(self): + """Access value to be used as trailing metadata upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The trailing :term:`metadata` for the RPC. + """ + raise NotImplementedError() + + def code(self): + """Accesses the value to be used as status code upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The StatusCode value for the RPC. + """ + raise NotImplementedError() + + def details(self): + """Accesses the value to be used as detail string upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The details string of the RPC. + """ + raise NotImplementedError() + + def add_done_callback(self, callback: DoneCallbackType) -> None: + """Registers a callback to be called on RPC termination. + + This is an EXPERIMENTAL API. + + Args: + callback: A callable object will be called with the servicer context + object as its only argument. + """ + + def cancelled(self) -> bool: + """Return True if the RPC is cancelled. + + The RPC is cancelled when the cancellation was requested with cancel(). + + This is an EXPERIMENTAL API. + + Returns: + A bool indicates whether the RPC is cancelled or not. + """ + + def done(self) -> bool: + """Return True if the RPC is done. + + An RPC is done if the RPC is completed, cancelled or aborted. + + This is an EXPERIMENTAL API. + + Returns: + A bool indicates if the RPC is done. + """ diff --git a/contrib/python/grpcio/py3/grpc/aio/_call.py b/contrib/python/grpcio/py3/grpc/aio/_call.py new file mode 100644 index 00000000000..37ba945da73 --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_call.py @@ -0,0 +1,649 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Invocation-side implementation of gRPC Asyncio Python.""" + +import asyncio +import enum +from functools import partial +import inspect +import logging +import traceback +from typing import AsyncIterator, Optional, Tuple + +import grpc +from grpc import _common +from grpc._cython import cygrpc + +from . import _base_call +from ._metadata import Metadata +from ._typing import DeserializingFunction +from ._typing import DoneCallbackType +from ._typing import MetadatumType +from ._typing import RequestIterableType +from ._typing import RequestType +from ._typing import ResponseType +from ._typing import SerializingFunction + +__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' + +_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' +_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!' +_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.' +_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".' +_API_STYLE_ERROR = 'The iterator and read/write APIs may not be mixed on a single RPC.' + +_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' + '\tstatus = {}\n' + '\tdetails = "{}"\n' + '>') + +_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' + '\tstatus = {}\n' + '\tdetails = "{}"\n' + '\tdebug_error_string = "{}"\n' + '>') + +_LOGGER = logging.getLogger(__name__) + + +class AioRpcError(grpc.RpcError): + """An implementation of RpcError to be used by the asynchronous API. + + Raised RpcError is a snapshot of the final status of the RPC, values are + determined. Hence, its methods no longer needs to be coroutines. + """ + + _code: grpc.StatusCode + _details: Optional[str] + _initial_metadata: Optional[Metadata] + _trailing_metadata: Optional[Metadata] + _debug_error_string: Optional[str] + + def __init__(self, + code: grpc.StatusCode, + initial_metadata: Metadata, + trailing_metadata: Metadata, + details: Optional[str] = None, + debug_error_string: Optional[str] = None) -> None: + """Constructor. + + Args: + code: The status code with which the RPC has been finalized. + details: Optional details explaining the reason of the error. + initial_metadata: Optional initial metadata that could be sent by the + Server. + trailing_metadata: Optional metadata that could be sent by the Server. + """ + + super().__init__() + self._code = code + self._details = details + self._initial_metadata = initial_metadata + self._trailing_metadata = trailing_metadata + self._debug_error_string = debug_error_string + + def code(self) -> grpc.StatusCode: + """Accesses the status code sent by the server. + + Returns: + The `grpc.StatusCode` status code. + """ + return self._code + + def details(self) -> Optional[str]: + """Accesses the details sent by the server. + + Returns: + The description of the error. + """ + return self._details + + def initial_metadata(self) -> Metadata: + """Accesses the initial metadata sent by the server. + + Returns: + The initial metadata received. + """ + return self._initial_metadata + + def trailing_metadata(self) -> Metadata: + """Accesses the trailing metadata sent by the server. + + Returns: + The trailing metadata received. + """ + return self._trailing_metadata + + def debug_error_string(self) -> str: + """Accesses the debug error string sent by the server. + + Returns: + The debug error string received. + """ + return self._debug_error_string + + def _repr(self) -> str: + """Assembles the error string for the RPC error.""" + return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__, + self._code, self._details, + self._debug_error_string) + + def __repr__(self) -> str: + return self._repr() + + def __str__(self) -> str: + return self._repr() + + +def _create_rpc_error(initial_metadata: Metadata, + status: cygrpc.AioRpcStatus) -> AioRpcError: + return AioRpcError( + _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], + Metadata.from_tuple(initial_metadata), + Metadata.from_tuple(status.trailing_metadata()), + details=status.details(), + debug_error_string=status.debug_error_string(), + ) + + +class Call: + """Base implementation of client RPC Call object. + + Implements logic around final status, metadata and cancellation. + """ + _loop: asyncio.AbstractEventLoop + _code: grpc.StatusCode + _cython_call: cygrpc._AioCall + _metadata: Tuple[MetadatumType, ...] + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + + def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._cython_call = cython_call + self._metadata = tuple(metadata) + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __del__(self) -> None: + # The '_cython_call' object might be destructed before Call object + if hasattr(self, '_cython_call'): + if not self._cython_call.done(): + self._cancel(_GC_CANCELLATION_DETAILS) + + def cancelled(self) -> bool: + return self._cython_call.cancelled() + + def _cancel(self, details: str) -> bool: + """Forwards the application cancellation reasoning.""" + if not self._cython_call.done(): + self._cython_call.cancel(details) + return True + else: + return False + + def cancel(self) -> bool: + return self._cancel(_LOCAL_CANCELLATION_DETAILS) + + def done(self) -> bool: + return self._cython_call.done() + + def add_done_callback(self, callback: DoneCallbackType) -> None: + cb = partial(callback, self) + self._cython_call.add_done_callback(cb) + + def time_remaining(self) -> Optional[float]: + return self._cython_call.time_remaining() + + async def initial_metadata(self) -> Metadata: + raw_metadata_tuple = await self._cython_call.initial_metadata() + return Metadata.from_tuple(raw_metadata_tuple) + + async def trailing_metadata(self) -> Metadata: + raw_metadata_tuple = (await + self._cython_call.status()).trailing_metadata() + return Metadata.from_tuple(raw_metadata_tuple) + + async def code(self) -> grpc.StatusCode: + cygrpc_code = (await self._cython_call.status()).code() + return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code] + + async def details(self) -> str: + return (await self._cython_call.status()).details() + + async def debug_error_string(self) -> str: + return (await self._cython_call.status()).debug_error_string() + + async def _raise_for_status(self) -> None: + if self._cython_call.is_locally_cancelled(): + raise asyncio.CancelledError() + code = await self.code() + if code != grpc.StatusCode.OK: + raise _create_rpc_error(await self.initial_metadata(), await + self._cython_call.status()) + + def _repr(self) -> str: + return repr(self._cython_call) + + def __repr__(self) -> str: + return self._repr() + + def __str__(self) -> str: + return self._repr() + + +class _APIStyle(enum.IntEnum): + UNKNOWN = 0 + ASYNC_GENERATOR = 1 + READER_WRITER = 2 + + +class _UnaryResponseMixin(Call): + _call_response: asyncio.Task + + def _init_unary_response_mixin(self, response_task: asyncio.Task): + self._call_response = response_task + + def cancel(self) -> bool: + if super().cancel(): + self._call_response.cancel() + return True + else: + return False + + def __await__(self) -> ResponseType: + """Wait till the ongoing RPC request finishes.""" + try: + response = yield from self._call_response + except asyncio.CancelledError: + # Even if we caught all other CancelledError, there is still + # this corner case. If the application cancels immediately after + # the Call object is created, we will observe this + # `CancelledError`. + if not self.cancelled(): + self.cancel() + raise + + # NOTE(lidiz) If we raise RpcError in the task, and users doesn't + # 'await' on it. AsyncIO will log 'Task exception was never retrieved'. + # Instead, if we move the exception raising here, the spam stops. + # Unfortunately, there can only be one 'yield from' in '__await__'. So, + # we need to access the private instance variable. + if response is cygrpc.EOF: + if self._cython_call.is_locally_cancelled(): + raise asyncio.CancelledError() + else: + raise _create_rpc_error(self._cython_call._initial_metadata, + self._cython_call._status) + else: + return response + + +class _StreamResponseMixin(Call): + _message_aiter: AsyncIterator[ResponseType] + _preparation: asyncio.Task + _response_style: _APIStyle + + def _init_stream_response_mixin(self, preparation: asyncio.Task): + self._message_aiter = None + self._preparation = preparation + self._response_style = _APIStyle.UNKNOWN + + def _update_response_style(self, style: _APIStyle): + if self._response_style is _APIStyle.UNKNOWN: + self._response_style = style + elif self._response_style is not style: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + def cancel(self) -> bool: + if super().cancel(): + self._preparation.cancel() + return True + else: + return False + + async def _fetch_stream_responses(self) -> ResponseType: + message = await self._read() + while message is not cygrpc.EOF: + yield message + message = await self._read() + + # If the read operation failed, Core should explain why. + await self._raise_for_status() + + def __aiter__(self) -> AsyncIterator[ResponseType]: + self._update_response_style(_APIStyle.ASYNC_GENERATOR) + if self._message_aiter is None: + self._message_aiter = self._fetch_stream_responses() + return self._message_aiter + + async def _read(self) -> ResponseType: + # Wait for the request being sent + await self._preparation + + # Reads response message from Core + try: + raw_response = await self._cython_call.receive_serialized_message() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + if raw_response is cygrpc.EOF: + return cygrpc.EOF + else: + return _common.deserialize(raw_response, + self._response_deserializer) + + async def read(self) -> ResponseType: + if self.done(): + await self._raise_for_status() + return cygrpc.EOF + self._update_response_style(_APIStyle.READER_WRITER) + + response_message = await self._read() + + if response_message is cygrpc.EOF: + # If the read operation failed, Core should explain why. + await self._raise_for_status() + return response_message + + +class _StreamRequestMixin(Call): + _metadata_sent: asyncio.Event + _done_writing_flag: bool + _async_request_poller: Optional[asyncio.Task] + _request_style: _APIStyle + + def _init_stream_request_mixin( + self, request_iterator: Optional[RequestIterableType]): + self._metadata_sent = asyncio.Event() + self._done_writing_flag = False + + # If user passes in an async iterator, create a consumer Task. + if request_iterator is not None: + self._async_request_poller = self._loop.create_task( + self._consume_request_iterator(request_iterator)) + self._request_style = _APIStyle.ASYNC_GENERATOR + else: + self._async_request_poller = None + self._request_style = _APIStyle.READER_WRITER + + def _raise_for_different_style(self, style: _APIStyle): + if self._request_style is not style: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + def cancel(self) -> bool: + if super().cancel(): + if self._async_request_poller is not None: + self._async_request_poller.cancel() + return True + else: + return False + + def _metadata_sent_observer(self): + self._metadata_sent.set() + + async def _consume_request_iterator( + self, request_iterator: RequestIterableType) -> None: + try: + if inspect.isasyncgen(request_iterator) or hasattr( + request_iterator, '__aiter__'): + async for request in request_iterator: + try: + await self._write(request) + except AioRpcError as rpc_error: + _LOGGER.debug( + 'Exception while consuming the request_iterator: %s', + rpc_error) + return + else: + for request in request_iterator: + try: + await self._write(request) + except AioRpcError as rpc_error: + _LOGGER.debug( + 'Exception while consuming the request_iterator: %s', + rpc_error) + return + + await self._done_writing() + except: # pylint: disable=bare-except + # Client iterators can raise exceptions, which we should handle by + # cancelling the RPC and logging the client's error. No exceptions + # should escape this function. + _LOGGER.debug('Client request_iterator raised exception:\n%s', + traceback.format_exc()) + self.cancel() + + async def _write(self, request: RequestType) -> None: + if self.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + if self._done_writing_flag: + raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + if not self._metadata_sent.is_set(): + await self._metadata_sent.wait() + if self.done(): + await self._raise_for_status() + + serialized_request = _common.serialize(request, + self._request_serializer) + try: + await self._cython_call.send_serialized_message(serialized_request) + except cygrpc.InternalError: + await self._raise_for_status() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + async def _done_writing(self) -> None: + if self.done(): + # If the RPC is finished, do nothing. + return + if not self._done_writing_flag: + # If the done writing is not sent before, try to send it. + self._done_writing_flag = True + try: + await self._cython_call.send_receive_close() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + async def write(self, request: RequestType) -> None: + self._raise_for_different_style(_APIStyle.READER_WRITER) + await self._write(request) + + async def done_writing(self) -> None: + """Signal peer that client is done writing. + + This method is idempotent. + """ + self._raise_for_different_style(_APIStyle.READER_WRITER) + await self._done_writing() + + async def wait_for_connection(self) -> None: + await self._metadata_sent.wait() + if self.done(): + await self._raise_for_status() + + +class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): + """Object for managing unary-unary RPC calls. + + Returned when an instance of `UnaryUnaryMultiCallable` object is called. + """ + _request: RequestType + _invocation_task: asyncio.Task + + # pylint: disable=too-many-arguments + def __init__(self, request: RequestType, deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + super().__init__( + channel.call(method, deadline, credentials, wait_for_ready), + metadata, request_serializer, response_deserializer, loop) + self._request = request + self._invocation_task = loop.create_task(self._invoke()) + self._init_unary_response_mixin(self._invocation_task) + + async def _invoke(self) -> ResponseType: + serialized_request = _common.serialize(self._request, + self._request_serializer) + + # NOTE(lidiz) asyncio.CancelledError is not a good transport for status, + # because the asyncio.Task class do not cache the exception object. + # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 + try: + serialized_response = await self._cython_call.unary_unary( + serialized_request, self._metadata) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + + if self._cython_call.is_ok(): + return _common.deserialize(serialized_response, + self._response_deserializer) + else: + return cygrpc.EOF + + async def wait_for_connection(self) -> None: + await self._invocation_task + if self.done(): + await self._raise_for_status() + + +class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): + """Object for managing unary-stream RPC calls. + + Returned when an instance of `UnaryStreamMultiCallable` object is called. + """ + _request: RequestType + _send_unary_request_task: asyncio.Task + + # pylint: disable=too-many-arguments + def __init__(self, request: RequestType, deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + super().__init__( + channel.call(method, deadline, credentials, wait_for_ready), + metadata, request_serializer, response_deserializer, loop) + self._request = request + self._send_unary_request_task = loop.create_task( + self._send_unary_request()) + self._init_stream_response_mixin(self._send_unary_request_task) + + async def _send_unary_request(self) -> ResponseType: + serialized_request = _common.serialize(self._request, + self._request_serializer) + try: + await self._cython_call.initiate_unary_stream( + serialized_request, self._metadata) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + async def wait_for_connection(self) -> None: + await self._send_unary_request_task + if self.done(): + await self._raise_for_status() + + +class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, + _base_call.StreamUnaryCall): + """Object for managing stream-unary RPC calls. + + Returned when an instance of `StreamUnaryMultiCallable` object is called. + """ + + # pylint: disable=too-many-arguments + def __init__(self, request_iterator: Optional[RequestIterableType], + deadline: Optional[float], metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + super().__init__( + channel.call(method, deadline, credentials, wait_for_ready), + metadata, request_serializer, response_deserializer, loop) + + self._init_stream_request_mixin(request_iterator) + self._init_unary_response_mixin(loop.create_task(self._conduct_rpc())) + + async def _conduct_rpc(self) -> ResponseType: + try: + serialized_response = await self._cython_call.stream_unary( + self._metadata, self._metadata_sent_observer) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + if self._cython_call.is_ok(): + return _common.deserialize(serialized_response, + self._response_deserializer) + else: + return cygrpc.EOF + + +class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, + _base_call.StreamStreamCall): + """Object for managing stream-stream RPC calls. + + Returned when an instance of `StreamStreamMultiCallable` object is called. + """ + _initializer: asyncio.Task + + # pylint: disable=too-many-arguments + def __init__(self, request_iterator: Optional[RequestIterableType], + deadline: Optional[float], metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + super().__init__( + channel.call(method, deadline, credentials, wait_for_ready), + metadata, request_serializer, response_deserializer, loop) + self._initializer = self._loop.create_task(self._prepare_rpc()) + self._init_stream_request_mixin(request_iterator) + self._init_stream_response_mixin(self._initializer) + + async def _prepare_rpc(self): + """This method prepares the RPC for receiving/sending messages. + + All other operations around the stream should only happen after the + completion of this method. + """ + try: + await self._cython_call.initiate_stream_stream( + self._metadata, self._metadata_sent_observer) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + # No need to raise RpcError here, because no one will `await` this task. diff --git a/contrib/python/grpcio/py3/grpc/aio/_channel.py b/contrib/python/grpcio/py3/grpc/aio/_channel.py new file mode 100644 index 00000000000..a6fb2221250 --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_channel.py @@ -0,0 +1,492 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Invocation-side implementation of gRPC Asyncio Python.""" + +import asyncio +import sys +from typing import Any, Iterable, List, Optional, Sequence + +import grpc +from grpc import _common +from grpc import _compression +from grpc import _grpcio_metadata +from grpc._cython import cygrpc + +from . import _base_call +from . import _base_channel +from ._call import StreamStreamCall +from ._call import StreamUnaryCall +from ._call import UnaryStreamCall +from ._call import UnaryUnaryCall +from ._interceptor import ClientInterceptor +from ._interceptor import InterceptedStreamStreamCall +from ._interceptor import InterceptedStreamUnaryCall +from ._interceptor import InterceptedUnaryStreamCall +from ._interceptor import InterceptedUnaryUnaryCall +from ._interceptor import StreamStreamClientInterceptor +from ._interceptor import StreamUnaryClientInterceptor +from ._interceptor import UnaryStreamClientInterceptor +from ._interceptor import UnaryUnaryClientInterceptor +from ._metadata import Metadata +from ._typing import ChannelArgumentType +from ._typing import DeserializingFunction +from ._typing import RequestIterableType +from ._typing import SerializingFunction +from ._utils import _timeout_to_deadline + +_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) + +if sys.version_info[1] < 7: + + def _all_tasks() -> Iterable[asyncio.Task]: + return asyncio.Task.all_tasks() +else: + + def _all_tasks() -> Iterable[asyncio.Task]: + return asyncio.all_tasks() + + +def _augment_channel_arguments(base_options: ChannelArgumentType, + compression: Optional[grpc.Compression]): + compression_channel_argument = _compression.create_channel_option( + compression) + user_agent_channel_argument = (( + cygrpc.ChannelArgKey.primary_user_agent_string, + _USER_AGENT, + ),) + return tuple(base_options + ) + compression_channel_argument + user_agent_channel_argument + + +class _BaseMultiCallable: + """Base class of all multi callable objects. + + Handles the initialization logic and stores common attributes. + """ + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _method: bytes + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + _interceptors: Optional[Sequence[ClientInterceptor]] + _references: List[Any] + _loop: asyncio.AbstractEventLoop + + # pylint: disable=too-many-arguments + def __init__( + self, + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + interceptors: Optional[Sequence[ClientInterceptor]], + references: List[Any], + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._channel = channel + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + self._interceptors = interceptors + self._references = references + + @staticmethod + def _init_metadata( + metadata: Optional[Metadata] = None, + compression: Optional[grpc.Compression] = None) -> Metadata: + """Based on the provided values for <metadata> or <compression> initialise the final + metadata, as it should be used for the current call. + """ + metadata = metadata or Metadata() + if compression: + metadata = Metadata( + *_compression.augment_metadata(metadata, compression)) + return metadata + + +class UnaryUnaryMultiCallable(_BaseMultiCallable, + _base_channel.UnaryUnaryMultiCallable): + + def __call__( + self, + request: Any, + *, + timeout: Optional[float] = None, + metadata: Optional[Metadata] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.UnaryUnaryCall: + + metadata = self._init_metadata(metadata, compression) + if not self._interceptors: + call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), + metadata, credentials, wait_for_ready, + self._channel, self._method, + self._request_serializer, + self._response_deserializer, self._loop) + else: + call = InterceptedUnaryUnaryCall( + self._interceptors, request, timeout, metadata, credentials, + wait_for_ready, self._channel, self._method, + self._request_serializer, self._response_deserializer, + self._loop) + + return call + + +class UnaryStreamMultiCallable(_BaseMultiCallable, + _base_channel.UnaryStreamMultiCallable): + + def __call__( + self, + request: Any, + *, + timeout: Optional[float] = None, + metadata: Optional[Metadata] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.UnaryStreamCall: + + metadata = self._init_metadata(metadata, compression) + deadline = _timeout_to_deadline(timeout) + + if not self._interceptors: + call = UnaryStreamCall(request, deadline, metadata, credentials, + wait_for_ready, self._channel, self._method, + self._request_serializer, + self._response_deserializer, self._loop) + else: + call = InterceptedUnaryStreamCall( + self._interceptors, request, deadline, metadata, credentials, + wait_for_ready, self._channel, self._method, + self._request_serializer, self._response_deserializer, + self._loop) + + return call + + +class StreamUnaryMultiCallable(_BaseMultiCallable, + _base_channel.StreamUnaryMultiCallable): + + def __call__( + self, + request_iterator: Optional[RequestIterableType] = None, + timeout: Optional[float] = None, + metadata: Optional[Metadata] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.StreamUnaryCall: + + metadata = self._init_metadata(metadata, compression) + deadline = _timeout_to_deadline(timeout) + + if not self._interceptors: + call = StreamUnaryCall(request_iterator, deadline, metadata, + credentials, wait_for_ready, self._channel, + self._method, self._request_serializer, + self._response_deserializer, self._loop) + else: + call = InterceptedStreamUnaryCall( + self._interceptors, request_iterator, deadline, metadata, + credentials, wait_for_ready, self._channel, self._method, + self._request_serializer, self._response_deserializer, + self._loop) + + return call + + +class StreamStreamMultiCallable(_BaseMultiCallable, + _base_channel.StreamStreamMultiCallable): + + def __call__( + self, + request_iterator: Optional[RequestIterableType] = None, + timeout: Optional[float] = None, + metadata: Optional[Metadata] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _base_call.StreamStreamCall: + + metadata = self._init_metadata(metadata, compression) + deadline = _timeout_to_deadline(timeout) + + if not self._interceptors: + call = StreamStreamCall(request_iterator, deadline, metadata, + credentials, wait_for_ready, self._channel, + self._method, self._request_serializer, + self._response_deserializer, self._loop) + else: + call = InterceptedStreamStreamCall( + self._interceptors, request_iterator, deadline, metadata, + credentials, wait_for_ready, self._channel, self._method, + self._request_serializer, self._response_deserializer, + self._loop) + + return call + + +class Channel(_base_channel.Channel): + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] + _unary_stream_interceptors: List[UnaryStreamClientInterceptor] + _stream_unary_interceptors: List[StreamUnaryClientInterceptor] + _stream_stream_interceptors: List[StreamStreamClientInterceptor] + + def __init__(self, target: str, options: ChannelArgumentType, + credentials: Optional[grpc.ChannelCredentials], + compression: Optional[grpc.Compression], + interceptors: Optional[Sequence[ClientInterceptor]]): + """Constructor. + + Args: + target: The target to which to connect. + options: Configuration options for the channel. + credentials: A cygrpc.ChannelCredentials or None. + compression: An optional value indicating the compression method to be + used over the lifetime of the channel. + interceptors: An optional list of interceptors that would be used for + intercepting any RPC executed with that channel. + """ + self._unary_unary_interceptors = [] + self._unary_stream_interceptors = [] + self._stream_unary_interceptors = [] + self._stream_stream_interceptors = [] + + if interceptors is not None: + for interceptor in interceptors: + if isinstance(interceptor, UnaryUnaryClientInterceptor): + self._unary_unary_interceptors.append(interceptor) + elif isinstance(interceptor, UnaryStreamClientInterceptor): + self._unary_stream_interceptors.append(interceptor) + elif isinstance(interceptor, StreamUnaryClientInterceptor): + self._stream_unary_interceptors.append(interceptor) + elif isinstance(interceptor, StreamStreamClientInterceptor): + self._stream_stream_interceptors.append(interceptor) + else: + raise ValueError( + "Interceptor {} must be ".format(interceptor) + + "{} or ".format(UnaryUnaryClientInterceptor.__name__) + + "{} or ".format(UnaryStreamClientInterceptor.__name__) + + "{} or ".format(StreamUnaryClientInterceptor.__name__) + + "{}. ".format(StreamStreamClientInterceptor.__name__)) + + self._loop = cygrpc.get_working_loop() + self._channel = cygrpc.AioChannel( + _common.encode(target), + _augment_channel_arguments(options, compression), credentials, + self._loop) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._close(None) + + async def _close(self, grace): # pylint: disable=too-many-branches + if self._channel.closed(): + return + + # No new calls will be accepted by the Cython channel. + self._channel.closing() + + # Iterate through running tasks + tasks = _all_tasks() + calls = [] + call_tasks = [] + for task in tasks: + try: + stack = task.get_stack(limit=1) + except AttributeError as attribute_error: + # NOTE(lidiz) tl;dr: If the Task is created with a CPython + # object, it will trigger AttributeError. + # + # In the global finalizer, the event loop schedules + # a CPython PyAsyncGenAThrow object. + # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484 + # + # However, the PyAsyncGenAThrow object is written in C and + # failed to include the normal Python frame objects. Hence, + # this exception is a false negative, and it is safe to ignore + # the failure. It is fixed by https://github.com/python/cpython/pull/18669, + # but not available until 3.9 or 3.8.3. So, we have to keep it + # for a while. + # TODO(lidiz) drop this hack after 3.8 deprecation + if 'frame' in str(attribute_error): + continue + else: + raise + + # If the Task is created by a C-extension, the stack will be empty. + if not stack: + continue + + # Locate ones created by `aio.Call`. + frame = stack[0] + candidate = frame.f_locals.get('self') + if candidate: + if isinstance(candidate, _base_call.Call): + if hasattr(candidate, '_channel'): + # For intercepted Call object + if candidate._channel is not self._channel: + continue + elif hasattr(candidate, '_cython_call'): + # For normal Call object + if candidate._cython_call._channel is not self._channel: + continue + else: + # Unidentified Call object + raise cygrpc.InternalError( + f'Unrecognized call object: {candidate}') + + calls.append(candidate) + call_tasks.append(task) + + # If needed, try to wait for them to finish. + # Call objects are not always awaitables. + if grace and call_tasks: + await asyncio.wait(call_tasks, timeout=grace) + + # Time to cancel existing calls. + for call in calls: + call.cancel() + + # Destroy the channel + self._channel.close() + + async def close(self, grace: Optional[float] = None): + await self._close(grace) + + def __del__(self): + if hasattr(self, '_channel'): + if not self._channel.closed(): + self._channel.close() + + def get_state(self, + try_to_connect: bool = False) -> grpc.ChannelConnectivity: + result = self._channel.check_connectivity_state(try_to_connect) + return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] + + async def wait_for_state_change( + self, + last_observed_state: grpc.ChannelConnectivity, + ) -> None: + assert await self._channel.watch_connectivity_state( + last_observed_state.value[0], None) + + async def channel_ready(self) -> None: + state = self.get_state(try_to_connect=True) + while state != grpc.ChannelConnectivity.READY: + await self.wait_for_state_change(state) + state = self.get_state(try_to_connect=True) + + def unary_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> UnaryUnaryMultiCallable: + return UnaryUnaryMultiCallable(self._channel, _common.encode(method), + request_serializer, + response_deserializer, + self._unary_unary_interceptors, [self], + self._loop) + + def unary_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> UnaryStreamMultiCallable: + return UnaryStreamMultiCallable(self._channel, _common.encode(method), + request_serializer, + response_deserializer, + self._unary_stream_interceptors, [self], + self._loop) + + def stream_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> StreamUnaryMultiCallable: + return StreamUnaryMultiCallable(self._channel, _common.encode(method), + request_serializer, + response_deserializer, + self._stream_unary_interceptors, [self], + self._loop) + + def stream_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> StreamStreamMultiCallable: + return StreamStreamMultiCallable(self._channel, _common.encode(method), + request_serializer, + response_deserializer, + self._stream_stream_interceptors, + [self], self._loop) + + +def insecure_channel( + target: str, + options: Optional[ChannelArgumentType] = None, + compression: Optional[grpc.Compression] = None, + interceptors: Optional[Sequence[ClientInterceptor]] = None): + """Creates an insecure asynchronous Channel to a server. + + Args: + target: The server address + options: An optional list of key-value pairs (:term:`channel_arguments` + in gRPC Core runtime) to configure the channel. + compression: An optional value indicating the compression method to be + used over the lifetime of the channel. + interceptors: An optional sequence of interceptors that will be executed for + any call executed with this channel. + + Returns: + A Channel. + """ + return Channel(target, () if options is None else options, None, + compression, interceptors) + + +def secure_channel(target: str, + credentials: grpc.ChannelCredentials, + options: Optional[ChannelArgumentType] = None, + compression: Optional[grpc.Compression] = None, + interceptors: Optional[Sequence[ClientInterceptor]] = None): + """Creates a secure asynchronous Channel to a server. + + Args: + target: The server address. + credentials: A ChannelCredentials instance. + options: An optional list of key-value pairs (:term:`channel_arguments` + in gRPC Core runtime) to configure the channel. + compression: An optional value indicating the compression method to be + used over the lifetime of the channel. + interceptors: An optional sequence of interceptors that will be executed for + any call executed with this channel. + + Returns: + An aio.Channel. + """ + return Channel(target, () if options is None else options, + credentials._credentials, compression, interceptors) diff --git a/contrib/python/grpcio/py3/grpc/aio/_interceptor.py b/contrib/python/grpcio/py3/grpc/aio/_interceptor.py new file mode 100644 index 00000000000..05f166e3b0b --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_interceptor.py @@ -0,0 +1,1001 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interceptors implementation of gRPC Asyncio Python.""" +from abc import ABCMeta +from abc import abstractmethod +import asyncio +import collections +import functools +from typing import (AsyncIterable, Awaitable, Callable, Iterator, List, + Optional, Sequence, Union) + +import grpc +from grpc._cython import cygrpc + +from . import _base_call +from ._call import AioRpcError +from ._call import StreamStreamCall +from ._call import StreamUnaryCall +from ._call import UnaryStreamCall +from ._call import UnaryUnaryCall +from ._call import _API_STYLE_ERROR +from ._call import _RPC_ALREADY_FINISHED_DETAILS +from ._call import _RPC_HALF_CLOSED_DETAILS +from ._metadata import Metadata +from ._typing import DeserializingFunction +from ._typing import DoneCallbackType +from ._typing import RequestIterableType +from ._typing import RequestType +from ._typing import ResponseIterableType +from ._typing import ResponseType +from ._typing import SerializingFunction +from ._utils import _timeout_to_deadline + +_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' + + +class ServerInterceptor(metaclass=ABCMeta): + """Affords intercepting incoming RPCs on the service-side. + + This is an EXPERIMENTAL API. + """ + + @abstractmethod + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], + Awaitable[grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + """Intercepts incoming RPCs before handing them over to a handler. + + Args: + continuation: A function that takes a HandlerCallDetails and + proceeds to invoke the next interceptor in the chain, if any, + or the RPC handler lookup logic, with the call details passed + as an argument, and returns an RpcMethodHandler instance if + the RPC is considered serviced, or None otherwise. + handler_call_details: A HandlerCallDetails describing the RPC. + + Returns: + An RpcMethodHandler with which the RPC may be serviced if the + interceptor chooses to service this RPC, or None otherwise. + """ + + +class ClientCallDetails( + collections.namedtuple( + 'ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')), + grpc.ClientCallDetails): + """Describes an RPC to be invoked. + + This is an EXPERIMENTAL API. + + Args: + method: The method name of the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: Optional metadata to be transmitted to the service-side of + the RPC. + credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. + """ + + method: str + timeout: Optional[float] + metadata: Optional[Metadata] + credentials: Optional[grpc.CallCredentials] + wait_for_ready: Optional[bool] + + +class ClientInterceptor(metaclass=ABCMeta): + """Base class used for all Aio Client Interceptor classes""" + + +class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting unary-unary invocations.""" + + @abstractmethod + async def intercept_unary_unary( + self, continuation: Callable[[ClientCallDetails, RequestType], + UnaryUnaryCall], + client_call_details: ClientCallDetails, + request: RequestType) -> Union[UnaryUnaryCall, ResponseType]: + """Intercepts a unary-unary invocation asynchronously. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + + Returns: + An object with the RPC response. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting unary-stream invocations.""" + + @abstractmethod + async def intercept_unary_stream( + self, continuation: Callable[[ClientCallDetails, RequestType], + UnaryStreamCall], + client_call_details: ClientCallDetails, request: RequestType + ) -> Union[ResponseIterableType, UnaryStreamCall]: + """Intercepts a unary-stream invocation asynchronously. + + The function could return the call object or an asynchronous + iterator, in case of being an asyncrhonous iterator this will + become the source of the reads done by the caller. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + + Returns: + The RPC Call or an asynchronous iterator. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting stream-unary invocations.""" + + @abstractmethod + async def intercept_stream_unary( + self, + continuation: Callable[[ClientCallDetails, RequestType], + StreamUnaryCall], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> StreamUnaryCall: + """Intercepts a stream-unary invocation asynchronously. + + Within the interceptor the usage of the call methods like `write` or + even awaiting the call should be done carefully, since the caller + could be expecting an untouched call, for example for start writing + messages to it. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request_iterator)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request_iterator: The request iterator that will produce requests + for the RPC. + + Returns: + The RPC Call. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting stream-stream invocations.""" + + @abstractmethod + async def intercept_stream_stream( + self, + continuation: Callable[[ClientCallDetails, RequestType], + StreamStreamCall], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> Union[ResponseIterableType, StreamStreamCall]: + """Intercepts a stream-stream invocation asynchronously. + + Within the interceptor the usage of the call methods like `write` or + even awaiting the call should be done carefully, since the caller + could be expecting an untouched call, for example for start writing + messages to it. + + The function could return the call object or an asynchronous + iterator, in case of being an asyncrhonous iterator this will + become the source of the reads done by the caller. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request_iterator)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request_iterator: The request iterator that will produce requests + for the RPC. + + Returns: + The RPC Call or an asynchronous iterator. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class InterceptedCall: + """Base implementation for all intercepted call arities. + + Interceptors might have some work to do before the RPC invocation with + the capacity of changing the invocation parameters, and some work to do + after the RPC invocation with the capacity for accessing to the wrapped + `UnaryUnaryCall`. + + It handles also early and later cancellations, when the RPC has not even + started and the execution is still held by the interceptors or when the + RPC has finished but again the execution is still held by the interceptors. + + Once the RPC is finally executed, all methods are finally done against the + intercepted call, being at the same time the same call returned to the + interceptors. + + As a base class for all of the interceptors implements the logic around + final status, metadata and cancellation. + """ + + _interceptors_task: asyncio.Task + _pending_add_done_callbacks: Sequence[DoneCallbackType] + + def __init__(self, interceptors_task: asyncio.Task) -> None: + self._interceptors_task = interceptors_task + self._pending_add_done_callbacks = [] + self._interceptors_task.add_done_callback( + self._fire_or_add_pending_done_callbacks) + + def __del__(self): + self.cancel() + + def _fire_or_add_pending_done_callbacks( + self, interceptors_task: asyncio.Task) -> None: + + if not self._pending_add_done_callbacks: + return + + call_completed = False + + try: + call = interceptors_task.result() + if call.done(): + call_completed = True + except (AioRpcError, asyncio.CancelledError): + call_completed = True + + if call_completed: + for callback in self._pending_add_done_callbacks: + callback(self) + else: + for callback in self._pending_add_done_callbacks: + callback = functools.partial(self._wrap_add_done_callback, + callback) + call.add_done_callback(callback) + + self._pending_add_done_callbacks = [] + + def _wrap_add_done_callback(self, callback: DoneCallbackType, + unused_call: _base_call.Call) -> None: + callback(self) + + def cancel(self) -> bool: + if not self._interceptors_task.done(): + # There is no yet the intercepted call available, + # Trying to cancel it by using the generic Asyncio + # cancellation method. + return self._interceptors_task.cancel() + + try: + call = self._interceptors_task.result() + except AioRpcError: + return False + except asyncio.CancelledError: + return False + + return call.cancel() + + def cancelled(self) -> bool: + if not self._interceptors_task.done(): + return False + + try: + call = self._interceptors_task.result() + except AioRpcError as err: + return err.code() == grpc.StatusCode.CANCELLED + except asyncio.CancelledError: + return True + + return call.cancelled() + + def done(self) -> bool: + if not self._interceptors_task.done(): + return False + + try: + call = self._interceptors_task.result() + except (AioRpcError, asyncio.CancelledError): + return True + + return call.done() + + def add_done_callback(self, callback: DoneCallbackType) -> None: + if not self._interceptors_task.done(): + self._pending_add_done_callbacks.append(callback) + return + + try: + call = self._interceptors_task.result() + except (AioRpcError, asyncio.CancelledError): + callback(self) + return + + if call.done(): + callback(self) + else: + callback = functools.partial(self._wrap_add_done_callback, callback) + call.add_done_callback(callback) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[Metadata]: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.initial_metadata() + except asyncio.CancelledError: + return None + + return await call.initial_metadata() + + async def trailing_metadata(self) -> Optional[Metadata]: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.trailing_metadata() + except asyncio.CancelledError: + return None + + return await call.trailing_metadata() + + async def code(self) -> grpc.StatusCode: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.code() + except asyncio.CancelledError: + return grpc.StatusCode.CANCELLED + + return await call.code() + + async def details(self) -> str: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.details() + except asyncio.CancelledError: + return _LOCAL_CANCELLATION_DETAILS + + return await call.details() + + async def debug_error_string(self) -> Optional[str]: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.debug_error_string() + except asyncio.CancelledError: + return '' + + return await call.debug_error_string() + + async def wait_for_connection(self) -> None: + call = await self._interceptors_task + return await call.wait_for_connection() + + +class _InterceptedUnaryResponseMixin: + + def __await__(self): + call = yield from self._interceptors_task.__await__() + response = yield from call.__await__() + return response + + +class _InterceptedStreamResponseMixin: + _response_aiter: Optional[AsyncIterable[ResponseType]] + + def _init_stream_response_mixin(self) -> None: + # Is initalized later, otherwise if the iterator is not finnally + # consumed a logging warning is emmited by Asyncio. + self._response_aiter = None + + async def _wait_for_interceptor_task_response_iterator( + self) -> ResponseType: + call = await self._interceptors_task + async for response in call: + yield response + + def __aiter__(self) -> AsyncIterable[ResponseType]: + if self._response_aiter is None: + self._response_aiter = self._wait_for_interceptor_task_response_iterator( + ) + return self._response_aiter + + async def read(self) -> ResponseType: + if self._response_aiter is None: + self._response_aiter = self._wait_for_interceptor_task_response_iterator( + ) + return await self._response_aiter.asend(None) + + +class _InterceptedStreamRequestMixin: + + _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]] + _write_to_iterator_queue: Optional[asyncio.Queue] + _status_code_task: Optional[asyncio.Task] + + _FINISH_ITERATOR_SENTINEL = object() + + def _init_stream_request_mixin( + self, request_iterator: Optional[RequestIterableType] + ) -> RequestIterableType: + + if request_iterator is None: + # We provide our own request iterator which is a proxy + # of the futures writes that will be done by the caller. + self._write_to_iterator_queue = asyncio.Queue(maxsize=1) + self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator( + ) + self._status_code_task = None + request_iterator = self._write_to_iterator_async_gen + else: + self._write_to_iterator_queue = None + + return request_iterator + + async def _proxy_writes_as_request_iterator(self): + await self._interceptors_task + + while True: + value = await self._write_to_iterator_queue.get() + if value is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL: + break + yield value + + async def _write_to_iterator_queue_interruptible(self, request: RequestType, + call: InterceptedCall): + # Write the specified 'request' to the request iterator queue using the + # specified 'call' to allow for interruption of the write in the case + # of abrupt termination of the call. + if self._status_code_task is None: + self._status_code_task = self._loop.create_task(call.code()) + + await asyncio.wait( + (self._loop.create_task(self._write_to_iterator_queue.put(request)), + self._status_code_task), + return_when=asyncio.FIRST_COMPLETED) + + async def write(self, request: RequestType) -> None: + # If no queue was created it means that requests + # should be expected through an iterators provided + # by the caller. + if self._write_to_iterator_queue is None: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + try: + call = await self._interceptors_task + except (asyncio.CancelledError, AioRpcError): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + if call.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + elif call._done_writing_flag: + raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + + await self._write_to_iterator_queue_interruptible(request, call) + + if call.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + async def done_writing(self) -> None: + """Signal peer that client is done writing. + + This method is idempotent. + """ + # If no queue was created it means that requests + # should be expected through an iterators provided + # by the caller. + if self._write_to_iterator_queue is None: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + try: + call = await self._interceptors_task + except asyncio.CancelledError: + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + await self._write_to_iterator_queue_interruptible( + _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call) + + +class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, + _base_call.UnaryUnaryCall): + """Used for running a `UnaryUnaryCall` wrapped by interceptors. + + For the `__await__` method is it is proxied to the intercepted call only when + the interceptor task is finished. + """ + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + + # pylint: disable=too-many-arguments + def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor], + request: RequestType, timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._channel = channel + interceptors_task = loop.create_task( + self._invoke(interceptors, method, timeout, metadata, credentials, + wait_for_ready, request, request_serializer, + response_deserializer)) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> UnaryUnaryCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: List[UnaryUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType) -> _base_call.UnaryUnaryCall: + + if interceptors: + continuation = functools.partial(_run_interceptor, + interceptors[1:]) + call_or_response = await interceptors[0].intercept_unary_unary( + continuation, client_call_details, request) + + if isinstance(call_or_response, _base_call.UnaryUnaryCall): + return call_or_response + else: + return UnaryUnaryCallResponse(call_or_response) + + else: + return UnaryUnaryCall( + request, _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, self._channel, + client_call_details.method, request_serializer, + response_deserializer, self._loop) + + client_call_details = ClientCallDetails(method, timeout, metadata, + credentials, wait_for_ready) + return await _run_interceptor(list(interceptors), client_call_details, + request) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, + InterceptedCall, _base_call.UnaryStreamCall): + """Used for running a `UnaryStreamCall` wrapped by interceptors.""" + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] + + # pylint: disable=too-many-arguments + def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor], + request: RequestType, timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._channel = channel + self._init_stream_response_mixin() + self._last_returned_call_from_interceptors = None + interceptors_task = loop.create_task( + self._invoke(interceptors, method, timeout, metadata, credentials, + wait_for_ready, request, request_serializer, + response_deserializer)) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> UnaryStreamCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: List[UnaryStreamClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> _base_call.UnaryUnaryCall: + + if interceptors: + continuation = functools.partial(_run_interceptor, + interceptors[1:]) + + call_or_response_iterator = await interceptors[ + 0].intercept_unary_stream(continuation, client_call_details, + request) + + if isinstance(call_or_response_iterator, + _base_call.UnaryStreamCall): + self._last_returned_call_from_interceptors = call_or_response_iterator + else: + self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator( + self._last_returned_call_from_interceptors, + call_or_response_iterator) + return self._last_returned_call_from_interceptors + else: + self._last_returned_call_from_interceptors = UnaryStreamCall( + request, _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, self._channel, + client_call_details.method, request_serializer, + response_deserializer, self._loop) + + return self._last_returned_call_from_interceptors + + client_call_details = ClientCallDetails(method, timeout, metadata, + credentials, wait_for_ready) + return await _run_interceptor(list(interceptors), client_call_details, + request) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, + _InterceptedStreamRequestMixin, + InterceptedCall, _base_call.StreamUnaryCall): + """Used for running a `StreamUnaryCall` wrapped by interceptors. + + For the `__await__` method is it is proxied to the intercepted call only when + the interceptor task is finished. + """ + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + + # pylint: disable=too-many-arguments + def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor], + request_iterator: Optional[RequestIterableType], + timeout: Optional[float], metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._channel = channel + request_iterator = self._init_stream_request_mixin(request_iterator) + interceptors_task = loop.create_task( + self._invoke(interceptors, method, timeout, metadata, credentials, + wait_for_ready, request_iterator, request_serializer, + response_deserializer)) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, interceptors: Sequence[StreamUnaryClientInterceptor], + method: bytes, timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request_iterator: RequestIterableType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> StreamUnaryCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: Iterator[UnaryUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType + ) -> _base_call.StreamUnaryCall: + + if interceptors: + continuation = functools.partial(_run_interceptor, + interceptors[1:]) + + return await interceptors[0].intercept_stream_unary( + continuation, client_call_details, request_iterator) + else: + return StreamUnaryCall( + request_iterator, + _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, self._channel, + client_call_details.method, request_serializer, + response_deserializer, self._loop) + + client_call_details = ClientCallDetails(method, timeout, metadata, + credentials, wait_for_ready) + return await _run_interceptor(list(interceptors), client_call_details, + request_iterator) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, + _InterceptedStreamRequestMixin, + InterceptedCall, _base_call.StreamStreamCall): + """Used for running a `StreamStreamCall` wrapped by interceptors.""" + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] + + # pylint: disable=too-many-arguments + def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor], + request_iterator: Optional[RequestIterableType], + timeout: Optional[float], metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._channel = channel + self._init_stream_response_mixin() + request_iterator = self._init_stream_request_mixin(request_iterator) + self._last_returned_call_from_interceptors = None + interceptors_task = loop.create_task( + self._invoke(interceptors, method, timeout, metadata, credentials, + wait_for_ready, request_iterator, request_serializer, + response_deserializer)) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, interceptors: Sequence[StreamStreamClientInterceptor], + method: bytes, timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request_iterator: RequestIterableType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> StreamStreamCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: List[StreamStreamClientInterceptor], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType + ) -> _base_call.StreamStreamCall: + + if interceptors: + continuation = functools.partial(_run_interceptor, + interceptors[1:]) + + call_or_response_iterator = await interceptors[ + 0].intercept_stream_stream(continuation, + client_call_details, + request_iterator) + + if isinstance(call_or_response_iterator, + _base_call.StreamStreamCall): + self._last_returned_call_from_interceptors = call_or_response_iterator + else: + self._last_returned_call_from_interceptors = StreamStreamCallResponseIterator( + self._last_returned_call_from_interceptors, + call_or_response_iterator) + return self._last_returned_call_from_interceptors + else: + self._last_returned_call_from_interceptors = StreamStreamCall( + request_iterator, + _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, self._channel, + client_call_details.method, request_serializer, + response_deserializer, self._loop) + return self._last_returned_call_from_interceptors + + client_call_details = ClientCallDetails(method, timeout, metadata, + credentials, wait_for_ready) + return await _run_interceptor(list(interceptors), client_call_details, + request_iterator) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): + """Final UnaryUnaryCall class finished with a response.""" + _response: ResponseType + + def __init__(self, response: ResponseType) -> None: + self._response = response + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return False + + def done(self) -> bool: + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[Metadata]: + return None + + async def trailing_metadata(self) -> Optional[Metadata]: + return None + + async def code(self) -> grpc.StatusCode: + return grpc.StatusCode.OK + + async def details(self) -> str: + return '' + + async def debug_error_string(self) -> Optional[str]: + return None + + def __await__(self): + if False: # pylint: disable=using-constant-test + # This code path is never used, but a yield statement is needed + # for telling the interpreter that __await__ is a generator. + yield None + return self._response + + async def wait_for_connection(self) -> None: + pass + + +class _StreamCallResponseIterator: + + _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall] + _response_iterator: AsyncIterable[ResponseType] + + def __init__(self, call: Union[_base_call.UnaryStreamCall, + _base_call.StreamStreamCall], + response_iterator: AsyncIterable[ResponseType]) -> None: + self._response_iterator = response_iterator + self._call = call + + def cancel(self) -> bool: + return self._call.cancel() + + def cancelled(self) -> bool: + return self._call.cancelled() + + def done(self) -> bool: + return self._call.done() + + def add_done_callback(self, callback) -> None: + self._call.add_done_callback(callback) + + def time_remaining(self) -> Optional[float]: + return self._call.time_remaining() + + async def initial_metadata(self) -> Optional[Metadata]: + return await self._call.initial_metadata() + + async def trailing_metadata(self) -> Optional[Metadata]: + return await self._call.trailing_metadata() + + async def code(self) -> grpc.StatusCode: + return await self._call.code() + + async def details(self) -> str: + return await self._call.details() + + async def debug_error_string(self) -> Optional[str]: + return await self._call.debug_error_string() + + def __aiter__(self): + return self._response_iterator.__aiter__() + + async def wait_for_connection(self) -> None: + return await self._call.wait_for_connection() + + +class UnaryStreamCallResponseIterator(_StreamCallResponseIterator, + _base_call.UnaryStreamCall): + """UnaryStreamCall class wich uses an alternative response iterator.""" + + async def read(self) -> ResponseType: + # Behind the scenes everyting goes through the + # async iterator. So this path should not be reached. + raise NotImplementedError() + + +class StreamStreamCallResponseIterator(_StreamCallResponseIterator, + _base_call.StreamStreamCall): + """StreamStreamCall class wich uses an alternative response iterator.""" + + async def read(self) -> ResponseType: + # Behind the scenes everyting goes through the + # async iterator. So this path should not be reached. + raise NotImplementedError() + + async def write(self, request: RequestType) -> None: + # Behind the scenes everyting goes through the + # async iterator provided by the InterceptedStreamStreamCall. + # So this path should not be reached. + raise NotImplementedError() + + async def done_writing(self) -> None: + # Behind the scenes everyting goes through the + # async iterator provided by the InterceptedStreamStreamCall. + # So this path should not be reached. + raise NotImplementedError() + + @property + def _done_writing_flag(self) -> bool: + return self._call._done_writing_flag diff --git a/contrib/python/grpcio/py3/grpc/aio/_metadata.py b/contrib/python/grpcio/py3/grpc/aio/_metadata.py new file mode 100644 index 00000000000..970f62c0590 --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_metadata.py @@ -0,0 +1,120 @@ +# Copyright 2020 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the metadata abstraction for gRPC Asyncio Python.""" +from collections import OrderedDict +from collections import abc +from typing import Any, Iterator, List, Tuple, Union + +MetadataKey = str +MetadataValue = Union[str, bytes] + + +class Metadata(abc.Mapping): + """Metadata abstraction for the asynchronous calls and interceptors. + + The metadata is a mapping from str -> List[str] + + Traits + * Multiple entries are allowed for the same key + * The order of the values by key is preserved + * Getting by an element by key, retrieves the first mapped value + * Supports an immutable view of the data + * Allows partial mutation on the data without recreating the new object from scratch. + """ + + def __init__(self, *args: Tuple[MetadataKey, MetadataValue]) -> None: + self._metadata = OrderedDict() + for md_key, md_value in args: + self.add(md_key, md_value) + + @classmethod + def from_tuple(cls, raw_metadata: tuple): + if raw_metadata: + return cls(*raw_metadata) + return cls() + + def add(self, key: MetadataKey, value: MetadataValue) -> None: + self._metadata.setdefault(key, []) + self._metadata[key].append(value) + + def __len__(self) -> int: + """Return the total number of elements that there are in the metadata, + including multiple values for the same key. + """ + return sum(map(len, self._metadata.values())) + + def __getitem__(self, key: MetadataKey) -> MetadataValue: + """When calling <metadata>[<key>], the first element of all those + mapped for <key> is returned. + """ + try: + return self._metadata[key][0] + except (ValueError, IndexError) as e: + raise KeyError("{0!r}".format(key)) from e + + def __setitem__(self, key: MetadataKey, value: MetadataValue) -> None: + """Calling metadata[<key>] = <value> + Maps <value> to the first instance of <key>. + """ + if key not in self: + self._metadata[key] = [value] + else: + current_values = self.get_all(key) + self._metadata[key] = [value, *current_values[1:]] + + def __delitem__(self, key: MetadataKey) -> None: + """``del metadata[<key>]`` deletes the first mapping for <key>.""" + current_values = self.get_all(key) + if not current_values: + raise KeyError(repr(key)) + self._metadata[key] = current_values[1:] + + def delete_all(self, key: MetadataKey) -> None: + """Delete all mappings for <key>.""" + del self._metadata[key] + + def __iter__(self) -> Iterator[Tuple[MetadataKey, MetadataValue]]: + for key, values in self._metadata.items(): + for value in values: + yield (key, value) + + def get_all(self, key: MetadataKey) -> List[MetadataValue]: + """For compatibility with other Metadata abstraction objects (like in Java), + this would return all items under the desired <key>. + """ + return self._metadata.get(key, []) + + def set_all(self, key: MetadataKey, values: List[MetadataValue]) -> None: + self._metadata[key] = values + + def __contains__(self, key: MetadataKey) -> bool: + return key in self._metadata + + def __eq__(self, other: Any) -> bool: + if isinstance(other, self.__class__): + return self._metadata == other._metadata + if isinstance(other, tuple): + return tuple(self) == other + return NotImplemented # pytype: disable=bad-return-type + + def __add__(self, other: Any) -> 'Metadata': + if isinstance(other, self.__class__): + return Metadata(*(tuple(self) + tuple(other))) + if isinstance(other, tuple): + return Metadata(*(tuple(self) + other)) + return NotImplemented # pytype: disable=bad-return-type + + def __repr__(self) -> str: + view = tuple(self) + return "{0}({1!r})".format(self.__class__.__name__, view) diff --git a/contrib/python/grpcio/py3/grpc/aio/_server.py b/contrib/python/grpcio/py3/grpc/aio/_server.py new file mode 100644 index 00000000000..1465ab6bbb0 --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_server.py @@ -0,0 +1,209 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Server-side implementation of gRPC Asyncio Python.""" + +from concurrent.futures import Executor +from typing import Any, Optional, Sequence + +import grpc +from grpc import _common +from grpc import _compression +from grpc._cython import cygrpc + +from . import _base_server +from ._interceptor import ServerInterceptor +from ._typing import ChannelArgumentType + + +def _augment_channel_arguments(base_options: ChannelArgumentType, + compression: Optional[grpc.Compression]): + compression_option = _compression.create_channel_option(compression) + return tuple(base_options) + compression_option + + +class Server(_base_server.Server): + """Serves RPCs.""" + + def __init__(self, thread_pool: Optional[Executor], + generic_handlers: Optional[Sequence[grpc.GenericRpcHandler]], + interceptors: Optional[Sequence[Any]], + options: ChannelArgumentType, + maximum_concurrent_rpcs: Optional[int], + compression: Optional[grpc.Compression]): + self._loop = cygrpc.get_working_loop() + if interceptors: + invalid_interceptors = [ + interceptor for interceptor in interceptors + if not isinstance(interceptor, ServerInterceptor) + ] + if invalid_interceptors: + raise ValueError( + 'Interceptor must be ServerInterceptor, the ' + f'following are invalid: {invalid_interceptors}') + self._server = cygrpc.AioServer( + self._loop, thread_pool, generic_handlers, interceptors, + _augment_channel_arguments(options, compression), + maximum_concurrent_rpcs) + + def add_generic_rpc_handlers( + self, + generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None: + """Registers GenericRpcHandlers with this Server. + + This method is only safe to call before the server is started. + + Args: + generic_rpc_handlers: A sequence of GenericRpcHandlers that will be + used to service RPCs. + """ + self._server.add_generic_rpc_handlers(generic_rpc_handlers) + + def add_insecure_port(self, address: str) -> int: + """Opens an insecure port for accepting RPCs. + + This method may only be called before starting the server. + + Args: + address: The address for which to open a port. If the port is 0, + or not specified in the address, then the gRPC runtime will choose a port. + + Returns: + An integer port on which the server will accept RPC requests. + """ + return _common.validate_port_binding_result( + address, self._server.add_insecure_port(_common.encode(address))) + + def add_secure_port(self, address: str, + server_credentials: grpc.ServerCredentials) -> int: + """Opens a secure port for accepting RPCs. + + This method may only be called before starting the server. + + Args: + address: The address for which to open a port. + if the port is 0, or not specified in the address, then the gRPC + runtime will choose a port. + server_credentials: A ServerCredentials object. + + Returns: + An integer port on which the server will accept RPC requests. + """ + return _common.validate_port_binding_result( + address, + self._server.add_secure_port(_common.encode(address), + server_credentials)) + + async def start(self) -> None: + """Starts this Server. + + This method may only be called once. (i.e. it is not idempotent). + """ + await self._server.start() + + async def stop(self, grace: Optional[float]) -> None: + """Stops this Server. + + This method immediately stops the server from servicing new RPCs in + all cases. + + If a grace period is specified, this method returns immediately and all + RPCs active at the end of the grace period are aborted. If a grace + period is not specified (by passing None for grace), all existing RPCs + are aborted immediately and this method blocks until the last RPC + handler terminates. + + This method is idempotent and may be called at any time. Passing a + smaller grace value in a subsequent call will have the effect of + stopping the Server sooner (passing None will have the effect of + stopping the server immediately). Passing a larger grace value in a + subsequent call will not have the effect of stopping the server later + (i.e. the most restrictive grace value is used). + + Args: + grace: A duration of time in seconds or None. + """ + await self._server.shutdown(grace) + + async def wait_for_termination(self, + timeout: Optional[float] = None) -> bool: + """Block current coroutine until the server stops. + + This is an EXPERIMENTAL API. + + The wait will not consume computational resources during blocking, and + it will block until one of the two following conditions are met: + + 1) The server is stopped or terminated; + 2) A timeout occurs if timeout is not `None`. + + The timeout argument works in the same way as `threading.Event.wait()`. + https://docs.python.org/3/library/threading.html#threading.Event.wait + + Args: + timeout: A floating point number specifying a timeout for the + operation in seconds. + + Returns: + A bool indicates if the operation times out. + """ + return await self._server.wait_for_termination(timeout) + + def __del__(self): + """Schedules a graceful shutdown in current event loop. + + The Cython AioServer doesn't hold a ref-count to this class. It should + be safe to slightly extend the underlying Cython object's life span. + """ + if hasattr(self, '_server'): + if self._server.is_running(): + cygrpc.schedule_coro_threadsafe( + self._server.shutdown(None), + self._loop, + ) + + +def server(migration_thread_pool: Optional[Executor] = None, + handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None, + interceptors: Optional[Sequence[Any]] = None, + options: Optional[ChannelArgumentType] = None, + maximum_concurrent_rpcs: Optional[int] = None, + compression: Optional[grpc.Compression] = None): + """Creates a Server with which RPCs can be serviced. + + Args: + migration_thread_pool: A futures.ThreadPoolExecutor to be used by the + Server to execute non-AsyncIO RPC handlers for migration purpose. + handlers: An optional list of GenericRpcHandlers used for executing RPCs. + More handlers may be added by calling add_generic_rpc_handlers any time + before the server is started. + interceptors: An optional list of ServerInterceptor objects that observe + and optionally manipulate the incoming RPCs before handing them over to + handlers. The interceptors are given control in the order they are + specified. This is an EXPERIMENTAL API. + options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC runtime) + to configure the channel. + maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server + will service before returning RESOURCE_EXHAUSTED status, or None to + indicate no limit. + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This compression algorithm will be used for the + lifetime of the server unless overridden by set_compression. + + Returns: + A Server object. + """ + return Server(migration_thread_pool, () if handlers is None else handlers, + () if interceptors is None else interceptors, + () if options is None else options, maximum_concurrent_rpcs, + compression) diff --git a/contrib/python/grpcio/py3/grpc/aio/_typing.py b/contrib/python/grpcio/py3/grpc/aio/_typing.py new file mode 100644 index 00000000000..f9c0eb10fc7 --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_typing.py @@ -0,0 +1,35 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common types for gRPC Async API""" + +from typing import (Any, AsyncIterable, Callable, Iterable, Sequence, Tuple, + TypeVar, Union) + +from grpc._cython.cygrpc import EOF + +from ._metadata import Metadata +from ._metadata import MetadataKey +from ._metadata import MetadataValue + +RequestType = TypeVar('RequestType') +ResponseType = TypeVar('ResponseType') +SerializingFunction = Callable[[Any], bytes] +DeserializingFunction = Callable[[bytes], Any] +MetadatumType = Tuple[MetadataKey, MetadataValue] +MetadataType = Union[Metadata, Sequence[MetadatumType]] +ChannelArgumentType = Sequence[Tuple[str, Any]] +EOFType = type(EOF) +DoneCallbackType = Callable[[Any], None] +RequestIterableType = Union[Iterable[Any], AsyncIterable[Any]] +ResponseIterableType = AsyncIterable[Any] diff --git a/contrib/python/grpcio/py3/grpc/aio/_utils.py b/contrib/python/grpcio/py3/grpc/aio/_utils.py new file mode 100644 index 00000000000..e5772dce2da --- /dev/null +++ b/contrib/python/grpcio/py3/grpc/aio/_utils.py @@ -0,0 +1,22 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal utilities used by the gRPC Aio module.""" +import time +from typing import Optional + + +def _timeout_to_deadline(timeout: Optional[float]) -> Optional[float]: + if timeout is None: + return None + return time.time() + timeout |
