aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/grpcio/py3/grpc/aio/_call.py
diff options
context:
space:
mode:
authornkozlovskiy <nmk@ydb.tech>2023-09-29 12:24:06 +0300
committernkozlovskiy <nmk@ydb.tech>2023-09-29 12:41:34 +0300
commite0e3e1717e3d33762ce61950504f9637a6e669ed (patch)
treebca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/grpcio/py3/grpc/aio/_call.py
parent38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff)
downloadydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz
add ydb deps
Diffstat (limited to 'contrib/python/grpcio/py3/grpc/aio/_call.py')
-rw-r--r--contrib/python/grpcio/py3/grpc/aio/_call.py649
1 files changed, 649 insertions, 0 deletions
diff --git a/contrib/python/grpcio/py3/grpc/aio/_call.py b/contrib/python/grpcio/py3/grpc/aio/_call.py
new file mode 100644
index 00000000000..37ba945da73
--- /dev/null
+++ b/contrib/python/grpcio/py3/grpc/aio/_call.py
@@ -0,0 +1,649 @@
+# Copyright 2019 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Invocation-side implementation of gRPC Asyncio Python."""
+
+import asyncio
+import enum
+from functools import partial
+import inspect
+import logging
+import traceback
+from typing import AsyncIterator, Optional, Tuple
+
+import grpc
+from grpc import _common
+from grpc._cython import cygrpc
+
+from . import _base_call
+from ._metadata import Metadata
+from ._typing import DeserializingFunction
+from ._typing import DoneCallbackType
+from ._typing import MetadatumType
+from ._typing import RequestIterableType
+from ._typing import RequestType
+from ._typing import ResponseType
+from ._typing import SerializingFunction
+
+__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
+
+_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
+_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
+_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
+_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
+_API_STYLE_ERROR = 'The iterator and read/write APIs may not be mixed on a single RPC.'
+
+_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
+ '\tstatus = {}\n'
+ '\tdetails = "{}"\n'
+ '>')
+
+_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
+ '\tstatus = {}\n'
+ '\tdetails = "{}"\n'
+ '\tdebug_error_string = "{}"\n'
+ '>')
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class AioRpcError(grpc.RpcError):
+ """An implementation of RpcError to be used by the asynchronous API.
+
+ Raised RpcError is a snapshot of the final status of the RPC, values are
+ determined. Hence, its methods no longer needs to be coroutines.
+ """
+
+ _code: grpc.StatusCode
+ _details: Optional[str]
+ _initial_metadata: Optional[Metadata]
+ _trailing_metadata: Optional[Metadata]
+ _debug_error_string: Optional[str]
+
+ def __init__(self,
+ code: grpc.StatusCode,
+ initial_metadata: Metadata,
+ trailing_metadata: Metadata,
+ details: Optional[str] = None,
+ debug_error_string: Optional[str] = None) -> None:
+ """Constructor.
+
+ Args:
+ code: The status code with which the RPC has been finalized.
+ details: Optional details explaining the reason of the error.
+ initial_metadata: Optional initial metadata that could be sent by the
+ Server.
+ trailing_metadata: Optional metadata that could be sent by the Server.
+ """
+
+ super().__init__()
+ self._code = code
+ self._details = details
+ self._initial_metadata = initial_metadata
+ self._trailing_metadata = trailing_metadata
+ self._debug_error_string = debug_error_string
+
+ def code(self) -> grpc.StatusCode:
+ """Accesses the status code sent by the server.
+
+ Returns:
+ The `grpc.StatusCode` status code.
+ """
+ return self._code
+
+ def details(self) -> Optional[str]:
+ """Accesses the details sent by the server.
+
+ Returns:
+ The description of the error.
+ """
+ return self._details
+
+ def initial_metadata(self) -> Metadata:
+ """Accesses the initial metadata sent by the server.
+
+ Returns:
+ The initial metadata received.
+ """
+ return self._initial_metadata
+
+ def trailing_metadata(self) -> Metadata:
+ """Accesses the trailing metadata sent by the server.
+
+ Returns:
+ The trailing metadata received.
+ """
+ return self._trailing_metadata
+
+ def debug_error_string(self) -> str:
+ """Accesses the debug error string sent by the server.
+
+ Returns:
+ The debug error string received.
+ """
+ return self._debug_error_string
+
+ def _repr(self) -> str:
+ """Assembles the error string for the RPC error."""
+ return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__,
+ self._code, self._details,
+ self._debug_error_string)
+
+ def __repr__(self) -> str:
+ return self._repr()
+
+ def __str__(self) -> str:
+ return self._repr()
+
+
+def _create_rpc_error(initial_metadata: Metadata,
+ status: cygrpc.AioRpcStatus) -> AioRpcError:
+ return AioRpcError(
+ _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
+ Metadata.from_tuple(initial_metadata),
+ Metadata.from_tuple(status.trailing_metadata()),
+ details=status.details(),
+ debug_error_string=status.debug_error_string(),
+ )
+
+
+class Call:
+ """Base implementation of client RPC Call object.
+
+ Implements logic around final status, metadata and cancellation.
+ """
+ _loop: asyncio.AbstractEventLoop
+ _code: grpc.StatusCode
+ _cython_call: cygrpc._AioCall
+ _metadata: Tuple[MetadatumType, ...]
+ _request_serializer: SerializingFunction
+ _response_deserializer: DeserializingFunction
+
+ def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata,
+ request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction,
+ loop: asyncio.AbstractEventLoop) -> None:
+ self._loop = loop
+ self._cython_call = cython_call
+ self._metadata = tuple(metadata)
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __del__(self) -> None:
+ # The '_cython_call' object might be destructed before Call object
+ if hasattr(self, '_cython_call'):
+ if not self._cython_call.done():
+ self._cancel(_GC_CANCELLATION_DETAILS)
+
+ def cancelled(self) -> bool:
+ return self._cython_call.cancelled()
+
+ def _cancel(self, details: str) -> bool:
+ """Forwards the application cancellation reasoning."""
+ if not self._cython_call.done():
+ self._cython_call.cancel(details)
+ return True
+ else:
+ return False
+
+ def cancel(self) -> bool:
+ return self._cancel(_LOCAL_CANCELLATION_DETAILS)
+
+ def done(self) -> bool:
+ return self._cython_call.done()
+
+ def add_done_callback(self, callback: DoneCallbackType) -> None:
+ cb = partial(callback, self)
+ self._cython_call.add_done_callback(cb)
+
+ def time_remaining(self) -> Optional[float]:
+ return self._cython_call.time_remaining()
+
+ async def initial_metadata(self) -> Metadata:
+ raw_metadata_tuple = await self._cython_call.initial_metadata()
+ return Metadata.from_tuple(raw_metadata_tuple)
+
+ async def trailing_metadata(self) -> Metadata:
+ raw_metadata_tuple = (await
+ self._cython_call.status()).trailing_metadata()
+ return Metadata.from_tuple(raw_metadata_tuple)
+
+ async def code(self) -> grpc.StatusCode:
+ cygrpc_code = (await self._cython_call.status()).code()
+ return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
+
+ async def details(self) -> str:
+ return (await self._cython_call.status()).details()
+
+ async def debug_error_string(self) -> str:
+ return (await self._cython_call.status()).debug_error_string()
+
+ async def _raise_for_status(self) -> None:
+ if self._cython_call.is_locally_cancelled():
+ raise asyncio.CancelledError()
+ code = await self.code()
+ if code != grpc.StatusCode.OK:
+ raise _create_rpc_error(await self.initial_metadata(), await
+ self._cython_call.status())
+
+ def _repr(self) -> str:
+ return repr(self._cython_call)
+
+ def __repr__(self) -> str:
+ return self._repr()
+
+ def __str__(self) -> str:
+ return self._repr()
+
+
+class _APIStyle(enum.IntEnum):
+ UNKNOWN = 0
+ ASYNC_GENERATOR = 1
+ READER_WRITER = 2
+
+
+class _UnaryResponseMixin(Call):
+ _call_response: asyncio.Task
+
+ def _init_unary_response_mixin(self, response_task: asyncio.Task):
+ self._call_response = response_task
+
+ def cancel(self) -> bool:
+ if super().cancel():
+ self._call_response.cancel()
+ return True
+ else:
+ return False
+
+ def __await__(self) -> ResponseType:
+ """Wait till the ongoing RPC request finishes."""
+ try:
+ response = yield from self._call_response
+ except asyncio.CancelledError:
+ # Even if we caught all other CancelledError, there is still
+ # this corner case. If the application cancels immediately after
+ # the Call object is created, we will observe this
+ # `CancelledError`.
+ if not self.cancelled():
+ self.cancel()
+ raise
+
+ # NOTE(lidiz) If we raise RpcError in the task, and users doesn't
+ # 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
+ # Instead, if we move the exception raising here, the spam stops.
+ # Unfortunately, there can only be one 'yield from' in '__await__'. So,
+ # we need to access the private instance variable.
+ if response is cygrpc.EOF:
+ if self._cython_call.is_locally_cancelled():
+ raise asyncio.CancelledError()
+ else:
+ raise _create_rpc_error(self._cython_call._initial_metadata,
+ self._cython_call._status)
+ else:
+ return response
+
+
+class _StreamResponseMixin(Call):
+ _message_aiter: AsyncIterator[ResponseType]
+ _preparation: asyncio.Task
+ _response_style: _APIStyle
+
+ def _init_stream_response_mixin(self, preparation: asyncio.Task):
+ self._message_aiter = None
+ self._preparation = preparation
+ self._response_style = _APIStyle.UNKNOWN
+
+ def _update_response_style(self, style: _APIStyle):
+ if self._response_style is _APIStyle.UNKNOWN:
+ self._response_style = style
+ elif self._response_style is not style:
+ raise cygrpc.UsageError(_API_STYLE_ERROR)
+
+ def cancel(self) -> bool:
+ if super().cancel():
+ self._preparation.cancel()
+ return True
+ else:
+ return False
+
+ async def _fetch_stream_responses(self) -> ResponseType:
+ message = await self._read()
+ while message is not cygrpc.EOF:
+ yield message
+ message = await self._read()
+
+ # If the read operation failed, Core should explain why.
+ await self._raise_for_status()
+
+ def __aiter__(self) -> AsyncIterator[ResponseType]:
+ self._update_response_style(_APIStyle.ASYNC_GENERATOR)
+ if self._message_aiter is None:
+ self._message_aiter = self._fetch_stream_responses()
+ return self._message_aiter
+
+ async def _read(self) -> ResponseType:
+ # Wait for the request being sent
+ await self._preparation
+
+ # Reads response message from Core
+ try:
+ raw_response = await self._cython_call.receive_serialized_message()
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ raise
+
+ if raw_response is cygrpc.EOF:
+ return cygrpc.EOF
+ else:
+ return _common.deserialize(raw_response,
+ self._response_deserializer)
+
+ async def read(self) -> ResponseType:
+ if self.done():
+ await self._raise_for_status()
+ return cygrpc.EOF
+ self._update_response_style(_APIStyle.READER_WRITER)
+
+ response_message = await self._read()
+
+ if response_message is cygrpc.EOF:
+ # If the read operation failed, Core should explain why.
+ await self._raise_for_status()
+ return response_message
+
+
+class _StreamRequestMixin(Call):
+ _metadata_sent: asyncio.Event
+ _done_writing_flag: bool
+ _async_request_poller: Optional[asyncio.Task]
+ _request_style: _APIStyle
+
+ def _init_stream_request_mixin(
+ self, request_iterator: Optional[RequestIterableType]):
+ self._metadata_sent = asyncio.Event()
+ self._done_writing_flag = False
+
+ # If user passes in an async iterator, create a consumer Task.
+ if request_iterator is not None:
+ self._async_request_poller = self._loop.create_task(
+ self._consume_request_iterator(request_iterator))
+ self._request_style = _APIStyle.ASYNC_GENERATOR
+ else:
+ self._async_request_poller = None
+ self._request_style = _APIStyle.READER_WRITER
+
+ def _raise_for_different_style(self, style: _APIStyle):
+ if self._request_style is not style:
+ raise cygrpc.UsageError(_API_STYLE_ERROR)
+
+ def cancel(self) -> bool:
+ if super().cancel():
+ if self._async_request_poller is not None:
+ self._async_request_poller.cancel()
+ return True
+ else:
+ return False
+
+ def _metadata_sent_observer(self):
+ self._metadata_sent.set()
+
+ async def _consume_request_iterator(
+ self, request_iterator: RequestIterableType) -> None:
+ try:
+ if inspect.isasyncgen(request_iterator) or hasattr(
+ request_iterator, '__aiter__'):
+ async for request in request_iterator:
+ try:
+ await self._write(request)
+ except AioRpcError as rpc_error:
+ _LOGGER.debug(
+ 'Exception while consuming the request_iterator: %s',
+ rpc_error)
+ return
+ else:
+ for request in request_iterator:
+ try:
+ await self._write(request)
+ except AioRpcError as rpc_error:
+ _LOGGER.debug(
+ 'Exception while consuming the request_iterator: %s',
+ rpc_error)
+ return
+
+ await self._done_writing()
+ except: # pylint: disable=bare-except
+ # Client iterators can raise exceptions, which we should handle by
+ # cancelling the RPC and logging the client's error. No exceptions
+ # should escape this function.
+ _LOGGER.debug('Client request_iterator raised exception:\n%s',
+ traceback.format_exc())
+ self.cancel()
+
+ async def _write(self, request: RequestType) -> None:
+ if self.done():
+ raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
+ if self._done_writing_flag:
+ raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
+ if not self._metadata_sent.is_set():
+ await self._metadata_sent.wait()
+ if self.done():
+ await self._raise_for_status()
+
+ serialized_request = _common.serialize(request,
+ self._request_serializer)
+ try:
+ await self._cython_call.send_serialized_message(serialized_request)
+ except cygrpc.InternalError:
+ await self._raise_for_status()
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ raise
+
+ async def _done_writing(self) -> None:
+ if self.done():
+ # If the RPC is finished, do nothing.
+ return
+ if not self._done_writing_flag:
+ # If the done writing is not sent before, try to send it.
+ self._done_writing_flag = True
+ try:
+ await self._cython_call.send_receive_close()
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ raise
+
+ async def write(self, request: RequestType) -> None:
+ self._raise_for_different_style(_APIStyle.READER_WRITER)
+ await self._write(request)
+
+ async def done_writing(self) -> None:
+ """Signal peer that client is done writing.
+
+ This method is idempotent.
+ """
+ self._raise_for_different_style(_APIStyle.READER_WRITER)
+ await self._done_writing()
+
+ async def wait_for_connection(self) -> None:
+ await self._metadata_sent.wait()
+ if self.done():
+ await self._raise_for_status()
+
+
+class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
+ """Object for managing unary-unary RPC calls.
+
+ Returned when an instance of `UnaryUnaryMultiCallable` object is called.
+ """
+ _request: RequestType
+ _invocation_task: asyncio.Task
+
+ # pylint: disable=too-many-arguments
+ def __init__(self, request: RequestType, deadline: Optional[float],
+ metadata: Metadata,
+ credentials: Optional[grpc.CallCredentials],
+ wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
+ method: bytes, request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction,
+ loop: asyncio.AbstractEventLoop) -> None:
+ super().__init__(
+ channel.call(method, deadline, credentials, wait_for_ready),
+ metadata, request_serializer, response_deserializer, loop)
+ self._request = request
+ self._invocation_task = loop.create_task(self._invoke())
+ self._init_unary_response_mixin(self._invocation_task)
+
+ async def _invoke(self) -> ResponseType:
+ serialized_request = _common.serialize(self._request,
+ self._request_serializer)
+
+ # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
+ # because the asyncio.Task class do not cache the exception object.
+ # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
+ try:
+ serialized_response = await self._cython_call.unary_unary(
+ serialized_request, self._metadata)
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+
+ if self._cython_call.is_ok():
+ return _common.deserialize(serialized_response,
+ self._response_deserializer)
+ else:
+ return cygrpc.EOF
+
+ async def wait_for_connection(self) -> None:
+ await self._invocation_task
+ if self.done():
+ await self._raise_for_status()
+
+
+class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
+ """Object for managing unary-stream RPC calls.
+
+ Returned when an instance of `UnaryStreamMultiCallable` object is called.
+ """
+ _request: RequestType
+ _send_unary_request_task: asyncio.Task
+
+ # pylint: disable=too-many-arguments
+ def __init__(self, request: RequestType, deadline: Optional[float],
+ metadata: Metadata,
+ credentials: Optional[grpc.CallCredentials],
+ wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
+ method: bytes, request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction,
+ loop: asyncio.AbstractEventLoop) -> None:
+ super().__init__(
+ channel.call(method, deadline, credentials, wait_for_ready),
+ metadata, request_serializer, response_deserializer, loop)
+ self._request = request
+ self._send_unary_request_task = loop.create_task(
+ self._send_unary_request())
+ self._init_stream_response_mixin(self._send_unary_request_task)
+
+ async def _send_unary_request(self) -> ResponseType:
+ serialized_request = _common.serialize(self._request,
+ self._request_serializer)
+ try:
+ await self._cython_call.initiate_unary_stream(
+ serialized_request, self._metadata)
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ raise
+
+ async def wait_for_connection(self) -> None:
+ await self._send_unary_request_task
+ if self.done():
+ await self._raise_for_status()
+
+
+class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
+ _base_call.StreamUnaryCall):
+ """Object for managing stream-unary RPC calls.
+
+ Returned when an instance of `StreamUnaryMultiCallable` object is called.
+ """
+
+ # pylint: disable=too-many-arguments
+ def __init__(self, request_iterator: Optional[RequestIterableType],
+ deadline: Optional[float], metadata: Metadata,
+ credentials: Optional[grpc.CallCredentials],
+ wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
+ method: bytes, request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction,
+ loop: asyncio.AbstractEventLoop) -> None:
+ super().__init__(
+ channel.call(method, deadline, credentials, wait_for_ready),
+ metadata, request_serializer, response_deserializer, loop)
+
+ self._init_stream_request_mixin(request_iterator)
+ self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
+
+ async def _conduct_rpc(self) -> ResponseType:
+ try:
+ serialized_response = await self._cython_call.stream_unary(
+ self._metadata, self._metadata_sent_observer)
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ raise
+
+ if self._cython_call.is_ok():
+ return _common.deserialize(serialized_response,
+ self._response_deserializer)
+ else:
+ return cygrpc.EOF
+
+
+class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
+ _base_call.StreamStreamCall):
+ """Object for managing stream-stream RPC calls.
+
+ Returned when an instance of `StreamStreamMultiCallable` object is called.
+ """
+ _initializer: asyncio.Task
+
+ # pylint: disable=too-many-arguments
+ def __init__(self, request_iterator: Optional[RequestIterableType],
+ deadline: Optional[float], metadata: Metadata,
+ credentials: Optional[grpc.CallCredentials],
+ wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
+ method: bytes, request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction,
+ loop: asyncio.AbstractEventLoop) -> None:
+ super().__init__(
+ channel.call(method, deadline, credentials, wait_for_ready),
+ metadata, request_serializer, response_deserializer, loop)
+ self._initializer = self._loop.create_task(self._prepare_rpc())
+ self._init_stream_request_mixin(request_iterator)
+ self._init_stream_response_mixin(self._initializer)
+
+ async def _prepare_rpc(self):
+ """This method prepares the RPC for receiving/sending messages.
+
+ All other operations around the stream should only happen after the
+ completion of this method.
+ """
+ try:
+ await self._cython_call.initiate_stream_stream(
+ self._metadata, self._metadata_sent_observer)
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ # No need to raise RpcError here, because no one will `await` this task.