# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of gRPC Python interceptors."""

import collections
import sys
import types
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import grpc

from ._typing import DeserializingFunction
from ._typing import DoneCallbackType
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import SerializingFunction


class _ServicePipeline(object):
    interceptors: Tuple[grpc.ServerInterceptor]

    def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]):
        self.interceptors = tuple(interceptors)

    def _continuation(self, thunk: Callable, index: int) -> Callable:
        return lambda context: self._intercept_at(thunk, index, context)

    def _intercept_at(
            self, thunk: Callable, index: int,
            context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
        if index < len(self.interceptors):
            interceptor = self.interceptors[index]
            thunk = self._continuation(thunk, index + 1)
            return interceptor.intercept_service(thunk, context)
        else:
            return thunk(context)

    def execute(self, thunk: Callable,
                context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
        return self._intercept_at(thunk, 0, context)


def service_pipeline(
    interceptors: Optional[Sequence[grpc.ServerInterceptor]]
) -> Optional[_ServicePipeline]:
    return _ServicePipeline(interceptors) if interceptors else None


class _ClientCallDetails(
        collections.namedtuple('_ClientCallDetails',
                               ('method', 'timeout', 'metadata', 'credentials',
                                'wait_for_ready', 'compression')),
        grpc.ClientCallDetails):
    pass


def _unwrap_client_call_details(
    call_details: grpc.ClientCallDetails,
    default_details: grpc.ClientCallDetails
) -> Tuple[str, float, MetadataType, grpc.CallCredentials, bool,
           grpc.Compression]:
    try:
        method = call_details.method  # pytype: disable=attribute-error
    except AttributeError:
        method = default_details.method  # pytype: disable=attribute-error

    try:
        timeout = call_details.timeout  # pytype: disable=attribute-error
    except AttributeError:
        timeout = default_details.timeout  # pytype: disable=attribute-error

    try:
        metadata = call_details.metadata  # pytype: disable=attribute-error
    except AttributeError:
        metadata = default_details.metadata  # pytype: disable=attribute-error

    try:
        credentials = call_details.credentials  # pytype: disable=attribute-error
    except AttributeError:
        credentials = default_details.credentials  # pytype: disable=attribute-error

    try:
        wait_for_ready = call_details.wait_for_ready  # pytype: disable=attribute-error
    except AttributeError:
        wait_for_ready = default_details.wait_for_ready  # pytype: disable=attribute-error

    try:
        compression = call_details.compression  # pytype: disable=attribute-error
    except AttributeError:
        compression = default_details.compression  # pytype: disable=attribute-error

    return method, timeout, metadata, credentials, wait_for_ready, compression


class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call):  # pylint: disable=too-many-ancestors
    _exception: Exception
    _traceback: types.TracebackType

    def __init__(self, exception: Exception, traceback: types.TracebackType):
        super(_FailureOutcome, self).__init__()
        self._exception = exception
        self._traceback = traceback

    def initial_metadata(self) -> Optional[MetadataType]:
        return None

    def trailing_metadata(self) -> Optional[MetadataType]:
        return None

    def code(self) -> Optional[grpc.StatusCode]:
        return grpc.StatusCode.INTERNAL

    def details(self) -> Optional[str]:
        return 'Exception raised while intercepting the RPC'

    def cancel(self) -> bool:
        return False

    def cancelled(self) -> bool:
        return False

    def is_active(self) -> bool:
        return False

    def time_remaining(self) -> Optional[float]:
        return None

    def running(self) -> bool:
        return False

    def done(self) -> bool:
        return True

    def result(self, ignored_timeout: Optional[float] = None):
        raise self._exception

    def exception(
            self,
            ignored_timeout: Optional[float] = None) -> Optional[Exception]:
        return self._exception

    def traceback(
        self,
        ignored_timeout: Optional[float] = None
    ) -> Optional[types.TracebackType]:
        return self._traceback

    def add_callback(self, unused_callback) -> bool:
        return False

    def add_done_callback(self, fn: DoneCallbackType) -> None:
        fn(self)

    def __iter__(self):
        return self

    def __next__(self):
        raise self._exception

    def next(self):
        return self.__next__()


class _UnaryOutcome(grpc.Call, grpc.Future):
    _response: Any
    _call: grpc.Call

    def __init__(self, response: Any, call: grpc.Call):
        self._response = response
        self._call = call

    def initial_metadata(self) -> Optional[MetadataType]:
        return self._call.initial_metadata()

    def trailing_metadata(self) -> Optional[MetadataType]:
        return self._call.trailing_metadata()

    def code(self) -> Optional[grpc.StatusCode]:
        return self._call.code()

    def details(self) -> Optional[str]:
        return self._call.details()

    def is_active(self) -> bool:
        return self._call.is_active()

    def time_remaining(self) -> Optional[float]:
        return self._call.time_remaining()

    def cancel(self) -> bool:
        return self._call.cancel()

    def add_callback(self, callback) -> bool:
        return self._call.add_callback(callback)

    def cancelled(self) -> bool:
        return False

    def running(self) -> bool:
        return False

    def done(self) -> bool:
        return True

    def result(self, ignored_timeout: Optional[float] = None):
        return self._response

    def exception(self, ignored_timeout: Optional[float] = None):
        return None

    def traceback(self, ignored_timeout: Optional[float] = None):
        return None

    def add_done_callback(self, fn: DoneCallbackType) -> None:
        fn(self)


class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
    _thunk: Callable
    _method: str
    _interceptor: grpc.UnaryUnaryClientInterceptor

    def __init__(self, thunk: Callable, method: str,
                 interceptor: grpc.UnaryUnaryClientInterceptor):
        self._thunk = thunk
        self._method = method
        self._interceptor = interceptor

    def __call__(self,
                 request: Any,
                 timeout: Optional[float] = None,
                 metadata: Optional[MetadataType] = None,
                 credentials: Optional[grpc.CallCredentials] = None,
                 wait_for_ready: Optional[bool] = None,
                 compression: Optional[grpc.Compression] = None) -> Any:
        response, ignored_call = self._with_call(request,
                                                 timeout=timeout,
                                                 metadata=metadata,
                                                 credentials=credentials,
                                                 wait_for_ready=wait_for_ready,
                                                 compression=compression)
        return response

    def _with_call(
        self,
        request: Any,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None
    ) -> Tuple[Any, grpc.Call]:
        client_call_details = _ClientCallDetails(self._method, timeout,
                                                 metadata, credentials,
                                                 wait_for_ready, compression)

        def continuation(new_details, request):
            (new_method, new_timeout, new_metadata, new_credentials,
             new_wait_for_ready,
             new_compression) = (_unwrap_client_call_details(
                 new_details, client_call_details))
            try:
                response, call = self._thunk(new_method).with_call(
                    request,
                    timeout=new_timeout,
                    metadata=new_metadata,
                    credentials=new_credentials,
                    wait_for_ready=new_wait_for_ready,
                    compression=new_compression)
                return _UnaryOutcome(response, call)
            except grpc.RpcError as rpc_error:
                return rpc_error
            except Exception as exception:  # pylint:disable=broad-except
                return _FailureOutcome(exception, sys.exc_info()[2])

        call = self._interceptor.intercept_unary_unary(continuation,
                                                       client_call_details,
                                                       request)
        return call.result(), call

    def with_call(
        self,
        request: Any,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None
    ) -> Tuple[Any, grpc.Call]:
        return self._with_call(request,
                               timeout=timeout,
                               metadata=metadata,
                               credentials=credentials,
                               wait_for_ready=wait_for_ready,
                               compression=compression)

    def future(self,
               request: Any,
               timeout: Optional[float] = None,
               metadata: Optional[MetadataType] = None,
               credentials: Optional[grpc.CallCredentials] = None,
               wait_for_ready: Optional[bool] = None,
               compression: Optional[grpc.Compression] = None) -> Any:
        client_call_details = _ClientCallDetails(self._method, timeout,
                                                 metadata, credentials,
                                                 wait_for_ready, compression)

        def continuation(new_details, request):
            (new_method, new_timeout, new_metadata, new_credentials,
             new_wait_for_ready,
             new_compression) = (_unwrap_client_call_details(
                 new_details, client_call_details))
            return self._thunk(new_method).future(
                request,
                timeout=new_timeout,
                metadata=new_metadata,
                credentials=new_credentials,
                wait_for_ready=new_wait_for_ready,
                compression=new_compression)

        try:
            return self._interceptor.intercept_unary_unary(
                continuation, client_call_details, request)
        except Exception as exception:  # pylint:disable=broad-except
            return _FailureOutcome(exception, sys.exc_info()[2])


class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
    _thunk: Callable
    _method: str
    _interceptor: grpc.UnaryStreamClientInterceptor

    def __init__(self, thunk: Callable, method: str,
                 interceptor: grpc.UnaryStreamClientInterceptor):
        self._thunk = thunk
        self._method = method
        self._interceptor = interceptor

    def __call__(self,
                 request: Any,
                 timeout: Optional[float] = None,
                 metadata: Optional[MetadataType] = None,
                 credentials: Optional[grpc.CallCredentials] = None,
                 wait_for_ready: Optional[bool] = None,
                 compression: Optional[grpc.Compression] = None):
        client_call_details = _ClientCallDetails(self._method, timeout,
                                                 metadata, credentials,
                                                 wait_for_ready, compression)

        def continuation(new_details, request):
            (new_method, new_timeout, new_metadata, new_credentials,
             new_wait_for_ready,
             new_compression) = (_unwrap_client_call_details(
                 new_details, client_call_details))
            return self._thunk(new_method)(request,
                                           timeout=new_timeout,
                                           metadata=new_metadata,
                                           credentials=new_credentials,
                                           wait_for_ready=new_wait_for_ready,
                                           compression=new_compression)

        try:
            return self._interceptor.intercept_unary_stream(
                continuation, client_call_details, request)
        except Exception as exception:  # pylint:disable=broad-except
            return _FailureOutcome(exception, sys.exc_info()[2])


class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
    _thunk: Callable
    _method: str
    _interceptor: grpc.StreamUnaryClientInterceptor

    def __init__(self, thunk: Callable, method: str,
                 interceptor: grpc.StreamUnaryClientInterceptor):
        self._thunk = thunk
        self._method = method
        self._interceptor = interceptor

    def __call__(self,
                 request_iterator: RequestIterableType,
                 timeout: Optional[float] = None,
                 metadata: Optional[MetadataType] = None,
                 credentials: Optional[grpc.CallCredentials] = None,
                 wait_for_ready: Optional[bool] = None,
                 compression: Optional[grpc.Compression] = None) -> Any:
        response, ignored_call = self._with_call(request_iterator,
                                                 timeout=timeout,
                                                 metadata=metadata,
                                                 credentials=credentials,
                                                 wait_for_ready=wait_for_ready,
                                                 compression=compression)
        return response

    def _with_call(
        self,
        request_iterator: RequestIterableType,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None
    ) -> Tuple[Any, grpc.Call]:
        client_call_details = _ClientCallDetails(self._method, timeout,
                                                 metadata, credentials,
                                                 wait_for_ready, compression)

        def continuation(new_details, request_iterator):
            (new_method, new_timeout, new_metadata, new_credentials,
             new_wait_for_ready,
             new_compression) = (_unwrap_client_call_details(
                 new_details, client_call_details))
            try:
                response, call = self._thunk(new_method).with_call(
                    request_iterator,
                    timeout=new_timeout,
                    metadata=new_metadata,
                    credentials=new_credentials,
                    wait_for_ready=new_wait_for_ready,
                    compression=new_compression)
                return _UnaryOutcome(response, call)
            except grpc.RpcError as rpc_error:
                return rpc_error
            except Exception as exception:  # pylint:disable=broad-except
                return _FailureOutcome(exception, sys.exc_info()[2])

        call = self._interceptor.intercept_stream_unary(continuation,
                                                        client_call_details,
                                                        request_iterator)
        return call.result(), call

    def with_call(
        self,
        request_iterator: RequestIterableType,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None
    ) -> Tuple[Any, grpc.Call]:
        return self._with_call(request_iterator,
                               timeout=timeout,
                               metadata=metadata,
                               credentials=credentials,
                               wait_for_ready=wait_for_ready,
                               compression=compression)

    def future(self,
               request_iterator: RequestIterableType,
               timeout: Optional[float] = None,
               metadata: Optional[MetadataType] = None,
               credentials: Optional[grpc.CallCredentials] = None,
               wait_for_ready: Optional[bool] = None,
               compression: Optional[grpc.Compression] = None) -> Any:
        client_call_details = _ClientCallDetails(self._method, timeout,
                                                 metadata, credentials,
                                                 wait_for_ready, compression)

        def continuation(new_details, request_iterator):
            (new_method, new_timeout, new_metadata, new_credentials,
             new_wait_for_ready,
             new_compression) = (_unwrap_client_call_details(
                 new_details, client_call_details))
            return self._thunk(new_method).future(
                request_iterator,
                timeout=new_timeout,
                metadata=new_metadata,
                credentials=new_credentials,
                wait_for_ready=new_wait_for_ready,
                compression=new_compression)

        try:
            return self._interceptor.intercept_stream_unary(
                continuation, client_call_details, request_iterator)
        except Exception as exception:  # pylint:disable=broad-except
            return _FailureOutcome(exception, sys.exc_info()[2])


class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
    _thunk: Callable
    _method: str
    _interceptor: grpc.StreamStreamClientInterceptor

    def __init__(self, thunk: Callable, method: str,
                 interceptor: grpc.StreamStreamClientInterceptor):
        self._thunk = thunk
        self._method = method
        self._interceptor = interceptor

    def __call__(self,
                 request_iterator: RequestIterableType,
                 timeout: Optional[float] = None,
                 metadata: Optional[MetadataType] = None,
                 credentials: Optional[grpc.CallCredentials] = None,
                 wait_for_ready: Optional[bool] = None,
                 compression: Optional[grpc.Compression] = None):
        client_call_details = _ClientCallDetails(self._method, timeout,
                                                 metadata, credentials,
                                                 wait_for_ready, compression)

        def continuation(new_details, request_iterator):
            (new_method, new_timeout, new_metadata, new_credentials,
             new_wait_for_ready,
             new_compression) = (_unwrap_client_call_details(
                 new_details, client_call_details))
            return self._thunk(new_method)(request_iterator,
                                           timeout=new_timeout,
                                           metadata=new_metadata,
                                           credentials=new_credentials,
                                           wait_for_ready=new_wait_for_ready,
                                           compression=new_compression)

        try:
            return self._interceptor.intercept_stream_stream(
                continuation, client_call_details, request_iterator)
        except Exception as exception:  # pylint:disable=broad-except
            return _FailureOutcome(exception, sys.exc_info()[2])


class _Channel(grpc.Channel):
    _channel: grpc.Channel
    _interceptor: Union[grpc.UnaryUnaryClientInterceptor,
                        grpc.UnaryStreamClientInterceptor,
                        grpc.StreamStreamClientInterceptor,
                        grpc.StreamUnaryClientInterceptor]

    def __init__(self, channel: grpc.Channel,
                 interceptor: Union[grpc.UnaryUnaryClientInterceptor,
                                    grpc.UnaryStreamClientInterceptor,
                                    grpc.StreamStreamClientInterceptor,
                                    grpc.StreamUnaryClientInterceptor]):
        self._channel = channel
        self._interceptor = interceptor

    def subscribe(self,
                  callback: Callable,
                  try_to_connect: Optional[bool] = False):
        self._channel.subscribe(callback, try_to_connect=try_to_connect)

    def unsubscribe(self, callback: Callable):
        self._channel.unsubscribe(callback)

    def unary_unary(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None
    ) -> grpc.UnaryUnaryMultiCallable:
        thunk = lambda m: self._channel.unary_unary(m, request_serializer,
                                                    response_deserializer)
        if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
            return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
        else:
            return thunk(method)

    def unary_stream(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None
    ) -> grpc.UnaryStreamMultiCallable:
        thunk = lambda m: self._channel.unary_stream(m, request_serializer,
                                                     response_deserializer)
        if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
            return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
        else:
            return thunk(method)

    def stream_unary(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None
    ) -> grpc.StreamUnaryMultiCallable:
        thunk = lambda m: self._channel.stream_unary(m, request_serializer,
                                                     response_deserializer)
        if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
            return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
        else:
            return thunk(method)

    def stream_stream(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None
    ) -> grpc.StreamStreamMultiCallable:
        thunk = lambda m: self._channel.stream_stream(m, request_serializer,
                                                      response_deserializer)
        if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
            return _StreamStreamMultiCallable(thunk, method, self._interceptor)
        else:
            return thunk(method)

    def _close(self):
        self._channel.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._close()
        return False

    def close(self):
        self._channel.close()


def intercept_channel(
    channel: grpc.Channel,
    *interceptors: Optional[Sequence[Union[grpc.UnaryUnaryClientInterceptor,
                                           grpc.UnaryStreamClientInterceptor,
                                           grpc.StreamStreamClientInterceptor,
                                           grpc.StreamUnaryClientInterceptor]]]
) -> grpc.Channel:
    for interceptor in reversed(list(interceptors)):
        if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
           not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
           not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
           not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
            raise TypeError('interceptor must be '
                            'grpc.UnaryUnaryClientInterceptor or '
                            'grpc.UnaryStreamClientInterceptor or '
                            'grpc.StreamUnaryClientInterceptor or '
                            'grpc.StreamStreamClientInterceptor or ')
        channel = _Channel(channel, interceptor)
    return channel